diff options
Diffstat (limited to '')
-rw-r--r-- | src/http/async_socket.zig | 520 |
1 files changed, 520 insertions, 0 deletions
diff --git a/src/http/async_socket.zig b/src/http/async_socket.zig new file mode 100644 index 000000000..83ee904aa --- /dev/null +++ b/src/http/async_socket.zig @@ -0,0 +1,520 @@ +const boring = @import("boringssl"); +const std = @import("std"); +const AsyncIO = @import("io"); +const AsyncMessage = @import("./async_message.zig"); +const AsyncBIO = @import("./async_bio.zig"); +const Completion = AsyncIO.Completion; +const AsyncSocket = @This(); + +const Output = @import("../global.zig").Output; +const NetworkThread = @import("../network_thread.zig"); +const Environment = @import("../global.zig").Environment; + +const extremely_verbose = @import("../http_client_async.zig").extremely_verbose; +const SOCKET_FLAGS: u32 = @import("../http_client_async.zig").SOCKET_FLAGS; +const getAllocator = @import("../http_client_async.zig").getAllocator; +const OPEN_SOCKET_FLAGS: u32 = @import("../http_client_async.zig").OPEN_SOCKET_FLAGS; + +io: *AsyncIO = undefined, +socket: std.os.socket_t = 0, +head: *AsyncMessage = undefined, +tail: *AsyncMessage = undefined, +allocator: std.mem.Allocator, +err: ?anyerror = null, +queued: usize = 0, +sent: usize = 0, +send_frame: @Frame(AsyncSocket.send) = undefined, +read_frame: @Frame(AsyncSocket.read) = undefined, +connect_frame: @Frame(AsyncSocket.connectToAddress) = undefined, +close_frame: @Frame(AsyncSocket.close) = undefined, + +read_context: []u8 = undefined, +read_offset: u64 = 0, +read_completion: AsyncIO.Completion = undefined, +connect_completion: AsyncIO.Completion = undefined, +close_completion: AsyncIO.Completion = undefined, + +const ConnectError = AsyncIO.ConnectError || std.os.SocketError || std.os.SetSockOptError || error{UnknownHostName}; + +pub fn init(io: *AsyncIO, socket: std.os.socket_t, allocator: std.mem.Allocator) !AsyncSocket { + var head = AsyncMessage.get(allocator); + + return AsyncSocket{ .io = io, .socket = socket, .head = head, .tail = head, .allocator = allocator }; +} + +fn on_connect(this: *AsyncSocket, _: *Completion, err: ConnectError!void) void { + err catch |resolved_err| { + this.err = resolved_err; + }; + + resume this.connect_frame; +} + +fn connectToAddress(this: *AsyncSocket, address: std.net.Address) ConnectError!void { + const sockfd = AsyncIO.openSocket(address.any.family, OPEN_SOCKET_FLAGS | std.os.SOCK.STREAM, std.os.IPPROTO.TCP) catch |err| { + if (extremely_verbose) { + Output.prettyErrorln("openSocket error: {s}", .{@errorName(err)}); + } + + return error.ConnectionRefused; + }; + + this.io.connect(*AsyncSocket, this, on_connect, &this.connect_completion, sockfd, address); + suspend { + this.connect_frame = @frame().*; + } + + if (this.err) |e| { + return @errSetCast(ConnectError, e); + } + + this.socket = sockfd; + return; +} + +fn on_close(this: *AsyncSocket, _: *Completion, _: AsyncIO.CloseError!void) void { + resume this.close_frame; +} + +pub fn close(this: *AsyncSocket) void { + if (this.socket == 0) return; + this.io.close(*AsyncSocket, this, on_close, &this.close_completion, this.socket); + suspend { + this.close_frame = @frame().*; + } + this.socket = 0; +} + +pub fn connect(this: *AsyncSocket, name: []const u8, port: u16) ConnectError!void { + this.socket = 0; + outer: while (true) { + // on macOS, getaddrinfo() is very slow + // If you send ~200 network requests, about 1.5s is spent on getaddrinfo() + // So, we cache this. + var address_list = NetworkThread.getAddressList(getAllocator(), name, port) catch |err| { + return @errSetCast(ConnectError, err); + }; + + const list = address_list.address_list; + if (list.addrs.len == 0) return error.ConnectionRefused; + + try_cached_index: { + if (address_list.index) |i| { + const address = list.addrs[i]; + if (address_list.invalidated) continue :outer; + + this.connectToAddress(address) catch |err| { + if (err == error.ConnectionRefused) { + address_list.index = null; + break :try_cached_index; + } + + address_list.invalidate(); + continue :outer; + }; + } + } + + for (list.addrs) |address, i| { + if (address_list.invalidated) continue :outer; + this.connectToAddress(address) catch |err| { + if (err == error.ConnectionRefused) continue; + address_list.invalidate(); + if (err == error.AddressNotAvailable or err == error.UnknownHostName) continue :outer; + return err; + }; + address_list.index = @truncate(u32, i); + return; + } + + if (address_list.invalidated) continue :outer; + + address_list.invalidate(); + return error.ConnectionRefused; + } +} + +fn on_send(msg: *AsyncMessage, _: *Completion, result: SendError!usize) void { + var this = @ptrCast(*AsyncSocket, @alignCast(@alignOf(*AsyncSocket), msg.context)); + const written = result catch |err| { + this.err = err; + resume this.send_frame; + return; + }; + + if (written == 0) { + resume this.send_frame; + return; + } + + msg.sent += @truncate(u16, written); + const has_more = msg.used > msg.sent; + this.sent += written; + + if (has_more) { + this.io.send( + *AsyncMessage, + msg, + on_send, + &msg.completion, + this.socket, + msg.slice(), + SOCKET_FLAGS, + ); + } else { + msg.release(); + } + + // complete + if (this.queued <= this.sent) { + resume this.send_frame; + } +} + +pub fn write(this: *AsyncSocket, buf: []const u8) usize { + this.tail.context = this; + + const resp = this.tail.writeAll(buf); + this.queued += resp.written; + + if (resp.overflow) { + var next = AsyncMessage.get(getAllocator()); + this.tail.next = next; + this.tail = next; + + return @as(usize, resp.written) + this.write(buf[resp.written..]); + } + + return @as(usize, resp.written); +} + +pub const SendError = AsyncIO.SendError; + +pub fn deinit(this: *AsyncSocket) void { + this.head.release(); +} + +pub fn send(this: *AsyncSocket) SendError!usize { + const original_sent = this.sent; + this.head.context = this; + + this.io.send( + *AsyncMessage, + this.head, + on_send, + &this.head.completion, + this.socket, + this.head.slice(), + SOCKET_FLAGS, + ); + + var node = this.head; + while (node.next) |element| { + this.io.send( + *AsyncMessage, + element, + on_send, + &element.completion, + this.socket, + element.slice(), + SOCKET_FLAGS, + ); + node = element.next orelse break; + } + + suspend { + this.send_frame = @frame().*; + } + + if (this.err) |err| { + this.err = null; + return @errSetCast(AsyncSocket.SendError, err); + } + + return this.sent - original_sent; +} + +pub const RecvError = AsyncIO.RecvError; + +const Reader = struct { + pub fn on_read(ctx: *AsyncSocket, _: *AsyncIO.Completion, result: RecvError!usize) void { + const len = result catch |err| { + ctx.err = err; + resume ctx.read_frame; + return; + }; + ctx.read_offset += len; + resume ctx.read_frame; + } +}; + +pub fn read( + this: *AsyncSocket, + bytes: []u8, + offset: u64, +) RecvError!u64 { + this.read_context = bytes; + this.read_offset = offset; + const original_read_offset = this.read_offset; + + this.io.recv( + *AsyncSocket, + this, + Reader.on_read, + &this.read_completion, + this.socket, + bytes, + ); + + suspend { + this.read_frame = @frame().*; + } + + if (this.err) |err| { + this.err = null; + return @errSetCast(RecvError, err); + } + + return this.read_offset - original_read_offset; +} + +pub const SSL = struct { + ssl: *boring.SSL = undefined, + ssl_loaded: bool = false, + socket: AsyncSocket, + handshake_complete: bool = false, + ssl_bio: ?*AsyncBIO = null, + read_bio: ?*AsyncMessage = null, + handshake_frame: @Frame(SSL.handshake) = undefined, + send_frame: @Frame(SSL.send) = undefined, + read_frame: @Frame(SSL.read) = undefined, + hostname: [std.fs.MAX_PATH_BYTES]u8 = undefined, + is_ssl: bool = false, + + const SSLConnectError = ConnectError || HandshakeError; + const HandshakeError = error{OpenSSLError}; + + pub fn connect(this: *SSL, name: []const u8, port: u16) !void { + this.is_ssl = true; + try this.socket.connect(name, port); + + this.handshake_complete = false; + + var ssl = boring.initClient(); + this.ssl = ssl; + this.ssl_loaded = true; + errdefer { + this.ssl_loaded = false; + this.ssl.deinit(); + this.ssl = undefined; + } + + { + std.mem.copy(u8, &this.hostname, name); + this.hostname[name.len] = 0; + var name_ = this.hostname[0..name.len :0]; + ssl.setHostname(name_); + } + + var bio = try AsyncBIO.init(this.socket.allocator); + bio.socket_fd = this.socket.socket; + this.ssl_bio = bio; + + boring.SSL_set_bio(ssl, bio.bio, bio.bio); + + this.read_bio = AsyncMessage.get(this.socket.allocator); + try this.handshake(); + } + + pub fn close(this: *SSL) void { + this.socket.close(); + } + + fn handshake(this: *SSL) HandshakeError!void { + while (!this.ssl.isInitFinished()) { + boring.ERR_clear_error(); + this.ssl_bio.?.enqueueSend(); + const handshake_result = boring.SSL_connect(this.ssl); + if (handshake_result == 0) { + Output.prettyErrorln("ssl accept error", .{}); + Output.flush(); + return error.OpenSSLError; + } + this.handshake_complete = handshake_result == 1 and this.ssl.isInitFinished(); + + if (!this.handshake_complete) { + // accept_result < 0 + const e = boring.SSL_get_error(this.ssl, handshake_result); + if ((e == boring.SSL_ERROR_WANT_READ or e == boring.SSL_ERROR_WANT_WRITE)) { + this.ssl_bio.?.enqueueSend(); + suspend { + this.handshake_frame = @frame().*; + this.ssl_bio.?.pushPendingFrame(&this.handshake_frame); + } + + continue; + } + + Output.prettyErrorln("ssl accept error = {}, return val was {}", .{ e, handshake_result }); + Output.flush(); + return error.OpenSSLError; + } + } + } + + pub fn write(this: *SSL, buffer_: []const u8) usize { + var buffer = buffer_; + var read_bio = this.read_bio; + while (buffer.len > 0) { + const response = read_bio.?.writeAll(buffer); + buffer = buffer[response.written..]; + if (response.overflow) { + read_bio = read_bio.?.next orelse brk: { + read_bio.?.next = AsyncMessage.get(this.socket.allocator); + break :brk read_bio.?.next.?; + }; + } + } + + return buffer_.len; + } + + pub fn send(this: *SSL) !usize { + var bio_ = this.read_bio; + var len: usize = 0; + while (bio_) |bio| { + var slice = bio.slice(); + len += this.ssl.write(slice) catch |err| { + switch (err) { + error.WantRead => { + suspend { + this.send_frame = @frame().*; + this.ssl_bio.?.pushPendingFrame(&this.send_frame); + } + continue; + }, + error.WantWrite => { + this.ssl_bio.?.enqueueSend(); + + suspend { + this.send_frame = @frame().*; + this.ssl_bio.?.pushPendingFrame(&this.send_frame); + } + continue; + }, + else => {}, + } + + if (comptime Environment.isDebug) { + Output.prettyErrorln("SSL error: {s} (buf: {s})\n URL:", .{ + @errorName(err), + bio.slice(), + }); + Output.flush(); + } + + return err; + }; + + bio_ = bio.next; + } + return len; + } + + pub fn read(this: *SSL, buf_: []u8, offset: u64) !u64 { + var buf = buf_[offset..]; + var len: usize = 0; + while (buf.len > 0) { + this.ssl_bio.?.read_buf_len = buf.len; + len = this.ssl.read(buf) catch |err| { + switch (err) { + error.WantWrite => { + this.ssl_bio.?.enqueueSend(); + + if (extremely_verbose) { + Output.prettyErrorln( + "error: {s}: \n Read Wait: {s}\n Send Wait: {s}", + .{ + @errorName(err), + @tagName(this.ssl_bio.?.read_wait), + @tagName(this.ssl_bio.?.send_wait), + }, + ); + Output.flush(); + } + + suspend { + this.read_frame = @frame().*; + this.ssl_bio.?.pushPendingFrame(&this.read_frame); + } + continue; + }, + error.WantRead => { + // this.ssl_bio.enqueueSend(); + + if (extremely_verbose) { + Output.prettyErrorln( + "error: {s}: \n Read Wait: {s}\n Send Wait: {s}", + .{ + @errorName(err), + @tagName(this.ssl_bio.?.read_wait), + @tagName(this.ssl_bio.?.send_wait), + }, + ); + Output.flush(); + } + + suspend { + this.read_frame = @frame().*; + this.ssl_bio.?.pushPendingFrame(&this.read_frame); + } + continue; + }, + else => return err, + } + unreachable; + }; + + break; + } + + return len; + } + + pub inline fn init(allocator: std.mem.Allocator, io: *AsyncIO) !SSL { + return SSL{ + .socket = try AsyncSocket.init(io, 0, allocator), + }; + } + + pub fn deinit(this: *SSL) void { + this.socket.deinit(); + if (!this.is_ssl) return; + + if (this.ssl_bio) |bio| { + _ = boring.BIO_set_data(bio.bio, null); + bio.pending_frame = AsyncBIO.PendingFrame.init(); + bio.socket_fd = 0; + bio.release(); + this.ssl_bio = null; + } + + if (this.ssl_loaded) { + this.ssl.deinit(); + this.ssl_loaded = false; + } + + this.handshake_complete = false; + + if (this.read_bio) |bio| { + var next_ = bio.next; + while (next_) |next| { + next.release(); + next_ = next.next; + } + + bio.release(); + this.read_bio = null; + } + } +}; |