Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions lib/std/net.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ const net = @This();
const mem = std.mem;
const os = std.os;
const fs = std.fs;
const io = std.io;

pub const has_unix_sockets = @hasDecl(os, "sockaddr_un");
// Windows 10 added support for unix sockets in build 17063, redstone 4 is the
// first release to support them.
pub const has_unix_sockets = @hasDecl(os, "sockaddr_un") and
(builtin.os.tag != .windows or
std.Target.current.os.version_range.windows.isAtLeast(.win10_rs4) orelse false);

pub const Address = extern union {
any: os.sockaddr,
Expand Down Expand Up @@ -596,7 +601,7 @@ pub const Ip6Address = extern struct {
}
};

pub fn connectUnixSocket(path: []const u8) !fs.File {
pub fn connectUnixSocket(path: []const u8) !Stream {
const opt_non_block = if (std.io.is_async) os.SOCK_NONBLOCK else 0;
const sockfd = try os.socket(
os.AF_UNIX,
Expand All @@ -614,7 +619,7 @@ pub fn connectUnixSocket(path: []const u8) !fs.File {
try os.connect(sockfd, &addr.any, addr.getOsSockLen());
}

return fs.File{
return Stream{
.handle = sockfd,
};
}
Expand Down Expand Up @@ -648,7 +653,7 @@ pub const AddressList = struct {
};

/// All memory allocated with `allocator` will be freed before this function returns.
pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16) !fs.File {
pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16) !Stream {
const list = try getAddressList(allocator, name, port);
defer list.deinit();

Expand All @@ -665,7 +670,7 @@ pub fn tcpConnectToHost(allocator: *mem.Allocator, name: []const u8, port: u16)
return std.os.ConnectError.ConnectionRefused;
}

pub fn tcpConnectToAddress(address: Address) !fs.File {
pub fn tcpConnectToAddress(address: Address) !Stream {
const nonblock = if (std.io.is_async) os.SOCK_NONBLOCK else 0;
const sock_flags = os.SOCK_STREAM | nonblock |
(if (builtin.os.tag == .windows) 0 else os.SOCK_CLOEXEC);
Expand All @@ -679,7 +684,7 @@ pub fn tcpConnectToAddress(address: Address) !fs.File {
try os.connect(sockfd, &address.any, address.getOsSockLen());
}

return fs.File{ .handle = sockfd };
return Stream{ .handle = sockfd };
}

/// Call `AddressList.deinit` on the result.
Expand Down Expand Up @@ -1580,6 +1585,55 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8)
}
}

pub const Stream = struct {
// Underlying socket descriptor.
// Note that on some platforms this may not be interchangeable with a
// regular files descriptor.
handle: os.socket_t,

pub fn close(self: Stream) void {
os.closeSocket(self.handle);
}

pub const ReadError = os.ReadError;
pub const WriteError = os.WriteError;

pub const Reader = io.Reader(Stream, ReadError, read);
pub const Writer = io.Writer(Stream, WriteError, write);

pub fn reader(self: Stream) Reader {
return .{ .context = self };
}

pub fn writer(self: Stream) Writer {
return .{ .context = self };
}

pub fn read(self: Stream, buffer: []u8) ReadError!usize {
if (std.Target.current.os.tag == .windows) {
return os.windows.ReadFile(self.handle, buffer, null, io.default_mode);
}

if (std.io.is_async) {
return std.event.Loop.instance.?.read(self.handle, buffer, false);
} else {
return os.read(self.handle, buffer);
}
}

pub fn write(self: Stream, buffer: []const u8) WriteError!usize {
if (std.Target.current.os.tag == .windows) {
return os.windows.WriteFile(self.handle, buffer, null, io.default_mode);
}

if (std.io.is_async) {
return std.event.Loop.instance.?.write(self.handle, buffer, false);
} else {
return os.write(self.handle, buffer);
}
}
};

pub const StreamServer = struct {
/// Copied from `Options` on `init`.
kernel_backlog: u31,
Expand Down Expand Up @@ -1686,7 +1740,7 @@ pub const StreamServer = struct {
} || os.UnexpectedError;

pub const Connection = struct {
file: fs.File,
stream: Stream,
address: Address,
};

Expand All @@ -1705,7 +1759,7 @@ pub const StreamServer = struct {

if (accept_result) |fd| {
return Connection{
.file = fs.File{ .handle = fd },
.stream = Stream{ .handle = fd },
.address = accepted_addr,
};
} else |err| switch (err) {
Expand Down
50 changes: 47 additions & 3 deletions lib/std/net/test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ test "listen on a port, send bytes, receive bytes" {

// Try only the IPv4 variant as some CI builders have no IPv6 localhost
// configured.
const localhost = try net.Address.parseIp("127.0.0.1", 8080);
const localhost = try net.Address.parseIp("127.0.0.1", 0);

var server = net.StreamServer.init(.{});
defer server.deinit();
Expand All @@ -165,8 +165,9 @@ test "listen on a port, send bytes, receive bytes" {
defer t.wait();

var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.file.reader().read(&buf);
const n = try client.stream.reader().read(&buf);

testing.expectEqual(@as(usize, 12), n);
testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
Expand Down Expand Up @@ -249,6 +250,49 @@ fn testServer(server: *net.StreamServer) anyerror!void {

var client = try server.accept();

const stream = client.file.writer();
const stream = client.stream.writer();
try stream.print("hello from server\n", .{});
}

test "listen on a unix socket, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;

if (std.builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (std.builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}

var server = net.StreamServer.init(.{});
defer server.deinit();

const socket_path = "socket.unix";

var socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
try server.listen(socket_addr);

const S = struct {
fn clientFn(_: void) !void {
const socket = try net.connectUnixSocket(socket_path);
defer socket.close();

_ = try socket.writer().writeAll("Hello world!");
}
};

const t = try std.Thread.spawn({}, S.clientFn);
defer t.wait();

var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.stream.reader().read(&buf);

testing.expectEqual(@as(usize, 12), n);
testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
}