diff options
author | 2023-03-02 02:40:11 -0300 | |
---|---|---|
committer | 2023-03-01 21:40:11 -0800 | |
commit | 1be834b073420f4a00a8bb9aaf0de575eb39525d (patch) | |
tree | e0baeef23347e70e41bc193e264aa6e0a87fcf49 | |
parent | b9137dbdc81591f8b30cf95a4d27514bfb1ae71c (diff) | |
download | bun-1be834b073420f4a00a8bb9aaf0de575eb39525d.tar.gz bun-1be834b073420f4a00a8bb9aaf0de575eb39525d.tar.zst bun-1be834b073420f4a00a8bb9aaf0de575eb39525d.zip |
fix bun server segfault with abortsignal (#2261)
* removed redundant tests, fixed server segfault
* fix onRejectStream, safer unassign signal
* fix abort Bun.serve signal.addEventListener on async
* move ctx.signal null check up
* keep original behavior of streams onAborted
-rw-r--r-- | src/bun.js/api/server.zig | 101 | ||||
-rw-r--r-- | src/bun.js/webcore/streams.zig | 2 | ||||
-rw-r--r-- | test/bun.js/fetch.test.js | 109 |
3 files changed, 79 insertions, 133 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index daa028d39..5e2604983 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -650,6 +650,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp /// this prevents an extra pthread_getspecific() call which shows up in profiling allocator: std.mem.Allocator, req: *uws.Request, + signal: ?*JSC.AbortSignal = null, method: HTTP.Method, aborted: bool = false, finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false), @@ -698,11 +699,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } pub fn onResolve(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { + ctxLog("onResolve", .{}); + const arguments = callframe.arguments(2); var ctx = arguments.ptr[1].asPromisePtr(@This()); const result = arguments.ptr[0]; result.ensureStillAlive(); + if (ctx.request_js_object != null and ctx.signal == null) { + var request_js = ctx.request_js_object.?.value(); + request_js.ensureStillAlive(); + if (request_js.as(Request)) |request_object| { + if (request_object.signal) |signal| { + ctx.signal = signal; + _ = signal.ref(); + } + } + } + ctx.pending_promises_for_abort -|= 1; if (ctx.aborted) { ctx.finalizeForAbort(); @@ -745,10 +759,23 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } pub fn onReject(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { + ctxLog("onReject", .{}); + const arguments = callframe.arguments(2); var ctx = arguments.ptr[1].asPromisePtr(@This()); const err = arguments.ptr[0]; + if (ctx.request_js_object != null and ctx.signal == null) { + var request_js = ctx.request_js_object.?.value(); + request_js.ensureStillAlive(); + if (request_js.as(Request)) |request_object| { + if (request_object.signal) |signal| { + ctx.signal = signal; + _ = signal.ref(); + } + } + } + ctx.pending_promises_for_abort -|= 1; if (ctx.aborted) { @@ -992,13 +1019,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp std.debug.assert(!this.aborted); //mark request as aborted this.aborted = true; + + // if signal is not aborted, abort the signal + if (this.signal) |signal| { + this.signal = null; + if (!signal.aborted()) { + const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); + reason.ensureStillAlive(); + _ = signal.signal(reason); + } + _ = signal.unref(); + } + //if have sink, call onAborted on sink if (this.sink) |wrapper| { wrapper.detach(); wrapper.sink.onAborted(resp); this.sink = null; wrapper.sink.destroy(); - this.finalizeForAbort(); return; } @@ -1022,7 +1060,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // User called .blob(), .json(), text(), or .arrayBuffer() on the Request object // but we received nothing or the connection was aborted if (request_js.as(Request)) |req| { - this._signalAbort(req); // the promise is pending if (req.body == .Locked and (req.body.Locked.action != .none or req.body.Locked.promise != null)) { @@ -1059,20 +1096,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } } - pub fn _signalAbort(this: *RequestContext, req: *Request) void { - //only call when actually aborted - if (!this.aborted) return; - //check if have a valid signal - if (req.signal) |signal| { - // if signal is not aborted, abort the signal - if (!signal.aborted()) { - const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); - reason.ensureStillAlive(); - _ = signal.signal(reason); - } - } - } - pub fn markComplete(this: *RequestContext) void { if (!this.has_marked_complete) this.server.onRequestComplete(); this.has_marked_complete = true; @@ -1098,6 +1121,17 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this.response_jsvalue = JSC.JSValue.zero; } + // if signal is not aborted, abort the signal + if (this.signal) |signal| { + this.signal = null; + if (this.aborted and !signal.aborted()) { + const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); + reason.ensureStillAlive(); + _ = signal.signal(reason); + } + _ = signal.unref(); + } + if (this.request_js_object != null) { ctxLog("finalizeWithoutDeinit: request_js_object != null", .{}); @@ -1110,7 +1144,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // User called .blob(), .json(), text(), or .arrayBuffer() on the Request object // but we received nothing or the connection was aborted if (request_js.as(Request)) |req| { - this._signalAbort(req); // the promise is pending if (req.body == .Locked and req.body.Locked.action != .none and req.body.Locked.promise != null) { req.body.toErrorInstance(JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis), this.server.globalThis); @@ -1734,6 +1767,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp switch (promise.status(vm.global.vm())) { .Pending => {}, .Fulfilled => { + if (ctx.signal == null) { + if (request_object.signal) |signal| { + ctx.signal = signal; + _ = signal.ref(); + } + } const fulfilled_value = promise.result(vm.global.vm()); // if you return a Response object or a Promise<Response> @@ -1776,6 +1815,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return; }, .Rejected => { + if (ctx.signal == null) { + if (request_object.signal) |signal| { + ctx.signal = signal; + _ = signal.ref(); + } + } ctx.handleReject(promise.result(vm.global.vm())); return; }, @@ -1816,8 +1861,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp pub fn handleResolveStream(req: *RequestContext) void { streamLog("handleResolveStream", .{}); - //aborted already called finalizeForAbort at this stage - if (req.aborted) return; + //aborted so call finalizeForAbort + if (req.aborted) { + req.finalizeForAbort(); + return; + } var wrote_anything = false; if (req.sink) |wrapper| { @@ -1869,9 +1917,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } pub fn handleRejectStream(req: *@This(), globalThis: *JSC.JSGlobalObject, err: JSValue) void { - //aborted already called finalizeForAbort at this stage - if (req.aborted) return; - streamLog("handleRejectStream", .{}); var wrote_anything = req.has_written_status; @@ -1895,6 +1940,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp streamLog("onReject({any})", .{wrote_anything}); + //aborted so call finalizeForAbort + if (req.aborted) { + req.finalizeForAbort(); + return; + } + if (!err.isEmptyOrUndefinedOrNull() and !wrote_anything) { req.response_jsvalue.unprotect(); req.response_jsvalue = JSValue.zero; @@ -4696,8 +4747,12 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { ctx.request_js_object = args[0].asObjectRef(); const request_value = args[0]; request_value.ensureStillAlive(); - const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); + const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); + if (request_object.signal) |signal| { + ctx.signal = signal; + _ = signal.ref(); + } ctx.onResponse( this, req, diff --git a/src/bun.js/webcore/streams.zig b/src/bun.js/webcore/streams.zig index 61a83669d..0a38c7ed0 100644 --- a/src/bun.js/webcore/streams.zig +++ b/src/bun.js/webcore/streams.zig @@ -2760,9 +2760,9 @@ pub fn HTTPServerWritable(comptime ssl: bool) type { pub fn onAborted(this: *@This(), _: *UWSResponse) void { log("onAborted()", .{}); + this.signal.close(null); this.done = true; this.aborted = true; - this.signal.close(null); this.flushPromise(); this.finalize(); } diff --git a/test/bun.js/fetch.test.js b/test/bun.js/fetch.test.js index d8422dca4..1946681ad 100644 --- a/test/bun.js/fetch.test.js +++ b/test/bun.js/fetch.test.js @@ -27,115 +27,6 @@ afterEach(() => { const payload = new Uint8Array(1024 * 1024 * 2); crypto.getRandomValues(payload); -describe("AbortSignalStreamTest", async () => { - async function abortOnStage(body, stage) { - let error = undefined; - var abortController = new AbortController(); - { - const server = getServer({ - async fetch(request) { - let chunk_count = 0; - const reader = request.body.getReader(); - return Response( - new ReadableStream({ - async pull(controller) { - while (true) { - chunk_count++; - - const { done, value } = await reader.read(); - if (chunk_count == stage) { - abortController.abort(); - } - - if (done) { - controller.close(); - return; - } - controller.enqueue(value); - } - }, - }), - ); - }, - }); - - try { - const signal = abortController.signal; - - await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res => - res.arrayBuffer(), - ); - } catch (ex) { - error = ex; - } - expect(error.name).toBe("AbortError"); - expect(error.message).toBe("The operation was aborted."); - expect(error instanceof DOMException).toBeTruthy(); - } - } - - for (let i = 1; i < 7; i++) { - it(`Abort after ${i} chunks`, async () => { - await abortOnStage(payload, i); - }); - } -}); - -describe("AbortSignalDirectStreamTest", () => { - async function abortOnStage(body, stage) { - let error = undefined; - var abortController = new AbortController(); - { - const server = getServer({ - async fetch(request) { - let chunk_count = 0; - const reader = request.body.getReader(); - return Response( - new ReadableStream({ - type: "direct", - async pull(controller) { - while (true) { - chunk_count++; - - const { done, value } = await reader.read(); - if (chunk_count == stage) { - abortController.abort(); - } - - if (done) { - controller.end(); - return; - } - controller.write(value); - } - }, - }), - ); - }, - }); - - try { - const signal = abortController.signal; - - await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res => - res.arrayBuffer(), - ); - } catch (ex) { - error = ex; - } - expect(error.name).toBe("AbortError"); - expect(error.message).toBe("The operation was aborted."); - expect(error instanceof DOMException).toBeTruthy(); - } - } - - for (let i = 1; i < 7; i++) { - it(`Abort after ${i} chunks`, async () => { - await abortOnStage(payload, i); - }); - } -}); - describe("AbortSignal", () => { var server; beforeEach(() => { |