Salsa20-Poly1305.c (13278B)
1 /* 2 * Salsa20-Poly1305 AEAD Implementation 3 * 4 * Salsa20: Stream cipher by Daniel J. Bernstein 5 * Poly1305: MAC by Daniel J. Bernstein 6 * Combined as AEAD (Authenticated Encryption with Associated Data) 7 */ 8 9 #include "Salsa20.h" 10 #include <string.h> 11 #include <stdint.h> 12 13 /* Salsa20 quarter round */ 14 #define QR(a, b, c, d) \ 15 b ^= ROTL32(a + d, 7); \ 16 c ^= ROTL32(b + a, 9); \ 17 d ^= ROTL32(c + b, 13); \ 18 a ^= ROTL32(d + c, 18) 19 20 #define ROTL32(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) 21 22 /* Load 32-bit little-endian */ 23 static inline uint32_t load32_le(const uint8_t *src) { 24 return (uint32_t)src[0] | 25 ((uint32_t)src[1] << 8) | 26 ((uint32_t)src[2] << 16) | 27 ((uint32_t)src[3] << 24); 28 } 29 30 /* Store 32-bit little-endian */ 31 static inline void store32_le(uint8_t *dst, uint32_t w) { 32 dst[0] = (uint8_t)w; 33 dst[1] = (uint8_t)(w >> 8); 34 dst[2] = (uint8_t)(w >> 16); 35 dst[3] = (uint8_t)(w >> 24); 36 } 37 38 /* Load 64-bit little-endian */ 39 static inline uint64_t load64_le(const uint8_t *src) { 40 return (uint64_t)src[0] | 41 ((uint64_t)src[1] << 8) | 42 ((uint64_t)src[2] << 16) | 43 ((uint64_t)src[3] << 24) | 44 ((uint64_t)src[4] << 32) | 45 ((uint64_t)src[5] << 40) | 46 ((uint64_t)src[6] << 48) | 47 ((uint64_t)src[7] << 56); 48 } 49 50 /* Store 64-bit little-endian */ 51 static inline void store64_le(uint8_t *dst, uint64_t w) { 52 dst[0] = (uint8_t)w; 53 dst[1] = (uint8_t)(w >> 8); 54 dst[2] = (uint8_t)(w >> 16); 55 dst[3] = (uint8_t)(w >> 24); 56 dst[4] = (uint8_t)(w >> 32); 57 dst[5] = (uint8_t)(w >> 40); 58 dst[6] = (uint8_t)(w >> 48); 59 dst[7] = (uint8_t)(w >> 56); 60 } 61 62 /* Salsa20 block function */ 63 static void salsa20_block(uint32_t out[16], const uint32_t in[16]) { 64 uint32_t x[16]; 65 memcpy(x, in, sizeof(x)); 66 67 /* 20 rounds (10 double rounds) */ 68 for (int i = 0; i < 10; i++) { 69 /* Column rounds */ 70 QR(x[0], x[4], x[8], x[12]); 71 QR(x[5], x[9], x[13], x[1]); 72 QR(x[10], x[14], x[2], x[6]); 73 QR(x[15], x[3], x[7], x[11]); 74 /* Row rounds */ 75 QR(x[0], x[1], x[2], x[3]); 76 QR(x[5], x[6], x[7], x[4]); 77 QR(x[10], x[11], x[8], x[9]); 78 QR(x[15], x[12], x[13], x[14]); 79 } 80 81 for (int i = 0; i < 16; i++) { 82 out[i] = x[i] + in[i]; 83 } 84 } 85 86 /* Initialize Salsa20 */ 87 int salsa20_init(salsa20_context *ctx, const uint8_t *key, 88 const uint8_t *nonce, uint64_t counter) { 89 if (!ctx || !key || !nonce) return -1; 90 91 const char *sigma = "expand 32-byte k"; 92 93 ctx->state[0] = load32_le((const uint8_t *)sigma + 0); 94 ctx->state[1] = load32_le(key + 0); 95 ctx->state[2] = load32_le(key + 4); 96 ctx->state[3] = load32_le(key + 8); 97 ctx->state[4] = load32_le(key + 12); 98 ctx->state[5] = load32_le((const uint8_t *)sigma + 4); 99 ctx->state[6] = load32_le(nonce + 0); 100 ctx->state[7] = load32_le(nonce + 4); 101 ctx->state[8] = (uint32_t)counter; 102 ctx->state[9] = (uint32_t)(counter >> 32); 103 ctx->state[10] = load32_le((const uint8_t *)sigma + 8); 104 ctx->state[11] = load32_le(key + 16); 105 ctx->state[12] = load32_le(key + 20); 106 ctx->state[13] = load32_le(key + 24); 107 ctx->state[14] = load32_le(key + 28); 108 ctx->state[15] = load32_le((const uint8_t *)sigma + 12); 109 110 memcpy(ctx->initial_state, ctx->state, sizeof(ctx->state)); 111 ctx->keystream_pos = 64; /* Force generation on first use */ 112 113 return 0; 114 } 115 116 /* Encrypt/Decrypt with Salsa20 */ 117 int salsa20_crypt(salsa20_context *ctx, const uint8_t *input, 118 uint8_t *output, size_t len) { 119 if (!ctx || !input || !output) return -1; 120 121 for (size_t i = 0; i < len; i++) { 122 if (ctx->keystream_pos >= 64) { 123 /* Generate new keystream block */ 124 uint32_t block[16]; 125 salsa20_block(block, ctx->state); 126 127 /* Convert to bytes */ 128 for (int j = 0; j < 16; j++) { 129 store32_le(ctx->keystream + j * 4, block[j]); 130 } 131 132 /* Increment counter */ 133 ctx->state[8]++; 134 if (ctx->state[8] == 0) { 135 ctx->state[9]++; 136 } 137 138 ctx->keystream_pos = 0; 139 } 140 141 output[i] = input[i] ^ ctx->keystream[ctx->keystream_pos++]; 142 } 143 144 return 0; 145 } 146 147 /* Poly1305 implementation */ 148 int poly1305_init(poly1305_context *ctx, const uint8_t *key) { 149 if (!ctx || !key) return -1; 150 151 /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */ 152 ctx->r[0] = (load32_le(key + 0)) & 0x3ffffff; 153 ctx->r[1] = (load32_le(key + 3) >> 2) & 0x3ffff03; 154 ctx->r[2] = (load32_le(key + 6) >> 4) & 0x3ffc0ff; 155 ctx->r[3] = (load32_le(key + 9) >> 6) & 0x3f03fff; 156 ctx->r[4] = (load32_le(key + 12) >> 8) & 0x00fffff; 157 158 /* h = 0 */ 159 ctx->h[0] = 0; 160 ctx->h[1] = 0; 161 ctx->h[2] = 0; 162 ctx->h[3] = 0; 163 ctx->h[4] = 0; 164 165 /* Save pad for later */ 166 ctx->pad[0] = load32_le(key + 16); 167 ctx->pad[1] = load32_le(key + 20); 168 ctx->pad[2] = load32_le(key + 24); 169 ctx->pad[3] = load32_le(key + 28); 170 171 ctx->buffer_len = 0; 172 173 return 0; 174 } 175 176 static void poly1305_block(poly1305_context *ctx, const uint8_t *data, int final) { 177 uint64_t d0, d1, d2, d3, d4; 178 uint32_t h0, h1, h2, h3, h4; 179 uint32_t r0, r1, r2, r3, r4; 180 uint32_t s1, s2, s3, s4; 181 uint64_t c; 182 183 /* Read r and h */ 184 r0 = ctx->r[0]; r1 = ctx->r[1]; r2 = ctx->r[2]; r3 = ctx->r[3]; r4 = ctx->r[4]; 185 h0 = ctx->h[0]; h1 = ctx->h[1]; h2 = ctx->h[2]; h3 = ctx->h[3]; h4 = ctx->h[4]; 186 187 /* s = 5 * r */ 188 s1 = r1 * 5; s2 = r2 * 5; s3 = r3 * 5; s4 = r4 * 5; 189 190 /* h += m[i] */ 191 h0 += (load32_le(data + 0)) & 0x3ffffff; 192 h1 += (load32_le(data + 3) >> 2) & 0x3ffffff; 193 h2 += (load32_le(data + 6) >> 4) & 0x3ffffff; 194 h3 += (load32_le(data + 9) >> 6) & 0x3ffffff; 195 h4 += (load32_le(data + 12) >> 8) | (final ? 0 : (1 << 24)); 196 197 /* h *= r */ 198 d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + ((uint64_t)h4 * s1); 199 d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + ((uint64_t)h4 * s2); 200 d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + ((uint64_t)h4 * s3); 201 d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + ((uint64_t)h4 * s4); 202 d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + ((uint64_t)h4 * r0); 203 204 /* (partial) h %= p */ 205 c = d0 >> 26; h0 = (uint32_t)d0 & 0x3ffffff; 206 d1 += c; c = d1 >> 26; h1 = (uint32_t)d1 & 0x3ffffff; 207 d2 += c; c = d2 >> 26; h2 = (uint32_t)d2 & 0x3ffffff; 208 d3 += c; c = d3 >> 26; h3 = (uint32_t)d3 & 0x3ffffff; 209 d4 += c; c = d4 >> 26; h4 = (uint32_t)d4 & 0x3ffffff; 210 h0 += (uint32_t)c * 5; c = h0 >> 26; h0 = h0 & 0x3ffffff; 211 h1 += (uint32_t)c; 212 213 ctx->h[0] = h0; ctx->h[1] = h1; ctx->h[2] = h2; ctx->h[3] = h3; ctx->h[4] = h4; 214 } 215 216 int poly1305_update(poly1305_context *ctx, const uint8_t *data, size_t len) { 217 if (!ctx || (len > 0 && !data)) return -1; 218 219 while (len > 0) { 220 size_t take = 16 - ctx->buffer_len; 221 if (take > len) take = len; 222 223 memcpy(ctx->buffer + ctx->buffer_len, data, take); 224 ctx->buffer_len += take; 225 data += take; 226 len -= take; 227 228 if (ctx->buffer_len == 16) { 229 poly1305_block(ctx, ctx->buffer, 0); 230 ctx->buffer_len = 0; 231 } 232 } 233 234 return 0; 235 } 236 237 int poly1305_finish(poly1305_context *ctx, uint8_t *tag) { 238 if (!ctx || !tag) return -1; 239 240 /* Process final block */ 241 if (ctx->buffer_len > 0) { 242 ctx->buffer[ctx->buffer_len] = 1; 243 for (size_t i = ctx->buffer_len + 1; i < 16; i++) { 244 ctx->buffer[i] = 0; 245 } 246 poly1305_block(ctx, ctx->buffer, 1); 247 } 248 249 /* Fully carry h */ 250 uint32_t h0 = ctx->h[0]; 251 uint32_t h1 = ctx->h[1]; 252 uint32_t h2 = ctx->h[2]; 253 uint32_t h3 = ctx->h[3]; 254 uint32_t h4 = ctx->h[4]; 255 256 uint32_t c = h1 >> 26; h1 &= 0x3ffffff; 257 h2 += c; c = h2 >> 26; h2 &= 0x3ffffff; 258 h3 += c; c = h3 >> 26; h3 &= 0x3ffffff; 259 h4 += c; c = h4 >> 26; h4 &= 0x3ffffff; 260 h0 += c * 5; c = h0 >> 26; h0 &= 0x3ffffff; 261 h1 += c; 262 263 /* Compute h - p */ 264 uint32_t g0 = h0 + 5; c = g0 >> 26; g0 &= 0x3ffffff; 265 uint32_t g1 = h1 + c; c = g1 >> 26; g1 &= 0x3ffffff; 266 uint32_t g2 = h2 + c; c = g2 >> 26; g2 &= 0x3ffffff; 267 uint32_t g3 = h3 + c; c = g3 >> 26; g3 &= 0x3ffffff; 268 uint32_t g4 = h4 + c - (1 << 26); 269 270 /* Select h if h < p, or h - p if h >= p */ 271 uint32_t mask = (g4 >> 31) - 1; 272 g0 &= mask; 273 g1 &= mask; 274 g2 &= mask; 275 g3 &= mask; 276 g4 &= mask; 277 mask = ~mask; 278 h0 = (h0 & mask) | g0; 279 h1 = (h1 & mask) | g1; 280 h2 = (h2 & mask) | g2; 281 h3 = (h3 & mask) | g3; 282 h4 = (h4 & mask) | g4; 283 284 /* h = h % (2^128) */ 285 h0 = ((h0) | (h1 << 26)) & 0xffffffff; 286 h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff; 287 h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff; 288 h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff; 289 290 /* mac = (h + pad) % (2^128) */ 291 uint64_t f = (uint64_t)h0 + ctx->pad[0]; h0 = (uint32_t)f; 292 f = (uint64_t)h1 + ctx->pad[1] + (f >> 32); h1 = (uint32_t)f; 293 f = (uint64_t)h2 + ctx->pad[2] + (f >> 32); h2 = (uint32_t)f; 294 f = (uint64_t)h3 + ctx->pad[3] + (f >> 32); h3 = (uint32_t)f; 295 296 store32_le(tag + 0, h0); 297 store32_le(tag + 4, h1); 298 store32_le(tag + 8, h2); 299 store32_le(tag + 12, h3); 300 301 return 0; 302 } 303 304 /* Salsa20-Poly1305 AEAD - typedef already in Salsa20.h */ 305 306 int salsa20_poly1305_init(salsa20_poly1305_context *ctx, const uint8_t *key, const uint8_t *nonce) { 307 if (!ctx || !key || !nonce) return -1; 308 309 /* Initialize Salsa20 with nonce (8 bytes) */ 310 if (salsa20_init(&ctx->cipher, key, nonce, 0) != 0) { 311 return -1; 312 } 313 314 /* Generate Poly1305 key from first 32 bytes of keystream */ 315 uint8_t poly_key[32]; 316 uint8_t zero[32] = {0}; 317 salsa20_crypt(&ctx->cipher, zero, poly_key, 32); 318 319 /* Initialize Poly1305 */ 320 if (poly1305_init(&ctx->mac, poly_key) != 0) { 321 memset(poly_key, 0, 32); 322 return -1; 323 } 324 325 memset(poly_key, 0, 32); 326 return 0; 327 } 328 329 int salsa20_poly1305_encrypt(salsa20_poly1305_context *ctx, 330 const uint8_t *aad, size_t aad_len, 331 const uint8_t *plaintext, size_t pt_len, 332 uint8_t *ciphertext, uint8_t *tag) { 333 if (!ctx || !ciphertext || !tag) return -1; 334 if (pt_len > 0 && !plaintext) return -1; 335 if (aad_len > 0 && !aad) return -1; 336 337 /* Encrypt plaintext */ 338 if (pt_len > 0) { 339 salsa20_crypt(&ctx->cipher, plaintext, ciphertext, pt_len); 340 } 341 342 /* Authenticate AAD */ 343 if (aad_len > 0) { 344 poly1305_update(&ctx->mac, aad, aad_len); 345 /* Pad to 16 bytes */ 346 if (aad_len % 16 != 0) { 347 uint8_t zeros[16] = {0}; 348 poly1305_update(&ctx->mac, zeros, 16 - (aad_len % 16)); 349 } 350 } 351 352 /* Authenticate ciphertext */ 353 if (pt_len > 0) { 354 poly1305_update(&ctx->mac, ciphertext, pt_len); 355 /* Pad to 16 bytes */ 356 if (pt_len % 16 != 0) { 357 uint8_t zeros[16] = {0}; 358 poly1305_update(&ctx->mac, zeros, 16 - (pt_len % 16)); 359 } 360 } 361 362 /* Authenticate lengths */ 363 uint8_t lens[16]; 364 store64_le(lens, aad_len); 365 store64_le(lens + 8, pt_len); 366 poly1305_update(&ctx->mac, lens, 16); 367 368 /* Compute tag */ 369 poly1305_finish(&ctx->mac, tag); 370 371 return 0; 372 } 373 374 int salsa20_poly1305_decrypt(salsa20_poly1305_context *ctx, 375 const uint8_t *aad, size_t aad_len, 376 const uint8_t *ciphertext, size_t ct_len, 377 const uint8_t *tag, uint8_t *plaintext) { 378 if (!ctx || !tag) return -1; 379 if (ct_len > 0 && (!ciphertext || !plaintext)) return -1; 380 if (aad_len > 0 && !aad) return -1; 381 382 /* Authenticate AAD */ 383 if (aad_len > 0) { 384 poly1305_update(&ctx->mac, aad, aad_len); 385 if (aad_len % 16 != 0) { 386 uint8_t zeros[16] = {0}; 387 poly1305_update(&ctx->mac, zeros, 16 - (aad_len % 16)); 388 } 389 } 390 391 /* Authenticate ciphertext */ 392 if (ct_len > 0) { 393 poly1305_update(&ctx->mac, ciphertext, ct_len); 394 if (ct_len % 16 != 0) { 395 uint8_t zeros[16] = {0}; 396 poly1305_update(&ctx->mac, zeros, 16 - (ct_len % 16)); 397 } 398 } 399 400 /* Authenticate lengths */ 401 uint8_t lens[16]; 402 store64_le(lens, aad_len); 403 store64_le(lens + 8, ct_len); 404 poly1305_update(&ctx->mac, lens, 16); 405 406 /* Verify tag */ 407 uint8_t computed_tag[16]; 408 poly1305_finish(&ctx->mac, computed_tag); 409 410 /* Constant-time comparison */ 411 uint8_t diff = 0; 412 for (int i = 0; i < 16; i++) { 413 diff |= tag[i] ^ computed_tag[i]; 414 } 415 if (diff != 0) { 416 memset(computed_tag, 0, 16); 417 return -1; 418 } 419 memset(computed_tag, 0, 16); 420 421 /* Decrypt */ 422 if (ct_len > 0) { 423 salsa20_crypt(&ctx->cipher, ciphertext, plaintext, ct_len); 424 } 425 426 return 0; 427 } 428 429 void salsa20_poly1305_cleanup(salsa20_poly1305_context *ctx) { 430 if (ctx) { 431 memset(ctx, 0, sizeof(salsa20_poly1305_context)); 432 } 433 }