From a6ede7ba86987b9ae2bb6b8aac60f66af56e7b08 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Tue, 5 Nov 2024 02:24:14 -0500 Subject: [PATCH] std.crypto.tls: support handshake fragments --- lib/std/crypto/tls/Client.zig | 54 +++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index e10a7273c9..922f7b66cc 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -274,13 +274,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } var tls_version: tls.ProtocolVersion = undefined; - // This is used for two purposes: + // These are used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. + var cert_index: usize = 0; // * Flip back and forth between the two cleartext buffers in order to keep // the previous certificate in memory so that it can be verified by the // next one. - var cert_index: usize = 0; + var cert_buf_index: usize = 0; var write_seq: u64 = 0; var read_seq: u64 = 0; var prev_cert: Certificate.Parsed = undefined; @@ -315,10 +316,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); + var cleartext_fragment_start: usize = 0; + var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; - while (true) { + fragment: while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); const record_header = d.buf[d.idx..][0..tls.record_header_len]; const record_ct = d.decode(tls.ContentType); @@ -332,15 +335,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client std.debug.assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - const cleartext = cleartext: switch (handshake_cipher) { + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { inline else => |*p| { const pv = &p.version.tls_1_3; const P = @TypeOf(p.*).A; if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow; const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..ciphertext.len]; const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const nonce = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) @@ -357,27 +361,29 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch return error.TlsBadRecordMac; - break :cleartext mem.trimRight(u8, cleartext, "\x00"); + cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len; }, - }; + } read_seq += 1; - const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); + cleartext_fragment_end -= 1; + const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]); if (ct != .handshake) return error.TlsUnexpectedMessage; - break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct }; + break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { std.debug.assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - const cleartext = cleartext: switch (handshake_cipher) { + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { inline else => |*p| { const pv = &p.version.tls_1_2; const P = @TypeOf(p.*).A; if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow; const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - if (message_len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..message_len]; + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..message_len]; const ad = std.mem.toBytes(big(read_seq)) ++ record_header[0 .. 1 + 2] ++ std.mem.toBytes(big(message_len)); @@ -400,16 +406,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const ciphertext = record_decoder.slice(message_len); const auth_tag = record_decoder.array(P.mac_length); P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac; - break :cleartext cleartext; + cleartext_fragment_end += message_len; }, - }; + } read_seq += 1; - break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct }; + break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct }; }, }; switch (ct) { .alert => { - try ctd.ensure(2); + ctd.ensure(2) catch continue :fragment; const level = ctd.decode(tls.AlertLevel); const desc = ctd.decode(tls.AlertDescription); _ = level; @@ -420,15 +426,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client return error.TlsUnexpectedMessage; }, .change_cipher_spec => { - try ctd.ensure(1); + ctd.ensure(1) catch continue :fragment; if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter; cipher_state = pending_cipher_state; }, .handshake => while (true) { - try ctd.ensure(4); + ctd.ensure(4) catch continue :fragment; const handshake_type = ctd.decode(tls.HandshakeType); const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); + var hsd = ctd.sub(handshake_len) catch continue :fragment; const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; switch (handshake_type) { .server_hello => { @@ -657,6 +663,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client prev_cert = subject; cert_index += 1; } + cert_buf_index += 1; }, .server_key_exchange => { if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; @@ -892,9 +899,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client else => return error.TlsUnexpectedMessage, } if (ctd.eof()) break; + cleartext_fragment_start = ctd.idx; }, else => return error.TlsUnexpectedMessage, } + cleartext_fragment_start = 0; + cleartext_fragment_end = 0; } }