luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

RSA.c (20837B)


      1 /*
      2  * RSA Implementation
      3  * Supports key generation, encryption/decryption, signing/verification
      4  * Implements PKCS#1 v1.5 and OAEP padding
      5  *
      6  * Note: This is an educational implementation. For production use,
      7  * consider using established libraries like OpenSSL or mbedTLS.
      8  */
      9 
     10 #include "RSA.h"
     11 #include "CSPRNG.h"
     12 #include <stdio.h>
     13 #include <stdlib.h>
     14 #include <string.h>
     15 #include <stdint.h>
     16 #include <time.h>
     17 
     18 // Large integer support using GMP (GNU Multiple Precision)
     19 // If GMP is not available, this provides a basic implementation
     20 #ifdef USE_GMP
     21 #include <gmp.h>
     22 #else
     23 // Simple big integer implementation for demonstration
     24 typedef struct {
     25     uint32_t *data;
     26     size_t size;
     27     size_t capacity;
     28 } bigint;
     29 
     30 static void bigint_init(bigint *n) {
     31     n->capacity = 64;
     32     n->size = 0;
     33     n->data = calloc(n->capacity, sizeof(uint32_t));
     34 }
     35 
     36 static void bigint_free(bigint *n) {
     37     if (n->data) free(n->data);
     38     n->data = NULL;
     39     n->size = 0;
     40 }
     41 
     42 static void bigint_from_uint(bigint *n, uint32_t val) {
     43     n->size = 1;
     44     n->data[0] = val;
     45 }
     46 #endif
     47 
     48 // Type definitions are in RSA.h
     49 
     50 // Random number generation using global CSPRNG
     51 static void get_random_bytes(uint8_t *buf, size_t len) {
     52     random_bytes(buf, len);
     53 }
     54 
     55 // Miller-Rabin primality test (simplified)
     56 static int is_prime_simple(uint32_t n) {
     57     if (n < 2) return 0;
     58     if (n == 2 || n == 3) return 1;
     59     if (n % 2 == 0) return 0;
     60 
     61     for (uint32_t i = 3; i * i <= n; i += 2) {
     62         if (n % i == 0) return 0;
     63     }
     64     return 1;
     65 }
     66 
     67 // Generate a random prime (simplified for demonstration)
     68 static uint32_t generate_small_prime(void) {
     69     uint32_t candidate;
     70     do {
     71         uint8_t bytes[2];
     72         get_random_bytes(bytes, 2);
     73         candidate = (bytes[0] << 8) | bytes[1];
     74         candidate |= 0x8000; // Ensure high bit set
     75     } while (!is_prime_simple(candidate));
     76     return candidate;
     77 }
     78 
     79 // Extended Euclidean algorithm
     80 static int64_t extended_gcd(int64_t a, int64_t b, int64_t *x, int64_t *y) {
     81     if (b == 0) {
     82         *x = 1;
     83         *y = 0;
     84         return a;
     85     }
     86 
     87     int64_t x1, y1;
     88     int64_t gcd = extended_gcd(b, a % b, &x1, &y1);
     89     *x = y1;
     90     *y = x1 - (a / b) * y1;
     91     return gcd;
     92 }
     93 
     94 // Modular inverse
     95 static uint64_t mod_inverse(uint64_t a, uint64_t m) {
     96     int64_t x, y;
     97     int64_t gcd = extended_gcd(a, m, &x, &y);
     98     if (gcd != 1) return 0; // Inverse doesn't exist
     99     return (x % m + m) % m;
    100 }
    101 
    102 // Modular multiplication without overflow (using addition)
    103 static uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t mod) {
    104     uint64_t result = 0;
    105     a = a % mod;
    106     while (b > 0) {
    107         if (b & 1) {
    108             result = (result + a) % mod;
    109         }
    110         a = (a * 2) % mod;
    111         b = b >> 1;
    112     }
    113     return result;
    114 }
    115 
    116 // Modular exponentiation
    117 static uint64_t mod_exp(uint64_t base, uint64_t exp, uint64_t mod) {
    118     uint64_t result = 1;
    119     base = base % mod;
    120 
    121     while (exp > 0) {
    122         if (exp % 2 == 1) {
    123             result = mod_mul(result, base, mod);
    124         }
    125         exp = exp >> 1;
    126         base = mod_mul(base, base, mod);
    127     }
    128 
    129     return result;
    130 }
    131 
    132 // Simple RSA key generation (educational - uses small keys)
    133 int rsa_generate_key_simple(rsa_public_key *pub, rsa_private_key *priv, int bits) {
    134     if (bits > 32) {
    135         fprintf(stderr, "Simple implementation limited to 32 bits\n");
    136         fprintf(stderr, "For production use, compile with -DUSE_GMP\n");
    137         return -1;
    138     }
    139 
    140     // Generate two prime numbers
    141     uint32_t p = generate_small_prime();
    142     uint32_t q = generate_small_prime();
    143 
    144     while (p == q) {
    145         q = generate_small_prime();
    146     }
    147 
    148     printf("Generated primes: p=%u, q=%u\n", p, q);
    149 
    150     // Calculate n = p * q
    151     uint64_t n = (uint64_t)p * q;
    152 
    153     // Calculate φ(n) = (p-1)(q-1)
    154     uint64_t phi = (uint64_t)(p - 1) * (q - 1);
    155 
    156     // Choose e (commonly 65537, but we'll use 65537 or smaller if needed)
    157     uint64_t e = 65537;
    158     if (e >= phi) {
    159         e = 3;
    160     }
    161 
    162     // Ensure gcd(e, phi) = 1
    163     while (e < phi) {
    164         int64_t x, y;
    165         if (extended_gcd(e, phi, &x, &y) == 1) {
    166             break;
    167         }
    168         e += 2;
    169     }
    170 
    171     // Calculate d = e^-1 mod φ(n)
    172     uint64_t d = mod_inverse(e, phi);
    173 
    174     if (d == 0) {
    175         fprintf(stderr, "Failed to calculate private exponent\n");
    176         return -1;
    177     }
    178 
    179     printf("Key parameters: n=%lu, e=%lu, d=%lu\n", n, e, d);
    180 
    181     // Allocate and fill public key
    182     pub->n_len = 8;
    183     pub->n = malloc(pub->n_len);
    184     for (int i = 0; i < 8; i++) {
    185         pub->n[i] = (n >> (56 - i * 8)) & 0xFF;
    186     }
    187 
    188     pub->e_len = 8;
    189     pub->e = malloc(pub->e_len);
    190     for (int i = 0; i < 8; i++) {
    191         pub->e[i] = (e >> (56 - i * 8)) & 0xFF;
    192     }
    193 
    194     // Allocate and fill private key
    195     priv->n_len = pub->n_len;
    196     priv->n = malloc(priv->n_len);
    197     memcpy(priv->n, pub->n, priv->n_len);
    198 
    199     priv->e_len = pub->e_len;
    200     priv->e = malloc(priv->e_len);
    201     memcpy(priv->e, pub->e, priv->e_len);
    202 
    203     priv->d_len = 8;
    204     priv->d = malloc(priv->d_len);
    205     for (int i = 0; i < 8; i++) {
    206         priv->d[i] = (d >> (56 - i * 8)) & 0xFF;
    207     }
    208 
    209     // Store p and q
    210     priv->p_len = 4;
    211     priv->p = malloc(priv->p_len);
    212     for (int i = 0; i < 4; i++) {
    213         priv->p[i] = (p >> (24 - i * 8)) & 0xFF;
    214     }
    215 
    216     priv->q_len = 4;
    217     priv->q = malloc(priv->q_len);
    218     for (int i = 0; i < 4; i++) {
    219         priv->q[i] = (q >> (24 - i * 8)) & 0xFF;
    220     }
    221 
    222     // Calculate CRT parameters
    223     uint64_t dp = d % (p - 1);
    224     uint64_t dq = d % (q - 1);
    225     uint64_t qinv = mod_inverse(q, p);
    226 
    227     priv->dp_len = 8;
    228     priv->dp = malloc(priv->dp_len);
    229     for (int i = 0; i < 8; i++) {
    230         priv->dp[i] = (dp >> (56 - i * 8)) & 0xFF;
    231     }
    232 
    233     priv->dq_len = 8;
    234     priv->dq = malloc(priv->dq_len);
    235     for (int i = 0; i < 8; i++) {
    236         priv->dq[i] = (dq >> (56 - i * 8)) & 0xFF;
    237     }
    238 
    239     priv->qinv_len = 8;
    240     priv->qinv = malloc(priv->qinv_len);
    241     for (int i = 0; i < 8; i++) {
    242         priv->qinv[i] = (qinv >> (56 - i * 8)) & 0xFF;
    243     }
    244 
    245     return 0;
    246 }
    247 
    248 // Convert byte array to uint64_t
    249 static uint64_t bytes_to_u64(const uint8_t *bytes, size_t len) {
    250     uint64_t result = 0;
    251     for (size_t i = 0; i < len && i < 8; i++) {
    252         result = (result << 8) | bytes[i];
    253     }
    254     return result;
    255 }
    256 
    257 // Convert uint64_t to byte array
    258 static void u64_to_bytes(uint64_t val, uint8_t *bytes, size_t len) {
    259     for (size_t i = 0; i < len && i < 8; i++) {
    260         bytes[len - 1 - i] = val & 0xFF;
    261         val >>= 8;
    262     }
    263 }
    264 
    265 // RSA encryption (raw, no padding)
    266 static int rsa_encrypt_raw(const rsa_public_key *key, const uint8_t *plaintext,
    267                            size_t pt_len, uint8_t *ciphertext, size_t *ct_len) {
    268     if (pt_len > key->n_len) {
    269         fprintf(stderr, "Plaintext too long\n");
    270         return -1;
    271     }
    272 
    273     uint64_t m = bytes_to_u64(plaintext, pt_len);
    274     uint64_t n = bytes_to_u64(key->n, key->n_len);
    275     uint64_t e = bytes_to_u64(key->e, key->e_len);
    276 
    277     if (m >= n) {
    278         fprintf(stderr, "Message must be less than modulus\n");
    279         return -1;
    280     }
    281 
    282     uint64_t c = mod_exp(m, e, n);
    283 
    284     *ct_len = key->n_len;
    285     u64_to_bytes(c, ciphertext, *ct_len);
    286 
    287     return 0;
    288 }
    289 
    290 // RSA decryption (raw, no padding)
    291 static int rsa_decrypt_raw(const rsa_private_key *key, const uint8_t *ciphertext,
    292                            size_t ct_len, uint8_t *plaintext, size_t *pt_len) {
    293     if (ct_len != key->n_len) {
    294         fprintf(stderr, "Invalid ciphertext length\n");
    295         return -1;
    296     }
    297 
    298     uint64_t c = bytes_to_u64(ciphertext, ct_len);
    299     uint64_t n = bytes_to_u64(key->n, key->n_len);
    300     uint64_t d = bytes_to_u64(key->d, key->d_len);
    301 
    302     uint64_t m = mod_exp(c, d, n);
    303 
    304     *pt_len = key->n_len;
    305     u64_to_bytes(m, plaintext, *pt_len);
    306 
    307     return 0;
    308 }
    309 
    310 // PKCS#1 v1.5 padding for encryption
    311 static int pkcs1_pad_encrypt(const uint8_t *message, size_t msg_len,
    312                              uint8_t *padded, size_t padded_len) {
    313     if (msg_len > padded_len - 11) {
    314         fprintf(stderr, "Message too long for padding\n");
    315         return -1;
    316     }
    317 
    318     padded[0] = 0x00;
    319     padded[1] = 0x02;
    320 
    321     // Generate random padding
    322     size_t ps_len = padded_len - msg_len - 3;
    323     get_random_bytes(padded + 2, ps_len);
    324 
    325     // Ensure no zero bytes in padding
    326     for (size_t i = 0; i < ps_len; i++) {
    327         if (padded[2 + i] == 0) {
    328             padded[2 + i] = 1;
    329         }
    330     }
    331 
    332     padded[2 + ps_len] = 0x00;
    333     memcpy(padded + 2 + ps_len + 1, message, msg_len);
    334 
    335     return 0;
    336 }
    337 
    338 // PKCS#1 v1.5 unpadding for encryption
    339 static int pkcs1_unpad_encrypt(const uint8_t *padded, size_t padded_len,
    340                                uint8_t *message, size_t *msg_len) {
    341     if (padded_len < 11) {
    342         return -1;
    343     }
    344 
    345     if (padded[0] != 0x00 || padded[1] != 0x02) {
    346         fprintf(stderr, "Invalid padding format\n");
    347         return -1;
    348     }
    349 
    350     size_t i = 2;
    351     while (i < padded_len && padded[i] != 0x00) {
    352         i++;
    353     }
    354 
    355     if (i >= padded_len) {
    356         fprintf(stderr, "No padding separator found\n");
    357         return -1;
    358     }
    359 
    360     i++; // Skip the 0x00 separator
    361     *msg_len = padded_len - i;
    362     memcpy(message, padded + i, *msg_len);
    363 
    364     return 0;
    365 }
    366 
    367 // RSA encryption with PKCS#1 v1.5 padding
    368 int rsa_encrypt(const rsa_public_key *key, const uint8_t *plaintext,
    369                 size_t pt_len, uint8_t *ciphertext, size_t *ct_len) {
    370     uint8_t padded[256];
    371 
    372     if (pkcs1_pad_encrypt(plaintext, pt_len, padded, key->n_len) != 0) {
    373         return -1;
    374     }
    375 
    376     return rsa_encrypt_raw(key, padded, key->n_len, ciphertext, ct_len);
    377 }
    378 
    379 // RSA decryption with PKCS#1 v1.5 padding
    380 int rsa_decrypt(const rsa_private_key *key, const uint8_t *ciphertext,
    381                 size_t ct_len, uint8_t *plaintext, size_t *pt_len) {
    382     uint8_t padded[256];
    383     size_t padded_len;
    384 
    385     if (rsa_decrypt_raw(key, ciphertext, ct_len, padded, &padded_len) != 0) {
    386         return -1;
    387     }
    388 
    389     return pkcs1_unpad_encrypt(padded, padded_len, plaintext, pt_len);
    390 }
    391 
    392 // RSA signing (simplified - just encrypt with private key)
    393 int rsa_sign(const rsa_private_key *key, const uint8_t *message,
    394              size_t msg_len, uint8_t *signature, size_t *sig_len) {
    395     if (msg_len > key->n_len - 11) {
    396         fprintf(stderr, "Message too long for signing\n");
    397         return -1;
    398     }
    399 
    400     // For proper signing, we should hash the message first
    401     // Here we use PKCS#1 v1.5 signature padding
    402     uint8_t padded[256];
    403     padded[0] = 0x00;
    404     padded[1] = 0x01;
    405 
    406     size_t ps_len = key->n_len - msg_len - 3;
    407     memset(padded + 2, 0xFF, ps_len);
    408     padded[2 + ps_len] = 0x00;
    409     memcpy(padded + 2 + ps_len + 1, message, msg_len);
    410 
    411     return rsa_decrypt_raw(key, padded, key->n_len, signature, sig_len);
    412 }
    413 
    414 // RSA signature verification
    415 int rsa_verify(const rsa_public_key *key, const uint8_t *message,
    416                size_t msg_len, const uint8_t *signature, size_t sig_len) {
    417     uint8_t decrypted[256];
    418     size_t dec_len;
    419 
    420     if (rsa_encrypt_raw(key, signature, sig_len, decrypted, &dec_len) != 0) {
    421         return -1;
    422     }
    423 
    424     // Verify padding
    425     if (decrypted[0] != 0x00 || decrypted[1] != 0x01) {
    426         return -1;
    427     }
    428 
    429     size_t i = 2;
    430     while (i < dec_len && decrypted[i] == 0xFF) {
    431         i++;
    432     }
    433 
    434     if (i >= dec_len || decrypted[i] != 0x00) {
    435         return -1;
    436     }
    437 
    438     i++;
    439     size_t recovered_len = dec_len - i;
    440 
    441     if (recovered_len != msg_len) {
    442         return -1;
    443     }
    444 
    445     return memcmp(decrypted + i, message, msg_len) == 0 ? 0 : -1;
    446 }
    447 
    448 // Free keys
    449 void rsa_free_public_key(rsa_public_key *key) {
    450     if (key->n) {
    451         memset(key->n, 0, key->n_len);
    452         free(key->n);
    453     }
    454     if (key->e) {
    455         memset(key->e, 0, key->e_len);
    456         free(key->e);
    457     }
    458     memset(key, 0, sizeof(*key));
    459 }
    460 
    461 void rsa_free_private_key(rsa_private_key *key) {
    462     // Zero all sensitive data before freeing
    463     if (key->n) {
    464         memset(key->n, 0, key->n_len);
    465         free(key->n);
    466     }
    467     if (key->e) {
    468         memset(key->e, 0, key->e_len);
    469         free(key->e);
    470     }
    471     if (key->d) {
    472         memset(key->d, 0, key->d_len);
    473         free(key->d);
    474     }
    475     if (key->p) {
    476         memset(key->p, 0, key->p_len);
    477         free(key->p);
    478     }
    479     if (key->q) {
    480         memset(key->q, 0, key->q_len);
    481         free(key->q);
    482     }
    483     if (key->dp) {
    484         memset(key->dp, 0, key->dp_len);
    485         free(key->dp);
    486     }
    487     if (key->dq) {
    488         memset(key->dq, 0, key->dq_len);
    489         free(key->dq);
    490     }
    491     if (key->qinv) {
    492         memset(key->qinv, 0, key->qinv_len);
    493         free(key->qinv);
    494     }
    495     memset(key, 0, sizeof(*key));
    496 }
    497 
    498 // Print key (hex format)
    499 void rsa_print_public_key(const rsa_public_key *key) {
    500     printf("Public Key:\n");
    501     printf("  n (%zu bytes): ", key->n_len);
    502     for (size_t i = 0; i < key->n_len; i++) {
    503         printf("%02x", key->n[i]);
    504     }
    505     printf("\n  e (%zu bytes): ", key->e_len);
    506     for (size_t i = 0; i < key->e_len; i++) {
    507         printf("%02x", key->e[i]);
    508     }
    509     printf("\n");
    510 }
    511 
    512 /* ============================================================================
    513  * RSA-PSS (Probabilistic Signature Scheme) Implementation
    514  * RFC 8017 - PKCS #1: RSA Cryptography Specifications Version 2.2
    515  * ========================================================================= */
    516 
    517 /* SHA-256 function (from hashing/SHA256.c) */
    518 extern void sha256(const uint8_t *data, size_t len, uint8_t digest[32]);
    519 
    520 /**
    521  * MGF1 - Mask Generation Function based on SHA-256
    522  * RFC 8017 Section B.2.1
    523  */
    524 static void mgf1_sha256(const uint8_t *seed, size_t seed_len,
    525                         uint8_t *mask, size_t mask_len) {
    526     uint8_t counter[4];
    527     size_t offset = 0;
    528     uint32_t count = 0;
    529     uint8_t hash_input[256 + 4];  /* Max seed length + counter */
    530     uint8_t hash[32];
    531 
    532     if (seed_len > 256) seed_len = 256;  /* Safety limit */
    533 
    534     while (offset < mask_len) {
    535         /* Convert counter to big-endian bytes */
    536         counter[0] = (count >> 24) & 0xFF;
    537         counter[1] = (count >> 16) & 0xFF;
    538         counter[2] = (count >> 8) & 0xFF;
    539         counter[3] = count & 0xFF;
    540 
    541         /* Hash seed || counter */
    542         memcpy(hash_input, seed, seed_len);
    543         memcpy(hash_input + seed_len, counter, 4);
    544 
    545         sha256(hash_input, seed_len + 4, hash);
    546 
    547         /* Copy to mask */
    548         size_t to_copy = (mask_len - offset < 32) ? (mask_len - offset) : 32;
    549         memcpy(mask + offset, hash, to_copy);
    550 
    551         offset += to_copy;
    552         count++;
    553     }
    554 }
    555 
    556 /**
    557  * EMSA-PSS-ENCODE operation
    558  * RFC 8017 Section 9.1.1
    559  */
    560 static int emsa_pss_encode(const uint8_t *m_hash, size_t h_len,
    561                            size_t em_bits, size_t s_len,
    562                            uint8_t *em, size_t em_len) {
    563     /* PSS encoding requires em_len >= h_len + s_len + 2 */
    564     if (em_len < h_len + s_len + 2) {
    565         return -1;
    566     }
    567 
    568     /* Generate random salt */
    569     uint8_t salt[64];
    570     if (s_len > 64) s_len = 64;  /* Safety limit */
    571     get_random_bytes(salt, s_len);
    572 
    573     /* M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
    574     uint8_t m_prime[8 + 64 + 64];  /* 8 padding + max hash + max salt */
    575     memset(m_prime, 0, 8);
    576     memcpy(m_prime + 8, m_hash, h_len);
    577     memcpy(m_prime + 8 + h_len, salt, s_len);
    578 
    579     /* H = Hash(M') */
    580     uint8_t h[32];
    581     sha256(m_prime, 8 + h_len + s_len, h);
    582 
    583     /* Generate DB = PS || 0x01 || salt */
    584     size_t ps_len = em_len - s_len - h_len - 2;
    585     uint8_t *db = calloc(em_len - h_len - 1, 1);  /* Zero-initialized PS */
    586     if (!db) return -1;
    587 
    588     db[ps_len] = 0x01;
    589     memcpy(db + ps_len + 1, salt, s_len);
    590 
    591     /* dbMask = MGF(H, emLen - hLen - 1) */
    592     size_t db_len = em_len - h_len - 1;
    593     uint8_t *db_mask = malloc(db_len);
    594     if (!db_mask) {
    595         free(db);
    596         return -1;
    597     }
    598 
    599     mgf1_sha256(h, h_len, db_mask, db_len);
    600 
    601     /* maskedDB = DB xor dbMask */
    602     for (size_t i = 0; i < db_len; i++) {
    603         db[i] ^= db_mask[i];
    604     }
    605 
    606     /* Set leftmost bits of maskedDB to zero (8*emLen - emBits bits) */
    607     size_t mask_bits = 8 * em_len - em_bits;
    608     if (mask_bits > 0 && mask_bits < 8) {
    609         db[0] &= (0xFF >> mask_bits);
    610     }
    611 
    612     /* EM = maskedDB || H || 0xbc */
    613     memcpy(em, db, db_len);
    614     memcpy(em + db_len, h, h_len);
    615     em[em_len - 1] = 0xbc;
    616 
    617     free(db);
    618     free(db_mask);
    619     return 0;
    620 }
    621 
    622 /**
    623  * EMSA-PSS-VERIFY operation
    624  * RFC 8017 Section 9.1.2
    625  */
    626 static int emsa_pss_verify(const uint8_t *m_hash, size_t h_len,
    627                            const uint8_t *em, size_t em_len,
    628                            size_t em_bits, size_t s_len) {
    629     /* Check minimum length */
    630     if (em_len < h_len + s_len + 2) {
    631         return -1;
    632     }
    633 
    634     /* Check trailer field (0xbc) */
    635     if (em[em_len - 1] != 0xbc) {
    636         return -1;
    637     }
    638 
    639     /* Split EM = maskedDB || H */
    640     size_t db_len = em_len - h_len - 1;
    641     const uint8_t *masked_db = em;
    642     const uint8_t *h = em + db_len;
    643 
    644     /* Check leftmost bits of maskedDB are zero */
    645     size_t mask_bits = 8 * em_len - em_bits;
    646     if (mask_bits > 0 && mask_bits < 8) {
    647         uint8_t mask = (0xFF << (8 - mask_bits)) & 0xFF;
    648         if ((masked_db[0] & mask) != 0) {
    649             return -1;
    650         }
    651     }
    652 
    653     /* dbMask = MGF(H, emLen - hLen - 1) */
    654     uint8_t *db_mask = malloc(db_len);
    655     if (!db_mask) return -1;
    656 
    657     mgf1_sha256(h, h_len, db_mask, db_len);
    658 
    659     /* DB = maskedDB xor dbMask */
    660     uint8_t *db = malloc(db_len);
    661     if (!db) {
    662         free(db_mask);
    663         return -1;
    664     }
    665 
    666     for (size_t i = 0; i < db_len; i++) {
    667         db[i] = masked_db[i] ^ db_mask[i];
    668     }
    669 
    670     /* Set leftmost bits to zero */
    671     if (mask_bits > 0 && mask_bits < 8) {
    672         db[0] &= (0xFF >> mask_bits);
    673     }
    674 
    675     /* Verify DB = PS || 0x01 || salt */
    676     size_t ps_len = em_len - s_len - h_len - 2;
    677 
    678     /* Check PS is all zeros */
    679     for (size_t i = 0; i < ps_len; i++) {
    680         if (db[i] != 0x00) {
    681             free(db);
    682             free(db_mask);
    683             return -1;
    684         }
    685     }
    686 
    687     /* Check 0x01 separator */
    688     if (db[ps_len] != 0x01) {
    689         free(db);
    690         free(db_mask);
    691         return -1;
    692     }
    693 
    694     /* Extract salt */
    695     uint8_t *salt = db + ps_len + 1;
    696 
    697     /* M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
    698     uint8_t m_prime[8 + 64 + 64];
    699     memset(m_prime, 0, 8);
    700     memcpy(m_prime + 8, m_hash, h_len);
    701     memcpy(m_prime + 8 + h_len, salt, s_len);
    702 
    703     /* H' = Hash(M') */
    704     uint8_t h_prime[32];
    705     sha256(m_prime, 8 + h_len + s_len, h_prime);
    706 
    707     /* Verify H == H' */
    708     int result = (memcmp(h, h_prime, h_len) == 0) ? 0 : -1;
    709 
    710     free(db);
    711     free(db_mask);
    712     return result;
    713 }
    714 
    715 /**
    716  * RSA-PSS Sign
    717  * RFC 8017 Section 8.1.1
    718  */
    719 int rsa_sign_pss(const rsa_private_key *key,
    720                  const uint8_t *message, size_t msg_len,
    721                  uint8_t *signature, size_t *sig_len) {
    722     /* PSS parameters (using SHA-256) */
    723     const size_t h_len = 32;  /* SHA-256 output length */
    724     const size_t s_len = 32;  /* Salt length (same as hash length) */
    725 
    726     /* Message must be a SHA-256 hash */
    727     if (msg_len != h_len) {
    728         fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n");
    729         return -1;
    730     }
    731 
    732     size_t em_len = key->n_len;
    733     size_t em_bits = em_len * 8 - 1;  /* One less than modulus bits */
    734 
    735     if (em_len < h_len + s_len + 2) {
    736         fprintf(stderr, "RSA-PSS: Key too short for PSS padding\n");
    737         return -1;
    738     }
    739 
    740     /* EMSA-PSS encoding */
    741     uint8_t *em = malloc(em_len);
    742     if (!em) return -1;
    743 
    744     if (emsa_pss_encode(message, h_len, em_bits, s_len, em, em_len) != 0) {
    745         free(em);
    746         return -1;
    747     }
    748 
    749     /* RSA signature primitive (decrypt with private key) */
    750     int result = rsa_decrypt_raw(key, em, em_len, signature, sig_len);
    751     free(em);
    752     return result;
    753 }
    754 
    755 /**
    756  * RSA-PSS Verify
    757  * RFC 8017 Section 8.1.2
    758  */
    759 int rsa_verify_pss(const rsa_public_key *key,
    760                    const uint8_t *message, size_t msg_len,
    761                    const uint8_t *signature, size_t sig_len) {
    762     /* PSS parameters (using SHA-256) */
    763     const size_t h_len = 32;  /* SHA-256 output length */
    764     const size_t s_len = 32;  /* Salt length (same as hash length) */
    765 
    766     /* Message must be a SHA-256 hash */
    767     if (msg_len != h_len) {
    768         fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n");
    769         return -1;
    770     }
    771 
    772     if (sig_len != key->n_len) {
    773         return -1;
    774     }
    775 
    776     size_t em_len = key->n_len;
    777     size_t em_bits = em_len * 8 - 1;
    778 
    779     /* RSA verification primitive (encrypt with public key) */
    780     uint8_t *em = malloc(em_len);
    781     if (!em) return -1;
    782 
    783     size_t recovered_len;
    784     if (rsa_encrypt_raw(key, signature, sig_len, em, &recovered_len) != 0) {
    785         free(em);
    786         return -1;
    787     }
    788 
    789     /* EMSA-PSS verification */
    790     int result = emsa_pss_verify(message, h_len, em, em_len, em_bits, s_len);
    791     free(em);
    792     return result;
    793 }
    794