From 21ab99174eabc9ae8efa2b19890d9cab51773b35 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 23:49:15 -0700 Subject: [PATCH] std.crypto.tls.Client: use enums more --- lib/std/crypto/tls.zig | 3 +++ lib/std/crypto/tls/Client.zig | 41 +++++++++++++++++------------------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index acfa8558c1..fc2523f02a 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -74,6 +74,7 @@ pub const HandshakeType = enum(u8) { finished = 20, key_update = 24, message_hash = 254, + _, }; pub const ExtensionType = enum(u16) { @@ -121,6 +122,8 @@ pub const ExtensionType = enum(u16) { signature_algorithms_cert = 50, /// RFC 8446 key_share = 51, + + _, }; pub const AlertLevel = enum(u8) { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 260441295d..fd22a503c1 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -9,7 +9,6 @@ const assert = std.debug.assert; const ApplicationCipher = tls.ApplicationCipher; const CipherSuite = tls.CipherSuite; const ContentType = tls.ContentType; -const HandshakeType = tls.HandshakeType; const HandshakeCipher = tls.HandshakeCipher; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; @@ -91,7 +90,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) extensions_header; const out_handshake = - [_]u8{@enumToInt(HandshakeType.client_hello)} ++ + [_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++ int3(@intCast(u24, client_hello.len + host_len)) ++ client_hello; @@ -142,7 +141,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) return error.TlsAlert; }, .handshake => { - if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) { return error.TlsUnexpectedMessage; } const length = mem.readIntBig(u24, frag[1..4]); @@ -175,27 +174,27 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var shared_key: [32]u8 = undefined; var have_shared_key = false; while (i < frag.len) { - const et = mem.readIntBig(u16, frag[i..][0..2]); + const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2])); i += 2; const ext_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; const next_i = i + ext_size; if (next_i > frag.len) return error.TlsBadLength; switch (et) { - @enumToInt(tls.ExtensionType.supported_versions) => { + .supported_versions => { if (supported_version != 0) return error.TlsIllegalParameter; supported_version = mem.readIntBig(u16, frag[i..][0..2]); }, - @enumToInt(tls.ExtensionType.key_share) => { + .key_share => { if (have_shared_key) return error.TlsIllegalParameter; have_shared_key = true; - const named_group = mem.readIntBig(u16, frag[i..][0..2]); + const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2])); i += 2; const key_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; switch (named_group) { - @enumToInt(tls.NamedGroup.x25519) => { + .x25519 => { if (key_size != 32) return error.TlsBadLength; const server_pub_key = frag[i..][0..32]; @@ -204,7 +203,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) server_pub_key.*, ) catch return error.TlsDecryptFailure; }, - @enumToInt(tls.NamedGroup.secp256r1) => { + .secp256r1 => { const server_pub_key = frag[i..][0..key_size]; const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; @@ -217,7 +216,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) shared_key = mul.affineCoordinates().x.toBytes(.Big); }, else => { - std.debug.print("named group: {x}\n", .{named_group}); + //std.debug.print("named group: {x}\n", .{named_group}); return error.TlsIllegalParameter; }, } @@ -380,7 +379,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .handshake => { var ct_i: usize = 0; while (true) { - const handshake_type = cleartext[ct_i]; + const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); ct_i += 1; const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); ct_i += 3; @@ -390,7 +389,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { - @enumToInt(HandshakeType.encrypted_extensions) => { + .encrypted_extensions => { if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; handshake_state = .certificate; switch (handshake_cipher) { @@ -400,13 +399,13 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var hs_i: usize = 2; const end_ext_i = 2 + total_ext_size; while (hs_i < end_ext_i) { - const et = mem.readIntBig(u16, handshake[hs_i..][0..2]); + const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2])); hs_i += 2; const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; const next_ext_i = hs_i + ext_size; switch (et) { - @enumToInt(tls.ExtensionType.server_name) => {}, + .server_name => {}, else => { std.debug.print("encrypted extension: {any}\n", .{ et, @@ -416,7 +415,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) hs_i = next_ext_i; } }, - @enumToInt(HandshakeType.certificate) => cert: { + .certificate => cert: { switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } @@ -488,7 +487,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) hs_i += total_ext_size; } }, - @enumToInt(HandshakeType.certificate_verify) => { + .certificate_verify => { switch (handshake_state) { .trust_chain_established => handshake_state = .finished, .certificate => return error.TlsCertificateNotVerified, @@ -535,7 +534,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, } }, - @enumToInt(HandshakeType.finished) => { + .finished => { if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. const client_change_cipher_spec_msg = [_]u8{ @@ -555,7 +554,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ - @enumToInt(HandshakeType.finished), + @enumToInt(tls.HandshakeType.finished), 0, 0, verify_data.len, // length } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; @@ -810,7 +809,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { .handshake => { var ct_i: usize = 0; while (true) { - const handshake_type = cleartext[ct_i]; + const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); ct_i += 1; const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); ct_i += 3; @@ -819,10 +818,10 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsBadLength; const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { - @enumToInt(HandshakeType.new_session_ticket) => { + .new_session_ticket => { std.debug.print("server sent a new session ticket\n", .{}); }, - @enumToInt(HandshakeType.key_update) => { + .key_update => { switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*);