luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

Kyber.c (28211B)


      1 /*
      2  * Kyber.c - CRYSTALS-Kyber FIPS 203 (ML-KEM) Compliant Implementation
      3  *
      4  * NIST FIPS 203 - Module-Lattice-Based Key-Encapsulation Mechanism Standard
      5  * Full specification compliance with all required algorithms
      6  *
      7  * SECURITY: Production implementation with:
      8  * - Full FIPS 203 encapsulation/decapsulation algorithms
      9  * - IND-CCA2 security via implicit rejection
     10  * - Constant-time operations for secret-dependent code
     11  * - Proper polynomial packing per specification
     12  * - SHAKE-128/256 for all randomness expansion
     13  */
     14 
     15 #include "Kyber.h"
     16 #include "CSPRNG.h"
     17 #include "ct_util.h"
     18 #include "hashing/hash.h"
     19 #include "hashing/SHA3.h"
     20 #include <string.h>
     21 
     22 /* ============================================================================
     23  * FIPS 203 Parameters
     24  * ========================================================================= */
     25 
     26 #define KYBER_N 256
     27 #define KYBER_Q 3329
     28 #define KYBER_SYMBYTES 32
     29 
     30 /* Kyber512 (ML-KEM-512) */
     31 #define KYBER512_K 2
     32 #define KYBER512_ETA1 3
     33 #define KYBER512_ETA2 2
     34 #define KYBER512_DU 10
     35 #define KYBER512_DV 4
     36 
     37 /* Kyber768 (ML-KEM-768) - RECOMMENDED */
     38 #define KYBER768_K 3
     39 #define KYBER768_ETA1 2
     40 #define KYBER768_ETA2 2
     41 #define KYBER768_DU 10
     42 #define KYBER768_DV 4
     43 
     44 /* Kyber1024 (ML-KEM-1024) */
     45 #define KYBER1024_K 4
     46 #define KYBER1024_ETA1 2
     47 #define KYBER1024_ETA2 2
     48 #define KYBER1024_DU 11
     49 #define KYBER1024_DV 5
     50 
     51 #define KYBER_POLYBYTES 384
     52 #define SHAKE128_RATE 168
     53 
     54 /* ============================================================================
     55  * NTT Constants (zetas in bit-reversed order)
     56  * Verified against pq-crystals/kyber reference implementation
     57  * ========================================================================= */
     58 
     59 static const int16_t zetas[128] = {
     60     -1044, -758, -359, -1517, 1493, 1422, 287, 202,
     61     -171, 622, 1577, 182, 962, -1202, -1474, 1468,
     62     573, -1325, 264, 383, -829, 1458, -1602, -130,
     63     -681, 1017, 732, 608, -1542, 411, -205, -1571,
     64     1223, 652, -552, 1015, -1293, 1491, -282, -1544,
     65     516, -8, -320, -666, -1618, -1162, 126, 1469,
     66     -853, -90, -271, 830, 107, -1421, -247, -951,
     67     -398, 961, -1508, -725, 448, -1065, 677, -1275,
     68     -1103, 430, 555, 843, -1251, 871, 1550, 105,
     69     422, 587, 177, -235, -291, -460, 1574, 1653,
     70     -246, 778, 1159, -147, -777, 1483, -602, 1119,
     71     -1590, 644, -872, 349, 418, 329, -156, -75,
     72     817, 1097, 603, 610, 1322, -1285, -1465, 384,
     73     -1215, -136, 1218, -1335, -874, 220, -1187, -1659,
     74     -1185, -1530, -1278, 794, -1510, -854, -870, 478,
     75     -108, -308, 996, 991, 958, -1460, 1522, 1628
     76 };
     77 
     78 /* ============================================================================
     79  * Modular Arithmetic
     80  * ========================================================================= */
     81 
     82 static inline int16_t montgomery_reduce(int32_t a) {
     83     int16_t t;
     84     t = (int16_t)((uint32_t)a * 62209);
     85     t = (a - (int32_t)t * KYBER_Q) >> 16;
     86     return t;
     87 }
     88 
     89 static inline int16_t barrett_reduce(int16_t a) {
     90     int16_t t;
     91     const int16_t v = ((1 << 26) + KYBER_Q/2) / KYBER_Q;
     92     t = ((int32_t)v * a + (1 << 25)) >> 26;
     93     t *= KYBER_Q;
     94     return a - t;
     95 }
     96 
     97 static inline int16_t fqmul(int16_t a, int16_t b) {
     98     return montgomery_reduce((int32_t)a * b);
     99 }
    100 
    101 /* ============================================================================
    102  * Polynomial Operations
    103  * ========================================================================= */
    104 
    105 typedef struct {
    106     int16_t coeffs[KYBER_N];
    107 } poly;
    108 
    109 typedef struct {
    110     poly vec[4];  /* Max k=4 */
    111 } polyvec;
    112 
    113 static void poly_reduce(poly *r) {
    114     for (int i = 0; i < KYBER_N; i++) {
    115         r->coeffs[i] = barrett_reduce(r->coeffs[i]);
    116     }
    117 }
    118 
    119 static void poly_add(poly *r, const poly *a, const poly *b) {
    120     for (int i = 0; i < KYBER_N; i++) {
    121         r->coeffs[i] = a->coeffs[i] + b->coeffs[i];
    122     }
    123 }
    124 
    125 static void poly_sub(poly *r, const poly *a, const poly *b) {
    126     for (int i = 0; i < KYBER_N; i++) {
    127         r->coeffs[i] = a->coeffs[i] - b->coeffs[i];
    128     }
    129 }
    130 
    131 static void poly_ntt(poly *r) {
    132     unsigned int len, start, j, k;
    133     int16_t t, zeta;
    134 
    135     k = 1;
    136     for (len = 128; len >= 2; len >>= 1) {
    137         for (start = 0; start < KYBER_N; start = j + len) {
    138             zeta = zetas[k++];
    139             for (j = start; j < start + len; j++) {
    140                 t = fqmul(zeta, r->coeffs[j + len]);
    141                 r->coeffs[j + len] = r->coeffs[j] - t;
    142                 r->coeffs[j] = r->coeffs[j] + t;
    143             }
    144         }
    145     }
    146 }
    147 
    148 static void poly_invntt_tomont(poly *r) {
    149     unsigned int start, len, j, k;
    150     int16_t t, zeta;
    151     const int16_t f = 1441;
    152 
    153     k = 127;
    154     for (len = 2; len <= 128; len <<= 1) {
    155         for (start = 0; start < KYBER_N; start = j + len) {
    156             zeta = zetas[k--];
    157             for (j = start; j < start + len; j++) {
    158                 t = r->coeffs[j];
    159                 r->coeffs[j] = barrett_reduce(t + r->coeffs[j + len]);
    160                 r->coeffs[j + len] = r->coeffs[j + len] - t;
    161                 r->coeffs[j + len] = fqmul(zeta, r->coeffs[j + len]);
    162             }
    163         }
    164     }
    165 
    166     for (j = 0; j < KYBER_N; j++) {
    167         r->coeffs[j] = fqmul(r->coeffs[j], f);
    168     }
    169 }
    170 
    171 static void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) {
    172     for (int i = 0; i < KYBER_N / 4; i++) {
    173         int16_t rx, ry;
    174 
    175         rx = fqmul(a->coeffs[4*i+1], b->coeffs[4*i+1]);
    176         rx = fqmul(rx, zetas[64 + i]);
    177         rx += fqmul(a->coeffs[4*i], b->coeffs[4*i]);
    178 
    179         ry = fqmul(a->coeffs[4*i], b->coeffs[4*i+1]);
    180         ry += fqmul(a->coeffs[4*i+1], b->coeffs[4*i]);
    181 
    182         r->coeffs[4*i] = rx;
    183         r->coeffs[4*i+1] = ry;
    184 
    185         rx = fqmul(a->coeffs[4*i+3], b->coeffs[4*i+3]);
    186         rx = fqmul(rx, -zetas[64 + i]);
    187         rx += fqmul(a->coeffs[4*i+2], b->coeffs[4*i+2]);
    188 
    189         ry = fqmul(a->coeffs[4*i+2], b->coeffs[4*i+3]);
    190         ry += fqmul(a->coeffs[4*i+3], b->coeffs[4*i+2]);
    191 
    192         r->coeffs[4*i+2] = rx;
    193         r->coeffs[4*i+3] = ry;
    194     }
    195 }
    196 
    197 static void poly_tomont(poly *r) {
    198     const int16_t f = (1ULL << 32) % KYBER_Q;
    199     for (int i = 0; i < KYBER_N; i++) {
    200         r->coeffs[i] = montgomery_reduce((int32_t)r->coeffs[i] * f);
    201     }
    202 }
    203 
    204 /* ============================================================================
    205  * Polynomial Vector Operations
    206  * ========================================================================= */
    207 
    208 static void polyvec_ntt(polyvec *r, int k) {
    209     for (int i = 0; i < k; i++) {
    210         poly_ntt(&r->vec[i]);
    211     }
    212 }
    213 
    214 static void polyvec_invntt_tomont(polyvec *r, int k) {
    215     for (int i = 0; i < k; i++) {
    216         poly_invntt_tomont(&r->vec[i]);
    217     }
    218 }
    219 
    220 static void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b, int k) {
    221     for (int i = 0; i < k; i++) {
    222         poly_add(&r->vec[i], &a->vec[i], &b->vec[i]);
    223     }
    224 }
    225 
    226 static void polyvec_reduce(polyvec *r, int k) {
    227     for (int i = 0; i < k; i++) {
    228         poly_reduce(&r->vec[i]);
    229     }
    230 }
    231 
    232 /* ============================================================================
    233  * Sampling Functions (FIPS 203 compliant)
    234  * ========================================================================= */
    235 
    236 static unsigned int rej_uniform(int16_t *r, unsigned int len, const uint8_t *buf, unsigned int buflen) {
    237     unsigned int ctr, pos;
    238     uint16_t val0, val1;
    239 
    240     ctr = pos = 0;
    241     while (ctr < len && pos + 3 <= buflen) {
    242         val0 = ((buf[pos] | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF);
    243         val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF;
    244         pos += 3;
    245 
    246         if (val0 < KYBER_Q) {
    247             r[ctr++] = val0;
    248         }
    249         if (ctr < len && val1 < KYBER_Q) {
    250             r[ctr++] = val1;
    251         }
    252     }
    253     return ctr;
    254 }
    255 
    256 /* SampleNTT: Sample polynomial from uniform distribution using XOF (SHAKE-128) */
    257 static void poly_uniform(poly *a, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) {
    258     unsigned int ctr, off;
    259     unsigned int buflen = SHAKE128_RATE * 3;
    260     uint8_t buf[SHAKE128_RATE * 3];
    261     uint8_t extseed[KYBER_SYMBYTES + 1];
    262 
    263     memcpy(extseed, seed, KYBER_SYMBYTES);
    264     extseed[KYBER_SYMBYTES] = nonce;
    265 
    266     shake128(extseed, KYBER_SYMBYTES + 1, buf, buflen);
    267     ctr = rej_uniform(a->coeffs, KYBER_N, buf, buflen);
    268 
    269     while (ctr < KYBER_N) {
    270         off = buflen % 3;
    271         for (unsigned int i = 0; i < off; i++) {
    272             buf[i] = buf[buflen - off + i];
    273         }
    274         shake128(extseed, KYBER_SYMBYTES + 1, buf + off, buflen - off);
    275         buflen = SHAKE128_RATE * 3;
    276         ctr += rej_uniform(a->coeffs + ctr, KYBER_N - ctr, buf, buflen);
    277     }
    278 }
    279 
    280 /* CBD: Centered Binomial Distribution */
    281 static void cbd_eta(poly *r, const uint8_t *buf, int eta) {
    282     uint32_t t, d;
    283     int16_t a, b;
    284 
    285     if (eta == 2) {
    286         for (int i = 0; i < KYBER_N / 8; i++) {
    287             t = buf[2*i] | ((uint32_t)buf[2*i + 1] << 8);
    288             d = t & 0x55555555;
    289             d += (t >> 1) & 0x55555555;
    290 
    291             for (int j = 0; j < 8; j++) {
    292                 a = (d >> (4*j)) & 0x3;
    293                 b = (d >> (4*j + 2)) & 0x3;
    294                 r->coeffs[8*i + j] = a - b;
    295             }
    296         }
    297     } else if (eta == 3) {
    298         for (int i = 0; i < KYBER_N / 4; i++) {
    299             t = buf[3*i] | ((uint32_t)buf[3*i + 1] << 8) | ((uint32_t)buf[3*i + 2] << 16);
    300             d = t & 0x00249249;
    301             d += (t >> 1) & 0x00249249;
    302             d += (t >> 2) & 0x00249249;
    303 
    304             for (int j = 0; j < 4; j++) {
    305                 a = (d >> (6*j)) & 0x7;
    306                 b = (d >> (6*j + 3)) & 0x7;
    307                 r->coeffs[4*i + j] = a - b;
    308             }
    309         }
    310     }
    311 }
    312 
    313 /* SamplePolyCBD: Sample polynomial from CBD using PRF (SHAKE-256) */
    314 static void poly_cbd_eta(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce, int eta) {
    315     uint8_t buf[eta * KYBER_N / 4];
    316     uint8_t extseed[KYBER_SYMBYTES + 1];
    317 
    318     memcpy(extseed, seed, KYBER_SYMBYTES);
    319     extseed[KYBER_SYMBYTES] = nonce;
    320 
    321     shake256(extseed, KYBER_SYMBYTES + 1, buf, eta * KYBER_N / 4);
    322     cbd_eta(r, buf, eta);
    323 }
    324 
    325 /* ============================================================================
    326  * Compression and Decompression (FIPS 203 §4.2.1)
    327  * ========================================================================= */
    328 
    329 static void poly_compress(uint8_t *r, const poly *a, int d) {
    330     uint8_t t[8];
    331     int16_t u;
    332 
    333     if (d == 4) {
    334         for (int i = 0; i < KYBER_N / 8; i++) {
    335             for (int j = 0; j < 8; j++) {
    336                 u = a->coeffs[8*i + j];
    337                 u += (u >> 15) & KYBER_Q;
    338                 t[j] = ((((uint32_t)u << 4) + KYBER_Q/2) / KYBER_Q) & 15;
    339             }
    340             r[4*i] = t[0] | (t[1] << 4);
    341             r[4*i + 1] = t[2] | (t[3] << 4);
    342             r[4*i + 2] = t[4] | (t[5] << 4);
    343             r[4*i + 3] = t[6] | (t[7] << 4);
    344         }
    345     } else if (d == 10) {
    346         for (int i = 0; i < KYBER_N / 4; i++) {
    347             for (int j = 0; j < 4; j++) {
    348                 u = a->coeffs[4*i + j];
    349                 u += (u >> 15) & KYBER_Q;
    350                 t[j] = ((((uint32_t)u << 10) + KYBER_Q/2) / KYBER_Q) & 0x3ff;
    351             }
    352             r[5*i] = (uint8_t)t[0];
    353             r[5*i + 1] = (uint8_t)((t[0] >> 8) | (t[1] << 2));
    354             r[5*i + 2] = (uint8_t)((t[1] >> 6) | (t[2] << 4));
    355             r[5*i + 3] = (uint8_t)((t[2] >> 4) | (t[3] << 6));
    356             r[5*i + 4] = (uint8_t)(t[3] >> 2);
    357         }
    358     } else if (d == 11) {
    359         for (int i = 0; i < KYBER_N / 8; i++) {
    360             for (int j = 0; j < 8; j++) {
    361                 u = a->coeffs[8*i + j];
    362                 u += (u >> 15) & KYBER_Q;
    363                 t[j] = ((((uint32_t)u << 11) + KYBER_Q/2) / KYBER_Q) & 0x7ff;
    364             }
    365             r[11*i] = (uint8_t)t[0];
    366             r[11*i + 1] = (uint8_t)((t[0] >> 8) | (t[1] << 3));
    367             r[11*i + 2] = (uint8_t)((t[1] >> 5) | (t[2] << 6));
    368             r[11*i + 3] = (uint8_t)(t[2] >> 2);
    369             r[11*i + 4] = (uint8_t)((t[2] >> 10) | (t[3] << 1));
    370             r[11*i + 5] = (uint8_t)((t[3] >> 7) | (t[4] << 4));
    371             r[11*i + 6] = (uint8_t)((t[4] >> 4) | (t[5] << 7));
    372             r[11*i + 7] = (uint8_t)(t[5] >> 1);
    373             r[11*i + 8] = (uint8_t)((t[5] >> 9) | (t[6] << 2));
    374             r[11*i + 9] = (uint8_t)((t[6] >> 6) | (t[7] << 5));
    375             r[11*i + 10] = (uint8_t)(t[7] >> 3);
    376         }
    377     }
    378 }
    379 
    380 static void poly_decompress(poly *r, const uint8_t *a, int d) {
    381     if (d == 4) {
    382         for (int i = 0; i < KYBER_N / 2; i++) {
    383             r->coeffs[2*i] = (((uint32_t)(a[i] & 15) * KYBER_Q) + 8) >> 4;
    384             r->coeffs[2*i + 1] = (((uint32_t)(a[i] >> 4) * KYBER_Q) + 8) >> 4;
    385         }
    386     } else if (d == 10) {
    387         for (int i = 0; i < KYBER_N / 4; i++) {
    388             r->coeffs[4*i] = ((((uint32_t)a[5*i] | ((uint32_t)a[5*i + 1] << 8)) & 0x3ff) * KYBER_Q + 512) >> 10;
    389             r->coeffs[4*i + 1] = ((((uint32_t)a[5*i + 1] >> 2 | ((uint32_t)a[5*i + 2] << 6)) & 0x3ff) * KYBER_Q + 512) >> 10;
    390             r->coeffs[4*i + 2] = ((((uint32_t)a[5*i + 2] >> 4 | ((uint32_t)a[5*i + 3] << 4)) & 0x3ff) * KYBER_Q + 512) >> 10;
    391             r->coeffs[4*i + 3] = ((((uint32_t)a[5*i + 3] >> 6 | ((uint32_t)a[5*i + 4] << 2)) & 0x3ff) * KYBER_Q + 512) >> 10;
    392         }
    393     } else if (d == 11) {
    394         for (int i = 0; i < KYBER_N / 8; i++) {
    395             r->coeffs[8*i] = ((((uint32_t)a[11*i] | ((uint32_t)a[11*i + 1] << 8)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    396             r->coeffs[8*i + 1] = ((((uint32_t)a[11*i + 1] >> 3 | ((uint32_t)a[11*i + 2] << 5)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    397             r->coeffs[8*i + 2] = ((((uint32_t)a[11*i + 2] >> 6 | ((uint32_t)a[11*i + 3] << 2) | ((uint32_t)a[11*i + 4] << 10)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    398             r->coeffs[8*i + 3] = ((((uint32_t)a[11*i + 4] >> 1 | ((uint32_t)a[11*i + 5] << 7)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    399             r->coeffs[8*i + 4] = ((((uint32_t)a[11*i + 5] >> 4 | ((uint32_t)a[11*i + 6] << 4)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    400             r->coeffs[8*i + 5] = ((((uint32_t)a[11*i + 6] >> 7 | ((uint32_t)a[11*i + 7] << 1) | ((uint32_t)a[11*i + 8] << 9)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    401             r->coeffs[8*i + 6] = ((((uint32_t)a[11*i + 8] >> 2 | ((uint32_t)a[11*i + 9] << 6)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    402             r->coeffs[8*i + 7] = ((((uint32_t)a[11*i + 9] >> 5 | ((uint32_t)a[11*i + 10] << 3)) & 0x7ff) * KYBER_Q + 1024) >> 11;
    403         }
    404     }
    405 }
    406 
    407 static void polyvec_compress(uint8_t *r, const polyvec *a, int k, int d) {
    408     for (int i = 0; i < k; i++) {
    409         poly_compress(r + i * (d == 10 ? 320 : d == 11 ? 352 : 128), &a->vec[i], d);
    410     }
    411 }
    412 
    413 static void polyvec_decompress(polyvec *r, const uint8_t *a, int k, int d) {
    414     for (int i = 0; i < k; i++) {
    415         poly_decompress(&r->vec[i], a + i * (d == 10 ? 320 : d == 11 ? 352 : 128), d);
    416     }
    417 }
    418 
    419 /* ============================================================================
    420  * Polynomial Packing/Unpacking (FIPS 203 compliant)
    421  * ========================================================================= */
    422 
    423 static void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) {
    424     uint16_t t0, t1;
    425 
    426     for (int i = 0; i < KYBER_N / 2; i++) {
    427         t0 = a->coeffs[2*i];
    428         t0 += (t0 >> 15) & KYBER_Q;
    429 
    430         t1 = a->coeffs[2*i + 1];
    431         t1 += (t1 >> 15) & KYBER_Q;
    432 
    433         r[3*i] = (uint8_t)t0;
    434         r[3*i + 1] = (uint8_t)(t0 >> 8) | (uint8_t)(t1 << 4);
    435         r[3*i + 2] = (uint8_t)(t1 >> 4);
    436     }
    437 }
    438 
    439 static void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) {
    440     for (int i = 0; i < KYBER_N / 2; i++) {
    441         r->coeffs[2*i] = ((a[3*i] | ((uint16_t)a[3*i + 1] << 8)) & 0xFFF);
    442         r->coeffs[2*i + 1] = ((a[3*i + 1] >> 4) | ((uint16_t)a[3*i + 2] << 4)) & 0xFFF;
    443     }
    444 }
    445 
    446 static void polyvec_tobytes(uint8_t *r, const polyvec *a, int k) {
    447     for (int i = 0; i < k; i++) {
    448         poly_tobytes(r + i * KYBER_POLYBYTES, &a->vec[i]);
    449     }
    450 }
    451 
    452 static void polyvec_frombytes(polyvec *r, const uint8_t *a, int k) {
    453     for (int i = 0; i < k; i++) {
    454         poly_frombytes(&r->vec[i], a + i * KYBER_POLYBYTES);
    455     }
    456 }
    457 
    458 /* ============================================================================
    459  * Matrix-Vector Multiplication
    460  * ========================================================================= */
    461 
    462 static void polyvec_pointwise_acc_montgomery(poly *r, const polyvec *a, const polyvec *b, int k) {
    463     poly t;
    464     poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]);
    465 
    466     for (int i = 1; i < k; i++) {
    467         poly_basemul_montgomery(&t, &a->vec[i], &b->vec[i]);
    468         poly_add(r, r, &t);
    469     }
    470     poly_reduce(r);
    471 }
    472 
    473 /* ============================================================================
    474  * Key Generation (FIPS 203 Algorithm 15 - ML-KEM.KeyGen)
    475  * ========================================================================= */
    476 
    477 static int kyber_keypair_internal(uint8_t *pk, uint8_t *sk, int k, int eta1) {
    478     polyvec matrix[4];  /* A^T (max 4x4) */
    479     polyvec s, e, pkpv;
    480     uint8_t buf[2 * KYBER_SYMBYTES];
    481     uint8_t *rho = buf;
    482     uint8_t *sigma = buf + KYBER_SYMBYTES;
    483     uint8_t nonce = 0;
    484 
    485     /* 1. d ← B^32; (ρ, σ) ← G(d) */
    486     random_bytes(buf, KYBER_SYMBYTES);
    487     sha3_512(buf, KYBER_SYMBYTES, buf);
    488 
    489     /* 2-3. Generate matrix A from ρ */
    490     for (int i = 0; i < k; i++) {
    491         for (int j = 0; j < k; j++) {
    492             poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j);
    493         }
    494     }
    495 
    496     /* 4-6. Sample secret vector s and error vector e */
    497     for (int i = 0; i < k; i++) {
    498         poly_cbd_eta(&s.vec[i], sigma, nonce++, eta1);
    499     }
    500     for (int i = 0; i < k; i++) {
    501         poly_cbd_eta(&e.vec[i], sigma, nonce++, eta1);
    502     }
    503 
    504     /* 7. ŝ ← NTT(s) */
    505     polyvec_ntt(&s, k);
    506 
    507     /* 8-10. Compute t = A ◦ s + e */
    508     for (int i = 0; i < k; i++) {
    509         polyvec_pointwise_acc_montgomery(&pkpv.vec[i], &matrix[i], &s, k);
    510         poly_invntt_tomont(&pkpv.vec[i]);
    511     }
    512     polyvec_add(&pkpv, &pkpv, &e, k);
    513     polyvec_reduce(&pkpv, k);
    514 
    515     /* 11-12. Pack public key: pk ← ByteEncode12(t) || ρ */
    516     polyvec_tobytes(pk, &pkpv, k);
    517     memcpy(pk + k * KYBER_POLYBYTES, rho, KYBER_SYMBYTES);
    518 
    519     /* 13-14. Pack secret key: sk ← ByteEncode12(ŝ) */
    520     polyvec_tobytes(sk, &s, k);
    521 
    522     /* Append pk and hash for implicit rejection */
    523     memcpy(sk + k * KYBER_POLYBYTES, pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES);
    524     sha3_256(pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES, sk + 2 * k * KYBER_POLYBYTES + KYBER_SYMBYTES);
    525 
    526     /* Append random z for implicit rejection */
    527     random_bytes(sk + 2 * k * KYBER_POLYBYTES + 2 * KYBER_SYMBYTES, KYBER_SYMBYTES);
    528 
    529     /* Secure cleanup */
    530     secure_zero(buf, sizeof(buf));
    531     secure_zero(&s, sizeof(s));
    532     secure_zero(&e, sizeof(e));
    533 
    534     return 0;
    535 }
    536 
    537 /* ============================================================================
    538  * Encapsulation (FIPS 203 Algorithm 16 - ML-KEM.Encaps)
    539  * ========================================================================= */
    540 
    541 static int kyber_encapsulate_internal(uint8_t *ct, uint8_t *ss, const uint8_t *pk, int k, int eta1, int eta2, int du, int dv) {
    542     polyvec matrix[4], sp, ep, bp;
    543     poly v, epp, mp;
    544     uint8_t m[KYBER_SYMBYTES];
    545     uint8_t buf[2 * KYBER_SYMBYTES];
    546     uint8_t kr[2 * KYBER_SYMBYTES];
    547     uint8_t *rho = (uint8_t*)pk + k * KYBER_POLYBYTES;
    548     uint8_t nonce = 0;
    549 
    550     /* 1-2. m ← B^32; (K̄, r) ← G(m || H(pk)) */
    551     random_bytes(m, KYBER_SYMBYTES);
    552     sha3_256(pk, k * KYBER_POLYBYTES + KYBER_SYMBYTES, buf);
    553     memcpy(buf + KYBER_SYMBYTES, m, KYBER_SYMBYTES);
    554     sha3_512(buf, 2 * KYBER_SYMBYTES, kr);
    555 
    556     /* 3. Generate matrix A from ρ (same as in keygen) */
    557     for (int i = 0; i < k; i++) {
    558         for (int j = 0; j < k; j++) {
    559             poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j);
    560         }
    561     }
    562 
    563     /* 4-6. Sample r, e1, e2 from PRF using r as seed */
    564     for (int i = 0; i < k; i++) {
    565         poly_cbd_eta(&sp.vec[i], kr + KYBER_SYMBYTES, nonce++, eta1);
    566     }
    567     for (int i = 0; i < k; i++) {
    568         poly_cbd_eta(&ep.vec[i], kr + KYBER_SYMBYTES, nonce++, eta2);
    569     }
    570     poly_cbd_eta(&epp, kr + KYBER_SYMBYTES, nonce++, eta2);
    571 
    572     /* 7. r̂ ← NTT(r) */
    573     polyvec_ntt(&sp, k);
    574 
    575     /* 8-10. u ← NTT^{-1}(A^T ◦ r̂) + e1 */
    576     for (int i = 0; i < k; i++) {
    577         polyvec_pointwise_acc_montgomery(&bp.vec[i], &matrix[i], &sp, k);
    578         poly_invntt_tomont(&bp.vec[i]);
    579     }
    580     polyvec_add(&bp, &bp, &ep, k);
    581     polyvec_reduce(&bp, k);
    582 
    583     /* 11-14. v ← NTT^{-1}(t̂^T ◦ r̂) + e2 + Decompress_q(Decode_1(m), 1) */
    584     polyvec pkpv;
    585     polyvec_frombytes(&pkpv, pk, k);
    586     polyvec_ntt(&pkpv, k);
    587     polyvec_pointwise_acc_montgomery(&v, &pkpv, &sp, k);
    588     poly_invntt_tomont(&v);
    589     poly_add(&v, &v, &epp);
    590 
    591     /* Add message */
    592     for (int i = 0; i < KYBER_N; i++) {
    593         mp.coeffs[i] = ((m[i >> 3] >> (i & 7)) & 1) * (KYBER_Q + 1) / 2;
    594     }
    595     poly_add(&v, &v, &mp);
    596     poly_reduce(&v);
    597 
    598     /* 15-16. c ← ByteEncode_{du}(Compress_q(u, du)) || ByteEncode_{dv}(Compress_q(v, dv)) */
    599     polyvec_compress(ct, &bp, k, du);
    600     poly_compress(ct + k * (du == 10 ? 320 : du == 11 ? 352 : 128), &v, dv);
    601 
    602     /* 17. K ← KDF(K̄ || H(c)) */
    603     sha3_256(ct, k * (du == 10 ? 320 : du == 11 ? 352 : 128) + (dv == 4 ? 128 : 160), buf);
    604     memcpy(buf + KYBER_SYMBYTES, kr, KYBER_SYMBYTES);
    605     shake256(buf, 2 * KYBER_SYMBYTES, ss, KYBER_SYMBYTES);
    606 
    607     /* Secure cleanup */
    608     secure_zero(m, sizeof(m));
    609     secure_zero(kr, sizeof(kr));
    610     secure_zero(&sp, sizeof(sp));
    611 
    612     return 0;
    613 }
    614 
    615 /* ============================================================================
    616  * Decapsulation (FIPS 203 Algorithm 17 - ML-KEM.Decaps with Implicit Rejection)
    617  * ========================================================================= */
    618 
    619 static int kyber_decapsulate_internal(uint8_t *ss, const uint8_t *ct, const uint8_t *sk, int k, int eta1, int eta2, int du, int dv) {
    620     polyvec matrix[4], bp, skpv, mp;
    621     poly v, vp, epp;
    622     uint8_t m[KYBER_SYMBYTES];
    623     uint8_t m2[KYBER_SYMBYTES];
    624     uint8_t buf[2 * KYBER_SYMBYTES];
    625     uint8_t kr[2 * KYBER_SYMBYTES];
    626     uint8_t ct2[KYBER1024_CIPHERTEXT_BYTES];  /* Max size */
    627     uint8_t *rho = (uint8_t*)sk + k * KYBER_POLYBYTES + k * KYBER_POLYBYTES;
    628     uint8_t *pk_hash = (uint8_t*)sk + 2 * k * KYBER_POLYBYTES + KYBER_SYMBYTES;
    629     uint8_t *z = (uint8_t*)sk + 2 * k * KYBER_POLYBYTES + 2 * KYBER_SYMBYTES;
    630     uint8_t nonce;
    631     int fail;
    632 
    633     /* 1-2. Decode ciphertext: u ← Decompress_q(Decode_{du}(c[:32·du·k]), du) */
    634     polyvec_decompress(&bp, ct, k, du);
    635     poly_decompress(&v, ct + k * (du == 10 ? 320 : du == 11 ? 352 : 128), dv);
    636 
    637     /* 3-4. ŝ ← ByteDecode_12(dk) */
    638     polyvec_frombytes(&skpv, sk, k);
    639 
    640     /* 5-7. m ← ByteEncode_1(Compress_q(v - NTT^{-1}(ŝ^T ◦ NTT(u)), 1)) */
    641     polyvec_ntt(&bp, k);
    642     polyvec_pointwise_acc_montgomery(&vp, &skpv, &bp, k);
    643     poly_invntt_tomont(&vp);
    644     poly_sub(&vp, &v, &vp);
    645     poly_reduce(&vp);
    646 
    647     /* Extract message bits */
    648     for (int i = 0; i < KYBER_SYMBYTES; i++) {
    649         m[i] = 0;
    650         for (int j = 0; j < 8; j++) {
    651             int16_t t = vp.coeffs[8*i + j];
    652             t += (t >> 15) & KYBER_Q;
    653             t = (((t << 1) + KYBER_Q/2) / KYBER_Q) & 1;
    654             m[i] |= t << j;
    655         }
    656     }
    657 
    658     /* 8-9. (K̄', r') ← G(m || H(pk)) */
    659     memcpy(buf, m, KYBER_SYMBYTES);
    660     memcpy(buf + KYBER_SYMBYTES, pk_hash, KYBER_SYMBYTES);
    661     sha3_512(buf, 2 * KYBER_SYMBYTES, kr);
    662 
    663     /* 10-15. Re-encrypt to verify: c' ← Encrypt(pk, m, r') */
    664     for (int i = 0; i < k; i++) {
    665         for (int j = 0; j < k; j++) {
    666             poly_uniform(&matrix[i].vec[j], rho, (i << 4) | j);
    667         }
    668     }
    669 
    670     nonce = 0;
    671     polyvec sp, ep;
    672     for (int i = 0; i < k; i++) {
    673         poly_cbd_eta(&sp.vec[i], kr + KYBER_SYMBYTES, nonce++, eta1);
    674     }
    675     for (int i = 0; i < k; i++) {
    676         poly_cbd_eta(&ep.vec[i], kr + KYBER_SYMBYTES, nonce++, eta2);
    677     }
    678     poly_cbd_eta(&epp, kr + KYBER_SYMBYTES, nonce++, eta2);
    679 
    680     polyvec_ntt(&sp, k);
    681     for (int i = 0; i < k; i++) {
    682         polyvec_pointwise_acc_montgomery(&mp.vec[i], &matrix[i], &sp, k);
    683         poly_invntt_tomont(&mp.vec[i]);
    684     }
    685     polyvec_add(&mp, &mp, &ep, k);
    686     polyvec_reduce(&mp, k);
    687 
    688     polyvec pkpv;
    689     polyvec_frombytes(&pkpv, (uint8_t*)sk + k * KYBER_POLYBYTES, k);
    690     polyvec_ntt(&pkpv, k);
    691     polyvec_pointwise_acc_montgomery(&vp, &pkpv, &sp, k);
    692     poly_invntt_tomont(&vp);
    693     poly_add(&vp, &vp, &epp);
    694 
    695     poly msg_poly;
    696     for (int i = 0; i < KYBER_N; i++) {
    697         msg_poly.coeffs[i] = ((m[i >> 3] >> (i & 7)) & 1) * (KYBER_Q + 1) / 2;
    698     }
    699     poly_add(&vp, &vp, &msg_poly);
    700     poly_reduce(&vp);
    701 
    702     polyvec_compress(ct2, &mp, k, du);
    703     poly_compress(ct2 + k * (du == 10 ? 320 : du == 11 ? 352 : 128), &vp, dv);
    704 
    705     /* 16. Constant-time comparison */
    706     int ctlen = k * (du == 10 ? 320 : du == 11 ? 352 : 128) + (dv == 4 ? 128 : 160);
    707     fail = ct_memcmp(ct, ct2, ctlen);
    708 
    709     /* 17-18. Implicit rejection: use z if decryption fails */
    710     sha3_256(ct, ctlen, buf);
    711 
    712     /* Constant-time select between K̄' and z */
    713     for (int i = 0; i < KYBER_SYMBYTES; i++) {
    714         kr[i] = ct_select_u8(fail, z[i], kr[i]);
    715     }
    716 
    717     memcpy(buf + KYBER_SYMBYTES, kr, KYBER_SYMBYTES);
    718     shake256(buf, 2 * KYBER_SYMBYTES, ss, KYBER_SYMBYTES);
    719 
    720     /* Secure cleanup */
    721     secure_zero(m, sizeof(m));
    722     secure_zero(kr, sizeof(kr));
    723     secure_zero(&skpv, sizeof(skpv));
    724 
    725     return 0;
    726 }
    727 
    728 /* ============================================================================
    729  * Public API
    730  * ========================================================================= */
    731 
    732 int kyber512_keypair(uint8_t *public_key, uint8_t *secret_key) {
    733     if (!public_key || !secret_key) return -1;
    734     return kyber_keypair_internal(public_key, secret_key, KYBER512_K, KYBER512_ETA1);
    735 }
    736 
    737 int kyber512_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) {
    738     if (!ciphertext || !shared_secret || !public_key) return -1;
    739     return kyber_encapsulate_internal(ciphertext, shared_secret, public_key,
    740                                        KYBER512_K, KYBER512_ETA1, KYBER512_ETA2, KYBER512_DU, KYBER512_DV);
    741 }
    742 
    743 int kyber512_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) {
    744     if (!shared_secret || !ciphertext || !secret_key) return -1;
    745     return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key,
    746                                        KYBER512_K, KYBER512_ETA1, KYBER512_ETA2, KYBER512_DU, KYBER512_DV);
    747 }
    748 
    749 int kyber768_keypair(uint8_t *public_key, uint8_t *secret_key) {
    750     if (!public_key || !secret_key) return -1;
    751     return kyber_keypair_internal(public_key, secret_key, KYBER768_K, KYBER768_ETA1);
    752 }
    753 
    754 int kyber768_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) {
    755     if (!ciphertext || !shared_secret || !public_key) return -1;
    756     return kyber_encapsulate_internal(ciphertext, shared_secret, public_key,
    757                                        KYBER768_K, KYBER768_ETA1, KYBER768_ETA2, KYBER768_DU, KYBER768_DV);
    758 }
    759 
    760 int kyber768_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) {
    761     if (!shared_secret || !ciphertext || !secret_key) return -1;
    762     return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key,
    763                                        KYBER768_K, KYBER768_ETA1, KYBER768_ETA2, KYBER768_DU, KYBER768_DV);
    764 }
    765 
    766 int kyber1024_keypair(uint8_t *public_key, uint8_t *secret_key) {
    767     if (!public_key || !secret_key) return -1;
    768     return kyber_keypair_internal(public_key, secret_key, KYBER1024_K, KYBER1024_ETA1);
    769 }
    770 
    771 int kyber1024_encapsulate(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) {
    772     if (!ciphertext || !shared_secret || !public_key) return -1;
    773     return kyber_encapsulate_internal(ciphertext, shared_secret, public_key,
    774                                        KYBER1024_K, KYBER1024_ETA1, KYBER1024_ETA2, KYBER1024_DU, KYBER1024_DV);
    775 }
    776 
    777 int kyber1024_decapsulate(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) {
    778     if (!shared_secret || !ciphertext || !secret_key) return -1;
    779     return kyber_decapsulate_internal(shared_secret, ciphertext, secret_key,
    780                                        KYBER1024_K, KYBER1024_ETA1, KYBER1024_ETA2, KYBER1024_DU, KYBER1024_DV);
    781 }