std.crypto.tls: support handshake fragments

This commit is contained in:
Jacob Young 2024-11-05 02:24:14 -05:00
parent de53e6e4f2
commit a6ede7ba86

View File

@ -274,13 +274,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
}
var tls_version: tls.ProtocolVersion = undefined;
// This is used for two purposes:
// These are used for two purposes:
// * Detect whether a certificate is the first one presented, in which case
// we need to verify the host name.
var cert_index: usize = 0;
// * Flip back and forth between the two cleartext buffers in order to keep
// the previous certificate in memory so that it can be verified by the
// next one.
var cert_index: usize = 0;
var cert_buf_index: usize = 0;
var write_seq: u64 = 0;
var read_seq: u64 = 0;
var prev_cert: Certificate.Parsed = undefined;
@ -315,10 +316,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
var main_cert_pub_key: CertificatePublicKey = undefined;
const now_sec = std.time.timestamp();
var cleartext_fragment_start: usize = 0;
var cleartext_fragment_end: usize = 0;
var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined;
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
while (true) {
fragment: while (true) {
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const record_header = d.buf[d.idx..][0..tls.record_header_len];
const record_ct = d.decode(tls.ContentType);
@ -332,15 +335,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
std.debug.assert(tls_version == .tls_1_3);
if (record_ct != .application_data) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_index % 2];
const cleartext = cleartext: switch (handshake_cipher) {
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
switch (handshake_cipher) {
inline else => |*p| {
const pv = &p.version.tls_1_3;
const P = @TypeOf(p.*).A;
if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow;
const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length);
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_fragment_buf[0..ciphertext.len];
const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const nonce = if (builtin.zig_backend == .stage2_x86_64 and
P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1)
@ -357,27 +361,29 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
};
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
return error.TlsBadRecordMac;
break :cleartext mem.trimRight(u8, cleartext, "\x00");
cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len;
},
};
}
read_seq += 1;
const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]);
cleartext_fragment_end -= 1;
const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]);
if (ct != .handshake) return error.TlsUnexpectedMessage;
break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct };
break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct };
},
.application => {
std.debug.assert(tls_version == .tls_1_2);
if (record_ct != .handshake) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_index % 2];
const cleartext = cleartext: switch (handshake_cipher) {
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
switch (handshake_cipher) {
inline else => |*p| {
const pv = &p.version.tls_1_2;
const P = @TypeOf(p.*).A;
if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow;
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
if (message_len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..message_len];
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_fragment_buf[0..message_len];
const ad = std.mem.toBytes(big(read_seq)) ++
record_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
@ -400,16 +406,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const ciphertext = record_decoder.slice(message_len);
const auth_tag = record_decoder.array(P.mac_length);
P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac;
break :cleartext cleartext;
cleartext_fragment_end += message_len;
},
};
}
read_seq += 1;
break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct };
break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct };
},
};
switch (ct) {
.alert => {
try ctd.ensure(2);
ctd.ensure(2) catch continue :fragment;
const level = ctd.decode(tls.AlertLevel);
const desc = ctd.decode(tls.AlertDescription);
_ = level;
@ -420,15 +426,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
return error.TlsUnexpectedMessage;
},
.change_cipher_spec => {
try ctd.ensure(1);
ctd.ensure(1) catch continue :fragment;
if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter;
cipher_state = pending_cipher_state;
},
.handshake => while (true) {
try ctd.ensure(4);
ctd.ensure(4) catch continue :fragment;
const handshake_type = ctd.decode(tls.HandshakeType);
const handshake_len = ctd.decode(u24);
var hsd = try ctd.sub(handshake_len);
var hsd = ctd.sub(handshake_len) catch continue :fragment;
const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
switch (handshake_type) {
.server_hello => {
@ -657,6 +663,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
prev_cert = subject;
cert_index += 1;
}
cert_buf_index += 1;
},
.server_key_exchange => {
if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage;
@ -892,9 +899,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
else => return error.TlsUnexpectedMessage,
}
if (ctd.eof()) break;
cleartext_fragment_start = ctd.idx;
},
else => return error.TlsUnexpectedMessage,
}
cleartext_fragment_start = 0;
cleartext_fragment_end = 0;
}
}