std.crypto.tls.Client: use enums more

This commit is contained in:
Andrew Kelley 2022-12-27 23:49:15 -07:00
parent 477864dca5
commit 21ab99174e
2 changed files with 23 additions and 21 deletions

View File

@ -74,6 +74,7 @@ pub const HandshakeType = enum(u8) {
finished = 20,
key_update = 24,
message_hash = 254,
_,
};
pub const ExtensionType = enum(u16) {
@ -121,6 +122,8 @@ pub const ExtensionType = enum(u16) {
signature_algorithms_cert = 50,
/// RFC 8446
key_share = 51,
_,
};
pub const AlertLevel = enum(u8) {

View File

@ -9,7 +9,6 @@ const assert = std.debug.assert;
const ApplicationCipher = tls.ApplicationCipher;
const CipherSuite = tls.CipherSuite;
const ContentType = tls.ContentType;
const HandshakeType = tls.HandshakeType;
const HandshakeCipher = tls.HandshakeCipher;
const max_ciphertext_len = tls.max_ciphertext_len;
const hkdfExpandLabel = tls.hkdfExpandLabel;
@ -91,7 +90,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
extensions_header;
const out_handshake =
[_]u8{@enumToInt(HandshakeType.client_hello)} ++
[_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++
int3(@intCast(u24, client_hello.len + host_len)) ++
client_hello;
@ -142,7 +141,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
return error.TlsAlert;
},
.handshake => {
if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) {
return error.TlsUnexpectedMessage;
}
const length = mem.readIntBig(u24, frag[1..4]);
@ -175,27 +174,27 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
var shared_key: [32]u8 = undefined;
var have_shared_key = false;
while (i < frag.len) {
const et = mem.readIntBig(u16, frag[i..][0..2]);
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2]));
i += 2;
const ext_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const next_i = i + ext_size;
if (next_i > frag.len) return error.TlsBadLength;
switch (et) {
@enumToInt(tls.ExtensionType.supported_versions) => {
.supported_versions => {
if (supported_version != 0) return error.TlsIllegalParameter;
supported_version = mem.readIntBig(u16, frag[i..][0..2]);
},
@enumToInt(tls.ExtensionType.key_share) => {
.key_share => {
if (have_shared_key) return error.TlsIllegalParameter;
have_shared_key = true;
const named_group = mem.readIntBig(u16, frag[i..][0..2]);
const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2]));
i += 2;
const key_size = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
switch (named_group) {
@enumToInt(tls.NamedGroup.x25519) => {
.x25519 => {
if (key_size != 32) return error.TlsBadLength;
const server_pub_key = frag[i..][0..32];
@ -204,7 +203,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
server_pub_key.*,
) catch return error.TlsDecryptFailure;
},
@enumToInt(tls.NamedGroup.secp256r1) => {
.secp256r1 => {
const server_pub_key = frag[i..][0..key_size];
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
@ -217,7 +216,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
shared_key = mul.affineCoordinates().x.toBytes(.Big);
},
else => {
std.debug.print("named group: {x}\n", .{named_group});
//std.debug.print("named group: {x}\n", .{named_group});
return error.TlsIllegalParameter;
},
}
@ -380,7 +379,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
.handshake => {
var ct_i: usize = 0;
while (true) {
const handshake_type = cleartext[ct_i];
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
ct_i += 1;
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
ct_i += 3;
@ -390,7 +389,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
@enumToInt(HandshakeType.encrypted_extensions) => {
.encrypted_extensions => {
if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
handshake_state = .certificate;
switch (handshake_cipher) {
@ -400,13 +399,13 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
var hs_i: usize = 2;
const end_ext_i = 2 + total_ext_size;
while (hs_i < end_ext_i) {
const et = mem.readIntBig(u16, handshake[hs_i..][0..2]);
const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2]));
hs_i += 2;
const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
hs_i += 2;
const next_ext_i = hs_i + ext_size;
switch (et) {
@enumToInt(tls.ExtensionType.server_name) => {},
.server_name => {},
else => {
std.debug.print("encrypted extension: {any}\n", .{
et,
@ -416,7 +415,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
hs_i = next_ext_i;
}
},
@enumToInt(HandshakeType.certificate) => cert: {
.certificate => cert: {
switch (handshake_cipher) {
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
}
@ -488,7 +487,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
hs_i += total_ext_size;
}
},
@enumToInt(HandshakeType.certificate_verify) => {
.certificate_verify => {
switch (handshake_state) {
.trust_chain_established => handshake_state = .finished,
.certificate => return error.TlsCertificateNotVerified,
@ -535,7 +534,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
}
},
@enumToInt(HandshakeType.finished) => {
.finished => {
if (handshake_state != .finished) return error.TlsUnexpectedMessage;
// This message is to trick buggy proxies into behaving correctly.
const client_change_cipher_spec_msg = [_]u8{
@ -555,7 +554,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const handshake_hash = p.transcript_hash.finalResult();
const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
const out_cleartext = [_]u8{
@enumToInt(HandshakeType.finished),
@enumToInt(tls.HandshakeType.finished),
0, 0, verify_data.len, // length
} ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
@ -810,7 +809,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
.handshake => {
var ct_i: usize = 0;
while (true) {
const handshake_type = cleartext[ct_i];
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
ct_i += 1;
const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
ct_i += 3;
@ -819,10 +818,10 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
return error.TlsBadLength;
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
@enumToInt(HandshakeType.new_session_ticket) => {
.new_session_ticket => {
std.debug.print("server sent a new session ticket\n", .{});
},
@enumToInt(HandshakeType.key_update) => {
.key_update => {
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);