luajitos

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

TCP.lua (13703B)


      1 -- TCP Protocol Implementation
      2 -- Implements a basic TCP state machine and connection handling
      3 
      4 local TCP = {}
      5 
      6 -- TCP Flags
      7 TCP.FLAGS = {
      8     FIN = 0x01,
      9     SYN = 0x02,
     10     RST = 0x04,
     11     PSH = 0x08,
     12     ACK = 0x10,
     13     URG = 0x20,
     14 }
     15 
     16 -- TCP States
     17 TCP.STATE = {
     18     CLOSED = 0,
     19     LISTEN = 1,
     20     SYN_SENT = 2,
     21     SYN_RECEIVED = 3,
     22     ESTABLISHED = 4,
     23     FIN_WAIT_1 = 5,
     24     FIN_WAIT_2 = 6,
     25     CLOSE_WAIT = 7,
     26     CLOSING = 8,
     27     LAST_ACK = 9,
     28     TIME_WAIT = 10,
     29 }
     30 
     31 -- TCP connections
     32 TCP.connections = {}
     33 TCP.next_port = 49152
     34 
     35 ---Calculate TCP checksum (with pseudo-header)
     36 ---@param src_ip table Source IP (4 bytes)
     37 ---@param dst_ip table Destination IP (4 bytes)
     38 ---@param tcp_packet string TCP packet
     39 ---@return number checksum 16-bit checksum
     40 function TCP.calculate_checksum(src_ip, dst_ip, tcp_packet)
     41     local pseudo = {}
     42 
     43     -- Source IP
     44     for i = 1, 4 do
     45         pseudo[#pseudo + 1] = string.char(src_ip[i])
     46     end
     47 
     48     -- Destination IP
     49     for i = 1, 4 do
     50         pseudo[#pseudo + 1] = string.char(dst_ip[i])
     51     end
     52 
     53     -- Zero + Protocol (6 = TCP)
     54     pseudo[#pseudo + 1] = string.char(0x00, 0x06)
     55 
     56     -- TCP length
     57     local len = #tcp_packet
     58     pseudo[#pseudo + 1] = string.char(math.floor(len / 256), len % 256)
     59 
     60     -- TCP packet
     61     pseudo[#pseudo + 1] = tcp_packet
     62 
     63     -- Calculate checksum on pseudo-header + TCP packet
     64     local data = table.concat(pseudo)
     65     local sum = 0
     66     local i = 1
     67 
     68     while i < #data do
     69         local word = string.byte(data, i) * 256 + string.byte(data, i + 1)
     70         sum = sum + word
     71         i = i + 2
     72     end
     73 
     74     if i == #data then
     75         sum = sum + string.byte(data, i) * 256
     76     end
     77 
     78     -- Fold to 16 bits
     79     while sum > 0xFFFF do
     80         sum = (sum & 0xFFFF) + (sum >> 16)
     81     end
     82 
     83     return ~sum & 0xFFFF
     84 end
     85 
     86 ---Build TCP packet
     87 ---@param src_port number Source port
     88 ---@param dst_port number Destination port
     89 ---@param seq number Sequence number
     90 ---@param ack number Acknowledgment number
     91 ---@param flags number TCP flags
     92 ---@param window number Window size
     93 ---@param payload string Payload data
     94 ---@param src_ip table Source IP
     95 ---@param dst_ip table Destination IP
     96 ---@return string packet TCP packet
     97 function TCP.build_packet(src_port, dst_port, seq, ack, flags, window, payload, src_ip, dst_ip)
     98     local tcp = {}
     99 
    100     -- Source port
    101     tcp[#tcp + 1] = string.char(math.floor(src_port / 256), src_port % 256)
    102 
    103     -- Destination port
    104     tcp[#tcp + 1] = string.char(math.floor(dst_port / 256), dst_port % 256)
    105 
    106     -- Sequence number
    107     tcp[#tcp + 1] = string.char(
    108         (seq >> 24) & 0xFF,
    109         (seq >> 16) & 0xFF,
    110         (seq >> 8) & 0xFF,
    111         seq & 0xFF
    112     )
    113 
    114     -- Acknowledgment number
    115     tcp[#tcp + 1] = string.char(
    116         (ack >> 24) & 0xFF,
    117         (ack >> 16) & 0xFF,
    118         (ack >> 8) & 0xFF,
    119         ack & 0xFF
    120     )
    121 
    122     -- Data offset (5 = 20 bytes, no options) + reserved + flags
    123     local data_offset = 5 << 4  -- 20 bytes header
    124     tcp[#tcp + 1] = string.char(data_offset, flags)
    125 
    126     -- Window size
    127     tcp[#tcp + 1] = string.char(math.floor(window / 256), window % 256)
    128 
    129     -- Checksum placeholder
    130     local checksum_pos = #tcp + 1
    131     tcp[#tcp + 1] = string.char(0x00, 0x00)
    132 
    133     -- Urgent pointer
    134     tcp[#tcp + 1] = string.char(0x00, 0x00)
    135 
    136     -- Payload
    137     tcp[#tcp + 1] = payload
    138 
    139     -- Calculate checksum
    140     local packet = table.concat(tcp)
    141     local checksum = TCP.calculate_checksum(src_ip, dst_ip, packet)
    142 
    143     -- Insert checksum
    144     tcp[checksum_pos] = string.char(math.floor(checksum / 256), checksum % 256)
    145 
    146     return table.concat(tcp)
    147 end
    148 
    149 ---Parse TCP packet
    150 ---@param payload string TCP payload
    151 ---@return table|nil tcp Parsed TCP packet
    152 function TCP.parse_packet(payload)
    153     if #payload < 20 then
    154         return nil
    155     end
    156 
    157     local tcp = {}
    158 
    159     -- Source port
    160     tcp.src_port = string.byte(payload, 1) * 256 + string.byte(payload, 2)
    161 
    162     -- Destination port
    163     tcp.dst_port = string.byte(payload, 3) * 256 + string.byte(payload, 4)
    164 
    165     -- Sequence number
    166     tcp.seq = (string.byte(payload, 5) << 24) |
    167               (string.byte(payload, 6) << 16) |
    168               (string.byte(payload, 7) << 8) |
    169               string.byte(payload, 8)
    170 
    171     -- Acknowledgment number
    172     tcp.ack = (string.byte(payload, 9) << 24) |
    173               (string.byte(payload, 10) << 16) |
    174               (string.byte(payload, 11) << 8) |
    175               string.byte(payload, 12)
    176 
    177     -- Data offset and flags
    178     local data_offset_flags = string.byte(payload, 13)
    179     tcp.data_offset = (data_offset_flags >> 4) * 4  -- Convert to bytes
    180     tcp.flags = string.byte(payload, 14)
    181 
    182     -- Window size
    183     tcp.window = string.byte(payload, 15) * 256 + string.byte(payload, 16)
    184 
    185     -- Checksum
    186     tcp.checksum = string.byte(payload, 17) * 256 + string.byte(payload, 18)
    187 
    188     -- Payload (after header)
    189     tcp.payload = string.sub(payload, tcp.data_offset + 1)
    190 
    191     return tcp
    192 end
    193 
    194 ---Create a new TCP connection
    195 ---@param NetworkStack table Network stack instance
    196 ---@param dst_ip table Destination IP
    197 ---@param dst_port number Destination port
    198 ---@param src_port number|nil Source port (auto-assigned if nil)
    199 ---@return table|nil connection Connection object
    200 function TCP.connect(NetworkStack, dst_ip, dst_port, src_port)
    201     src_port = src_port or TCP.next_port
    202     TCP.next_port = TCP.next_port + 1
    203     if TCP.next_port > 65535 then
    204         TCP.next_port = 49152
    205     end
    206 
    207     local conn_id = src_port .. "_" .. table.concat(dst_ip, ".") .. "_" .. dst_port
    208 
    209     local connection = {
    210         id = conn_id,
    211         src_port = src_port,
    212         dst_ip = dst_ip,
    213         dst_port = dst_port,
    214         state = TCP.STATE.SYN_SENT,
    215         seq = math.random(1, 0x7FFFFFFF),  -- Initial sequence number
    216         ack = 0,
    217         window = 8192,
    218         receive_buffer = {},
    219         send_buffer = {},
    220         callbacks = {},
    221         NetworkStack = NetworkStack,
    222     }
    223 
    224     TCP.connections[conn_id] = connection
    225 
    226     -- Send SYN
    227     local syn_packet = TCP.build_packet(
    228         src_port, dst_port,
    229         connection.seq, 0,
    230         TCP.FLAGS.SYN,
    231         connection.window,
    232         "",
    233         NetworkStack.config.ip, dst_ip
    234     )
    235 
    236     local ip_packet = NetworkStack.build_ipv4(6, NetworkStack.config.ip, dst_ip, syn_packet)
    237 
    238     -- Try to send (may need ARP resolution)
    239     local dst_mac = NetworkStack.arp_resolve(NetworkStack.RTL8139, dst_ip, 5)
    240     if dst_mac then
    241         local our_mac = NetworkStack.RTL8139.getMACAddress()
    242         local frame = NetworkStack.RTL8139.buildEthernetFrame(dst_mac, our_mac, 0x0800, ip_packet)
    243         NetworkStack.RTL8139.send(frame)
    244         connection.seq = connection.seq + 1  -- SYN consumes one sequence number
    245     end
    246 
    247     return connection
    248 end
    249 
    250 ---Send data on TCP connection
    251 ---@param connection table Connection object
    252 ---@param data string Data to send
    253 ---@return boolean success True if queued successfully
    254 function TCP.send(connection, data)
    255     if connection.state ~= TCP.STATE.ESTABLISHED then
    256         return false
    257     end
    258 
    259     local tcp_packet = TCP.build_packet(
    260         connection.src_port, connection.dst_port,
    261         connection.seq, connection.ack,
    262         TCP.FLAGS.PSH | TCP.FLAGS.ACK,
    263         connection.window,
    264         data,
    265         connection.NetworkStack.config.ip, connection.dst_ip
    266     )
    267 
    268     local ip_packet = connection.NetworkStack.build_ipv4(
    269         6, connection.NetworkStack.config.ip, connection.dst_ip, tcp_packet
    270     )
    271 
    272     local dst_mac = connection.NetworkStack.arp_resolve(
    273         connection.NetworkStack.RTL8139, connection.dst_ip, 5
    274     )
    275     if dst_mac then
    276         local our_mac = connection.NetworkStack.RTL8139.getMACAddress()
    277         local frame = connection.NetworkStack.RTL8139.buildEthernetFrame(
    278             dst_mac, our_mac, 0x0800, ip_packet
    279         )
    280         connection.NetworkStack.RTL8139.send(frame)
    281         connection.seq = connection.seq + #data
    282         return true
    283     end
    284 
    285     return false
    286 end
    287 
    288 ---Close TCP connection
    289 ---@param connection table Connection object
    290 function TCP.close(connection)
    291     if connection.state == TCP.STATE.ESTABLISHED then
    292         local fin_packet = TCP.build_packet(
    293             connection.src_port, connection.dst_port,
    294             connection.seq, connection.ack,
    295             TCP.FLAGS.FIN | TCP.FLAGS.ACK,
    296             connection.window,
    297             "",
    298             connection.NetworkStack.config.ip, connection.dst_ip
    299         )
    300 
    301         local ip_packet = connection.NetworkStack.build_ipv4(
    302             6, connection.NetworkStack.config.ip, connection.dst_ip, fin_packet
    303         )
    304 
    305         local dst_mac = connection.NetworkStack.arp_resolve(
    306             connection.NetworkStack.RTL8139, connection.dst_ip, 5
    307         )
    308         if dst_mac then
    309             local our_mac = connection.NetworkStack.RTL8139.getMACAddress()
    310             local frame = connection.NetworkStack.RTL8139.buildEthernetFrame(
    311                 dst_mac, our_mac, 0x0800, ip_packet
    312             )
    313             connection.NetworkStack.RTL8139.send(frame)
    314             connection.seq = connection.seq + 1  -- FIN consumes one sequence number
    315             connection.state = TCP.STATE.FIN_WAIT_1
    316         end
    317     end
    318 end
    319 
    320 ---Handle incoming TCP packet
    321 ---@param NetworkStack table Network stack instance
    322 ---@param src_ip table Source IP
    323 ---@param dst_ip table Destination IP
    324 ---@param tcp table Parsed TCP packet
    325 function TCP.handle_packet(NetworkStack, src_ip, dst_ip, tcp)
    326     -- Find matching connection
    327     local conn_id = tcp.dst_port .. "_" .. table.concat(src_ip, ".") .. "_" .. tcp.src_port
    328     local connection = TCP.connections[conn_id]
    329 
    330     if not connection then
    331         -- No connection found - send RST
    332         return
    333     end
    334 
    335     -- Handle based on state
    336     if connection.state == TCP.STATE.SYN_SENT then
    337         if (tcp.flags & TCP.FLAGS.SYN) ~= 0 and (tcp.flags & TCP.FLAGS.ACK) ~= 0 then
    338             -- SYN-ACK received
    339             connection.ack = tcp.seq + 1
    340             connection.state = TCP.STATE.ESTABLISHED
    341 
    342             -- Send ACK
    343             local ack_packet = TCP.build_packet(
    344                 connection.src_port, connection.dst_port,
    345                 connection.seq, connection.ack,
    346                 TCP.FLAGS.ACK,
    347                 connection.window,
    348                 "",
    349                 NetworkStack.config.ip, connection.dst_ip
    350             )
    351 
    352             local ip_packet = NetworkStack.build_ipv4(6, NetworkStack.config.ip, connection.dst_ip, ack_packet)
    353             local dst_mac = NetworkStack.arp_resolve(NetworkStack.RTL8139, connection.dst_ip, 5)
    354             if dst_mac then
    355                 local our_mac = NetworkStack.RTL8139.getMACAddress()
    356                 local frame = NetworkStack.RTL8139.buildEthernetFrame(dst_mac, our_mac, 0x0800, ip_packet)
    357                 NetworkStack.RTL8139.send(frame)
    358             end
    359 
    360             -- Call connected callback
    361             if connection.callbacks.on_connected then
    362                 connection.callbacks.on_connected()
    363             end
    364         end
    365     elseif connection.state == TCP.STATE.ESTABLISHED then
    366         if #tcp.payload > 0 then
    367             -- Data received
    368             connection.ack = tcp.seq + #tcp.payload
    369 
    370             -- Store in receive buffer
    371             table.insert(connection.receive_buffer, tcp.payload)
    372 
    373             -- Send ACK
    374             local ack_packet = TCP.build_packet(
    375                 connection.src_port, connection.dst_port,
    376                 connection.seq, connection.ack,
    377                 TCP.FLAGS.ACK,
    378                 connection.window,
    379                 "",
    380                 NetworkStack.config.ip, connection.dst_ip
    381             )
    382 
    383             local ip_packet = NetworkStack.build_ipv4(6, NetworkStack.config.ip, connection.dst_ip, ack_packet)
    384             local dst_mac = NetworkStack.arp_resolve(NetworkStack.RTL8139, connection.dst_ip, 5)
    385             if dst_mac then
    386                 local our_mac = NetworkStack.RTL8139.getMACAddress()
    387                 local frame = NetworkStack.RTL8139.buildEthernetFrame(dst_mac, our_mac, 0x0800, ip_packet)
    388                 NetworkStack.RTL8139.send(frame)
    389             end
    390 
    391             -- Call data callback
    392             if connection.callbacks.on_data then
    393                 connection.callbacks.on_data(tcp.payload)
    394             end
    395         end
    396 
    397         if (tcp.flags & TCP.FLAGS.FIN) ~= 0 then
    398             -- FIN received
    399             connection.ack = tcp.seq + 1
    400             connection.state = TCP.STATE.CLOSE_WAIT
    401 
    402             -- Send ACK
    403             local ack_packet = TCP.build_packet(
    404                 connection.src_port, connection.dst_port,
    405                 connection.seq, connection.ack,
    406                 TCP.FLAGS.ACK,
    407                 connection.window,
    408                 "",
    409                 NetworkStack.config.ip, connection.dst_ip
    410             )
    411 
    412             local ip_packet = NetworkStack.build_ipv4(6, NetworkStack.config.ip, connection.dst_ip, ack_packet)
    413             local dst_mac = NetworkStack.arp_resolve(NetworkStack.RTL8139, connection.dst_ip, 5)
    414             if dst_mac then
    415                 local our_mac = NetworkStack.RTL8139.getMACAddress()
    416                 local frame = NetworkStack.RTL8139.buildEthernetFrame(dst_mac, our_mac, 0x0800, ip_packet)
    417                 NetworkStack.RTL8139.send(frame)
    418             end
    419 
    420             -- Call closed callback
    421             if connection.callbacks.on_closed then
    422                 connection.callbacks.on_closed()
    423             end
    424         end
    425     end
    426 end
    427 
    428 ---Read data from connection
    429 ---@param connection table Connection object
    430 ---@return string|nil data Received data
    431 function TCP.read(connection)
    432     if #connection.receive_buffer > 0 then
    433         local data = table.concat(connection.receive_buffer)
    434         connection.receive_buffer = {}
    435         return data
    436     end
    437     return nil
    438 end
    439 
    440 return TCP