luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

RSA_production.c (27166B)


      1 /*
      2  * RSA Production Implementation with GMP
      3  * Supports 2048, 3072, and 4096-bit keys
      4  * Implements PKCS#1 v1.5 padding
      5  *
      6  * Compile with: gcc -O3 -o rsa_prod RSA_production.c -lgmp
      7  */
      8 
      9 #include "RSA.h"
     10 #include "SHA-256.h"
     11 #include "CSPRNG.h"
     12 #include <stdio.h>
     13 #include <stdlib.h>
     14 #include <string.h>
     15 #include <stdint.h>
     16 #include <time.h>
     17 #include <gmp.h>
     18 
     19 // Random number generation using global CSPRNG
     20 static void get_random_bytes(uint8_t *buf, size_t len) {
     21     // Use the global CSPRNG (will auto-initialize if needed)
     22     random_bytes(buf, len);
     23 }
     24 
     25 // Miller-Rabin primality test with GMP
     26 static int is_prime_gmp(mpz_t n, int iterations) {
     27     return mpz_probab_prime_p(n, iterations) > 0;
     28 }
     29 
     30 // Generate a random prime of specified bit size
     31 static void generate_prime(mpz_t prime, int bits, gmp_randstate_t state) {
     32     mpz_t candidate;
     33     mpz_init(candidate);
     34 
     35     do {
     36         // Generate random number of specified bit size
     37         mpz_urandomb(candidate, state, bits);
     38         // Set high bit to ensure correct bit length
     39         mpz_setbit(candidate, bits - 1);
     40         // Set low bit to make it odd
     41         mpz_setbit(candidate, 0);
     42 
     43         // Find next prime
     44         mpz_nextprime(prime, candidate);
     45     } while (mpz_sizeinbase(prime, 2) != (size_t)bits);
     46 
     47     mpz_clear(candidate);
     48 }
     49 
     50 // Convert mpz_t to byte array
     51 static void mpz_to_bytes(const mpz_t n, uint8_t **out, size_t *out_len) {
     52     size_t count;
     53     *out_len = (mpz_sizeinbase(n, 2) + 7) / 8;
     54     *out = malloc(*out_len);
     55 
     56     mpz_export(*out, &count, 1, 1, 1, 0, n);
     57 
     58     // Handle leading zeros
     59     if (count < *out_len) {
     60         memmove(*out + (*out_len - count), *out, count);
     61         memset(*out, 0, *out_len - count);
     62     }
     63 }
     64 
     65 // Convert byte array to mpz_t
     66 static void bytes_to_mpz(mpz_t n, const uint8_t *bytes, size_t len) {
     67     mpz_import(n, len, 1, 1, 1, 0, bytes);
     68 }
     69 
     70 // RSA key generation with GMP
     71 int rsa_generate_key_simple(rsa_public_key *pub, rsa_private_key *priv, int bits) {
     72     if (bits < 2048) {
     73         fprintf(stderr, "Key size must be at least 2048 bits for security\n");
     74         return -1;
     75     }
     76 
     77     if (bits % 2 != 0) {
     78         fprintf(stderr, "Key size must be even\n");
     79         return -1;
     80     }
     81 
     82     printf("Generating %d-bit RSA key pair...\n", bits);
     83     printf("This may take a minute...\n");
     84 
     85     // Initialize GMP random state
     86     gmp_randstate_t state;
     87     gmp_randinit_default(state);
     88 
     89     // Seed with random data
     90     uint8_t seed[32];
     91     get_random_bytes(seed, 32);
     92     mpz_t seed_mpz;
     93     mpz_init(seed_mpz);
     94     bytes_to_mpz(seed_mpz, seed, 32);
     95     gmp_randseed(state, seed_mpz);
     96     mpz_clear(seed_mpz);
     97 
     98     mpz_t p, q, n, phi, e, d, p_minus_1, q_minus_1;
     99     mpz_t dp, dq, qinv, gcd_val;
    100 
    101     mpz_init(p);
    102     mpz_init(q);
    103     mpz_init(n);
    104     mpz_init(phi);
    105     mpz_init(e);
    106     mpz_init(d);
    107     mpz_init(p_minus_1);
    108     mpz_init(q_minus_1);
    109     mpz_init(dp);
    110     mpz_init(dq);
    111     mpz_init(qinv);
    112     mpz_init(gcd_val);
    113 
    114     // Generate two primes
    115     int prime_bits = bits / 2;
    116     printf("Generating first prime (%d bits)...\n", prime_bits);
    117     generate_prime(p, prime_bits, state);
    118 
    119     printf("Generating second prime (%d bits)...\n", prime_bits);
    120     do {
    121         generate_prime(q, prime_bits, state);
    122     } while (mpz_cmp(p, q) == 0); // Ensure p != q
    123 
    124     // Calculate n = p * q
    125     mpz_mul(n, p, q);
    126 
    127     // Calculate φ(n) = (p-1)(q-1)
    128     mpz_sub_ui(p_minus_1, p, 1);
    129     mpz_sub_ui(q_minus_1, q, 1);
    130     mpz_mul(phi, p_minus_1, q_minus_1);
    131 
    132     // Choose e = 65537 (standard public exponent)
    133     mpz_set_ui(e, 65537);
    134 
    135     // Verify gcd(e, φ(n)) = 1
    136     mpz_gcd(gcd_val, e, phi);
    137     if (mpz_cmp_ui(gcd_val, 1) != 0) {
    138         fprintf(stderr, "gcd(e, phi) != 1, key generation failed\n");
    139         goto cleanup;
    140     }
    141 
    142     // Calculate d = e^-1 mod φ(n)
    143     if (!mpz_invert(d, e, phi)) {
    144         fprintf(stderr, "Failed to compute private exponent\n");
    145         goto cleanup;
    146     }
    147 
    148     // Calculate CRT parameters
    149     // dp = d mod (p-1)
    150     mpz_mod(dp, d, p_minus_1);
    151 
    152     // dq = d mod (q-1)
    153     mpz_mod(dq, d, q_minus_1);
    154 
    155     // qinv = q^-1 mod p
    156     if (!mpz_invert(qinv, q, p)) {
    157         fprintf(stderr, "Failed to compute qinv\n");
    158         goto cleanup;
    159     }
    160 
    161     // Convert to byte arrays
    162     mpz_to_bytes(n, &pub->n, &pub->n_len);
    163     mpz_to_bytes(e, &pub->e, &pub->e_len);
    164 
    165     mpz_to_bytes(n, &priv->n, &priv->n_len);
    166     mpz_to_bytes(e, &priv->e, &priv->e_len);
    167     mpz_to_bytes(d, &priv->d, &priv->d_len);
    168     mpz_to_bytes(p, &priv->p, &priv->p_len);
    169     mpz_to_bytes(q, &priv->q, &priv->q_len);
    170     mpz_to_bytes(dp, &priv->dp, &priv->dp_len);
    171     mpz_to_bytes(dq, &priv->dq, &priv->dq_len);
    172     mpz_to_bytes(qinv, &priv->qinv, &priv->qinv_len);
    173 
    174     printf("Key generation complete!\n");
    175     printf("Modulus size: %zu bits\n", mpz_sizeinbase(n, 2));
    176 
    177 cleanup:
    178     mpz_clear(p);
    179     mpz_clear(q);
    180     mpz_clear(n);
    181     mpz_clear(phi);
    182     mpz_clear(e);
    183     mpz_clear(d);
    184     mpz_clear(p_minus_1);
    185     mpz_clear(q_minus_1);
    186     mpz_clear(dp);
    187     mpz_clear(dq);
    188     mpz_clear(qinv);
    189     mpz_clear(gcd_val);
    190     gmp_randclear(state);
    191 
    192     return 0;
    193 }
    194 
    195 // PKCS#1 v1.5 padding for encryption
    196 static int pkcs1_v15_encode(const uint8_t *message, size_t msg_len,
    197                             uint8_t *padded, size_t padded_len) {
    198     if (msg_len > padded_len - 11) {
    199         fprintf(stderr, "Message too long for key size\n");
    200         return -1;
    201     }
    202 
    203     // EM = 0x00 || 0x02 || PS || 0x00 || M
    204     padded[0] = 0x00;
    205     padded[1] = 0x02;
    206 
    207     // PS: padding string (at least 8 random non-zero bytes)
    208     size_t ps_len = padded_len - msg_len - 3;
    209     get_random_bytes(padded + 2, ps_len);
    210 
    211     // Ensure no zero bytes in padding
    212     for (size_t i = 0; i < ps_len; i++) {
    213         while (padded[2 + i] == 0) {
    214             get_random_bytes(&padded[2 + i], 1);
    215         }
    216     }
    217 
    218     padded[2 + ps_len] = 0x00;
    219     memcpy(padded + 2 + ps_len + 1, message, msg_len);
    220 
    221     return 0;
    222 }
    223 
    224 // PKCS#1 v1.5 padding removal for decryption
    225 static int pkcs1_v15_decode(const uint8_t *padded, size_t padded_len,
    226                             uint8_t *message, size_t *msg_len) {
    227     if (padded_len < 11) return -1;
    228     if (padded[0] != 0x00) return -1;
    229     if (padded[1] != 0x02) return -1;
    230 
    231     // Find the 0x00 separator
    232     size_t i;
    233     for (i = 2; i < padded_len; i++) {
    234         if (padded[i] == 0x00) break;
    235     }
    236 
    237     if (i < 10 || i >= padded_len) return -1; // Invalid padding
    238 
    239     // Copy message
    240     *msg_len = padded_len - i - 1;
    241     memcpy(message, padded + i + 1, *msg_len);
    242 
    243     return 0;
    244 }
    245 
    246 // RSA encryption
    247 int rsa_encrypt(const rsa_public_key *key,
    248                 const uint8_t *plaintext, size_t pt_len,
    249                 uint8_t *ciphertext, size_t *ct_len) {
    250     if (pt_len > key->n_len - 11) {
    251         fprintf(stderr, "Plaintext too long for key size\n");
    252         return -1;
    253     }
    254 
    255     // Apply PKCS#1 v1.5 padding
    256     uint8_t *padded = malloc(key->n_len);
    257     if (pkcs1_v15_encode(plaintext, pt_len, padded, key->n_len) != 0) {
    258         free(padded);
    259         return -1;
    260     }
    261 
    262     // Convert to GMP integers
    263     mpz_t m, c, n, e;
    264     mpz_init(m);
    265     mpz_init(c);
    266     mpz_init(n);
    267     mpz_init(e);
    268 
    269     bytes_to_mpz(m, padded, key->n_len);
    270     bytes_to_mpz(n, key->n, key->n_len);
    271     bytes_to_mpz(e, key->e, key->e_len);
    272 
    273     // c = m^e mod n
    274     mpz_powm(c, m, e, n);
    275 
    276     // Convert back to bytes
    277     *ct_len = key->n_len;
    278     size_t count;
    279     mpz_export(ciphertext, &count, 1, 1, 1, 0, c);
    280 
    281     // Pad with leading zeros if necessary
    282     if (count < *ct_len) {
    283         memmove(ciphertext + (*ct_len - count), ciphertext, count);
    284         memset(ciphertext, 0, *ct_len - count);
    285     }
    286 
    287     mpz_clear(m);
    288     mpz_clear(c);
    289     mpz_clear(n);
    290     mpz_clear(e);
    291     free(padded);
    292 
    293     return 0;
    294 }
    295 
    296 // RSA decryption with CRT
    297 int rsa_decrypt(const rsa_private_key *key,
    298                 const uint8_t *ciphertext, size_t ct_len,
    299                 uint8_t *plaintext, size_t *pt_len) {
    300     if (ct_len != key->n_len) {
    301         fprintf(stderr, "Invalid ciphertext length\n");
    302         return -1;
    303     }
    304 
    305     mpz_t c, m, p, q, dp, dq, qinv, m1, m2, h;
    306     mpz_init(c);
    307     mpz_init(m);
    308     mpz_init(p);
    309     mpz_init(q);
    310     mpz_init(dp);
    311     mpz_init(dq);
    312     mpz_init(qinv);
    313     mpz_init(m1);
    314     mpz_init(m2);
    315     mpz_init(h);
    316 
    317     bytes_to_mpz(c, ciphertext, ct_len);
    318     bytes_to_mpz(p, key->p, key->p_len);
    319     bytes_to_mpz(q, key->q, key->q_len);
    320     bytes_to_mpz(dp, key->dp, key->dp_len);
    321     bytes_to_mpz(dq, key->dq, key->dq_len);
    322     bytes_to_mpz(qinv, key->qinv, key->qinv_len);
    323 
    324     // CRT decryption for speed
    325     // m1 = c^dp mod p
    326     mpz_powm(m1, c, dp, p);
    327 
    328     // m2 = c^dq mod q
    329     mpz_powm(m2, c, dq, q);
    330 
    331     // h = qinv * (m1 - m2) mod p
    332     mpz_sub(h, m1, m2);
    333     mpz_mul(h, qinv, h);
    334     mpz_mod(h, h, p);
    335 
    336     // m = m2 + h * q
    337     mpz_mul(h, h, q);
    338     mpz_add(m, m2, h);
    339 
    340     // Convert to bytes
    341     uint8_t *padded = malloc(key->n_len);
    342     size_t count;
    343     mpz_export(padded, &count, 1, 1, 1, 0, m);
    344 
    345     // Pad with leading zeros if necessary
    346     if (count < key->n_len) {
    347         memmove(padded + (key->n_len - count), padded, count);
    348         memset(padded, 0, key->n_len - count);
    349     }
    350 
    351     // Remove PKCS#1 v1.5 padding
    352     int result = pkcs1_v15_decode(padded, key->n_len, plaintext, pt_len);
    353 
    354     mpz_clear(c);
    355     mpz_clear(m);
    356     mpz_clear(p);
    357     mpz_clear(q);
    358     mpz_clear(dp);
    359     mpz_clear(dq);
    360     mpz_clear(qinv);
    361     mpz_clear(m1);
    362     mpz_clear(m2);
    363     mpz_clear(h);
    364     free(padded);
    365 
    366     return result;
    367 }
    368 
    369 // RSA Sign (same as decrypt operation)
    370 int rsa_sign(const rsa_private_key *key,
    371              const uint8_t *message, size_t msg_len,
    372              uint8_t *signature, size_t *sig_len) {
    373     if (msg_len > key->n_len - 11) {
    374         fprintf(stderr, "Message too long. Hash it first with SHA-256.\n");
    375         return -1;
    376     }
    377 
    378     // Apply PKCS#1 v1.5 padding
    379     uint8_t *padded = malloc(key->n_len);
    380     padded[0] = 0x00;
    381     padded[1] = 0x01; // 0x01 for signatures
    382 
    383     size_t ps_len = key->n_len - msg_len - 3;
    384     memset(padded + 2, 0xFF, ps_len); // 0xFF padding for signatures
    385 
    386     padded[2 + ps_len] = 0x00;
    387     memcpy(padded + 2 + ps_len + 1, message, msg_len);
    388 
    389     // Sign using CRT (same as decrypt)
    390     mpz_t m, s, p, q, dp, dq, qinv, s1, s2, h;
    391     mpz_init(m);
    392     mpz_init(s);
    393     mpz_init(p);
    394     mpz_init(q);
    395     mpz_init(dp);
    396     mpz_init(dq);
    397     mpz_init(qinv);
    398     mpz_init(s1);
    399     mpz_init(s2);
    400     mpz_init(h);
    401 
    402     bytes_to_mpz(m, padded, key->n_len);
    403     bytes_to_mpz(p, key->p, key->p_len);
    404     bytes_to_mpz(q, key->q, key->q_len);
    405     bytes_to_mpz(dp, key->dp, key->dp_len);
    406     bytes_to_mpz(dq, key->dq, key->dq_len);
    407     bytes_to_mpz(qinv, key->qinv, key->qinv_len);
    408 
    409     // CRT signing
    410     mpz_powm(s1, m, dp, p);
    411     mpz_powm(s2, m, dq, q);
    412     mpz_sub(h, s1, s2);
    413     mpz_mul(h, qinv, h);
    414     mpz_mod(h, h, p);
    415     mpz_mul(h, h, q);
    416     mpz_add(s, s2, h);
    417 
    418     // Convert to bytes
    419     *sig_len = key->n_len;
    420     size_t count;
    421     mpz_export(signature, &count, 1, 1, 1, 0, s);
    422 
    423     if (count < *sig_len) {
    424         memmove(signature + (*sig_len - count), signature, count);
    425         memset(signature, 0, *sig_len - count);
    426     }
    427 
    428     mpz_clear(m);
    429     mpz_clear(s);
    430     mpz_clear(p);
    431     mpz_clear(q);
    432     mpz_clear(dp);
    433     mpz_clear(dq);
    434     mpz_clear(qinv);
    435     mpz_clear(s1);
    436     mpz_clear(s2);
    437     mpz_clear(h);
    438     free(padded);
    439 
    440     return 0;
    441 }
    442 
    443 // RSA Verify signature
    444 int rsa_verify(const rsa_public_key *key,
    445                const uint8_t *message, size_t msg_len,
    446                const uint8_t *signature, size_t sig_len) {
    447     if (sig_len != key->n_len) {
    448         return -1;
    449     }
    450 
    451     mpz_t s, m, n, e;
    452     mpz_init(s);
    453     mpz_init(m);
    454     mpz_init(n);
    455     mpz_init(e);
    456 
    457     bytes_to_mpz(s, signature, sig_len);
    458     bytes_to_mpz(n, key->n, key->n_len);
    459     bytes_to_mpz(e, key->e, key->e_len);
    460 
    461     // m = s^e mod n
    462     mpz_powm(m, s, e, n);
    463 
    464     // Convert to bytes
    465     uint8_t *decrypted = malloc(key->n_len);
    466     size_t count;
    467     mpz_export(decrypted, &count, 1, 1, 1, 0, m);
    468 
    469     if (count < key->n_len) {
    470         memmove(decrypted + (key->n_len - count), decrypted, count);
    471         memset(decrypted, 0, key->n_len - count);
    472     }
    473 
    474     // Verify padding and message
    475     int result = -1;
    476     if (decrypted[0] == 0x00 && decrypted[1] == 0x01) {
    477         size_t i;
    478         for (i = 2; i < key->n_len; i++) {
    479             if (decrypted[i] == 0x00) break;
    480             if (decrypted[i] != 0xFF) goto cleanup;
    481         }
    482 
    483         if (i >= 10 && i < key->n_len) {
    484             size_t recovered_len = key->n_len - i - 1;
    485             if (recovered_len == msg_len &&
    486                 memcmp(decrypted + i + 1, message, msg_len) == 0) {
    487                 result = 0;
    488             }
    489         }
    490     }
    491 
    492 cleanup:
    493     mpz_clear(s);
    494     mpz_clear(m);
    495     mpz_clear(n);
    496     mpz_clear(e);
    497     free(decrypted);
    498 
    499     return result;
    500 }
    501 
    502 /* ============================================================================
    503  * RSA-PSS (Probabilistic Signature Scheme) Implementation
    504  * RFC 8017 - PKCS #1: RSA Cryptography Specifications Version 2.2
    505  * ========================================================================= */
    506 
    507 /**
    508  * MGF1 - Mask Generation Function based on SHA-256
    509  * Used in PSS padding
    510  */
    511 static void mgf1_sha256(const uint8_t *seed, size_t seed_len,
    512                         uint8_t *mask, size_t mask_len) {
    513     uint8_t counter[4];
    514     size_t offset = 0;
    515     uint32_t count = 0;
    516 
    517     while (offset < mask_len) {
    518         // Convert counter to big-endian bytes
    519         counter[0] = (count >> 24) & 0xFF;
    520         counter[1] = (count >> 16) & 0xFF;
    521         counter[2] = (count >> 8) & 0xFF;
    522         counter[3] = count & 0xFF;
    523 
    524         // Hash seed || counter
    525         uint8_t hash_input[256]; // Max seed length
    526         memcpy(hash_input, seed, seed_len);
    527         memcpy(hash_input + seed_len, counter, 4);
    528 
    529         uint8_t hash[32];
    530         sha256(hash_input, seed_len + 4, hash);
    531 
    532         // Copy to mask
    533         size_t to_copy = (mask_len - offset < 32) ? (mask_len - offset) : 32;
    534         memcpy(mask + offset, hash, to_copy);
    535 
    536         offset += to_copy;
    537         count++;
    538     }
    539 }
    540 
    541 /**
    542  * RSA-PSS Sign
    543  */
    544 int rsa_sign_pss(const rsa_private_key *key,
    545                  const uint8_t *message, size_t msg_len,
    546                  uint8_t *signature, size_t *sig_len) {
    547     // PSS parameters (using SHA-256)
    548     const size_t h_len = 32;  // SHA-256 output length
    549     const size_t s_len = 32;  // Salt length (recommended: same as hash length)
    550 
    551     if (msg_len != h_len) {
    552         fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n");
    553         return -1;
    554     }
    555 
    556     size_t em_len = key->n_len;  // Encoded message length
    557     if (em_len < h_len + s_len + 2) {
    558         fprintf(stderr, "RSA-PSS: Key too short for PSS padding\n");
    559         return -1;
    560     }
    561 
    562     // Generate salt
    563     uint8_t salt[32];
    564     random_bytes(salt, s_len);
    565 
    566     // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
    567     uint8_t m_prime[8 + 32 + 32];
    568     memset(m_prime, 0, 8);
    569     memcpy(m_prime + 8, message, h_len);
    570     memcpy(m_prime + 8 + h_len, salt, s_len);
    571 
    572     // H = Hash(M')
    573     uint8_t h[32];
    574     sha256(m_prime, 8 + h_len + s_len, h);
    575 
    576     // Generate DB = PS || 0x01 || salt
    577     size_t ps_len = em_len - s_len - h_len - 2;
    578     uint8_t *db = calloc(em_len - h_len - 1, 1);  // Zero-initialized
    579     db[ps_len] = 0x01;
    580     memcpy(db + ps_len + 1, salt, s_len);
    581 
    582     // dbMask = MGF(H, emLen - hLen - 1)
    583     size_t db_len = em_len - h_len - 1;
    584     uint8_t *db_mask = malloc(db_len);
    585     mgf1_sha256(h, h_len, db_mask, db_len);
    586 
    587     // maskedDB = DB xor dbMask
    588     for (size_t i = 0; i < db_len; i++) {
    589         db[i] ^= db_mask[i];
    590     }
    591 
    592     // Set leftmost bits of maskedDB to zero (emBits - 1 % 8)
    593     size_t em_bits = key->n_len * 8 - 1;  // Adjust for modulus bit length
    594     size_t mask_bits = 8 * em_len - em_bits;
    595     if (mask_bits > 0) {
    596         db[0] &= (0xFF >> mask_bits);
    597     }
    598 
    599     // EM = maskedDB || H || 0xbc
    600     uint8_t *em = malloc(em_len);
    601     memcpy(em, db, db_len);
    602     memcpy(em + db_len, h, h_len);
    603     em[em_len - 1] = 0xbc;
    604 
    605     // Sign using CRT (same as PKCS#1 v1.5)
    606     mpz_t m, s, p, q, dp, dq, qinv, s1, s2, h_crt;
    607     mpz_init(m);
    608     mpz_init(s);
    609     mpz_init(p);
    610     mpz_init(q);
    611     mpz_init(dp);
    612     mpz_init(dq);
    613     mpz_init(qinv);
    614     mpz_init(s1);
    615     mpz_init(s2);
    616     mpz_init(h_crt);
    617 
    618     bytes_to_mpz(m, em, em_len);
    619     bytes_to_mpz(p, key->p, key->p_len);
    620     bytes_to_mpz(q, key->q, key->q_len);
    621     bytes_to_mpz(dp, key->dp, key->dp_len);
    622     bytes_to_mpz(dq, key->dq, key->dq_len);
    623     bytes_to_mpz(qinv, key->qinv, key->qinv_len);
    624 
    625     // CRT signing
    626     mpz_powm(s1, m, dp, p);
    627     mpz_powm(s2, m, dq, q);
    628     mpz_sub(h_crt, s1, s2);
    629     mpz_mul(h_crt, qinv, h_crt);
    630     mpz_mod(h_crt, h_crt, p);
    631     mpz_mul(h_crt, h_crt, q);
    632     mpz_add(s, s2, h_crt);
    633 
    634     // Convert to bytes
    635     *sig_len = key->n_len;
    636     size_t count;
    637     mpz_export(signature, &count, 1, 1, 1, 0, s);
    638     if (count < key->n_len) {
    639         memmove(signature + (key->n_len - count), signature, count);
    640         memset(signature, 0, key->n_len - count);
    641     }
    642 
    643     // Clean up
    644     mpz_clear(m);
    645     mpz_clear(s);
    646     mpz_clear(p);
    647     mpz_clear(q);
    648     mpz_clear(dp);
    649     mpz_clear(dq);
    650     mpz_clear(qinv);
    651     mpz_clear(s1);
    652     mpz_clear(s2);
    653     mpz_clear(h_crt);
    654     free(db);
    655     free(db_mask);
    656     free(em);
    657 
    658     return 0;
    659 }
    660 
    661 /**
    662  * RSA-PSS Verify
    663  */
    664 int rsa_verify_pss(const rsa_public_key *key,
    665                    const uint8_t *message, size_t msg_len,
    666                    const uint8_t *signature, size_t sig_len) {
    667     const size_t h_len = 32;  // SHA-256 output length
    668     const size_t s_len = 32;  // Salt length
    669 
    670     if (msg_len != h_len) {
    671         fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n");
    672         return -1;
    673     }
    674 
    675     if (sig_len != key->n_len) {
    676         return -1;
    677     }
    678 
    679     // Verify signature: m = s^e mod n
    680     mpz_t s, m, n, e;
    681     mpz_init(s);
    682     mpz_init(m);
    683     mpz_init(n);
    684     mpz_init(e);
    685 
    686     bytes_to_mpz(s, signature, sig_len);
    687     bytes_to_mpz(n, key->n, key->n_len);
    688     bytes_to_mpz(e, key->e, key->e_len);
    689 
    690     mpz_powm(m, s, e, n);
    691 
    692     // Convert to bytes
    693     size_t em_len = key->n_len;
    694     uint8_t *em = malloc(em_len);
    695     size_t count;
    696     mpz_export(em, &count, 1, 1, 1, 0, m);
    697     if (count < em_len) {
    698         memmove(em + (em_len - count), em, count);
    699         memset(em, 0, em_len - count);
    700     }
    701 
    702     mpz_clear(s);
    703     mpz_clear(m);
    704     mpz_clear(n);
    705     mpz_clear(e);
    706 
    707     // Verify EM format
    708     if (em[em_len - 1] != 0xbc) {
    709         free(em);
    710         return -1;
    711     }
    712 
    713     // Split EM = maskedDB || H || 0xbc
    714     size_t db_len = em_len - h_len - 1;
    715     uint8_t *masked_db = em;
    716     uint8_t *h = em + db_len;
    717 
    718     // Generate dbMask
    719     uint8_t *db_mask = malloc(db_len);
    720     mgf1_sha256(h, h_len, db_mask, db_len);
    721 
    722     // DB = maskedDB xor dbMask
    723     uint8_t *db = malloc(db_len);
    724     for (size_t i = 0; i < db_len; i++) {
    725         db[i] = masked_db[i] ^ db_mask[i];
    726     }
    727 
    728     // Set leftmost bits to zero
    729     size_t em_bits = key->n_len * 8 - 1;
    730     size_t mask_bits = 8 * em_len - em_bits;
    731     if (mask_bits > 0) {
    732         db[0] &= (0xFF >> mask_bits);
    733     }
    734 
    735     // Verify DB = PS || 0x01 || salt
    736     size_t ps_len = em_len - s_len - h_len - 2;
    737     for (size_t i = 0; i < ps_len; i++) {
    738         if (db[i] != 0x00) {
    739             free(em);
    740             free(db_mask);
    741             free(db);
    742             return -1;
    743         }
    744     }
    745 
    746     if (db[ps_len] != 0x01) {
    747         free(em);
    748         free(db_mask);
    749         free(db);
    750         return -1;
    751     }
    752 
    753     // Extract salt
    754     uint8_t *salt = db + ps_len + 1;
    755 
    756     // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
    757     uint8_t m_prime[8 + 32 + 32];
    758     memset(m_prime, 0, 8);
    759     memcpy(m_prime + 8, message, h_len);
    760     memcpy(m_prime + 8 + h_len, salt, s_len);
    761 
    762     // H' = Hash(M')
    763     uint8_t h_prime[32];
    764     sha256(m_prime, 8 + h_len + s_len, h_prime);
    765 
    766     // Verify H == H'
    767     int result = (memcmp(h, h_prime, h_len) == 0) ? 0 : -1;
    768 
    769     free(em);
    770     free(db_mask);
    771     free(db);
    772 
    773     return result;
    774 }
    775 
    776 // Free keys (already implemented in RSA.c, included here for completeness)
    777 void rsa_free_public_key(rsa_public_key *key) {
    778     if (key->n) {
    779         memset(key->n, 0, key->n_len);
    780         free(key->n);
    781     }
    782     if (key->e) {
    783         memset(key->e, 0, key->e_len);
    784         free(key->e);
    785     }
    786     memset(key, 0, sizeof(*key));
    787 }
    788 
    789 void rsa_free_private_key(rsa_private_key *key) {
    790     if (key->n) {
    791         memset(key->n, 0, key->n_len);
    792         free(key->n);
    793     }
    794     if (key->e) {
    795         memset(key->e, 0, key->e_len);
    796         free(key->e);
    797     }
    798     if (key->d) {
    799         memset(key->d, 0, key->d_len);
    800         free(key->d);
    801     }
    802     if (key->p) {
    803         memset(key->p, 0, key->p_len);
    804         free(key->p);
    805     }
    806     if (key->q) {
    807         memset(key->q, 0, key->q_len);
    808         free(key->q);
    809     }
    810     if (key->dp) {
    811         memset(key->dp, 0, key->dp_len);
    812         free(key->dp);
    813     }
    814     if (key->dq) {
    815         memset(key->dq, 0, key->dq_len);
    816         free(key->dq);
    817     }
    818     if (key->qinv) {
    819         memset(key->qinv, 0, key->qinv_len);
    820         free(key->qinv);
    821     }
    822     memset(key, 0, sizeof(*key));
    823 }
    824 
    825 // Print public key
    826 void rsa_print_public_key(const rsa_public_key *key) {
    827     printf("\nPublic Key:\n");
    828     printf("  Modulus (%zu bytes, %zu bits):\n    ",
    829            key->n_len, key->n_len * 8);
    830     for (size_t i = 0; i < (key->n_len < 32 ? key->n_len : 32); i++) {
    831         printf("%02x", key->n[i]);
    832     }
    833     if (key->n_len > 32) printf("...");
    834     printf("\n  Exponent (%zu bytes): ", key->e_len);
    835     for (size_t i = 0; i < key->e_len; i++) {
    836         printf("%02x", key->e[i]);
    837     }
    838     printf("\n");
    839 }
    840 
    841 // Placeholder for export/import (not implemented yet)
    842 int rsa_export_public_key(const rsa_public_key *key,
    843                           uint8_t *buffer, size_t buf_len) {
    844     (void)key;
    845     (void)buffer;
    846     (void)buf_len;
    847     fprintf(stderr, "Export function not yet implemented\n");
    848     return -1;
    849 }
    850 
    851 int rsa_import_public_key(rsa_public_key *key,
    852                           const uint8_t *buffer, size_t buf_len) {
    853     (void)key;
    854     (void)buffer;
    855     (void)buf_len;
    856     fprintf(stderr, "Import function not yet implemented\n");
    857     return -1;
    858 }
    859 
    860 // Test program
    861 #ifdef INCLUDE_MAIN
    862 int main(void) {
    863     printf("╔════════════════════════════════════════════════════╗\n");
    864     printf("║   RSA Production Implementation with GMP          ║\n");
    865     printf("╚════════════════════════════════════════════════════╝\n\n");
    866 
    867     rsa_public_key pub;
    868     rsa_private_key priv;
    869 
    870     // Generate 2048-bit key
    871     if (rsa_generate_key_simple(&pub, &priv, 2048) != 0) {
    872         return 1;
    873     }
    874 
    875     rsa_print_public_key(&pub);
    876 
    877     printf("\n════════════════════════════════════════════════════\n");
    878     printf("Test 1: Encryption/Decryption\n");
    879     printf("════════════════════════════════════════════════════\n\n");
    880 
    881     const char *message = "Hello, RSA with 2048-bit keys!";
    882     size_t msg_len = strlen(message);
    883 
    884     uint8_t ciphertext[512];
    885     size_t ct_len;
    886 
    887     printf("Original message: \"%s\"\n", message);
    888 
    889     if (rsa_encrypt(&pub, (uint8_t*)message, msg_len,
    890                     ciphertext, &ct_len) == 0) {
    891         printf("✓ Encryption successful\n");
    892         printf("Ciphertext (%zu bytes):\n  ", ct_len);
    893         for (size_t i = 0; i < (ct_len < 32 ? ct_len : 32); i++) {
    894             printf("%02x", ciphertext[i]);
    895         }
    896         if (ct_len > 32) printf("...");
    897         printf("\n\n");
    898 
    899         uint8_t decrypted[512];
    900         size_t dec_len;
    901 
    902         if (rsa_decrypt(&priv, ciphertext, ct_len,
    903                         decrypted, &dec_len) == 0) {
    904             decrypted[dec_len] = '\0';
    905             printf("Decrypted: \"%s\"\n", decrypted);
    906 
    907             if (dec_len == msg_len && memcmp(message, decrypted, msg_len) == 0) {
    908                 printf("✓ Decryption successful!\n");
    909             } else {
    910                 printf("✗ Decryption mismatch!\n");
    911             }
    912         } else {
    913             printf("✗ Decryption failed!\n");
    914         }
    915     } else {
    916         printf("✗ Encryption failed!\n");
    917     }
    918 
    919     printf("\n════════════════════════════════════════════════════\n");
    920     printf("Test 2: Digital Signature\n");
    921     printf("════════════════════════════════════════════════════\n\n");
    922 
    923     const char *doc = "Important document to sign";
    924     size_t doc_len = strlen(doc);
    925 
    926     uint8_t signature[512];
    927     size_t sig_len;
    928 
    929     printf("Document: \"%s\"\n", doc);
    930 
    931     if (rsa_sign(&priv, (uint8_t*)doc, doc_len,
    932                  signature, &sig_len) == 0) {
    933         printf("✓ Signature generated\n");
    934         printf("Signature (%zu bytes):\n  ", sig_len);
    935         for (size_t i = 0; i < (sig_len < 32 ? sig_len : 32); i++) {
    936             printf("%02x", signature[i]);
    937         }
    938         if (sig_len > 32) printf("...");
    939         printf("\n\n");
    940 
    941         if (rsa_verify(&pub, (uint8_t*)doc, doc_len,
    942                        signature, sig_len) == 0) {
    943             printf("✓ Signature verified!\n");
    944         } else {
    945             printf("✗ Signature verification failed!\n");
    946         }
    947 
    948         // Test with wrong document
    949         const char *wrong_doc = "Tampered document";
    950         printf("\nTesting with tampered document: \"%s\"\n", wrong_doc);
    951         if (rsa_verify(&pub, (uint8_t*)wrong_doc, strlen(wrong_doc),
    952                        signature, sig_len) == 0) {
    953             printf("✗ Should have rejected!\n");
    954         } else {
    955             printf("✓ Correctly rejected tampered document!\n");
    956         }
    957     } else {
    958         printf("✗ Signing failed!\n");
    959     }
    960 
    961     printf("\n════════════════════════════════════════════════════\n");
    962     printf("Security Notes:\n");
    963     printf("════════════════════════════════════════════════════\n");
    964     printf("• Using 2048-bit RSA keys (production-ready)\n");
    965     printf("• PKCS#1 v1.5 padding implemented\n");
    966     printf("• Chinese Remainder Theorem for fast decryption\n");
    967     printf("• Always hash large documents before signing\n");
    968     printf("• Consider OAEP for encryption (more secure)\n");
    969     printf("• Consider PSS for signatures (more secure)\n");
    970     printf("• Memory is securely zeroed after use\n");
    971 
    972     // Clean up
    973     rsa_free_public_key(&pub);
    974     rsa_free_private_key(&priv);
    975 
    976     return 0;
    977 }
    978 #endif