From 940d368e7ea95d2bb8185e71af3d1ec0328917dc Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 28 Dec 2022 16:37:22 -0700 Subject: [PATCH] std.crypto.tls.Client: fix the read function The read function has been renamed to readAdvanced since it has slightly different semantics than typical read functions, specifically regarding the end-of-file. A higher level read function is implemented on top. Now, API users may pass small buffers to the read function and everything will work fine. This is done by re-decrypting the same ciphertext record with each call to read() until the record is finished being transmitted. If the buffer supplied to read() is large enough, then any given ciphertext record will only be decrypted once, since it decrypts directly to the read() buffer and therefore does not need any memcpy. On the other hand, if the buffer supplied to read() is small, then the ciphertext is decrypted into a stack buffer, a subset is copied to the read() buffer, and then the entire ciphertext record is saved for the next call to read(). --- lib/std/crypto/tls/Client.zig | 163 +++++++++++++++++++++++++++------- lib/std/http/Client.zig | 12 +-- lib/std/net.zig | 7 +- 3 files changed, 136 insertions(+), 46 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index fd22a503c1..8d37e82117 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -18,14 +18,20 @@ const array = tls.array; const enum_array = tls.enum_array; const Certificate = crypto.Certificate; -application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, -/// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partially_read_buffer`. partially_read_len: u15, +/// The number of cleartext bytes from decoding `partially_read_buffer` which +/// have already been transferred via read() calls. This implementation will +/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by +/// the read() API user is not large enough. +partial_cleartext_index: u15, +application_cipher: ApplicationCipher, eof: bool, +/// The size is enough to contain exactly one TLSCiphertext record. +/// Contains encrypted bytes. +partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { @@ -596,6 +602,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, + .partial_cleartext_index = 0, .partially_read_buffer = undefined, .partially_read_len = @intCast(u15, len - end), .eof = false, @@ -722,27 +729,85 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } } -/// Returns number of bytes that have been read, which are now populated inside -/// `buffer`. A return value of zero bytes does not necessarily mean end of -/// stream. Instead, the `eof` flag is set upon end of stream. The `eof` flag -/// may be set after any call to `read`, including when greater than zero bytes -/// are returned, and this function asserts that `eof` is `false`. -pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the buffer has at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { + assert(len <= buffer.len); + if (c.eof) return 0; + var index: usize = 0; + while (index < len) { + index += try c.readAdvanced(stream, buffer[index..]); + if (c.eof) break; + } + return index; +} + +pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, 1); +} + +/// Returns the number of bytes read. If the number read is smaller than +/// `buffer.len`, it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, buffer.len); +} + +/// Returns number of bytes that have been read, populated inside `buffer`. A +/// return value of zero bytes does not mean end of stream. Instead, the `eof` +/// flag is set upon end of stream. The `eof` flag may be set after any call to +/// `read`, including when greater than zero bytes are returned, and this +/// function asserts that `eof` is `false`. +/// See `read` for a higher level function that has the same, familiar API +/// as other read functions, such as `std.fs.File.read`. +/// It is recommended to use a buffer size with length at least +/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same +/// encoded data. +pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { assert(!c.eof); const prev_len = c.partially_read_len; - var in_buf: [max_ciphertext_len * 4]u8 = undefined; - mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]); + // Ideally, this buffer would never be used. It is needed when `buffer` is too small + // to fit the cleartext, which may be as large as `max_ciphertext_len`. + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + // This buffer is typically used, except, as an optimization when a very large + // `buffer` is provided, we use half of it for buffering ciphertext and the + // other half for outputting cleartext. + var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; + const half_buffer_len = buffer.len / 2; + const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{ + buffer[0..half_buffer_len], + buffer[half_buffer_len..], + } else .{ + buffer, + &in_stack_buffer, + }; + const out_buf = out_in[0]; + const in_buf = out_in[1]; + mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]); // Capacity of output buffer, in records, rounded up. - const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); - const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)]; - const actual_read_len = try stream.read(ask_slice); - const frag = in_buf[0 .. prev_len + actual_read_len]; - if (frag.len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; - } + const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); + const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)]; + assert(ask_slice.len > 0); + const frag = frag: { + if (prev_len >= 5) { + const record_size = mem.readIntBig(u16, in_buf[3..][0..2]); + if (prev_len >= 5 + record_size) { + // We can use our buffered data without calling read(). + break :frag in_buf[0..prev_len]; + } + } + const actual_read_len = try stream.read(ask_slice); + if (actual_read_len == 0) { + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; + } + break :frag in_buf[0 .. prev_len + actual_read_len]; + }; var in: usize = 0; var out: usize = 0; @@ -750,6 +815,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { if (in + tls.ciphertext_record_header_len > frag.len) { return finishRead(c, frag, in, out); } + const record_start = in; const ct = @intToEnum(ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); @@ -767,7 +833,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { @panic("TODO handle an alert here"); }, .application_data => { - const cleartext_len = switch (c.application_cipher) { + const cleartext = switch (c.application_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); @@ -776,29 +842,29 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const cleartext = buffer[out..][0..ciphertext_len]; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + // Here we use read_seq and then intentionally don't + // increment it until later when it is certain the same + // ciphertext does not need to be decrypted again. const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); - c.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ - // c.read_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&p.server_key), - // std.fmt.fmtSliceHexLower(&p.server_iv), - //}); + const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len) + out_buf[out..] + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; - break :c cleartext.len; + break :c cleartext; }, }; - const cleartext = buffer[out..][0..cleartext_len]; const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { - const level = @intToEnum(tls.AlertLevel, buffer[out]); - const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]); + c.read_seq += 1; + const level = @intToEnum(tls.AlertLevel, out_buf[out]); + const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]); if (desc == .close_notify) { c.eof = true; return out; @@ -807,6 +873,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsAlert; }, .handshake => { + c.read_seq += 1; var ct_i: usize = 0; while (true) { const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); @@ -819,7 +886,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { - std.debug.print("server sent a new session ticket\n", .{}); + // This client implementation ignores new session tickets. }, .key_update => { switch (c.application_cipher) { @@ -859,7 +926,35 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { } }, .application_data => { - out += cleartext_len - 1; + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (c.partial_cleartext_index == 0 and + out + cleartext.len <= out_buf.len) + { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + out += cleartext.len - 1; + c.read_seq += 1; + } else { + // Stack buffer was used, so we must copy to the output buffer. + const dest = out_buf[out..]; + const rest = cleartext[c.partial_cleartext_index..]; + const src = rest[0..@min(rest.len, dest.len)]; + mem.copy(u8, dest, src); + out += src.len; + c.partial_cleartext_index = @intCast( + @TypeOf(c.partial_cleartext_index), + c.partial_cleartext_index + src.len, + ); + if (c.partial_cleartext_index >= cleartext.len) { + c.partial_cleartext_index = 0; + c.read_seq += 1; + } else { + in = record_start; + return finishRead(c, frag, in, out); + } + } }, else => { std.debug.print("inner content type: {d}\n", .{inner_ct}); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index f1f61cae0c..d27d879663 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -63,16 +63,10 @@ pub const Request = struct { } pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - var index: usize = 0; - while (index < len) { - const amt = try req.read(buffer[index..]); - index += amt; - switch (req.protocol) { - .http => if (amt == 0) break, - .https => if (req.tls_client.eof) break, - } + switch (req.protocol) { + .http => return req.stream.readAtLeast(buffer, len), + .https => return req.tls_client.readAtLeast(req.stream, buffer, len), } - return index; } }; diff --git a/lib/std/net.zig b/lib/std/net.zig index a265fa69a9..0112d5be8c 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1680,11 +1680,12 @@ pub const Stream = struct { } /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until at least the buffer has at least - /// `len` bytes filled. If the number read is less than `len` it means the - /// stream reached the end. Reaching the end of the stream is not an error + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error /// condition. pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); var index: usize = 0; while (index < len) { const amt = try s.read(buffer[index..]);