luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

RSA_Lua.c (14196B)


      1 /*
      2  * RSA_Lua.c - Lua bindings for RSA public key cryptography
      3  */
      4 
      5 #include "RSA_Lua.h"
      6 #include "RSA.h"
      7 #include "CSPRNG.h"
      8 #include "hashing/hash.h"
      9 #include <string.h>
     10 #include <stdlib.h>
     11 
     12 /* Base64 encoding function (provided by crypto.c) */
     13 extern char* base64_encode(const uint8_t *data, size_t len, size_t *out_len);
     14 
     15 /* Serialize public key to base64 */
     16 static char* serialize_public_key(const rsa_public_key *key, size_t *out_len) {
     17     /* Simple format: n_len(4) || n || e_len(4) || e */
     18     size_t total_len = 4 + key->n_len + 4 + key->e_len;
     19     uint8_t *buffer = malloc(total_len);
     20     if (!buffer) return NULL;
     21 
     22     uint8_t *ptr = buffer;
     23 
     24     /* Write n_len and n */
     25     *(uint32_t*)ptr = key->n_len; ptr += 4;
     26     memcpy(ptr, key->n, key->n_len); ptr += key->n_len;
     27 
     28     /* Write e_len and e */
     29     *(uint32_t*)ptr = key->e_len; ptr += 4;
     30     memcpy(ptr, key->e, key->e_len); ptr += key->e_len;
     31 
     32     char *b64 = base64_encode(buffer, total_len, out_len);
     33     free(buffer);
     34     return b64;
     35 }
     36 
     37 /* Serialize private key to base64 */
     38 static char* serialize_private_key(const rsa_private_key *key, size_t *out_len) {
     39     /* Simple format: n_len(4) || n || d_len(4) || d || p_len(4) || p || q_len(4) || q */
     40     size_t total_len = 4 + key->n_len + 4 + key->d_len + 4 + key->p_len + 4 + key->q_len;
     41     uint8_t *buffer = malloc(total_len);
     42     if (!buffer) return NULL;
     43 
     44     uint8_t *ptr = buffer;
     45 
     46     *(uint32_t*)ptr = key->n_len; ptr += 4;
     47     memcpy(ptr, key->n, key->n_len); ptr += key->n_len;
     48 
     49     *(uint32_t*)ptr = key->d_len; ptr += 4;
     50     memcpy(ptr, key->d, key->d_len); ptr += key->d_len;
     51 
     52     *(uint32_t*)ptr = key->p_len; ptr += 4;
     53     memcpy(ptr, key->p, key->p_len); ptr += key->p_len;
     54 
     55     *(uint32_t*)ptr = key->q_len; ptr += 4;
     56     memcpy(ptr, key->q, key->q_len); ptr += key->q_len;
     57 
     58     char *b64 = base64_encode(buffer, total_len, out_len);
     59     memset(buffer, 0, total_len);  /* Zero sensitive data */
     60     free(buffer);
     61     return b64;
     62 }
     63 
     64 /* rsa.generateKeypair([bits]) - Generate RSA keypair (default 2048 bits) */
     65 int l_rsa_generate_keypair(lua_State *L) {
     66     int bits = luaL_optinteger(L, 1, 2048);
     67 
     68     if (bits < 2048) {
     69         return luaL_error(L, "Key size must be at least 2048 bits for security");
     70     }
     71 
     72     rsa_public_key pub;
     73     rsa_private_key priv;
     74 
     75     if (rsa_generate_key_simple(&pub, &priv, bits) != 0) {
     76         return luaL_error(L, "RSA key generation failed");
     77     }
     78 
     79     /* Serialize keys to base64 */
     80     size_t pub_b64_len, priv_b64_len;
     81     char *pub_b64 = serialize_public_key(&pub, &pub_b64_len);
     82     char *priv_b64 = serialize_private_key(&priv, &priv_b64_len);
     83 
     84     if (!pub_b64 || !priv_b64) {
     85         free(pub_b64);
     86         free(priv_b64);
     87         rsa_free_public_key(&pub);
     88         rsa_free_private_key(&priv);
     89         return luaL_error(L, "Failed to serialize RSA keys");
     90     }
     91 
     92     /* Return public_key, private_key */
     93     lua_pushlstring(L, pub_b64, pub_b64_len);
     94     lua_pushlstring(L, priv_b64, priv_b64_len);
     95 
     96     free(pub_b64);
     97     memset(priv_b64, 0, priv_b64_len);  /* Zero sensitive data */
     98     free(priv_b64);
     99 
    100     rsa_free_public_key(&pub);
    101     rsa_free_private_key(&priv);
    102 
    103     return 2;
    104 }
    105 
    106 /* Deserialize public key from base64 */
    107 static int deserialize_public_key(rsa_public_key *key, const uint8_t *buffer, size_t buf_len) {
    108     if (buf_len < 8) return -1;
    109 
    110     const uint8_t *ptr = buffer;
    111 
    112     /* Read n_len and n */
    113     key->n_len = *(uint32_t*)ptr; ptr += 4;
    114     if (ptr + key->n_len > buffer + buf_len) return -1;
    115     key->n = malloc(key->n_len);
    116     memcpy(key->n, ptr, key->n_len); ptr += key->n_len;
    117 
    118     /* Read e_len and e */
    119     key->e_len = *(uint32_t*)ptr; ptr += 4;
    120     if (ptr + key->e_len > buffer + buf_len) return -1;
    121     key->e = malloc(key->e_len);
    122     memcpy(key->e, ptr, key->e_len);
    123 
    124     return 0;
    125 }
    126 
    127 /* Deserialize private key from base64 */
    128 static int deserialize_private_key(rsa_private_key *key, const uint8_t *buffer, size_t buf_len) {
    129     if (buf_len < 16) return -1;
    130 
    131     const uint8_t *ptr = buffer;
    132 
    133     key->n_len = *(uint32_t*)ptr; ptr += 4;
    134     if (ptr + key->n_len > buffer + buf_len) return -1;
    135     key->n = malloc(key->n_len);
    136     memcpy(key->n, ptr, key->n_len); ptr += key->n_len;
    137 
    138     key->d_len = *(uint32_t*)ptr; ptr += 4;
    139     if (ptr + key->d_len > buffer + buf_len) return -1;
    140     key->d = malloc(key->d_len);
    141     memcpy(key->d, ptr, key->d_len); ptr += key->d_len;
    142 
    143     key->p_len = *(uint32_t*)ptr; ptr += 4;
    144     if (ptr + key->p_len > buffer + buf_len) return -1;
    145     key->p = malloc(key->p_len);
    146     memcpy(key->p, ptr, key->p_len); ptr += key->p_len;
    147 
    148     key->q_len = *(uint32_t*)ptr; ptr += 4;
    149     if (ptr + key->q_len > buffer + buf_len) return -1;
    150     key->q = malloc(key->q_len);
    151     memcpy(key->q, ptr, key->q_len);
    152 
    153     /* Initialize other CRT parameters to NULL */
    154     key->e = NULL; key->e_len = 0;
    155     key->dp = NULL; key->dp_len = 0;
    156     key->dq = NULL; key->dq_len = 0;
    157     key->qinv = NULL; key->qinv_len = 0;
    158 
    159     return 0;
    160 }
    161 
    162 /* Base64 decode (external from crypto.c) */
    163 extern uint8_t* base64_decode(const char *data, size_t len, size_t *out_len);
    164 
    165 /* rsa.encrypt(public_key_b64, plaintext) - Encrypt with RSA public key */
    166 int l_rsa_encrypt(lua_State *L) {
    167     size_t pub_key_len, plaintext_len;
    168     const char *pub_key_b64 = luaL_checklstring(L, 1, &pub_key_len);
    169     const uint8_t *plaintext = (const uint8_t*)luaL_checklstring(L, 2, &plaintext_len);
    170 
    171     /* Decode public key */
    172     size_t pub_key_decoded_len;
    173     uint8_t *pub_key_decoded = base64_decode(pub_key_b64, pub_key_len, &pub_key_decoded_len);
    174     if (!pub_key_decoded) {
    175         return luaL_error(L, "Invalid public key base64");
    176     }
    177 
    178     rsa_public_key pub;
    179     if (deserialize_public_key(&pub, pub_key_decoded, pub_key_decoded_len) != 0) {
    180         free(pub_key_decoded);
    181         return luaL_error(L, "Invalid public key format");
    182     }
    183     free(pub_key_decoded);
    184 
    185     /* Encrypt */
    186     uint8_t *ciphertext = malloc(pub.n_len);
    187     size_t ct_len;
    188     if (rsa_encrypt(&pub, plaintext, plaintext_len, ciphertext, &ct_len) != 0) {
    189         free(ciphertext);
    190         rsa_free_public_key(&pub);
    191         return luaL_error(L, "RSA encryption failed");
    192     }
    193 
    194     /* Return base64-encoded ciphertext */
    195     size_t b64_len;
    196     char *b64 = base64_encode(ciphertext, ct_len, &b64_len);
    197     free(ciphertext);
    198     rsa_free_public_key(&pub);
    199 
    200     if (!b64) {
    201         return luaL_error(L, "Base64 encoding failed");
    202     }
    203 
    204     lua_pushlstring(L, b64, b64_len);
    205     free(b64);
    206     return 1;
    207 }
    208 
    209 /* rsa.decrypt(private_key_b64, ciphertext_b64) - Decrypt with RSA private key */
    210 int l_rsa_decrypt(lua_State *L) {
    211     size_t priv_key_len, ciphertext_len;
    212     const char *priv_key_b64 = luaL_checklstring(L, 1, &priv_key_len);
    213     const char *ciphertext_b64 = luaL_checklstring(L, 2, &ciphertext_len);
    214 
    215     /* Decode private key */
    216     size_t priv_key_decoded_len;
    217     uint8_t *priv_key_decoded = base64_decode(priv_key_b64, priv_key_len, &priv_key_decoded_len);
    218     if (!priv_key_decoded) {
    219         return luaL_error(L, "Invalid private key base64");
    220     }
    221 
    222     rsa_private_key priv;
    223     if (deserialize_private_key(&priv, priv_key_decoded, priv_key_decoded_len) != 0) {
    224         memset(priv_key_decoded, 0, priv_key_decoded_len);
    225         free(priv_key_decoded);
    226         return luaL_error(L, "Invalid private key format");
    227     }
    228     memset(priv_key_decoded, 0, priv_key_decoded_len);
    229     free(priv_key_decoded);
    230 
    231     /* Decode ciphertext */
    232     size_t ct_decoded_len;
    233     uint8_t *ct_decoded = base64_decode(ciphertext_b64, ciphertext_len, &ct_decoded_len);
    234     if (!ct_decoded) {
    235         rsa_free_private_key(&priv);
    236         return luaL_error(L, "Invalid ciphertext base64");
    237     }
    238 
    239     /* Decrypt */
    240     uint8_t *plaintext = malloc(priv.n_len);
    241     size_t pt_len;
    242     if (rsa_decrypt(&priv, ct_decoded, ct_decoded_len, plaintext, &pt_len) != 0) {
    243         free(plaintext);
    244         free(ct_decoded);
    245         rsa_free_private_key(&priv);
    246         return luaL_error(L, "RSA decryption failed");
    247     }
    248 
    249     free(ct_decoded);
    250     rsa_free_private_key(&priv);
    251 
    252     lua_pushlstring(L, (const char*)plaintext, pt_len);
    253     free(plaintext);
    254     return 1;
    255 }
    256 
    257 /* rsa.sign(private_key_b64, message) - Sign message with RSA private key */
    258 int l_rsa_sign(lua_State *L) {
    259     size_t priv_key_len, message_len;
    260     const char *priv_key_b64 = luaL_checklstring(L, 1, &priv_key_len);
    261     const uint8_t *message = (const uint8_t*)luaL_checklstring(L, 2, &message_len);
    262 
    263     /* Decode private key */
    264     size_t priv_key_decoded_len;
    265     uint8_t *priv_key_decoded = base64_decode(priv_key_b64, priv_key_len, &priv_key_decoded_len);
    266     if (!priv_key_decoded) {
    267         return luaL_error(L, "Invalid private key base64");
    268     }
    269 
    270     rsa_private_key priv;
    271     if (deserialize_private_key(&priv, priv_key_decoded, priv_key_decoded_len) != 0) {
    272         memset(priv_key_decoded, 0, priv_key_decoded_len);
    273         free(priv_key_decoded);
    274         return luaL_error(L, "Invalid private key format");
    275     }
    276     memset(priv_key_decoded, 0, priv_key_decoded_len);
    277     free(priv_key_decoded);
    278 
    279     /* Sign */
    280     uint8_t *signature = malloc(priv.n_len);
    281     size_t sig_len;
    282     if (rsa_sign(&priv, message, message_len, signature, &sig_len) != 0) {
    283         free(signature);
    284         rsa_free_private_key(&priv);
    285         return luaL_error(L, "RSA signing failed");
    286     }
    287 
    288     rsa_free_private_key(&priv);
    289 
    290     /* Return base64-encoded signature */
    291     size_t b64_len;
    292     char *b64 = base64_encode(signature, sig_len, &b64_len);
    293     free(signature);
    294 
    295     if (!b64) {
    296         return luaL_error(L, "Base64 encoding failed");
    297     }
    298 
    299     lua_pushlstring(L, b64, b64_len);
    300     free(b64);
    301     return 1;
    302 }
    303 
    304 /* rsa.verify(public_key_b64, message, signature_b64) - Verify RSA signature */
    305 int l_rsa_verify(lua_State *L) {
    306     size_t pub_key_len, message_len, signature_len;
    307     const char *pub_key_b64 = luaL_checklstring(L, 1, &pub_key_len);
    308     const uint8_t *message = (const uint8_t*)luaL_checklstring(L, 2, &message_len);
    309     const char *signature_b64 = luaL_checklstring(L, 3, &signature_len);
    310 
    311     /* Decode public key */
    312     size_t pub_key_decoded_len;
    313     uint8_t *pub_key_decoded = base64_decode(pub_key_b64, pub_key_len, &pub_key_decoded_len);
    314     if (!pub_key_decoded) {
    315         return luaL_error(L, "Invalid public key base64");
    316     }
    317 
    318     rsa_public_key pub;
    319     if (deserialize_public_key(&pub, pub_key_decoded, pub_key_decoded_len) != 0) {
    320         free(pub_key_decoded);
    321         return luaL_error(L, "Invalid public key format");
    322     }
    323     free(pub_key_decoded);
    324 
    325     /* Decode signature */
    326     size_t sig_decoded_len;
    327     uint8_t *sig_decoded = base64_decode(signature_b64, signature_len, &sig_decoded_len);
    328     if (!sig_decoded) {
    329         rsa_free_public_key(&pub);
    330         return luaL_error(L, "Invalid signature base64");
    331     }
    332 
    333     /* Verify */
    334     int result = rsa_verify(&pub, message, message_len, sig_decoded, sig_decoded_len);
    335 
    336     free(sig_decoded);
    337     rsa_free_public_key(&pub);
    338 
    339     lua_pushboolean(L, result == 0);
    340     return 1;
    341 }
    342 
    343 /**
    344  * RSA-PSS Sign function
    345  * Usage: signature_b64 = crypto.RSA.signPSS(private_key_b64, message_hash)
    346  */
    347 int l_rsa_sign_pss(lua_State *L) {
    348     size_t priv_key_len, message_len;
    349     const char *priv_key_b64 = luaL_checklstring(L, 1, &priv_key_len);
    350     const uint8_t *message = (const uint8_t*)luaL_checklstring(L, 2, &message_len);
    351 
    352     /* Message must be 32 bytes (SHA-256 hash) */
    353     if (message_len != 32) {
    354         return luaL_error(L, "RSA-PSS requires SHA-256 hash (32 bytes), got %d bytes", (int)message_len);
    355     }
    356 
    357     /* Decode private key */
    358     size_t priv_key_decoded_len;
    359     uint8_t *priv_key_decoded = base64_decode(priv_key_b64, priv_key_len, &priv_key_decoded_len);
    360     if (!priv_key_decoded) {
    361         return luaL_error(L, "Invalid private key base64");
    362     }
    363 
    364     rsa_private_key priv;
    365     if (deserialize_private_key(&priv, priv_key_decoded, priv_key_decoded_len) != 0) {
    366         free(priv_key_decoded);
    367         return luaL_error(L, "Invalid private key format");
    368     }
    369     free(priv_key_decoded);
    370 
    371     /* Sign with PSS */
    372     uint8_t *signature = malloc(priv.n_len);
    373     size_t sig_len;
    374     if (rsa_sign_pss(&priv, message, message_len, signature, &sig_len) != 0) {
    375         free(signature);
    376         rsa_free_private_key(&priv);
    377         return luaL_error(L, "RSA-PSS signing failed");
    378     }
    379 
    380     /* Encode to base64 */
    381     size_t b64_len;
    382     char *b64 = base64_encode(signature, sig_len, &b64_len);
    383     free(signature);
    384     rsa_free_private_key(&priv);
    385 
    386     if (!b64) return luaL_error(L, "Base64 encoding failed");
    387 
    388     lua_pushlstring(L, b64, b64_len);
    389     free(b64);
    390     return 1;
    391 }
    392 
    393 /**
    394  * RSA-PSS Verify function
    395  * Usage: valid = crypto.RSA.verifyPSS(public_key_b64, message_hash, signature_b64)
    396  */
    397 int l_rsa_verify_pss(lua_State *L) {
    398     size_t pub_key_len, message_len, sig_len;
    399     const char *pub_key_b64 = luaL_checklstring(L, 1, &pub_key_len);
    400     const uint8_t *message = (const uint8_t*)luaL_checklstring(L, 2, &message_len);
    401     const char *sig_b64 = luaL_checklstring(L, 3, &sig_len);
    402 
    403     /* Message must be 32 bytes (SHA-256 hash) */
    404     if (message_len != 32) {
    405         return luaL_error(L, "RSA-PSS requires SHA-256 hash (32 bytes), got %d bytes", (int)message_len);
    406     }
    407 
    408     /* Decode public key */
    409     size_t pub_key_decoded_len;
    410     uint8_t *pub_key_decoded = base64_decode(pub_key_b64, pub_key_len, &pub_key_decoded_len);
    411     if (!pub_key_decoded) {
    412         return luaL_error(L, "Invalid public key base64");
    413     }
    414 
    415     rsa_public_key pub;
    416     if (deserialize_public_key(&pub, pub_key_decoded, pub_key_decoded_len) != 0) {
    417         free(pub_key_decoded);
    418         return luaL_error(L, "Invalid public key format");
    419     }
    420     free(pub_key_decoded);
    421 
    422     /* Decode signature */
    423     size_t sig_decoded_len;
    424     uint8_t *sig_decoded = base64_decode(sig_b64, sig_len, &sig_decoded_len);
    425     if (!sig_decoded) {
    426         rsa_free_public_key(&pub);
    427         return luaL_error(L, "Invalid signature base64");
    428     }
    429 
    430     /* Verify PSS */
    431     int result = rsa_verify_pss(&pub, message, message_len, sig_decoded, sig_decoded_len);
    432 
    433     free(sig_decoded);
    434     rsa_free_public_key(&pub);
    435 
    436     lua_pushboolean(L, result == 0);
    437     return 1;
    438 }