RSA.c (20837B)
1 /* 2 * RSA Implementation 3 * Supports key generation, encryption/decryption, signing/verification 4 * Implements PKCS#1 v1.5 and OAEP padding 5 * 6 * Note: This is an educational implementation. For production use, 7 * consider using established libraries like OpenSSL or mbedTLS. 8 */ 9 10 #include "RSA.h" 11 #include "CSPRNG.h" 12 #include <stdio.h> 13 #include <stdlib.h> 14 #include <string.h> 15 #include <stdint.h> 16 #include <time.h> 17 18 // Large integer support using GMP (GNU Multiple Precision) 19 // If GMP is not available, this provides a basic implementation 20 #ifdef USE_GMP 21 #include <gmp.h> 22 #else 23 // Simple big integer implementation for demonstration 24 typedef struct { 25 uint32_t *data; 26 size_t size; 27 size_t capacity; 28 } bigint; 29 30 static void bigint_init(bigint *n) { 31 n->capacity = 64; 32 n->size = 0; 33 n->data = calloc(n->capacity, sizeof(uint32_t)); 34 } 35 36 static void bigint_free(bigint *n) { 37 if (n->data) free(n->data); 38 n->data = NULL; 39 n->size = 0; 40 } 41 42 static void bigint_from_uint(bigint *n, uint32_t val) { 43 n->size = 1; 44 n->data[0] = val; 45 } 46 #endif 47 48 // Type definitions are in RSA.h 49 50 // Random number generation using global CSPRNG 51 static void get_random_bytes(uint8_t *buf, size_t len) { 52 random_bytes(buf, len); 53 } 54 55 // Miller-Rabin primality test (simplified) 56 static int is_prime_simple(uint32_t n) { 57 if (n < 2) return 0; 58 if (n == 2 || n == 3) return 1; 59 if (n % 2 == 0) return 0; 60 61 for (uint32_t i = 3; i * i <= n; i += 2) { 62 if (n % i == 0) return 0; 63 } 64 return 1; 65 } 66 67 // Generate a random prime (simplified for demonstration) 68 static uint32_t generate_small_prime(void) { 69 uint32_t candidate; 70 do { 71 uint8_t bytes[2]; 72 get_random_bytes(bytes, 2); 73 candidate = (bytes[0] << 8) | bytes[1]; 74 candidate |= 0x8000; // Ensure high bit set 75 } while (!is_prime_simple(candidate)); 76 return candidate; 77 } 78 79 // Extended Euclidean algorithm 80 static int64_t extended_gcd(int64_t a, int64_t b, int64_t *x, int64_t *y) { 81 if (b == 0) { 82 *x = 1; 83 *y = 0; 84 return a; 85 } 86 87 int64_t x1, y1; 88 int64_t gcd = extended_gcd(b, a % b, &x1, &y1); 89 *x = y1; 90 *y = x1 - (a / b) * y1; 91 return gcd; 92 } 93 94 // Modular inverse 95 static uint64_t mod_inverse(uint64_t a, uint64_t m) { 96 int64_t x, y; 97 int64_t gcd = extended_gcd(a, m, &x, &y); 98 if (gcd != 1) return 0; // Inverse doesn't exist 99 return (x % m + m) % m; 100 } 101 102 // Modular multiplication without overflow (using addition) 103 static uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t mod) { 104 uint64_t result = 0; 105 a = a % mod; 106 while (b > 0) { 107 if (b & 1) { 108 result = (result + a) % mod; 109 } 110 a = (a * 2) % mod; 111 b = b >> 1; 112 } 113 return result; 114 } 115 116 // Modular exponentiation 117 static uint64_t mod_exp(uint64_t base, uint64_t exp, uint64_t mod) { 118 uint64_t result = 1; 119 base = base % mod; 120 121 while (exp > 0) { 122 if (exp % 2 == 1) { 123 result = mod_mul(result, base, mod); 124 } 125 exp = exp >> 1; 126 base = mod_mul(base, base, mod); 127 } 128 129 return result; 130 } 131 132 // Simple RSA key generation (educational - uses small keys) 133 int rsa_generate_key_simple(rsa_public_key *pub, rsa_private_key *priv, int bits) { 134 if (bits > 32) { 135 fprintf(stderr, "Simple implementation limited to 32 bits\n"); 136 fprintf(stderr, "For production use, compile with -DUSE_GMP\n"); 137 return -1; 138 } 139 140 // Generate two prime numbers 141 uint32_t p = generate_small_prime(); 142 uint32_t q = generate_small_prime(); 143 144 while (p == q) { 145 q = generate_small_prime(); 146 } 147 148 printf("Generated primes: p=%u, q=%u\n", p, q); 149 150 // Calculate n = p * q 151 uint64_t n = (uint64_t)p * q; 152 153 // Calculate φ(n) = (p-1)(q-1) 154 uint64_t phi = (uint64_t)(p - 1) * (q - 1); 155 156 // Choose e (commonly 65537, but we'll use 65537 or smaller if needed) 157 uint64_t e = 65537; 158 if (e >= phi) { 159 e = 3; 160 } 161 162 // Ensure gcd(e, phi) = 1 163 while (e < phi) { 164 int64_t x, y; 165 if (extended_gcd(e, phi, &x, &y) == 1) { 166 break; 167 } 168 e += 2; 169 } 170 171 // Calculate d = e^-1 mod φ(n) 172 uint64_t d = mod_inverse(e, phi); 173 174 if (d == 0) { 175 fprintf(stderr, "Failed to calculate private exponent\n"); 176 return -1; 177 } 178 179 printf("Key parameters: n=%lu, e=%lu, d=%lu\n", n, e, d); 180 181 // Allocate and fill public key 182 pub->n_len = 8; 183 pub->n = malloc(pub->n_len); 184 for (int i = 0; i < 8; i++) { 185 pub->n[i] = (n >> (56 - i * 8)) & 0xFF; 186 } 187 188 pub->e_len = 8; 189 pub->e = malloc(pub->e_len); 190 for (int i = 0; i < 8; i++) { 191 pub->e[i] = (e >> (56 - i * 8)) & 0xFF; 192 } 193 194 // Allocate and fill private key 195 priv->n_len = pub->n_len; 196 priv->n = malloc(priv->n_len); 197 memcpy(priv->n, pub->n, priv->n_len); 198 199 priv->e_len = pub->e_len; 200 priv->e = malloc(priv->e_len); 201 memcpy(priv->e, pub->e, priv->e_len); 202 203 priv->d_len = 8; 204 priv->d = malloc(priv->d_len); 205 for (int i = 0; i < 8; i++) { 206 priv->d[i] = (d >> (56 - i * 8)) & 0xFF; 207 } 208 209 // Store p and q 210 priv->p_len = 4; 211 priv->p = malloc(priv->p_len); 212 for (int i = 0; i < 4; i++) { 213 priv->p[i] = (p >> (24 - i * 8)) & 0xFF; 214 } 215 216 priv->q_len = 4; 217 priv->q = malloc(priv->q_len); 218 for (int i = 0; i < 4; i++) { 219 priv->q[i] = (q >> (24 - i * 8)) & 0xFF; 220 } 221 222 // Calculate CRT parameters 223 uint64_t dp = d % (p - 1); 224 uint64_t dq = d % (q - 1); 225 uint64_t qinv = mod_inverse(q, p); 226 227 priv->dp_len = 8; 228 priv->dp = malloc(priv->dp_len); 229 for (int i = 0; i < 8; i++) { 230 priv->dp[i] = (dp >> (56 - i * 8)) & 0xFF; 231 } 232 233 priv->dq_len = 8; 234 priv->dq = malloc(priv->dq_len); 235 for (int i = 0; i < 8; i++) { 236 priv->dq[i] = (dq >> (56 - i * 8)) & 0xFF; 237 } 238 239 priv->qinv_len = 8; 240 priv->qinv = malloc(priv->qinv_len); 241 for (int i = 0; i < 8; i++) { 242 priv->qinv[i] = (qinv >> (56 - i * 8)) & 0xFF; 243 } 244 245 return 0; 246 } 247 248 // Convert byte array to uint64_t 249 static uint64_t bytes_to_u64(const uint8_t *bytes, size_t len) { 250 uint64_t result = 0; 251 for (size_t i = 0; i < len && i < 8; i++) { 252 result = (result << 8) | bytes[i]; 253 } 254 return result; 255 } 256 257 // Convert uint64_t to byte array 258 static void u64_to_bytes(uint64_t val, uint8_t *bytes, size_t len) { 259 for (size_t i = 0; i < len && i < 8; i++) { 260 bytes[len - 1 - i] = val & 0xFF; 261 val >>= 8; 262 } 263 } 264 265 // RSA encryption (raw, no padding) 266 static int rsa_encrypt_raw(const rsa_public_key *key, const uint8_t *plaintext, 267 size_t pt_len, uint8_t *ciphertext, size_t *ct_len) { 268 if (pt_len > key->n_len) { 269 fprintf(stderr, "Plaintext too long\n"); 270 return -1; 271 } 272 273 uint64_t m = bytes_to_u64(plaintext, pt_len); 274 uint64_t n = bytes_to_u64(key->n, key->n_len); 275 uint64_t e = bytes_to_u64(key->e, key->e_len); 276 277 if (m >= n) { 278 fprintf(stderr, "Message must be less than modulus\n"); 279 return -1; 280 } 281 282 uint64_t c = mod_exp(m, e, n); 283 284 *ct_len = key->n_len; 285 u64_to_bytes(c, ciphertext, *ct_len); 286 287 return 0; 288 } 289 290 // RSA decryption (raw, no padding) 291 static int rsa_decrypt_raw(const rsa_private_key *key, const uint8_t *ciphertext, 292 size_t ct_len, uint8_t *plaintext, size_t *pt_len) { 293 if (ct_len != key->n_len) { 294 fprintf(stderr, "Invalid ciphertext length\n"); 295 return -1; 296 } 297 298 uint64_t c = bytes_to_u64(ciphertext, ct_len); 299 uint64_t n = bytes_to_u64(key->n, key->n_len); 300 uint64_t d = bytes_to_u64(key->d, key->d_len); 301 302 uint64_t m = mod_exp(c, d, n); 303 304 *pt_len = key->n_len; 305 u64_to_bytes(m, plaintext, *pt_len); 306 307 return 0; 308 } 309 310 // PKCS#1 v1.5 padding for encryption 311 static int pkcs1_pad_encrypt(const uint8_t *message, size_t msg_len, 312 uint8_t *padded, size_t padded_len) { 313 if (msg_len > padded_len - 11) { 314 fprintf(stderr, "Message too long for padding\n"); 315 return -1; 316 } 317 318 padded[0] = 0x00; 319 padded[1] = 0x02; 320 321 // Generate random padding 322 size_t ps_len = padded_len - msg_len - 3; 323 get_random_bytes(padded + 2, ps_len); 324 325 // Ensure no zero bytes in padding 326 for (size_t i = 0; i < ps_len; i++) { 327 if (padded[2 + i] == 0) { 328 padded[2 + i] = 1; 329 } 330 } 331 332 padded[2 + ps_len] = 0x00; 333 memcpy(padded + 2 + ps_len + 1, message, msg_len); 334 335 return 0; 336 } 337 338 // PKCS#1 v1.5 unpadding for encryption 339 static int pkcs1_unpad_encrypt(const uint8_t *padded, size_t padded_len, 340 uint8_t *message, size_t *msg_len) { 341 if (padded_len < 11) { 342 return -1; 343 } 344 345 if (padded[0] != 0x00 || padded[1] != 0x02) { 346 fprintf(stderr, "Invalid padding format\n"); 347 return -1; 348 } 349 350 size_t i = 2; 351 while (i < padded_len && padded[i] != 0x00) { 352 i++; 353 } 354 355 if (i >= padded_len) { 356 fprintf(stderr, "No padding separator found\n"); 357 return -1; 358 } 359 360 i++; // Skip the 0x00 separator 361 *msg_len = padded_len - i; 362 memcpy(message, padded + i, *msg_len); 363 364 return 0; 365 } 366 367 // RSA encryption with PKCS#1 v1.5 padding 368 int rsa_encrypt(const rsa_public_key *key, const uint8_t *plaintext, 369 size_t pt_len, uint8_t *ciphertext, size_t *ct_len) { 370 uint8_t padded[256]; 371 372 if (pkcs1_pad_encrypt(plaintext, pt_len, padded, key->n_len) != 0) { 373 return -1; 374 } 375 376 return rsa_encrypt_raw(key, padded, key->n_len, ciphertext, ct_len); 377 } 378 379 // RSA decryption with PKCS#1 v1.5 padding 380 int rsa_decrypt(const rsa_private_key *key, const uint8_t *ciphertext, 381 size_t ct_len, uint8_t *plaintext, size_t *pt_len) { 382 uint8_t padded[256]; 383 size_t padded_len; 384 385 if (rsa_decrypt_raw(key, ciphertext, ct_len, padded, &padded_len) != 0) { 386 return -1; 387 } 388 389 return pkcs1_unpad_encrypt(padded, padded_len, plaintext, pt_len); 390 } 391 392 // RSA signing (simplified - just encrypt with private key) 393 int rsa_sign(const rsa_private_key *key, const uint8_t *message, 394 size_t msg_len, uint8_t *signature, size_t *sig_len) { 395 if (msg_len > key->n_len - 11) { 396 fprintf(stderr, "Message too long for signing\n"); 397 return -1; 398 } 399 400 // For proper signing, we should hash the message first 401 // Here we use PKCS#1 v1.5 signature padding 402 uint8_t padded[256]; 403 padded[0] = 0x00; 404 padded[1] = 0x01; 405 406 size_t ps_len = key->n_len - msg_len - 3; 407 memset(padded + 2, 0xFF, ps_len); 408 padded[2 + ps_len] = 0x00; 409 memcpy(padded + 2 + ps_len + 1, message, msg_len); 410 411 return rsa_decrypt_raw(key, padded, key->n_len, signature, sig_len); 412 } 413 414 // RSA signature verification 415 int rsa_verify(const rsa_public_key *key, const uint8_t *message, 416 size_t msg_len, const uint8_t *signature, size_t sig_len) { 417 uint8_t decrypted[256]; 418 size_t dec_len; 419 420 if (rsa_encrypt_raw(key, signature, sig_len, decrypted, &dec_len) != 0) { 421 return -1; 422 } 423 424 // Verify padding 425 if (decrypted[0] != 0x00 || decrypted[1] != 0x01) { 426 return -1; 427 } 428 429 size_t i = 2; 430 while (i < dec_len && decrypted[i] == 0xFF) { 431 i++; 432 } 433 434 if (i >= dec_len || decrypted[i] != 0x00) { 435 return -1; 436 } 437 438 i++; 439 size_t recovered_len = dec_len - i; 440 441 if (recovered_len != msg_len) { 442 return -1; 443 } 444 445 return memcmp(decrypted + i, message, msg_len) == 0 ? 0 : -1; 446 } 447 448 // Free keys 449 void rsa_free_public_key(rsa_public_key *key) { 450 if (key->n) { 451 memset(key->n, 0, key->n_len); 452 free(key->n); 453 } 454 if (key->e) { 455 memset(key->e, 0, key->e_len); 456 free(key->e); 457 } 458 memset(key, 0, sizeof(*key)); 459 } 460 461 void rsa_free_private_key(rsa_private_key *key) { 462 // Zero all sensitive data before freeing 463 if (key->n) { 464 memset(key->n, 0, key->n_len); 465 free(key->n); 466 } 467 if (key->e) { 468 memset(key->e, 0, key->e_len); 469 free(key->e); 470 } 471 if (key->d) { 472 memset(key->d, 0, key->d_len); 473 free(key->d); 474 } 475 if (key->p) { 476 memset(key->p, 0, key->p_len); 477 free(key->p); 478 } 479 if (key->q) { 480 memset(key->q, 0, key->q_len); 481 free(key->q); 482 } 483 if (key->dp) { 484 memset(key->dp, 0, key->dp_len); 485 free(key->dp); 486 } 487 if (key->dq) { 488 memset(key->dq, 0, key->dq_len); 489 free(key->dq); 490 } 491 if (key->qinv) { 492 memset(key->qinv, 0, key->qinv_len); 493 free(key->qinv); 494 } 495 memset(key, 0, sizeof(*key)); 496 } 497 498 // Print key (hex format) 499 void rsa_print_public_key(const rsa_public_key *key) { 500 printf("Public Key:\n"); 501 printf(" n (%zu bytes): ", key->n_len); 502 for (size_t i = 0; i < key->n_len; i++) { 503 printf("%02x", key->n[i]); 504 } 505 printf("\n e (%zu bytes): ", key->e_len); 506 for (size_t i = 0; i < key->e_len; i++) { 507 printf("%02x", key->e[i]); 508 } 509 printf("\n"); 510 } 511 512 /* ============================================================================ 513 * RSA-PSS (Probabilistic Signature Scheme) Implementation 514 * RFC 8017 - PKCS #1: RSA Cryptography Specifications Version 2.2 515 * ========================================================================= */ 516 517 /* SHA-256 function (from hashing/SHA256.c) */ 518 extern void sha256(const uint8_t *data, size_t len, uint8_t digest[32]); 519 520 /** 521 * MGF1 - Mask Generation Function based on SHA-256 522 * RFC 8017 Section B.2.1 523 */ 524 static void mgf1_sha256(const uint8_t *seed, size_t seed_len, 525 uint8_t *mask, size_t mask_len) { 526 uint8_t counter[4]; 527 size_t offset = 0; 528 uint32_t count = 0; 529 uint8_t hash_input[256 + 4]; /* Max seed length + counter */ 530 uint8_t hash[32]; 531 532 if (seed_len > 256) seed_len = 256; /* Safety limit */ 533 534 while (offset < mask_len) { 535 /* Convert counter to big-endian bytes */ 536 counter[0] = (count >> 24) & 0xFF; 537 counter[1] = (count >> 16) & 0xFF; 538 counter[2] = (count >> 8) & 0xFF; 539 counter[3] = count & 0xFF; 540 541 /* Hash seed || counter */ 542 memcpy(hash_input, seed, seed_len); 543 memcpy(hash_input + seed_len, counter, 4); 544 545 sha256(hash_input, seed_len + 4, hash); 546 547 /* Copy to mask */ 548 size_t to_copy = (mask_len - offset < 32) ? (mask_len - offset) : 32; 549 memcpy(mask + offset, hash, to_copy); 550 551 offset += to_copy; 552 count++; 553 } 554 } 555 556 /** 557 * EMSA-PSS-ENCODE operation 558 * RFC 8017 Section 9.1.1 559 */ 560 static int emsa_pss_encode(const uint8_t *m_hash, size_t h_len, 561 size_t em_bits, size_t s_len, 562 uint8_t *em, size_t em_len) { 563 /* PSS encoding requires em_len >= h_len + s_len + 2 */ 564 if (em_len < h_len + s_len + 2) { 565 return -1; 566 } 567 568 /* Generate random salt */ 569 uint8_t salt[64]; 570 if (s_len > 64) s_len = 64; /* Safety limit */ 571 get_random_bytes(salt, s_len); 572 573 /* M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */ 574 uint8_t m_prime[8 + 64 + 64]; /* 8 padding + max hash + max salt */ 575 memset(m_prime, 0, 8); 576 memcpy(m_prime + 8, m_hash, h_len); 577 memcpy(m_prime + 8 + h_len, salt, s_len); 578 579 /* H = Hash(M') */ 580 uint8_t h[32]; 581 sha256(m_prime, 8 + h_len + s_len, h); 582 583 /* Generate DB = PS || 0x01 || salt */ 584 size_t ps_len = em_len - s_len - h_len - 2; 585 uint8_t *db = calloc(em_len - h_len - 1, 1); /* Zero-initialized PS */ 586 if (!db) return -1; 587 588 db[ps_len] = 0x01; 589 memcpy(db + ps_len + 1, salt, s_len); 590 591 /* dbMask = MGF(H, emLen - hLen - 1) */ 592 size_t db_len = em_len - h_len - 1; 593 uint8_t *db_mask = malloc(db_len); 594 if (!db_mask) { 595 free(db); 596 return -1; 597 } 598 599 mgf1_sha256(h, h_len, db_mask, db_len); 600 601 /* maskedDB = DB xor dbMask */ 602 for (size_t i = 0; i < db_len; i++) { 603 db[i] ^= db_mask[i]; 604 } 605 606 /* Set leftmost bits of maskedDB to zero (8*emLen - emBits bits) */ 607 size_t mask_bits = 8 * em_len - em_bits; 608 if (mask_bits > 0 && mask_bits < 8) { 609 db[0] &= (0xFF >> mask_bits); 610 } 611 612 /* EM = maskedDB || H || 0xbc */ 613 memcpy(em, db, db_len); 614 memcpy(em + db_len, h, h_len); 615 em[em_len - 1] = 0xbc; 616 617 free(db); 618 free(db_mask); 619 return 0; 620 } 621 622 /** 623 * EMSA-PSS-VERIFY operation 624 * RFC 8017 Section 9.1.2 625 */ 626 static int emsa_pss_verify(const uint8_t *m_hash, size_t h_len, 627 const uint8_t *em, size_t em_len, 628 size_t em_bits, size_t s_len) { 629 /* Check minimum length */ 630 if (em_len < h_len + s_len + 2) { 631 return -1; 632 } 633 634 /* Check trailer field (0xbc) */ 635 if (em[em_len - 1] != 0xbc) { 636 return -1; 637 } 638 639 /* Split EM = maskedDB || H */ 640 size_t db_len = em_len - h_len - 1; 641 const uint8_t *masked_db = em; 642 const uint8_t *h = em + db_len; 643 644 /* Check leftmost bits of maskedDB are zero */ 645 size_t mask_bits = 8 * em_len - em_bits; 646 if (mask_bits > 0 && mask_bits < 8) { 647 uint8_t mask = (0xFF << (8 - mask_bits)) & 0xFF; 648 if ((masked_db[0] & mask) != 0) { 649 return -1; 650 } 651 } 652 653 /* dbMask = MGF(H, emLen - hLen - 1) */ 654 uint8_t *db_mask = malloc(db_len); 655 if (!db_mask) return -1; 656 657 mgf1_sha256(h, h_len, db_mask, db_len); 658 659 /* DB = maskedDB xor dbMask */ 660 uint8_t *db = malloc(db_len); 661 if (!db) { 662 free(db_mask); 663 return -1; 664 } 665 666 for (size_t i = 0; i < db_len; i++) { 667 db[i] = masked_db[i] ^ db_mask[i]; 668 } 669 670 /* Set leftmost bits to zero */ 671 if (mask_bits > 0 && mask_bits < 8) { 672 db[0] &= (0xFF >> mask_bits); 673 } 674 675 /* Verify DB = PS || 0x01 || salt */ 676 size_t ps_len = em_len - s_len - h_len - 2; 677 678 /* Check PS is all zeros */ 679 for (size_t i = 0; i < ps_len; i++) { 680 if (db[i] != 0x00) { 681 free(db); 682 free(db_mask); 683 return -1; 684 } 685 } 686 687 /* Check 0x01 separator */ 688 if (db[ps_len] != 0x01) { 689 free(db); 690 free(db_mask); 691 return -1; 692 } 693 694 /* Extract salt */ 695 uint8_t *salt = db + ps_len + 1; 696 697 /* M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */ 698 uint8_t m_prime[8 + 64 + 64]; 699 memset(m_prime, 0, 8); 700 memcpy(m_prime + 8, m_hash, h_len); 701 memcpy(m_prime + 8 + h_len, salt, s_len); 702 703 /* H' = Hash(M') */ 704 uint8_t h_prime[32]; 705 sha256(m_prime, 8 + h_len + s_len, h_prime); 706 707 /* Verify H == H' */ 708 int result = (memcmp(h, h_prime, h_len) == 0) ? 0 : -1; 709 710 free(db); 711 free(db_mask); 712 return result; 713 } 714 715 /** 716 * RSA-PSS Sign 717 * RFC 8017 Section 8.1.1 718 */ 719 int rsa_sign_pss(const rsa_private_key *key, 720 const uint8_t *message, size_t msg_len, 721 uint8_t *signature, size_t *sig_len) { 722 /* PSS parameters (using SHA-256) */ 723 const size_t h_len = 32; /* SHA-256 output length */ 724 const size_t s_len = 32; /* Salt length (same as hash length) */ 725 726 /* Message must be a SHA-256 hash */ 727 if (msg_len != h_len) { 728 fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n"); 729 return -1; 730 } 731 732 size_t em_len = key->n_len; 733 size_t em_bits = em_len * 8 - 1; /* One less than modulus bits */ 734 735 if (em_len < h_len + s_len + 2) { 736 fprintf(stderr, "RSA-PSS: Key too short for PSS padding\n"); 737 return -1; 738 } 739 740 /* EMSA-PSS encoding */ 741 uint8_t *em = malloc(em_len); 742 if (!em) return -1; 743 744 if (emsa_pss_encode(message, h_len, em_bits, s_len, em, em_len) != 0) { 745 free(em); 746 return -1; 747 } 748 749 /* RSA signature primitive (decrypt with private key) */ 750 int result = rsa_decrypt_raw(key, em, em_len, signature, sig_len); 751 free(em); 752 return result; 753 } 754 755 /** 756 * RSA-PSS Verify 757 * RFC 8017 Section 8.1.2 758 */ 759 int rsa_verify_pss(const rsa_public_key *key, 760 const uint8_t *message, size_t msg_len, 761 const uint8_t *signature, size_t sig_len) { 762 /* PSS parameters (using SHA-256) */ 763 const size_t h_len = 32; /* SHA-256 output length */ 764 const size_t s_len = 32; /* Salt length (same as hash length) */ 765 766 /* Message must be a SHA-256 hash */ 767 if (msg_len != h_len) { 768 fprintf(stderr, "RSA-PSS: Message must be a SHA-256 hash (32 bytes)\n"); 769 return -1; 770 } 771 772 if (sig_len != key->n_len) { 773 return -1; 774 } 775 776 size_t em_len = key->n_len; 777 size_t em_bits = em_len * 8 - 1; 778 779 /* RSA verification primitive (encrypt with public key) */ 780 uint8_t *em = malloc(em_len); 781 if (!em) return -1; 782 783 size_t recovered_len; 784 if (rsa_encrypt_raw(key, signature, sig_len, em, &recovered_len) != 0) { 785 free(em); 786 return -1; 787 } 788 789 /* EMSA-PSS verification */ 790 int result = emsa_pss_verify(message, h_len, em, em_len, em_bits, s_len); 791 free(em); 792 return result; 793 } 794