crypto_aes_lib.c (10772B)
1 /* 2 * crypto_aes_lib.c - AES-256-GCM library (extracted core functions) 3 */ 4 5 #include "crypto_core.h" 6 #include <stdio.h> 7 #include <string.h> 8 #include <immintrin.h> 9 #include <wmmintrin.h> 10 11 #ifdef __GNUC__ 12 #include <cpuid.h> 13 #endif 14 15 int crypto_has_aesni = 0; 16 int crypto_has_pclmulqdq = 0; 17 18 void crypto_detect_features(void) { 19 #ifdef __GNUC__ 20 unsigned int eax, ebx, ecx, edx; 21 if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { 22 crypto_has_aesni = (ecx & bit_AES) != 0; 23 crypto_has_pclmulqdq = (ecx & bit_PCLMUL) != 0; 24 } 25 #endif 26 } 27 28 static inline __m128i reverse_bytes(__m128i x) { 29 const __m128i mask = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15); 30 return _mm_shuffle_epi8(x, mask); 31 } 32 33 static void aes256_key_expansion(const uint8_t *key, aes256_key_schedule *ks) { 34 __m128i temp1, temp2, temp3; 35 ks->nr = 14; 36 37 temp1 = _mm_loadu_si128((__m128i*)key); 38 temp2 = _mm_loadu_si128((__m128i*)(key + 16)); 39 ks->round_keys[0] = temp1; 40 ks->round_keys[1] = temp2; 41 42 #define AES_256_ASSIST_1(temp1, temp2, temp3) do { \ 43 temp2 = _mm_aeskeygenassist_si128(temp2, 0x0); \ 44 temp3 = _mm_shuffle_epi32(temp2, 0xff); \ 45 temp2 = _mm_slli_si128(temp1, 0x4); \ 46 temp1 = _mm_xor_si128(temp1, temp2); \ 47 temp2 = _mm_slli_si128(temp2, 0x4); \ 48 temp1 = _mm_xor_si128(temp1, temp2); \ 49 temp2 = _mm_slli_si128(temp2, 0x4); \ 50 temp1 = _mm_xor_si128(temp1, temp2); \ 51 temp1 = _mm_xor_si128(temp1, temp3); \ 52 } while(0) 53 54 #define AES_256_ASSIST_2(temp1, temp2, temp3) do { \ 55 temp3 = _mm_aeskeygenassist_si128(temp1, 0x0); \ 56 temp3 = _mm_shuffle_epi32(temp3, 0xaa); \ 57 temp1 = _mm_slli_si128(temp2, 0x4); \ 58 temp2 = _mm_xor_si128(temp2, temp1); \ 59 temp1 = _mm_slli_si128(temp1, 0x4); \ 60 temp2 = _mm_xor_si128(temp2, temp1); \ 61 temp1 = _mm_slli_si128(temp1, 0x4); \ 62 temp2 = _mm_xor_si128(temp2, temp1); \ 63 temp2 = _mm_xor_si128(temp2, temp3); \ 64 } while(0) 65 66 AES_256_ASSIST_1(temp1, temp2, temp3); 67 ks->round_keys[2] = temp1; 68 AES_256_ASSIST_2(temp1, temp2, temp3); 69 ks->round_keys[3] = temp2; 70 71 temp3 = _mm_aeskeygenassist_si128(temp2, 0x01); 72 AES_256_ASSIST_1(temp1, temp2, temp3); 73 ks->round_keys[4] = temp1; 74 AES_256_ASSIST_2(temp1, temp2, temp3); 75 ks->round_keys[5] = temp2; 76 77 temp3 = _mm_aeskeygenassist_si128(temp2, 0x02); 78 AES_256_ASSIST_1(temp1, temp2, temp3); 79 ks->round_keys[6] = temp1; 80 AES_256_ASSIST_2(temp1, temp2, temp3); 81 ks->round_keys[7] = temp2; 82 83 temp3 = _mm_aeskeygenassist_si128(temp2, 0x04); 84 AES_256_ASSIST_1(temp1, temp2, temp3); 85 ks->round_keys[8] = temp1; 86 AES_256_ASSIST_2(temp1, temp2, temp3); 87 ks->round_keys[9] = temp2; 88 89 temp3 = _mm_aeskeygenassist_si128(temp2, 0x08); 90 AES_256_ASSIST_1(temp1, temp2, temp3); 91 ks->round_keys[10] = temp1; 92 AES_256_ASSIST_2(temp1, temp2, temp3); 93 ks->round_keys[11] = temp2; 94 95 temp3 = _mm_aeskeygenassist_si128(temp2, 0x10); 96 AES_256_ASSIST_1(temp1, temp2, temp3); 97 ks->round_keys[12] = temp1; 98 AES_256_ASSIST_2(temp1, temp2, temp3); 99 ks->round_keys[13] = temp2; 100 101 temp3 = _mm_aeskeygenassist_si128(temp2, 0x20); 102 AES_256_ASSIST_1(temp1, temp2, temp3); 103 ks->round_keys[14] = temp1; 104 } 105 106 static inline __m128i aes256_encrypt_block(__m128i plaintext, const aes256_key_schedule *ks) { 107 __m128i tmp = _mm_xor_si128(plaintext, ks->round_keys[0]); 108 for (int i = 1; i < 14; i++) { 109 tmp = _mm_aesenc_si128(tmp, ks->round_keys[i]); 110 } 111 tmp = _mm_aesenclast_si128(tmp, ks->round_keys[14]); 112 return tmp; 113 } 114 115 static inline __m128i gf_mul(__m128i a, __m128i b) { 116 __m128i tmp3 = _mm_clmulepi64_si128(a, b, 0x00); 117 __m128i tmp6 = _mm_clmulepi64_si128(a, b, 0x11); 118 __m128i tmp4 = _mm_shuffle_epi32(a, 78); 119 __m128i tmp5 = _mm_shuffle_epi32(b, 78); 120 tmp4 = _mm_xor_si128(tmp4, a); 121 tmp5 = _mm_xor_si128(tmp5, b); 122 tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00); 123 tmp4 = _mm_xor_si128(tmp4, tmp3); 124 tmp4 = _mm_xor_si128(tmp4, tmp6); 125 tmp5 = _mm_slli_si128(tmp4, 8); 126 tmp4 = _mm_srli_si128(tmp4, 8); 127 tmp3 = _mm_xor_si128(tmp3, tmp5); 128 tmp6 = _mm_xor_si128(tmp6, tmp4); 129 130 __m128i tmp7 = _mm_srli_epi32(tmp3, 31); 131 __m128i tmp8 = _mm_srli_epi32(tmp6, 31); 132 tmp3 = _mm_slli_epi32(tmp3, 1); 133 tmp6 = _mm_slli_epi32(tmp6, 1); 134 __m128i tmp9 = _mm_srli_si128(tmp7, 12); 135 tmp8 = _mm_slli_si128(tmp8, 4); 136 tmp7 = _mm_slli_si128(tmp7, 4); 137 tmp3 = _mm_or_si128(tmp3, tmp7); 138 tmp6 = _mm_or_si128(tmp6, tmp8); 139 tmp6 = _mm_or_si128(tmp6, tmp9); 140 141 tmp7 = _mm_slli_epi32(tmp3, 31); 142 tmp8 = _mm_slli_epi32(tmp3, 30); 143 tmp9 = _mm_slli_epi32(tmp3, 25); 144 tmp7 = _mm_xor_si128(tmp7, tmp8); 145 tmp7 = _mm_xor_si128(tmp7, tmp9); 146 tmp8 = _mm_srli_si128(tmp7, 4); 147 tmp7 = _mm_slli_si128(tmp7, 12); 148 tmp3 = _mm_xor_si128(tmp3, tmp7); 149 __m128i tmp2 = _mm_srli_epi32(tmp3, 1); 150 tmp4 = _mm_srli_epi32(tmp3, 2); 151 tmp5 = _mm_srli_epi32(tmp3, 7); 152 tmp2 = _mm_xor_si128(tmp2, tmp4); 153 tmp2 = _mm_xor_si128(tmp2, tmp5); 154 tmp2 = _mm_xor_si128(tmp2, tmp8); 155 tmp3 = _mm_xor_si128(tmp3, tmp2); 156 tmp6 = _mm_xor_si128(tmp6, tmp3); 157 return tmp6; 158 } 159 160 int aes256_gcm_init(aes256_gcm_context *ctx, const uint8_t *key) { 161 if (!crypto_has_aesni || !crypto_has_pclmulqdq) return -1; 162 163 aes256_key_expansion(key, &ctx->key_schedule); 164 __m128i zero = _mm_setzero_si128(); 165 ctx->H = aes256_encrypt_block(zero, &ctx->key_schedule); 166 ctx->H = reverse_bytes(ctx->H); 167 168 ctx->H_powers[0] = ctx->H; 169 for (int i = 1; i < 8; i++) { 170 ctx->H_powers[i] = gf_mul(ctx->H_powers[i-1], ctx->H); 171 } 172 return 0; 173 } 174 175 static void ghash(const aes256_gcm_context *ctx, const uint8_t *aad, size_t aad_len, 176 const uint8_t *ciphertext, size_t ct_len, __m128i *tag) { 177 __m128i hash = _mm_setzero_si128(); 178 size_t i; 179 180 for (i = 0; i + 16 <= aad_len; i += 16) { 181 __m128i block = _mm_loadu_si128((__m128i*)(aad + i)); 182 block = reverse_bytes(block); 183 hash = _mm_xor_si128(hash, block); 184 hash = gf_mul(hash, ctx->H); 185 } 186 if (i < aad_len) { 187 uint8_t temp[16] = {0}; 188 memcpy(temp, aad + i, aad_len - i); 189 __m128i block = _mm_loadu_si128((__m128i*)temp); 190 block = reverse_bytes(block); 191 hash = _mm_xor_si128(hash, block); 192 hash = gf_mul(hash, ctx->H); 193 } 194 195 for (i = 0; i + 16 <= ct_len; i += 16) { 196 __m128i block = _mm_loadu_si128((__m128i*)(ciphertext + i)); 197 block = reverse_bytes(block); 198 hash = _mm_xor_si128(hash, block); 199 hash = gf_mul(hash, ctx->H); 200 } 201 if (i < ct_len) { 202 uint8_t temp[16] = {0}; 203 memcpy(temp, ciphertext + i, ct_len - i); 204 __m128i block = _mm_loadu_si128((__m128i*)temp); 205 block = reverse_bytes(block); 206 hash = _mm_xor_si128(hash, block); 207 hash = gf_mul(hash, ctx->H); 208 } 209 210 uint64_t aad_bits = aad_len * 8; 211 uint64_t ct_bits = ct_len * 8; 212 __m128i len_block = _mm_set_epi64x(ct_bits, aad_bits); 213 len_block = reverse_bytes(len_block); 214 hash = _mm_xor_si128(hash, len_block); 215 hash = gf_mul(hash, ctx->H); 216 *tag = reverse_bytes(hash); 217 } 218 219 static inline __m128i inc_counter(__m128i counter) { 220 uint32_t ctr = _mm_extract_epi32(counter, 0); 221 ctr++; 222 return _mm_insert_epi32(counter, ctr, 0); 223 } 224 225 int aes256_gcm_encrypt(aes256_gcm_context *ctx, const uint8_t *iv, size_t iv_len, 226 const uint8_t *aad, size_t aad_len, 227 const uint8_t *plaintext, size_t pt_len, 228 uint8_t *ciphertext, uint8_t *tag, size_t tag_len) { 229 if (tag_len > 16) return -1; 230 231 __m128i counter; 232 if (iv_len == 12) { 233 counter = _mm_set_epi32(1, ((uint32_t*)iv)[2], ((uint32_t*)iv)[1], ((uint32_t*)iv)[0]); 234 } else { 235 return -1; // Simplified - only support 96-bit IV 236 } 237 238 __m128i J0 = counter; 239 size_t i; 240 241 for (i = 0; i + 16 <= pt_len; i += 16) { 242 counter = inc_counter(counter); 243 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 244 __m128i pt_block = _mm_loadu_si128((__m128i*)(plaintext + i)); 245 __m128i ct_block = _mm_xor_si128(pt_block, keystream); 246 _mm_storeu_si128((__m128i*)(ciphertext + i), ct_block); 247 } 248 249 if (i < pt_len) { 250 counter = inc_counter(counter); 251 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 252 uint8_t ks_bytes[16]; 253 _mm_storeu_si128((__m128i*)ks_bytes, keystream); 254 for (size_t j = 0; j < pt_len - i; j++) { 255 ciphertext[i + j] = plaintext[i + j] ^ ks_bytes[j]; 256 } 257 } 258 259 __m128i auth_tag; 260 ghash(ctx, aad, aad_len, ciphertext, pt_len, &auth_tag); 261 __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule); 262 auth_tag = _mm_xor_si128(auth_tag, J0_keystream); 263 _mm_storeu_si128((__m128i*)tag, auth_tag); 264 265 return 0; 266 } 267 268 int aes256_gcm_decrypt(aes256_gcm_context *ctx, const uint8_t *iv, size_t iv_len, 269 const uint8_t *aad, size_t aad_len, 270 const uint8_t *ciphertext, size_t ct_len, 271 const uint8_t *tag, size_t tag_len, uint8_t *plaintext) { 272 if (tag_len > 16) return -1; 273 274 __m128i counter; 275 if (iv_len == 12) { 276 counter = _mm_set_epi32(1, ((uint32_t*)iv)[2], ((uint32_t*)iv)[1], ((uint32_t*)iv)[0]); 277 } else { 278 return -1; 279 } 280 281 __m128i J0 = counter; 282 283 __m128i computed_tag; 284 ghash(ctx, aad, aad_len, ciphertext, ct_len, &computed_tag); 285 __m128i J0_keystream = aes256_encrypt_block(J0, &ctx->key_schedule); 286 computed_tag = _mm_xor_si128(computed_tag, J0_keystream); 287 288 uint8_t computed_tag_bytes[16]; 289 _mm_storeu_si128((__m128i*)computed_tag_bytes, computed_tag); 290 291 int tag_match = 1; 292 for (size_t i = 0; i < tag_len; i++) { 293 tag_match &= (computed_tag_bytes[i] == tag[i]); 294 } 295 if (!tag_match) return -1; 296 297 size_t i; 298 for (i = 0; i + 16 <= ct_len; i += 16) { 299 counter = inc_counter(counter); 300 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 301 __m128i ct_block = _mm_loadu_si128((__m128i*)(ciphertext + i)); 302 __m128i pt_block = _mm_xor_si128(ct_block, keystream); 303 _mm_storeu_si128((__m128i*)(plaintext + i), pt_block); 304 } 305 306 if (i < ct_len) { 307 counter = inc_counter(counter); 308 __m128i keystream = aes256_encrypt_block(counter, &ctx->key_schedule); 309 uint8_t ks_bytes[16]; 310 _mm_storeu_si128((__m128i*)ks_bytes, keystream); 311 for (size_t j = 0; j < ct_len - i; j++) { 312 plaintext[i + j] = ciphertext[i + j] ^ ks_bytes[j]; 313 } 314 } 315 316 return 0; 317 }