luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

Salsa20-Poly1305.c (13278B)


      1 /*
      2  * Salsa20-Poly1305 AEAD Implementation
      3  *
      4  * Salsa20: Stream cipher by Daniel J. Bernstein
      5  * Poly1305: MAC by Daniel J. Bernstein
      6  * Combined as AEAD (Authenticated Encryption with Associated Data)
      7  */
      8 
      9 #include "Salsa20.h"
     10 #include <string.h>
     11 #include <stdint.h>
     12 
     13 /* Salsa20 quarter round */
     14 #define QR(a, b, c, d) \
     15     b ^= ROTL32(a + d, 7);  \
     16     c ^= ROTL32(b + a, 9);  \
     17     d ^= ROTL32(c + b, 13); \
     18     a ^= ROTL32(d + c, 18)
     19 
     20 #define ROTL32(x, n) (((x) << (n)) | ((x) >> (32 - (n))))
     21 
     22 /* Load 32-bit little-endian */
     23 static inline uint32_t load32_le(const uint8_t *src) {
     24     return (uint32_t)src[0] |
     25            ((uint32_t)src[1] << 8) |
     26            ((uint32_t)src[2] << 16) |
     27            ((uint32_t)src[3] << 24);
     28 }
     29 
     30 /* Store 32-bit little-endian */
     31 static inline void store32_le(uint8_t *dst, uint32_t w) {
     32     dst[0] = (uint8_t)w;
     33     dst[1] = (uint8_t)(w >> 8);
     34     dst[2] = (uint8_t)(w >> 16);
     35     dst[3] = (uint8_t)(w >> 24);
     36 }
     37 
     38 /* Load 64-bit little-endian */
     39 static inline uint64_t load64_le(const uint8_t *src) {
     40     return (uint64_t)src[0] |
     41            ((uint64_t)src[1] << 8) |
     42            ((uint64_t)src[2] << 16) |
     43            ((uint64_t)src[3] << 24) |
     44            ((uint64_t)src[4] << 32) |
     45            ((uint64_t)src[5] << 40) |
     46            ((uint64_t)src[6] << 48) |
     47            ((uint64_t)src[7] << 56);
     48 }
     49 
     50 /* Store 64-bit little-endian */
     51 static inline void store64_le(uint8_t *dst, uint64_t w) {
     52     dst[0] = (uint8_t)w;
     53     dst[1] = (uint8_t)(w >> 8);
     54     dst[2] = (uint8_t)(w >> 16);
     55     dst[3] = (uint8_t)(w >> 24);
     56     dst[4] = (uint8_t)(w >> 32);
     57     dst[5] = (uint8_t)(w >> 40);
     58     dst[6] = (uint8_t)(w >> 48);
     59     dst[7] = (uint8_t)(w >> 56);
     60 }
     61 
     62 /* Salsa20 block function */
     63 static void salsa20_block(uint32_t out[16], const uint32_t in[16]) {
     64     uint32_t x[16];
     65     memcpy(x, in, sizeof(x));
     66 
     67     /* 20 rounds (10 double rounds) */
     68     for (int i = 0; i < 10; i++) {
     69         /* Column rounds */
     70         QR(x[0], x[4], x[8], x[12]);
     71         QR(x[5], x[9], x[13], x[1]);
     72         QR(x[10], x[14], x[2], x[6]);
     73         QR(x[15], x[3], x[7], x[11]);
     74         /* Row rounds */
     75         QR(x[0], x[1], x[2], x[3]);
     76         QR(x[5], x[6], x[7], x[4]);
     77         QR(x[10], x[11], x[8], x[9]);
     78         QR(x[15], x[12], x[13], x[14]);
     79     }
     80 
     81     for (int i = 0; i < 16; i++) {
     82         out[i] = x[i] + in[i];
     83     }
     84 }
     85 
     86 /* Initialize Salsa20 */
     87 int salsa20_init(salsa20_context *ctx, const uint8_t *key,
     88                  const uint8_t *nonce, uint64_t counter) {
     89     if (!ctx || !key || !nonce) return -1;
     90 
     91     const char *sigma = "expand 32-byte k";
     92 
     93     ctx->state[0] = load32_le((const uint8_t *)sigma + 0);
     94     ctx->state[1] = load32_le(key + 0);
     95     ctx->state[2] = load32_le(key + 4);
     96     ctx->state[3] = load32_le(key + 8);
     97     ctx->state[4] = load32_le(key + 12);
     98     ctx->state[5] = load32_le((const uint8_t *)sigma + 4);
     99     ctx->state[6] = load32_le(nonce + 0);
    100     ctx->state[7] = load32_le(nonce + 4);
    101     ctx->state[8] = (uint32_t)counter;
    102     ctx->state[9] = (uint32_t)(counter >> 32);
    103     ctx->state[10] = load32_le((const uint8_t *)sigma + 8);
    104     ctx->state[11] = load32_le(key + 16);
    105     ctx->state[12] = load32_le(key + 20);
    106     ctx->state[13] = load32_le(key + 24);
    107     ctx->state[14] = load32_le(key + 28);
    108     ctx->state[15] = load32_le((const uint8_t *)sigma + 12);
    109 
    110     memcpy(ctx->initial_state, ctx->state, sizeof(ctx->state));
    111     ctx->keystream_pos = 64; /* Force generation on first use */
    112 
    113     return 0;
    114 }
    115 
    116 /* Encrypt/Decrypt with Salsa20 */
    117 int salsa20_crypt(salsa20_context *ctx, const uint8_t *input,
    118                   uint8_t *output, size_t len) {
    119     if (!ctx || !input || !output) return -1;
    120 
    121     for (size_t i = 0; i < len; i++) {
    122         if (ctx->keystream_pos >= 64) {
    123             /* Generate new keystream block */
    124             uint32_t block[16];
    125             salsa20_block(block, ctx->state);
    126 
    127             /* Convert to bytes */
    128             for (int j = 0; j < 16; j++) {
    129                 store32_le(ctx->keystream + j * 4, block[j]);
    130             }
    131 
    132             /* Increment counter */
    133             ctx->state[8]++;
    134             if (ctx->state[8] == 0) {
    135                 ctx->state[9]++;
    136             }
    137 
    138             ctx->keystream_pos = 0;
    139         }
    140 
    141         output[i] = input[i] ^ ctx->keystream[ctx->keystream_pos++];
    142     }
    143 
    144     return 0;
    145 }
    146 
    147 /* Poly1305 implementation */
    148 int poly1305_init(poly1305_context *ctx, const uint8_t *key) {
    149     if (!ctx || !key) return -1;
    150 
    151     /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
    152     ctx->r[0] = (load32_le(key + 0)) & 0x3ffffff;
    153     ctx->r[1] = (load32_le(key + 3) >> 2) & 0x3ffff03;
    154     ctx->r[2] = (load32_le(key + 6) >> 4) & 0x3ffc0ff;
    155     ctx->r[3] = (load32_le(key + 9) >> 6) & 0x3f03fff;
    156     ctx->r[4] = (load32_le(key + 12) >> 8) & 0x00fffff;
    157 
    158     /* h = 0 */
    159     ctx->h[0] = 0;
    160     ctx->h[1] = 0;
    161     ctx->h[2] = 0;
    162     ctx->h[3] = 0;
    163     ctx->h[4] = 0;
    164 
    165     /* Save pad for later */
    166     ctx->pad[0] = load32_le(key + 16);
    167     ctx->pad[1] = load32_le(key + 20);
    168     ctx->pad[2] = load32_le(key + 24);
    169     ctx->pad[3] = load32_le(key + 28);
    170 
    171     ctx->buffer_len = 0;
    172 
    173     return 0;
    174 }
    175 
    176 static void poly1305_block(poly1305_context *ctx, const uint8_t *data, int final) {
    177     uint64_t d0, d1, d2, d3, d4;
    178     uint32_t h0, h1, h2, h3, h4;
    179     uint32_t r0, r1, r2, r3, r4;
    180     uint32_t s1, s2, s3, s4;
    181     uint64_t c;
    182 
    183     /* Read r and h */
    184     r0 = ctx->r[0]; r1 = ctx->r[1]; r2 = ctx->r[2]; r3 = ctx->r[3]; r4 = ctx->r[4];
    185     h0 = ctx->h[0]; h1 = ctx->h[1]; h2 = ctx->h[2]; h3 = ctx->h[3]; h4 = ctx->h[4];
    186 
    187     /* s = 5 * r */
    188     s1 = r1 * 5; s2 = r2 * 5; s3 = r3 * 5; s4 = r4 * 5;
    189 
    190     /* h += m[i] */
    191     h0 += (load32_le(data + 0)) & 0x3ffffff;
    192     h1 += (load32_le(data + 3) >> 2) & 0x3ffffff;
    193     h2 += (load32_le(data + 6) >> 4) & 0x3ffffff;
    194     h3 += (load32_le(data + 9) >> 6) & 0x3ffffff;
    195     h4 += (load32_le(data + 12) >> 8) | (final ? 0 : (1 << 24));
    196 
    197     /* h *= r */
    198     d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + ((uint64_t)h4 * s1);
    199     d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + ((uint64_t)h4 * s2);
    200     d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + ((uint64_t)h4 * s3);
    201     d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + ((uint64_t)h4 * s4);
    202     d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + ((uint64_t)h4 * r0);
    203 
    204     /* (partial) h %= p */
    205     c = d0 >> 26; h0 = (uint32_t)d0 & 0x3ffffff;
    206     d1 += c;      c = d1 >> 26; h1 = (uint32_t)d1 & 0x3ffffff;
    207     d2 += c;      c = d2 >> 26; h2 = (uint32_t)d2 & 0x3ffffff;
    208     d3 += c;      c = d3 >> 26; h3 = (uint32_t)d3 & 0x3ffffff;
    209     d4 += c;      c = d4 >> 26; h4 = (uint32_t)d4 & 0x3ffffff;
    210     h0 += (uint32_t)c * 5;    c = h0 >> 26; h0 = h0 & 0x3ffffff;
    211     h1 += (uint32_t)c;
    212 
    213     ctx->h[0] = h0; ctx->h[1] = h1; ctx->h[2] = h2; ctx->h[3] = h3; ctx->h[4] = h4;
    214 }
    215 
    216 int poly1305_update(poly1305_context *ctx, const uint8_t *data, size_t len) {
    217     if (!ctx || (len > 0 && !data)) return -1;
    218 
    219     while (len > 0) {
    220         size_t take = 16 - ctx->buffer_len;
    221         if (take > len) take = len;
    222 
    223         memcpy(ctx->buffer + ctx->buffer_len, data, take);
    224         ctx->buffer_len += take;
    225         data += take;
    226         len -= take;
    227 
    228         if (ctx->buffer_len == 16) {
    229             poly1305_block(ctx, ctx->buffer, 0);
    230             ctx->buffer_len = 0;
    231         }
    232     }
    233 
    234     return 0;
    235 }
    236 
    237 int poly1305_finish(poly1305_context *ctx, uint8_t *tag) {
    238     if (!ctx || !tag) return -1;
    239 
    240     /* Process final block */
    241     if (ctx->buffer_len > 0) {
    242         ctx->buffer[ctx->buffer_len] = 1;
    243         for (size_t i = ctx->buffer_len + 1; i < 16; i++) {
    244             ctx->buffer[i] = 0;
    245         }
    246         poly1305_block(ctx, ctx->buffer, 1);
    247     }
    248 
    249     /* Fully carry h */
    250     uint32_t h0 = ctx->h[0];
    251     uint32_t h1 = ctx->h[1];
    252     uint32_t h2 = ctx->h[2];
    253     uint32_t h3 = ctx->h[3];
    254     uint32_t h4 = ctx->h[4];
    255 
    256     uint32_t c = h1 >> 26; h1 &= 0x3ffffff;
    257     h2 += c;     c = h2 >> 26; h2 &= 0x3ffffff;
    258     h3 += c;     c = h3 >> 26; h3 &= 0x3ffffff;
    259     h4 += c;     c = h4 >> 26; h4 &= 0x3ffffff;
    260     h0 += c * 5; c = h0 >> 26; h0 &= 0x3ffffff;
    261     h1 += c;
    262 
    263     /* Compute h - p */
    264     uint32_t g0 = h0 + 5; c = g0 >> 26; g0 &= 0x3ffffff;
    265     uint32_t g1 = h1 + c; c = g1 >> 26; g1 &= 0x3ffffff;
    266     uint32_t g2 = h2 + c; c = g2 >> 26; g2 &= 0x3ffffff;
    267     uint32_t g3 = h3 + c; c = g3 >> 26; g3 &= 0x3ffffff;
    268     uint32_t g4 = h4 + c - (1 << 26);
    269 
    270     /* Select h if h < p, or h - p if h >= p */
    271     uint32_t mask = (g4 >> 31) - 1;
    272     g0 &= mask;
    273     g1 &= mask;
    274     g2 &= mask;
    275     g3 &= mask;
    276     g4 &= mask;
    277     mask = ~mask;
    278     h0 = (h0 & mask) | g0;
    279     h1 = (h1 & mask) | g1;
    280     h2 = (h2 & mask) | g2;
    281     h3 = (h3 & mask) | g3;
    282     h4 = (h4 & mask) | g4;
    283 
    284     /* h = h % (2^128) */
    285     h0 = ((h0) | (h1 << 26)) & 0xffffffff;
    286     h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff;
    287     h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
    288     h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff;
    289 
    290     /* mac = (h + pad) % (2^128) */
    291     uint64_t f = (uint64_t)h0 + ctx->pad[0]; h0 = (uint32_t)f;
    292     f = (uint64_t)h1 + ctx->pad[1] + (f >> 32); h1 = (uint32_t)f;
    293     f = (uint64_t)h2 + ctx->pad[2] + (f >> 32); h2 = (uint32_t)f;
    294     f = (uint64_t)h3 + ctx->pad[3] + (f >> 32); h3 = (uint32_t)f;
    295 
    296     store32_le(tag + 0, h0);
    297     store32_le(tag + 4, h1);
    298     store32_le(tag + 8, h2);
    299     store32_le(tag + 12, h3);
    300 
    301     return 0;
    302 }
    303 
    304 /* Salsa20-Poly1305 AEAD - typedef already in Salsa20.h */
    305 
    306 int salsa20_poly1305_init(salsa20_poly1305_context *ctx, const uint8_t *key, const uint8_t *nonce) {
    307     if (!ctx || !key || !nonce) return -1;
    308 
    309     /* Initialize Salsa20 with nonce (8 bytes) */
    310     if (salsa20_init(&ctx->cipher, key, nonce, 0) != 0) {
    311         return -1;
    312     }
    313 
    314     /* Generate Poly1305 key from first 32 bytes of keystream */
    315     uint8_t poly_key[32];
    316     uint8_t zero[32] = {0};
    317     salsa20_crypt(&ctx->cipher, zero, poly_key, 32);
    318 
    319     /* Initialize Poly1305 */
    320     if (poly1305_init(&ctx->mac, poly_key) != 0) {
    321         memset(poly_key, 0, 32);
    322         return -1;
    323     }
    324 
    325     memset(poly_key, 0, 32);
    326     return 0;
    327 }
    328 
    329 int salsa20_poly1305_encrypt(salsa20_poly1305_context *ctx,
    330                              const uint8_t *aad, size_t aad_len,
    331                              const uint8_t *plaintext, size_t pt_len,
    332                              uint8_t *ciphertext, uint8_t *tag) {
    333     if (!ctx || !ciphertext || !tag) return -1;
    334     if (pt_len > 0 && !plaintext) return -1;
    335     if (aad_len > 0 && !aad) return -1;
    336 
    337     /* Encrypt plaintext */
    338     if (pt_len > 0) {
    339         salsa20_crypt(&ctx->cipher, plaintext, ciphertext, pt_len);
    340     }
    341 
    342     /* Authenticate AAD */
    343     if (aad_len > 0) {
    344         poly1305_update(&ctx->mac, aad, aad_len);
    345         /* Pad to 16 bytes */
    346         if (aad_len % 16 != 0) {
    347             uint8_t zeros[16] = {0};
    348             poly1305_update(&ctx->mac, zeros, 16 - (aad_len % 16));
    349         }
    350     }
    351 
    352     /* Authenticate ciphertext */
    353     if (pt_len > 0) {
    354         poly1305_update(&ctx->mac, ciphertext, pt_len);
    355         /* Pad to 16 bytes */
    356         if (pt_len % 16 != 0) {
    357             uint8_t zeros[16] = {0};
    358             poly1305_update(&ctx->mac, zeros, 16 - (pt_len % 16));
    359         }
    360     }
    361 
    362     /* Authenticate lengths */
    363     uint8_t lens[16];
    364     store64_le(lens, aad_len);
    365     store64_le(lens + 8, pt_len);
    366     poly1305_update(&ctx->mac, lens, 16);
    367 
    368     /* Compute tag */
    369     poly1305_finish(&ctx->mac, tag);
    370 
    371     return 0;
    372 }
    373 
    374 int salsa20_poly1305_decrypt(salsa20_poly1305_context *ctx,
    375                              const uint8_t *aad, size_t aad_len,
    376                              const uint8_t *ciphertext, size_t ct_len,
    377                              const uint8_t *tag, uint8_t *plaintext) {
    378     if (!ctx || !tag) return -1;
    379     if (ct_len > 0 && (!ciphertext || !plaintext)) return -1;
    380     if (aad_len > 0 && !aad) return -1;
    381 
    382     /* Authenticate AAD */
    383     if (aad_len > 0) {
    384         poly1305_update(&ctx->mac, aad, aad_len);
    385         if (aad_len % 16 != 0) {
    386             uint8_t zeros[16] = {0};
    387             poly1305_update(&ctx->mac, zeros, 16 - (aad_len % 16));
    388         }
    389     }
    390 
    391     /* Authenticate ciphertext */
    392     if (ct_len > 0) {
    393         poly1305_update(&ctx->mac, ciphertext, ct_len);
    394         if (ct_len % 16 != 0) {
    395             uint8_t zeros[16] = {0};
    396             poly1305_update(&ctx->mac, zeros, 16 - (ct_len % 16));
    397         }
    398     }
    399 
    400     /* Authenticate lengths */
    401     uint8_t lens[16];
    402     store64_le(lens, aad_len);
    403     store64_le(lens + 8, ct_len);
    404     poly1305_update(&ctx->mac, lens, 16);
    405 
    406     /* Verify tag */
    407     uint8_t computed_tag[16];
    408     poly1305_finish(&ctx->mac, computed_tag);
    409 
    410     /* Constant-time comparison */
    411     uint8_t diff = 0;
    412     for (int i = 0; i < 16; i++) {
    413         diff |= tag[i] ^ computed_tag[i];
    414     }
    415     if (diff != 0) {
    416         memset(computed_tag, 0, 16);
    417         return -1;
    418     }
    419     memset(computed_tag, 0, 16);
    420 
    421     /* Decrypt */
    422     if (ct_len > 0) {
    423         salsa20_crypt(&ctx->cipher, ciphertext, plaintext, ct_len);
    424     }
    425 
    426     return 0;
    427 }
    428 
    429 void salsa20_poly1305_cleanup(salsa20_poly1305_context *ctx) {
    430     if (ctx) {
    431         memset(ctx, 0, sizeof(salsa20_poly1305_context));
    432     }
    433 }