diff options
author | 2023-02-27 21:10:03 -0300 | |
---|---|---|
committer | 2023-02-27 16:10:03 -0800 | |
commit | 0afb1693d370c0dd1340b3eb7659d31ef3fc94a3 (patch) | |
tree | 1efa66eafd12df9b00730fe3270888f4657b74f7 /src/http_client_async.zig | |
parent | 7a4ac03338515b5153e7341dc827919d2f6b5245 (diff) | |
download | bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.gz bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.zst bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.zip |
fix(fetch.signal) capture socket and shutdown on abort signal (#2143)
* capture socket and shutdown on abort signal
* queue shutdown tasks
* little cleanup
* change shutdown strategy
* use fetchSwapRemove on fetch shutdown
* use fetchSwapRemove on fetch shutdown
* fix formatting, remove unused property
Diffstat (limited to 'src/http_client_async.zig')
-rw-r--r-- | src/http_client_async.zig | 107 |
1 files changed, 67 insertions, 40 deletions
diff --git a/src/http_client_async.zig b/src/http_client_async.zig index 41684acde..0f6045bbb 100644 --- a/src/http_client_async.zig +++ b/src/http_client_async.zig @@ -39,6 +39,9 @@ const Batch = NetworkThread.Batch; const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion; const DeadSocket = opaque {}; var dead_socket = @intToPtr(*DeadSocket, 1); +//TODO: this needs to be freed when Worker Threads are implemented +var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, *uws.Socket).init(bun.default_allocator); +var async_http_id: std.atomic.Atomic(u32) = std.atomic.Atomic(u32).init(0); const print_every = 0; var print_every_i: usize = 0; @@ -475,6 +478,7 @@ fn NewHTTPContext(comptime ssl: bool) type { const UnboundedQueue = @import("./bun.js/unbounded_queue.zig").UnboundedQueue; const Queue = UnboundedQueue(AsyncHTTP, .next); +const ShutdownQueue = UnboundedQueue(AsyncHTTP, .next); pub const HTTPThread = struct { var http_thread_loaded: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false); @@ -484,6 +488,7 @@ pub const HTTPThread = struct { https_context: NewHTTPContext(true), queued_tasks: Queue = Queue{}, + queued_shutdowns: ShutdownQueue = ShutdownQueue{}, has_awoken: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), timer: std.time.Timer = undefined, const threadlog = Output.scoped(.HTTPThread, true); @@ -552,6 +557,18 @@ pub const HTTPThread = struct { } fn drainEvents(this: *@This()) void { + while (this.queued_shutdowns.pop()) |http| { + if (socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| { + if (http.client.isHTTPS()) { + const socket = uws.SocketTLS.from(socket_ptr.value); + socket.shutdown(); + } else { + const socket = uws.SocketTCP.from(socket_ptr.value); + socket.shutdown(); + } + } + } + var count: usize = 0; var remaining: usize = AsyncHTTP.max_simultaneous_requests - AsyncHTTP.active_requests_count.loadUnchecked(); if (remaining == 0) return; @@ -600,6 +617,13 @@ pub const HTTPThread = struct { processEvents_(this); unreachable; } + + pub fn scheduleShutdown(this: *@This(), http: *AsyncHTTP) void { + this.queued_shutdowns.push(http); + if (this.has_awoken.load(.Monotonic)) + this.loop.wakeup(); + } + pub fn schedule(this: *@This(), batch: Batch) void { if (batch.len == 0) return; @@ -632,7 +656,9 @@ pub fn onOpen( std.debug.assert(is_ssl == client.url.isHTTPS()); } } - + if (client.aborted != null) { + socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable; + } log("Connected {s} \n", .{client.url.href}); if (client.hasSignalAborted()) { @@ -1000,22 +1026,16 @@ proxy_authorization: ?[]u8 = null, proxy_tunneling: bool = false, proxy_tunnel: ?ProxyTunnel = null, aborted: ?*std.atomic.Atomic(bool) = null, - -pub fn init( - allocator: std.mem.Allocator, - method: Method, - url: URL, - header_entries: Headers.Entries, - header_buf: string, - signal: ?*std.atomic.Atomic(bool), -) HTTPClient { - return HTTPClient{ - .allocator = allocator, - .method = method, - .url = url, - .header_entries = header_entries, - .header_buf = header_buf, - .aborted = signal, +async_http_id: u32 = 0, + +pub fn init(allocator: std.mem.Allocator, method: Method, url: URL, header_entries: Headers.Entries, header_buf: string, signal: ?*std.atomic.Atomic(bool)) HTTPClient { + return HTTPClient { + .allocator = allocator, + .method = method, + .url = url, + .header_entries = header_entries, + .header_buf = header_buf, + .aborted = signal, }; } @@ -1170,6 +1190,7 @@ pub const AsyncHTTP = struct { client: HTTPClient = undefined, err: ?anyerror = null, + async_http_id: u32 = 0, state: AtomicState = AtomicState.init(State.pending), elapsed: u64 = 0, @@ -1207,18 +1228,10 @@ pub const AsyncHTTP = struct { http_proxy: ?URL, signal: ?*std.atomic.Atomic(bool), ) AsyncHTTP { - var this = AsyncHTTP{ - .allocator = allocator, - .url = url, - .method = method, - .request_headers = headers, - .request_header_buf = headers_buf, - .request_body = request_body, - .response_buffer = response_buffer, - .completion_callback = callback, - .http_proxy = http_proxy, - }; + var this = AsyncHTTP{ .allocator = allocator, .url = url, .method = method, .request_headers = headers, .request_header_buf = headers_buf, .request_body = request_body, .response_buffer = response_buffer, .completion_callback = callback, .http_proxy = http_proxy, .async_http_id = if (signal != null) async_http_id.fetchAdd(1, .Monotonic) else 0 }; + this.client = HTTPClient.init(allocator, method, url, headers, headers_buf, signal); + this.client.async_http_id = this.async_http_id; this.client.timeout = timeout; this.client.http_proxy = this.http_proxy; if (http_proxy) |proxy| { @@ -1499,27 +1512,33 @@ pub fn doRedirect(this: *HTTPClient) void { tunnel.deinit(); this.proxy_tunnel = null; } + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } return this.start("", body_out_str); } - +pub fn isHTTPS(this: *HTTPClient) bool { + if (this.http_proxy) |proxy| { + if (proxy.isHTTPS()) { + return true; + } + return false; + } + if (this.url.isHTTPS()) { + return true; + } + return false; +} pub fn start(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) void { body_out_str.reset(); std.debug.assert(this.state.response_message_buffer.list.capacity == 0); this.state = InternalState.init(body, body_out_str); - if (this.http_proxy) |proxy| { - if (proxy.isHTTPS()) { - this.start_(true); - } else { - this.start_(false); - } + if (this.isHTTPS()) { + this.start_(true); } else { - if (this.url.isHTTPS()) { - this.start_(true); - } else { - this.start_(false); - } + this.start_(false); } } @@ -1538,6 +1557,7 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void { if (socket.isClosed() and (this.state.response_stage != .done and this.state.response_stage != .fail)) { this.fail(error.ConnectionClosed); std.debug.assert(this.state.fail != error.NoError); + return; } } @@ -2113,6 +2133,9 @@ pub fn closeAndAbort(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo } fn fail(this: *HTTPClient, err: anyerror) void { + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } this.state.request_stage = .fail; this.state.response_stage = .fail; this.state.fail = err; @@ -2157,6 +2180,10 @@ pub fn setTimeout(this: *HTTPClient, socket: anytype, amount: c_uint) void { pub fn done(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket) void { if (this.state.stage != .done and this.state.stage != .fail) { + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } + log("done", .{}); var out_str = this.state.body_out_str.?; |