diff options
-rw-r--r-- | src/http_client_async.zig | 133 | ||||
-rw-r--r-- | src/thread_pool.zig | 10 |
2 files changed, 126 insertions, 17 deletions
diff --git a/src/http_client_async.zig b/src/http_client_async.zig index b1d50c36c..d8a25a424 100644 --- a/src/http_client_async.zig +++ b/src/http_client_async.zig @@ -67,7 +67,7 @@ else pub const OPEN_SOCKET_FLAGS = SOCK.CLOEXEC; -pub const extremely_verbose = false; +pub const extremely_verbose = Environment.isDebug; fn writeRequest( comptime Writer: type, @@ -113,6 +113,7 @@ socket: AsyncSocket.SSL = undefined, socket_loaded: bool = false, gzip_elapsed: u64 = 0, stage: Stage = Stage.pending, +received_keep_alive: bool = false, /// Some HTTP servers (such as npm) report Last-Modified times but ignore If-Modified-Since. /// This is a workaround for that. @@ -135,7 +136,9 @@ pub fn init( .url = url, .header_entries = header_entries, .header_buf = header_buf, - .socket = undefined, + .socket = AsyncSocket.SSL{ + .socket = undefined, + }, }; } @@ -246,6 +249,62 @@ pub const HTTPChannelContext = struct { } }; +// This causes segfaults when resume connect() +pub const KeepAlive = struct { + const limit = 2; + pub const disabled = true; + fds: [limit]u32 = undefined, + hosts: [limit]u64 = undefined, + ports: [limit]u16 = undefined, + used: u8 = 0, + + pub var instance = KeepAlive{}; + + pub fn append(this: *KeepAlive, host: []const u8, port: u16, fd: os.socket_t) bool { + if (disabled) return false; + if (this.used >= limit or fd > std.math.maxInt(u32)) return false; + + const i = this.used; + const hash = std.hash.Wyhash.hash(0, host); + + this.fds[i] = @truncate(u32, @intCast(u64, fd)); + this.hosts[i] = hash; + this.ports[i] = port; + this.used += 1; + return true; + } + pub fn find(this: *KeepAlive, host: []const u8, port: u16) ?os.socket_t { + if (disabled) return null; + + if (this.used == 0) { + return null; + } + + const hash = std.hash.Wyhash.hash(0, host); + const list = this.hosts[0..this.used]; + for (list) |host_hash, i| { + if (host_hash == hash and this.ports[i] == port) { + const fd = this.fds[i]; + const last = this.used - 1; + + if (i > last) { + const end_host = this.hosts[last]; + const end_fd = this.fds[last]; + const end_port = this.ports[last]; + this.hosts[i] = end_host; + this.fds[i] = end_fd; + this.ports[i] = end_port; + } + this.used -= 1; + + return @intCast(os.socket_t, fd); + } + } + + return null; + } +}; + pub const AsyncHTTP = struct { request: ?picohttp.Request = null, response: ?picohttp.Response = null, @@ -319,6 +378,13 @@ pub const AsyncHTTP = struct { return this; } + fn reset(this: *AsyncHTTP) !void { + const timeout = this.timeout; + this.client = try HTTPClient.init(this.allocator, this.method, this.client.url, this.client.header_entries, this.client.header_buf); + this.client.timeout = timeout; + this.timeout = timeout; + } + pub fn schedule(this: *AsyncHTTP, _: std.mem.Allocator, batch: *ThreadPool.Batch) void { std.debug.assert(NetworkThread.global_loaded.load(.Monotonic) == 1); this.state.store(.scheduled, .Monotonic); @@ -381,6 +447,10 @@ pub const AsyncHTTP = struct { }; pub fn do(sender: *HTTPSender, this: *AsyncHTTP) void { + defer { + NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 }); + } + outer: { this.err = null; this.state.store(.sending, .Monotonic); @@ -394,6 +464,7 @@ pub const AsyncHTTP = struct { if (this.max_retry_count > this.retries_count) { this.retries_count += 1; this.response_buffer.reset(); + NetworkThread.global.pool.schedule(ThreadPool.Batch.from(&this.task)); return; } @@ -408,7 +479,6 @@ pub const AsyncHTTP = struct { if (this.callback) |callback| { callback(this); } - NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 }); } }; @@ -534,20 +604,20 @@ pub fn sendAsync(this: *HTTPClient, body: []const u8, body_out_str: *MutableStri return async this.send(body, body_out_str); } -pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response { - defer { - if (this.socket_loaded) { - this.socket_loaded = false; - this.socket.deinit(); - } +fn maybeClearSocket(this: *HTTPClient) void { + if (this.socket_loaded) { + this.socket_loaded = false; + + this.socket.deinit(); } +} + +pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response { + defer this.maybeClearSocket(); // this prevents stack overflow redirect: while (this.remaining_redirect_count >= -1) { - if (this.socket_loaded) { - this.socket_loaded = false; - this.socket.deinit(); - } + this.maybeClearSocket(); _ = AsyncHTTP.active_requests_count.fetchAdd(1, .Monotonic); defer { @@ -596,7 +666,7 @@ pub fn sendHTTP(this: *HTTPClient, body: []const u8, body_out_str: *MutableStrin var socket = &this.socket.socket; try this.connect(*AsyncSocket, socket); this.stage = Stage.request; - defer this.socket.close(); + defer this.closeSocket(); var request = buildRequest(this, body.len); if (this.verbose) { @@ -673,6 +743,10 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti var location: string = ""; var pretend_its_304 = false; + var maybe_keepalive = false; + errdefer { + maybe_keepalive = false; + } for (response.headers) |header| { switch (hashHeaderName(header.name)) { @@ -707,6 +781,13 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti location_header_hash => { location = header.value; }, + hashHeaderName("Connection") => { + if (response.status_code >= 200 and response.status_code <= 299 and !KeepAlive.disabled) { + if (strings.eqlComptime(header.value, "keep-alive")) { + maybe_keepalive = true; + } + } + }, hashHeaderName("Last-Modified") => { if (this.force_last_modified and response.status_code > 199 and response.status_code < 300 and this.if_modified_since.len > 0) { if (strings.eql(this.if_modified_since, header.value)) { @@ -774,6 +855,7 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti if (response.status_code == 304) break :body_getter; if (transfer_encoding == Encoding.chunked) { + maybe_keepalive = false; var decoder = std.mem.zeroes(picohttp.phr_chunked_decoder); var buffer_: *MutableString = body_out_str; @@ -1020,9 +1102,30 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti this.progress_node.?.context.maybeRefresh(); } + if (maybe_keepalive and response.status_code >= 200 and response.status_code < 300) { + this.received_keep_alive = true; + } + return response; } +pub fn closeSocket(this: *HTTPClient) void { + if (this.received_keep_alive) { + this.received_keep_alive = false; + if (this.url.hostname.len > 0 and this.socket.socket.socket > 0) { + if (!this.socket.connect_frame.wait and + (!this.socket.ssl_bio_loaded or + (this.socket.ssl_bio.pending_sends == 0 and this.socket.ssl_bio.pending_reads == 0))) + { + if (KeepAlive.instance.append(this.url.hostname, this.url.getPortAuto(), this.socket.socket.socket)) { + this.socket.socket.socket = 0; + } + } + } + } + this.socket.close(); +} + pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *MutableString) !picohttp.Response { this.socket = try AsyncSocket.SSL.init(default_allocator, &AsyncIO.global); this.socket_loaded = true; @@ -1031,7 +1134,7 @@ pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *Mutable this.stage = Stage.connect; try this.connect(*AsyncSocket.SSL, socket); this.stage = Stage.request; - defer this.socket.close(); + defer this.closeSocket(); var request = buildRequest(this, body_str.len); if (this.verbose) { diff --git a/src/thread_pool.zig b/src/thread_pool.zig index c398151ca..dc3ff65db 100644 --- a/src/thread_pool.zig +++ b/src/thread_pool.zig @@ -269,8 +269,14 @@ fn _wait(self: *ThreadPool, _is_waking: bool, comptime sleep_on_idle: bool) erro const end_count = HTTP.AsyncHTTP.active_requests_count.loadUnchecked(); if (end_count > 0) { - while (HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) { - io.run_for_ns(std.time.ns_per_ms) catch {}; + if (comptime sleep_on_idle) { + idle_network_ticks = 0; + } + + var remaining_ticks: i32 = 5; + + while (remaining_ticks > 0 and HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) : (remaining_ticks -= 1) { + io.run_for_ns(std.time.ns_per_ms * 2) catch {}; io.tick() catch {}; } } |