luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

crypto_aes_lib.c (10772B)


      1 /*
      2  * crypto_aes_lib.c - AES-256-GCM library (extracted core functions)
      3  */
      4 
      5 #include "crypto_core.h"
      6 #include <stdio.h>
      7 #include <string.h>
      8 #include <immintrin.h>
      9 #include <wmmintrin.h>
     10 
     11 #ifdef __GNUC__
     12 #include <cpuid.h>
     13 #endif
     14 
     15 int crypto_has_aesni = 0;
     16 int crypto_has_pclmulqdq = 0;
     17 
     18 void crypto_detect_features(void) {
     19 #ifdef __GNUC__
     20     unsigned int eax, ebx, ecx, edx;
     21     if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) {
     22         crypto_has_aesni = (ecx & bit_AES) != 0;
     23         crypto_has_pclmulqdq = (ecx & bit_PCLMUL) != 0;
     24     }
     25 #endif
     26 }
     27 
     28 static inline __m128i reverse_bytes(__m128i x) {
     29     const __m128i mask = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
     30     return _mm_shuffle_epi8(x, mask);
     31 }
     32 
     33 static void aes256_key_expansion(const uint8_t *key, aes256_key_schedule *ks) {
     34     __m128i temp1, temp2, temp3;
     35     ks->nr = 14;
     36 
     37     temp1 = _mm_loadu_si128((__m128i*)key);
     38     temp2 = _mm_loadu_si128((__m128i*)(key + 16));
     39     ks->round_keys[0] = temp1;
     40     ks->round_keys[1] = temp2;
     41 
     42     #define AES_256_ASSIST_1(temp1, temp2, temp3) do { \
     43         temp2 = _mm_aeskeygenassist_si128(temp2, 0x0); \
     44         temp3 = _mm_shuffle_epi32(temp2, 0xff); \
     45         temp2 = _mm_slli_si128(temp1, 0x4); \
     46         temp1 = _mm_xor_si128(temp1, temp2); \
     47         temp2 = _mm_slli_si128(temp2, 0x4); \
     48         temp1 = _mm_xor_si128(temp1, temp2); \
     49         temp2 = _mm_slli_si128(temp2, 0x4); \
     50         temp1 = _mm_xor_si128(temp1, temp2); \
     51         temp1 = _mm_xor_si128(temp1, temp3); \
     52     } while(0)
     53 
     54     #define AES_256_ASSIST_2(temp1, temp2, temp3) do { \
     55         temp3 = _mm_aeskeygenassist_si128(temp1, 0x0); \
     56         temp3 = _mm_shuffle_epi32(temp3, 0xaa); \
     57         temp1 = _mm_slli_si128(temp2, 0x4); \
     58         temp2 = _mm_xor_si128(temp2, temp1); \
     59         temp1 = _mm_slli_si128(temp1, 0x4); \
     60         temp2 = _mm_xor_si128(temp2, temp1); \
     61         temp1 = _mm_slli_si128(temp1, 0x4); \
     62         temp2 = _mm_xor_si128(temp2, temp1); \
     63         temp2 = _mm_xor_si128(temp2, temp3); \
     64     } while(0)
     65 
     66     AES_256_ASSIST_1(temp1, temp2, temp3);
     67     ks->round_keys[2] = temp1;
     68     AES_256_ASSIST_2(temp1, temp2, temp3);
     69     ks->round_keys[3] = temp2;
     70 
     71     temp3 = _mm_aeskeygenassist_si128(temp2, 0x01);
     72     AES_256_ASSIST_1(temp1, temp2, temp3);
     73     ks->round_keys[4] = temp1;
     74     AES_256_ASSIST_2(temp1, temp2, temp3);
     75     ks->round_keys[5] = temp2;
     76 
     77     temp3 = _mm_aeskeygenassist_si128(temp2, 0x02);
     78     AES_256_ASSIST_1(temp1, temp2, temp3);
     79     ks->round_keys[6] = temp1;
     80     AES_256_ASSIST_2(temp1, temp2, temp3);
     81     ks->round_keys[7] = temp2;
     82 
     83     temp3 = _mm_aeskeygenassist_si128(temp2, 0x04);
     84     AES_256_ASSIST_1(temp1, temp2, temp3);
     85     ks->round_keys[8] = temp1;
     86     AES_256_ASSIST_2(temp1, temp2, temp3);
     87     ks->round_keys[9] = temp2;
     88 
     89     temp3 = _mm_aeskeygenassist_si128(temp2, 0x08);
     90     AES_256_ASSIST_1(temp1, temp2, temp3);
     91     ks->round_keys[10] = temp1;
     92     AES_256_ASSIST_2(temp1, temp2, temp3);
     93     ks->round_keys[11] = temp2;
     94 
     95     temp3 = _mm_aeskeygenassist_si128(temp2, 0x10);
     96     AES_256_ASSIST_1(temp1, temp2, temp3);
     97     ks->round_keys[12] = temp1;
     98     AES_256_ASSIST_2(temp1, temp2, temp3);
     99     ks->round_keys[13] = temp2;
    100 
    101     temp3 = _mm_aeskeygenassist_si128(temp2, 0x20);
    102     AES_256_ASSIST_1(temp1, temp2, temp3);
    103     ks->round_keys[14] = temp1;
    104 }
    105 
    106 static inline __m128i aes256_encrypt_block(__m128i plaintext, const aes256_key_schedule *ks) {
    107     __m128i tmp = _mm_xor_si128(plaintext, ks->round_keys[0]);
    108     for (int i = 1; i < 14; i++) {
    109         tmp = _mm_aesenc_si128(tmp, ks->round_keys[i]);
    110     }
    111     tmp = _mm_aesenclast_si128(tmp, ks->round_keys[14]);
    112     return tmp;
    113 }
    114 
    115 static inline __m128i gf_mul(__m128i a, __m128i b) {
    116     __m128i tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
    117     __m128i tmp6 = _mm_clmulepi64_si128(a, b, 0x11);
    118     __m128i tmp4 = _mm_shuffle_epi32(a, 78);
    119     __m128i tmp5 = _mm_shuffle_epi32(b, 78);
    120     tmp4 = _mm_xor_si128(tmp4, a);
    121     tmp5 = _mm_xor_si128(tmp5, b);
    122     tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00);
    123     tmp4 = _mm_xor_si128(tmp4, tmp3);
    124     tmp4 = _mm_xor_si128(tmp4, tmp6);
    125     tmp5 = _mm_slli_si128(tmp4, 8);
    126     tmp4 = _mm_srli_si128(tmp4, 8);
    127     tmp3 = _mm_xor_si128(tmp3, tmp5);
    128     tmp6 = _mm_xor_si128(tmp6, tmp4);
    129 
    130     __m128i tmp7 = _mm_srli_epi32(tmp3, 31);
    131     __m128i tmp8 = _mm_srli_epi32(tmp6, 31);
    132     tmp3 = _mm_slli_epi32(tmp3, 1);
    133     tmp6 = _mm_slli_epi32(tmp6, 1);
    134     __m128i tmp9 = _mm_srli_si128(tmp7, 12);
    135     tmp8 = _mm_slli_si128(tmp8, 4);
    136     tmp7 = _mm_slli_si128(tmp7, 4);
    137     tmp3 = _mm_or_si128(tmp3, tmp7);
    138     tmp6 = _mm_or_si128(tmp6, tmp8);
    139     tmp6 = _mm_or_si128(tmp6, tmp9);
    140 
    141     tmp7 = _mm_slli_epi32(tmp3, 31);
    142     tmp8 = _mm_slli_epi32(tmp3, 30);
    143     tmp9 = _mm_slli_epi32(tmp3, 25);
    144     tmp7 = _mm_xor_si128(tmp7, tmp8);
    145     tmp7 = _mm_xor_si128(tmp7, tmp9);
    146     tmp8 = _mm_srli_si128(tmp7, 4);
    147     tmp7 = _mm_slli_si128(tmp7, 12);
    148     tmp3 = _mm_xor_si128(tmp3, tmp7);
    149     __m128i tmp2 = _mm_srli_epi32(tmp3, 1);
    150     tmp4 = _mm_srli_epi32(tmp3, 2);
    151     tmp5 = _mm_srli_epi32(tmp3, 7);
    152     tmp2 = _mm_xor_si128(tmp2, tmp4);
    153     tmp2 = _mm_xor_si128(tmp2, tmp5);
    154     tmp2 = _mm_xor_si128(tmp2, tmp8);
    155     tmp3 = _mm_xor_si128(tmp3, tmp2);
    156     tmp6 = _mm_xor_si128(tmp6, tmp3);
    157     return tmp6;
    158 }
    159 
    160 int aes256_gcm_init(aes256_gcm_context *ctx, const uint8_t *key) {
    161     if (!crypto_has_aesni || !crypto_has_pclmulqdq) return -1;
    162 
    163     aes256_key_expansion(key, &ctx->key_schedule);
    164     __m128i zero = _mm_setzero_si128();
    165     ctx->H = aes256_encrypt_block(zero, &ctx->key_schedule);
    166     ctx->H = reverse_bytes(ctx->H);
    167 
    168     ctx->H_powers[0] = ctx->H;
    169     for (int i = 1; i < 8; i++) {
    170         ctx->H_powers[i] = gf_mul(ctx->H_powers[i-1], ctx->H);
    171     }
    172     return 0;
    173 }
    174 
    175 static void ghash(const aes256_gcm_context *ctx, const uint8_t *aad, size_t aad_len,
    176                   const uint8_t *ciphertext, size_t ct_len, __m128i *tag) {
    177     __m128i hash = _mm_setzero_si128();
    178     size_t i;
    179 
    180     for (i = 0; i + 16 <= aad_len; i += 16) {
    181         __m128i block = _mm_loadu_si128((__m128i*)(aad + i));
    182         block = reverse_bytes(block);
    183         hash = _mm_xor_si128(hash, block);
    184         hash = gf_mul(hash, ctx->H);
    185     }
    186     if (i < aad_len) {
    187         uint8_t temp[16] = {0};
    188         memcpy(temp, aad + i, aad_len - i);
    189         __m128i block = _mm_loadu_si128((__m128i*)temp);
    190         block = reverse_bytes(block);
    191         hash = _mm_xor_si128(hash, block);
    192         hash = gf_mul(hash, ctx->H);
    193     }
    194 
    195     for (i = 0; i + 16 <= ct_len; i += 16) {
    196         __m128i block = _mm_loadu_si128((__m128i*)(ciphertext + i));
    197         block = reverse_bytes(block);
    198         hash = _mm_xor_si128(hash, block);
    199         hash = gf_mul(hash, ctx->H);
    200     }
    201     if (i < ct_len) {
    202         uint8_t temp[16] = {0};
    203         memcpy(temp, ciphertext + i, ct_len - i);
    204         __m128i block = _mm_loadu_si128((__m128i*)temp);
    205         block = reverse_bytes(block);
    206         hash = _mm_xor_si128(hash, block);
    207         hash = gf_mul(hash, ctx->H);
    208     }
    209 
    210     uint64_t aad_bits = aad_len * 8;
    211     uint64_t ct_bits = ct_len * 8;
    212     __m128i len_block = _mm_set_epi64x(ct_bits, aad_bits);
    213     len_block = reverse_bytes(len_block);
    214     hash = _mm_xor_si128(hash, len_block);
    215     hash = gf_mul(hash, ctx->H);
    216     *tag = reverse_bytes(hash);
    217 }
    218 
    219 static inline __m128i inc_counter(__m128i counter) {
    220     uint32_t ctr = _mm_extract_epi32(counter, 0);
    221     ctr++;
    222     return _mm_insert_epi32(counter, ctr, 0);
    223 }
    224 
    225 int aes256_gcm_encrypt(aes256_gcm_context *ctx, const uint8_t *iv, size_t iv_len,
    226                        const uint8_t *aad, size_t aad_len,
    227                        const uint8_t *plaintext, size_t pt_len,
    228                        uint8_t *ciphertext, uint8_t *tag, size_t tag_len) {
    229     if (tag_len > 16) return -1;
    230 
    231     __m128i counter;
    232     if (iv_len == 12) {
    233         counter = _mm_set_epi32(1, ((uint32_t*)iv)[2], ((uint32_t*)iv)[1], ((uint32_t*)iv)[0]);
    234     } else {
    235         return -1;  // Simplified - only support 96-bit IV
    236     }
    237 
    238     __m128i J0 = counter;
    239     size_t i;
    240 
    241     for (i = 0; i + 16 <= pt_len; i += 16) {
    242         counter = inc_counter(counter);
    243         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    244         __m128i pt_block = _mm_loadu_si128((__m128i*)(plaintext + i));
    245         __m128i ct_block = _mm_xor_si128(pt_block, keystream);
    246         _mm_storeu_si128((__m128i*)(ciphertext + i), ct_block);
    247     }
    248 
    249     if (i < pt_len) {
    250         counter = inc_counter(counter);
    251         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    252         uint8_t ks_bytes[16];
    253         _mm_storeu_si128((__m128i*)ks_bytes, keystream);
    254         for (size_t j = 0; j < pt_len - i; j++) {
    255             ciphertext[i + j] = plaintext[i + j] ^ ks_bytes[j];
    256         }
    257     }
    258 
    259     __m128i auth_tag;
    260     ghash(ctx, aad, aad_len, ciphertext, pt_len, &auth_tag);
    261     __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule);
    262     auth_tag = _mm_xor_si128(auth_tag, J0_keystream);
    263     _mm_storeu_si128((__m128i*)tag, auth_tag);
    264 
    265     return 0;
    266 }
    267 
    268 int aes256_gcm_decrypt(aes256_gcm_context *ctx, const uint8_t *iv, size_t iv_len,
    269                        const uint8_t *aad, size_t aad_len,
    270                        const uint8_t *ciphertext, size_t ct_len,
    271                        const uint8_t *tag, size_t tag_len, uint8_t *plaintext) {
    272     if (tag_len > 16) return -1;
    273 
    274     __m128i counter;
    275     if (iv_len == 12) {
    276         counter = _mm_set_epi32(1, ((uint32_t*)iv)[2], ((uint32_t*)iv)[1], ((uint32_t*)iv)[0]);
    277     } else {
    278         return -1;
    279     }
    280 
    281     __m128i J0 = counter;
    282 
    283     __m128i computed_tag;
    284     ghash(ctx, aad, aad_len, ciphertext, ct_len, &computed_tag);
    285     __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule);
    286     computed_tag = _mm_xor_si128(computed_tag, J0_keystream);
    287 
    288     uint8_t computed_tag_bytes[16];
    289     _mm_storeu_si128((__m128i*)computed_tag_bytes, computed_tag);
    290 
    291     int tag_match = 1;
    292     for (size_t i = 0; i < tag_len; i++) {
    293         tag_match &= (computed_tag_bytes[i] == tag[i]);
    294     }
    295     if (!tag_match) return -1;
    296 
    297     size_t i;
    298     for (i = 0; i + 16 <= ct_len; i += 16) {
    299         counter = inc_counter(counter);
    300         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    301         __m128i ct_block = _mm_loadu_si128((__m128i*)(ciphertext + i));
    302         __m128i pt_block = _mm_xor_si128(ct_block, keystream);
    303         _mm_storeu_si128((__m128i*)(plaintext + i), pt_block);
    304     }
    305 
    306     if (i < ct_len) {
    307         counter = inc_counter(counter);
    308         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    309         uint8_t ks_bytes[16];
    310         _mm_storeu_si128((__m128i*)ks_bytes, keystream);
    311         for (size_t j = 0; j < ct_len - i; j++) {
    312             plaintext[i + j] = ciphertext[i + j] ^ ks_bytes[j];
    313         }
    314     }
    315 
    316     return 0;
    317 }