std.crypto.tls: remove hardcoded initial loop

This was preventing TLSv1.2 from working in some cases, because servers
are allowed to send multiple handshake messages in the first handshake
record, whereas this inital loop was assuming that it only contained a
server hello.
This commit is contained in:
Jacob Young 2024-11-04 20:45:18 -05:00
parent 90a761c186
commit 485f20a10a

View File

@ -214,158 +214,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
try stream.writevAll(&iovecs);
}
const client_hello_bytes1 = cleartext_header[tls.record_header_len..];
var tls_version: tls.ProtocolVersion = undefined;
var cipher_suite_tag: tls.CipherSuite = undefined;
var handshake_cipher: tls.HandshakeCipher = undefined;
var handshake_buffer: [8000]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
{
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const ct = d.decode(tls.ContentType);
d.skip(2); // legacy_record_version
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
const server_hello_fragment = d.buf[d.idx..][0..record_len];
var ptd = try d.sub(record_len);
switch (ct) {
.alert => {
try ptd.ensure(2);
const level = ptd.decode(tls.AlertLevel);
const desc = ptd.decode(tls.AlertDescription);
_ = level;
// if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
},
.handshake => {
try ptd.ensure(4);
const handshake_type = ptd.decode(tls.HandshakeType);
if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
const length = ptd.decode(u24);
var hsd = try ptd.sub(length);
try hsd.ensure(2 + 32 + 1);
const legacy_version = hsd.decode(u16);
@memcpy(&server_hello_rand, hsd.array(32));
if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) {
// This is a HelloRetryRequest message. This client implementation
// does not expect to get one.
return error.TlsUnexpectedMessage;
}
const legacy_session_id_echo_len = hsd.decode(u8);
try hsd.ensure(legacy_session_id_echo_len + 2 + 1);
const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len);
cipher_suite_tag = hsd.decode(tls.CipherSuite);
hsd.skip(1); // legacy_compression_method
var supported_version: ?u16 = null;
if (!hsd.eof()) {
try hsd.ensure(2);
const extensions_size = hsd.decode(u16);
var all_extd = try hsd.sub(extensions_size);
while (!all_extd.eof()) {
try all_extd.ensure(2 + 2);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
switch (et) {
.supported_versions => {
if (supported_version) |_| return error.TlsIllegalParameter;
try extd.ensure(2);
supported_version = extd.decode(u16);
},
.key_share => {
if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter;
try extd.ensure(4);
const named_group = extd.decode(tls.NamedGroup);
const key_size = extd.decode(u16);
try extd.ensure(key_size);
try key_share.exchange(named_group, extd.slice(key_size));
},
else => {},
}
}
}
tls_version = @enumFromInt(supported_version orelse legacy_version);
switch (tls_version) {
.tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter,
.tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and
server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter,
else => return error.TlsIllegalParameter,
}
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
.AEGIS_256_SHA512,
.AEGIS_128L_SHA256,
.ECDHE_RSA_WITH_AES_128_GCM_SHA256,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
=> |tag| {
handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{
.transcript_hash = .init(.{}),
.version = undefined,
});
const p = &@field(handshake_cipher, @tagName(tag.with()));
p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
p.transcript_hash.update(host); // Client Hello part 2
p.transcript_hash.update(server_hello_fragment);
},
else => return error.TlsIllegalParameter,
}
switch (tls_version) {
.tls_1_3 => switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
.AEGIS_256_SHA512,
.AEGIS_128L_SHA256,
=> |tag| {
const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter;
const p = &@field(handshake_cipher, @tagName(tag.with()));
const P = @TypeOf(p.*).A;
const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
const empty_hash = tls.emptyHash(P.Hash);
p.version = .{ .tls_1_3 = undefined };
const pv = &p.version.tls_1_3;
const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk);
const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
},
else => return error.TlsIllegalParameter,
},
.tls_1_2 => switch (cipher_suite_tag) {
.ECDHE_RSA_WITH_AES_128_GCM_SHA256,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
=> {},
else => return error.TlsIllegalParameter,
},
else => return error.TlsIllegalParameter,
}
},
else => return error.TlsUnexpectedMessage,
}
}
// This is used for two purposes:
// * Detect whether a certificate is the first one presented, in which case
// we need to verify the host name.
@ -384,13 +233,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
/// Application cipher is in use
application,
};
var pending_cipher_state: CipherState = switch (tls_version) {
.tls_1_3 => .handshake,
.tls_1_2 => .cleartext,
else => unreachable,
};
var cipher_state: CipherState = .cleartext;
var pending_cipher_state: CipherState = .cleartext;
var cipher_state = pending_cipher_state;
const HandshakeState = enum {
/// In this state we expect only a server hello message.
hello,
/// In this state we expect only an encrypted_extensions message.
encrypted_extensions,
/// In this state we expect certificate handshake messages.
@ -404,15 +251,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
/// In this state, we expect only the finished handshake message.
finished,
};
var handshake_state: HandshakeState = switch (tls_version) {
.tls_1_3 => .encrypted_extensions,
.tls_1_2 => .certificate,
else => unreachable,
};
var cleartext_bufs: [2][8000]u8 = undefined;
var handshake_state: HandshakeState = .hello;
var handshake_cipher: tls.HandshakeCipher = undefined;
var main_cert_pub_key: CertificatePublicKey = undefined;
const now_sec = std.time.timestamp();
var cleartext_bufs: [2][8000]u8 = undefined;
var handshake_buffer: [8000]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
while (true) {
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const record_header = d.buf[d.idx..][0..tls.record_header_len];
@ -526,11 +372,132 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
var hsd = try ctd.sub(handshake_len);
const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
switch (handshake_type) {
.server_hello => {
if (cipher_state != .cleartext) return error.TlsUnexpectedMessage;
if (handshake_state != .hello) return error.TlsUnexpectedMessage;
try hsd.ensure(2 + 32 + 1);
const legacy_version = hsd.decode(u16);
@memcpy(&server_hello_rand, hsd.array(32));
if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) {
// This is a HelloRetryRequest message. This client implementation
// does not expect to get one.
return error.TlsUnexpectedMessage;
}
const legacy_session_id_echo_len = hsd.decode(u8);
try hsd.ensure(legacy_session_id_echo_len + 2 + 1);
const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len);
const cipher_suite_tag = hsd.decode(tls.CipherSuite);
hsd.skip(1); // legacy_compression_method
var supported_version: ?u16 = null;
if (!hsd.eof()) {
try hsd.ensure(2);
const extensions_size = hsd.decode(u16);
var all_extd = try hsd.sub(extensions_size);
while (!all_extd.eof()) {
try all_extd.ensure(2 + 2);
const et = all_extd.decode(tls.ExtensionType);
const ext_size = all_extd.decode(u16);
var extd = try all_extd.sub(ext_size);
switch (et) {
.supported_versions => {
if (supported_version) |_| return error.TlsIllegalParameter;
try extd.ensure(2);
supported_version = extd.decode(u16);
},
.key_share => {
if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter;
try extd.ensure(4);
const named_group = extd.decode(tls.NamedGroup);
const key_size = extd.decode(u16);
try extd.ensure(key_size);
try key_share.exchange(named_group, extd.slice(key_size));
},
else => {},
}
}
}
tls_version = @enumFromInt(supported_version orelse legacy_version);
switch (tls_version) {
.tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter,
.tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and
server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter,
else => return error.TlsIllegalParameter,
}
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
.AEGIS_256_SHA512,
.AEGIS_128L_SHA256,
.ECDHE_RSA_WITH_AES_128_GCM_SHA256,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
=> |tag| {
handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{
.transcript_hash = .init(.{}),
.version = undefined,
});
const p = &@field(handshake_cipher, @tagName(tag.with()));
p.transcript_hash.update(cleartext_header[tls.record_header_len..]); // Client Hello part 1
p.transcript_hash.update(host); // Client Hello part 2
p.transcript_hash.update(wrapped_handshake);
},
else => return error.TlsIllegalParameter,
}
switch (tls_version) {
.tls_1_3 => {
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
.AEGIS_256_SHA512,
.AEGIS_128L_SHA256,
=> |tag| {
const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter;
const p = &@field(handshake_cipher, @tagName(tag.with()));
const P = @TypeOf(p.*).A;
const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
const empty_hash = tls.emptyHash(P.Hash);
p.version = .{ .tls_1_3 = undefined };
const pv = &p.version.tls_1_3;
const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk);
const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
},
else => return error.TlsIllegalParameter,
}
pending_cipher_state = .handshake;
handshake_state = .encrypted_extensions;
},
.tls_1_2 => switch (cipher_suite_tag) {
.ECDHE_RSA_WITH_AES_128_GCM_SHA256,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
=> handshake_state = .certificate,
else => return error.TlsIllegalParameter,
},
else => return error.TlsIllegalParameter,
}
},
.encrypted_extensions => {
if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage;
if (cipher_state != .handshake) return error.TlsUnexpectedMessage;
if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
handshake_state = .certificate;
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
@ -548,16 +515,18 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
else => {},
}
}
handshake_state = .certificate;
},
.certificate => cert: {
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
if (cipher_state == .application) return error.TlsUnexpectedMessage;
switch (handshake_state) {
.certificate => {},
.trust_chain_established => break :cert,
else => return error.TlsUnexpectedMessage,
}
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
switch (tls_version) {
.tls_1_3 => {
@ -614,7 +583,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage;
if (cipher_state != .cleartext) return error.TlsUnexpectedMessage;
switch (handshake_state) {
.trust_chain_established => handshake_state = .server_hello_done,
.trust_chain_established => {},
.certificate => return error.TlsCertificateNotVerified,
else => return error.TlsUnexpectedMessage,
}
@ -631,12 +600,12 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
const server_pub_key = hsd.slice(key_size);
try main_cert_pub_key.verifySignature(&hsd, &.{ &client_hello_rand, &server_hello_rand, hsd.buf[0..hsd.idx] });
try key_share.exchange(named_group, server_pub_key);
handshake_state = .server_hello_done;
},
.server_hello_done => {
if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage;
if (cipher_state != .cleartext) return error.TlsUnexpectedMessage;
if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage;
handshake_state = .finished;
const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
@ -680,7 +649,6 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
.app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
} };
const pv = &p.version.tls_1_2;
pending_cipher_state = .application;
const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and
P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1)
nonce: {
@ -715,12 +683,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
},
}
write_seq += 1;
pending_cipher_state = .application;
handshake_state = .finished;
},
.certificate_verify => {
if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage;
if (cipher_state != .handshake) return error.TlsUnexpectedMessage;
switch (handshake_state) {
.trust_chain_established => handshake_state = .finished,
.trust_chain_established => {},
.certificate => return error.TlsCertificateNotVerified,
else => return error.TlsUnexpectedMessage,
}
@ -733,6 +703,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
p.transcript_hash.update(wrapped_handshake);
},
}
handshake_state = .finished;
},
.finished => {
if (cipher_state == .cleartext) return error.TlsUnexpectedMessage;