aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Ciro Spaciari <ciro.spaciari@gmail.com> 2023-02-27 21:10:03 -0300
committerGravatar GitHub <noreply@github.com> 2023-02-27 16:10:03 -0800
commit0afb1693d370c0dd1340b3eb7659d31ef3fc94a3 (patch)
tree1efa66eafd12df9b00730fe3270888f4657b74f7
parent7a4ac03338515b5153e7341dc827919d2f6b5245 (diff)
downloadbun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.gz
bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.tar.zst
bun-0afb1693d370c0dd1340b3eb7659d31ef3fc94a3.zip
fix(fetch.signal) capture socket and shutdown on abort signal (#2143)
* capture socket and shutdown on abort signal * queue shutdown tasks * little cleanup * change shutdown strategy * use fetchSwapRemove on fetch shutdown * use fetchSwapRemove on fetch shutdown * fix formatting, remove unused property
-rw-r--r--src/bun.js/webcore/response.zig31
-rw-r--r--src/deps/uws.zig6
-rw-r--r--src/http_client_async.zig107
3 files changed, 86 insertions, 58 deletions
diff --git a/src/bun.js/webcore/response.zig b/src/bun.js/webcore/response.zig
index 58a1dcac8..96b7e2409 100644
--- a/src/bun.js/webcore/response.zig
+++ b/src/bun.js/webcore/response.zig
@@ -608,6 +608,8 @@ pub const Fetch = struct {
);
pub const FetchTasklet = struct {
+ const log = Output.scoped(.FetchTasklet, false);
+
http: ?*HTTPClient.AsyncHTTP = null,
result: HTTPClient.HTTPClientResult = .{},
javascript_vm: *VirtualMachine = undefined,
@@ -817,24 +819,12 @@ pub const Fetch = struct {
proxy = jsc_vm.bundler.env.getHttpProxy(fetch_options.url);
}
- fetch_tasklet.http.?.* = HTTPClient.AsyncHTTP.init(
- allocator,
- fetch_options.method,
- fetch_options.url,
- fetch_options.headers.entries,
- fetch_options.headers.buf.items,
- &fetch_tasklet.response_buffer,
- fetch_tasklet.request_body.slice(),
- fetch_options.timeout,
- HTTPClient.HTTPClientResult.Callback.New(
- *FetchTasklet,
- FetchTasklet.callback,
- ).init(
- fetch_tasklet,
- ),
- proxy,
- if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null,
- );
+ fetch_tasklet.http.?.* = HTTPClient.AsyncHTTP.init(allocator, fetch_options.method, fetch_options.url, fetch_options.headers.entries, fetch_options.headers.buf.items, &fetch_tasklet.response_buffer, fetch_tasklet.request_body.slice(), fetch_options.timeout, HTTPClient.HTTPClientResult.Callback.New(
+ *FetchTasklet,
+ FetchTasklet.callback,
+ ).init(
+ fetch_tasklet,
+ ), proxy, if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null);
if (!fetch_options.follow_redirects) {
fetch_tasklet.http.?.client.remaining_redirect_count = 0;
@@ -851,10 +841,15 @@ pub const Fetch = struct {
}
pub fn abortListener(this: *FetchTasklet, reason: JSValue) void {
+ log("abortListener", .{});
reason.ensureStillAlive();
this.abort_reason = reason;
reason.protect();
this.aborted.store(true, .Monotonic);
+
+ if (this.http != null) {
+ HTTPClient.http_thread.scheduleShutdown(this.http.?);
+ }
}
const FetchOptions = struct {
diff --git a/src/deps/uws.zig b/src/deps/uws.zig
index 73eb72649..58a5b1b93 100644
--- a/src/deps/uws.zig
+++ b/src/deps/uws.zig
@@ -318,6 +318,12 @@ pub fn NewSocketHandler(comptime ssl: bool) type {
us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end);
}
+ pub fn from(socket: *Socket) ThisSocket {
+ return ThisSocket {
+ .socket = socket
+ };
+ }
+
pub fn adopt(
socket: *Socket,
socket_ctx: *SocketContext,
diff --git a/src/http_client_async.zig b/src/http_client_async.zig
index 41684acde..0f6045bbb 100644
--- a/src/http_client_async.zig
+++ b/src/http_client_async.zig
@@ -39,6 +39,9 @@ const Batch = NetworkThread.Batch;
const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion;
const DeadSocket = opaque {};
var dead_socket = @intToPtr(*DeadSocket, 1);
+//TODO: this needs to be freed when Worker Threads are implemented
+var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, *uws.Socket).init(bun.default_allocator);
+var async_http_id: std.atomic.Atomic(u32) = std.atomic.Atomic(u32).init(0);
const print_every = 0;
var print_every_i: usize = 0;
@@ -475,6 +478,7 @@ fn NewHTTPContext(comptime ssl: bool) type {
const UnboundedQueue = @import("./bun.js/unbounded_queue.zig").UnboundedQueue;
const Queue = UnboundedQueue(AsyncHTTP, .next);
+const ShutdownQueue = UnboundedQueue(AsyncHTTP, .next);
pub const HTTPThread = struct {
var http_thread_loaded: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false);
@@ -484,6 +488,7 @@ pub const HTTPThread = struct {
https_context: NewHTTPContext(true),
queued_tasks: Queue = Queue{},
+ queued_shutdowns: ShutdownQueue = ShutdownQueue{},
has_awoken: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false),
timer: std.time.Timer = undefined,
const threadlog = Output.scoped(.HTTPThread, true);
@@ -552,6 +557,18 @@ pub const HTTPThread = struct {
}
fn drainEvents(this: *@This()) void {
+ while (this.queued_shutdowns.pop()) |http| {
+ if (socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| {
+ if (http.client.isHTTPS()) {
+ const socket = uws.SocketTLS.from(socket_ptr.value);
+ socket.shutdown();
+ } else {
+ const socket = uws.SocketTCP.from(socket_ptr.value);
+ socket.shutdown();
+ }
+ }
+ }
+
var count: usize = 0;
var remaining: usize = AsyncHTTP.max_simultaneous_requests - AsyncHTTP.active_requests_count.loadUnchecked();
if (remaining == 0) return;
@@ -600,6 +617,13 @@ pub const HTTPThread = struct {
processEvents_(this);
unreachable;
}
+
+ pub fn scheduleShutdown(this: *@This(), http: *AsyncHTTP) void {
+ this.queued_shutdowns.push(http);
+ if (this.has_awoken.load(.Monotonic))
+ this.loop.wakeup();
+ }
+
pub fn schedule(this: *@This(), batch: Batch) void {
if (batch.len == 0)
return;
@@ -632,7 +656,9 @@ pub fn onOpen(
std.debug.assert(is_ssl == client.url.isHTTPS());
}
}
-
+ if (client.aborted != null) {
+ socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable;
+ }
log("Connected {s} \n", .{client.url.href});
if (client.hasSignalAborted()) {
@@ -1000,22 +1026,16 @@ proxy_authorization: ?[]u8 = null,
proxy_tunneling: bool = false,
proxy_tunnel: ?ProxyTunnel = null,
aborted: ?*std.atomic.Atomic(bool) = null,
-
-pub fn init(
- allocator: std.mem.Allocator,
- method: Method,
- url: URL,
- header_entries: Headers.Entries,
- header_buf: string,
- signal: ?*std.atomic.Atomic(bool),
-) HTTPClient {
- return HTTPClient{
- .allocator = allocator,
- .method = method,
- .url = url,
- .header_entries = header_entries,
- .header_buf = header_buf,
- .aborted = signal,
+async_http_id: u32 = 0,
+
+pub fn init(allocator: std.mem.Allocator, method: Method, url: URL, header_entries: Headers.Entries, header_buf: string, signal: ?*std.atomic.Atomic(bool)) HTTPClient {
+ return HTTPClient {
+ .allocator = allocator,
+ .method = method,
+ .url = url,
+ .header_entries = header_entries,
+ .header_buf = header_buf,
+ .aborted = signal,
};
}
@@ -1170,6 +1190,7 @@ pub const AsyncHTTP = struct {
client: HTTPClient = undefined,
err: ?anyerror = null,
+ async_http_id: u32 = 0,
state: AtomicState = AtomicState.init(State.pending),
elapsed: u64 = 0,
@@ -1207,18 +1228,10 @@ pub const AsyncHTTP = struct {
http_proxy: ?URL,
signal: ?*std.atomic.Atomic(bool),
) AsyncHTTP {
- var this = AsyncHTTP{
- .allocator = allocator,
- .url = url,
- .method = method,
- .request_headers = headers,
- .request_header_buf = headers_buf,
- .request_body = request_body,
- .response_buffer = response_buffer,
- .completion_callback = callback,
- .http_proxy = http_proxy,
- };
+ var this = AsyncHTTP{ .allocator = allocator, .url = url, .method = method, .request_headers = headers, .request_header_buf = headers_buf, .request_body = request_body, .response_buffer = response_buffer, .completion_callback = callback, .http_proxy = http_proxy, .async_http_id = if (signal != null) async_http_id.fetchAdd(1, .Monotonic) else 0 };
+
this.client = HTTPClient.init(allocator, method, url, headers, headers_buf, signal);
+ this.client.async_http_id = this.async_http_id;
this.client.timeout = timeout;
this.client.http_proxy = this.http_proxy;
if (http_proxy) |proxy| {
@@ -1499,27 +1512,33 @@ pub fn doRedirect(this: *HTTPClient) void {
tunnel.deinit();
this.proxy_tunnel = null;
}
+ if (this.aborted != null) {
+ _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id);
+ }
return this.start("", body_out_str);
}
-
+pub fn isHTTPS(this: *HTTPClient) bool {
+ if (this.http_proxy) |proxy| {
+ if (proxy.isHTTPS()) {
+ return true;
+ }
+ return false;
+ }
+ if (this.url.isHTTPS()) {
+ return true;
+ }
+ return false;
+}
pub fn start(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) void {
body_out_str.reset();
std.debug.assert(this.state.response_message_buffer.list.capacity == 0);
this.state = InternalState.init(body, body_out_str);
- if (this.http_proxy) |proxy| {
- if (proxy.isHTTPS()) {
- this.start_(true);
- } else {
- this.start_(false);
- }
+ if (this.isHTTPS()) {
+ this.start_(true);
} else {
- if (this.url.isHTTPS()) {
- this.start_(true);
- } else {
- this.start_(false);
- }
+ this.start_(false);
}
}
@@ -1538,6 +1557,7 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void {
if (socket.isClosed() and (this.state.response_stage != .done and this.state.response_stage != .fail)) {
this.fail(error.ConnectionClosed);
std.debug.assert(this.state.fail != error.NoError);
+ return;
}
}
@@ -2113,6 +2133,9 @@ pub fn closeAndAbort(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo
}
fn fail(this: *HTTPClient, err: anyerror) void {
+ if (this.aborted != null) {
+ _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id);
+ }
this.state.request_stage = .fail;
this.state.response_stage = .fail;
this.state.fail = err;
@@ -2157,6 +2180,10 @@ pub fn setTimeout(this: *HTTPClient, socket: anytype, amount: c_uint) void {
pub fn done(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket) void {
if (this.state.stage != .done and this.state.stage != .fail) {
+ if (this.aborted != null) {
+ _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id);
+ }
+
log("done", .{});
var out_str = this.state.body_out_str.?;