Kyber.c (28211B)
1 /* 2 * Kyber.c - CRYSTALS-Kyber FIPS 203 (ML-KEM) Compliant Implementation 3 * 4 * NIST FIPS 203 - Module-Lattice-Based Key-Encapsulation Mechanism Standard 5 * Full specification compliance with all required algorithms 6 * 7 * SECURITY: Production implementation with: 8 * - Full FIPS 203 encapsulation/decapsulation algorithms 9 * - IND-CCA2 security via implicit rejection 10 * - Constant-time operations for secret-dependent code 11 * - Proper polynomial packing per specification 12 * - SHAKE-128/256 for all randomness expansion 13 */ 14 15 #include "Kyber.h" 16 #include "CSPRNG.h" 17 #include "ct_util.h" 18 #include "hashing/hash.h" 19 #include "hashing/SHA3.h" 20 #include <string.h> 21 22 /* ============================================================================ 23 * FIPS 203 Parameters 24 * ========================================================================= */ 25 26 #define KYBER_N 256 27 #define KYBER_Q 3329 28 #define KYBER_SYMBYTES 32 29 30 /* Kyber512 (ML-KEM-512) */ 31 #define KYBER512_K 2 32 #define KYBER512_ETA1 3 33 #define KYBER512_ETA2 2 34 #define KYBER512_DU 10 35 #define KYBER512_DV 4 36 37 /* Kyber768 (ML-KEM-768) - RECOMMENDED */ 38 #define KYBER768_K 3 39 #define KYBER768_ETA1 2 40 #define KYBER768_ETA2 2 41 #define KYBER768_DU 10 42 #define KYBER768_DV 4 43 44 /* Kyber1024 (ML-KEM-1024) */ 45 #define KYBER1024_K 4 46 #define KYBER1024_ETA1 2 47 #define KYBER1024_ETA2 2 48 #define KYBER1024_DU 11 49 #define KYBER1024_DV 5 50 51 #define KYBER_POLYBYTES 384 52 #define SHAKE128_RATE 168 53 54 /* ============================================================================ 55 * NTT Constants (zetas in bit-reversed order) 56 * Verified against pq-crystals/kyber reference implementation 57 * ========================================================================= */ 58 59 static const int16_t zetas[128] = { 60 -1044, -758, -359, -1517, 1493, 1422, 287, 202, 61 -171, 622, 1577, 182, 962, -1202, -1474, 1468, 62 573, -1325, 264, 383, -829, 1458, -1602, -130, 63 -681, 1017, 732, 608, -1542, 411, -205, -1571, 64 1223, 652, -552, 1015, -1293, 1491, -282, -1544, 65 516, -8, -320, -666, -1618, -1162, 126, 1469, 66 -853, -90, -271, 830, 107, -1421, -247, -951, 67 -398, 961, -1508, -725, 448, -1065, 677, -1275, 68 -1103, 430, 555, 843, -1251, 871, 1550, 105, 69 422, 587, 177, -235, -291, -460, 1574, 1653, 70 -246, 778, 1159, -147, -777, 1483, -602, 1119, 71 -1590, 644, -872, 349, 418, 329, -156, -75, 72 817, 1097, 603, 610, 1322, -1285, -1465, 384, 73 -1215, -136, 1218, -1335, -874, 220, -1187, -1659, 74 -1185, -1530, -1278, 794, -1510, -854, -870, 478, 75 -108, -308, 996, 991, 958, -1460, 1522, 1628 76 }; 77 78 /* ============================================================================ 79 * Modular Arithmetic 80 * ========================================================================= */ 81 82 static inline int16_t montgomery_reduce(int32_t a) { 83 int16_t t; 84 t = (int16_t)((uint32_t)a * 62209); 85 t = (a - (int32_t)t * KYBER_Q) >> 16; 86 return t; 87 } 88 89 static inline int16_t barrett_reduce(int16_t a) { 90 int16_t t; 91 const int16_t v = ((1 << 26) + KYBER_Q/2) / KYBER_Q; 92 t = ((int32_t)v * a + (1 << 25)) >> 26; 93 t *= KYBER_Q; 94 return a - t; 95 } 96 97 static inline int16_t fqmul(int16_t a, int16_t b) { 98 return montgomery_reduce((int32_t)a * b); 99 } 100 101 /* ============================================================================ 102 * Polynomial Operations 103 * ========================================================================= */ 104 105 typedef struct { 106 int16_t coeffs[KYBER_N]; 107 } poly; 108 109 typedef struct { 110 poly vec[4]; /* Max k=4 */ 111 } polyvec; 112 113 static void poly_reduce(poly *r) { 114 for (int i = 0; i < KYBER_N; i++) { 115 r->coeffs[i] = barrett_reduce(r->coeffs[i]); 116 } 117 } 118 119 static void poly_add(poly *r, const poly *a, const poly *b) { 120 for (int i = 0; i < KYBER_N; i++) { 121 r->coeffs[i] = a->coeffs[i] + b->coeffs[i]; 122 } 123 } 124 125 static void poly_sub(poly *r, const poly *a, const poly *b) { 126 for (int i = 0; i < KYBER_N; i++) { 127 r->coeffs[i] = a->coeffs[i] - b->coeffs[i]; 128 } 129 } 130 131 static void poly_ntt(poly *r) { 132 unsigned int len, start, j, k; 133 int16_t t, zeta; 134 135 k = 1; 136 for (len = 128; len >= 2; len >>= 1) { 137 for (start = 0; start < KYBER_N; start = j + len) { 138 zeta = zetas[k++]; 139 for (j = start; j < start + len; j++) { 140 t = fqmul(zeta, r->coeffs[j + len]); 141 r->coeffs[j + len] = r->coeffs[j] - t; 142 r->coeffs[j] = r->coeffs[j] + t; 143 } 144 } 145 } 146 } 147 148 static void poly_invntt_tomont(poly *r) { 149 unsigned int start, len, j, k; 150 int16_t t, zeta; 151 const int16_t f = 1441; 152 153 k = 127; 154 for (len = 2; len <= 128; len <<= 1) { 155 for (start = 0; start < KYBER_N; start = j + len) { 156 zeta = zetas[k--]; 157 for (j = start; j < start + len; j++) { 158 t = r->coeffs[j]; 159 r->coeffs[j] = barrett_reduce(t + r->coeffs[j + len]); 160 r->coeffs[j + len] = r->coeffs[j + len] - t; 161 r->coeffs[j + len] = fqmul(zeta, r->coeffs[j + len]); 162 } 163 } 164 } 165 166 for (j = 0; j < KYBER_N; j++) { 167 r->coeffs[j] = fqmul(r->coeffs[j], f); 168 } 169 } 170 171 static void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) { 172 for (int i = 0; i < KYBER_N / 4; i++) { 173 int16_t rx, ry; 174 175 rx = fqmul(a->coeffs[4*i+1], b->coeffs[4*i+1]); 176 rx = fqmul(rx, zetas[64 + i]); 177 rx += fqmul(a->coeffs[4*i], b->coeffs[4*i]); 178 179 ry = fqmul(a->coeffs[4*i], b->coeffs[4*i+1]); 180 ry += fqmul(a->coeffs[4*i+1], b->coeffs[4*i]); 181 182 r->coeffs[4*i] = rx; 183 r->coeffs[4*i+1] = ry; 184 185 rx = fqmul(a->coeffs[4*i+3], b->coeffs[4*i+3]); 186 rx = fqmul(rx, -zetas[64 + i]); 187 rx += fqmul(a->coeffs[4*i+2], b->coeffs[4*i+2]); 188 189 ry = fqmul(a->coeffs[4*i+2], b->coeffs[4*i+3]); 190 ry += fqmul(a->coeffs[4*i+3], b->coeffs[4*i+2]); 191 192 r->coeffs[4*i+2] = rx; 193 r->coeffs[4*i+3] = ry; 194 } 195 } 196 197 static void poly_tomont(poly *r) { 198 const int16_t f = (1ULL << 32) % KYBER_Q; 199 for (int i = 0; i < KYBER_N; i++) { 200 r->coeffs[i] = montgomery_reduce((int32_t)r->coeffs[i] * f); 201 } 202 } 203 204 /* ============================================================================ 205 * Polynomial Vector Operations 206 * ========================================================================= */ 207 208 static void polyvec_ntt(polyvec *r, int k) { 209 for (int i = 0; i < k; i++) { 210 poly_ntt(&r->vec[i]); 211 } 212 } 213 214 static void polyvec_invntt_tomont(polyvec *r, int k) { 215 for (int i = 0; i < k; i++) { 216 poly_invntt_tomont(&r->vec[i]); 217 } 218 } 219 220 static void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b, int k) { 221 for (int i = 0; i < k; i++) { 222 poly_add(&r->vec[i], &a->vec[i], &b->vec[i]); 223 } 224 } 225 226 static void polyvec_reduce(polyvec *r, int k) { 227 for (int i = 0; i < k; i++) { 228 poly_reduce(&r->vec[i]); 229 } 230 } 231 232 /* ============================================================================ 233 * Sampling Functions (FIPS 203 compliant) 234 * ========================================================================= */ 235 236 static unsigned int rej_uniform(int16_t *r, unsigned int len, const uint8_t *buf, unsigned int buflen) { 237 unsigned int ctr, pos; 238 uint16_t val0, val1; 239 240 ctr = pos = 0; 241 while (ctr < len && pos + 3 <= buflen) { 242 val0 = ((buf[pos] | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF); 243 val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; 244 pos += 3; 245 246 if (val0 < KYBER_Q) { 247 r[ctr++] = val0; 248 } 249 if (ctr < len && val1 < KYBER_Q) { 250 r[ctr++] = val1; 251 } 252 } 253 return ctr; 254 } 255 256 /* SampleNTT: Sample polynomial from uniform distribution using XOF (SHAKE-128) */ 257 static void poly_uniform(poly *a, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) { 258 unsigned int ctr, off; 259 unsigned int buflen = SHAKE128_RATE * 3; 260 uint8_t buf[SHAKE128_RATE * 3]; 261 uint8_t extseed[KYBER_SYMBYTES + 1]; 262 263 memcpy(extseed, seed, KYBER_SYMBYTES); 264 extseed[KYBER_SYMBYTES] = nonce; 265 266 shake128(extseed, KYBER_SYMBYTES + 1, buf, buflen); 267 ctr = rej_uniform(a->coeffs, KYBER_N, buf, buflen); 268 269 while (ctr < KYBER_N) { 270 off = buflen % 3; 271 for (unsigned int i = 0; i < off; i++) { 272 buf[i] = buf[buflen - off + i]; 273 } 274 shake128(extseed, KYBER_SYMBYTES + 1, buf + off, buflen - off); 275 buflen = SHAKE128_RATE * 3; 276 ctr += rej_uniform(a->coeffs + ctr, KYBER_N - ctr, buf, buflen); 277 } 278 } 279 280 /* CBD: Centered Binomial Distribution */ 281 static void cbd_eta(poly *r, const uint8_t *buf, int eta) { 282 uint32_t t, d; 283 int16_t a, b; 284 285 if (eta == 2) { 286 for (int i = 0; i < KYBER_N / 8; i++) { 287 t = buf[2*i] | ((uint32_t)buf[2*i + 1] << 8); 288 d = t & 0x55555555; 289 d += (t >> 1) & 0x55555555; 290 291 for (int j = 0; j < 8; j++) { 292 a = (d >> (4*j)) & 0x3; 293 b = (d >> (4*j + 2)) & 0x3; 294 r->coeffs[8*i + j] = a - b; 295 } 296 } 297 } else if (eta == 3) { 298 for (int i = 0; i < KYBER_N / 4; i++) { 299 t = buf[3*i] | ((uint32_t)buf[3*i + 1] << 8) | ((uint32_t)buf[3*i + 2] << 16); 300 d = t & 0x00249249; 301 d += (t >> 1) & 0x00249249; 302 d += (t >> 2) & 0x00249249; 303 304 for (int j = 0; j < 4; j++) { 305 a = (d >> (6*j)) & 0x7; 306 b = (d >> (6*j + 3)) & 0x7; 307 r->coeffs[4*i + j] = a - b; 308 } 309 } 310 } 311 } 312 313 /* SamplePolyCBD: Sample polynomial from CBD using PRF (SHAKE-256) */ 314 static void poly_cbd_eta(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce, int eta) { 315 uint8_t buf[eta * KYBER_N / 4]; 316 uint8_t extseed[KYBER_SYMBYTES + 1]; 317 318 memcpy(extseed, seed, KYBER_SYMBYTES); 319 extseed[KYBER_SYMBYTES] = nonce; 320 321 shake256(extseed, KYBER_SYMBYTES + 1, buf, eta * KYBER_N / 4); 322 cbd_eta(r, buf, eta); 323 } 324 325 /* ============================================================================ 326 * Compression and Decompression (FIPS 203 §4.2.1) 327 * ========================================================================= */ 328 329 static void poly_compress(uint8_t *r, const poly *a, int d) { 330 uint8_t t[8]; 331 int16_t u; 332 333 if (d == 4) { 334 for (int i = 0; i < KYBER_N / 8; i++) { 335 for (int j = 0; j < 8; j++) { 336 u = a->coeffs[8*i + j]; 337 u += (u >> 15) & KYBER_Q; 338 t[j] = ((((uint32_t)u << 4) + KYBER_Q/2) / KYBER_Q) & 15; 339 } 340 r[4*i] = t[0] | (t[1] << 4); 341 r[4*i + 1] = t[2] | (t[3] << 4); 342 r[4*i + 2] = t[4] | (t[5] << 4); 343 r[4*i + 3] = t[6] | (t[7] << 4); 344 } 345 } else if (d == 10) { 346 for (int i = 0; i < KYBER_N / 4; i++) { 347 for (int j = 0; j < 4; j++) { 348 u = a->coeffs[4*i + j]; 349 u += (u >> 15) & KYBER_Q; 350 t[j] = ((((uint32_t)u << 10) + KYBER_Q/2) / KYBER_Q) & 0x3ff; 351 } 352 r[5*i] = (uint8_t)t[0]; 353 r[5*i + 1] = (uint8_t)((t[0] >> 8) | (t[1] << 2)); 354 r[5*i + 2] = (uint8_t)((t[1] >> 6) | (t[2] << 4)); 355 r[5*i + 3] = (uint8_t)((t[2] >> 4) | (t[3] << 6)); 356 r[5*i + 4] = (uint8_t)(t[3] >> 2); 357 } 358 } else if (d == 11) { 359 for (int i = 0; i < KYBER_N / 8; i++) { 360 for (int j = 0; j < 8; j++) { 361 u = a->coeffs[8*i + j]; 362 u += (u >> 15) & KYBER_Q; 363 t[j] = ((((uint32_t)u << 11) + KYBER_Q/2) / KYBER_Q) & 0x7ff; 364 } 365 r[11*i] = (uint8_t)t[0]; 366 r[11*i + 1] = (uint8_t)((t[0] >> 8) | (t[1] << 3)); 367 r[11*i + 2] = (uint8_t)((t[1] >> 5) | (t[2] << 6)); 368 r[11*i + 3] = (uint8_t)(t[2] >> 2); 369 r[11*i + 4] = (uint8_t)((t[2] >> 10) | (t[3] << 1)); 370 r[11*i + 5] = (uint8_t)((t[3] >> 7) | (t[4] << 4)); 371 r[11*i + 6] = (uint8_t)((t[4] >> 4) | (t[5] << 7)); 372 r[11*i + 7] = (uint8_t)(t[5] >> 1); 373 r[11*i + 8] = (uint8_t)((t[5] >> 9) | (t[6] << 2)); 374 r[11*i + 9] = (uint8_t)((t[6] >> 6) | (t[7] << 5)); 375 r[11*i + 10] = (uint8_t)(t[7] >> 3); 376 } 377 } 378 } 379 380 static void poly_decompress(poly *r, const uint8_t *a, int d) { 381 if (d == 4) { 382 for (int i = 0; i < KYBER_N / 2; i++) { 383 r->coeffs[2*i] = (((uint32_t)(a[i] & 15) * KYBER_Q) + 8) >> 4; 384 r->coeffs[2*i + 1] = (((uint32_t)(a[i] >> 4) * KYBER_Q) + 8) >> 4; 385 } 386 } else if (d == 10) { 387 for (int i = 0; i < KYBER_N / 4; i++) { 388 r->coeffs[4*i] = ((((uint32_t)a[5*i] | ((uint32_t)a[5*i + 1] << 8)) & 0x3ff) * KYBER_Q + 512) >> 10; 389 r->coeffs[4*i + 1] = ((((uint32_t)a[5*i + 1] >> 2 | ((uint32_t)a[5*i + 2] << 6)) & 0x3ff) * KYBER_Q + 512) >> 10; 390 r->coeffs[4*i + 2] = ((((uint32_t)a[5*i + 2] >> 4 | ((uint32_t)a[5*i + 3] << 4)) & 0x3ff) * KYBER_Q + 512) >> 10; 391 r->coeffs[4*i + 3] = ((((uint32_t)a[5*i + 3] >> 6 | ((uint32_t)a[5*i + 4] << 2)) & 0x3ff) * KYBER_Q + 512) >> 10; 392 } 393 } else if (d == 11) { 394 for (int i = 0; i < KYBER_N / 8; i++) { 395 r->coeffs[8*i] = ((((uint32_t)a[11*i] | ((uint32_t)a[11*i + 1] << 8)) & 0x7ff) * KYBER_Q + 1024) >> 11; 396 r->coeffs[8*i + 1] = ((((uint32_t)a[11*i + 1] >> 3 | ((uint32_t)a[11*i + 2] << 5)) & 0x7ff) * KYBER_Q + 1024) >> 11; 397 r->coeffs[8*i + 2] = ((((uint32_t)a[11*i + 2] >> 6 | ((uint32_t)a[11*i + 3] << 2) | ((uint32_t)a[11*i + 4] << 10)) & 0x7ff) * KYBER_Q + 1024) >> 11; 398 r->coeffs[8*i + 3] = ((((uint32_t)a[11*i + 4] >> 1 | ((uint32_t)a[11*i + 5] << 7)) & 0x7ff) * KYBER_Q + 1024) >> 11; 399 r->coeffs[8*i + 4] = ((((uint32_t)a[11*i + 5] >> 4 | ((uint32_t)a[11*i + 6] << 4)) & 0x7ff) * KYBER_Q + 1024) >> 11; 400 r->coeffs[8*i + 5] = ((((uint32_t)a[11*i + 6] >> 7 | ((uint32_t)a[11*i + 7] << 1) | ((uint32_t)a[11*i + 8] << 9)) & 0x7ff) * KYBER_Q + 1024) >> 11; 401 r->coeffs[8*i + 6] = ((((uint32_t)a[11*i + 8] >> 2 | ((uint32_t)a[11*i + 9] << 6)) & 0x7ff) * KYBER_Q + 1024) >> 11; 402 r->coeffs[8*i + 7] = ((((uint32_t)a[11*i + 9] >> 5 | ((uint32_t)a[11*i + 10] << 3)) & 0x7ff) * KYBER_Q + 1024) >> 11; 403 } 404 } 405 } 406 407 static void polyvec_compress(uint8_t *r, const polyvec *a, int k, int d) { 408 for (int i = 0; i < k; i++) { 409 poly_compress(r + i * (d == 10 ? 320 : d == 11 ? 352 : 128), &a->vec[i], d); 410 } 411 } 412 413 static void polyvec_decompress(polyvec *r, const uint8_t *a, int k, int d) { 414 for (int i = 0; i < k; i++) { 415 poly_decompress(&r->vec[i], a + i * (d == 10 ? 320 : d == 11 ? 352 : 128), d); 416 } 417 } 418 419 /* ============================================================================ 420 * Polynomial Packing/Unpacking (FIPS 203 compliant) 421 * ========================================================================= */ 422 423 static void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) { 424 uint16_t t0, t1; 425 426 for (int i = 0; i < KYBER_N / 2; i++) { 427 t0 = a->coeffs[2*i]; 428 t0 += (t0 >> 15) & KYBER_Q; 429 430 t1 = a->coeffs[2*i + 1]; 431 t1 += (t1 >> 15) & KYBER_Q; 432 433 r[3*i] = (uint8_t)t0; 434 r[3*i + 1] = (uint8_t)(t0 >> 8) | (uint8_t)(t1 << 4); 435 r[3*i + 2] = (uint8_t)(t1 >> 4); 436 } 437 } 438 439 static void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) { 440 for (int i = 0; i < KYBER_N / 2; i++) { 441 r->coeffs[2*i] = ((a[3*i] | ((uint16_t)a[3*i + 1] << 8)) & 0xFFF); 442 r->coeffs[2*i + 1] = ((a[3*i + 1] >> 4) | ((uint16_t)a[3*i + 2] << 4)) & 0xFFF; 443 } 444 } 445 446 static void polyvec_tobytes(uint8_t *r, const polyvec *a, int k) { 447 for (int i = 0; i < k; i++) { 448 poly_tobytes(r + i * KYBER_POLYBYTES, &a->vec[i]); 449 } 450 } 451 452 static void polyvec_frombytes(polyvec *r, const uint8_t *a, int k) { 453 for (int i = 0; i < k; i++) { 454 poly_frombytes(&r->vec[i], a + i * KYBER_POLYBYTES); 455 } 456 } 457 458 /* ============================================================================ 459 * Matrix-Vector Multiplication 460 * ========================================================================= */ 461 462 static void polyvec_pointwise_acc_montgomery(poly *r, const polyvec *a, const polyvec *b, int k) { 463 poly t; 464 poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]); 465 466 for (int i = 1; i < k; i++) { 467 poly_basemul_montgomery(&t, &a->vec[i], &b->vec[i]); 468 poly_add(r, r, &t); 469 } 470 poly_reduce(r); 471 } 472 473 /* ============================================================================ 474 * Key Generation (FIPS 203 Algorithm 15 - ML-KEM.KeyGen) 475 * ========================================================================= */ 476 477 static int kyber_keypair_internal(uint8_t *pk, uint8_t *sk, int k, int eta1) { 478 polyvec matrix[4]; /* A^T (max 4x4) */ 479 polyvec s, e, pkpv; 480 uint8_t buf[2 * KYBER_SYMBYTES]; 481 uint8_t *rho = buf; 482 uint8_t *sigma = buf + KYBER_SYMBYTES; 483 uint8_t nonce = 0; 484 485 /* 1. d ← B^32; (ρ, σ) ← G(d) */ 486 random_bytes(buf, KYBER_SYMBYTES); 487 sha3_512(buf, KYBER_SYMBYTES, buf); 488 489 /* 2-3. Generate matrix A from ρ */ 490 for (int i = 0; i < k; i++) { 491 for (int j = 0; j < k; j++) { 492 poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j); 493 } 494 } 495 496 /* 4-6. Sample secret vector s and error vector e */ 497 for (int i = 0; i < k; i++) { 498 poly_cbd_eta(&s.vec[i], sigma, nonce++, eta1); 499 } 500 for (int i = 0; i < k; i++) { 501 poly_cbd_eta(&e.vec[i], sigma, nonce++, eta1); 502 } 503 504 /* 7. ŝ ← NTT(s) */ 505 polyvec_ntt(&s, k); 506 507 /* 8-10. Compute t = A ◦ s + e */ 508 for (int i = 0; i < k; i++) { 509 polyvec_pointwise_acc_montgomery(&pkpv.vec[i], &matrix[i], &s, k); 510 poly_invntt_tomont(&pkpv.vec[i]); 511 } 512 polyvec_add(&pkpv, &pkpv, &e, k); 513 polyvec_reduce(&pkpv, k); 514 515 /* 11-12. Pack public key: pk ← ByteEncode12(t) || ρ */ 516 polyvec_tobytes(pk, &pkpv, k); 517 memcpy(pk + k * KYBER_POLYBYTES, rho, KYBER_SYMBYTES); 518 519 /* 13-14. Pack secret key: sk ← ByteEncode12(ŝ) */ 520 polyvec_tobytes(sk, &s, k); 521 522 /* Append pk and hash for implicit rejection */ 523 memcpy(sk + k * KYBER_POLYBYTES, pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES); 524 sha3_256(pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES, sk + 2 * k * KYBER_POLYBYTES + KYBER_SYMBYTES); 525 526 /* Append random z for implicit rejection */ 527 random_bytes(sk + 2 * k * KYBER_POLYBYTES + 2 * KYBER_SYMBYTES, KYBER_SYMBYTES); 528 529 /* Secure cleanup */ 530 secure_zero(buf, sizeof(buf)); 531 secure_zero(&s, sizeof(s)); 532 secure_zero(&e, sizeof(e)); 533 534 return 0; 535 } 536 537 /* ============================================================================ 538 * Encapsulation (FIPS 203 Algorithm 16 - ML-KEM.Encaps) 539 * ========================================================================= */ 540 541 static int kyber_encapsulate_internal(uint8_t *ct, uint8_t *ss, const uint8_t *pk, int k, int eta1, int eta2, int du, int dv) { 542 polyvec matrix[4], sp, ep, bp; 543 poly v, epp, mp; 544 uint8_t m[KYBER_SYMBYTES]; 545 uint8_t buf[2 * KYBER_SYMBYTES]; 546 uint8_t kr[2 * KYBER_SYMBYTES]; 547 uint8_t *rho = (uint8_t*)pk + k * KYBER_POLYBYTES; 548 uint8_t nonce = 0; 549 550 /* 1-2. m ← B^32; (K̄, r) ← G(m || H(pk)) */ 551 random_bytes(m, KYBER_SYMBYTES); 552 sha3_256(pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES, buf); 553 memcpy(buf + KYBER_SYMBYTES, m, KYBER_SYMBYTES); 554 sha3_512(buf, 2 * KYBER_SYMBYTES, kr); 555 556 /* 3. Generate matrix A from ρ (same as in keygen) */ 557 for (int i = 0; i < k; i++) { 558 for (int j = 0; j < k; j++) { 559 poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j); 560 } 561 } 562 563 /* 4-6. Sample r, e1, e2 from PRF using r as seed */ 564 for (int i = 0; i < k; i++) { 565 poly_cbd_eta(&sp.vec[i], kr + KYBER_SYMBYTES, nonce++, eta1); 566 } 567 for (int i = 0; i < k; i++) { 568 poly_cbd_eta(&ep.vec[i], kr + KYBER_SYMBYTES, nonce++, eta2); 569 } 570 poly_cbd_eta(&epp, kr + KYBER_SYMBYTES, nonce++, eta2); 571 572 /* 7. r̂ ← NTT(r) */ 573 polyvec_ntt(&sp, k); 574 575 /* 8-10. u ← NTT^{-1}(A^T ◦ r̂) + e1 */ 576 for (int i = 0; i < k; i++) { 577 polyvec_pointwise_acc_montgomery(&bp.vec[i], &matrix[i], &sp, k); 578 poly_invntt_tomont(&bp.vec[i]); 579 } 580 polyvec_add(&bp, &bp, &ep, k); 581 polyvec_reduce(&bp, k); 582 583 /* 11-14. v ← NTT^{-1}(t̂^T ◦ r̂) + e2 + Decompress_q(Decode_1(m), 1) */ 584 polyvec pkpv; 585 polyvec_frombytes(&pkpv, pk, k); 586 polyvec_ntt(&pkpv, k); 587 polyvec_pointwise_acc_montgomery(&v, &pkpv, &sp, k); 588 poly_invntt_tomont(&v); 589 poly_add(&v, &v, &epp); 590 591 /* Add message */ 592 for (int i = 0; i < KYBER_N; i++) { 593 mp.coeffs[i] = ((m[i >> 3] >> (i & 7)) & 1) * (KYBER_Q + 1) / 2; 594 } 595 poly_add(&v, &v, &mp); 596 poly_reduce(&v); 597 598 /* 15-16. c ← ByteEncode_{du}(Compress_q(u, du)) || ByteEncode_{dv}(Compress_q(v, dv)) */ 599 polyvec_compress(ct, &bp, k, du); 600 poly_compress(ct + k * (du == 10 ? 320 : du == 11 ? 352 : 128), &v, dv); 601 602 /* 17. K ← KDF(K̄ || H(c)) */ 603 sha3_256(ct, k * (du == 10 ? 320 : du == 11 ? 352 : 128) + (dv == 4 ? 128 : 160), buf); 604 memcpy(buf + KYBER_SYMBYTES, kr, KYBER_SYMBYTES); 605 shake256(buf, 2 * KYBER_SYMBYTES, ss, KYBER_SYMBYTES); 606 607 /* Secure cleanup */ 608 secure_zero(m, sizeof(m)); 609 secure_zero(kr, sizeof(kr)); 610 secure_zero(&sp, sizeof(sp)); 611 612 return 0; 613 } 614 615 /* ============================================================================ 616 * Decapsulation (FIPS 203 Algorithm 17 - ML-KEM.Decaps with Implicit Rejection) 617 * ========================================================================= */ 618 619 static int kyber_decapsulate_internal(uint8_t *ss, const uint8_t *ct, const uint8_t *sk, int k, int eta1, int eta2, int du, int dv) { 620 polyvec matrix[4], bp, skpv, mp; 621 poly v, vp, epp; 622 uint8_t m[KYBER_SYMBYTES]; 623 uint8_t m2[KYBER_SYMBYTES]; 624 uint8_t buf[2 * KYBER_SYMBYTES]; 625 uint8_t kr[2 * KYBER_SYMBYTES]; 626 uint8_t ct2[KYBER1024_CIPHERTEXT_BYTES]; /* Max size */ 627 uint8_t *rho = (uint8_t*)sk + k * KYBER_POLYBYTES + k * KYBER_POLYBYTES; 628 uint8_t *pk_hash = (uint8_t*)sk + 2 * k * KYBER_POLYBYTES + KYBER_SYMBYTES; 629 uint8_t *z = (uint8_t*)sk + 2 * k * KYBER_POLYBYTES + 2 * KYBER_SYMBYTES; 630 uint8_t nonce; 631 int fail; 632 633 /* 1-2. Decode ciphertext: u ← Decompress_q(Decode_{du}(c[:32·du·k]), du) */ 634 polyvec_decompress(&bp, ct, k, du); 635 poly_decompress(&v, ct + k * (du == 10 ? 320 : du == 11 ? 352 : 128), dv); 636 637 /* 3-4. ŝ ← ByteDecode_12(dk) */ 638 polyvec_frombytes(&skpv, sk, k); 639 640 /* 5-7. m ← ByteEncode_1(Compress_q(v - NTT^{-1}(ŝ^T ◦ NTT(u)), 1)) */ 641 polyvec_ntt(&bp, k); 642 polyvec_pointwise_acc_montgomery(&vp, &skpv, &bp, k); 643 poly_invntt_tomont(&vp); 644 poly_sub(&vp, &v, &vp); 645 poly_reduce(&vp); 646 647 /* Extract message bits */ 648 for (int i = 0; i < KYBER_SYMBYTES; i++) { 649 m[i] = 0; 650 for (int j = 0; j < 8; j++) { 651 int16_t t = vp.coeffs[8*i + j]; 652 t += (t >> 15) & KYBER_Q; 653 t = (((t << 1) + KYBER_Q/2) / KYBER_Q) & 1; 654 m[i] |= t << j; 655 } 656 } 657 658 /* 8-9. (K̄', r') ← G(m || H(pk)) */ 659 memcpy(buf, m, KYBER_SYMBYTES); 660 memcpy(buf + KYBER_SYMBYTES, pk_hash, KYBER_SYMBYTES); 661 sha3_512(buf, 2 * KYBER_SYMBYTES, kr); 662 663 /* 10-15. Re-encrypt to verify: c' ← Encrypt(pk, m, r') */ 664 for (int i = 0; i < k; i++) { 665 for (int j = 0; j < k; j++) { 666 poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j); 667 } 668 } 669 670 nonce = 0; 671 polyvec sp, ep; 672 for (int i = 0; i < k; i++) { 673 poly_cbd_eta(&sp.vec[i], kr + KYBER_SYMBYTES, nonce++, eta1); 674 } 675 for (int i = 0; i < k; i++) { 676 poly_cbd_eta(&ep.vec[i], kr + KYBER_SYMBYTES, nonce++, eta2); 677 } 678 poly_cbd_eta(&epp, kr + KYBER_SYMBYTES, nonce++, eta2); 679 680 polyvec_ntt(&sp, k); 681 for (int i = 0; i < k; i++) { 682 polyvec_pointwise_acc_montgomery(&mp.vec[i], &matrix[i], &sp, k); 683 poly_invntt_tomont(&mp.vec[i]); 684 } 685 polyvec_add(&mp, &mp, &ep, k); 686 polyvec_reduce(&mp, k); 687 688 polyvec pkpv; 689 polyvec_frombytes(&pkpv, (uint8_t*)sk + k * KYBER_POLYBYTES, k); 690 polyvec_ntt(&pkpv, k); 691 polyvec_pointwise_acc_montgomery(&vp, &pkpv, &sp, k); 692 poly_invntt_tomont(&vp); 693 poly_add(&vp, &vp, &epp); 694 695 poly msg_poly; 696 for (int i = 0; i < KYBER_N; i++) { 697 msg_poly.coeffs[i] = ((m[i >> 3] >> (i & 7)) & 1) * (KYBER_Q + 1) / 2; 698 } 699 poly_add(&vp, &vp, &msg_poly); 700 poly_reduce(&vp); 701 702 polyvec_compress(ct2, &mp, k, du); 703 poly_compress(ct2 + k * (du == 10 ? 320 : du == 11 ? 352 : 128), &vp, dv); 704 705 /* 16. Constant-time comparison */ 706 int ctlen = k * (du == 10 ? 320 : du == 11 ? 352 : 128) + (dv == 4 ? 128 : 160); 707 fail = ct_memcmp(ct, ct2, ctlen); 708 709 /* 17-18. Implicit rejection: use z if decryption fails */ 710 sha3_256(ct, ctlen, buf); 711 712 /* Constant-time select between K̄' and z */ 713 for (int i = 0; i < KYBER_SYMBYTES; i++) { 714 kr[i] = ct_select_u8(fail, z[i], kr[i]); 715 } 716 717 memcpy(buf + KYBER_SYMBYTES, kr, KYBER_SYMBYTES); 718 shake256(buf, 2 * KYBER_SYMBYTES, ss, KYBER_SYMBYTES); 719 720 /* Secure cleanup */ 721 secure_zero(m, sizeof(m)); 722 secure_zero(kr, sizeof(kr)); 723 secure_zero(&skpv, sizeof(skpv)); 724 725 return 0; 726 } 727 728 /* ============================================================================ 729 * Public API 730 * ========================================================================= */ 731 732 int kyber512_keypair(uint8_t *public_key, uint8_t *secret_key) { 733 if (!public_key || !secret_key) return -1; 734 return kyber_keypair_internal(public_key, secret_key, KYBER512_K, KYBER512_ETA1); 735 } 736 737 int kyber512_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { 738 if (!ciphertext || !shared_secret || !public_key) return -1; 739 return kyber_encapsulate_internal(ciphertext, shared_secret, public_key, 740 KYBER512_K, KYBER512_ETA1, KYBER512_ETA2, KYBER512_DU, KYBER512_DV); 741 } 742 743 int kyber512_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { 744 if (!shared_secret || !ciphertext || !secret_key) return -1; 745 return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key, 746 KYBER512_K, KYBER512_ETA1, KYBER512_ETA2, KYBER512_DU, KYBER512_DV); 747 } 748 749 int kyber768_keypair(uint8_t *public_key, uint8_t *secret_key) { 750 if (!public_key || !secret_key) return -1; 751 return kyber_keypair_internal(public_key, secret_key, KYBER768_K, KYBER768_ETA1); 752 } 753 754 int kyber768_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { 755 if (!ciphertext || !shared_secret || !public_key) return -1; 756 return kyber_encapsulate_internal(ciphertext, shared_secret, public_key, 757 KYBER768_K, KYBER768_ETA1, KYBER768_ETA2, KYBER768_DU, KYBER768_DV); 758 } 759 760 int kyber768_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { 761 if (!shared_secret || !ciphertext || !secret_key) return -1; 762 return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key, 763 KYBER768_K, KYBER768_ETA1, KYBER768_ETA2, KYBER768_DU, KYBER768_DV); 764 } 765 766 int kyber1024_keypair(uint8_t *public_key, uint8_t *secret_key) { 767 if (!public_key || !secret_key) return -1; 768 return kyber_keypair_internal(public_key, secret_key, KYBER1024_K, KYBER1024_ETA1); 769 } 770 771 int kyber1024_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { 772 if (!ciphertext || !shared_secret || !public_key) return -1; 773 return kyber_encapsulate_internal(ciphertext, shared_secret, public_key, 774 KYBER1024_K, KYBER1024_ETA1, KYBER1024_ETA2, KYBER1024_DU, KYBER1024_DV); 775 } 776 777 int kyber1024_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { 778 if (!shared_secret || !ciphertext || !secret_key) return -1; 779 return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key, 780 KYBER1024_K, KYBER1024_ETA1, KYBER1024_ETA2, KYBER1024_DU, KYBER1024_DV); 781 }