diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index fd22a503c1..8d37e82117 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -18,14 +18,20 @@ const array = tls.array; const enum_array = tls.enum_array; const Certificate = crypto.Certificate; -application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, -/// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partially_read_buffer`. partially_read_len: u15, +/// The number of cleartext bytes from decoding `partially_read_buffer` which +/// have already been transferred via read() calls. This implementation will +/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by +/// the read() API user is not large enough. +partial_cleartext_index: u15, +application_cipher: ApplicationCipher, eof: bool, +/// The size is enough to contain exactly one TLSCiphertext record. +/// Contains encrypted bytes. +partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { @@ -596,6 +602,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, + .partial_cleartext_index = 0, .partially_read_buffer = undefined, .partially_read_len = @intCast(u15, len - end), .eof = false, @@ -722,27 +729,85 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } } -/// Returns number of bytes that have been read, which are now populated inside -/// `buffer`. A return value of zero bytes does not necessarily mean end of -/// stream. Instead, the `eof` flag is set upon end of stream. The `eof` flag -/// may be set after any call to `read`, including when greater than zero bytes -/// are returned, and this function asserts that `eof` is `false`. -pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the buffer has at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { + assert(len <= buffer.len); + if (c.eof) return 0; + var index: usize = 0; + while (index < len) { + index += try c.readAdvanced(stream, buffer[index..]); + if (c.eof) break; + } + return index; +} + +pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, 1); +} + +/// Returns the number of bytes read. If the number read is smaller than +/// `buffer.len`, it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, buffer.len); +} + +/// Returns number of bytes that have been read, populated inside `buffer`. A +/// return value of zero bytes does not mean end of stream. Instead, the `eof` +/// flag is set upon end of stream. The `eof` flag may be set after any call to +/// `read`, including when greater than zero bytes are returned, and this +/// function asserts that `eof` is `false`. +/// See `read` for a higher level function that has the same, familiar API +/// as other read functions, such as `std.fs.File.read`. +/// It is recommended to use a buffer size with length at least +/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same +/// encoded data. +pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { assert(!c.eof); const prev_len = c.partially_read_len; - var in_buf: [max_ciphertext_len * 4]u8 = undefined; - mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]); + // Ideally, this buffer would never be used. It is needed when `buffer` is too small + // to fit the cleartext, which may be as large as `max_ciphertext_len`. + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + // This buffer is typically used, except, as an optimization when a very large + // `buffer` is provided, we use half of it for buffering ciphertext and the + // other half for outputting cleartext. + var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; + const half_buffer_len = buffer.len / 2; + const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{ + buffer[0..half_buffer_len], + buffer[half_buffer_len..], + } else .{ + buffer, + &in_stack_buffer, + }; + const out_buf = out_in[0]; + const in_buf = out_in[1]; + mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]); // Capacity of output buffer, in records, rounded up. - const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); - const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)]; - const actual_read_len = try stream.read(ask_slice); - const frag = in_buf[0 .. prev_len + actual_read_len]; - if (frag.len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; - } + const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); + const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)]; + assert(ask_slice.len > 0); + const frag = frag: { + if (prev_len >= 5) { + const record_size = mem.readIntBig(u16, in_buf[3..][0..2]); + if (prev_len >= 5 + record_size) { + // We can use our buffered data without calling read(). + break :frag in_buf[0..prev_len]; + } + } + const actual_read_len = try stream.read(ask_slice); + if (actual_read_len == 0) { + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; + } + break :frag in_buf[0 .. prev_len + actual_read_len]; + }; var in: usize = 0; var out: usize = 0; @@ -750,6 +815,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { if (in + tls.ciphertext_record_header_len > frag.len) { return finishRead(c, frag, in, out); } + const record_start = in; const ct = @intToEnum(ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); @@ -767,7 +833,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { @panic("TODO handle an alert here"); }, .application_data => { - const cleartext_len = switch (c.application_cipher) { + const cleartext = switch (c.application_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); @@ -776,29 +842,29 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const cleartext = buffer[out..][0..ciphertext_len]; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + // Here we use read_seq and then intentionally don't + // increment it until later when it is certain the same + // ciphertext does not need to be decrypted again. const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); - c.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ - // c.read_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&p.server_key), - // std.fmt.fmtSliceHexLower(&p.server_iv), - //}); + const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len) + out_buf[out..] + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; - break :c cleartext.len; + break :c cleartext; }, }; - const cleartext = buffer[out..][0..cleartext_len]; const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { - const level = @intToEnum(tls.AlertLevel, buffer[out]); - const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]); + c.read_seq += 1; + const level = @intToEnum(tls.AlertLevel, out_buf[out]); + const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]); if (desc == .close_notify) { c.eof = true; return out; @@ -807,6 +873,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsAlert; }, .handshake => { + c.read_seq += 1; var ct_i: usize = 0; while (true) { const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); @@ -819,7 +886,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { - std.debug.print("server sent a new session ticket\n", .{}); + // This client implementation ignores new session tickets. }, .key_update => { switch (c.application_cipher) { @@ -859,7 +926,35 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { } }, .application_data => { - out += cleartext_len - 1; + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (c.partial_cleartext_index == 0 and + out + cleartext.len <= out_buf.len) + { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + out += cleartext.len - 1; + c.read_seq += 1; + } else { + // Stack buffer was used, so we must copy to the output buffer. + const dest = out_buf[out..]; + const rest = cleartext[c.partial_cleartext_index..]; + const src = rest[0..@min(rest.len, dest.len)]; + mem.copy(u8, dest, src); + out += src.len; + c.partial_cleartext_index = @intCast( + @TypeOf(c.partial_cleartext_index), + c.partial_cleartext_index + src.len, + ); + if (c.partial_cleartext_index >= cleartext.len) { + c.partial_cleartext_index = 0; + c.read_seq += 1; + } else { + in = record_start; + return finishRead(c, frag, in, out); + } + } }, else => { std.debug.print("inner content type: {d}\n", .{inner_ct}); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index f1f61cae0c..d27d879663 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -63,16 +63,10 @@ pub const Request = struct { } pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - var index: usize = 0; - while (index < len) { - const amt = try req.read(buffer[index..]); - index += amt; - switch (req.protocol) { - .http => if (amt == 0) break, - .https => if (req.tls_client.eof) break, - } + switch (req.protocol) { + .http => return req.stream.readAtLeast(buffer, len), + .https => return req.tls_client.readAtLeast(req.stream, buffer, len), } - return index; } }; diff --git a/lib/std/net.zig b/lib/std/net.zig index a265fa69a9..0112d5be8c 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1680,11 +1680,12 @@ pub const Stream = struct { } /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until at least the buffer has at least - /// `len` bytes filled. If the number read is less than `len` it means the - /// stream reached the end. Reaching the end of the stream is not an error + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error /// condition. pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); var index: usize = 0; while (index < len) { const amt = try s.read(buffer[index..]);