luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

AES-256-GCM.c (15171B)


      1 /*
      2  * AES-256-GCM Implementation with Hardware Acceleration
      3  * Supports AES-NI and PCLMULQDQ for high performance
      4  * Conforms to NIST SP 800-38D
      5  */
      6 
      7 #include "AES-256-GCM.h"
      8 #include <stdio.h>
      9 #include <stdlib.h>
     10 #include <string.h>
     11 #include <stdint.h>
     12 #include <immintrin.h>
     13 #include <wmmintrin.h>
     14 
     15 // CPU feature detection
     16 #ifdef __GNUC__
     17 #include <cpuid.h>
     18 #endif
     19 
     20 // Feature flags
     21 static int has_aesni = 0;
     22 static int has_pclmulqdq = 0;
     23 
     24 // Check CPU capabilities
     25 static void detect_cpu_features(void) {
     26 #ifdef __GNUC__
     27     unsigned int eax, ebx, ecx, edx;
     28     if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) {
     29         has_aesni = (ecx & bit_AES) != 0;
     30         has_pclmulqdq = (ecx & bit_PCLMUL) != 0;
     31     }
     32 #endif
     33 }
     34 
     35 // Type definitions are in AES-256-GCM.h
     36 
     37 // Utility: Reverse bytes in __m128i (for GCM)
     38 static inline __m128i reverse_bytes(__m128i x) {
     39     const __m128i mask = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
     40     return _mm_shuffle_epi8(x, mask);
     41 }
     42 
     43 // AES-256 key expansion using AES-NI
     44 static void aes256_key_expansion(const uint8_t *key, aes256_key_schedule *ks) {
     45     __m128i temp1, temp2, temp3;
     46 
     47     ks->nr = 14;  // AES-256 has 14 rounds
     48 
     49     // Load the key (32 bytes for AES-256)
     50     temp1 = _mm_loadu_si128((__m128i*)key);
     51     temp2 = _mm_loadu_si128((__m128i*)(key + 16));
     52 
     53     ks->round_keys[0] = temp1;
     54     ks->round_keys[1] = temp2;
     55 
     56     // Helper macro for key expansion
     57     #define AES_256_ASSIST_1(temp1, temp2, temp3) do { \
     58         temp2 = _mm_aeskeygenassist_si128(temp2, 0x0); \
     59         temp3 = _mm_shuffle_epi32(temp2, 0xff); \
     60         temp2 = _mm_slli_si128(temp1, 0x4); \
     61         temp1 = _mm_xor_si128(temp1, temp2); \
     62         temp2 = _mm_slli_si128(temp2, 0x4); \
     63         temp1 = _mm_xor_si128(temp1, temp2); \
     64         temp2 = _mm_slli_si128(temp2, 0x4); \
     65         temp1 = _mm_xor_si128(temp1, temp2); \
     66         temp1 = _mm_xor_si128(temp1, temp3); \
     67     } while(0)
     68 
     69     #define AES_256_ASSIST_2(temp1, temp2, temp3) do { \
     70         temp3 = _mm_aeskeygenassist_si128(temp1, 0x0); \
     71         temp3 = _mm_shuffle_epi32(temp3, 0xaa); \
     72         temp1 = _mm_slli_si128(temp2, 0x4); \
     73         temp2 = _mm_xor_si128(temp2, temp1); \
     74         temp1 = _mm_slli_si128(temp1, 0x4); \
     75         temp2 = _mm_xor_si128(temp2, temp1); \
     76         temp1 = _mm_slli_si128(temp1, 0x4); \
     77         temp2 = _mm_xor_si128(temp2, temp1); \
     78         temp2 = _mm_xor_si128(temp2, temp3); \
     79     } while(0)
     80 
     81     // Generate round keys
     82     AES_256_ASSIST_1(temp1, temp2, temp3);
     83     ks->round_keys[2] = temp1;
     84     AES_256_ASSIST_2(temp1, temp2, temp3);
     85     ks->round_keys[3] = temp2;
     86 
     87     temp3 = _mm_aeskeygenassist_si128(temp2, 0x01);
     88     AES_256_ASSIST_1(temp1, temp2, temp3);
     89     ks->round_keys[4] = temp1;
     90     AES_256_ASSIST_2(temp1, temp2, temp3);
     91     ks->round_keys[5] = temp2;
     92 
     93     temp3 = _mm_aeskeygenassist_si128(temp2, 0x02);
     94     AES_256_ASSIST_1(temp1, temp2, temp3);
     95     ks->round_keys[6] = temp1;
     96     AES_256_ASSIST_2(temp1, temp2, temp3);
     97     ks->round_keys[7] = temp2;
     98 
     99     temp3 = _mm_aeskeygenassist_si128(temp2, 0x04);
    100     AES_256_ASSIST_1(temp1, temp2, temp3);
    101     ks->round_keys[8] = temp1;
    102     AES_256_ASSIST_2(temp1, temp2, temp3);
    103     ks->round_keys[9] = temp2;
    104 
    105     temp3 = _mm_aeskeygenassist_si128(temp2, 0x08);
    106     AES_256_ASSIST_1(temp1, temp2, temp3);
    107     ks->round_keys[10] = temp1;
    108     AES_256_ASSIST_2(temp1, temp2, temp3);
    109     ks->round_keys[11] = temp2;
    110 
    111     temp3 = _mm_aeskeygenassist_si128(temp2, 0x10);
    112     AES_256_ASSIST_1(temp1, temp2, temp3);
    113     ks->round_keys[12] = temp1;
    114     AES_256_ASSIST_2(temp1, temp2, temp3);
    115     ks->round_keys[13] = temp2;
    116 
    117     temp3 = _mm_aeskeygenassist_si128(temp2, 0x20);
    118     AES_256_ASSIST_1(temp1, temp2, temp3);
    119     ks->round_keys[14] = temp1;
    120 }
    121 
    122 // AES-256 single block encryption using AES-NI
    123 static inline __m128i aes256_encrypt_block(__m128i plaintext, const aes256_key_schedule *ks) {
    124     __m128i tmp = _mm_xor_si128(plaintext, ks->round_keys[0]);
    125 
    126     for (int i = 1; i < 14; i++) {
    127         tmp = _mm_aesenc_si128(tmp, ks->round_keys[i]);
    128     }
    129 
    130     tmp = _mm_aesenclast_si128(tmp, ks->round_keys[14]);
    131     return tmp;
    132 }
    133 
    134 // GHASH multiplication using PCLMULQDQ
    135 static inline __m128i gf_mul(__m128i a, __m128i b) {
    136     __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
    137     __m128i tmp8, tmp9, tmp10, tmp11, tmp12;
    138     __m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0);
    139 
    140     tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
    141     tmp6 = _mm_clmulepi64_si128(a, b, 0x11);
    142 
    143     tmp4 = _mm_shuffle_epi32(a, 78);
    144     tmp5 = _mm_shuffle_epi32(b, 78);
    145     tmp4 = _mm_xor_si128(tmp4, a);
    146     tmp5 = _mm_xor_si128(tmp5, b);
    147 
    148     tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00);
    149     tmp4 = _mm_xor_si128(tmp4, tmp3);
    150     tmp4 = _mm_xor_si128(tmp4, tmp6);
    151 
    152     tmp5 = _mm_slli_si128(tmp4, 8);
    153     tmp4 = _mm_srli_si128(tmp4, 8);
    154     tmp3 = _mm_xor_si128(tmp3, tmp5);
    155     tmp6 = _mm_xor_si128(tmp6, tmp4);
    156 
    157     // Reduction
    158     tmp7 = _mm_srli_epi32(tmp3, 31);
    159     tmp8 = _mm_srli_epi32(tmp6, 31);
    160     tmp3 = _mm_slli_epi32(tmp3, 1);
    161     tmp6 = _mm_slli_epi32(tmp6, 1);
    162 
    163     tmp9 = _mm_srli_si128(tmp7, 12);
    164     tmp8 = _mm_slli_si128(tmp8, 4);
    165     tmp7 = _mm_slli_si128(tmp7, 4);
    166     tmp3 = _mm_or_si128(tmp3, tmp7);
    167     tmp6 = _mm_or_si128(tmp6, tmp8);
    168     tmp6 = _mm_or_si128(tmp6, tmp9);
    169 
    170     tmp7 = _mm_slli_epi32(tmp3, 31);
    171     tmp8 = _mm_slli_epi32(tmp3, 30);
    172     tmp9 = _mm_slli_epi32(tmp3, 25);
    173 
    174     tmp7 = _mm_xor_si128(tmp7, tmp8);
    175     tmp7 = _mm_xor_si128(tmp7, tmp9);
    176     tmp8 = _mm_srli_si128(tmp7, 4);
    177     tmp7 = _mm_slli_si128(tmp7, 12);
    178     tmp3 = _mm_xor_si128(tmp3, tmp7);
    179 
    180     tmp2 = _mm_srli_epi32(tmp3, 1);
    181     tmp4 = _mm_srli_epi32(tmp3, 2);
    182     tmp5 = _mm_srli_epi32(tmp3, 7);
    183     tmp2 = _mm_xor_si128(tmp2, tmp4);
    184     tmp2 = _mm_xor_si128(tmp2, tmp5);
    185     tmp2 = _mm_xor_si128(tmp2, tmp8);
    186     tmp3 = _mm_xor_si128(tmp3, tmp2);
    187     tmp6 = _mm_xor_si128(tmp6, tmp3);
    188 
    189     return tmp6;
    190 }
    191 
    192 // Initialize GCM context
    193 int aes256_gcm_init(aes256_gcm_context *ctx, const uint8_t *key) {
    194     static int features_detected = 0;
    195     if (!features_detected) {
    196         detect_cpu_features();
    197         features_detected = 1;
    198     }
    199 
    200     if (!has_aesni || !has_pclmulqdq) {
    201         fprintf(stderr, "Error: CPU does not support AES-NI and PCLMULQDQ\n");
    202         return -1;
    203     }
    204 
    205     // Expand the key
    206     aes256_key_expansion(key, &ctx->key_schedule);
    207 
    208     // Compute H = E(K, 0^128)
    209     __m128i zero = _mm_setzero_si128();
    210     ctx->H = aes256_encrypt_block(zero, &ctx->key_schedule);
    211     ctx->H = reverse_bytes(ctx->H);
    212 
    213     // Precompute powers of H for faster GHASH
    214     ctx->H_powers[0] = ctx->H;
    215     for (int i = 1; i < 8; i++) {
    216         ctx->H_powers[i] = gf_mul(ctx->H_powers[i-1], ctx->H);
    217     }
    218 
    219     return 0;
    220 }
    221 
    222 // GHASH computation
    223 static void ghash(const aes256_gcm_context *ctx, const uint8_t *aad, size_t aad_len,
    224                   const uint8_t *ciphertext, size_t ct_len, __m128i *tag) {
    225     __m128i hash = _mm_setzero_si128();
    226     size_t i;
    227 
    228     // Process AAD
    229     for (i = 0; i + 16 <= aad_len; i += 16) {
    230         __m128i block = _mm_loadu_si128((__m128i*)(aad + i));
    231         block = reverse_bytes(block);
    232         hash = _mm_xor_si128(hash, block);
    233         hash = gf_mul(hash, ctx->H);
    234     }
    235 
    236     // Handle remaining AAD bytes
    237     if (i < aad_len) {
    238         uint8_t temp[16] = {0};
    239         memcpy(temp, aad + i, aad_len - i);
    240         __m128i block = _mm_loadu_si128((__m128i*)temp);
    241         block = reverse_bytes(block);
    242         hash = _mm_xor_si128(hash, block);
    243         hash = gf_mul(hash, ctx->H);
    244     }
    245 
    246     // Process ciphertext
    247     for (i = 0; i + 16 <= ct_len; i += 16) {
    248         __m128i block = _mm_loadu_si128((__m128i*)(ciphertext + i));
    249         block = reverse_bytes(block);
    250         hash = _mm_xor_si128(hash, block);
    251         hash = gf_mul(hash, ctx->H);
    252     }
    253 
    254     // Handle remaining ciphertext bytes
    255     if (i < ct_len) {
    256         uint8_t temp[16] = {0};
    257         memcpy(temp, ciphertext + i, ct_len - i);
    258         __m128i block = _mm_loadu_si128((__m128i*)temp);
    259         block = reverse_bytes(block);
    260         hash = _mm_xor_si128(hash, block);
    261         hash = gf_mul(hash, ctx->H);
    262     }
    263 
    264     // Process length block
    265     uint64_t aad_bits = aad_len * 8;
    266     uint64_t ct_bits = ct_len * 8;
    267     __m128i len_block = _mm_set_epi64x(ct_bits, aad_bits);
    268     len_block = reverse_bytes(len_block);
    269     hash = _mm_xor_si128(hash, len_block);
    270     hash = gf_mul(hash, ctx->H);
    271 
    272     *tag = reverse_bytes(hash);
    273 }
    274 
    275 // Increment counter (32-bit increment of rightmost 32 bits)
    276 static inline __m128i inc_counter(__m128i counter) {
    277     __m128i one = _mm_set_epi32(1, 0, 0, 0);
    278     __m128i mask = _mm_set_epi32(0xFFFFFFFF, 0, 0, 0);
    279 
    280     // Extract the counter value
    281     uint32_t ctr = _mm_extract_epi32(counter, 0);
    282     ctr++;
    283 
    284     // Insert back
    285     counter = _mm_insert_epi32(counter, ctr, 0);
    286     return counter;
    287 }
    288 
    289 // AES-256-GCM Encryption
    290 int aes256_gcm_encrypt(aes256_gcm_context *ctx,
    291                        const uint8_t *iv, size_t iv_len,
    292                        const uint8_t *aad, size_t aad_len,
    293                        const uint8_t *plaintext, size_t pt_len,
    294                        uint8_t *ciphertext,
    295                        uint8_t *tag, size_t tag_len) {
    296     if (tag_len > 16) {
    297         return -1;
    298     }
    299 
    300     // Prepare initial counter block
    301     __m128i counter;
    302     if (iv_len == 12) {
    303         // Standard case: IV is 96 bits
    304         counter = _mm_set_epi32(1,
    305                                  ((uint32_t*)iv)[2],
    306                                  ((uint32_t*)iv)[1],
    307                                  ((uint32_t*)iv)[0]);
    308     } else {
    309         // Non-standard IV length: use GHASH
    310         __m128i hash = _mm_setzero_si128();
    311         size_t i;
    312         for (i = 0; i + 16 <= iv_len; i += 16) {
    313             __m128i block = _mm_loadu_si128((__m128i*)(iv + i));
    314             block = reverse_bytes(block);
    315             hash = _mm_xor_si128(hash, block);
    316             hash = gf_mul(hash, ctx->H);
    317         }
    318         if (i < iv_len) {
    319             uint8_t temp[16] = {0};
    320             memcpy(temp, iv + i, iv_len - i);
    321             __m128i block = _mm_loadu_si128((__m128i*)temp);
    322             block = reverse_bytes(block);
    323             hash = _mm_xor_si128(hash, block);
    324             hash = gf_mul(hash, ctx->H);
    325         }
    326         uint64_t iv_bits = iv_len * 8;
    327         __m128i len_block = _mm_set_epi64x(iv_bits, 0);
    328         len_block = reverse_bytes(len_block);
    329         hash = _mm_xor_si128(hash, len_block);
    330         hash = gf_mul(hash, ctx->H);
    331         counter = reverse_bytes(hash);
    332     }
    333 
    334     // Save J0 for tag computation
    335     __m128i J0 = counter;
    336 
    337     // Encrypt plaintext
    338     size_t i;
    339     for (i = 0; i + 16 <= pt_len; i += 16) {
    340         counter = inc_counter(counter);
    341         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    342         __m128i pt_block = _mm_loadu_si128((__m128i*)(plaintext + i));
    343         __m128i ct_block = _mm_xor_si128(pt_block, keystream);
    344         _mm_storeu_si128((__m128i*)(ciphertext + i), ct_block);
    345     }
    346 
    347     // Handle remaining bytes
    348     if (i < pt_len) {
    349         counter = inc_counter(counter);
    350         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    351         uint8_t ks_bytes[16];
    352         _mm_storeu_si128((__m128i*)ks_bytes, keystream);
    353         for (size_t j = 0; j < pt_len - i; j++) {
    354             ciphertext[i + j] = plaintext[i + j] ^ ks_bytes[j];
    355         }
    356     }
    357 
    358     // Compute authentication tag
    359     __m128i auth_tag;
    360     ghash(ctx, aad, aad_len, ciphertext, pt_len, &auth_tag);
    361 
    362     // Encrypt the tag with J0
    363     __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule);
    364     auth_tag = _mm_xor_si128(auth_tag, J0_keystream);
    365 
    366     // Store tag
    367     _mm_storeu_si128((__m128i*)tag, auth_tag);
    368 
    369     return 0;
    370 }
    371 
    372 // AES-256-GCM Decryption
    373 int aes256_gcm_decrypt(aes256_gcm_context *ctx,
    374                        const uint8_t *iv, size_t iv_len,
    375                        const uint8_t *aad, size_t aad_len,
    376                        const uint8_t *ciphertext, size_t ct_len,
    377                        const uint8_t *tag, size_t tag_len,
    378                        uint8_t *plaintext) {
    379     if (tag_len > 16) {
    380         return -1;
    381     }
    382 
    383     // Prepare initial counter block (same as encryption)
    384     __m128i counter;
    385     if (iv_len == 12) {
    386         counter = _mm_set_epi32(1,
    387                                  ((uint32_t*)iv)[2],
    388                                  ((uint32_t*)iv)[1],
    389                                  ((uint32_t*)iv)[0]);
    390     } else {
    391         __m128i hash = _mm_setzero_si128();
    392         size_t i;
    393         for (i = 0; i + 16 <= iv_len; i += 16) {
    394             __m128i block = _mm_loadu_si128((__m128i*)(iv + i));
    395             block = reverse_bytes(block);
    396             hash = _mm_xor_si128(hash, block);
    397             hash = gf_mul(hash, ctx->H);
    398         }
    399         if (i < iv_len) {
    400             uint8_t temp[16] = {0};
    401             memcpy(temp, iv + i, iv_len - i);
    402             __m128i block = _mm_loadu_si128((__m128i*)temp);
    403             block = reverse_bytes(block);
    404             hash = _mm_xor_si128(hash, block);
    405             hash = gf_mul(hash, ctx->H);
    406         }
    407         uint64_t iv_bits = iv_len * 8;
    408         __m128i len_block = _mm_set_epi64x(iv_bits, 0);
    409         len_block = reverse_bytes(len_block);
    410         hash = _mm_xor_si128(hash, len_block);
    411         hash = gf_mul(hash, ctx->H);
    412         counter = reverse_bytes(hash);
    413     }
    414 
    415     __m128i J0 = counter;
    416 
    417     // Verify authentication tag first
    418     __m128i computed_tag;
    419     ghash(ctx, aad, aad_len, ciphertext, ct_len, &computed_tag);
    420     __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule);
    421     computed_tag = _mm_xor_si128(computed_tag, J0_keystream);
    422 
    423     // Compare tags (constant-time comparison)
    424     uint8_t computed_tag_bytes[16];
    425     _mm_storeu_si128((__m128i*)computed_tag_bytes, computed_tag);
    426 
    427     int tag_match = 1;
    428     for (size_t i = 0; i < tag_len; i++) {
    429         tag_match &= (computed_tag_bytes[i] == tag[i]);
    430     }
    431 
    432     if (!tag_match) {
    433         return -1;  // Authentication failed
    434     }
    435 
    436     // Decrypt ciphertext
    437     size_t i;
    438     for (i = 0; i + 16 <= ct_len; i += 16) {
    439         counter = inc_counter(counter);
    440         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    441         __m128i ct_block = _mm_loadu_si128((__m128i*)(ciphertext + i));
    442         __m128i pt_block = _mm_xor_si128(ct_block, keystream);
    443         _mm_storeu_si128((__m128i*)(plaintext + i), pt_block);
    444     }
    445 
    446     // Handle remaining bytes
    447     if (i < ct_len) {
    448         counter = inc_counter(counter);
    449         __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule);
    450         uint8_t ks_bytes[16];
    451         _mm_storeu_si128((__m128i*)ks_bytes, keystream);
    452         for (size_t j = 0; j < ct_len - i; j++) {
    453             plaintext[i + j] = ciphertext[i + j] ^ ks_bytes[j];
    454         }
    455     }
    456 
    457     return 0;
    458 }
    459 
    460 /**
    461  * Clean up AES-256-GCM context
    462  * Zeros all sensitive key material
    463  */
    464 void aes256_gcm_cleanup(aes256_gcm_context *ctx) {
    465     if (ctx == NULL) return;
    466 
    467     // Zero all sensitive data using volatile to prevent compiler optimization
    468     volatile uint8_t *p = (volatile uint8_t *)ctx;
    469     size_t n = sizeof(aes256_gcm_context);
    470     while (n--) {
    471         *p++ = 0;
    472     }
    473 }