AES-256-GCM.c (15171B)
1 /* 2 * AES-256-GCM Implementation with Hardware Acceleration 3 * Supports AES-NI and PCLMULQDQ for high performance 4 * Conforms to NIST SP 800-38D 5 */ 6 7 #include "AES-256-GCM.h" 8 #include <stdio.h> 9 #include <stdlib.h> 10 #include <string.h> 11 #include <stdint.h> 12 #include <immintrin.h> 13 #include <wmmintrin.h> 14 15 // CPU feature detection 16 #ifdef __GNUC__ 17 #include <cpuid.h> 18 #endif 19 20 // Feature flags 21 static int has_aesni = 0; 22 static int has_pclmulqdq = 0; 23 24 // Check CPU capabilities 25 static void detect_cpu_features(void) { 26 #ifdef __GNUC__ 27 unsigned int eax, ebx, ecx, edx; 28 if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { 29 has_aesni = (ecx & bit_AES) != 0; 30 has_pclmulqdq = (ecx & bit_PCLMUL) != 0; 31 } 32 #endif 33 } 34 35 // Type definitions are in AES-256-GCM.h 36 37 // Utility: Reverse bytes in __m128i (for GCM) 38 static inline __m128i reverse_bytes(__m128i x) { 39 const __m128i mask = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15); 40 return _mm_shuffle_epi8(x, mask); 41 } 42 43 // AES-256 key expansion using AES-NI 44 static void aes256_key_expansion(const uint8_t *key, aes256_key_schedule *ks) { 45 __m128i temp1, temp2, temp3; 46 47 ks->nr = 14; // AES-256 has 14 rounds 48 49 // Load the key (32 bytes for AES-256) 50 temp1 = _mm_loadu_si128((__m128i*)key); 51 temp2 = _mm_loadu_si128((__m128i*)(key + 16)); 52 53 ks->round_keys[0] = temp1; 54 ks->round_keys[1] = temp2; 55 56 // Helper macro for key expansion 57 #define AES_256_ASSIST_1(temp1, temp2, temp3) do { \ 58 temp2 = _mm_aeskeygenassist_si128(temp2, 0x0); \ 59 temp3 = _mm_shuffle_epi32(temp2, 0xff); \ 60 temp2 = _mm_slli_si128(temp1, 0x4); \ 61 temp1 = _mm_xor_si128(temp1, temp2); \ 62 temp2 = _mm_slli_si128(temp2, 0x4); \ 63 temp1 = _mm_xor_si128(temp1, temp2); \ 64 temp2 = _mm_slli_si128(temp2, 0x4); \ 65 temp1 = _mm_xor_si128(temp1, temp2); \ 66 temp1 = _mm_xor_si128(temp1, temp3); \ 67 } while(0) 68 69 #define AES_256_ASSIST_2(temp1, temp2, temp3) do { \ 70 temp3 = _mm_aeskeygenassist_si128(temp1, 0x0); \ 71 temp3 = _mm_shuffle_epi32(temp3, 0xaa); \ 72 temp1 = _mm_slli_si128(temp2, 0x4); \ 73 temp2 = _mm_xor_si128(temp2, temp1); \ 74 temp1 = _mm_slli_si128(temp1, 0x4); \ 75 temp2 = _mm_xor_si128(temp2, temp1); \ 76 temp1 = _mm_slli_si128(temp1, 0x4); \ 77 temp2 = _mm_xor_si128(temp2, temp1); \ 78 temp2 = _mm_xor_si128(temp2, temp3); \ 79 } while(0) 80 81 // Generate round keys 82 AES_256_ASSIST_1(temp1, temp2, temp3); 83 ks->round_keys[2] = temp1; 84 AES_256_ASSIST_2(temp1, temp2, temp3); 85 ks->round_keys[3] = temp2; 86 87 temp3 = _mm_aeskeygenassist_si128(temp2, 0x01); 88 AES_256_ASSIST_1(temp1, temp2, temp3); 89 ks->round_keys[4] = temp1; 90 AES_256_ASSIST_2(temp1, temp2, temp3); 91 ks->round_keys[5] = temp2; 92 93 temp3 = _mm_aeskeygenassist_si128(temp2, 0x02); 94 AES_256_ASSIST_1(temp1, temp2, temp3); 95 ks->round_keys[6] = temp1; 96 AES_256_ASSIST_2(temp1, temp2, temp3); 97 ks->round_keys[7] = temp2; 98 99 temp3 = _mm_aeskeygenassist_si128(temp2, 0x04); 100 AES_256_ASSIST_1(temp1, temp2, temp3); 101 ks->round_keys[8] = temp1; 102 AES_256_ASSIST_2(temp1, temp2, temp3); 103 ks->round_keys[9] = temp2; 104 105 temp3 = _mm_aeskeygenassist_si128(temp2, 0x08); 106 AES_256_ASSIST_1(temp1, temp2, temp3); 107 ks->round_keys[10] = temp1; 108 AES_256_ASSIST_2(temp1, temp2, temp3); 109 ks->round_keys[11] = temp2; 110 111 temp3 = _mm_aeskeygenassist_si128(temp2, 0x10); 112 AES_256_ASSIST_1(temp1, temp2, temp3); 113 ks->round_keys[12] = temp1; 114 AES_256_ASSIST_2(temp1, temp2, temp3); 115 ks->round_keys[13] = temp2; 116 117 temp3 = _mm_aeskeygenassist_si128(temp2, 0x20); 118 AES_256_ASSIST_1(temp1, temp2, temp3); 119 ks->round_keys[14] = temp1; 120 } 121 122 // AES-256 single block encryption using AES-NI 123 static inline __m128i aes256_encrypt_block(__m128i plaintext, const aes256_key_schedule *ks) { 124 __m128i tmp = _mm_xor_si128(plaintext, ks->round_keys[0]); 125 126 for (int i = 1; i < 14; i++) { 127 tmp = _mm_aesenc_si128(tmp, ks->round_keys[i]); 128 } 129 130 tmp = _mm_aesenclast_si128(tmp, ks->round_keys[14]); 131 return tmp; 132 } 133 134 // GHASH multiplication using PCLMULQDQ 135 static inline __m128i gf_mul(__m128i a, __m128i b) { 136 __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; 137 __m128i tmp8, tmp9, tmp10, tmp11, tmp12; 138 __m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0); 139 140 tmp3 = _mm_clmulepi64_si128(a, b, 0x00); 141 tmp6 = _mm_clmulepi64_si128(a, b, 0x11); 142 143 tmp4 = _mm_shuffle_epi32(a, 78); 144 tmp5 = _mm_shuffle_epi32(b, 78); 145 tmp4 = _mm_xor_si128(tmp4, a); 146 tmp5 = _mm_xor_si128(tmp5, b); 147 148 tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00); 149 tmp4 = _mm_xor_si128(tmp4, tmp3); 150 tmp4 = _mm_xor_si128(tmp4, tmp6); 151 152 tmp5 = _mm_slli_si128(tmp4, 8); 153 tmp4 = _mm_srli_si128(tmp4, 8); 154 tmp3 = _mm_xor_si128(tmp3, tmp5); 155 tmp6 = _mm_xor_si128(tmp6, tmp4); 156 157 // Reduction 158 tmp7 = _mm_srli_epi32(tmp3, 31); 159 tmp8 = _mm_srli_epi32(tmp6, 31); 160 tmp3 = _mm_slli_epi32(tmp3, 1); 161 tmp6 = _mm_slli_epi32(tmp6, 1); 162 163 tmp9 = _mm_srli_si128(tmp7, 12); 164 tmp8 = _mm_slli_si128(tmp8, 4); 165 tmp7 = _mm_slli_si128(tmp7, 4); 166 tmp3 = _mm_or_si128(tmp3, tmp7); 167 tmp6 = _mm_or_si128(tmp6, tmp8); 168 tmp6 = _mm_or_si128(tmp6, tmp9); 169 170 tmp7 = _mm_slli_epi32(tmp3, 31); 171 tmp8 = _mm_slli_epi32(tmp3, 30); 172 tmp9 = _mm_slli_epi32(tmp3, 25); 173 174 tmp7 = _mm_xor_si128(tmp7, tmp8); 175 tmp7 = _mm_xor_si128(tmp7, tmp9); 176 tmp8 = _mm_srli_si128(tmp7, 4); 177 tmp7 = _mm_slli_si128(tmp7, 12); 178 tmp3 = _mm_xor_si128(tmp3, tmp7); 179 180 tmp2 = _mm_srli_epi32(tmp3, 1); 181 tmp4 = _mm_srli_epi32(tmp3, 2); 182 tmp5 = _mm_srli_epi32(tmp3, 7); 183 tmp2 = _mm_xor_si128(tmp2, tmp4); 184 tmp2 = _mm_xor_si128(tmp2, tmp5); 185 tmp2 = _mm_xor_si128(tmp2, tmp8); 186 tmp3 = _mm_xor_si128(tmp3, tmp2); 187 tmp6 = _mm_xor_si128(tmp6, tmp3); 188 189 return tmp6; 190 } 191 192 // Initialize GCM context 193 int aes256_gcm_init(aes256_gcm_context *ctx, const uint8_t *key) { 194 static int features_detected = 0; 195 if (!features_detected) { 196 detect_cpu_features(); 197 features_detected = 1; 198 } 199 200 if (!has_aesni || !has_pclmulqdq) { 201 fprintf(stderr, "Error: CPU does not support AES-NI and PCLMULQDQ\n"); 202 return -1; 203 } 204 205 // Expand the key 206 aes256_key_expansion(key, &ctx->key_schedule); 207 208 // Compute H = E(K, 0^128) 209 __m128i zero = _mm_setzero_si128(); 210 ctx->H = aes256_encrypt_block(zero, &ctx->key_schedule); 211 ctx->H = reverse_bytes(ctx->H); 212 213 // Precompute powers of H for faster GHASH 214 ctx->H_powers[0] = ctx->H; 215 for (int i = 1; i < 8; i++) { 216 ctx->H_powers[i] = gf_mul(ctx->H_powers[i-1], ctx->H); 217 } 218 219 return 0; 220 } 221 222 // GHASH computation 223 static void ghash(const aes256_gcm_context *ctx, const uint8_t *aad, size_t aad_len, 224 const uint8_t *ciphertext, size_t ct_len, __m128i *tag) { 225 __m128i hash = _mm_setzero_si128(); 226 size_t i; 227 228 // Process AAD 229 for (i = 0; i + 16 <= aad_len; i += 16) { 230 __m128i block = _mm_loadu_si128((__m128i*)(aad + i)); 231 block = reverse_bytes(block); 232 hash = _mm_xor_si128(hash, block); 233 hash = gf_mul(hash, ctx->H); 234 } 235 236 // Handle remaining AAD bytes 237 if (i < aad_len) { 238 uint8_t temp[16] = {0}; 239 memcpy(temp, aad + i, aad_len - i); 240 __m128i block = _mm_loadu_si128((__m128i*)temp); 241 block = reverse_bytes(block); 242 hash = _mm_xor_si128(hash, block); 243 hash = gf_mul(hash, ctx->H); 244 } 245 246 // Process ciphertext 247 for (i = 0; i + 16 <= ct_len; i += 16) { 248 __m128i block = _mm_loadu_si128((__m128i*)(ciphertext + i)); 249 block = reverse_bytes(block); 250 hash = _mm_xor_si128(hash, block); 251 hash = gf_mul(hash, ctx->H); 252 } 253 254 // Handle remaining ciphertext bytes 255 if (i < ct_len) { 256 uint8_t temp[16] = {0}; 257 memcpy(temp, ciphertext + i, ct_len - i); 258 __m128i block = _mm_loadu_si128((__m128i*)temp); 259 block = reverse_bytes(block); 260 hash = _mm_xor_si128(hash, block); 261 hash = gf_mul(hash, ctx->H); 262 } 263 264 // Process length block 265 uint64_t aad_bits = aad_len * 8; 266 uint64_t ct_bits = ct_len * 8; 267 __m128i len_block = _mm_set_epi64x(ct_bits, aad_bits); 268 len_block = reverse_bytes(len_block); 269 hash = _mm_xor_si128(hash, len_block); 270 hash = gf_mul(hash, ctx->H); 271 272 *tag = reverse_bytes(hash); 273 } 274 275 // Increment counter (32-bit increment of rightmost 32 bits) 276 static inline __m128i inc_counter(__m128i counter) { 277 __m128i one = _mm_set_epi32(1, 0, 0, 0); 278 __m128i mask = _mm_set_epi32(0xFFFFFFFF, 0, 0, 0); 279 280 // Extract the counter value 281 uint32_t ctr = _mm_extract_epi32(counter, 0); 282 ctr++; 283 284 // Insert back 285 counter = _mm_insert_epi32(counter, ctr, 0); 286 return counter; 287 } 288 289 // AES-256-GCM Encryption 290 int aes256_gcm_encrypt(aes256_gcm_context *ctx, 291 const uint8_t *iv, size_t iv_len, 292 const uint8_t *aad, size_t aad_len, 293 const uint8_t *plaintext, size_t pt_len, 294 uint8_t *ciphertext, 295 uint8_t *tag, size_t tag_len) { 296 if (tag_len > 16) { 297 return -1; 298 } 299 300 // Prepare initial counter block 301 __m128i counter; 302 if (iv_len == 12) { 303 // Standard case: IV is 96 bits 304 counter = _mm_set_epi32(1, 305 ((uint32_t*)iv)[2], 306 ((uint32_t*)iv)[1], 307 ((uint32_t*)iv)[0]); 308 } else { 309 // Non-standard IV length: use GHASH 310 __m128i hash = _mm_setzero_si128(); 311 size_t i; 312 for (i = 0; i + 16 <= iv_len; i += 16) { 313 __m128i block = _mm_loadu_si128((__m128i*)(iv + i)); 314 block = reverse_bytes(block); 315 hash = _mm_xor_si128(hash, block); 316 hash = gf_mul(hash, ctx->H); 317 } 318 if (i < iv_len) { 319 uint8_t temp[16] = {0}; 320 memcpy(temp, iv + i, iv_len - i); 321 __m128i block = _mm_loadu_si128((__m128i*)temp); 322 block = reverse_bytes(block); 323 hash = _mm_xor_si128(hash, block); 324 hash = gf_mul(hash, ctx->H); 325 } 326 uint64_t iv_bits = iv_len * 8; 327 __m128i len_block = _mm_set_epi64x(iv_bits, 0); 328 len_block = reverse_bytes(len_block); 329 hash = _mm_xor_si128(hash, len_block); 330 hash = gf_mul(hash, ctx->H); 331 counter = reverse_bytes(hash); 332 } 333 334 // Save J0 for tag computation 335 __m128i J0 = counter; 336 337 // Encrypt plaintext 338 size_t i; 339 for (i = 0; i + 16 <= pt_len; i += 16) { 340 counter = inc_counter(counter); 341 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 342 __m128i pt_block = _mm_loadu_si128((__m128i*)(plaintext + i)); 343 __m128i ct_block = _mm_xor_si128(pt_block, keystream); 344 _mm_storeu_si128((__m128i*)(ciphertext + i), ct_block); 345 } 346 347 // Handle remaining bytes 348 if (i < pt_len) { 349 counter = inc_counter(counter); 350 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 351 uint8_t ks_bytes[16]; 352 _mm_storeu_si128((__m128i*)ks_bytes, keystream); 353 for (size_t j = 0; j < pt_len - i; j++) { 354 ciphertext[i + j] = plaintext[i + j] ^ ks_bytes[j]; 355 } 356 } 357 358 // Compute authentication tag 359 __m128i auth_tag; 360 ghash(ctx, aad, aad_len, ciphertext, pt_len, &auth_tag); 361 362 // Encrypt the tag with J0 363 __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule); 364 auth_tag = _mm_xor_si128(auth_tag, J0_keystream); 365 366 // Store tag 367 _mm_storeu_si128((__m128i*)tag, auth_tag); 368 369 return 0; 370 } 371 372 // AES-256-GCM Decryption 373 int aes256_gcm_decrypt(aes256_gcm_context *ctx, 374 const uint8_t *iv, size_t iv_len, 375 const uint8_t *aad, size_t aad_len, 376 const uint8_t *ciphertext, size_t ct_len, 377 const uint8_t *tag, size_t tag_len, 378 uint8_t *plaintext) { 379 if (tag_len > 16) { 380 return -1; 381 } 382 383 // Prepare initial counter block (same as encryption) 384 __m128i counter; 385 if (iv_len == 12) { 386 counter = _mm_set_epi32(1, 387 ((uint32_t*)iv)[2], 388 ((uint32_t*)iv)[1], 389 ((uint32_t*)iv)[0]); 390 } else { 391 __m128i hash = _mm_setzero_si128(); 392 size_t i; 393 for (i = 0; i + 16 <= iv_len; i += 16) { 394 __m128i block = _mm_loadu_si128((__m128i*)(iv + i)); 395 block = reverse_bytes(block); 396 hash = _mm_xor_si128(hash, block); 397 hash = gf_mul(hash, ctx->H); 398 } 399 if (i < iv_len) { 400 uint8_t temp[16] = {0}; 401 memcpy(temp, iv + i, iv_len - i); 402 __m128i block = _mm_loadu_si128((__m128i*)temp); 403 block = reverse_bytes(block); 404 hash = _mm_xor_si128(hash, block); 405 hash = gf_mul(hash, ctx->H); 406 } 407 uint64_t iv_bits = iv_len * 8; 408 __m128i len_block = _mm_set_epi64x(iv_bits, 0); 409 len_block = reverse_bytes(len_block); 410 hash = _mm_xor_si128(hash, len_block); 411 hash = gf_mul(hash, ctx->H); 412 counter = reverse_bytes(hash); 413 } 414 415 __m128i J0 = counter; 416 417 // Verify authentication tag first 418 __m128i computed_tag; 419 ghash(ctx, aad, aad_len, ciphertext, ct_len, &computed_tag); 420 __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule); 421 computed_tag = _mm_xor_si128(computed_tag, J0_keystream); 422 423 // Compare tags (constant-time comparison) 424 uint8_t computed_tag_bytes[16]; 425 _mm_storeu_si128((__m128i*)computed_tag_bytes, computed_tag); 426 427 int tag_match = 1; 428 for (size_t i = 0; i < tag_len; i++) { 429 tag_match &= (computed_tag_bytes[i] == tag[i]); 430 } 431 432 if (!tag_match) { 433 return -1; // Authentication failed 434 } 435 436 // Decrypt ciphertext 437 size_t i; 438 for (i = 0; i + 16 <= ct_len; i += 16) { 439 counter = inc_counter(counter); 440 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 441 __m128i ct_block = _mm_loadu_si128((__m128i*)(ciphertext + i)); 442 __m128i pt_block = _mm_xor_si128(ct_block, keystream); 443 _mm_storeu_si128((__m128i*)(plaintext + i), pt_block); 444 } 445 446 // Handle remaining bytes 447 if (i < ct_len) { 448 counter = inc_counter(counter); 449 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 450 uint8_t ks_bytes[16]; 451 _mm_storeu_si128((__m128i*)ks_bytes, keystream); 452 for (size_t j = 0; j < ct_len - i; j++) { 453 plaintext[i + j] = ciphertext[i + j] ^ ks_bytes[j]; 454 } 455 } 456 457 return 0; 458 } 459 460 /** 461 * Clean up AES-256-GCM context 462 * Zeros all sensitive key material 463 */ 464 void aes256_gcm_cleanup(aes256_gcm_context *ctx) { 465 if (ctx == NULL) return; 466 467 // Zero all sensitive data using volatile to prevent compiler optimization 468 volatile uint8_t *p = (volatile uint8_t *)ctx; 469 size_t n = sizeof(aes256_gcm_context); 470 while (n--) { 471 *p++ = 0; 472 } 473 }