RSA_production.c (27166B)
1 /* 2 * RSA Production Implementation with GMP 3 * Supports 2048, 3072, and 4096-bit keys 4 * Implements PKCS#1 v1.5 padding 5 * 6 * Compile with: gcc -O3 -o rsa_prod RSA_production.c -lgmp 7 */ 8 9 #include "RSA.h" 10 #include "SHA-256.h" 11 #include "CSPRNG.h" 12 #include <stdio.h> 13 #include <stdlib.h> 14 #include <string.h> 15 #include <stdint.h> 16 #include <time.h> 17 #include <gmp.h> 18 19 // Random number generation using global CSPRNG 20 static void get_random_bytes(uint8_t *buf, size_t len) { 21 // Use the global CSPRNG (will auto-initialize if needed) 22 random_bytes(buf, len); 23 } 24 25 // Miller-Rabin primality test with GMP 26 static int is_prime_gmp(mpz_t n, int iterations) { 27 return mpz_probab_prime_p(n, iterations) > 0; 28 } 29 30 // Generate a random prime of specified bit size 31 static void generate_prime(mpz_t prime, int bits, gmp_randstate_t state) { 32 mpz_t candidate; 33 mpz_init(candidate); 34 35 do { 36 // Generate random number of specified bit size 37 mpz_urandomb(candidate, state, bits); 38 // Set high bit to ensure correct bit length 39 mpz_setbit(candidate, bits - 1); 40 // Set low bit to make it odd 41 mpz_setbit(candidate, 0); 42 43 // Find next prime 44 mpz_nextprime(prime, candidate); 45 } while (mpz_sizeinbase(prime, 2) != (size_t)bits); 46 47 mpz_clear(candidate); 48 } 49 50 // Convert mpz_t to byte array 51 static void mpz_to_bytes(const mpz_t n, uint8_t **out, size_t *out_len) { 52 size_t count; 53 *out_len = (mpz_sizeinbase(n, 2) + 7) / 8; 54 *out = malloc(*out_len); 55 56 mpz_export(*out, &count, 1, 1, 1, 0, n); 57 58 // Handle leading zeros 59 if (count < *out_len) { 60 memmove(*out + (*out_len - count), *out, count); 61 memset(*out, 0, *out_len - count); 62 } 63 } 64 65 // Convert byte array to mpz_t 66 static void bytes_to_mpz(mpz_t n, const uint8_t *bytes, size_t len) { 67 mpz_import(n, len, 1, 1, 1, 0, bytes); 68 } 69 70 // RSA key generation with GMP 71 int rsa_generate_key_simple(rsa_public_key *pub, rsa_private_key *priv, int bits) { 72 if (bits < 2048) { 73 fprintf(stderr, "Key size must be at least 2048 bits for security\n"); 74 return -1; 75 } 76 77 if (bits % 2 != 0) { 78 fprintf(stderr, "Key size must be even\n"); 79 return -1; 80 } 81 82 printf("Generating %d-bit RSA key pair...\n", bits); 83 printf("This may take a minute...\n"); 84 85 // Initialize GMP random state 86 gmp_randstate_t state; 87 gmp_randinit_default(state); 88 89 // Seed with random data 90 uint8_t seed[32]; 91 get_random_bytes(seed, 32); 92 mpz_t seed_mpz; 93 mpz_init(seed_mpz); 94 bytes_to_mpz(seed_mpz, seed, 32); 95 gmp_randseed(state, seed_mpz); 96 mpz_clear(seed_mpz); 97 98 mpz_t p, q, n, phi, e, d, p_minus_1, q_minus_1; 99 mpz_t dp, dq, qinv, gcd_val; 100 101 mpz_init(p); 102 mpz_init(q); 103 mpz_init(n); 104 mpz_init(phi); 105 mpz_init(e); 106 mpz_init(d); 107 mpz_init(p_minus_1); 108 mpz_init(q_minus_1); 109 mpz_init(dp); 110 mpz_init(dq); 111 mpz_init(qinv); 112 mpz_init(gcd_val); 113 114 // Generate two primes 115 int prime_bits = bits / 2; 116 printf("Generating first prime (%d bits)...\n", prime_bits); 117 generate_prime(p, prime_bits, state); 118 119 printf("Generating second prime (%d bits)...\n", prime_bits); 120 do { 121 generate_prime(q, prime_bits, state); 122 } while (mpz_cmp(p, q) == 0); // Ensure p != q 123 124 // Calculate n = p * q 125 mpz_mul(n, p, q); 126 127 // Calculate φ(n) = (p-1)(q-1) 128 mpz_sub_ui(p_minus_1, p, 1); 129 mpz_sub_ui(q_minus_1, q, 1); 130 mpz_mul(phi, p_minus_1, q_minus_1); 131 132 // Choose e = 65537 (standard public exponent) 133 mpz_set_ui(e, 65537); 134 135 // Verify gcd(e, φ(n)) = 1 136 mpz_gcd(gcd_val, e, phi); 137 if (mpz_cmp_ui(gcd_val, 1) != 0) { 138 fprintf(stderr, "gcd(e, phi) != 1, key generation failed\n"); 139 goto cleanup; 140 } 141 142 // Calculate d = e^-1 mod φ(n) 143 if (!mpz_invert(d, e, phi)) { 144 fprintf(stderr, "Failed to compute private exponent\n"); 145 goto cleanup; 146 } 147 148 // Calculate CRT parameters 149 // dp = d mod (p-1) 150 mpz_mod(dp, d, p_minus_1); 151 152 // dq = d mod (q-1) 153 mpz_mod(dq, d, q_minus_1); 154 155 // qinv = q^-1 mod p 156 if (!mpz_invert(qinv, q, p)) { 157 fprintf(stderr, "Failed to compute qinv\n"); 158 goto cleanup; 159 } 160 161 // Convert to byte arrays 162 mpz_to_bytes(n, &pub->n, &pub->n_len); 163 mpz_to_bytes(e, &pub->e, &pub->e_len); 164 165 mpz_to_bytes(n, &priv->n, &priv->n_len); 166 mpz_to_bytes(e, &priv->e, &priv->e_len); 167 mpz_to_bytes(d, &priv->d, &priv->d_len); 168 mpz_to_bytes(p, &priv->p, &priv->p_len); 169 mpz_to_bytes(q, &priv->q, &priv->q_len); 170 mpz_to_bytes(dp, &priv->dp, &priv->dp_len); 171 mpz_to_bytes(dq, &priv->dq, &priv->dq_len); 172 mpz_to_bytes(qinv, &priv->qinv, &priv->qinv_len); 173 174 printf("Key generation complete!\n"); 175 printf("Modulus size: %zu bits\n", mpz_sizeinbase(n, 2)); 176 177 cleanup: 178 mpz_clear(p); 179 mpz_clear(q); 180 mpz_clear(n); 181 mpz_clear(phi); 182 mpz_clear(e); 183 mpz_clear(d); 184 mpz_clear(p_minus_1); 185 mpz_clear(q_minus_1); 186 mpz_clear(dp); 187 mpz_clear(dq); 188 mpz_clear(qinv); 189 mpz_clear(gcd_val); 190 gmp_randclear(state); 191 192 return 0; 193 } 194 195 // PKCS#1 v1.5 padding for encryption 196 static int pkcs1_v15_encode(const uint8_t *message, size_t msg_len, 197 uint8_t *padded, size_t padded_len) { 198 if (msg_len > padded_len - 11) { 199 fprintf(stderr, "Message too long for key size\n"); 200 return -1; 201 } 202 203 // EM = 0x00 || 0x02 || PS || 0x00 || M 204 padded[0] = 0x00; 205 padded[1] = 0x02; 206 207 // PS: padding string (at least 8 random non-zero bytes) 208 size_t ps_len = padded_len - msg_len - 3; 209 get_random_bytes(padded + 2, ps_len); 210 211 // Ensure no zero bytes in padding 212 for (size_t i = 0; i < ps_len; i++) { 213 while (padded[2 + i] == 0) { 214 get_random_bytes(&padded[2 + i], 1); 215 } 216 } 217 218 padded[2 + ps_len] = 0x00; 219 memcpy(padded + 2 + ps_len + 1, message, msg_len); 220 221 return 0; 222 } 223 224 // PKCS#1 v1.5 padding removal for decryption 225 static int pkcs1_v15_decode(const uint8_t *padded, size_t padded_len, 226 uint8_t *message, size_t *msg_len) { 227 if (padded_len < 11) return -1; 228 if (padded[0] != 0x00) return -1; 229 if (padded[1] != 0x02) return -1; 230 231 // Find the 0x00 separator 232 size_t i; 233 for (i = 2; i < padded_len; i++) { 234 if (padded[i] == 0x00) break; 235 } 236 237 if (i < 10 || i >= padded_len) return -1; // Invalid padding 238 239 // Copy message 240 *msg_len = padded_len - i - 1; 241 memcpy(message, padded + i + 1, *msg_len); 242 243 return 0; 244 } 245 246 // RSA encryption 247 int rsa_encrypt(const rsa_public_key *key, 248 const uint8_t *plaintext, size_t pt_len, 249 uint8_t *ciphertext, size_t *ct_len) { 250 if (pt_len > key->n_len - 11) { 251 fprintf(stderr, "Plaintext too long for key size\n"); 252 return -1; 253 } 254 255 // Apply PKCS#1 v1.5 padding 256 uint8_t *padded = malloc(key->n_len); 257 if (pkcs1_v15_encode(plaintext, pt_len, padded, key->n_len) != 0) { 258 free(padded); 259 return -1; 260 } 261 262 // Convert to GMP integers 263 mpz_t m, c, n, e; 264 mpz_init(m); 265 mpz_init(c); 266 mpz_init(n); 267 mpz_init(e); 268 269 bytes_to_mpz(m, padded, key->n_len); 270 bytes_to_mpz(n, key->n, key->n_len); 271 bytes_to_mpz(e, key->e, key->e_len); 272 273 // c = m^e mod n 274 mpz_powm(c, m, e, n); 275 276 // Convert back to bytes 277 *ct_len = key->n_len; 278 size_t count; 279 mpz_export(ciphertext, &count, 1, 1, 1, 0, c); 280 281 // Pad with leading zeros if necessary 282 if (count < *ct_len) { 283 memmove(ciphertext + (*ct_len - count), ciphertext, count); 284 memset(ciphertext, 0, *ct_len - count); 285 } 286 287 mpz_clear(m); 288 mpz_clear(c); 289 mpz_clear(n); 290 mpz_clear(e); 291 free(padded); 292 293 return 0; 294 } 295 296 // RSA decryption with CRT 297 int rsa_decrypt(const rsa_private_key *key, 298 const uint8_t *ciphertext, size_t ct_len, 299 uint8_t *plaintext, size_t *pt_len) { 300 if (ct_len != key->n_len) { 301 fprintf(stderr, "Invalid ciphertext length\n"); 302 return -1; 303 } 304 305 mpz_t c, m, p, q, dp, dq, qinv, m1, m2, h; 306 mpz_init(c); 307 mpz_init(m); 308 mpz_init(p); 309 mpz_init(q); 310 mpz_init(dp); 311 mpz_init(dq); 312 mpz_init(qinv); 313 mpz_init(m1); 314 mpz_init(m2); 315 mpz_init(h); 316 317 bytes_to_mpz(c, ciphertext, ct_len); 318 bytes_to_mpz(p, key->p, key->p_len); 319 bytes_to_mpz(q, key->q, key->q_len); 320 bytes_to_mpz(dp, key->dp, key->dp_len); 321 bytes_to_mpz(dq, key->dq, key->dq_len); 322 bytes_to_mpz(qinv, key->qinv, key->qinv_len); 323 324 // CRT decryption for speed 325 // m1 = c^dp mod p 326 mpz_powm(m1, c, dp, p); 327 328 // m2 = c^dq mod q 329 mpz_powm(m2, c, dq, q); 330 331 // h = qinv * (m1 - m2) mod p 332 mpz_sub(h, m1, m2); 333 mpz_mul(h, qinv, h); 334 mpz_mod(h, h, p); 335 336 // m = m2 + h * q 337 mpz_mul(h, h, q); 338 mpz_add(m, m2, h); 339 340 // Convert to bytes 341 uint8_t *padded = malloc(key->n_len); 342 size_t count; 343 mpz_export(padded, &count, 1, 1, 1, 0, m); 344 345 // Pad with leading zeros if necessary 346 if (count < key->n_len) { 347 memmove(padded + (key->n_len - count), padded, count); 348 memset(padded, 0, key->n_len - count); 349 } 350 351 // Remove PKCS#1 v1.5 padding 352 int result = pkcs1_v15_decode(padded, key->n_len, plaintext, pt_len); 353 354 mpz_clear(c); 355 mpz_clear(m); 356 mpz_clear(p); 357 mpz_clear(q); 358 mpz_clear(dp); 359 mpz_clear(dq); 360 mpz_clear(qinv); 361 mpz_clear(m1); 362 mpz_clear(m2); 363 mpz_clear(h); 364 free(padded); 365 366 return result; 367 } 368 369 // RSA Sign (same as decrypt operation) 370 int rsa_sign(const rsa_private_key *key, 371 const uint8_t *message, size_t msg_len, 372 uint8_t *signature, size_t *sig_len) { 373 if (msg_len > key->n_len - 11) { 374 fprintf(stderr, "Message too long. Hash it first with SHA-256.\n"); 375 return -1; 376 } 377 378 // Apply PKCS#1 v1.5 padding 379 uint8_t *padded = malloc(key->n_len); 380 padded[0] = 0x00; 381 padded[1] = 0x01; // 0x01 for signatures 382 383 size_t ps_len = key->n_len - msg_len - 3; 384 memset(padded + 2, 0xFF, ps_len); // 0xFF padding for signatures 385 386 padded[2 + ps_len] = 0x00; 387 memcpy(padded + 2 + ps_len + 1, message, msg_len); 388 389 // Sign using CRT (same as decrypt) 390 mpz_t m, s, p, q, dp, dq, qinv, s1, s2, h; 391 mpz_init(m); 392 mpz_init(s); 393 mpz_init(p); 394 mpz_init(q); 395 mpz_init(dp); 396 mpz_init(dq); 397 mpz_init(qinv); 398 mpz_init(s1); 399 mpz_init(s2); 400 mpz_init(h); 401 402 bytes_to_mpz(m, padded, key->n_len); 403 bytes_to_mpz(p, key->p, key->p_len); 404 bytes_to_mpz(q, key->q, key->q_len); 405 bytes_to_mpz(dp, key->dp, key->dp_len); 406 bytes_to_mpz(dq, key->dq, key->dq_len); 407 bytes_to_mpz(qinv, key->qinv, key->qinv_len); 408 409 // CRT signing 410 mpz_powm(s1, m, dp, p); 411 mpz_powm(s2, m, dq, q); 412 mpz_sub(h, s1, s2); 413 mpz_mul(h, qinv, h); 414 mpz_mod(h, h, p); 415 mpz_mul(h, h, q); 416 mpz_add(s, s2, h); 417 418 // Convert to bytes 419 *sig_len = key->n_len; 420 size_t count; 421 mpz_export(signature, &count, 1, 1, 1, 0, s); 422 423 if (count < *sig_len) { 424 memmove(signature + (*sig_len - count), signature, count); 425 memset(signature, 0, *sig_len - count); 426 } 427 428 mpz_clear(m); 429 mpz_clear(s); 430 mpz_clear(p); 431 mpz_clear(q); 432 mpz_clear(dp); 433 mpz_clear(dq); 434 mpz_clear(qinv); 435 mpz_clear(s1); 436 mpz_clear(s2); 437 mpz_clear(h); 438 free(padded); 439 440 return 0; 441 } 442 443 // RSA Verify signature 444 int rsa_verify(const rsa_public_key *key, 445 const uint8_t *message, size_t msg_len, 446 const uint8_t *signature, size_t sig_len) { 447 if (sig_len != key->n_len) { 448 return -1; 449 } 450 451 mpz_t s, m, n, e; 452 mpz_init(s); 453 mpz_init(m); 454 mpz_init(n); 455 mpz_init(e); 456 457 bytes_to_mpz(s, signature, sig_len); 458 bytes_to_mpz(n, key->n, key->n_len); 459 bytes_to_mpz(e, key->e, key->e_len); 460 461 // m = s^e mod n 462 mpz_powm(m, s, e, n); 463 464 // Convert to bytes 465 uint8_t *decrypted = malloc(key->n_len); 466 size_t count; 467 mpz_export(decrypted, &count, 1, 1, 1, 0, m); 468 469 if (count < key->n_len) { 470 memmove(decrypted + (key->n_len - count), decrypted, count); 471 memset(decrypted, 0, key->n_len - count); 472 } 473 474 // Verify padding and message 475 int result = -1; 476 if (decrypted[0] == 0x00 && decrypted[1] == 0x01) { 477 size_t i; 478 for (i = 2; i < key->n_len; i++) { 479 if (decrypted[i] == 0x00) break; 480 if (decrypted[i] != 0xFF) goto cleanup; 481 } 482 483 if (i >= 10 && i < key->n_len) { 484 size_t recovered_len = key->n_len - i - 1; 485 if (recovered_len == msg_len && 486 memcmp(decrypted + i + 1, message, msg_len) == 0) { 487 result = 0; 488 } 489 } 490 } 491 492 cleanup: 493 mpz_clear(s); 494 mpz_clear(m); 495 mpz_clear(n); 496 mpz_clear(e); 497 free(decrypted); 498 499 return result; 500 } 501 502 /* ============================================================================ 503 * RSA-PSS (Probabilistic Signature Scheme) Implementation 504 * RFC 8017 - PKCS #1: RSA Cryptography Specifications Version 2.2 505 * ========================================================================= */ 506 507 /** 508 * MGF1 - Mask Generation Function based on SHA-256 509 * Used in PSS padding 510 */ 511 static void mgf1_sha256(const uint8_t *seed, size_t seed_len, 512 uint8_t *mask, size_t mask_len) { 513 uint8_t counter[4]; 514 size_t offset = 0; 515 uint32_t count = 0; 516 517 while (offset < mask_len) { 518 // Convert counter to big-endian bytes 519 counter[0] = (count >> 24) & 0xFF; 520 counter[1] = (count >> 16) & 0xFF; 521 counter[2] = (count >> 8) & 0xFF; 522 counter[3] = count & 0xFF; 523 524 // Hash seed || counter 525 uint8_t hash_input[256]; // Max seed length 526 memcpy(hash_input, seed, seed_len); 527 memcpy(hash_input + seed_len, counter, 4); 528 529 uint8_t hash[32]; 530 sha256(hash_input, seed_len + 4, hash); 531 532 // Copy to mask 533 size_t to_copy = (mask_len - offset < 32) ? (mask_len - offset) : 32; 534 memcpy(mask + offset, hash, to_copy); 535 536 offset += to_copy; 537 count++; 538 } 539 } 540 541 /** 542 * RSA-PSS Sign 543 */ 544 int rsa_sign_pss(const rsa_private_key *key, 545 const uint8_t *message, size_t msg_len, 546 uint8_t *signature, size_t *sig_len) { 547 // PSS parameters (using SHA-256) 548 const size_t h_len = 32; // SHA-256 output length 549 const size_t s_len = 32; // Salt length (recommended: same as hash length) 550 551 if (msg_len != h_len) { 552 fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n"); 553 return -1; 554 } 555 556 size_t em_len = key->n_len; // Encoded message length 557 if (em_len < h_len + s_len + 2) { 558 fprintf(stderr, "RSA-PSS: Key too short for PSS padding\n"); 559 return -1; 560 } 561 562 // Generate salt 563 uint8_t salt[32]; 564 random_bytes(salt, s_len); 565 566 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt 567 uint8_t m_prime[8 + 32 + 32]; 568 memset(m_prime, 0, 8); 569 memcpy(m_prime + 8, message, h_len); 570 memcpy(m_prime + 8 + h_len, salt, s_len); 571 572 // H = Hash(M') 573 uint8_t h[32]; 574 sha256(m_prime, 8 + h_len + s_len, h); 575 576 // Generate DB = PS || 0x01 || salt 577 size_t ps_len = em_len - s_len - h_len - 2; 578 uint8_t *db = calloc(em_len - h_len - 1, 1); // Zero-initialized 579 db[ps_len] = 0x01; 580 memcpy(db + ps_len + 1, salt, s_len); 581 582 // dbMask = MGF(H, emLen - hLen - 1) 583 size_t db_len = em_len - h_len - 1; 584 uint8_t *db_mask = malloc(db_len); 585 mgf1_sha256(h, h_len, db_mask, db_len); 586 587 // maskedDB = DB xor dbMask 588 for (size_t i = 0; i < db_len; i++) { 589 db[i] ^= db_mask[i]; 590 } 591 592 // Set leftmost bits of maskedDB to zero (emBits - 1 % 8) 593 size_t em_bits = key->n_len * 8 - 1; // Adjust for modulus bit length 594 size_t mask_bits = 8 * em_len - em_bits; 595 if (mask_bits > 0) { 596 db[0] &= (0xFF >> mask_bits); 597 } 598 599 // EM = maskedDB || H || 0xbc 600 uint8_t *em = malloc(em_len); 601 memcpy(em, db, db_len); 602 memcpy(em + db_len, h, h_len); 603 em[em_len - 1] = 0xbc; 604 605 // Sign using CRT (same as PKCS#1 v1.5) 606 mpz_t m, s, p, q, dp, dq, qinv, s1, s2, h_crt; 607 mpz_init(m); 608 mpz_init(s); 609 mpz_init(p); 610 mpz_init(q); 611 mpz_init(dp); 612 mpz_init(dq); 613 mpz_init(qinv); 614 mpz_init(s1); 615 mpz_init(s2); 616 mpz_init(h_crt); 617 618 bytes_to_mpz(m, em, em_len); 619 bytes_to_mpz(p, key->p, key->p_len); 620 bytes_to_mpz(q, key->q, key->q_len); 621 bytes_to_mpz(dp, key->dp, key->dp_len); 622 bytes_to_mpz(dq, key->dq, key->dq_len); 623 bytes_to_mpz(qinv, key->qinv, key->qinv_len); 624 625 // CRT signing 626 mpz_powm(s1, m, dp, p); 627 mpz_powm(s2, m, dq, q); 628 mpz_sub(h_crt, s1, s2); 629 mpz_mul(h_crt, qinv, h_crt); 630 mpz_mod(h_crt, h_crt, p); 631 mpz_mul(h_crt, h_crt, q); 632 mpz_add(s, s2, h_crt); 633 634 // Convert to bytes 635 *sig_len = key->n_len; 636 size_t count; 637 mpz_export(signature, &count, 1, 1, 1, 0, s); 638 if (count < key->n_len) { 639 memmove(signature + (key->n_len - count), signature, count); 640 memset(signature, 0, key->n_len - count); 641 } 642 643 // Clean up 644 mpz_clear(m); 645 mpz_clear(s); 646 mpz_clear(p); 647 mpz_clear(q); 648 mpz_clear(dp); 649 mpz_clear(dq); 650 mpz_clear(qinv); 651 mpz_clear(s1); 652 mpz_clear(s2); 653 mpz_clear(h_crt); 654 free(db); 655 free(db_mask); 656 free(em); 657 658 return 0; 659 } 660 661 /** 662 * RSA-PSS Verify 663 */ 664 int rsa_verify_pss(const rsa_public_key *key, 665 const uint8_t *message, size_t msg_len, 666 const uint8_t *signature, size_t sig_len) { 667 const size_t h_len = 32; // SHA-256 output length 668 const size_t s_len = 32; // Salt length 669 670 if (msg_len != h_len) { 671 fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n"); 672 return -1; 673 } 674 675 if (sig_len != key->n_len) { 676 return -1; 677 } 678 679 // Verify signature: m = s^e mod n 680 mpz_t s, m, n, e; 681 mpz_init(s); 682 mpz_init(m); 683 mpz_init(n); 684 mpz_init(e); 685 686 bytes_to_mpz(s, signature, sig_len); 687 bytes_to_mpz(n, key->n, key->n_len); 688 bytes_to_mpz(e, key->e, key->e_len); 689 690 mpz_powm(m, s, e, n); 691 692 // Convert to bytes 693 size_t em_len = key->n_len; 694 uint8_t *em = malloc(em_len); 695 size_t count; 696 mpz_export(em, &count, 1, 1, 1, 0, m); 697 if (count < em_len) { 698 memmove(em + (em_len - count), em, count); 699 memset(em, 0, em_len - count); 700 } 701 702 mpz_clear(s); 703 mpz_clear(m); 704 mpz_clear(n); 705 mpz_clear(e); 706 707 // Verify EM format 708 if (em[em_len - 1] != 0xbc) { 709 free(em); 710 return -1; 711 } 712 713 // Split EM = maskedDB || H || 0xbc 714 size_t db_len = em_len - h_len - 1; 715 uint8_t *masked_db = em; 716 uint8_t *h = em + db_len; 717 718 // Generate dbMask 719 uint8_t *db_mask = malloc(db_len); 720 mgf1_sha256(h, h_len, db_mask, db_len); 721 722 // DB = maskedDB xor dbMask 723 uint8_t *db = malloc(db_len); 724 for (size_t i = 0; i < db_len; i++) { 725 db[i] = masked_db[i] ^ db_mask[i]; 726 } 727 728 // Set leftmost bits to zero 729 size_t em_bits = key->n_len * 8 - 1; 730 size_t mask_bits = 8 * em_len - em_bits; 731 if (mask_bits > 0) { 732 db[0] &= (0xFF >> mask_bits); 733 } 734 735 // Verify DB = PS || 0x01 || salt 736 size_t ps_len = em_len - s_len - h_len - 2; 737 for (size_t i = 0; i < ps_len; i++) { 738 if (db[i] != 0x00) { 739 free(em); 740 free(db_mask); 741 free(db); 742 return -1; 743 } 744 } 745 746 if (db[ps_len] != 0x01) { 747 free(em); 748 free(db_mask); 749 free(db); 750 return -1; 751 } 752 753 // Extract salt 754 uint8_t *salt = db + ps_len + 1; 755 756 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt 757 uint8_t m_prime[8 + 32 + 32]; 758 memset(m_prime, 0, 8); 759 memcpy(m_prime + 8, message, h_len); 760 memcpy(m_prime + 8 + h_len, salt, s_len); 761 762 // H' = Hash(M') 763 uint8_t h_prime[32]; 764 sha256(m_prime, 8 + h_len + s_len, h_prime); 765 766 // Verify H == H' 767 int result = (memcmp(h, h_prime, h_len) == 0) ? 0 : -1; 768 769 free(em); 770 free(db_mask); 771 free(db); 772 773 return result; 774 } 775 776 // Free keys (already implemented in RSA.c, included here for completeness) 777 void rsa_free_public_key(rsa_public_key *key) { 778 if (key->n) { 779 memset(key->n, 0, key->n_len); 780 free(key->n); 781 } 782 if (key->e) { 783 memset(key->e, 0, key->e_len); 784 free(key->e); 785 } 786 memset(key, 0, sizeof(*key)); 787 } 788 789 void rsa_free_private_key(rsa_private_key *key) { 790 if (key->n) { 791 memset(key->n, 0, key->n_len); 792 free(key->n); 793 } 794 if (key->e) { 795 memset(key->e, 0, key->e_len); 796 free(key->e); 797 } 798 if (key->d) { 799 memset(key->d, 0, key->d_len); 800 free(key->d); 801 } 802 if (key->p) { 803 memset(key->p, 0, key->p_len); 804 free(key->p); 805 } 806 if (key->q) { 807 memset(key->q, 0, key->q_len); 808 free(key->q); 809 } 810 if (key->dp) { 811 memset(key->dp, 0, key->dp_len); 812 free(key->dp); 813 } 814 if (key->dq) { 815 memset(key->dq, 0, key->dq_len); 816 free(key->dq); 817 } 818 if (key->qinv) { 819 memset(key->qinv, 0, key->qinv_len); 820 free(key->qinv); 821 } 822 memset(key, 0, sizeof(*key)); 823 } 824 825 // Print public key 826 void rsa_print_public_key(const rsa_public_key *key) { 827 printf("\nPublic Key:\n"); 828 printf(" Modulus (%zu bytes, %zu bits):\n ", 829 key->n_len, key->n_len * 8); 830 for (size_t i = 0; i < (key->n_len < 32 ? key->n_len : 32); i++) { 831 printf("%02x", key->n[i]); 832 } 833 if (key->n_len > 32) printf("..."); 834 printf("\n Exponent (%zu bytes): ", key->e_len); 835 for (size_t i = 0; i < key->e_len; i++) { 836 printf("%02x", key->e[i]); 837 } 838 printf("\n"); 839 } 840 841 // Placeholder for export/import (not implemented yet) 842 int rsa_export_public_key(const rsa_public_key *key, 843 uint8_t *buffer, size_t buf_len) { 844 (void)key; 845 (void)buffer; 846 (void)buf_len; 847 fprintf(stderr, "Export function not yet implemented\n"); 848 return -1; 849 } 850 851 int rsa_import_public_key(rsa_public_key *key, 852 const uint8_t *buffer, size_t buf_len) { 853 (void)key; 854 (void)buffer; 855 (void)buf_len; 856 fprintf(stderr, "Import function not yet implemented\n"); 857 return -1; 858 } 859 860 // Test program 861 #ifdef INCLUDE_MAIN 862 int main(void) { 863 printf("╔════════════════════════════════════════════════════╗\n"); 864 printf("║ RSA Production Implementation with GMP ║\n"); 865 printf("╚════════════════════════════════════════════════════╝\n\n"); 866 867 rsa_public_key pub; 868 rsa_private_key priv; 869 870 // Generate 2048-bit key 871 if (rsa_generate_key_simple(&pub, &priv, 2048) != 0) { 872 return 1; 873 } 874 875 rsa_print_public_key(&pub); 876 877 printf("\n════════════════════════════════════════════════════\n"); 878 printf("Test 1: Encryption/Decryption\n"); 879 printf("════════════════════════════════════════════════════\n\n"); 880 881 const char *message = "Hello, RSA with 2048-bit keys!"; 882 size_t msg_len = strlen(message); 883 884 uint8_t ciphertext[512]; 885 size_t ct_len; 886 887 printf("Original message: \"%s\"\n", message); 888 889 if (rsa_encrypt(&pub, (uint8_t*)message, msg_len, 890 ciphertext, &ct_len) == 0) { 891 printf("✓ Encryption successful\n"); 892 printf("Ciphertext (%zu bytes):\n ", ct_len); 893 for (size_t i = 0; i < (ct_len < 32 ? ct_len : 32); i++) { 894 printf("%02x", ciphertext[i]); 895 } 896 if (ct_len > 32) printf("..."); 897 printf("\n\n"); 898 899 uint8_t decrypted[512]; 900 size_t dec_len; 901 902 if (rsa_decrypt(&priv, ciphertext, ct_len, 903 decrypted, &dec_len) == 0) { 904 decrypted[dec_len] = '\0'; 905 printf("Decrypted: \"%s\"\n", decrypted); 906 907 if (dec_len == msg_len && memcmp(message, decrypted, msg_len) == 0) { 908 printf("✓ Decryption successful!\n"); 909 } else { 910 printf("✗ Decryption mismatch!\n"); 911 } 912 } else { 913 printf("✗ Decryption failed!\n"); 914 } 915 } else { 916 printf("✗ Encryption failed!\n"); 917 } 918 919 printf("\n════════════════════════════════════════════════════\n"); 920 printf("Test 2: Digital Signature\n"); 921 printf("════════════════════════════════════════════════════\n\n"); 922 923 const char *doc = "Important document to sign"; 924 size_t doc_len = strlen(doc); 925 926 uint8_t signature[512]; 927 size_t sig_len; 928 929 printf("Document: \"%s\"\n", doc); 930 931 if (rsa_sign(&priv, (uint8_t*)doc, doc_len, 932 signature, &sig_len) == 0) { 933 printf("✓ Signature generated\n"); 934 printf("Signature (%zu bytes):\n ", sig_len); 935 for (size_t i = 0; i < (sig_len < 32 ? sig_len : 32); i++) { 936 printf("%02x", signature[i]); 937 } 938 if (sig_len > 32) printf("..."); 939 printf("\n\n"); 940 941 if (rsa_verify(&pub, (uint8_t*)doc, doc_len, 942 signature, sig_len) == 0) { 943 printf("✓ Signature verified!\n"); 944 } else { 945 printf("✗ Signature verification failed!\n"); 946 } 947 948 // Test with wrong document 949 const char *wrong_doc = "Tampered document"; 950 printf("\nTesting with tampered document: \"%s\"\n", wrong_doc); 951 if (rsa_verify(&pub, (uint8_t*)wrong_doc, strlen(wrong_doc), 952 signature, sig_len) == 0) { 953 printf("✗ Should have rejected!\n"); 954 } else { 955 printf("✓ Correctly rejected tampered document!\n"); 956 } 957 } else { 958 printf("✗ Signing failed!\n"); 959 } 960 961 printf("\n════════════════════════════════════════════════════\n"); 962 printf("Security Notes:\n"); 963 printf("════════════════════════════════════════════════════\n"); 964 printf("• Using 2048-bit RSA keys (production-ready)\n"); 965 printf("• PKCS#1 v1.5 padding implemented\n"); 966 printf("• Chinese Remainder Theorem for fast decryption\n"); 967 printf("• Always hash large documents before signing\n"); 968 printf("• Consider OAEP for encryption (more secure)\n"); 969 printf("• Consider PSS for signatures (more secure)\n"); 970 printf("• Memory is securely zeroed after use\n"); 971 972 // Clean up 973 rsa_free_public_key(&pub); 974 rsa_free_private_key(&priv); 975 976 return 0; 977 } 978 #endif