diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 8b455c6c71f1..0903c129c4fb 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -4,6 +4,7 @@ const Uri = @This(); const std = @import("std.zig"); const testing = std.testing; +const Allocator = std.mem.Allocator; scheme: []const u8, user: ?[]const u8 = null, @@ -15,15 +16,15 @@ query: ?[]const u8 = null, fragment: ?[]const u8 = null, /// Applies URI encoding and replaces all reserved characters with their respective %XX code. -pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]u8 { +pub fn escapeString(allocator: Allocator, input: []const u8) error{OutOfMemory}![]u8 { return escapeStringWithFn(allocator, input, isUnreserved); } -pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]u8 { +pub fn escapePath(allocator: Allocator, input: []const u8) error{OutOfMemory}![]u8 { return escapeStringWithFn(allocator, input, isPathChar); } -pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]u8 { +pub fn escapeQuery(allocator: Allocator, input: []const u8) error{OutOfMemory}![]u8 { return escapeStringWithFn(allocator, input, isQueryChar); } @@ -39,7 +40,7 @@ pub fn writeEscapedQuery(writer: anytype, input: []const u8) !void { return writeEscapedStringWithFn(writer, input, isQueryChar); } -pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]u8 { +pub fn escapeStringWithFn(allocator: Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) Allocator.Error![]u8 { var outsize: usize = 0; for (input) |c| { outsize += if (keepUnescaped(c)) @as(usize, 1) else 3; @@ -76,7 +77,7 @@ pub fn writeEscapedStringWithFn(writer: anytype, input: []const u8, comptime kee /// Parses a URI string and unescapes all %XX where XX is a valid hex number. Otherwise, verbatim copies /// them to the output. -pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]u8 { +pub fn unescapeString(allocator: Allocator, input: []const u8) error{OutOfMemory}![]u8 { var outsize: usize = 0; var inptr: usize = 0; while (inptr < input.len) { @@ -341,7 +342,7 @@ pub fn format( /// The return value will contain unescaped strings pointing into the /// original `text`. Each component that is provided, will be non-`null`. pub fn parse(text: []const u8) ParseError!Uri { - var reader = SliceReader{ .slice = text }; + var reader: SliceReader = .{ .slice = text }; const scheme = reader.readWhile(isSchemeChar); // after the scheme, a ':' must appear @@ -358,111 +359,145 @@ pub fn parse(text: []const u8) ParseError!Uri { return uri; } -/// Implementation of RFC 3986, Section 5.2.4. Removes dot segments from a URI path. -/// -/// `std.fs.path.resolvePosix` is not sufficient here because it may return relative paths and does not preserve trailing slashes. -fn removeDotSegments(allocator: std.mem.Allocator, paths: []const []const u8) std.mem.Allocator.Error![]const u8 { - var result = std.ArrayList(u8).init(allocator); - defer result.deinit(); - - for (paths) |p| { - var it = std.mem.tokenizeScalar(u8, p, '/'); - while (it.next()) |component| { - if (std.mem.eql(u8, component, ".")) { - continue; - } else if (std.mem.eql(u8, component, "..")) { - if (result.items.len == 0) - continue; +pub const ResolveInplaceError = ParseError || error{OutOfMemory}; - while (true) { - const ends_with_slash = result.items[result.items.len - 1] == '/'; - result.items.len -= 1; - if (ends_with_slash or result.items.len == 0) break; - } - } else { - try result.ensureUnusedCapacity(1 + component.len); - result.appendAssumeCapacity('/'); - result.appendSliceAssumeCapacity(component); - } - } - } +/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. +/// Copies `new` to the beginning of `aux_buf`, allowing the slices to overlap, +/// then parses `new` as a URI, and then resolves the path in place. +/// If a merge needs to take place, the newly constructed path will be stored +/// in `aux_buf` just after the copied `new`. +pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: []u8) ResolveInplaceError!Uri { + std.mem.copyBackwards(u8, aux_buf, new); + // At this point, new is an invalid pointer. + const new_mut = aux_buf[0..new.len]; + + const new_parsed, const has_scheme = p: { + break :p .{ + parse(new_mut) catch |first_err| { + break :p .{ + parseWithoutScheme(new_mut) catch return first_err, + false, + }; + }, + true, + }; + }; - // ensure a trailing slash is kept - const last_path = paths[paths.len - 1]; - if (last_path.len > 0 and last_path[last_path.len - 1] == '/') { - try result.append('/'); - } + // As you can see above, `new_mut` is not a const pointer. + const new_path: []u8 = @constCast(new_parsed.path); + + if (has_scheme) return .{ + .scheme = new_parsed.scheme, + .user = new_parsed.user, + .host = new_parsed.host, + .port = new_parsed.port, + .path = remove_dot_segments(new_path), + .query = new_parsed.query, + .fragment = new_parsed.fragment, + }; - return result.toOwnedSlice(); -} + if (new_parsed.host) |host| return .{ + .scheme = base.scheme, + .user = new_parsed.user, + .host = host, + .port = new_parsed.port, + .path = remove_dot_segments(new_path), + .query = new_parsed.query, + .fragment = new_parsed.fragment, + }; -/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. -/// -/// Assumes `arena` owns all memory in `base` and `ref`. `arena` will own all memory in the returned URI. -pub fn resolve(base: Uri, ref: Uri, strict: bool, arena: std.mem.Allocator) std.mem.Allocator.Error!Uri { - var target: Uri = Uri{ - .scheme = "", - .user = null, - .password = null, - .host = null, - .port = null, - .path = "", - .query = null, - .fragment = null, + const path, const query = b: { + if (new_path.len == 0) + break :b .{ + base.path, + new_parsed.query orelse base.query, + }; + + if (new_path[0] == '/') + break :b .{ + remove_dot_segments(new_path), + new_parsed.query, + }; + + break :b .{ + try merge_paths(base.path, new_path, aux_buf[new_mut.len..]), + new_parsed.query, + }; }; - if (ref.scheme.len > 0 and (strict or !std.mem.eql(u8, ref.scheme, base.scheme))) { - target.scheme = ref.scheme; - target.user = ref.user; - target.host = ref.host; - target.port = ref.port; - target.path = try removeDotSegments(arena, &.{ref.path}); - target.query = ref.query; - } else { - target.scheme = base.scheme; - if (ref.host) |host| { - target.user = ref.user; - target.host = host; - target.port = ref.port; - target.path = ref.path; - target.path = try removeDotSegments(arena, &.{ref.path}); - target.query = ref.query; + return .{ + .scheme = base.scheme, + .user = base.user, + .host = base.host, + .port = base.port, + .path = path, + .query = query, + .fragment = new_parsed.fragment, + }; +} + +/// In-place implementation of RFC 3986, Section 5.2.4. +fn remove_dot_segments(path: []u8) []u8 { + var in_i: usize = 0; + var out_i: usize = 0; + while (in_i < path.len) { + if (std.mem.startsWith(u8, path[in_i..], "./")) { + in_i += 2; + } else if (std.mem.startsWith(u8, path[in_i..], "../")) { + in_i += 3; + } else if (std.mem.startsWith(u8, path[in_i..], "/./")) { + in_i += 2; + } else if (std.mem.eql(u8, path[in_i..], "/.")) { + in_i += 1; + path[in_i] = '/'; + } else if (std.mem.startsWith(u8, path[in_i..], "/../")) { + in_i += 3; + while (out_i > 0) { + out_i -= 1; + if (path[out_i] == '/') break; + } + } else if (std.mem.eql(u8, path[in_i..], "/..")) { + in_i += 2; + path[in_i] = '/'; + while (out_i > 0) { + out_i -= 1; + if (path[out_i] == '/') break; + } + } else if (std.mem.eql(u8, path[in_i..], ".")) { + in_i += 1; + } else if (std.mem.eql(u8, path[in_i..], "..")) { + in_i += 2; } else { - if (ref.path.len == 0) { - target.path = base.path; - target.query = ref.query orelse base.query; - } else { - if (ref.path[0] == '/') { - target.path = try removeDotSegments(arena, &.{ref.path}); - } else { - target.path = try removeDotSegments(arena, &.{ std.fs.path.dirnamePosix(base.path) orelse "", ref.path }); - } - target.query = ref.query; + while (true) { + path[out_i] = path[in_i]; + out_i += 1; + in_i += 1; + if (in_i >= path.len or path[in_i] == '/') break; } - - target.user = base.user; - target.host = base.host; - target.port = base.port; } } - - target.fragment = ref.fragment; - - return target; + return path[0..out_i]; } -test resolve { - const base = try parse("http://a/b/c/d;p?q"); - - var arena = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena.deinit(); +test remove_dot_segments { + { + var buffer = "/a/b/c/./../../g".*; + try std.testing.expectEqualStrings("/a/g", remove_dot_segments(&buffer)); + } +} - try std.testing.expectEqualDeep(try parse("http://a/b/c/blog/"), try base.resolve(try parseWithoutScheme("blog/"), true, arena.allocator())); - try std.testing.expectEqualDeep(try parse("http://a/b/c/blog/?k"), try base.resolve(try parseWithoutScheme("blog/?k"), true, arena.allocator())); - try std.testing.expectEqualDeep(try parse("http://a/b/blog/"), try base.resolve(try parseWithoutScheme("../blog/"), true, arena.allocator())); - try std.testing.expectEqualDeep(try parse("http://a/b/blog"), try base.resolve(try parseWithoutScheme("../blog"), true, arena.allocator())); - try std.testing.expectEqualDeep(try parse("http://e"), try base.resolve(try parseWithoutScheme("//e"), true, arena.allocator())); - try std.testing.expectEqualDeep(try parse("https://a:1/"), try base.resolve(try parse("https://a:1/"), true, arena.allocator())); +/// 5.2.3. Merge Paths +fn merge_paths(base: []const u8, new: []u8, aux: []u8) error{OutOfMemory}![]u8 { + if (aux.len < base.len + 1 + new.len) return error.OutOfMemory; + if (base.len == 0) { + aux[0] = '/'; + @memcpy(aux[1..][0..new.len], new); + return remove_dot_segments(aux[0 .. new.len + 1]); + } + const pos = std.mem.lastIndexOfScalar(u8, base, '/') orelse return remove_dot_segments(new); + @memcpy(aux[0 .. pos + 1], base[0 .. pos + 1]); + @memcpy(aux[pos + 1 ..][0..new.len], new); + return remove_dot_segments(aux[0 .. pos + 1 + new.len]); } const SliceReader = struct { diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index 79ed5c192c30..1926f627f365 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -937,14 +937,33 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ return .{ .context = .{ .self = self, .allocator = allocator } }; } - /// Same as `append` except it returns the number of bytes written, which is always the same - /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. + /// Same as `append` except it returns the number of bytes written, + /// which is always the same as `m.len`. The purpose of this function + /// existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. fn appendWrite(context: WriterContext, m: []const u8) Allocator.Error!usize { try context.self.appendSlice(context.allocator, m); return m.len; } + pub const FixedWriter = std.io.Writer(*Self, Allocator.Error, appendWriteFixed); + + /// Initializes a Writer which will append to the list but will return + /// `error.OutOfMemory` rather than increasing capacity. + pub fn fixedWriter(self: *Self) FixedWriter { + return .{ .context = self }; + } + + /// The purpose of this function existing is to match `std.io.Writer` API. + fn appendWriteFixed(self: *Self, m: []const u8) error{OutOfMemory}!usize { + const available_capacity = self.capacity - self.items.len; + if (m.len > available_capacity) + return error.OutOfMemory; + + self.appendSliceAssumeCapacity(m); + return m.len; + } + /// Append a value to the list `n` times. /// Allocates more memory as necessary. /// Invalidates element pointers if additional memory is needed. diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 4d9421acac6d..cfe5618bde31 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const Allocator = std.mem.Allocator; const RingBuffer = std.RingBuffer; const types = @import("zstandard/types.zig"); @@ -8,32 +7,41 @@ pub const compressed_block = types.compressed_block; pub const decompress = @import("zstandard/decompress.zig"); -pub const DecompressStreamOptions = struct { +pub const DecompressorOptions = struct { verify_checksum: bool = true, - window_size_max: usize = 1 << 23, // 8MiB default maximum window size + window_buffer: []u8, + + /// Recommended amount by the standard. Lower than this may result + /// in inability to decompress common streams. + pub const default_window_buffer_len = 8 * 1024 * 1024; }; -pub fn DecompressStream( - comptime ReaderType: type, - comptime options: DecompressStreamOptions, -) type { +pub fn Decompressor(comptime ReaderType: type) type { return struct { const Self = @This(); - allocator: Allocator, + const table_size_max = types.compressed_block.table_size_max; + source: std.io.CountingReader(ReaderType), state: enum { NewFrame, InFrame, LastBlock }, decode_state: decompress.block.DecodeState, frame_context: decompress.FrameContext, - buffer: RingBuffer, - literal_fse_buffer: []types.compressed_block.Table.Fse, - match_fse_buffer: []types.compressed_block.Table.Fse, - offset_fse_buffer: []types.compressed_block.Table.Fse, - literals_buffer: []u8, - sequence_buffer: []u8, - checksum: if (options.verify_checksum) ?u32 else void, + buffer: WindowBuffer, + literal_fse_buffer: [table_size_max.literal]types.compressed_block.Table.Fse, + match_fse_buffer: [table_size_max.match]types.compressed_block.Table.Fse, + offset_fse_buffer: [table_size_max.offset]types.compressed_block.Table.Fse, + literals_buffer: [types.block_size_max]u8, + sequence_buffer: [types.block_size_max]u8, + verify_checksum: bool, + checksum: ?u32, current_frame_decompressed_size: usize, + const WindowBuffer = struct { + data: []u8 = undefined, + read_index: usize = 0, + write_index: usize = 0, + }; + pub const Error = ReaderType.Error || error{ ChecksumFailure, DictionaryIdFlagUnsupported, @@ -44,19 +52,19 @@ pub fn DecompressStream( pub const Reader = std.io.Reader(*Self, Error, read); - pub fn init(allocator: Allocator, source: ReaderType) Self { - return Self{ - .allocator = allocator, + pub fn init(source: ReaderType, options: DecompressorOptions) Self { + return .{ .source = std.io.countingReader(source), .state = .NewFrame, .decode_state = undefined, .frame_context = undefined, - .buffer = undefined, + .buffer = .{ .data = options.window_buffer }, .literal_fse_buffer = undefined, .match_fse_buffer = undefined, .offset_fse_buffer = undefined, .literals_buffer = undefined, .sequence_buffer = undefined, + .verify_checksum = options.verify_checksum, .checksum = undefined, .current_frame_decompressed_size = undefined, }; @@ -72,53 +80,20 @@ pub fn DecompressStream( .zstandard => |header| { const frame_context = try decompress.FrameContext.init( header, - options.window_size_max, - options.verify_checksum, - ); - - const literal_fse_buffer = try self.allocator.alloc( - types.compressed_block.Table.Fse, - types.compressed_block.table_size_max.literal, + self.buffer.data.len, + self.verify_checksum, ); - errdefer self.allocator.free(literal_fse_buffer); - - const match_fse_buffer = try self.allocator.alloc( - types.compressed_block.Table.Fse, - types.compressed_block.table_size_max.match, - ); - errdefer self.allocator.free(match_fse_buffer); - - const offset_fse_buffer = try self.allocator.alloc( - types.compressed_block.Table.Fse, - types.compressed_block.table_size_max.offset, - ); - errdefer self.allocator.free(offset_fse_buffer); const decode_state = decompress.block.DecodeState.init( - literal_fse_buffer, - match_fse_buffer, - offset_fse_buffer, + &self.literal_fse_buffer, + &self.match_fse_buffer, + &self.offset_fse_buffer, ); - const buffer = try RingBuffer.init(self.allocator, frame_context.window_size); - - const literals_data = try self.allocator.alloc(u8, options.window_size_max); - errdefer self.allocator.free(literals_data); - - const sequence_data = try self.allocator.alloc(u8, options.window_size_max); - errdefer self.allocator.free(sequence_data); - - self.literal_fse_buffer = literal_fse_buffer; - self.match_fse_buffer = match_fse_buffer; - self.offset_fse_buffer = offset_fse_buffer; - self.literals_buffer = literals_data; - self.sequence_buffer = sequence_data; - - self.buffer = buffer; self.decode_state = decode_state; self.frame_context = frame_context; - self.checksum = if (options.verify_checksum) null else {}; + self.checksum = null; self.current_frame_decompressed_size = 0; self.state = .InFrame; @@ -126,16 +101,6 @@ pub fn DecompressStream( } } - pub fn deinit(self: *Self) void { - if (self.state == .NewFrame) return; - self.allocator.free(self.decode_state.literal_fse_buffer); - self.allocator.free(self.decode_state.match_fse_buffer); - self.allocator.free(self.decode_state.offset_fse_buffer); - self.allocator.free(self.literals_buffer); - self.allocator.free(self.sequence_buffer); - self.buffer.deinit(self.allocator); - } - pub fn reader(self: *Self) Reader { return .{ .context = self }; } @@ -153,7 +118,6 @@ pub fn DecompressStream( 0 else error.MalformedFrame, - error.OutOfMemory => return error.OutOfMemory, else => return error.MalformedFrame, }; } @@ -165,20 +129,30 @@ pub fn DecompressStream( fn readInner(self: *Self, buffer: []u8) Error!usize { std.debug.assert(self.state != .NewFrame); + var ring_buffer = RingBuffer{ + .data = self.buffer.data, + .read_index = self.buffer.read_index, + .write_index = self.buffer.write_index, + }; + defer { + self.buffer.read_index = ring_buffer.read_index; + self.buffer.write_index = ring_buffer.write_index; + } + const source_reader = self.source.reader(); - while (self.buffer.isEmpty() and self.state != .LastBlock) { + while (ring_buffer.isEmpty() and self.state != .LastBlock) { const header_bytes = source_reader.readBytesNoEof(3) catch return error.MalformedFrame; const block_header = decompress.block.decodeBlockHeader(&header_bytes); decompress.block.decodeBlockReader( - &self.buffer, + &ring_buffer, source_reader, block_header, &self.decode_state, self.frame_context.block_size_max, - self.literals_buffer, - self.sequence_buffer, + &self.literals_buffer, + &self.sequence_buffer, ) catch return error.MalformedBlock; @@ -186,12 +160,12 @@ pub fn DecompressStream( if (self.current_frame_decompressed_size > size) return error.MalformedFrame; } - const size = self.buffer.len(); + const size = ring_buffer.len(); self.current_frame_decompressed_size += size; if (self.frame_context.hasher_opt) |*hasher| { if (size > 0) { - const written_slice = self.buffer.sliceLast(size); + const written_slice = ring_buffer.sliceLast(size); hasher.update(written_slice.first); hasher.update(written_slice.second); } @@ -201,7 +175,7 @@ pub fn DecompressStream( if (self.frame_context.has_checksum) { const checksum = source_reader.readInt(u32, .little) catch return error.MalformedFrame; - if (comptime options.verify_checksum) { + if (self.verify_checksum) { if (self.frame_context.hasher_opt) |*hasher| { if (checksum != decompress.computeChecksum(hasher)) return error.ChecksumFailure; @@ -216,43 +190,28 @@ pub fn DecompressStream( } } - const size = @min(self.buffer.len(), buffer.len); + const size = @min(ring_buffer.len(), buffer.len); if (size > 0) { - self.buffer.readFirstAssumeLength(buffer, size); + ring_buffer.readFirstAssumeLength(buffer, size); } - if (self.state == .LastBlock and self.buffer.len() == 0) { + if (self.state == .LastBlock and ring_buffer.len() == 0) { self.state = .NewFrame; - self.allocator.free(self.literal_fse_buffer); - self.allocator.free(self.match_fse_buffer); - self.allocator.free(self.offset_fse_buffer); - self.allocator.free(self.literals_buffer); - self.allocator.free(self.sequence_buffer); - self.buffer.deinit(self.allocator); } return size; } }; } -pub fn decompressStreamOptions( - allocator: Allocator, - reader: anytype, - comptime options: DecompressStreamOptions, -) DecompressStream(@TypeOf(reader, options)) { - return DecompressStream(@TypeOf(reader), options).init(allocator, reader); -} - -pub fn decompressStream( - allocator: Allocator, - reader: anytype, -) DecompressStream(@TypeOf(reader), .{}) { - return DecompressStream(@TypeOf(reader), .{}).init(allocator, reader); +pub fn decompressor(reader: anytype, options: DecompressorOptions) Decompressor(@TypeOf(reader)) { + return Decompressor(@TypeOf(reader)).init(reader, options); } fn testDecompress(data: []const u8) ![]u8 { + const window_buffer = try std.testing.allocator.alloc(u8, 1 << 23); + defer std.testing.allocator.free(window_buffer); + var in_stream = std.io.fixedBufferStream(data); - var zstd_stream = decompressStream(std.testing.allocator, in_stream.reader()); - defer zstd_stream.deinit(); + var zstd_stream = decompressor(in_stream.reader(), .{ .window_buffer = window_buffer }); const result = zstd_stream.reader().readAllAlloc(std.testing.allocator, std.math.maxInt(usize)); return result; } @@ -278,38 +237,48 @@ test "zstandard decompression" { const res19 = try decompress.decode(buffer, compressed19, true); try std.testing.expectEqual(uncompressed.len, res19); try std.testing.expectEqualSlices(u8, uncompressed, buffer); +} + +test "zstandard streaming decompression" { + // default stack size for wasm32 is too low for Decompressor - slightly + // over 1MiB stack space is needed via the --stack CLI flag + if (@import("builtin").target.cpu.arch == .wasm32) return error.SkipZigTest; + + const uncompressed = @embedFile("testdata/rfc8478.txt"); + const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3"); + const compressed19 = @embedFile("testdata/rfc8478.txt.zst.19"); try testReader(compressed3, uncompressed); try testReader(compressed19, uncompressed); } fn expectEqualDecoded(expected: []const u8, input: []const u8) !void { - const allocator = std.testing.allocator; - { - const result = try decompress.decodeAlloc(allocator, input, false, 1 << 23); - defer allocator.free(result); + const result = try decompress.decodeAlloc(std.testing.allocator, input, false, 1 << 23); + defer std.testing.allocator.free(result); try std.testing.expectEqualStrings(expected, result); } { - var buffer = try allocator.alloc(u8, 2 * expected.len); - defer allocator.free(buffer); + var buffer = try std.testing.allocator.alloc(u8, 2 * expected.len); + defer std.testing.allocator.free(buffer); const size = try decompress.decode(buffer, input, false); try std.testing.expectEqualStrings(expected, buffer[0..size]); } +} - { - var in_stream = std.io.fixedBufferStream(input); - var stream = decompressStream(allocator, in_stream.reader()); - defer stream.deinit(); +fn expectEqualDecodedStreaming(expected: []const u8, input: []const u8) !void { + const window_buffer = try std.testing.allocator.alloc(u8, 1 << 23); + defer std.testing.allocator.free(window_buffer); - const result = try stream.reader().readAllAlloc(allocator, std.math.maxInt(usize)); - defer allocator.free(result); + var in_stream = std.io.fixedBufferStream(input); + var stream = decompressor(in_stream.reader(), .{ .window_buffer = window_buffer }); - try std.testing.expectEqualStrings(expected, result); - } + const result = try stream.reader().readAllAlloc(std.testing.allocator, std.math.maxInt(usize)); + defer std.testing.allocator.free(result); + + try std.testing.expectEqualStrings(expected, result); } test "zero sized block" { @@ -327,3 +296,23 @@ test "zero sized block" { try expectEqualDecoded("", input_raw); try expectEqualDecoded("", input_rle); } + +test "zero sized block streaming" { + // default stack size for wasm32 is too low for Decompressor - slightly + // over 1MiB stack space is needed via the --stack CLI flag + if (@import("builtin").target.cpu.arch == .wasm32) return error.SkipZigTest; + + const input_raw = + "\x28\xb5\x2f\xfd" ++ // zstandard frame magic number + "\x20\x00" ++ // frame header: only single_segment_flag set, frame_content_size zero + "\x01\x00\x00"; // block header with: last_block set, block_type raw, block_size zero + + const input_rle = + "\x28\xb5\x2f\xfd" ++ // zstandard frame magic number + "\x20\x00" ++ // frame header: only single_segment_flag set, frame_content_size zero + "\x03\x00\x00" ++ // block header with: last_block set, block_type rle, block_size zero + "\xaa"; // block_content + + try expectEqualDecodedStreaming("", input_raw); + try expectEqualDecodedStreaming("", input_rle); +} diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index a012312ab1a8..86be16268f35 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -409,7 +409,7 @@ pub const FrameContext = struct { .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null, .window_size = window_size, .has_checksum = frame_header.descriptor.content_checksum_flag, - .block_size_max = @min(1 << 17, window_size), + .block_size_max = @min(types.block_size_max, window_size), .content_size = content_size, }; } diff --git a/lib/std/compress/zstandard/types.zig b/lib/std/compress/zstandard/types.zig index db4fbdee2d92..41c3797d16bc 100644 --- a/lib/std/compress/zstandard/types.zig +++ b/lib/std/compress/zstandard/types.zig @@ -1,3 +1,5 @@ +pub const block_size_max = 1 << 17; + pub const frame = struct { pub const Kind = enum { zstandard, skippable }; @@ -391,7 +393,7 @@ pub const compressed_block = struct { pub const table_size_max = struct { pub const literal = 1 << table_accuracy_log_max.literal; pub const match = 1 << table_accuracy_log_max.match; - pub const offset = 1 << table_accuracy_log_max.match; + pub const offset = 1 << table_accuracy_log_max.offset; }; }; diff --git a/lib/std/http.zig b/lib/std/http.zig index 9b2bce133814..af966d89e75d 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,12 +1,9 @@ -const std = @import("std.zig"); - pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); -const headers = @import("http/Headers.zig"); - -pub const Headers = headers.Headers; -pub const Field = headers.Field; +pub const HeadParser = @import("http/HeadParser.zig"); +pub const ChunkParser = @import("http/ChunkParser.zig"); +pub const HeaderIterator = @import("http/HeaderIterator.zig"); pub const Version = enum { @"HTTP/1.0", @@ -18,7 +15,7 @@ pub const Version = enum { /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition /// /// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is supported by the C backend, and therefore cannot pass CI +pub const Method = enum(u64) { GET = parse("GET"), HEAD = parse("HEAD"), POST = parse("POST"), @@ -46,10 +43,6 @@ pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is s try w.writeAll(str); } - pub fn format(value: Method, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) @TypeOf(writer).Error!void { - return try value.write(writer); - } - /// Returns true if a request of this method is allowed to have a body /// Actual behavior from servers may vary and should still be checked pub fn requestHasBody(self: Method) bool { @@ -309,9 +302,22 @@ pub const Connection = enum { close, }; +pub const Header = struct { + name: []const u8, + value: []const u8, +}; + +const builtin = @import("builtin"); +const std = @import("std.zig"); + test { _ = Client; _ = Method; _ = Server; _ = Status; + _ = HeadParser; + _ = ChunkParser; + if (builtin.os.tag != .wasi) { + _ = @import("http/test.zig"); + } } diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig new file mode 100644 index 000000000000..adcdc74bc7be --- /dev/null +++ b/lib/std/http/ChunkParser.zig @@ -0,0 +1,131 @@ +//! Parser for transfer-encoding: chunked. + +state: State, +chunk_len: u64, + +pub const init: ChunkParser = .{ + .state = .head_size, + .chunk_len = 0, +}; + +pub const State = enum { + head_size, + head_ext, + head_r, + data, + data_suffix, + data_suffix_r, + invalid, +}; + +/// Returns the number of bytes consumed by the chunk size. This is always +/// less than or equal to `bytes.len`. +/// +/// After this function returns, `chunk_len` will contain the parsed chunk size +/// in bytes when `state` is `data`. Alternately, `state` may become `invalid`, +/// indicating a syntax error in the input stream. +/// +/// If the amount returned is less than `bytes.len`, the parser is in the +/// `chunk_data` state and the first byte of the chunk is at `bytes[result]`. +/// +/// Asserts `state` is neither `data` nor `invalid`. +pub fn feed(p: *ChunkParser, bytes: []const u8) usize { + for (bytes, 0..) |c, i| switch (p.state) { + .data_suffix => switch (c) { + '\r' => p.state = .data_suffix_r, + '\n' => p.state = .head_size, + else => { + p.state = .invalid; + return i; + }, + }, + .data_suffix_r => switch (c) { + '\n' => p.state = .head_size, + else => { + p.state = .invalid; + return i; + }, + }, + .head_size => { + const digit = switch (c) { + '0'...'9' => |b| b - '0', + 'A'...'Z' => |b| b - 'A' + 10, + 'a'...'z' => |b| b - 'a' + 10, + '\r' => { + p.state = .head_r; + continue; + }, + '\n' => { + p.state = .data; + return i + 1; + }, + else => { + p.state = .head_ext; + continue; + }, + }; + + const new_len = p.chunk_len *% 16 +% digit; + if (new_len <= p.chunk_len and p.chunk_len != 0) { + p.state = .invalid; + return i; + } + + p.chunk_len = new_len; + }, + .head_ext => switch (c) { + '\r' => p.state = .head_r, + '\n' => { + p.state = .data; + return i + 1; + }, + else => continue, + }, + .head_r => switch (c) { + '\n' => { + p.state = .data; + return i + 1; + }, + else => { + p.state = .invalid; + return i; + }, + }, + .data => unreachable, + .invalid => unreachable, + }; + return bytes.len; +} + +const ChunkParser = @This(); +const std = @import("std"); + +test feed { + const testing = std.testing; + + const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; + + var p = init; + const first = p.feed(data[0..]); + try testing.expectEqual(@as(u32, 4), first); + try testing.expectEqual(@as(u64, 0xff), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const second = p.feed(data[first..]); + try testing.expectEqual(@as(u32, 13), second); + try testing.expectEqual(@as(u64, 0xf0f000), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const third = p.feed(data[first + second ..]); + try testing.expectEqual(@as(u32, 3), third); + try testing.expectEqual(@as(u64, 0), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const fourth = p.feed(data[first + second + third ..]); + try testing.expectEqual(@as(u32, 16), fourth); + try testing.expectEqual(@as(u64, 0xffffffffffffffff), p.chunk_len); + try testing.expectEqual(.invalid, p.state); +} diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index a50e814fd4cf..5f580bd53ea7 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -20,9 +20,7 @@ const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; -/// Allocator used for all allocations made by the client. -/// -/// This allocator must be thread-safe. +/// Used for all client allocations. Must be thread-safe. allocator: Allocator, ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, @@ -35,14 +33,25 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, -/// This is the proxy that will handle http:// connections. It *must not* be modified when the client has any active connections. -http_proxy: ?Proxy = null, - -/// This is the proxy that will handle https:// connections. It *must not* be modified when the client has any active connections. -https_proxy: ?Proxy = null, +/// If populated, all http traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +http_proxy: ?*Proxy = null, +/// If populated, all https traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +https_proxy: ?*Proxy = null, /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { + mutex: std.Thread.Mutex = .{}, + /// Open connections that are currently in use. + used: Queue = .{}, + /// Open connections that are not currently in use. + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = 32, + /// The criteria for a connection to be considered a match. pub const Criteria = struct { host: []const u8, @@ -53,14 +62,6 @@ pub const ConnectionPool = struct { const Queue = std.DoublyLinkedList(Connection); pub const Node = Queue.Node; - mutex: std.Thread.Mutex = .{}, - /// Open connections that are currently in use. - used: Queue = .{}, - /// Open connections that are not currently in use. - free: Queue = .{}, - free_len: usize = 0, - free_size: usize = 32, - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { @@ -189,11 +190,6 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - stream: net.Stream, /// undefined unless protocol is tls. tls_client: if (!disable_tls) *std.crypto.tls.Client else void, @@ -219,6 +215,11 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + + pub const Protocol = enum { plain, tls }; + pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { return conn.tls_client.readv(conn.stream, buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 @@ -406,31 +407,63 @@ pub const RequestTransfer = union(enum) { pub const Compression = union(enum) { pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); - pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + // https://github.com/ziglang/zig/issues/18937 + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); deflate: DeflateDecompressor, gzip: GzipDecompressor, - zstd: ZstdDecompressor, + // https://github.com/ziglang/zig/issues/18937 + //zstd: ZstdDecompressor, none: void, }; /// A HTTP response originating from a server. pub const Response = struct { - pub const ParseError = Allocator.Error || error{ + version: http.Version, + status: http.Status, + reason: []const u8, + + /// Points into the user-provided `server_header_buffer`. + location: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_type: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_disposition: ?[]const u8 = null, + + keep_alive: bool = false, + + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, + + /// If present, the transfer encoding of the response body, otherwise none. + transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). + transfer_compression: http.ContentEncoding = .identity, + + parser: proto.HeadersParser, + compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the + /// response body will be discarded. + skip: bool = false, + + pub const ParseError = error{ HttpHeadersInvalid, HttpHeaderContinuationsUnsupported, HttpTransferEncodingUnsupported, HttpConnectionHeaderUnsupported, InvalidContentLength, - CompressionNotSupported, + CompressionUnsupported, }; - pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void { - var it = mem.tokenizeAny(u8, bytes, "\r\n"); + pub fn parse(res: *Response, bytes: []const u8) ParseError!void { + var it = mem.splitSequence(u8, bytes, "\r\n"); - const first_line = it.next() orelse return error.HttpHeadersInvalid; - if (first_line.len < 12) + const first_line = it.next().?; + if (first_line.len < 12) { return error.HttpHeadersInvalid; + } const version: http.Version = switch (int64(first_line[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", @@ -445,24 +478,27 @@ pub const Response = struct { res.status = status; res.reason = reason; - res.headers.clearRetainingCapacity(); - while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; + if (line.len == 0) return; switch (line[0]) { ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, else => {}, } - var line_it = mem.tokenizeAny(u8, line, ": "); - const header_name = line_it.next() orelse return error.HttpHeadersInvalid; + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; const header_value = line_it.rest(); - - try res.headers.append(header_name, header_value); - - if (trailing) continue; - - if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + if (header_value.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { // Transfer-Encoding: second, first // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); @@ -508,6 +544,7 @@ pub const Response = struct { } } } + return error.HttpHeadersInvalid; // missing empty line } inline fn int64(array: *const [8]u8) u64 { @@ -531,60 +568,25 @@ pub const Response = struct { try expectEqual(@as(u10, 999), parseInt3("999")); } - /// The HTTP version this response is using. - version: http.Version, - - /// The status code of the response. - status: http.Status, - - /// The reason phrase of the response. - reason: []const u8, - - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, - - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, - - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - /// The headers received from the server. - headers: http.Headers, - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the response body will be discarded. - skip: bool = false, + pub fn iterateHeaders(r: Response) http.HeaderIterator { + return http.HeaderIterator.init(r.parser.get()); + } }; /// A HTTP request that has been sent. /// /// Order of operations: open -> send[ -> write -> finish] -> wait -> read pub const Request = struct { - /// The uri that this request is being sent to. uri: Uri, - - /// The client that this request was created from. client: *Client, - - /// Underlying connection to the server. This is null when the connection is released. + /// This is null when the connection is released. connection: ?*Connection, + keep_alive: bool, method: http.Method, version: http.Version = .@"HTTP/1.1", - - /// The list of HTTP request headers. - headers: http.Headers, - - /// The transfer encoding of the request body. - transfer_encoding: RequestTransfer = .none, - - /// The redirect quota left for this request. - redirects_left: u32, - - /// Whether the request should follow redirects. - handle_redirects: bool, + transfer_encoding: RequestTransfer, + redirect_behavior: RedirectBehavior, /// Whether the request should handle a 100-continue response before sending the request body. handle_continue: bool, @@ -594,25 +596,60 @@ pub const Request = struct { /// This field is undefined until `wait` is called. response: Response, - /// Used as a allocator for resolving redirects locations. - arena: std.heap.ArenaAllocator, + /// Standard headers that have default, but overridable, behavior. + headers: Headers, + + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header, + + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header, + + pub const Headers = struct { + host: Value = .default, + authorization: Value = .default, + user_agent: Value = .default, + connection: Value = .default, + accept_encoding: Value = .default, + content_type: Value = .default, + + pub const Value = union(enum) { + default, + omit, + override: []const u8, + }; + }; - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - switch (req.response.compression) { - .none => {}, - .deflate => {}, - .gzip => {}, - .zstd => |*zstd| zstd.deinit(), + /// Any value other than `not_allowed` or `unhandled` means that integer represents + /// how many remaining redirects are allowed. + pub const RedirectBehavior = enum(u16) { + /// The next redirect will cause an error. + not_allowed = 0, + /// Redirects are passed to the client to analyze the redirect response + /// directly. + unhandled = std.math.maxInt(u16), + _, + + pub fn subtractOne(rb: *RedirectBehavior) void { + switch (rb.*) { + .not_allowed => unreachable, + .unhandled => unreachable, + _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), + } } - req.headers.deinit(); - req.response.headers.deinit(); - - if (req.response.parser.header_bytes_owned) { - req.response.parser.header_bytes.deinit(req.client.allocator); + pub fn remaining(rb: RedirectBehavior) u16 { + assert(rb != .unhandled); + return @intFromEnum(rb); } + }; + /// Frees all resources associated with the request. + pub fn deinit(req: *Request) void { if (req.connection) |connection| { if (!req.response.parser.done) { // If the response wasn't fully read, then we need to close the connection. @@ -620,23 +657,15 @@ pub const Request = struct { } req.client.connection_pool.release(req.client.allocator, connection); } - - req.arena.deinit(); req.* = undefined; } - // This function must deallocate all resources associated with the request, or keep those which will be used - // This needs to be kept in sync with deinit and request + // This function must deallocate all resources associated with the request, + // or keep those which will be used. + // This needs to be kept in sync with deinit and request. fn redirect(req: *Request, uri: Uri) !void { assert(req.response.parser.done); - switch (req.response.compression) { - .none => {}, - .deflate => {}, - .gzip => {}, - .zstd => |*zstd| zstd.deinit(), - } - req.client.connection_pool.release(req.client.allocator, req.connection.?); req.connection = null; @@ -651,15 +680,13 @@ pub const Request = struct { req.uri = uri; req.connection = try req.client.connect(host, port, protocol); - req.redirects_left -= 1; - req.response.headers.clearRetainingCapacity(); + req.redirect_behavior.subtractOne(); req.response.parser.reset(); req.response = .{ .status = undefined, .reason = undefined, .version = undefined, - .headers = req.response.headers, .parser = req.response.parser, }; } @@ -667,15 +694,17 @@ pub const Request = struct { pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; pub const SendOptions = struct { - /// Specifies that the uri should be used as is. You guarantee that the uri is already escaped. + /// Specifies that the uri is already escaped. raw_uri: bool = false, }; /// Send the HTTP request headers to the server. pub fn send(req: *Request, options: SendOptions) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + return error.UnsupportedTransferEncoding; - const w = req.connection.?.writer(); + const connection = req.connection.?; + const w = connection.writer(); try req.method.write(w); try w.writeByte(' '); @@ -684,9 +713,9 @@ pub const Request = struct { try req.uri.writeToStream(.{ .authority = true }, w); } else { try req.uri.writeToStream(.{ - .scheme = req.connection.?.proxied, - .authentication = req.connection.?.proxied, - .authority = req.connection.?.proxied, + .scheme = connection.proxied, + .authentication = connection.proxied, + .authority = connection.proxied, .path = true, .query = true, .raw = options.raw_uri, @@ -696,97 +725,93 @@ pub const Request = struct { try w.writeAll(@tagName(req.version)); try w.writeAll("\r\n"); - if (!req.headers.contains("host")) { - try w.writeAll("Host: "); + if (try emitOverridableHeader("host: ", req.headers.host, w)) { + try w.writeAll("host: "); try req.uri.writeToStream(.{ .authority = true }, w); try w.writeAll("\r\n"); } - if ((req.uri.user != null or req.uri.password != null) and - !req.headers.contains("authorization")) - { - try w.writeAll("Authorization: "); - const authorization = try req.connection.?.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - std.debug.assert(basic_authorization.value(req.uri, authorization).len == authorization.len); - try w.writeAll("\r\n"); + if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { + if (req.uri.user != null or req.uri.password != null) { + try w.writeAll("authorization: "); + const authorization = try connection.allocWriteBuffer( + @intCast(basic_authorization.valueLengthFromUri(req.uri)), + ); + assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try w.writeAll("\r\n"); + } } - if (!req.headers.contains("user-agent")) { - try w.writeAll("User-Agent: zig/"); + if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + try w.writeAll("user-agent: zig/"); try w.writeAll(builtin.zig_version_string); try w.writeAll(" (std.http)\r\n"); } - if (!req.headers.contains("connection")) { - try w.writeAll("Connection: keep-alive\r\n"); + if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { + if (req.keep_alive) { + try w.writeAll("connection: keep-alive\r\n"); + } else { + try w.writeAll("connection: close\r\n"); + } } - if (!req.headers.contains("accept-encoding")) { - try w.writeAll("Accept-Encoding: gzip, deflate, zstd\r\n"); + if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { + // https://github.com/ziglang/zig/issues/18937 + //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); + try w.writeAll("accept-encoding: gzip, deflate\r\n"); } - if (!req.headers.contains("te")) { - try w.writeAll("TE: gzip, deflate, trailers\r\n"); + switch (req.transfer_encoding) { + .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), + .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), + .none => {}, } - const has_transfer_encoding = req.headers.contains("transfer-encoding"); - const has_content_length = req.headers.contains("content-length"); - - if (!has_transfer_encoding and !has_content_length) { - switch (req.transfer_encoding) { - .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), - .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), - .none => {}, - } - } else { - if (has_transfer_encoding) { - const transfer_encoding = req.headers.getFirstValue("transfer-encoding").?; - if (std.mem.eql(u8, transfer_encoding, "chunked")) { - req.transfer_encoding = .chunked; - } else { - return error.UnsupportedTransferEncoding; - } - } else if (has_content_length) { - const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - req.transfer_encoding = .{ .content_length = content_length }; - } else { - req.transfer_encoding = .none; - } + if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + // The default is to omit content-type if not provided because + // "application/octet-stream" is redundant. } - for (req.headers.list.items) |entry| { - if (entry.value.len == 0) continue; + for (req.extra_headers) |header| { + assert(header.value.len != 0); - try w.writeAll(entry.name); + try w.writeAll(header.name); try w.writeAll(": "); - try w.writeAll(entry.value); + try w.writeAll(header.value); try w.writeAll("\r\n"); } - if (req.connection.?.proxied) { - const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) { - .plain => if (req.client.http_proxy) |proxy| proxy.headers else null, - .tls => if (req.client.https_proxy) |proxy| proxy.headers else null, - }; - - if (proxy_headers) |headers| { - for (headers.list.items) |entry| { - if (entry.value.len == 0) continue; + if (connection.proxied) proxy: { + const proxy = switch (connection.protocol) { + .plain => req.client.http_proxy, + .tls => req.client.https_proxy, + } orelse break :proxy; - try w.writeAll(entry.name); - try w.writeAll(": "); - try w.writeAll(entry.value); - try w.writeAll("\r\n"); - } - } + const authorization = proxy.authorization orelse break :proxy; + try w.writeAll("proxy-authorization: "); + try w.writeAll(authorization); + try w.writeAll("\r\n"); } try w.writeAll("\r\n"); - try req.connection.?.flush(); + try connection.flush(); + } + + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + try w.writeAll(prefix); + try w.writeAll(x); + try w.writeAll("\r\n"); + return false; + }, + } } const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; @@ -810,145 +835,169 @@ pub const Request = struct { return index; } - pub const WaitError = RequestError || SendError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; + pub const WaitError = RequestError || SendError || TransferReadError || + proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || + error{ // TODO: file zig fmt issue for this bad indentation + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationInvalid, + CompressionInitializationFailed, + CompressionUnsupported, + }; /// Waits for a response from the server and parses any headers that are sent. /// This function will block until the final response is received. /// - /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow - /// redirects. If a request payload is present, then this function will error with error.RedirectRequiresResend. + /// If handling redirects and the request has no payload, then this + /// function will automatically follow redirects. If a request payload is + /// present, then this function will error with + /// error.RedirectRequiresResend. /// - /// Must be called after `send` and, if any data was written to the request body, then also after `finish`. + /// Must be called after `send` and, if any data was written to the request + /// body, then also after `finish`. pub fn wait(req: *Request) WaitError!void { + const connection = req.connection.?; + while (true) { // handle redirects while (true) { // read headers - try req.connection.?.fill(); + try connection.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); + const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); + connection.drop(@intCast(nchecked)); if (req.response.parser.state.isContent()) break; } - try req.response.parse(req.response.parser.header_bytes.items, false); + try req.response.parse(req.response.parser.get()); if (req.response.status == .@"continue") { - req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; req.response.parser.reset(); if (req.handle_continue) continue; - return; // we're not handling the 100-continue, return to the caller + return; // we're not handling the 100-continue } // we're switching protocols, so this connection is no longer doing http if (req.method == .CONNECT and req.response.status.class() == .success) { - req.connection.?.closing = false; + connection.closing = false; req.response.parser.done = true; - - return; // the connection is not HTTP past this point, return to the caller + return; // the connection is not HTTP past this point } - // we default to using keep-alive if not provided in the client if the server asks for it - const req_connection = req.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - - const res_connection = req.response.headers.getFirstValue("connection"); - const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - if (res_keepalive and (req_keepalive or req_connection == null)) { - req.connection.?.closing = false; - } else { - req.connection.?.closing = true; - } + connection.closing = !req.response.keep_alive or !req.keep_alive; - // Any response to a HEAD request and any response with a 1xx (Informational), 204 (No Content), or 304 (Not Modified) - // status code is always terminated by the first empty line after the header fields, regardless of the header fields - // present in the message - if (req.method == .HEAD or req.response.status.class() == .informational or req.response.status == .no_content or req.response.status == .not_modified) { + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { req.response.parser.done = true; - - return; // the response is empty, no further setup or redirection is necessary + return; // The response is empty; no further setup or redirection is necessary. } - if (req.response.transfer_encoding != .none) { - switch (req.response.transfer_encoding) { - .none => unreachable, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - } else if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, } - if (req.response.status.class() == .redirect and req.handle_redirects) { + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. req.response.skip = true; - - // skip the body of the redirect response, this will at least leave the connection in a known good state. - const empty = @as([*]u8, undefined)[0..0]; - assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary - - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - - const location = req.response.headers.getFirstValue("location") orelse - return error.HttpRedirectMissingLocation; - - const arena = req.arena.allocator(); - - const location_duped = try arena.dupe(u8, location); - - const new_url = Uri.parse(location_duped) catch try Uri.parseWithoutScheme(location_duped); - const resolved_url = try req.uri.resolve(new_url, false, arena); - - // is the redirect location on the same domain, or a subdomain of the original request? - const is_same_domain_or_subdomain = std.ascii.endsWithIgnoreCase(resolved_url.host.?, req.uri.host.?) and (resolved_url.host.?.len == req.uri.host.?.len or resolved_url.host.?[resolved_url.host.?.len - req.uri.host.?.len - 1] == '.'); - - if (resolved_url.host == null or !is_same_domain_or_subdomain or !std.ascii.eqlIgnoreCase(resolved_url.scheme, req.uri.scheme)) { - // we're redirecting to a different domain, strip privileged headers like cookies - _ = req.headers.delete("authorization"); - _ = req.headers.delete("www-authenticate"); - _ = req.headers.delete("cookie"); - _ = req.headers.delete("cookie2"); + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_buffer and uses that + // for the backing memory of the returned new_uri. + const header_buffer = req.response.parser.header_bytes_buffer; + const new_uri = req.uri.resolve_inplace(location, header_buffer) catch + return error.HttpRedirectLocationInvalid; + + // The new URI references the beginning of header_bytes_buffer memory. + // That memory will be kept, but everything after it will be + // reused by the subsequent request. In other words, + // header_bytes_buffer must be large enough to store all + // redirect locations as well as the final request header. + const path_end = new_uri.path.ptr + new_uri.path.len; + // https://github.com/ziglang/zig/issues/1738 + const path_offset = @intFromPtr(path_end) - @intFromPtr(header_buffer.ptr); + const end_offset = @max(path_offset, location.len); + req.response.parser.header_bytes_buffer = header_buffer[end_offset..]; + + const is_same_domain_or_subdomain = + std.ascii.endsWithIgnoreCase(new_uri.host.?, req.uri.host.?) and + (new_uri.host.?.len == req.uri.host.?.len or + new_uri.host.?[new_uri.host.?.len - req.uri.host.?.len - 1] == '.'); + + if (new_uri.host == null or !is_same_domain_or_subdomain or + !std.ascii.eqlIgnoreCase(new_uri.scheme, req.uri.scheme)) + { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; } - if (req.response.status == .see_other or ((req.response.status == .moved_permanently or req.response.status == .found) and req.method == .POST)) { - // we're redirecting to a GET, so we need to change the method and remove the body + if (switch (req.response.status) { + .see_other => true, + .moved_permanently, .found => req.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. req.method = .GET; req.transfer_encoding = .none; - _ = req.headers.delete("transfer-encoding"); - _ = req.headers.delete("content-length"); - _ = req.headers.delete("content-type"); + req.headers.content_type = .omit; } if (req.transfer_encoding != .none) { - return error.RedirectRequiresResend; // The request body has already been sent. The request is still in a valid state, but the redirect must be handled manually. + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; } - try req.redirect(resolved_url); - + try req.redirect(new_uri); try req.send(.{}); } else { req.response.skip = false; if (!req.response.parser.done) { switch (req.response.transfer_compression) { .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionNotSupported, + .compress, .@"x-compress" => return error.CompressionUnsupported, .deflate => req.response.compression = .{ .deflate = std.compress.zlib.decompressor(req.transferReader()), }, .gzip, .@"x-gzip" => req.response.compression = .{ .gzip = std.compress.gzip.decompressor(req.transferReader()), }, - .zstd => req.response.compression = .{ - .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, } } @@ -957,7 +1006,8 @@ pub const Request = struct { } } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || + error{ DecompressionFailure, InvalidTrailers }; pub const Reader = std.io.Reader(*Request, ReadError, read); @@ -970,28 +1020,20 @@ pub const Request = struct { const out_index = switch (req.response.compression) { .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, else => try req.transferRead(buffer), }; + if (out_index > 0) return out_index; - if (out_index == 0) { - const has_trail = !req.response.parser.state.isContent(); - - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - } + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.?.fill(); - if (has_trail) { - // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. - // This will *only* fail for a malformed trailer. - req.response.parse(req.response.parser.header_bytes.items, true) catch return error.InvalidTrailers; - } + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); } - return out_index; + return 0; } /// Reads data from the response body. Must be called after `wait`. @@ -1061,16 +1103,12 @@ pub const Request = struct { } }; -/// A HTTP proxy server. pub const Proxy = struct { - allocator: Allocator, - headers: http.Headers, - protocol: Connection.Protocol, host: []const u8, + authorization: ?[]const u8, port: u16, - - supports_connect: bool = true, + supports_connect: bool, }; /// Release all associated resources with the client. @@ -1082,116 +1120,71 @@ pub fn deinit(client: *Client) void { client.connection_pool.deinit(client.allocator); - if (client.http_proxy) |*proxy| { - proxy.allocator.free(proxy.host); - proxy.headers.deinit(); - } - - if (client.https_proxy) |*proxy| { - proxy.allocator.free(proxy.host); - proxy.headers.deinit(); - } - if (!disable_tls) client.ca_bundle.deinit(client.allocator); client.* = undefined; } -/// Uses the *_proxy environment variable to set any unset proxies for the client. -/// This function *must not* be called when the client has any active connections. -pub fn loadDefaultProxies(client: *Client) !void { +/// Populates `http_proxy` and `http_proxy` via standard proxy environment variables. +/// Asserts the client has no active connections. +/// Uses `arena` for a few small allocations that must outlive the client, or +/// at least until those fields are set to different values. +pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { // Prevent any new connections from being created. client.connection_pool.mutex.lock(); defer client.connection_pool.mutex.unlock(); - assert(client.connection_pool.used.first == null); // There are still active requests. + assert(client.connection_pool.used.first == null); // There are active requests. - if (client.http_proxy == null) http: { - const content: []const u8 = if (std.process.hasEnvVarConstant("http_proxy")) - try std.process.getEnvVarOwned(client.allocator, "http_proxy") - else if (std.process.hasEnvVarConstant("HTTP_PROXY")) - try std.process.getEnvVarOwned(client.allocator, "HTTP_PROXY") - else if (std.process.hasEnvVarConstant("all_proxy")) - try std.process.getEnvVarOwned(client.allocator, "all_proxy") - else if (std.process.hasEnvVarConstant("ALL_PROXY")) - try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") - else - break :http; - defer client.allocator.free(content); - - const uri = Uri.parse(content) catch - Uri.parseWithoutScheme(content) catch - break :http; - - const protocol = if (uri.scheme.len == 0) - .plain // No scheme, assume http:// - else - protocol_map.get(uri.scheme) orelse break :http; // Unknown scheme, ignore - - const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :http; // Missing host, ignore - client.http_proxy = .{ - .allocator = client.allocator, - .headers = .{ .allocator = client.allocator }, - - .protocol = protocol, - .host = host, - .port = uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }, - }; + if (client.http_proxy == null) { + client.http_proxy = try createProxyFromEnvVar(arena, &.{ + "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", + }); + } - if (uri.user != null or uri.password != null) { - const authorization = try client.allocator.alloc(u8, basic_authorization.valueLengthFromUri(uri)); - errdefer client.allocator.free(authorization); - std.debug.assert(basic_authorization.value(uri, authorization).len == authorization.len); - try client.http_proxy.?.headers.appendOwned(.{ .unowned = "proxy-authorization" }, .{ .owned = authorization }); - } + if (client.https_proxy == null) { + client.https_proxy = try createProxyFromEnvVar(arena, &.{ + "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", + }); } +} - if (client.https_proxy == null) https: { - const content: []const u8 = if (std.process.hasEnvVarConstant("https_proxy")) - try std.process.getEnvVarOwned(client.allocator, "https_proxy") - else if (std.process.hasEnvVarConstant("HTTPS_PROXY")) - try std.process.getEnvVarOwned(client.allocator, "HTTPS_PROXY") - else if (std.process.hasEnvVarConstant("all_proxy")) - try std.process.getEnvVarOwned(client.allocator, "all_proxy") - else if (std.process.hasEnvVarConstant("ALL_PROXY")) - try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") - else - break :https; - defer client.allocator.free(content); - - const uri = Uri.parse(content) catch - Uri.parseWithoutScheme(content) catch - break :https; - - const protocol = if (uri.scheme.len == 0) - .plain // No scheme, assume http:// - else - protocol_map.get(uri.scheme) orelse break :https; // Unknown scheme, ignore - - const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :https; // Missing host, ignore - client.https_proxy = .{ - .allocator = client.allocator, - .headers = .{ .allocator = client.allocator }, - - .protocol = protocol, - .host = host, - .port = uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }, +fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { + const content = for (env_var_names) |name| { + break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { + error.EnvironmentVariableNotFound => continue, + else => |e| return e, }; + } else return null; - if (uri.user != null or uri.password != null) { - const authorization = try client.allocator.alloc(u8, basic_authorization.valueLengthFromUri(uri)); - errdefer client.allocator.free(authorization); - std.debug.assert(basic_authorization.value(uri, authorization).len == authorization.len); - try client.https_proxy.?.headers.appendOwned(.{ .unowned = "proxy-authorization" }, .{ .owned = authorization }); - } - } + const uri = Uri.parse(content) catch try Uri.parseWithoutScheme(content); + + const protocol = if (uri.scheme.len == 0) + .plain // No scheme, assume http:// + else + protocol_map.get(uri.scheme) orelse return null; // Unknown scheme, ignore + + const host = uri.host orelse return error.HttpProxyMissingHost; + + const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri)); + assert(basic_authorization.value(uri, authorization).len == authorization.len); + break :a authorization; + } else null; + + const proxy = try arena.create(Proxy); + proxy.* = .{ + .protocol = protocol, + .host = host, + .authorization = authorization, + .port = uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }, + .supports_connect = true, + }; + return proxy; } pub const basic_authorization = struct { @@ -1213,8 +1206,8 @@ pub const basic_authorization = struct { } pub fn value(uri: Uri, out: []u8) []u8 { - std.debug.assert(uri.user == null or uri.user.?.len <= max_user_len); - std.debug.assert(uri.password == null or uri.password.?.len <= max_password_len); + assert(uri.user == null or uri.user.?.len <= max_user_len); + assert(uri.password == null or uri.password.?.len <= max_password_len); @memcpy(out[0..prefix.len], prefix); @@ -1288,14 +1281,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec return &conn.data; } -pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; +pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError; /// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. /// /// This function is threadsafe. pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { - if (!net.has_unix_sockets) return error.Unsupported; - if (client.connection_pool.findConnection(.{ .host = path, .port = 0, @@ -1325,7 +1316,8 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return &conn.data; } -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP CONNECT. This will reuse a connection if one is already open. +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// CONNECT. This will reuse a connection if one is already open. /// /// This function is threadsafe. pub fn connectTunnel( @@ -1351,7 +1343,7 @@ pub fn connectTunnel( client.connection_pool.release(client.allocator, conn); } - const uri = Uri{ + const uri: Uri = .{ .scheme = "http", .user = null, .password = null, @@ -1362,13 +1354,11 @@ pub fn connectTunnel( .fragment = null, }; - // we can use a small buffer here because a CONNECT response should be very small var buffer: [8096]u8 = undefined; - - var req = client.open(.CONNECT, uri, proxy.headers, .{ - .handle_redirects = false, + var req = client.open(.CONNECT, uri, .{ + .redirect_behavior = .unhandled, .connection = conn, - .header_strategy = .{ .static = &buffer }, + .server_header_buffer = &buffer, }) catch |err| { std.log.debug("err {}", .{err}); break :tunnel err; @@ -1407,45 +1397,51 @@ pub fn connectTunnel( const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. -/// If a proxy is configured for the client, then the proxy will be used to connect to the host. +/// Connect to `host:port` using the specified protocol. This will reuse a +/// connection if one is already open. +/// If a proxy is configured for the client, then the proxy will be used to +/// connect to the host. /// /// This function is threadsafe. -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { - // pointer required so that `supports_connect` can be updated if a CONNECT fails - const potential_proxy: ?*Proxy = switch (protocol) { - .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, - .tls => if (client.https_proxy) |*proxy_info| proxy_info else null, - }; - - if (potential_proxy) |proxy| { - // don't attempt to proxy the proxy thru itself. - if (std.mem.eql(u8, proxy.host, host) and proxy.port == port and proxy.protocol == protocol) { - return client.connectTcp(host, port, protocol); - } - - if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - error.TunnelNotSupported => break :tunnel, - else => |e| return e, - }; - } +pub fn connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, +) ConnectError!*Connection { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.connectTcp(host, port, protocol); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.connectTcp(host, port, protocol); + } - // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } + if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + } - conn.proxied = true; - return conn; + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(conn); } - return client.connectTcp(host, port, protocol); + conn.proxied = true; + return conn; } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || std.fmt.ParseIntError || Connection.WriteError || error{ +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || + std.fmt.ParseIntError || Connection.WriteError || + error{ // TODO: file a zig fmt issue for this bad indentation UnsupportedUrlScheme, UriMissingHost, @@ -1456,36 +1452,44 @@ pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendE pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", - /// Automatically ignore 100 Continue responses. This assumes you don't care, and will have sent the body before you - /// wait for the response. + /// Automatically ignore 100 Continue responses. This assumes you don't + /// care, and will have sent the body before you wait for the response. /// - /// If this is not the case AND you know the server will send a 100 Continue, set this to false and wait for a - /// response before sending the body. If you wait AND the server does not send a 100 Continue before you finish the - /// request, then the request *will* deadlock. + /// If this is not the case AND you know the server will send a 100 + /// Continue, set this to false and wait for a response before sending the + /// body. If you wait AND the server does not send a 100 Continue before + /// you finish the request, then the request *will* deadlock. handle_continue: bool = true, - /// Automatically follow redirects. This will only follow redirects for repeatable requests (ie. with no payload or the server has acknowledged the payload) - handle_redirects: bool = true, + /// If false, close the connection after the one request. If true, + /// participate in the client connection pool. + keep_alive: bool = true, + + /// This field specifies whether to automatically follow redirects, and if + /// so, how many redirects to follow before returning an error. + /// + /// This will only follow redirects for repeatable requests (ie. with no + /// payload or the server has acknowledged the payload). + redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - /// How many redirects to follow before returning an error. - max_redirects: u32 = 3, - header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 }, + /// Externally-owned memory used to store the server's entire HTTP header. + /// `error.HttpHeadersOversize` is returned from read() when a + /// client sends too many bytes of HTTP headers. + server_header_buffer: []u8, /// Must be an already acquired connection. connection: ?*Connection = null, - pub const StorageStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, - }; + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, }; pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ @@ -1498,11 +1502,29 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ /// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// /// `uri` must remain alive during the entire request. -/// `headers` is cloned and may be freed after this function returns. /// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. -pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request { +/// +/// Asserts that "\r\n" does not occur in any header name or value. +pub fn open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = uri.port orelse switch (protocol) { @@ -1530,163 +1552,131 @@ pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Header .uri = uri, .client = client, .connection = conn, - .headers = try headers.clone(client.allocator), // Headers must be cloned to properly handle header transformations in redirects. + .keep_alive = options.keep_alive, .method = method, .version = options.version, - .redirects_left = options.max_redirects, - .handle_redirects = options.handle_redirects, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, .handle_continue = options.handle_continue, .response = .{ .status = undefined, .reason = undefined, .version = undefined, - .headers = http.Headers{ .allocator = client.allocator, .owned = false }, - .parser = switch (options.header_strategy) { - .dynamic => |max| proto.HeadersParser.initDynamic(max), - .static => |buf| proto.HeadersParser.initStatic(buf), - }, + .parser = proto.HeadersParser.init(options.server_header_buffer), }, - .arena = undefined, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, }; errdefer req.deinit(); - req.arena = std.heap.ArenaAllocator.init(client.allocator); - return req; } pub const FetchOptions = struct { + server_header_buffer: ?[]u8 = null, + redirect_behavior: ?Request.RedirectBehavior = null, + + /// If the server sends a body, it will be appended to this ArrayList. + /// `max_append_size` provides an upper limit for how much they can grow. + response_storage: ResponseStorage = .ignore, + max_append_size: ?usize = null, + + location: Location, + method: ?http.Method = null, + payload: ?[]const u8 = null, + raw_uri: bool = false, + keep_alive: bool = true, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, + pub const Location = union(enum) { url: []const u8, uri: Uri, }; - pub const Payload = union(enum) { - string: []const u8, - file: std.fs.File, - none, + pub const ResponseStorage = union(enum) { + ignore, + /// Only the existing capacity will be used. + static: *std.ArrayListUnmanaged(u8), + dynamic: *std.ArrayList(u8), }; - - pub const ResponseStrategy = union(enum) { - storage: RequestOptions.StorageStrategy, - file: std.fs.File, - none, - }; - - header_strategy: RequestOptions.StorageStrategy = .{ .dynamic = 16 * 1024 }, - response_strategy: ResponseStrategy = .{ .storage = .{ .dynamic = 16 * 1024 * 1024 } }, - - location: Location, - method: http.Method = .GET, - headers: http.Headers = http.Headers{ .allocator = std.heap.page_allocator, .owned = false }, - payload: Payload = .none, - raw_uri: bool = false, }; pub const FetchResult = struct { status: http.Status, - body: ?[]const u8 = null, - headers: http.Headers, - - allocator: Allocator, - options: FetchOptions, - - pub fn deinit(res: *FetchResult) void { - if (res.options.response_strategy == .storage and res.options.response_strategy.storage == .dynamic) { - if (res.body) |body| res.allocator.free(body); - } - - res.headers.deinit(); - } }; /// Perform a one-shot HTTP request with the provided options. /// /// This function is threadsafe. -pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !FetchResult { - const has_transfer_encoding = options.headers.contains("transfer-encoding"); - const has_content_length = options.headers.contains("content-length"); - - if (has_content_length or has_transfer_encoding) return error.UnsupportedHeader; - +pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { const uri = switch (options.location) { .url => |u| try Uri.parse(u), .uri => |u| u, }; - - var req = try open(client, options.method, uri, options.headers, .{ - .header_strategy = options.header_strategy, - .handle_redirects = options.payload == .none, + var server_header_buffer: [16 * 1024]u8 = undefined; + + const method: http.Method = options.method orelse + if (options.payload != null) .POST else .GET; + + var req = try open(client, method, uri, .{ + .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, + .redirect_behavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + .keep_alive = options.keep_alive, }); defer req.deinit(); - { // Block to maintain lock of file to attempt to prevent a race condition where another process modifies the file while we are reading it. - // This relies on other processes actually obeying the advisory lock, which is not guaranteed. - if (options.payload == .file) try options.payload.file.lock(.shared); - defer if (options.payload == .file) options.payload.file.unlock(); + if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - switch (options.payload) { - .string => |str| req.transfer_encoding = .{ .content_length = str.len }, - .file => |file| req.transfer_encoding = .{ .content_length = (try file.stat()).size }, - .none => {}, - } - - try req.send(.{ .raw_uri = options.raw_uri }); + try req.send(.{ .raw_uri = options.raw_uri }); - switch (options.payload) { - .string => |str| try req.writeAll(str), - .file => |file| { - try file.seekTo(0); - var fifo = std.fifo.LinearFifo(u8, .{ .Static = 8192 }).init(); - try fifo.pump(file.reader(), req.writer()); - }, - .none => {}, - } - - try req.finish(); - } + if (options.payload) |payload| try req.writeAll(payload); + try req.finish(); try req.wait(); - var res = FetchResult{ - .status = req.response.status, - .headers = try req.response.headers.clone(allocator), - - .allocator = allocator, - .options = options, - }; - - switch (options.response_strategy) { - .storage => |storage| switch (storage) { - .dynamic => |max| res.body = try req.reader().readAllAlloc(allocator, max), - .static => |buf| res.body = buf[0..try req.reader().readAll(buf)], + switch (options.response_storage) { + .ignore => { + // Take advantage of request internals to discard the response body + // and make the connection available for another request. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. }, - .file => |file| { - var fifo = std.fifo.LinearFifo(u8, .{ .Static = 8192 }).init(); - try fifo.pump(req.reader(), file.writer()); + .dynamic => |list| { + const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; + try req.reader().readAllArrayList(list, max_append_size); }, - .none => { // Take advantage of request internals to discard the response body and make the connection available for another request. - req.response.skip = true; - - const empty = @as([*]u8, undefined)[0..0]; - assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary + .static => |list| { + const buf = b: { + const buf = list.unusedCapacitySlice(); + if (options.max_append_size) |len| { + if (len < buf.len) break :b buf[0..len]; + } + break :b buf; + }; + list.items.len += try req.reader().readAll(buf); }, } - return res; + return .{ + .status = req.response.status, + }; } test { - const native_endian = comptime builtin.cpu.arch.endian(); - if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { - // https://github.com/ziglang/zig/issues/13782 - return error.SkipZigTest; - } - - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - if (builtin.zig_backend == .stage2_x86_64 and - !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx)) return error.SkipZigTest; - - std.testing.refAllDecls(@This()); + _ = &initDefaultProxies; } diff --git a/lib/std/http/HeadParser.zig b/lib/std/http/HeadParser.zig new file mode 100644 index 000000000000..bb49faa14b33 --- /dev/null +++ b/lib/std/http/HeadParser.zig @@ -0,0 +1,371 @@ +//! Finds the end of an HTTP head in a stream. + +state: State = .start, + +pub const State = enum { + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, +}; + +/// Returns the number of bytes consumed by headers. This is always less +/// than or equal to `bytes.len`. +/// +/// If the amount returned is less than `bytes.len`, the parser is in a +/// content state and the first byte of content is located at +/// `bytes[result]`. +pub fn feed(p: *HeadParser, bytes: []const u8) usize { + const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); + var index: usize = 0; + + while (true) { + switch (p.state) { + .finished => return index, + .start => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + 3 => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + return index + 3; + }, + 4...vector_len - 1 => { + const b32 = int32(bytes[index..][0..4]); + const b24 = intShift(u24, b32); + const b16 = intShift(u16, b32); + const b8 = intShift(u8, b32); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + + switch (b32) { + int32("\r\n\r\n") => p.state = .finished, + else => {}, + } + + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..vector_len]; + const matches = if (use_vectors) matches: { + const Vector = @Vector(vector_len, u8); + // const BoolVector = @Vector(vector_len, bool); + const BitVector = @Vector(vector_len, u1); + const SizeVector = @Vector(vector_len, u8); + + const v: Vector = chunk.*; + const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); + const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); + const matches_or: SizeVector = matches_r | matches_n; + + break :matches @reduce(.Add, matches_or); + } else matches: { + var matches: u8 = 0; + for (chunk) |byte| switch (byte) { + '\r', '\n' => matches += 1, + else => {}, + }; + break :matches matches; + }; + switch (matches) { + 0 => {}, + 1 => switch (chunk[vector_len - 1]) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + }, + 2 => { + const b16 = int16(chunk[vector_len - 2 ..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + }, + 3 => { + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + 4...vector_len => { + inline for (0..vector_len - 3) |i_usize| { + const i = @as(u32, @truncate(i_usize)); + + const b32 = int32(chunk[i..][0..4]); + const b16 = intShift(u16, b32); + + if (b32 == int32("\r\n\r\n")) { + p.state = .finished; + return index + i + 4; + } else if (b16 == int16("\n\n")) { + p.state = .finished; + return index + i + 2; + } + } + + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => p.state = .seen_rnr, + else => {}, + } + }, + else => unreachable, + } + + index += vector_len; + continue; + }, + }, + .seen_n => switch (bytes.len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + .seen_r => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => p.state = .seen_rn, + '\r' => p.state = .seen_r, + else => p.state = .start, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_rn, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\r") => p.state = .seen_rnr, + int16("\n\n") => p.state = .finished, + else => {}, + } + + return index + 2; + }, + else => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => p.state = .seen_r, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .seen_rn, + int16("\n\n") => p.state = .finished, + else => {}, + } + + switch (b24) { + int24("\n\r\n") => p.state = .finished, + else => {}, + } + + index += 3; + continue; + }, + }, + .seen_rn => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + return index + 1; + }, + else => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => p.state = .seen_rnr, + '\n' => p.state = .seen_n, + else => p.state = .start, + } + + switch (b16) { + int16("\r\n") => p.state = .finished, + int16("\n\n") => p.state = .finished, + else => {}, + } + + index += 2; + continue; + }, + }, + .seen_rnr => switch (bytes.len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => p.state = .finished, + else => p.state = .start, + } + + index += 1; + continue; + }, + }, + } + + return index; + } +} + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(array.*); +} + +inline fn int24(array: *const [3]u8) u24 { + return @bitCast(array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(array.*); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .little => return @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))), + .big => return @truncate(x), + } +} + +const HeadParser = @This(); +const std = @import("std"); +const use_vectors = builtin.zig_backend != .stage2_x86_64; +const builtin = @import("builtin"); + +test feed { + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; + + for (0..36) |i| { + var p: HeadParser = .{}; + try std.testing.expectEqual(i, p.feed(data[0..i])); + try std.testing.expectEqual(35 - i, p.feed(data[i..])); + } +} diff --git a/lib/std/http/HeaderIterator.zig b/lib/std/http/HeaderIterator.zig new file mode 100644 index 000000000000..8d36374f8c62 --- /dev/null +++ b/lib/std/http/HeaderIterator.zig @@ -0,0 +1,62 @@ +bytes: []const u8, +index: usize, +is_trailer: bool, + +pub fn init(bytes: []const u8) HeaderIterator { + return .{ + .bytes = bytes, + .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2, + .is_trailer = false, + }; +} + +pub fn next(it: *HeaderIterator) ?std.http.Header { + const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?; + var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": "); + const name = kv_it.next().?; + const value = kv_it.rest(); + if (value.len == 0) { + if (it.is_trailer) return null; + const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse + return null; + it.is_trailer = true; + it.index = next_end + 2; + kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": "); + return .{ + .name = kv_it.next().?, + .value = kv_it.rest(), + }; + } + it.index = end + 2; + return .{ + .name = name, + .value = value, + }; +} + +test next { + var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n"); + try std.testing.expect(!it.is_trailer); + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("a", header.name); + try std.testing.expectEqualStrings("b", header.value); + } + { + const header = it.next().?; + try std.testing.expect(!it.is_trailer); + try std.testing.expectEqualStrings("c", header.name); + try std.testing.expectEqualStrings("d", header.value); + } + { + const header = it.next().?; + try std.testing.expect(it.is_trailer); + try std.testing.expectEqualStrings("e", header.name); + try std.testing.expectEqualStrings("f", header.value); + } + try std.testing.expectEqual(null, it.next()); +} + +const HeaderIterator = @This(); +const std = @import("../std.zig"); diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig deleted file mode 100644 index 653ec05126de..000000000000 --- a/lib/std/http/Headers.zig +++ /dev/null @@ -1,527 +0,0 @@ -const std = @import("../std.zig"); - -const Allocator = std.mem.Allocator; - -const testing = std.testing; -const ascii = std.ascii; -const assert = std.debug.assert; - -pub const HeaderList = std.ArrayListUnmanaged(Field); -pub const HeaderIndexList = std.ArrayListUnmanaged(usize); -pub const HeaderIndex = std.HashMapUnmanaged([]const u8, HeaderIndexList, CaseInsensitiveStringContext, std.hash_map.default_max_load_percentage); - -pub const CaseInsensitiveStringContext = struct { - pub fn hash(self: @This(), s: []const u8) u64 { - _ = self; - var buf: [64]u8 = undefined; - var i: usize = 0; - - var h = std.hash.Wyhash.init(0); - while (i + 64 < s.len) : (i += 64) { - const ret = ascii.lowerString(buf[0..], s[i..][0..64]); - h.update(ret); - } - - const left = @min(64, s.len - i); - const ret = ascii.lowerString(buf[0..], s[i..][0..left]); - h.update(ret); - - return h.final(); - } - - pub fn eql(self: @This(), a: []const u8, b: []const u8) bool { - _ = self; - return ascii.eqlIgnoreCase(a, b); - } -}; - -/// A single HTTP header field. -pub const Field = struct { - name: []const u8, - value: []const u8, - - fn lessThan(ctx: void, a: Field, b: Field) bool { - _ = ctx; - if (a.name.ptr == b.name.ptr) return false; - - return ascii.lessThanIgnoreCase(a.name, b.name); - } -}; - -/// A list of HTTP header fields. -pub const Headers = struct { - allocator: Allocator, - list: HeaderList = .{}, - index: HeaderIndex = .{}, - - /// When this is false, names and values will not be duplicated. - /// Use with caution. - owned: bool = true, - - /// Initialize an empty list of headers. - pub fn init(allocator: Allocator) Headers { - return .{ .allocator = allocator }; - } - - /// Initialize a pre-populated list of headers from a list of fields. - pub fn initList(allocator: Allocator, list: []const Field) !Headers { - var new = Headers.init(allocator); - - try new.list.ensureTotalCapacity(allocator, list.len); - try new.index.ensureTotalCapacity(allocator, @intCast(list.len)); - for (list) |field| { - try new.append(field.name, field.value); - } - - return new; - } - - /// Deallocate all memory associated with the headers. - /// - /// If the `owned` field is false, this will not free the names and values of the headers. - pub fn deinit(headers: *Headers) void { - headers.deallocateIndexListsAndFields(); - headers.index.deinit(headers.allocator); - headers.list.deinit(headers.allocator); - - headers.* = undefined; - } - - /// Appends a header to the list. - /// - /// If the `owned` field is true, both name and value will be copied. - pub fn append(headers: *Headers, name: []const u8, value: []const u8) !void { - try headers.appendOwned(.{ .unowned = name }, .{ .unowned = value }); - } - - pub const OwnedString = union(enum) { - /// A string allocated by the `allocator` field. - owned: []u8, - /// A string to be copied by the `allocator` field. - unowned: []const u8, - }; - - /// Appends a header to the list. - /// - /// If the `owned` field is true, `name` and `value` will be copied if unowned. - pub fn appendOwned(headers: *Headers, name: OwnedString, value: OwnedString) !void { - const n = headers.list.items.len; - try headers.list.ensureUnusedCapacity(headers.allocator, 1); - - const owned_value = switch (value) { - .owned => |owned| owned, - .unowned => |unowned| if (headers.owned) - try headers.allocator.dupe(u8, unowned) - else - unowned, - }; - errdefer if (value == .unowned and headers.owned) headers.allocator.free(owned_value); - - var entry = Field{ .name = undefined, .value = owned_value }; - - if (headers.index.getEntry(switch (name) { - inline else => |string| string, - })) |kv| { - defer switch (name) { - .owned => |owned| headers.allocator.free(owned), - .unowned => {}, - }; - - entry.name = kv.key_ptr.*; - try kv.value_ptr.append(headers.allocator, n); - } else { - const owned_name = switch (name) { - .owned => |owned| owned, - .unowned => |unowned| if (headers.owned) - try std.ascii.allocLowerString(headers.allocator, unowned) - else - unowned, - }; - errdefer if (name == .unowned and headers.owned) headers.allocator.free(owned_name); - - entry.name = owned_name; - - var new_index = try HeaderIndexList.initCapacity(headers.allocator, 1); - errdefer new_index.deinit(headers.allocator); - - new_index.appendAssumeCapacity(n); - try headers.index.put(headers.allocator, owned_name, new_index); - } - - headers.list.appendAssumeCapacity(entry); - } - - /// Returns true if this list of headers contains the given name. - pub fn contains(headers: Headers, name: []const u8) bool { - return headers.index.contains(name); - } - - /// Removes all headers with the given name. - pub fn delete(headers: *Headers, name: []const u8) bool { - if (headers.index.fetchRemove(name)) |kv| { - var index = kv.value; - - // iterate backwards - var i = index.items.len; - while (i > 0) { - i -= 1; - const data_index = index.items[i]; - const removed = headers.list.orderedRemove(data_index); - - assert(ascii.eqlIgnoreCase(removed.name, name)); // ensure the index hasn't been corrupted - if (headers.owned) headers.allocator.free(removed.value); - } - - if (headers.owned) headers.allocator.free(kv.key); - index.deinit(headers.allocator); - headers.rebuildIndex(); - - return true; - } else { - return false; - } - } - - /// Returns the index of the first occurrence of a header with the given name. - pub fn firstIndexOf(headers: Headers, name: []const u8) ?usize { - const index = headers.index.get(name) orelse return null; - - return index.items[0]; - } - - /// Returns a list of indices containing headers with the given name. - pub fn getIndices(headers: Headers, name: []const u8) ?[]const usize { - const index = headers.index.get(name) orelse return null; - - return index.items; - } - - /// Returns the entry of the first occurrence of a header with the given name. - pub fn getFirstEntry(headers: Headers, name: []const u8) ?Field { - const first_index = headers.firstIndexOf(name) orelse return null; - - return headers.list.items[first_index]; - } - - /// Returns a slice containing each header with the given name. - /// The caller owns the returned slice, but NOT the values in the slice. - pub fn getEntries(headers: Headers, allocator: Allocator, name: []const u8) !?[]const Field { - const indices = headers.getIndices(name) orelse return null; - - const buf = try allocator.alloc(Field, indices.len); - for (indices, 0..) |idx, n| { - buf[n] = headers.list.items[idx]; - } - - return buf; - } - - /// Returns the value in the entry of the first occurrence of a header with the given name. - pub fn getFirstValue(headers: Headers, name: []const u8) ?[]const u8 { - const first_index = headers.firstIndexOf(name) orelse return null; - - return headers.list.items[first_index].value; - } - - /// Returns a slice containing the value of each header with the given name. - /// The caller owns the returned slice, but NOT the values in the slice. - pub fn getValues(headers: Headers, allocator: Allocator, name: []const u8) !?[]const []const u8 { - const indices = headers.getIndices(name) orelse return null; - - const buf = try allocator.alloc([]const u8, indices.len); - for (indices, 0..) |idx, n| { - buf[n] = headers.list.items[idx].value; - } - - return buf; - } - - fn rebuildIndex(headers: *Headers) void { - // clear out the indexes - var it = headers.index.iterator(); - while (it.next()) |entry| { - entry.value_ptr.shrinkRetainingCapacity(0); - } - - // fill up indexes again; we know capacity is fine from before - for (headers.list.items, 0..) |entry, i| { - headers.index.getEntry(entry.name).?.value_ptr.appendAssumeCapacity(i); - } - } - - /// Sorts the headers in lexicographical order. - pub fn sort(headers: *Headers) void { - std.mem.sort(Field, headers.list.items, {}, Field.lessThan); - headers.rebuildIndex(); - } - - /// Writes the headers to the given stream. - pub fn format( - headers: Headers, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - _ = fmt; - _ = options; - - for (headers.list.items) |entry| { - if (entry.value.len == 0) continue; - - try out_stream.writeAll(entry.name); - try out_stream.writeAll(": "); - try out_stream.writeAll(entry.value); - try out_stream.writeAll("\r\n"); - } - } - - /// Writes all of the headers with the given name to the given stream, separated by commas. - /// - /// This is useful for headers like `Set-Cookie` which can have multiple values. RFC 9110, Section 5.2 - pub fn formatCommaSeparated( - headers: Headers, - name: []const u8, - out_stream: anytype, - ) !void { - const indices = headers.getIndices(name) orelse return; - - try out_stream.writeAll(name); - try out_stream.writeAll(": "); - - for (indices, 0..) |idx, n| { - if (n != 0) try out_stream.writeAll(", "); - try out_stream.writeAll(headers.list.items[idx].value); - } - - try out_stream.writeAll("\r\n"); - } - - /// Frees all `HeaderIndexList`s within `index`. - /// Frees names and values of all fields if they are owned. - fn deallocateIndexListsAndFields(headers: *Headers) void { - var it = headers.index.iterator(); - while (it.next()) |entry| { - entry.value_ptr.deinit(headers.allocator); - - if (headers.owned) headers.allocator.free(entry.key_ptr.*); - } - - if (headers.owned) { - for (headers.list.items) |entry| { - headers.allocator.free(entry.value); - } - } - } - - /// Clears and frees the underlying data structures. - /// Frees names and values if they are owned. - pub fn clearAndFree(headers: *Headers) void { - headers.deallocateIndexListsAndFields(); - headers.index.clearAndFree(headers.allocator); - headers.list.clearAndFree(headers.allocator); - } - - /// Clears the underlying data structures while retaining their capacities. - /// Frees names and values if they are owned. - pub fn clearRetainingCapacity(headers: *Headers) void { - headers.deallocateIndexListsAndFields(); - headers.index.clearRetainingCapacity(); - headers.list.clearRetainingCapacity(); - } - - /// Creates a copy of the headers using the provided allocator. - pub fn clone(headers: Headers, allocator: Allocator) !Headers { - var new = Headers.init(allocator); - - try new.list.ensureTotalCapacity(allocator, headers.list.capacity); - try new.index.ensureTotalCapacity(allocator, headers.index.capacity()); - for (headers.list.items) |field| { - try new.append(field.name, field.value); - } - - return new; - } -}; - -test "Headers.append" { - var h = Headers{ .allocator = std.testing.allocator }; - defer h.deinit(); - - try h.append("foo", "bar"); - try h.append("hello", "world"); - - try testing.expect(h.contains("Foo")); - try testing.expect(!h.contains("Bar")); -} - -test "Headers.delete" { - var h = Headers{ .allocator = std.testing.allocator }; - defer h.deinit(); - - try h.append("foo", "bar"); - try h.append("hello", "world"); - - try testing.expect(h.contains("Foo")); - - _ = h.delete("Foo"); - - try testing.expect(!h.contains("foo")); -} - -test "Headers consistency" { - var h = Headers{ .allocator = std.testing.allocator }; - defer h.deinit(); - - try h.append("foo", "bar"); - try h.append("hello", "world"); - _ = h.delete("Foo"); - - try h.append("foo", "bar"); - try h.append("bar", "world"); - try h.append("foo", "baz"); - try h.append("baz", "hello"); - - try testing.expectEqual(@as(?usize, 0), h.firstIndexOf("hello")); - try testing.expectEqual(@as(?usize, 1), h.firstIndexOf("foo")); - try testing.expectEqual(@as(?usize, 2), h.firstIndexOf("bar")); - try testing.expectEqual(@as(?usize, 4), h.firstIndexOf("baz")); - try testing.expectEqual(@as(?usize, null), h.firstIndexOf("pog")); - - try testing.expectEqualSlices(usize, &[_]usize{0}, h.getIndices("hello").?); - try testing.expectEqualSlices(usize, &[_]usize{ 1, 3 }, h.getIndices("foo").?); - try testing.expectEqualSlices(usize, &[_]usize{2}, h.getIndices("bar").?); - try testing.expectEqualSlices(usize, &[_]usize{4}, h.getIndices("baz").?); - try testing.expectEqual(@as(?[]const usize, null), h.getIndices("pog")); - - try testing.expectEqualStrings("world", h.getFirstEntry("hello").?.value); - try testing.expectEqualStrings("bar", h.getFirstEntry("foo").?.value); - try testing.expectEqualStrings("world", h.getFirstEntry("bar").?.value); - try testing.expectEqualStrings("hello", h.getFirstEntry("baz").?.value); - - const hello_entries = (try h.getEntries(testing.allocator, "hello")).?; - defer testing.allocator.free(hello_entries); - try testing.expectEqualDeep(@as([]const Field, &[_]Field{ - .{ .name = "hello", .value = "world" }, - }), hello_entries); - - const foo_entries = (try h.getEntries(testing.allocator, "foo")).?; - defer testing.allocator.free(foo_entries); - try testing.expectEqualDeep(@as([]const Field, &[_]Field{ - .{ .name = "foo", .value = "bar" }, - .{ .name = "foo", .value = "baz" }, - }), foo_entries); - - const bar_entries = (try h.getEntries(testing.allocator, "bar")).?; - defer testing.allocator.free(bar_entries); - try testing.expectEqualDeep(@as([]const Field, &[_]Field{ - .{ .name = "bar", .value = "world" }, - }), bar_entries); - - const baz_entries = (try h.getEntries(testing.allocator, "baz")).?; - defer testing.allocator.free(baz_entries); - try testing.expectEqualDeep(@as([]const Field, &[_]Field{ - .{ .name = "baz", .value = "hello" }, - }), baz_entries); - - const pog_entries = (try h.getEntries(testing.allocator, "pog")); - try testing.expectEqual(@as(?[]const Field, null), pog_entries); - - try testing.expectEqualStrings("world", h.getFirstValue("hello").?); - try testing.expectEqualStrings("bar", h.getFirstValue("foo").?); - try testing.expectEqualStrings("world", h.getFirstValue("bar").?); - try testing.expectEqualStrings("hello", h.getFirstValue("baz").?); - try testing.expectEqual(@as(?[]const u8, null), h.getFirstValue("pog")); - - const hello_values = (try h.getValues(testing.allocator, "hello")).?; - defer testing.allocator.free(hello_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"world"}), hello_values); - - const foo_values = (try h.getValues(testing.allocator, "foo")).?; - defer testing.allocator.free(foo_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{ "bar", "baz" }), foo_values); - - const bar_values = (try h.getValues(testing.allocator, "bar")).?; - defer testing.allocator.free(bar_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"world"}), bar_values); - - const baz_values = (try h.getValues(testing.allocator, "baz")).?; - defer testing.allocator.free(baz_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"hello"}), baz_values); - - const pog_values = (try h.getValues(testing.allocator, "pog")); - try testing.expectEqual(@as(?[]const []const u8, null), pog_values); - - h.sort(); - - try testing.expectEqualSlices(usize, &[_]usize{0}, h.getIndices("bar").?); - try testing.expectEqualSlices(usize, &[_]usize{1}, h.getIndices("baz").?); - try testing.expectEqualSlices(usize, &[_]usize{ 2, 3 }, h.getIndices("foo").?); - try testing.expectEqualSlices(usize, &[_]usize{4}, h.getIndices("hello").?); - - const formatted_values = try std.fmt.allocPrint(testing.allocator, "{}", .{h}); - defer testing.allocator.free(formatted_values); - - try testing.expectEqualStrings("bar: world\r\nbaz: hello\r\nfoo: bar\r\nfoo: baz\r\nhello: world\r\n", formatted_values); - - var buf: [128]u8 = undefined; - var fbs = std.io.fixedBufferStream(&buf); - const writer = fbs.writer(); - - try h.formatCommaSeparated("foo", writer); - try testing.expectEqualStrings("foo: bar, baz\r\n", fbs.getWritten()); -} - -test "Headers.clearRetainingCapacity and clearAndFree" { - var h = Headers.init(std.testing.allocator); - defer h.deinit(); - - h.clearRetainingCapacity(); - - try h.append("foo", "bar"); - try h.append("bar", "world"); - try h.append("foo", "baz"); - try h.append("baz", "hello"); - try testing.expectEqual(@as(usize, 4), h.list.items.len); - try testing.expectEqual(@as(usize, 3), h.index.count()); - const list_capacity = h.list.capacity; - const index_capacity = h.index.capacity(); - - h.clearRetainingCapacity(); - try testing.expectEqual(@as(usize, 0), h.list.items.len); - try testing.expectEqual(@as(usize, 0), h.index.count()); - try testing.expectEqual(list_capacity, h.list.capacity); - try testing.expectEqual(index_capacity, h.index.capacity()); - - try h.append("foo", "bar"); - try h.append("bar", "world"); - try h.append("foo", "baz"); - try h.append("baz", "hello"); - try testing.expectEqual(@as(usize, 4), h.list.items.len); - try testing.expectEqual(@as(usize, 3), h.index.count()); - // Capacity should still be the same since we shouldn't have needed to grow - // when adding back the same fields - try testing.expectEqual(list_capacity, h.list.capacity); - try testing.expectEqual(index_capacity, h.index.capacity()); - - h.clearAndFree(); - try testing.expectEqual(@as(usize, 0), h.list.items.len); - try testing.expectEqual(@as(usize, 0), h.index.count()); - try testing.expectEqual(@as(usize, 0), h.list.capacity); - try testing.expectEqual(@as(usize, 0), h.index.capacity()); -} - -test "Headers.initList" { - var h = try Headers.initList(std.testing.allocator, &.{ - .{ .name = "Accept-Encoding", .value = "gzip" }, - .{ .name = "Authorization", .value = "it's over 9000!" }, - }); - defer h.deinit(); - - const encoding_values = (try h.getValues(testing.allocator, "Accept-Encoding")).?; - defer testing.allocator.free(encoding_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"gzip"}), encoding_values); - - const authorization_values = (try h.getValues(testing.allocator, "Authorization")).?; - defer testing.allocator.free(authorization_values); - try testing.expectEqualDeep(@as([]const []const u8, &[_][]const u8{"it's over 9000!"}), authorization_values); -} diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 46590417792f..2d360d40a4be 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,873 +1,1044 @@ -//! HTTP Server implementation. -//! -//! This server assumes *all* clients are well behaved and standard compliant; it can and will deadlock if a client holds a connection open without sending a request. -//! -//! Example usage: -//! -//! ```zig -//! var server = Server.init(.{ .reuse_address = true }); -//! defer server.deinit(); -//! -//! try server.listen(bind_addr); -//! -//! while (true) { -//! var res = try server.accept(.{ .allocator = gpa }); -//! defer res.deinit(); -//! -//! while (res.reset() != .closing) { -//! res.wait() catch |err| switch (err) { -//! error.HttpHeadersInvalid => break, -//! error.HttpHeadersExceededSizeLimit => { -//! res.status = .request_header_fields_too_large; -//! res.send() catch break; -//! break; -//! }, -//! else => { -//! res.status = .bad_request; -//! res.send() catch break; -//! break; -//! }, -//! } -//! -//! res.status = .ok; -//! res.transfer_encoding = .chunked; -//! -//! try res.send(); -//! try res.writeAll("Hello, World!\n"); -//! try res.finish(); -//! } -//! } -//! ``` - -const std = @import("../std.zig"); -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; - -const Server = @This(); -const proto = @import("protocol.zig"); - -/// The underlying server socket. -socket: net.StreamServer, - -/// An interface to a plain connection. -pub const Connection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - pub const Protocol = enum { plain }; +//! Blocking HTTP server implementation. +//! Handles a single connection's lifecycle. + +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, + +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate connection: keep-alive. + closing, +}; - stream: net.Stream, - protocol: Protocol, - - closing: bool = true, - - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| { - switch (err) { - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } +/// Initialize an HTTP server that can respond to multiple requests on the same +/// connection. +/// The returned `Server` is ready for `receiveHead` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { + return .{ + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, + }; +} - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, + /// Partial HTTP request was received but the connection was closed before + /// fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. + /// In other words, a keep-alive connection was finally closed. + HttpConnectionClosing, +}; - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @as(u16, @intCast(nread)); +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; + + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + rebase(s, 0); + } else { + s.read_buffer_len = 0; + } } - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } + var hp: http.HeadParser = .{}; - pub fn drop(conn: *Connection, num: u16) void { - conn.read_start += num; + if (s.read_buffer_len > 0) { + const bytes = s.read_buffer[0..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, end); } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); - - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + if (read_n == 0) { + if (s.read_buffer_len > 0) { + return error.HttpRequestTruncated; + } else { + return error.HttpConnectionClosing; } - - try conn.fill(); } - - return out_index; - } - - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = error{ - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, - }; - - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - // .tls => return conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - // .tls => return conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); } +} - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, +fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(s.read_buffer[0..head_end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, }; +} - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - pub fn close(conn: *Connection) void { - conn.stream.close(); - } -}; - -/// The mode of transport for responses. -pub const ResponseTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for request messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Response.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Response.TransferReader); - pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP request originating from a client. pub const Request = struct { - pub const ParseError = Allocator.Error || error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionNotSupported, + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + chunk_parser: http.ChunkParser, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, }; - pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.tokenizeAny(u8, bytes, "\r\n"); - - const first_line = it.next() orelse return error.HttpHeadersInvalid; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, + + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, }; - const target = first_line[method_end + 1 .. version_start]; - - req.method = method; - req.target = target; - req.version = version; - - while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.tokenizeAny(u8, line, ": "); - const header_name = line_it.next() orelse return error.HttpHeadersInvalid; - const header_value = line_it.rest(); - - try req.headers.append(header_name, header_value); - - if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (req.content_length != null) return error.HttpHeadersInvalid; - req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (req.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - req.transfer_encoding = transfer; - - next = iter.next(); + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; + + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; + + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; + + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + + const target = first_line[method_end + 1 .. version_start]; + + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = false, + .compression = .none, + }; + + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - req.transfer_compression = transfer; + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; + const header_value = line_it.rest(); + if (header_value.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; } - } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - const trimmed = mem.trim(u8, header_value, " "); + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - } - } - - inline fn int64(array: *const [8]u8) u64 { - return @as(u64, @bitCast(array.*)); - } - - /// The HTTP request method. - method: http.Method, - - /// The HTTP request target. - target: []const u8, - - /// The HTTP version of this request. - version: http.Version, - - /// The length of the request body, if known. - content_length: ?u64 = null, - - /// The transfer encoding of the request body, or .none if not present. - transfer_encoding: http.TransferEncoding = .none, - - /// The compression of the request body, or .identity (no compression) if not present. - transfer_compression: http.ContentEncoding = .identity, - - /// The list of HTTP request headers - headers: http.Headers, - - parser: proto.HeadersParser, - compression: Compression = .none, -}; - -/// A HTTP response waiting to be sent. -/// -/// Order of operations: -/// ``` -/// [/ <--------------------------------------- \] -/// accept -> wait -> send [ -> write -> finish][ -> reset /] -/// \ -> read / -/// ``` -pub const Response = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - - transfer_encoding: ResponseTransfer = .none, - - /// The allocator responsible for allocating memory for this response. - allocator: Allocator, - - /// The peer's address - address: net.Address, - - /// The underlying connection for this response. - connection: Connection, + next = iter.next(); + } - /// The HTTP response headers - headers: http.Headers, + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - /// The HTTP request that this response is responding to. - /// - /// This field is only valid after calling `wait`. - request: Request, + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } - state: State = .first, + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } + } + return error.MissingFinalNewline; + } - const State = enum { - first, - start, - waited, - responded, - finished, + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } }; - /// Free all resources associated with this response. - pub fn deinit(res: *Response) void { - res.connection.close(); - - res.headers.deinit(); - res.request.headers.deinit(); - - if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.allocator); - } + pub fn iterateHeaders(r: *Request) http.HeaderIterator { + return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); } - pub const ResetState = enum { reset, closing }; + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + transfer_encoding: ?http.TransferEncoding = null, + }; - /// Reset this response to its initial state. This must be called before handling a second request on the same connection. - pub fn reset(res: *Response) ResetState { - if (res.state == .first) { - res.state = .start; - return .reset; + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// + /// Unless `transfer_encoding` is specified, uses the "content-length" + /// header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + /// Asserts that "\r\n" does not occur in any header name or value. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } } - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection.closing = true; - return .closing; + const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and options.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; } + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch unreachable; - // A connection is only keep-alive if the Connection header is present and it's value is not "close". - // The server and client must both agree - // - // send() defaults to using keep-alive if the client requests it. - const res_connection = res.headers.getFirstValue("connection"); - const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (req_keepalive and (res_keepalive or res_connection == null)) { - res.connection.closing = false; - } else { - res.connection.closing = true; - } + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); - switch (res.request.compression) { + if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, - .deflate => {}, - .gzip => {}, - .zstd => |*zstd| zstd.deinit(), - } - - res.state = .start; - res.version = .@"HTTP/1.1"; - res.status = .ok; - res.reason = null; - - res.transfer_encoding = .none; - - res.headers.clearRetainingCapacity(); - - res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here - res.request.parser.reset(); - - res.request = Request{ - .version = undefined, - .method = undefined, - .target = undefined, - .headers = res.request.headers, - .parser = res.request.parser, - }; - - if (res.connection.closing) { - return .closing; + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), } else { - return .reset; + h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; } - } - pub const SendError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + var chunk_header_buffer: [18]u8 = undefined; + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; - /// Send the HTTP response headers to the client. - pub fn send(res: *Response) SendError!void { - switch (res.state) { - .waited => res.state = .responded, - .first, .start, .responded, .finished => unreachable, + iovecs[iovecs_len] = .{ + .iov_base = h.items.ptr, + .iov_len = h.items.len, + }; + iovecs_len += 1; + + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .iov_base = header.name.ptr, + .iov_len = header.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = header.value.ptr, + .iov_len = header.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; } - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); - - try w.writeAll(@tagName(res.version)); - try w.writeByte(' '); - try w.print("{d}", .{@intFromEnum(res.status)}); - try w.writeByte(' '); - if (res.reason) |reason| { - try w.writeAll(reason); - } else if (res.status.phrase()) |phrase| { - try w.writeAll(phrase); - } - try w.writeAll("\r\n"); + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + + if (request.head.method != .HEAD) { + const is_chunked = (options.transfer_encoding orelse .none) == .chunked; + if (is_chunked) { + if (content.len > 0) { + const chunk_header = std.fmt.bufPrint( + &chunk_header_buffer, + "{x}\r\n", + .{content.len}, + ) catch unreachable; + + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } - if (res.status == .@"continue") { - res.state = .waited; // we still need to send another request after this - } else { - if (!res.headers.contains("server")) { - try w.writeAll("Server: zig (std.http)\r\n"); + iovecs[iovecs_len] = .{ + .iov_base = "0\r\n\r\n", + .iov_len = 5, + }; + iovecs_len += 1; + } else if (content.len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; } + } - if (!res.headers.contains("connection")) { - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - - if (req_keepalive) { - try w.writeAll("Connection: keep-alive\r\n"); - } else { - try w.writeAll("Connection: close\r\n"); - } - } + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + } - const has_transfer_encoding = res.headers.contains("transfer-encoding"); - const has_content_length = res.headers.contains("content-length"); + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, + }; - if (!has_transfer_encoding and !has_content_length) { - switch (res.transfer_encoding) { - .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), - .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), - .none => {}, - } + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and o.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - res.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { - const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; - if (std.mem.eql(u8, transfer_encoding, "chunked")) { - res.transfer_encoding = .chunked; - } else { - return error.UnsupportedTransferEncoding; - } - } else { - res.transfer_encoding = .none; - } + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); } - try w.print("{}", .{res.headers}); - } - - if (res.request.method == .HEAD) { - res.transfer_encoding = .none; - } + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } - try w.writeAll("\r\n"); + h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; - try buffered.flush(); + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { + .chunked => .chunked, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .chunked, + .elide_body = elide_body, + .chunk_len = 0, + }; } - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; - const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; - fn transferReader(res: *Response) TransferReader { - return .{ .context = res }; + const remaining_content_length = &request.reader_state.remaining_content_length; + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + assert(s.state == .receiving_body); + const available = try fill(s, request.head_end); + const len = @min(remaining_content_length.*, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + remaining_content_length.* -= len; + s.next_request_start += len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; } - fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { - if (res.request.parser.done) return 0; + fn fill(s: *Server, head_end: usize) ReadError![]u8 { + const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; + if (available.len > 0) return available; + s.next_request_start = head_end; + s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + return s.read_buffer[head_end..s.read_buffer_len]; + } - var index: usize = 0; - while (index == 0) { - const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.done) break; - index += amt; + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const cp = &request.reader_state.chunk_parser; + const head_end = request.head_end; + + // Protect against returning 0 before the end of stream. + var out_end: usize = 0; + while (out_end == 0) { + switch (cp.state) { + .invalid => return 0, + .data => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const len = @min(cp.chunk_len, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += len; + continue; + }, + else => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const n = cp.feed(available); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + if (cp.chunk_len == 0) { + // The next bytes in the stream are trailers, + // or \r\n to indicate end of chunked body. + // + // This function must append the trailers at + // head_end so that headers and trailers are + // together. + // + // Since returning 0 would indicate end of + // stream, this function must read all the + // trailers before returning. + if (s.next_request_start > head_end) rebase(s, head_end); + var hp: http.HeadParser = .{}; + { + const bytes = s.read_buffer[head_end..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = try s.connection.stream.read(buf); + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + } + const data = available[n..]; + const len = @min(cp.chunk_len, data.len, buffer.len); + @memcpy(buffer[0..len], data[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += n + len; + continue; + }, + else => continue, + } + }, + } } - - return index; + return out_end; } - pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; - /// Wait for the client to send a complete request head. + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. /// - /// For correct behavior, the following rules must be followed: - /// - /// * If this returns any error in `Connection.ReadError`, you MUST immediately close the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersExceededSizeLimit`, you MUST respond with a 431 status code and then call `deinit`. - /// * If this returns any error in `Request.ParseError`, you MUST respond with a 400 status code and then call `deinit`. - /// * If this returns any other error, you MUST respond with a 400 status code and then call `deinit`. - /// * If the request has an Expect header containing 100-continue, you MUST either: - /// * Respond with a 100 status code, then call `wait` again. - /// * Respond with a 417 status code. - pub fn wait(res: *Response) WaitError!void { - switch (res.state) { - .first, .start => res.state = .waited, - .waited, .responded, .finished => unreachable, + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } } - while (true) { - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.drop(@as(u16, @intCast(nchecked))); - - if (res.request.parser.state.isContent()) break; + switch (request.head.transfer_encoding) { + .chunked => { + request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + return .{ + .readFn = read_chunked, + .context = request, + }; + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, } + } - res.request.headers = .{ .allocator = res.allocator, .owned = true }; - try res.request.parse(res.request.parser.header_bytes.items); - - if (res.request.transfer_encoding != .none) { - switch (res.request.transfer_encoding) { - .none => unreachable, - .chunked => { - res.request.parser.next_chunk_length = 0; - res.request.parser.state = .chunk_head_size; - }, - } - } else if (res.request.content_length) |cl| { - res.request.parser.next_chunk_length = cl; - - if (cl == 0) res.request.parser.done = true; + /// Returns whether the connection: keep-alive header should be sent to the client. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + const r = request.reader() catch return false; + _ = r.discard() catch return false; + assert(s.state == .ready); + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, } else { - res.request.parser.done = true; + s.state = .closing; + return false; } + } +}; - if (!res.request.parser.done) { - switch (res.request.transfer_compression) { - .identity => res.request.compression = .none, - .compress, .@"x-compress" => return error.CompressionNotSupported, - .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.decompressor(res.transferReader()), - }, - .gzip, .@"x-gzip" => res.request.compression = .{ - .gzip = std.compress.gzip.decompressor(res.transferReader()), - }, - .zstd => res.request.compression = .{ - .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), - }, - } +pub const Response = struct { + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + transfer_encoding: TransferEncoding, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, + + pub const TransferEncoding = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked, + }; + + pub const WriteError = net.Stream.WriteError; + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn end(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + try flush_cl(r); + }, + .none => { + try flush_cl(r); + }, + .chunked => { + try flush_chunked(r, &.{}); + }, } + r.* = undefined; } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Response, ReadError, read); + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, + }; - pub fn reader(res: *Response) Reader { - return .{ .context = res }; + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.transfer_encoding == .chunked); + try flush_chunked(r, options.trailers); + r.* = undefined; } - /// Reads data from the response body. Must be called after `wait`. - pub fn read(res: *Response, buffer: []u8) ReadError!usize { - switch (res.state) { - .waited, .responded, .finished => {}, - .first, .start => unreachable, + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + switch (r.transfer_encoding) { + .content_length, .none => return write_cl(r, bytes), + .chunked => return write_chunked(r, bytes), } + } - const out_index = switch (res.request.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try res.transferRead(buffer), - }; + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); - if (out_index == 0) { - const has_trail = !res.request.parser.state.isContent(); + var trash: u64 = std.math.maxInt(u64); + const len = switch (r.transfer_encoding) { + .content_length => |*len| len, + else => &trash, + }; - while (!res.request.parser.state.isContent()) { // read trailing headers - try res.connection.fill(); + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } - const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.drop(@as(u16, @intCast(nchecked))); + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); + + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; } - if (has_trail) { - res.request.headers = http.Headers{ .allocator = res.allocator, .owned = false }; - - // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. - // This will *only* fail for a malformed trailer. - res.request.parse(res.request.parser.header_bytes.items) catch return error.InvalidTrailers; - } + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; } - return out_index; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(res: *Response, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(res, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.transfer_encoding == .chunked); - pub const Writer = std.io.Writer(*Response, WriteError, write); + if (r.elide_body) + return bytes.len; - pub fn writer(res: *Response) Writer { - return .{ .context = res }; - } + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(res: *Response, bytes: []const u8) WriteError!usize { - switch (res.state) { - .responded => {}, - .first, .waited, .start, .finished => unreachable, + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }, + .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }, + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + .{ + .iov_base = "\r\n", + .iov_len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; } - switch (res.transfer_encoding) { - .chunked => { - try res.connection.writer().print("{x}\r\n", .{bytes.len}); - try res.connection.writeAll(bytes); - try res.connection.writeAll("\r\n"); - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try res.connection.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; } - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { - index += try write(req, bytes[index..]); + index += try write(r, bytes[index..]); } } - pub const FinishError = WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(res: *Response) FinishError!void { - switch (res.state) { - .responded => res.state = .finished, - .first, .waited, .start, .finished => unreachable, + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn flush(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .none, .content_length => return flush_cl(r), + .chunked => return flush_chunked(r, null), } + } - switch (res.transfer_encoding) { - .chunked => try res.connection.writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } + fn flush_cl(r: *Response) WriteError!void { + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; } -}; -/// Create a new HTTP server. -pub fn init(options: net.StreamServer.Options) Server { - return .{ - .socket = net.StreamServer.init(options), - }; -} + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.transfer_encoding == .chunked); -/// Free all resources associated with this server. -pub fn deinit(server: *Server) void { - server.socket.deinit(); -} + const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; -pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError; + if (r.elide_body) { + try r.stream.writeAll(http_headers); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return; + } -/// Start the HTTP server listening on the given address. -pub fn listen(server: *Server, address: net.Address) ListenError!void { - try server.socket.listen(address); -} + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; -pub const AcceptError = net.StreamServer.AcceptError || Allocator.Error; - -pub const HeaderStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, -}; + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; -pub const AcceptOptions = struct { - allocator: Allocator, - header_strategy: HeaderStrategy = .{ .dynamic = 8192 }, -}; + iovecs[iovecs_len] = .{ + .iov_base = http_headers.ptr, + .iov_len = http_headers.len, + }; + iovecs_len += 1; + + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } -/// Accept a new connection. -pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { - const in = try server.socket.accept(); - - return Response{ - .allocator = options.allocator, - .address = in.address, - .connection = .{ - .stream = in.stream, - .protocol = .plain, - }, - .headers = .{ .allocator = options.allocator }, - .request = .{ - .version = undefined, - .method = undefined, - .target = undefined, - .headers = .{ .allocator = options.allocator, .owned = false }, - .parser = switch (options.header_strategy) { - .dynamic => |max| proto.HeadersParser.initDynamic(max), - .static => |buf| proto.HeadersParser.initStatic(buf), - }, - }, - }; -} + if (end_trailers) |trailers| { + iovecs[iovecs_len] = .{ + .iov_base = "0\r\n", + .iov_len = 3, + }; + iovecs_len += 1; + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .iov_base = trailer.name.ptr, + .iov_len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = trailer.value.ptr, + .iov_len = trailer.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } -test "HTTP server handles a chunked transfer coding request" { - const builtin = @import("builtin"); + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } - // This test requires spawning threads. - if (builtin.single_threaded) { - return error.SkipZigTest; + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; } - const native_endian = comptime builtin.cpu.arch.endian(); - if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { - // https://github.com/ziglang/zig/issues/13782 - return error.SkipZigTest; + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = switch (r.transfer_encoding) { + .none, .content_length => write_cl, + .chunked => write_chunked, + }, + .context = r, + }; } +}; - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const allocator = std.testing.allocator; - const expect = std.testing.expect; - - const max_header_size = 8192; - var server = std.http.Server.init(.{ .reuse_address = true }); - defer server.deinit(); - - const address = try std.net.Address.parseIp("127.0.0.1", 0); - try server.listen(address); - const server_port = server.socket.listen_address.in.getPort(); - - const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.http.Server) !void { - var res = try s.accept(.{ - .allocator = allocator, - .header_strategy = .{ .dynamic = max_header_size }, - }); - defer res.deinit(); - defer _ = res.reset(); - try res.wait(); - - try expect(res.request.transfer_encoding == .chunked); - - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - try res.headers.append("content-type", "text/plain"); - try res.headers.append("connection", "close"); - try res.send(); - - var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); - try expect(std.mem.eql(u8, buf[0..n], "ABCD")); - _ = try res.writer().writeAll(server_body); - try res.finish(); - } - }).apply, .{&server}); - - const request_bytes = - "POST / HTTP/1.1\r\n" ++ - "Content-Type: text/plain\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "\r\n" ++ - "1\r\n" ++ - "A\r\n" ++ - "1\r\n" ++ - "B\r\n" ++ - "2\r\n" ++ - "CD\r\n" ++ - "0\r\n" ++ - "\r\n"; - - const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); - defer stream.close(); - _ = try stream.writeAll(request_bytes[0..]); - - server_thread.join(); +fn rebase(s: *Server, index: usize) void { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[index..][0..leftover.len]; + if (leftover.len <= s.next_request_start - index) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = index + leftover.len; } + +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; + +const Server = @This(); diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 0ccafd2ee5cb..78511f435d67 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -7,15 +7,19 @@ const assert = std.debug.assert; const use_vectors = builtin.zig_backend != .stage2_x86_64; pub const State = enum { - /// Begin header parsing states. invalid, + + // Begin header and trailer parsing states. + start, seen_n, seen_r, seen_rn, seen_rnr, finished, - /// Begin transfer-encoding: chunked parsing states. + + // Begin transfer-encoding: chunked parsing states. + chunk_head_size, chunk_head_ext, chunk_head_r, @@ -34,484 +38,114 @@ pub const State = enum { pub const HeadersParser = struct { state: State = .start, - /// Whether or not `header_bytes` is allocated or was provided as a fixed buffer. - header_bytes_owned: bool, - /// Either a fixed buffer of len `max_header_bytes` or a dynamic buffer that can grow up to `max_header_bytes`. + /// A fixed buffer of len `max_header_bytes`. /// Pointers into this buffer are not stable until after a message is complete. - header_bytes: std.ArrayListUnmanaged(u8), - /// The maximum allowed size of `header_bytes`. - max_header_bytes: usize, - next_chunk_length: u64 = 0, - /// Whether this parser is done parsing a complete message. - /// A message is only done when the entire payload has been read. - done: bool = false, - - /// Initializes the parser with a dynamically growing header buffer of up to `max` bytes. - pub fn initDynamic(max: usize) HeadersParser { - return .{ - .header_bytes = .{}, - .max_header_bytes = max, - .header_bytes_owned = true, - }; - } + header_bytes_buffer: []u8, + header_bytes_len: u32, + next_chunk_length: u64, + /// `false`: headers. `true`: trailers. + done: bool, /// Initializes the parser with a provided buffer `buf`. - pub fn initStatic(buf: []u8) HeadersParser { + pub fn init(buf: []u8) HeadersParser { return .{ - .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, - .max_header_bytes = buf.len, - .header_bytes_owned = false, + .header_bytes_buffer = buf, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, }; } - /// Completely resets the parser to it's initial state. - /// This must be called after a message is complete. - pub fn reset(r: *HeadersParser) void { - assert(r.done); // The message must be completely read before reset, otherwise the parser is in an invalid state. - - r.header_bytes.clearRetainingCapacity(); - - r.* = .{ - .header_bytes = r.header_bytes, - .max_header_bytes = r.max_header_bytes, - .header_bytes_owned = r.header_bytes_owned, + /// Reinitialize the parser. + /// Asserts the parser is in the "done" state. + pub fn reset(hp: *HeadersParser) void { + assert(hp.done); + hp.* = .{ + .state = .start, + .header_bytes_buffer = hp.header_bytes_buffer, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, }; } - /// Returns the number of bytes consumed by headers. This is always less than or equal to `bytes.len`. - /// You should check `r.state.isContent()` after this to check if the headers are done. - /// - /// If the amount returned is less than `bytes.len`, you may assume that the parser is in a content state and the - /// first byte of content is located at `bytes[result]`. - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); - const len: u32 = @intCast(bytes.len); - var index: u32 = 0; - - while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => return index, - .start => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - 3 => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - return index + 3; - }, - 4...vector_len - 1 => { - const b32 = int32(bytes[index..][0..4]); - const b24 = intShift(u24, b32); - const b16 = intShift(u16, b32); - const b8 = intShift(u8, b32); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - - switch (b32) { - int32("\r\n\r\n") => r.state = .finished, - else => {}, - } - - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..vector_len]; - const matches = if (use_vectors) matches: { - const Vector = @Vector(vector_len, u8); - // const BoolVector = @Vector(vector_len, bool); - const BitVector = @Vector(vector_len, u1); - const SizeVector = @Vector(vector_len, u8); - - const v: Vector = chunk.*; - const matches_r: BitVector = @bitCast(v == @as(Vector, @splat('\r'))); - const matches_n: BitVector = @bitCast(v == @as(Vector, @splat('\n'))); - const matches_or: SizeVector = matches_r | matches_n; - - break :matches @reduce(.Add, matches_or); - } else matches: { - var matches: u8 = 0; - for (chunk) |byte| switch (byte) { - '\r', '\n' => matches += 1, - else => {}, - }; - break :matches matches; - }; - switch (matches) { - 0 => {}, - 1 => switch (chunk[vector_len - 1]) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - }, - 2 => { - const b16 = int16(chunk[vector_len - 2 ..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - }, - 3 => { - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - 4...vector_len => { - inline for (0..vector_len - 3) |i_usize| { - const i = @as(u32, @truncate(i_usize)); - - const b32 = int32(chunk[i..][0..4]); - const b16 = intShift(u16, b32); - - if (b32 == int32("\r\n\r\n")) { - r.state = .finished; - return index + i + 4; - } else if (b16 == int16("\n\n")) { - r.state = .finished; - return index + i + 2; - } - } - - const b24 = int24(chunk[vector_len - 3 ..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => {}, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\r\n\r") => r.state = .seen_rnr, - else => {}, - } - }, - else => unreachable, - } - - index += vector_len; - continue; - }, - }, - .seen_n => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .seen_r => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - - return index + 1; - }, - 2 => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_rn, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\r") => r.state = .seen_rnr, - int16("\n\n") => r.state = .finished, - else => {}, - } - - return index + 2; - }, - else => { - const b24 = int24(bytes[index..][0..3]); - const b16 = intShift(u16, b24); - const b8 = intShift(u8, b24); - - switch (b8) { - '\r' => r.state = .seen_r, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .seen_rn, - int16("\n\n") => r.state = .finished, - else => {}, - } - - switch (b24) { - int24("\n\r\n") => r.state = .finished, - else => {}, - } - - index += 3; - continue; - }, - }, - .seen_rn => switch (len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - return index + 1; - }, - else => { - const b16 = int16(bytes[index..][0..2]); - const b8 = intShift(u8, b16); - - switch (b8) { - '\r' => r.state = .seen_rnr, - '\n' => r.state = .seen_n, - else => r.state = .start, - } - - switch (b16) { - int16("\r\n") => r.state = .finished, - int16("\n\n") => r.state = .finished, - else => {}, - } - - index += 2; - continue; - }, - }, - .seen_rnr => switch (len - index) { - 0 => return index, - else => { - switch (bytes[index]) { - '\n' => r.state = .finished, - else => r.state = .start, - } - - index += 1; - continue; - }, - }, - .chunk_head_size => unreachable, - .chunk_head_ext => unreachable, - .chunk_head_r => unreachable, - .chunk_data => unreachable, - .chunk_data_suffix => unreachable, - .chunk_data_suffix_r => unreachable, - } + pub fn get(hp: HeadersParser) []u8 { + return hp.header_bytes_buffer[0..hp.header_bytes_len]; + } - return index; - } + pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { + var hp: std.http.HeadParser = .{ + .state = switch (r.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + else => unreachable, + }, + }; + const result = hp.feed(bytes); + r.state = switch (hp.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + }; + return @intCast(result); } - /// Returns the number of bytes consumed by the chunk size. This is always less than or equal to `bytes.len`. - /// You should check `r.state == .chunk_data` after this to check if the chunk size has been fully parsed. - /// - /// If the amount returned is less than `bytes.len`, you may assume that the parser is in the `chunk_data` state - /// and that the first byte of the chunk is at `bytes[result]`. pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - const len = @as(u32, @intCast(bytes.len)); - - for (bytes[0..], 0..) |c, i| { - const index = @as(u32, @intCast(i)); - switch (r.state) { - .chunk_data_suffix => switch (c) { - '\r' => r.state = .chunk_data_suffix_r, - '\n' => r.state = .chunk_head_size, - else => { - r.state = .invalid; - return index; - }, - }, - .chunk_data_suffix_r => switch (c) { - '\n' => r.state = .chunk_head_size, - else => { - r.state = .invalid; - return index; - }, - }, - .chunk_head_size => { - const digit = switch (c) { - '0'...'9' => |b| b - '0', - 'A'...'Z' => |b| b - 'A' + 10, - 'a'...'z' => |b| b - 'a' + 10, - '\r' => { - r.state = .chunk_head_r; - continue; - }, - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => { - r.state = .chunk_head_ext; - continue; - }, - }; - - const new_len = r.next_chunk_length *% 16 +% digit; - if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) { - r.state = .invalid; - return index; - } - - r.next_chunk_length = new_len; - }, - .chunk_head_ext => switch (c) { - '\r' => r.state = .chunk_head_r, - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => continue, - }, - .chunk_head_r => switch (c) { - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => { - r.state = .invalid; - return index; - }, - }, + var cp: std.http.ChunkParser = .{ + .state = switch (r.state) { + .chunk_head_size => .head_size, + .chunk_head_ext => .head_ext, + .chunk_head_r => .head_r, + .chunk_data => .data, + .chunk_data_suffix => .data_suffix, + .chunk_data_suffix_r => .data_suffix_r, + .invalid => .invalid, else => unreachable, - } - } - - return len; + }, + .chunk_len = r.next_chunk_length, + }; + const result = cp.feed(bytes); + r.state = switch (cp.state) { + .head_size => .chunk_head_size, + .head_ext => .chunk_head_ext, + .head_r => .chunk_head_r, + .data => .chunk_data, + .data_suffix => .chunk_data_suffix, + .data_suffix_r => .chunk_data_suffix_r, + .invalid => .invalid, + }; + r.next_chunk_length = cp.chunk_len; + return @intCast(result); } - /// Returns whether or not the parser has finished parsing a complete message. A message is only complete after the - /// entire body has been read and any trailing headers have been parsed. + /// Returns whether or not the parser has finished parsing a complete + /// message. A message is only complete after the entire body has been read + /// and any trailing headers have been parsed. pub fn isComplete(r: *HeadersParser) bool { return r.done and r.state == .finished; } - pub const CheckCompleteHeadError = mem.Allocator.Error || error{HttpHeadersExceededSizeLimit}; + pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - /// Pushes `in` into the parser. Returns the number of bytes consumed by the header. Any header bytes are appended - /// to the `header_bytes` buffer. - /// - /// This function only uses `allocator` if `r.header_bytes_owned` is true, and may be undefined otherwise. - pub fn checkCompleteHead(r: *HeadersParser, allocator: std.mem.Allocator, in: []const u8) CheckCompleteHeadError!u32 { - if (r.state.isContent()) return 0; + /// Pushes `in` into the parser. Returns the number of bytes consumed by + /// the header. Any header bytes are appended to `header_bytes_buffer`. + pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { + if (hp.state.isContent()) return 0; - const i = r.findHeadersEnd(in); + const i = hp.findHeadersEnd(in); const data = in[0..i]; - if (r.header_bytes.items.len + data.len > r.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } else { - if (r.header_bytes_owned) try r.header_bytes.ensureUnusedCapacity(allocator, data.len); + if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) + return error.HttpHeadersOversize; - r.header_bytes.appendSliceAssumeCapacity(data); - } + @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); + hp.header_bytes_len += @intCast(data.len); return i; } @@ -520,7 +154,8 @@ pub const HeadersParser = struct { HttpChunkInvalid, }; - /// Reads the body of the message into `buffer`. Returns the number of bytes placed in the buffer. + /// Reads the body of the message into `buffer`. Returns the number of + /// bytes placed in the buffer. /// /// If `skip` is true, the buffer will be unused and the body will be skipped. /// @@ -571,9 +206,10 @@ pub const HeadersParser = struct { .chunk_data => if (r.next_chunk_length == 0) { if (std.mem.eql(u8, conn.peek(), "\r\n")) { r.state = .finished; - r.done = true; + conn.drop(2); } else { - // The trailer section is formatted identically to the header section. + // The trailer section is formatted identically + // to the header section. r.state = .seen_rn; } r.done = true; @@ -713,57 +349,11 @@ const MockBufferedConnection = struct { } }; -test "HeadersParser.findHeadersEnd" { - var r: HeadersParser = undefined; - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; - - for (0..36) |i| { - r = HeadersParser.initDynamic(0); - try std.testing.expectEqual(@as(u32, @intCast(i)), r.findHeadersEnd(data[0..i])); - try std.testing.expectEqual(@as(u32, @intCast(35 - i)), r.findHeadersEnd(data[i..])); - } -} - -test "HeadersParser.findChunkedLen" { - var r: HeadersParser = undefined; - const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; - - r = HeadersParser.initDynamic(0); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const first = r.findChunkedLen(data[0..]); - try testing.expectEqual(@as(u32, 4), first); - try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const second = r.findChunkedLen(data[first..]); - try testing.expectEqual(@as(u32, 13), second); - try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const third = r.findChunkedLen(data[first + second ..]); - try testing.expectEqual(@as(u32, 3), third); - try testing.expectEqual(@as(u64, 0), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const fourth = r.findChunkedLen(data[first + second + third ..]); - try testing.expectEqual(@as(u32, 16), fourth); - try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length); - try testing.expectEqual(State.invalid, r.state); -} - test "HeadersParser.read length" { // mock BufferedConnection for read + var headers_buf: [256]u8 = undefined; - var r = HeadersParser.initDynamic(256); - defer r.header_bytes.deinit(std.testing.allocator); + var r = HeadersParser.init(&headers_buf); const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; var conn: MockBufferedConnection = .{ @@ -773,8 +363,8 @@ test "HeadersParser.read length" { while (true) { // read headers try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); - conn.drop(@as(u16, @intCast(nchecked))); + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); if (r.state.isContent()) break; } @@ -786,14 +376,14 @@ test "HeadersParser.read length" { try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqualStrings("Hello", buf[0..len]); - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.header_bytes.items); + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); } test "HeadersParser.read chunked" { // mock BufferedConnection for read - var r = HeadersParser.initDynamic(256); - defer r.header_bytes.deinit(std.testing.allocator); + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; var conn: MockBufferedConnection = .{ @@ -803,8 +393,8 @@ test "HeadersParser.read chunked" { while (true) { // read headers try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); - conn.drop(@as(u16, @intCast(nchecked))); + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); if (r.state.isContent()) break; } @@ -815,14 +405,14 @@ test "HeadersParser.read chunked" { try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqualStrings("Hello", buf[0..len]); - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.header_bytes.items); + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); } test "HeadersParser.read chunked trailer" { // mock BufferedConnection for read - var r = HeadersParser.initDynamic(256); - defer r.header_bytes.deinit(std.testing.allocator); + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; var conn: MockBufferedConnection = .{ @@ -832,8 +422,8 @@ test "HeadersParser.read chunked trailer" { while (true) { // read headers try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); - conn.drop(@as(u16, @intCast(nchecked))); + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); if (r.state.isContent()) break; } @@ -847,11 +437,11 @@ test "HeadersParser.read chunked trailer" { while (true) { // read headers try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); - conn.drop(@as(u16, @intCast(nchecked))); + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); if (r.state.isContent()) break; } - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.header_bytes.items); + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig new file mode 100644 index 000000000000..e36b0cdf280f --- /dev/null +++ b/lib/std/http/test.zig @@ -0,0 +1,995 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const http = std.http; +const mem = std.mem; +const native_endian = builtin.cpu.arch.endian(); +const expect = std.testing.expect; +const expectEqual = std.testing.expectEqual; +const expectEqualStrings = std.testing.expectEqualStrings; +const expectError = std.testing.expectError; + +test "trailers" { + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var header_buffer: [1024]u8 = undefined; + var remaining: usize = 1; + while (remaining != 0) : (remaining -= 1) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = http.Server.init(conn, &header_buffer); + + try expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try serve(&request); + try expectEqual(.ready, server.state); + } + } + + fn serve(request: *http.Server.Request) !void { + try expectEqualStrings(request.head.target, "/trailer"); + + var send_buffer: [1024]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + }); + try response.writeAll("Hello, "); + try response.flush(); + try response.writeAll("World!\n"); + try response.flush(); + try response.endChunked(.{ + .trailers = &.{ + .{ .name = "X-Checksum", .value = "aaaa" }, + }, + }); + } + }); + defer test_server.destroy(); + + const gpa = std.testing.allocator; + + var client: http.Client = .{ .allocator = gpa }; + defer client.deinit(); + + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/trailer", .{ + test_server.port(), + }); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + { + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + + var it = req.response.iterateHeaders(); + { + const header = it.next().?; + try expect(!it.is_trailer); + try expectEqualStrings("connection", header.name); + try expectEqualStrings("keep-alive", header.value); + } + { + const header = it.next().?; + try expect(!it.is_trailer); + try expectEqualStrings("transfer-encoding", header.name); + try expectEqualStrings("chunked", header.value); + } + { + const header = it.next().?; + try expect(it.is_trailer); + try expectEqualStrings("X-Checksum", header.name); + try expectEqualStrings("aaaa", header.value); + } + try expectEqual(null, it.next()); + } + + // connection has been kept alive + try expect(client.connection_pool.free_len == 1); +} + +test "HTTP server handles a chunked transfer coding request" { + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) !void { + var header_buffer: [8192]u8 = undefined; + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = http.Server.init(conn, &header_buffer); + var request = try server.receiveHead(); + + try expect(request.head.transfer_encoding == .chunked); + + var buf: [128]u8 = undefined; + const n = try (try request.reader()).readAll(&buf); + try expect(mem.eql(u8, buf[0..n], "ABCD")); + + try request.respond("message from server!\n", .{ + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + .keep_alive = false, + }); + } + }); + defer test_server.destroy(); + + const request_bytes = + "POST / HTTP/1.1\r\n" ++ + "Content-Type: text/plain\r\n" ++ + "Transfer-Encoding: chunked\r\n" ++ + "\r\n" ++ + "1\r\n" ++ + "A\r\n" ++ + "1\r\n" ++ + "B\r\n" ++ + "2\r\n" ++ + "CD\r\n" ++ + "0\r\n" ++ + "\r\n"; + + const gpa = std.testing.allocator; + const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + defer stream.close(); + try stream.writeAll(request_bytes); +} + +test "echo content server" { + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var read_buffer: [1024]u8 = undefined; + + accept: while (true) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var http_server = http.Server.init(conn, &read_buffer); + + while (http_server.state == .ready) { + var request = http_server.receiveHead() catch |err| switch (err) { + error.HttpConnectionClosing => continue :accept, + else => |e| return e, + }; + if (mem.eql(u8, request.head.target, "/end")) { + return request.respond("", .{ .keep_alive = false }); + } + if (request.head.expect) |expect_header_value| { + if (mem.eql(u8, expect_header_value, "garbage")) { + try expectError(error.HttpExpectationFailed, request.reader()); + try request.respond("", .{ .keep_alive = false }); + continue; + } + } + handleRequest(&request) catch |err| { + // This message helps the person troubleshooting determine whether + // output comes from the server thread or the client thread. + std.debug.print("handleRequest failed with '{s}'\n", .{@errorName(err)}); + return err; + }; + } + } + } + + fn handleRequest(request: *http.Server.Request) !void { + //std.debug.print("server received {s} {s} {s}\n", .{ + // @tagName(request.head.method), + // @tagName(request.head.version), + // request.head.target, + //}); + + const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); + defer std.testing.allocator.free(body); + + try expect(mem.startsWith(u8, request.head.target, "/echo-content")); + try expectEqualStrings("Hello, World!\n", body); + try expectEqualStrings("text/plain", request.head.content_type.?); + + var send_buffer: [100]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .content_length = switch (request.head.transfer_encoding) { + .chunked => null, + .none => len: { + try expectEqual(14, request.head.content_length.?); + break :len 14; + }, + }, + }); + + try response.flush(); // Test an early flush to send the HTTP headers before the body. + const w = response.writer(); + try w.writeAll("Hello, "); + try w.writeAll("World!\n"); + try response.end(); + //std.debug.print(" server finished responding\n", .{}); + } + }); + defer test_server.destroy(); + + { + var client: http.Client = .{ .allocator = std.testing.allocator }; + defer client.deinit(); + + try echoTests(&client, test_server.port()); + } +} + +test "Server.Request.respondStreaming non-chunked, unknown content-length" { + // In this case, the response is expected to stream until the connection is + // closed, indicating the end of the body. + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var header_buffer: [1000]u8 = undefined; + var remaining: usize = 1; + while (remaining != 0) : (remaining -= 1) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = http.Server.init(conn, &header_buffer); + + try expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try expectEqualStrings(request.head.target, "/foo"); + var send_buffer: [500]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .respond_options = .{ + .transfer_encoding = .none, + }, + }); + var total: usize = 0; + for (0..500) |i| { + var buf: [30]u8 = undefined; + const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); + try response.writeAll(line); + total += line.len; + } + try expectEqual(7390, total); + try response.end(); + try expectEqual(.closing, server.state); + } + } + }); + defer test_server.destroy(); + + const request_bytes = "GET /foo HTTP/1.1\r\n\r\n"; + const gpa = std.testing.allocator; + const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + defer stream.close(); + try stream.writeAll(request_bytes); + + const response = try stream.reader().readAllAlloc(gpa, 8192); + defer gpa.free(response); + + var expected_response = std.ArrayList(u8).init(gpa); + defer expected_response.deinit(); + + try expected_response.appendSlice("HTTP/1.1 200 OK\r\n\r\n"); + + { + var total: usize = 0; + for (0..500) |i| { + var buf: [30]u8 = undefined; + const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); + try expected_response.appendSlice(line); + total += line.len; + } + try expectEqual(7390, total); + } + + try expectEqualStrings(expected_response.items, response); +} + +test "receiving arbitrary http headers from the client" { + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var read_buffer: [666]u8 = undefined; + var remaining: usize = 1; + while (remaining != 0) : (remaining -= 1) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = http.Server.init(conn, &read_buffer); + try expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try expectEqualStrings("/bar", request.head.target); + var it = request.iterateHeaders(); + { + const header = it.next().?; + try expectEqualStrings("CoNneCtIoN", header.name); + try expectEqualStrings("close", header.value); + try expect(!it.is_trailer); + } + { + const header = it.next().?; + try expectEqualStrings("aoeu", header.name); + try expectEqualStrings("asdf", header.value); + try expect(!it.is_trailer); + } + try request.respond("", .{}); + } + } + }); + defer test_server.destroy(); + + const request_bytes = "GET /bar HTTP/1.1\r\n" ++ + "CoNneCtIoN: close\r\n" ++ + "aoeu: asdf\r\n" ++ + "\r\n"; + const gpa = std.testing.allocator; + const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + defer stream.close(); + try stream.writeAll(request_bytes); + + const response = try stream.reader().readAllAlloc(gpa, 8192); + defer gpa.free(response); + + var expected_response = std.ArrayList(u8).init(gpa); + defer expected_response.deinit(); + + try expected_response.appendSlice("HTTP/1.1 200 OK\r\n"); + try expected_response.appendSlice("content-length: 0\r\n\r\n"); + try expectEqualStrings(expected_response.items, response); +} + +test "general client/server API coverage" { + if (builtin.os.tag == .windows) { + // This test was never passing on Windows. + return error.SkipZigTest; + } + + const global = struct { + var handle_new_requests = true; + }; + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var client_header_buffer: [1024]u8 = undefined; + outer: while (global.handle_new_requests) { + var connection = try net_server.accept(); + defer connection.stream.close(); + + var http_server = http.Server.init(connection, &client_header_buffer); + + while (http_server.state == .ready) { + var request = http_server.receiveHead() catch |err| switch (err) { + error.HttpConnectionClosing => continue :outer, + else => |e| return e, + }; + + try handleRequest(&request, net_server.listen_address.getPort()); + } + } + } + + fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { + const log = std.log.scoped(.server); + + log.info("{} {s} {s}", .{ + request.head.method, + @tagName(request.head.version), + request.head.target, + }); + + const gpa = std.testing.allocator; + const body = try (try request.reader()).readAllAlloc(gpa, 8192); + defer gpa.free(body); + + var send_buffer: [100]u8 = undefined; + + if (mem.startsWith(u8, request.head.target, "/get")) { + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) + 14 + else + null, + .respond_options = .{ + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + }, + }); + const w = response.writer(); + try w.writeAll("Hello, "); + try w.writeAll("World!\n"); + try response.end(); + // Writing again would cause an assertion failure. + } else if (mem.startsWith(u8, request.head.target, "/large")) { + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .content_length = 14 * 1024 + 14 * 10, + }); + + try response.flush(); // Test an early flush to send the HTTP headers before the body. + + const w = response.writer(); + + var i: u32 = 0; + while (i < 5) : (i += 1) { + try w.writeAll("Hello, World!\n"); + } + + try w.writeAll("Hello, World!\n" ** 1024); + + i = 0; + while (i < 5) : (i += 1) { + try w.writeAll("Hello, World!\n"); + } + + try response.end(); + } else if (mem.eql(u8, request.head.target, "/redirect/1")) { + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .respond_options = .{ + .status = .found, + .extra_headers = &.{ + .{ .name = "location", .value = "../../get" }, + }, + }, + }); + + const w = response.writer(); + try w.writeAll("Hello, "); + try w.writeAll("Redirected!\n"); + try response.end(); + } else if (mem.eql(u8, request.head.target, "/redirect/2")) { + try request.respond("Hello, Redirected!\n", .{ + .status = .found, + .extra_headers = &.{ + .{ .name = "location", .value = "/redirect/1" }, + }, + }); + } else if (mem.eql(u8, request.head.target, "/redirect/3")) { + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{ + listen_port, + }); + defer gpa.free(location); + + try request.respond("Hello, Redirected!\n", .{ + .status = .found, + .extra_headers = &.{ + .{ .name = "location", .value = location }, + }, + }); + } else if (mem.eql(u8, request.head.target, "/redirect/4")) { + try request.respond("Hello, Redirected!\n", .{ + .status = .found, + .extra_headers = &.{ + .{ .name = "location", .value = "/redirect/3" }, + }, + }); + } else if (mem.eql(u8, request.head.target, "/redirect/invalid")) { + const invalid_port = try getUnusedTcpPort(); + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port}); + defer gpa.free(location); + + try request.respond("", .{ + .status = .found, + .extra_headers = &.{ + .{ .name = "location", .value = location }, + }, + }); + } else { + try request.respond("", .{ .status = .not_found }); + } + } + + fn getUnusedTcpPort() !u16 { + const addr = try std.net.Address.parseIp("127.0.0.1", 0); + var s = try addr.listen(.{}); + defer s.deinit(); + return s.listen_address.in.getPort(); + } + }); + defer test_server.destroy(); + + const log = std.log.scoped(.client); + + const gpa = std.testing.allocator; + var client: http.Client = .{ .allocator = gpa }; + errdefer client.deinit(); + // defer client.deinit(); handled below + + const port = test_server.port(); + + { // read content-length response + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + try expectEqualStrings("text/plain", req.response.content_type.?); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // read large content-length response + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/large", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192 * 1024); + defer gpa.free(body); + + try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // send head request and not read chunked + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.HEAD, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("", body); + try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqual(14, req.response.content_length.?); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // read chunked response + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get?chunked", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + try expectEqualStrings("text/plain", req.response.content_type.?); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // send head request and not read chunked + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get?chunked", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.HEAD, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("", body); + try expectEqualStrings("text/plain", req.response.content_type.?); + try expect(req.response.transfer_encoding == .chunked); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // read content-length response with connection close + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + .keep_alive = false, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + try expectEqualStrings("text/plain", req.response.content_type.?); + } + + // connection has been closed + try expect(client.connection_pool.free_len == 0); + + { // relative redirect + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/1", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // redirect from root + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // absolute redirect + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/3", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // too many redirects + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/4", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + req.wait() catch |err| switch (err) { + error.TooManyHttpRedirects => {}, + else => return err, + }; + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // check client without segfault by connection error after redirection + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/invalid", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.GET, uri, .{ + .server_header_buffer = &server_header_buffer, + }); + defer req.deinit(); + + try req.send(.{}); + const result = req.wait(); + + // a proxy without an upstream is likely to return a 5xx status. + if (client.http_proxy == null) { + try expectError(error.ConnectionRefused, result); // expects not segfault but the regular error + } + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + const total_connections = client.connection_pool.free_size + 64; + var requests = try gpa.alloc(http.Client.Request, total_connections); + defer gpa.free(requests); + + var header_bufs = std.ArrayList([]u8).init(gpa); + defer header_bufs.deinit(); + defer for (header_bufs.items) |item| gpa.free(item); + + for (0..total_connections) |i| { + const headers_buf = try gpa.alloc(u8, 1024); + try header_bufs.append(headers_buf); + var req = try client.open(.GET, uri, .{ + .server_header_buffer = headers_buf, + }); + req.response.parser.done = true; + req.connection.?.closing = false; + requests[i] = req; + } + + for (0..total_connections) |i| { + requests[i].deinit(); + } + + // free connections should be full now + try expect(client.connection_pool.free_len == client.connection_pool.free_size); + } + + client.deinit(); + + { + global.handle_new_requests = false; + + const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address); + conn.close(); + } +} + +fn echoTests(client: *http.Client, port: u16) !void { + const gpa = std.testing.allocator; + var location_buffer: [100]u8 = undefined; + + { // send content-length request + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.POST, uri, .{ + .server_header_buffer = &server_header_buffer, + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + }); + defer req.deinit(); + + req.transfer_encoding = .{ .content_length = 14 }; + + try req.send(.{}); + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // send chunked request + const uri = try std.Uri.parse(try std.fmt.bufPrint( + &location_buffer, + "http://127.0.0.1:{d}/echo-content", + .{port}, + )); + + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.POST, uri, .{ + .server_header_buffer = &server_header_buffer, + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + }); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.send(.{}); + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + // connection has been kept alive + try expect(client.http_proxy != null or client.connection_pool.free_len == 1); + + { // Client.fetch() + + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port}); + defer gpa.free(location); + + var body = std.ArrayList(u8).init(gpa); + defer body.deinit(); + + const res = try client.fetch(.{ + .location = .{ .url = location }, + .method = .POST, + .payload = "Hello, World!\n", + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + }, + .response_storage = .{ .dynamic = &body }, + }); + try expectEqual(.ok, res.status); + try expectEqualStrings("Hello, World!\n", body.items); + } + + { // expect: 100-continue + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#expect-100", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.POST, uri, .{ + .server_header_buffer = &server_header_buffer, + .extra_headers = &.{ + .{ .name = "expect", .value = "100-continue" }, + .{ .name = "content-type", .value = "text/plain" }, + }, + }); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.send(.{}); + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + try expectEqual(.ok, req.response.status); + + const body = try req.reader().readAllAlloc(gpa, 8192); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + } + + { // expect: garbage + const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port}); + defer gpa.free(location); + const uri = try std.Uri.parse(location); + + var server_header_buffer: [1024]u8 = undefined; + var req = try client.open(.POST, uri, .{ + .server_header_buffer = &server_header_buffer, + .extra_headers = &.{ + .{ .name = "content-type", .value = "text/plain" }, + .{ .name = "expect", .value = "garbage" }, + }, + }); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.send(.{}); + try req.wait(); + try expectEqual(.expectation_failed, req.response.status); + } + + _ = try client.fetch(.{ + .location = .{ + .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}), + }, + }); +} + +const TestServer = struct { + server_thread: std.Thread, + net_server: std.net.Server, + + fn destroy(self: *@This()) void { + self.server_thread.join(); + self.net_server.deinit(); + std.testing.allocator.destroy(self); + } + + fn port(self: @This()) u16 { + return self.net_server.listen_address.in.getPort(); + } +}; + +fn createTestServer(S: type) !*TestServer { + if (builtin.single_threaded) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { + // https://github.com/ziglang/zig/issues/13782 + return error.SkipZigTest; + } + + const address = try std.net.Address.parseIp("127.0.0.1", 0); + const test_server = try std.testing.allocator.create(TestServer); + test_server.net_server = try address.listen(.{ .reuse_address = true }); + test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server}); + return test_server; +} diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index 0d96629e7a8e..9569d8d56548 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -360,6 +360,18 @@ pub fn readEnum(self: Self, comptime Enum: type, endian: std.builtin.Endian) any return E.InvalidValue; } +/// Reads the stream until the end, ignoring all the data. +/// Returns the number of bytes discarded. +pub fn discard(self: Self) anyerror!u64 { + var trash: [4096]u8 = undefined; + var index: u64 = 0; + while (true) { + const n = try self.read(&trash); + if (n == 0) return index; + index += n; + } +} + const std = @import("../std.zig"); const Self = @This(); const math = std.math; diff --git a/lib/std/mem.zig b/lib/std/mem.zig index feb41aedebd3..f263b3e85164 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -1338,7 +1338,7 @@ pub fn indexOf(comptime T: type, haystack: []const T, needle: []const T) ?usize pub fn lastIndexOfLinear(comptime T: type, haystack: []const T, needle: []const T) ?usize { var i: usize = haystack.len - needle.len; while (true) : (i -= 1) { - if (mem.eql(T, haystack[i .. i + needle.len], needle)) return i; + if (mem.eql(T, haystack[i..][0..needle.len], needle)) return i; if (i == 0) return null; } } @@ -1349,7 +1349,7 @@ pub fn indexOfPosLinear(comptime T: type, haystack: []const T, start_index: usiz var i: usize = start_index; const end = haystack.len - needle.len; while (i <= end) : (i += 1) { - if (eql(T, haystack[i .. i + needle.len], needle)) return i; + if (eql(T, haystack[i..][0..needle.len], needle)) return i; } return null; } diff --git a/lib/std/net.zig b/lib/std/net.zig index 154e2f7375f3..66b90867c642 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -4,15 +4,17 @@ const assert = std.debug.assert; const net = @This(); const mem = std.mem; const os = std.os; +const posix = std.posix; const fs = std.fs; const io = std.io; const native_endian = builtin.target.cpu.arch.endian(); // Windows 10 added support for unix sockets in build 17063, redstone 4 is the // first release to support them. -pub const has_unix_sockets = @hasDecl(os.sockaddr, "un") and - (builtin.target.os.tag != .windows or - builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false); +pub const has_unix_sockets = switch (builtin.os.tag) { + .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, + else => true, +}; pub const IPParseError = error{ Overflow, @@ -122,7 +124,7 @@ pub const Address = extern union { @memset(&sock_addr.path, 0); @memcpy(sock_addr.path[0..path.len], path); - return Address{ .un = sock_addr }; + return .{ .un = sock_addr }; } /// Returns the port in native endian. @@ -206,6 +208,60 @@ pub const Address = extern union { else => unreachable, } } + + pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || + posix.SetSockOptError || posix.GetSockNameError; + + pub const ListenOptions = struct { + /// How many connections the kernel will accept on the application's behalf. + /// If more than this many connections pool in the kernel, clients will start + /// seeing "Connection refused". + kernel_backlog: u31 = 128, + /// Sets SO_REUSEADDR and SO_REUSEPORT on POSIX. + /// Sets SO_REUSEADDR on Windows, which is roughly equivalent. + reuse_address: bool = false, + /// Deprecated. Does the same thing as reuse_address. + reuse_port: bool = false, + force_nonblocking: bool = false, + }; + + /// The returned `Server` has an open `stream`. + pub fn listen(address: Address, options: ListenOptions) ListenError!Server { + const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; + const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; + const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; + + const sockfd = try posix.socket(address.any.family, sock_flags, proto); + var s: Server = .{ + .listen_address = undefined, + .stream = .{ .handle = sockfd }, + }; + errdefer s.stream.close(); + + if (options.reuse_address or options.reuse_port) { + try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEADDR, + &mem.toBytes(@as(c_int, 1)), + ); + switch (builtin.os.tag) { + .windows => {}, + else => try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEPORT, + &mem.toBytes(@as(c_int, 1)), + ), + } + } + + var socklen = address.getOsSockLen(); + try posix.bind(sockfd, &address.any, socklen); + try posix.listen(sockfd, options.kernel_backlog); + try posix.getsockname(sockfd, &s.listen_address.any, &socklen); + return s; + } }; pub const Ip4Address = extern struct { @@ -657,7 +713,7 @@ pub fn connectUnixSocket(path: []const u8) !Stream { os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block, 0, ); - errdefer os.closeSocket(sockfd); + errdefer Stream.close(.{ .handle = sockfd }); var addr = try std.net.Address.initUnix(path); try os.connect(sockfd, &addr.any, addr.getOsSockLen()); @@ -669,7 +725,7 @@ fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { if (builtin.target.os.tag == .linux) { var ifr: os.ifreq = undefined; const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0); - defer os.closeSocket(sockfd); + defer Stream.close(.{ .handle = sockfd }); @memcpy(ifr.ifrn.name[0..name.len], name); ifr.ifrn.name[name.len] = 0; @@ -738,7 +794,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { const sock_flags = os.SOCK.STREAM | nonblock | (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP); - errdefer os.closeSocket(sockfd); + errdefer Stream.close(.{ .handle = sockfd }); try os.connect(sockfd, &address.any, address.getOsSockLen()); @@ -1068,7 +1124,7 @@ fn linuxLookupName( var prefixlen: i32 = 0; const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC; if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: { - defer os.closeSocket(fd); + defer Stream.close(.{ .handle = fd }); os.connect(fd, da, dalen) catch break :syscalls; key |= DAS_USABLE; os.getsockname(fd, sa, &salen) catch break :syscalls; @@ -1553,7 +1609,7 @@ fn resMSendRc( }, else => |e| return e, }; - defer os.closeSocket(fd); + defer Stream.close(.{ .handle = fd }); // Past this point, there are no errors. Each individual query will // yield either no reply (indicated by zero length) or an answer @@ -1729,13 +1785,15 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) } pub const Stream = struct { - // Underlying socket descriptor. - // Note that on some platforms this may not be interchangeable with a - // regular files descriptor. - handle: os.socket_t, - - pub fn close(self: Stream) void { - os.closeSocket(self.handle); + /// Underlying platform-defined type which may or may not be + /// interchangeable with a file system file descriptor. + handle: posix.socket_t, + + pub fn close(s: Stream) void { + switch (builtin.os.tag) { + .windows => std.os.windows.closesocket(s.handle) catch unreachable, + else => posix.close(s.handle), + } } pub const ReadError = os.ReadError; @@ -1839,156 +1897,38 @@ pub const Stream = struct { } }; -pub const StreamServer = struct { - /// Copied from `Options` on `init`. - kernel_backlog: u31, - reuse_address: bool, - reuse_port: bool, - force_nonblocking: bool, - - /// `undefined` until `listen` returns successfully. +pub const Server = struct { listen_address: Address, + stream: std.net.Stream, - sockfd: ?os.socket_t, - - pub const Options = struct { - /// How many connections the kernel will accept on the application's behalf. - /// If more than this many connections pool in the kernel, clients will start - /// seeing "Connection refused". - kernel_backlog: u31 = 128, - - /// Enable SO.REUSEADDR on the socket. - reuse_address: bool = false, - - /// Enable SO.REUSEPORT on the socket. - reuse_port: bool = false, - - /// Force non-blocking mode. - force_nonblocking: bool = false, + pub const Connection = struct { + stream: std.net.Stream, + address: Address, }; - /// After this call succeeds, resources have been acquired and must - /// be released with `deinit`. - pub fn init(options: Options) StreamServer { - return StreamServer{ - .sockfd = null, - .kernel_backlog = options.kernel_backlog, - .reuse_address = options.reuse_address, - .reuse_port = options.reuse_port, - .force_nonblocking = options.force_nonblocking, - .listen_address = undefined, - }; - } - - /// Release all resources. The `StreamServer` memory becomes `undefined`. - pub fn deinit(self: *StreamServer) void { - self.close(); - self.* = undefined; - } - - pub fn listen(self: *StreamServer, address: Address) !void { - const nonblock = 0; - const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock; - var use_sock_flags: u32 = sock_flags; - if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK; - const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP; - - const sockfd = try os.socket(address.any.family, use_sock_flags, proto); - self.sockfd = sockfd; - errdefer { - os.closeSocket(sockfd); - self.sockfd = null; - } - - if (self.reuse_address) { - try os.setsockopt( - sockfd, - os.SOL.SOCKET, - os.SO.REUSEADDR, - &mem.toBytes(@as(c_int, 1)), - ); - } - if (@hasDecl(os.SO, "REUSEPORT") and self.reuse_port) { - try os.setsockopt( - sockfd, - os.SOL.SOCKET, - os.SO.REUSEPORT, - &mem.toBytes(@as(c_int, 1)), - ); - } - - var socklen = address.getOsSockLen(); - try os.bind(sockfd, &address.any, socklen); - try os.listen(sockfd, self.kernel_backlog); - try os.getsockname(sockfd, &self.listen_address.any, &socklen); - } - - /// Stop listening. It is still necessary to call `deinit` after stopping listening. - /// Calling `deinit` will automatically call `close`. It is safe to call `close` when - /// not listening. - pub fn close(self: *StreamServer) void { - if (self.sockfd) |fd| { - os.closeSocket(fd); - self.sockfd = null; - self.listen_address = undefined; - } + pub fn deinit(s: *Server) void { + s.stream.close(); + s.* = undefined; } - pub const AcceptError = error{ - ConnectionAborted, + pub const AcceptError = posix.AcceptError; - /// The per-process limit on the number of open file descriptors has been reached. - ProcessFdQuotaExceeded, - - /// The system-wide limit on the total number of open files has been reached. - SystemFdQuotaExceeded, - - /// Not enough free memory. This often means that the memory allocation - /// is limited by the socket buffer limits, not by the system memory. - SystemResources, - - /// Socket is not listening for new connections. - SocketNotListening, - - ProtocolFailure, - - /// Socket is in non-blocking mode and there is no connection to accept. - WouldBlock, - - /// Firewall rules forbid connection. - BlockedByFirewall, - - FileDescriptorNotASocket, - - ConnectionResetByPeer, - - NetworkSubsystemFailed, - - OperationNotSupported, - } || os.UnexpectedError; - - pub const Connection = struct { - stream: Stream, - address: Address, - }; - - /// If this function succeeds, the returned `Connection` is a caller-managed resource. - pub fn accept(self: *StreamServer) AcceptError!Connection { + /// Blocks until a client connects to the server. The returned `Connection` has + /// an open stream. + pub fn accept(s: *Server) AcceptError!Connection { var accepted_addr: Address = undefined; - var adr_len: os.socklen_t = @sizeOf(Address); - const accept_result = os.accept(self.sockfd.?, &accepted_addr.any, &adr_len, os.SOCK.CLOEXEC); - - if (accept_result) |fd| { - return Connection{ - .stream = Stream{ .handle = fd }, - .address = accepted_addr, - }; - } else |err| { - return err; - } + var addr_len: posix.socklen_t = @sizeOf(Address); + const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + return .{ + .stream = .{ .handle = fd }, + .address = accepted_addr, + }; } }; test { _ = @import("net/test.zig"); + _ = Server; + _ = Stream; + _ = Address; } diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index e359abb6d5cb..3e316c545643 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -181,11 +181,9 @@ test "listen on a port, send bytes, receive bytes" { // configured. const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = net.StreamServer.init(.{}); + var server = try localhost.listen(.{}); defer server.deinit(); - try server.listen(localhost); - const S = struct { fn clientFn(server_address: net.Address) !void { const socket = try net.tcpConnectToAddress(server_address); @@ -215,17 +213,11 @@ test "listen on an in use port" { const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server1 = net.StreamServer.init(net.StreamServer.Options{ - .reuse_port = true, - }); + var server1 = try localhost.listen(.{ .reuse_port = true }); defer server1.deinit(); - try server1.listen(localhost); - var server2 = net.StreamServer.init(net.StreamServer.Options{ - .reuse_port = true, - }); + var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); defer server2.deinit(); - try server2.listen(server1.listen_address); } fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { @@ -252,7 +244,7 @@ fn testClient(addr: net.Address) anyerror!void { try testing.expect(mem.eql(u8, msg, "hello from server\n")); } -fn testServer(server: *net.StreamServer) anyerror!void { +fn testServer(server: *net.Server) anyerror!void { if (builtin.os.tag == .wasi) return error.SkipZigTest; var client = try server.accept(); @@ -274,15 +266,14 @@ test "listen on a unix socket, send bytes, receive bytes" { } } - var server = net.StreamServer.init(.{}); - defer server.deinit(); - const socket_path = try generateFileName("socket.unix"); defer testing.allocator.free(socket_path); const socket_addr = try net.Address.initUnix(socket_path); defer std.fs.cwd().deleteFile(socket_path) catch {}; - try server.listen(socket_addr); + + var server = try socket_addr.listen(.{}); + defer server.deinit(); const S = struct { fn clientFn(path: []const u8) !void { @@ -323,9 +314,8 @@ test "non-blocking tcp server" { } const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = net.StreamServer.init(.{ .force_nonblocking = true }); + var server = localhost.listen(.{ .force_nonblocking = true }); defer server.deinit(); - try server.listen(localhost); const accept_err = server.accept(); try testing.expectError(error.WouldBlock, accept_err); diff --git a/lib/std/os.zig b/lib/std/os.zig index 6880878c45ef..87402e49a354 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -3598,14 +3598,6 @@ pub fn shutdown(sock: socket_t, how: ShutdownHow) ShutdownError!void { } } -pub fn closeSocket(sock: socket_t) void { - if (builtin.os.tag == .windows) { - windows.closesocket(sock) catch unreachable; - } else { - close(sock); - } -} - pub const BindError = error{ /// The address is protected, and the user is not the superuser. /// For UNIX domain sockets: Search permission is denied on a component diff --git a/lib/std/os/linux/io_uring.zig b/lib/std/os/linux/io_uring.zig index dbde08c2c105..16c542714c28 100644 --- a/lib/std/os/linux/io_uring.zig +++ b/lib/std/os/linux/io_uring.zig @@ -4,6 +4,7 @@ const assert = std.debug.assert; const mem = std.mem; const net = std.net; const os = std.os; +const posix = std.posix; const linux = os.linux; const testing = std.testing; @@ -3730,8 +3731,8 @@ const SocketTestHarness = struct { client: os.socket_t, fn close(self: SocketTestHarness) void { - os.closeSocket(self.client); - os.closeSocket(self.listener); + posix.close(self.client); + posix.close(self.listener); } }; @@ -3739,7 +3740,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { // Create a TCP server socket var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - errdefer os.closeSocket(listener_socket); + errdefer posix.close(listener_socket); // Submit 1 accept var accept_addr: os.sockaddr = undefined; @@ -3748,7 +3749,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { // Create a TCP client socket const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(client); + errdefer posix.close(client); _ = try ring.connect(0xcccccccc, client, &address.any, address.getOsSockLen()); try testing.expectEqual(@as(u32, 2), try ring.submit()); @@ -3788,7 +3789,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness { fn createListenerSocket(address: *net.Address) !os.socket_t { const kernel_backlog = 1; const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(listener_socket); + errdefer posix.close(listener_socket); try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); try os.bind(listener_socket, &address.any, address.getOsSockLen()); @@ -3813,7 +3814,7 @@ test "accept multishot" { var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); // submit multishot accept operation var addr: os.sockaddr = undefined; @@ -3826,7 +3827,7 @@ test "accept multishot" { while (nr > 0) : (nr -= 1) { // connect client const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); - errdefer os.closeSocket(client); + errdefer posix.close(client); try os.connect(client, &address.any, address.getOsSockLen()); // test accept completion @@ -3836,7 +3837,7 @@ test "accept multishot" { try testing.expect(cqe.user_data == userdata); try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set - os.closeSocket(client); + posix.close(client); } } @@ -3909,7 +3910,7 @@ test "accept_direct" { try ring.register_files(registered_fds[0..]); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; const read_userdata: u64 = 0xbbbbbbbb; @@ -3927,7 +3928,7 @@ test "accept_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // accept completion const cqe_accept = try ring.copy_cqe(); @@ -3961,7 +3962,7 @@ test "accept_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // completion with error const cqe_accept = try ring.copy_cqe(); try testing.expect(cqe_accept.user_data == accept_userdata); @@ -3989,7 +3990,7 @@ test "accept_multishot_direct" { try ring.register_files(registered_fds[0..]); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; @@ -4003,7 +4004,7 @@ test "accept_multishot_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // accept completion const cqe_accept = try ring.copy_cqe(); @@ -4018,7 +4019,7 @@ test "accept_multishot_direct" { // connect const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); try os.connect(client, &address.any, address.getOsSockLen()); - defer os.closeSocket(client); + defer posix.close(client); // completion with error const cqe_accept = try ring.copy_cqe(); try testing.expect(cqe_accept.user_data == accept_userdata); @@ -4092,7 +4093,7 @@ test "socket_direct/socket_direct_alloc/close_direct" { // use sockets from registered_fds in connect operation var address = try net.Address.parseIp4("127.0.0.1", 0); const listener_socket = try createListenerSocket(&address); - defer os.closeSocket(listener_socket); + defer posix.close(listener_socket); const accept_userdata: u64 = 0xaaaaaaaa; const connect_userdata: u64 = 0xbbbbbbbb; const close_userdata: u64 = 0xcccccccc; diff --git a/lib/std/os/test.zig b/lib/std/os/test.zig index 5fee5dcc7f16..0d9255641c70 100644 --- a/lib/std/os/test.zig +++ b/lib/std/os/test.zig @@ -817,7 +817,7 @@ test "shutdown socket" { error.SocketNotConnected => {}, else => |e| return e, }; - os.closeSocket(sock); + std.net.Stream.close(.{ .handle = sock }); } test "sigaction" { diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index ed3c6b099fd8..8fbaf79ea560 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -354,7 +354,8 @@ pub fn run(f: *Fetch) RunError!void { .{ path_or_url, @errorName(file_err), @errorName(uri_err) }, )); }; - var resource = try f.initResource(uri); + var server_header_buffer: [header_buffer_size]u8 = undefined; + var resource = try f.initResource(uri, &server_header_buffer); return runResource(f, uri.path, &resource, null); } }, @@ -415,7 +416,8 @@ pub fn run(f: *Fetch) RunError!void { f.location_tok, try eb.printString("invalid URI: {s}", .{@errorName(err)}), ); - var resource = try f.initResource(uri); + var server_header_buffer: [header_buffer_size]u8 = undefined; + var resource = try f.initResource(uri, &server_header_buffer); return runResource(f, uri.path, &resource, remote.hash); } @@ -876,7 +878,9 @@ const FileType = enum { } }; -fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { +const header_buffer_size = 16 * 1024; + +fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Resource { const gpa = f.arena.child_allocator; const arena = f.arena.allocator(); const eb = &f.error_bundle; @@ -894,10 +898,9 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { if (ascii.eqlIgnoreCase(uri.scheme, "http") or ascii.eqlIgnoreCase(uri.scheme, "https")) { - var h = std.http.Headers{ .allocator = gpa }; - defer h.deinit(); - - var req = http_client.open(.GET, uri, h, .{}) catch |err| { + var req = http_client.open(.GET, uri, .{ + .server_header_buffer = server_header_buffer, + }) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to connect to server: {s}", .{@errorName(err)}, @@ -935,7 +938,7 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { transport_uri.scheme = uri.scheme["git+".len..]; var redirect_uri: []u8 = undefined; var session: git.Session = .{ .transport = http_client, .uri = transport_uri }; - session.discoverCapabilities(gpa, &redirect_uri) catch |err| switch (err) { + session.discoverCapabilities(gpa, &redirect_uri, server_header_buffer) catch |err| switch (err) { error.Redirected => { defer gpa.free(redirect_uri); return f.fail(f.location_tok, try eb.printString( @@ -961,6 +964,7 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { var ref_iterator = session.listRefs(gpa, .{ .ref_prefixes = &.{ want_ref, want_ref_head, want_ref_tag }, .include_peeled = true, + .server_header_buffer = server_header_buffer, }) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to list refs: {s}", @@ -1003,7 +1007,7 @@ fn initResource(f: *Fetch, uri: std.Uri) RunError!Resource { _ = std.fmt.bufPrint(&want_oid_buf, "{}", .{ std.fmt.fmtSliceHexLower(&want_oid), }) catch unreachable; - var fetch_stream = session.fetch(gpa, &.{&want_oid_buf}) catch |err| { + var fetch_stream = session.fetch(gpa, &.{&want_oid_buf}, server_header_buffer) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to create fetch stream: {s}", .{@errorName(err)}, @@ -1036,7 +1040,7 @@ fn unpackResource( .http_request => |req| ft: { // Content-Type takes first precedence. - const content_type = req.response.headers.getFirstValue("Content-Type") orelse + const content_type = req.response.content_type orelse return f.fail(f.location_tok, try eb.addString("missing 'Content-Type' header")); // Extract the MIME type, ignoring charset and boundary directives @@ -1069,7 +1073,7 @@ fn unpackResource( } // Next, the filename from 'content-disposition: attachment' takes precedence. - if (req.response.headers.getFirstValue("Content-Disposition")) |cd_header| { + if (req.response.content_disposition) |cd_header| { break :ft FileType.fromContentDisposition(cd_header) orelse { return f.fail(f.location_tok, try eb.printString( "unsupported Content-Disposition header value: '{s}' for Content-Type=application/octet-stream", @@ -1105,8 +1109,29 @@ fn unpackResource( var dcp = std.compress.gzip.decompressor(br.reader()); try unpackTarball(f, tmp_directory.handle, dcp.reader()); }, - .@"tar.xz" => try unpackTarballCompressed(f, tmp_directory.handle, resource, std.compress.xz), - .@"tar.zst" => try unpackTarballCompressed(f, tmp_directory.handle, resource, ZstdWrapper), + .@"tar.xz" => { + const gpa = f.arena.child_allocator; + const reader = resource.reader(); + var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| { + return f.fail(f.location_tok, try eb.printString( + "unable to decompress tarball: {s}", + .{@errorName(err)}, + )); + }; + defer dcp.deinit(); + try unpackTarball(f, tmp_directory.handle, dcp.reader()); + }, + .@"tar.zst" => { + const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len; + const window_buffer = try f.arena.allocator().create([window_size]u8); + const reader = resource.reader(); + var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var dcp = std.compress.zstd.decompressor(br.reader(), .{ + .window_buffer = window_buffer, + }); + return unpackTarball(f, tmp_directory.handle, dcp.reader()); + }, .git_pack => unpackGitPack(f, tmp_directory.handle, resource) catch |err| switch (err) { error.FetchFailed => return error.FetchFailed, error.OutOfMemory => return error.OutOfMemory, @@ -1118,40 +1143,6 @@ fn unpackResource( } } -// due to slight differences in the API of std.compress.(gzip|xz) and std.compress.zstd, zstd is -// wrapped for generic use in unpackTarballCompressed: see github.com/ziglang/zig/issues/14739 -const ZstdWrapper = struct { - fn DecompressType(comptime T: type) type { - return error{}!std.compress.zstd.DecompressStream(T, .{}); - } - - fn decompress(allocator: Allocator, reader: anytype) DecompressType(@TypeOf(reader)) { - return std.compress.zstd.decompressStream(allocator, reader); - } -}; - -fn unpackTarballCompressed( - f: *Fetch, - out_dir: fs.Dir, - resource: *Resource, - comptime Compression: type, -) RunError!void { - const gpa = f.arena.child_allocator; - const eb = &f.error_bundle; - const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); - - var decompress = Compression.decompress(gpa, br.reader()) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to decompress tarball: {s}", - .{@errorName(err)}, - )); - }; - defer decompress.deinit(); - - return unpackTarball(f, out_dir, decompress.reader()); -} - fn unpackTarball(f: *Fetch, out_dir: fs.Dir, reader: anytype) RunError!void { const eb = &f.error_bundle; const gpa = f.arena.child_allocator; diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index ee8f1ba543f4..dc0c844d1daf 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -494,8 +494,9 @@ pub const Session = struct { session: *Session, allocator: Allocator, redirect_uri: *[]u8, + http_headers_buffer: []u8, ) !void { - var capability_iterator = try session.getCapabilities(allocator, redirect_uri); + var capability_iterator = try session.getCapabilities(allocator, redirect_uri, http_headers_buffer); defer capability_iterator.deinit(); while (try capability_iterator.next()) |capability| { if (mem.eql(u8, capability.key, "agent")) { @@ -521,6 +522,7 @@ pub const Session = struct { session: Session, allocator: Allocator, redirect_uri: *[]u8, + http_headers_buffer: []u8, ) !CapabilityIterator { var info_refs_uri = session.uri; info_refs_uri.path = try std.fs.path.resolvePosix(allocator, &.{ "/", session.uri.path, "info/refs" }); @@ -528,12 +530,13 @@ pub const Session = struct { info_refs_uri.query = "service=git-upload-pack"; info_refs_uri.fragment = null; - var headers = std.http.Headers.init(allocator); - defer headers.deinit(); - try headers.append("Git-Protocol", "version=2"); - - var request = try session.transport.open(.GET, info_refs_uri, headers, .{ - .max_redirects = 3, + const max_redirects = 3; + var request = try session.transport.open(.GET, info_refs_uri, .{ + .redirect_behavior = @enumFromInt(max_redirects), + .server_header_buffer = http_headers_buffer, + .extra_headers = &.{ + .{ .name = "Git-Protocol", .value = "version=2" }, + }, }); errdefer request.deinit(); try request.send(.{}); @@ -541,7 +544,8 @@ pub const Session = struct { try request.wait(); if (request.response.status != .ok) return error.ProtocolError; - if (request.redirects_left < 3) { + const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects; + if (any_redirects_occurred) { if (!mem.endsWith(u8, request.uri.path, "/info/refs")) return error.UnparseableRedirect; var new_uri = request.uri; new_uri.path = new_uri.path[0 .. new_uri.path.len - "/info/refs".len]; @@ -620,6 +624,7 @@ pub const Session = struct { include_symrefs: bool = false, /// Whether to include the peeled object ID for returned tag refs. include_peeled: bool = false, + server_header_buffer: []u8, }; /// Returns an iterator over refs known to the server. @@ -630,11 +635,6 @@ pub const Session = struct { upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var headers = std.http.Headers.init(allocator); - defer headers.deinit(); - try headers.append("Content-Type", "application/x-git-upload-pack-request"); - try headers.append("Git-Protocol", "version=2"); - var body = std.ArrayListUnmanaged(u8){}; defer body.deinit(allocator); const body_writer = body.writer(allocator); @@ -656,8 +656,13 @@ pub const Session = struct { } try Packet.write(.flush, body_writer); - var request = try session.transport.open(.POST, upload_pack_uri, headers, .{ - .handle_redirects = false, + var request = try session.transport.open(.POST, upload_pack_uri, .{ + .redirect_behavior = .unhandled, + .server_header_buffer = options.server_header_buffer, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, }); errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; @@ -721,18 +726,18 @@ pub const Session = struct { /// Fetches the given refs from the server. A shallow fetch (depth 1) is /// performed if the server supports it. - pub fn fetch(session: Session, allocator: Allocator, wants: []const []const u8) !FetchStream { + pub fn fetch( + session: Session, + allocator: Allocator, + wants: []const []const u8, + http_headers_buffer: []u8, + ) !FetchStream { var upload_pack_uri = session.uri; upload_pack_uri.path = try std.fs.path.resolvePosix(allocator, &.{ "/", session.uri.path, "git-upload-pack" }); defer allocator.free(upload_pack_uri.path); upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var headers = std.http.Headers.init(allocator); - defer headers.deinit(); - try headers.append("Content-Type", "application/x-git-upload-pack-request"); - try headers.append("Git-Protocol", "version=2"); - var body = std.ArrayListUnmanaged(u8){}; defer body.deinit(allocator); const body_writer = body.writer(allocator); @@ -756,8 +761,13 @@ pub const Session = struct { try Packet.write(.{ .data = "done\n" }, body_writer); try Packet.write(.flush, body_writer); - var request = try session.transport.open(.POST, upload_pack_uri, headers, .{ - .handle_redirects = false, + var request = try session.transport.open(.POST, upload_pack_uri, .{ + .redirect_behavior = .not_allowed, + .server_header_buffer = http_headers_buffer, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, }); errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; diff --git a/src/main.zig b/src/main.zig index 1a9c264b7d15..584a34eeee5b 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3322,13 +3322,13 @@ fn buildOutputType( .ip4 => |ip4_addr| { if (build_options.only_core_functionality) unreachable; - var server = std.net.StreamServer.init(.{ + const addr: std.net.Address = .{ .in = ip4_addr }; + + var server = try addr.listen(.{ .reuse_address = true, }); defer server.deinit(); - try server.listen(.{ .in = ip4_addr }); - const conn = try server.accept(); defer conn.stream.close(); @@ -5486,7 +5486,7 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { job_queue.read_only = true; cleanup_build_dir = job_queue.global_cache.handle; } else { - try http_client.loadDefaultProxies(); + try http_client.initDefaultProxies(arena); } try job_queue.all_fetches.ensureUnusedCapacity(gpa, 1); @@ -7442,7 +7442,7 @@ fn cmdFetch( var http_client: std.http.Client = .{ .allocator = gpa }; defer http_client.deinit(); - try http_client.loadDefaultProxies(); + try http_client.initDefaultProxies(arena); var progress: std.Progress = .{ .dont_print_on_dumb = true }; const root_prog_node = progress.start("Fetch", 0); diff --git a/test/standalone.zig b/test/standalone.zig index 3740ddb80adb..e72deb38ae08 100644 --- a/test/standalone.zig +++ b/test/standalone.zig @@ -55,10 +55,6 @@ pub const simple_cases = [_]SimpleCase{ .os_filter = .windows, .link_libc = true, }, - .{ - .src_path = "test/standalone/http.zig", - .all_modes = true, - }, // Ensure the development tools are buildable. Alphabetically sorted. // No need to build `tools/spirv/grammar.zig`. diff --git a/test/standalone/http.zig b/test/standalone/http.zig deleted file mode 100644 index 5002d8910d0a..000000000000 --- a/test/standalone/http.zig +++ /dev/null @@ -1,700 +0,0 @@ -const std = @import("std"); - -const http = std.http; -const Server = http.Server; -const Client = http.Client; - -const mem = std.mem; -const testing = std.testing; - -pub const std_options = .{ - .http_disable_tls = true, -}; - -const max_header_size = 8192; - -var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; -var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; - -const salloc = gpa_server.allocator(); -const calloc = gpa_client.allocator(); - -var server: Server = undefined; - -fn handleRequest(res: *Server.Response) !void { - const log = std.log.scoped(.server); - - log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target }); - - if (res.request.headers.contains("expect")) { - if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) { - res.status = .@"continue"; - try res.send(); - res.status = .ok; - } else { - res.status = .expectation_failed; - try res.send(); - return; - } - } - - const body = try res.reader().readAllAlloc(salloc, 8192); - defer salloc.free(body); - - if (res.request.headers.contains("connection")) { - try res.headers.append("connection", "keep-alive"); - } - - if (mem.startsWith(u8, res.request.target, "/get")) { - if (std.mem.indexOf(u8, res.request.target, "?chunked") != null) { - res.transfer_encoding = .chunked; - } else { - res.transfer_encoding = .{ .content_length = 14 }; - } - - try res.headers.append("content-type", "text/plain"); - - try res.send(); - if (res.request.method != .HEAD) { - try res.writeAll("Hello, "); - try res.writeAll("World!\n"); - try res.finish(); - } else { - try testing.expectEqual(res.writeAll("errors"), error.NotWriteable); - } - } else if (mem.startsWith(u8, res.request.target, "/large")) { - res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 }; - - try res.send(); - - var i: u32 = 0; - while (i < 5) : (i += 1) { - try res.writeAll("Hello, World!\n"); - } - - try res.writeAll("Hello, World!\n" ** 1024); - - i = 0; - while (i < 5) : (i += 1) { - try res.writeAll("Hello, World!\n"); - } - - try res.finish(); - } else if (mem.startsWith(u8, res.request.target, "/echo-content")) { - try testing.expectEqualStrings("Hello, World!\n", body); - try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); - - if (res.request.headers.contains("transfer-encoding")) { - try testing.expectEqualStrings("chunked", res.request.headers.getFirstValue("transfer-encoding").?); - res.transfer_encoding = .chunked; - } else { - res.transfer_encoding = .{ .content_length = 14 }; - try testing.expectEqualStrings("14", res.request.headers.getFirstValue("content-length").?); - } - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("World!\n"); - try res.finish(); - } else if (mem.eql(u8, res.request.target, "/trailer")) { - res.transfer_encoding = .chunked; - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("World!\n"); - // try res.finish(); - try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); - } else if (mem.eql(u8, res.request.target, "/redirect/1")) { - res.transfer_encoding = .chunked; - - res.status = .found; - try res.headers.append("location", "../../get"); - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("Redirected!\n"); - try res.finish(); - } else if (mem.eql(u8, res.request.target, "/redirect/2")) { - res.transfer_encoding = .chunked; - - res.status = .found; - try res.headers.append("location", "/redirect/1"); - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("Redirected!\n"); - try res.finish(); - } else if (mem.eql(u8, res.request.target, "/redirect/3")) { - res.transfer_encoding = .chunked; - - const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()}); - defer salloc.free(location); - - res.status = .found; - try res.headers.append("location", location); - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("Redirected!\n"); - try res.finish(); - } else if (mem.eql(u8, res.request.target, "/redirect/4")) { - res.transfer_encoding = .chunked; - - res.status = .found; - try res.headers.append("location", "/redirect/3"); - - try res.send(); - try res.writeAll("Hello, "); - try res.writeAll("Redirected!\n"); - try res.finish(); - } else if (mem.eql(u8, res.request.target, "/redirect/invalid")) { - const invalid_port = try getUnusedTcpPort(); - const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port}); - defer salloc.free(location); - - res.status = .found; - try res.headers.append("location", location); - try res.send(); - try res.finish(); - } else { - res.status = .not_found; - try res.send(); - } -} - -var handle_new_requests = true; - -fn runServer(srv: *Server) !void { - outer: while (handle_new_requests) { - var res = try srv.accept(.{ - .allocator = salloc, - .header_strategy = .{ .dynamic = max_header_size }, - }); - defer res.deinit(); - - while (res.reset() != .closing) { - res.wait() catch |err| switch (err) { - error.HttpHeadersInvalid => continue :outer, - error.EndOfStream => continue, - else => return err, - }; - - try handleRequest(&res); - } - } -} - -fn serverThread(srv: *Server) void { - defer srv.deinit(); - defer _ = gpa_server.deinit(); - - runServer(srv) catch |err| { - std.debug.print("server error: {}\n", .{err}); - - if (@errorReturnTrace()) |trace| { - std.debug.dumpStackTrace(trace.*); - } - - _ = gpa_server.deinit(); - std.os.exit(1); - }; -} - -fn killServer(addr: std.net.Address) void { - handle_new_requests = false; - - const conn = std.net.tcpConnectToAddress(addr) catch return; - conn.close(); -} - -fn getUnusedTcpPort() !u16 { - const addr = try std.net.Address.parseIp("127.0.0.1", 0); - var s = std.net.StreamServer.init(.{}); - defer s.deinit(); - try s.listen(addr); - return s.listen_address.in.getPort(); -} - -pub fn main() !void { - const log = std.log.scoped(.client); - - defer _ = gpa_client.deinit(); - - server = Server.init(.{ .reuse_address = true }); - - const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; - try server.listen(addr); - - const port = server.socket.listen_address.getPort(); - - const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); - - var client = Client{ .allocator = calloc }; - errdefer client.deinit(); - // defer client.deinit(); handled below - - try client.loadDefaultProxies(); - - { // read content-length response - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // read large content-length response - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/large", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192 * 1024); - defer calloc.free(body); - - try testing.expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // send head request and not read chunked - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.HEAD, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("", body); - try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); - try testing.expectEqualStrings("14", req.response.headers.getFirstValue("content-length").?); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // read chunked response - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get?chunked", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // send head request and not read chunked - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get?chunked", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.HEAD, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("", body); - try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); - try testing.expectEqualStrings("chunked", req.response.headers.getFirstValue("transfer-encoding").?); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // check trailing headers - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/trailer", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // send content-length request - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("content-type", "text/plain"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.POST, uri, h, .{}); - defer req.deinit(); - - req.transfer_encoding = .{ .content_length = 14 }; - - try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); - - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // read content-length response with connection close - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("connection", "close"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); - } - - // connection has been closed - try testing.expect(client.connection_pool.free_len == 0); - - { // send chunked request - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("content-type", "text/plain"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.POST, uri, h, .{}); - defer req.deinit(); - - req.transfer_encoding = .chunked; - - try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); - - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // relative redirect - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/1", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // redirect from root - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/2", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // absolute redirect - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/3", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - try req.wait(); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // too many redirects - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/4", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - req.wait() catch |err| switch (err) { - error.TooManyHttpRedirects => {}, - else => return err, - }; - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // check client without segfault by connection error after redirection - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.GET, uri, h, .{}); - defer req.deinit(); - - try req.send(.{}); - const result = req.wait(); - - // a proxy without an upstream is likely to return a 5xx status. - if (client.http_proxy == null) { - try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error - } - } - - // connection has been kept alive - try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // Client.fetch() - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("content-type", "text/plain"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#fetch", .{port}); - defer calloc.free(location); - - log.info("{s}", .{location}); - var res = try client.fetch(calloc, .{ - .location = .{ .url = location }, - .method = .POST, - .headers = h, - .payload = .{ .string = "Hello, World!\n" }, - }); - defer res.deinit(); - - try testing.expectEqualStrings("Hello, World!\n", res.body.?); - } - - { // expect: 100-continue - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("expect", "100-continue"); - try h.append("content-type", "text/plain"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.POST, uri, h, .{}); - defer req.deinit(); - - req.transfer_encoding = .chunked; - - try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); - - try req.wait(); - try testing.expectEqual(http.Status.ok, req.response.status); - - const body = try req.reader().readAllAlloc(calloc, 8192); - defer calloc.free(body); - - try testing.expectEqualStrings("Hello, World!\n", body); - } - - { // expect: garbage - var h = http.Headers{ .allocator = calloc }; - defer h.deinit(); - - try h.append("content-type", "text/plain"); - try h.append("expect", "garbage"); - - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - log.info("{s}", .{location}); - var req = try client.open(.POST, uri, h, .{}); - defer req.deinit(); - - req.transfer_encoding = .chunked; - - try req.send(.{}); - try req.wait(); - try testing.expectEqual(http.Status.expectation_failed, req.response.status); - } - - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** - const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); - defer calloc.free(location); - const uri = try std.Uri.parse(location); - - const total_connections = client.connection_pool.free_size + 64; - var requests = try calloc.alloc(http.Client.Request, total_connections); - defer calloc.free(requests); - - for (0..total_connections) |i| { - var req = try client.open(.GET, uri, .{ .allocator = calloc }, .{}); - req.response.parser.done = true; - req.connection.?.closing = false; - requests[i] = req; - } - - for (0..total_connections) |i| { - requests[i].deinit(); - } - - // free connections should be full now - try testing.expect(client.connection_pool.free_len == client.connection_pool.free_size); - } - - client.deinit(); - - killServer(server.socket.listen_address); - server_thread.join(); -}