aboutsummaryrefslogtreecommitdiff
path: root/src/http/async_socket.zig
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/http/async_socket.zig520
1 files changed, 520 insertions, 0 deletions
diff --git a/src/http/async_socket.zig b/src/http/async_socket.zig
new file mode 100644
index 000000000..83ee904aa
--- /dev/null
+++ b/src/http/async_socket.zig
@@ -0,0 +1,520 @@
+const boring = @import("boringssl");
+const std = @import("std");
+const AsyncIO = @import("io");
+const AsyncMessage = @import("./async_message.zig");
+const AsyncBIO = @import("./async_bio.zig");
+const Completion = AsyncIO.Completion;
+const AsyncSocket = @This();
+
+const Output = @import("../global.zig").Output;
+const NetworkThread = @import("../network_thread.zig");
+const Environment = @import("../global.zig").Environment;
+
+const extremely_verbose = @import("../http_client_async.zig").extremely_verbose;
+const SOCKET_FLAGS: u32 = @import("../http_client_async.zig").SOCKET_FLAGS;
+const getAllocator = @import("../http_client_async.zig").getAllocator;
+const OPEN_SOCKET_FLAGS: u32 = @import("../http_client_async.zig").OPEN_SOCKET_FLAGS;
+
+io: *AsyncIO = undefined,
+socket: std.os.socket_t = 0,
+head: *AsyncMessage = undefined,
+tail: *AsyncMessage = undefined,
+allocator: std.mem.Allocator,
+err: ?anyerror = null,
+queued: usize = 0,
+sent: usize = 0,
+send_frame: @Frame(AsyncSocket.send) = undefined,
+read_frame: @Frame(AsyncSocket.read) = undefined,
+connect_frame: @Frame(AsyncSocket.connectToAddress) = undefined,
+close_frame: @Frame(AsyncSocket.close) = undefined,
+
+read_context: []u8 = undefined,
+read_offset: u64 = 0,
+read_completion: AsyncIO.Completion = undefined,
+connect_completion: AsyncIO.Completion = undefined,
+close_completion: AsyncIO.Completion = undefined,
+
+const ConnectError = AsyncIO.ConnectError || std.os.SocketError || std.os.SetSockOptError || error{UnknownHostName};
+
+pub fn init(io: *AsyncIO, socket: std.os.socket_t, allocator: std.mem.Allocator) !AsyncSocket {
+ var head = AsyncMessage.get(allocator);
+
+ return AsyncSocket{ .io = io, .socket = socket, .head = head, .tail = head, .allocator = allocator };
+}
+
+fn on_connect(this: *AsyncSocket, _: *Completion, err: ConnectError!void) void {
+ err catch |resolved_err| {
+ this.err = resolved_err;
+ };
+
+ resume this.connect_frame;
+}
+
+fn connectToAddress(this: *AsyncSocket, address: std.net.Address) ConnectError!void {
+ const sockfd = AsyncIO.openSocket(address.any.family, OPEN_SOCKET_FLAGS | std.os.SOCK.STREAM, std.os.IPPROTO.TCP) catch |err| {
+ if (extremely_verbose) {
+ Output.prettyErrorln("openSocket error: {s}", .{@errorName(err)});
+ }
+
+ return error.ConnectionRefused;
+ };
+
+ this.io.connect(*AsyncSocket, this, on_connect, &this.connect_completion, sockfd, address);
+ suspend {
+ this.connect_frame = @frame().*;
+ }
+
+ if (this.err) |e| {
+ return @errSetCast(ConnectError, e);
+ }
+
+ this.socket = sockfd;
+ return;
+}
+
+fn on_close(this: *AsyncSocket, _: *Completion, _: AsyncIO.CloseError!void) void {
+ resume this.close_frame;
+}
+
+pub fn close(this: *AsyncSocket) void {
+ if (this.socket == 0) return;
+ this.io.close(*AsyncSocket, this, on_close, &this.close_completion, this.socket);
+ suspend {
+ this.close_frame = @frame().*;
+ }
+ this.socket = 0;
+}
+
+pub fn connect(this: *AsyncSocket, name: []const u8, port: u16) ConnectError!void {
+ this.socket = 0;
+ outer: while (true) {
+ // on macOS, getaddrinfo() is very slow
+ // If you send ~200 network requests, about 1.5s is spent on getaddrinfo()
+ // So, we cache this.
+ var address_list = NetworkThread.getAddressList(getAllocator(), name, port) catch |err| {
+ return @errSetCast(ConnectError, err);
+ };
+
+ const list = address_list.address_list;
+ if (list.addrs.len == 0) return error.ConnectionRefused;
+
+ try_cached_index: {
+ if (address_list.index) |i| {
+ const address = list.addrs[i];
+ if (address_list.invalidated) continue :outer;
+
+ this.connectToAddress(address) catch |err| {
+ if (err == error.ConnectionRefused) {
+ address_list.index = null;
+ break :try_cached_index;
+ }
+
+ address_list.invalidate();
+ continue :outer;
+ };
+ }
+ }
+
+ for (list.addrs) |address, i| {
+ if (address_list.invalidated) continue :outer;
+ this.connectToAddress(address) catch |err| {
+ if (err == error.ConnectionRefused) continue;
+ address_list.invalidate();
+ if (err == error.AddressNotAvailable or err == error.UnknownHostName) continue :outer;
+ return err;
+ };
+ address_list.index = @truncate(u32, i);
+ return;
+ }
+
+ if (address_list.invalidated) continue :outer;
+
+ address_list.invalidate();
+ return error.ConnectionRefused;
+ }
+}
+
+fn on_send(msg: *AsyncMessage, _: *Completion, result: SendError!usize) void {
+ var this = @ptrCast(*AsyncSocket, @alignCast(@alignOf(*AsyncSocket), msg.context));
+ const written = result catch |err| {
+ this.err = err;
+ resume this.send_frame;
+ return;
+ };
+
+ if (written == 0) {
+ resume this.send_frame;
+ return;
+ }
+
+ msg.sent += @truncate(u16, written);
+ const has_more = msg.used > msg.sent;
+ this.sent += written;
+
+ if (has_more) {
+ this.io.send(
+ *AsyncMessage,
+ msg,
+ on_send,
+ &msg.completion,
+ this.socket,
+ msg.slice(),
+ SOCKET_FLAGS,
+ );
+ } else {
+ msg.release();
+ }
+
+ // complete
+ if (this.queued <= this.sent) {
+ resume this.send_frame;
+ }
+}
+
+pub fn write(this: *AsyncSocket, buf: []const u8) usize {
+ this.tail.context = this;
+
+ const resp = this.tail.writeAll(buf);
+ this.queued += resp.written;
+
+ if (resp.overflow) {
+ var next = AsyncMessage.get(getAllocator());
+ this.tail.next = next;
+ this.tail = next;
+
+ return @as(usize, resp.written) + this.write(buf[resp.written..]);
+ }
+
+ return @as(usize, resp.written);
+}
+
+pub const SendError = AsyncIO.SendError;
+
+pub fn deinit(this: *AsyncSocket) void {
+ this.head.release();
+}
+
+pub fn send(this: *AsyncSocket) SendError!usize {
+ const original_sent = this.sent;
+ this.head.context = this;
+
+ this.io.send(
+ *AsyncMessage,
+ this.head,
+ on_send,
+ &this.head.completion,
+ this.socket,
+ this.head.slice(),
+ SOCKET_FLAGS,
+ );
+
+ var node = this.head;
+ while (node.next) |element| {
+ this.io.send(
+ *AsyncMessage,
+ element,
+ on_send,
+ &element.completion,
+ this.socket,
+ element.slice(),
+ SOCKET_FLAGS,
+ );
+ node = element.next orelse break;
+ }
+
+ suspend {
+ this.send_frame = @frame().*;
+ }
+
+ if (this.err) |err| {
+ this.err = null;
+ return @errSetCast(AsyncSocket.SendError, err);
+ }
+
+ return this.sent - original_sent;
+}
+
+pub const RecvError = AsyncIO.RecvError;
+
+const Reader = struct {
+ pub fn on_read(ctx: *AsyncSocket, _: *AsyncIO.Completion, result: RecvError!usize) void {
+ const len = result catch |err| {
+ ctx.err = err;
+ resume ctx.read_frame;
+ return;
+ };
+ ctx.read_offset += len;
+ resume ctx.read_frame;
+ }
+};
+
+pub fn read(
+ this: *AsyncSocket,
+ bytes: []u8,
+ offset: u64,
+) RecvError!u64 {
+ this.read_context = bytes;
+ this.read_offset = offset;
+ const original_read_offset = this.read_offset;
+
+ this.io.recv(
+ *AsyncSocket,
+ this,
+ Reader.on_read,
+ &this.read_completion,
+ this.socket,
+ bytes,
+ );
+
+ suspend {
+ this.read_frame = @frame().*;
+ }
+
+ if (this.err) |err| {
+ this.err = null;
+ return @errSetCast(RecvError, err);
+ }
+
+ return this.read_offset - original_read_offset;
+}
+
+pub const SSL = struct {
+ ssl: *boring.SSL = undefined,
+ ssl_loaded: bool = false,
+ socket: AsyncSocket,
+ handshake_complete: bool = false,
+ ssl_bio: ?*AsyncBIO = null,
+ read_bio: ?*AsyncMessage = null,
+ handshake_frame: @Frame(SSL.handshake) = undefined,
+ send_frame: @Frame(SSL.send) = undefined,
+ read_frame: @Frame(SSL.read) = undefined,
+ hostname: [std.fs.MAX_PATH_BYTES]u8 = undefined,
+ is_ssl: bool = false,
+
+ const SSLConnectError = ConnectError || HandshakeError;
+ const HandshakeError = error{OpenSSLError};
+
+ pub fn connect(this: *SSL, name: []const u8, port: u16) !void {
+ this.is_ssl = true;
+ try this.socket.connect(name, port);
+
+ this.handshake_complete = false;
+
+ var ssl = boring.initClient();
+ this.ssl = ssl;
+ this.ssl_loaded = true;
+ errdefer {
+ this.ssl_loaded = false;
+ this.ssl.deinit();
+ this.ssl = undefined;
+ }
+
+ {
+ std.mem.copy(u8, &this.hostname, name);
+ this.hostname[name.len] = 0;
+ var name_ = this.hostname[0..name.len :0];
+ ssl.setHostname(name_);
+ }
+
+ var bio = try AsyncBIO.init(this.socket.allocator);
+ bio.socket_fd = this.socket.socket;
+ this.ssl_bio = bio;
+
+ boring.SSL_set_bio(ssl, bio.bio, bio.bio);
+
+ this.read_bio = AsyncMessage.get(this.socket.allocator);
+ try this.handshake();
+ }
+
+ pub fn close(this: *SSL) void {
+ this.socket.close();
+ }
+
+ fn handshake(this: *SSL) HandshakeError!void {
+ while (!this.ssl.isInitFinished()) {
+ boring.ERR_clear_error();
+ this.ssl_bio.?.enqueueSend();
+ const handshake_result = boring.SSL_connect(this.ssl);
+ if (handshake_result == 0) {
+ Output.prettyErrorln("ssl accept error", .{});
+ Output.flush();
+ return error.OpenSSLError;
+ }
+ this.handshake_complete = handshake_result == 1 and this.ssl.isInitFinished();
+
+ if (!this.handshake_complete) {
+ // accept_result < 0
+ const e = boring.SSL_get_error(this.ssl, handshake_result);
+ if ((e == boring.SSL_ERROR_WANT_READ or e == boring.SSL_ERROR_WANT_WRITE)) {
+ this.ssl_bio.?.enqueueSend();
+ suspend {
+ this.handshake_frame = @frame().*;
+ this.ssl_bio.?.pushPendingFrame(&this.handshake_frame);
+ }
+
+ continue;
+ }
+
+ Output.prettyErrorln("ssl accept error = {}, return val was {}", .{ e, handshake_result });
+ Output.flush();
+ return error.OpenSSLError;
+ }
+ }
+ }
+
+ pub fn write(this: *SSL, buffer_: []const u8) usize {
+ var buffer = buffer_;
+ var read_bio = this.read_bio;
+ while (buffer.len > 0) {
+ const response = read_bio.?.writeAll(buffer);
+ buffer = buffer[response.written..];
+ if (response.overflow) {
+ read_bio = read_bio.?.next orelse brk: {
+ read_bio.?.next = AsyncMessage.get(this.socket.allocator);
+ break :brk read_bio.?.next.?;
+ };
+ }
+ }
+
+ return buffer_.len;
+ }
+
+ pub fn send(this: *SSL) !usize {
+ var bio_ = this.read_bio;
+ var len: usize = 0;
+ while (bio_) |bio| {
+ var slice = bio.slice();
+ len += this.ssl.write(slice) catch |err| {
+ switch (err) {
+ error.WantRead => {
+ suspend {
+ this.send_frame = @frame().*;
+ this.ssl_bio.?.pushPendingFrame(&this.send_frame);
+ }
+ continue;
+ },
+ error.WantWrite => {
+ this.ssl_bio.?.enqueueSend();
+
+ suspend {
+ this.send_frame = @frame().*;
+ this.ssl_bio.?.pushPendingFrame(&this.send_frame);
+ }
+ continue;
+ },
+ else => {},
+ }
+
+ if (comptime Environment.isDebug) {
+ Output.prettyErrorln("SSL error: {s} (buf: {s})\n URL:", .{
+ @errorName(err),
+ bio.slice(),
+ });
+ Output.flush();
+ }
+
+ return err;
+ };
+
+ bio_ = bio.next;
+ }
+ return len;
+ }
+
+ pub fn read(this: *SSL, buf_: []u8, offset: u64) !u64 {
+ var buf = buf_[offset..];
+ var len: usize = 0;
+ while (buf.len > 0) {
+ this.ssl_bio.?.read_buf_len = buf.len;
+ len = this.ssl.read(buf) catch |err| {
+ switch (err) {
+ error.WantWrite => {
+ this.ssl_bio.?.enqueueSend();
+
+ if (extremely_verbose) {
+ Output.prettyErrorln(
+ "error: {s}: \n Read Wait: {s}\n Send Wait: {s}",
+ .{
+ @errorName(err),
+ @tagName(this.ssl_bio.?.read_wait),
+ @tagName(this.ssl_bio.?.send_wait),
+ },
+ );
+ Output.flush();
+ }
+
+ suspend {
+ this.read_frame = @frame().*;
+ this.ssl_bio.?.pushPendingFrame(&this.read_frame);
+ }
+ continue;
+ },
+ error.WantRead => {
+ // this.ssl_bio.enqueueSend();
+
+ if (extremely_verbose) {
+ Output.prettyErrorln(
+ "error: {s}: \n Read Wait: {s}\n Send Wait: {s}",
+ .{
+ @errorName(err),
+ @tagName(this.ssl_bio.?.read_wait),
+ @tagName(this.ssl_bio.?.send_wait),
+ },
+ );
+ Output.flush();
+ }
+
+ suspend {
+ this.read_frame = @frame().*;
+ this.ssl_bio.?.pushPendingFrame(&this.read_frame);
+ }
+ continue;
+ },
+ else => return err,
+ }
+ unreachable;
+ };
+
+ break;
+ }
+
+ return len;
+ }
+
+ pub inline fn init(allocator: std.mem.Allocator, io: *AsyncIO) !SSL {
+ return SSL{
+ .socket = try AsyncSocket.init(io, 0, allocator),
+ };
+ }
+
+ pub fn deinit(this: *SSL) void {
+ this.socket.deinit();
+ if (!this.is_ssl) return;
+
+ if (this.ssl_bio) |bio| {
+ _ = boring.BIO_set_data(bio.bio, null);
+ bio.pending_frame = AsyncBIO.PendingFrame.init();
+ bio.socket_fd = 0;
+ bio.release();
+ this.ssl_bio = null;
+ }
+
+ if (this.ssl_loaded) {
+ this.ssl.deinit();
+ this.ssl_loaded = false;
+ }
+
+ this.handshake_complete = false;
+
+ if (this.read_bio) |bio| {
+ var next_ = bio.next;
+ while (next_) |next| {
+ next.release();
+ next_ = next.next;
+ }
+
+ bio.release();
+ this.read_bio = null;
+ }
+ }
+};