diff options
author | 2023-02-27 21:10:03 -0300 | |
---|---|---|
committer | 2023-02-27 16:10:03 -0800 | |
commit | 0afb1693d370c0dd1340b3eb7659d31ef3fc94a3 (patch) | |
tree | 1efa66eafd12df9b00730fe3270888f4657b74f7 | |
parent | 7a4ac03338515b5153e7341dc827919d2f6b5245 (diff) | |
download | bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.gz bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.zst bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.zip |
fix(fetch.signal) capture socket and shutdown on abort signal (#2143)
* capture socket and shutdown on abort signal
* queue shutdown tasks
* little cleanup
* change shutdown strategy
* use fetchSwapRemove on fetch shutdown
* use fetchSwapRemove on fetch shutdown
* fix formatting, remove unused property
-rw-r--r-- | src/bun.js/webcore/response.zig | 31 | ||||
-rw-r--r-- | src/deps/uws.zig | 6 | ||||
-rw-r--r-- | src/http_client_async.zig | 107 |
3 files changed, 86 insertions, 58 deletions
diff --git a/src/bun.js/webcore/response.zig b/src/bun.js/webcore/response.zig index 58a1dcac8..96b7e2409 100644 --- a/src/bun.js/webcore/response.zig +++ b/src/bun.js/webcore/response.zig @@ -608,6 +608,8 @@ pub const Fetch = struct { ); pub const FetchTasklet = struct { + const log = Output.scoped(.FetchTasklet, false); + http: ?*HTTPClient.AsyncHTTP = null, result: HTTPClient.HTTPClientResult = .{}, javascript_vm: *VirtualMachine = undefined, @@ -817,24 +819,12 @@ pub const Fetch = struct { proxy = jsc_vm.bundler.env.getHttpProxy(fetch_options.url); } - fetch_tasklet.http.?.* = HTTPClient.AsyncHTTP.init( - allocator, - fetch_options.method, - fetch_options.url, - fetch_options.headers.entries, - fetch_options.headers.buf.items, - &fetch_tasklet.response_buffer, - fetch_tasklet.request_body.slice(), - fetch_options.timeout, - HTTPClient.HTTPClientResult.Callback.New( - *FetchTasklet, - FetchTasklet.callback, - ).init( - fetch_tasklet, - ), - proxy, - if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null, - ); + fetch_tasklet.http.?.* = HTTPClient.AsyncHTTP.init(allocator, fetch_options.method, fetch_options.url, fetch_options.headers.entries, fetch_options.headers.buf.items, &fetch_tasklet.response_buffer, fetch_tasklet.request_body.slice(), fetch_options.timeout, HTTPClient.HTTPClientResult.Callback.New( + *FetchTasklet, + FetchTasklet.callback, + ).init( + fetch_tasklet, + ), proxy, if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null); if (!fetch_options.follow_redirects) { fetch_tasklet.http.?.client.remaining_redirect_count = 0; @@ -851,10 +841,15 @@ pub const Fetch = struct { } pub fn abortListener(this: *FetchTasklet, reason: JSValue) void { + log("abortListener", .{}); reason.ensureStillAlive(); this.abort_reason = reason; reason.protect(); this.aborted.store(true, .Monotonic); + + if (this.http != null) { + HTTPClient.http_thread.scheduleShutdown(this.http.?); + } } const FetchOptions = struct { diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 73eb72649..58a5b1b93 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -318,6 +318,12 @@ pub fn NewSocketHandler(comptime ssl: bool) type { us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); } + pub fn from(socket: *Socket) ThisSocket { + return ThisSocket { + .socket = socket + }; + } + pub fn adopt( socket: *Socket, socket_ctx: *SocketContext, diff --git a/src/http_client_async.zig b/src/http_client_async.zig index 41684acde..0f6045bbb 100644 --- a/src/http_client_async.zig +++ b/src/http_client_async.zig @@ -39,6 +39,9 @@ const Batch = NetworkThread.Batch; const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion; const DeadSocket = opaque {}; var dead_socket = @intToPtr(*DeadSocket, 1); +//TODO: this needs to be freed when Worker Threads are implemented +var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, *uws.Socket).init(bun.default_allocator); +var async_http_id: std.atomic.Atomic(u32) = std.atomic.Atomic(u32).init(0); const print_every = 0; var print_every_i: usize = 0; @@ -475,6 +478,7 @@ fn NewHTTPContext(comptime ssl: bool) type { const UnboundedQueue = @import("./bun.js/unbounded_queue.zig").UnboundedQueue; const Queue = UnboundedQueue(AsyncHTTP, .next); +const ShutdownQueue = UnboundedQueue(AsyncHTTP, .next); pub const HTTPThread = struct { var http_thread_loaded: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false); @@ -484,6 +488,7 @@ pub const HTTPThread = struct { https_context: NewHTTPContext(true), queued_tasks: Queue = Queue{}, + queued_shutdowns: ShutdownQueue = ShutdownQueue{}, has_awoken: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), timer: std.time.Timer = undefined, const threadlog = Output.scoped(.HTTPThread, true); @@ -552,6 +557,18 @@ pub const HTTPThread = struct { } fn drainEvents(this: *@This()) void { + while (this.queued_shutdowns.pop()) |http| { + if (socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| { + if (http.client.isHTTPS()) { + const socket = uws.SocketTLS.from(socket_ptr.value); + socket.shutdown(); + } else { + const socket = uws.SocketTCP.from(socket_ptr.value); + socket.shutdown(); + } + } + } + var count: usize = 0; var remaining: usize = AsyncHTTP.max_simultaneous_requests - AsyncHTTP.active_requests_count.loadUnchecked(); if (remaining == 0) return; @@ -600,6 +617,13 @@ pub const HTTPThread = struct { processEvents_(this); unreachable; } + + pub fn scheduleShutdown(this: *@This(), http: *AsyncHTTP) void { + this.queued_shutdowns.push(http); + if (this.has_awoken.load(.Monotonic)) + this.loop.wakeup(); + } + pub fn schedule(this: *@This(), batch: Batch) void { if (batch.len == 0) return; @@ -632,7 +656,9 @@ pub fn onOpen( std.debug.assert(is_ssl == client.url.isHTTPS()); } } - + if (client.aborted != null) { + socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable; + } log("Connected {s} \n", .{client.url.href}); if (client.hasSignalAborted()) { @@ -1000,22 +1026,16 @@ proxy_authorization: ?[]u8 = null, proxy_tunneling: bool = false, proxy_tunnel: ?ProxyTunnel = null, aborted: ?*std.atomic.Atomic(bool) = null, - -pub fn init( - allocator: std.mem.Allocator, - method: Method, - url: URL, - header_entries: Headers.Entries, - header_buf: string, - signal: ?*std.atomic.Atomic(bool), -) HTTPClient { - return HTTPClient{ - .allocator = allocator, - .method = method, - .url = url, - .header_entries = header_entries, - .header_buf = header_buf, - .aborted = signal, +async_http_id: u32 = 0, + +pub fn init(allocator: std.mem.Allocator, method: Method, url: URL, header_entries: Headers.Entries, header_buf: string, signal: ?*std.atomic.Atomic(bool)) HTTPClient { + return HTTPClient { + .allocator = allocator, + .method = method, + .url = url, + .header_entries = header_entries, + .header_buf = header_buf, + .aborted = signal, }; } @@ -1170,6 +1190,7 @@ pub const AsyncHTTP = struct { client: HTTPClient = undefined, err: ?anyerror = null, + async_http_id: u32 = 0, state: AtomicState = AtomicState.init(State.pending), elapsed: u64 = 0, @@ -1207,18 +1228,10 @@ pub const AsyncHTTP = struct { http_proxy: ?URL, signal: ?*std.atomic.Atomic(bool), ) AsyncHTTP { - var this = AsyncHTTP{ - .allocator = allocator, - .url = url, - .method = method, - .request_headers = headers, - .request_header_buf = headers_buf, - .request_body = request_body, - .response_buffer = response_buffer, - .completion_callback = callback, - .http_proxy = http_proxy, - }; + var this = AsyncHTTP{ .allocator = allocator, .url = url, .method = method, .request_headers = headers, .request_header_buf = headers_buf, .request_body = request_body, .response_buffer = response_buffer, .completion_callback = callback, .http_proxy = http_proxy, .async_http_id = if (signal != null) async_http_id.fetchAdd(1, .Monotonic) else 0 }; + this.client = HTTPClient.init(allocator, method, url, headers, headers_buf, signal); + this.client.async_http_id = this.async_http_id; this.client.timeout = timeout; this.client.http_proxy = this.http_proxy; if (http_proxy) |proxy| { @@ -1499,27 +1512,33 @@ pub fn doRedirect(this: *HTTPClient) void { tunnel.deinit(); this.proxy_tunnel = null; } + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } return this.start("", body_out_str); } - +pub fn isHTTPS(this: *HTTPClient) bool { + if (this.http_proxy) |proxy| { + if (proxy.isHTTPS()) { + return true; + } + return false; + } + if (this.url.isHTTPS()) { + return true; + } + return false; +} pub fn start(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) void { body_out_str.reset(); std.debug.assert(this.state.response_message_buffer.list.capacity == 0); this.state = InternalState.init(body, body_out_str); - if (this.http_proxy) |proxy| { - if (proxy.isHTTPS()) { - this.start_(true); - } else { - this.start_(false); - } + if (this.isHTTPS()) { + this.start_(true); } else { - if (this.url.isHTTPS()) { - this.start_(true); - } else { - this.start_(false); - } + this.start_(false); } } @@ -1538,6 +1557,7 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void { if (socket.isClosed() and (this.state.response_stage != .done and this.state.response_stage != .fail)) { this.fail(error.ConnectionClosed); std.debug.assert(this.state.fail != error.NoError); + return; } } @@ -2113,6 +2133,9 @@ pub fn closeAndAbort(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo } fn fail(this: *HTTPClient, err: anyerror) void { + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } this.state.request_stage = .fail; this.state.response_stage = .fail; this.state.fail = err; @@ -2157,6 +2180,10 @@ pub fn setTimeout(this: *HTTPClient, socket: anytype, amount: c_uint) void { pub fn done(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket) void { if (this.state.stage != .done and this.state.stage != .fail) { + if (this.aborted != null) { + _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); + } + log("done", .{}); var out_str = this.state.body_out_str.?; |