aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Ciro Spaciari <ciro.spaciari@gmail.com> 2023-03-02 02:40:11 -0300
committerGravatar GitHub <noreply@github.com> 2023-03-01 21:40:11 -0800
commit1be834b073420f4a00a8bb9aaf0de575eb39525d (patch)
treee0baeef23347e70e41bc193e264aa6e0a87fcf49
parentb9137dbdc81591f8b30cf95a4d27514bfb1ae71c (diff)
downloadbun-1be834b073420f4a00a8bb9aaf0de575eb39525d.tar.gz
bun-1be834b073420f4a00a8bb9aaf0de575eb39525d.tar.zst
bun-1be834b073420f4a00a8bb9aaf0de575eb39525d.zip
fix bun server segfault with abortsignal (#2261)
* removed redundant tests, fixed server segfault * fix onRejectStream, safer unassign signal * fix abort Bun.serve signal.addEventListener on async * move ctx.signal null check up * keep original behavior of streams onAborted
-rw-r--r--src/bun.js/api/server.zig101
-rw-r--r--src/bun.js/webcore/streams.zig2
-rw-r--r--test/bun.js/fetch.test.js109
3 files changed, 79 insertions, 133 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index daa028d39..5e2604983 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -650,6 +650,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
/// this prevents an extra pthread_getspecific() call which shows up in profiling
allocator: std.mem.Allocator,
req: *uws.Request,
+ signal: ?*JSC.AbortSignal = null,
method: HTTP.Method,
aborted: bool = false,
finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false),
@@ -698,11 +699,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn onResolve(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue {
+ ctxLog("onResolve", .{});
+
const arguments = callframe.arguments(2);
var ctx = arguments.ptr[1].asPromisePtr(@This());
const result = arguments.ptr[0];
result.ensureStillAlive();
+ if (ctx.request_js_object != null and ctx.signal == null) {
+ var request_js = ctx.request_js_object.?.value();
+ request_js.ensureStillAlive();
+ if (request_js.as(Request)) |request_object| {
+ if (request_object.signal) |signal| {
+ ctx.signal = signal;
+ _ = signal.ref();
+ }
+ }
+ }
+
ctx.pending_promises_for_abort -|= 1;
if (ctx.aborted) {
ctx.finalizeForAbort();
@@ -745,10 +759,23 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn onReject(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue {
+ ctxLog("onReject", .{});
+
const arguments = callframe.arguments(2);
var ctx = arguments.ptr[1].asPromisePtr(@This());
const err = arguments.ptr[0];
+ if (ctx.request_js_object != null and ctx.signal == null) {
+ var request_js = ctx.request_js_object.?.value();
+ request_js.ensureStillAlive();
+ if (request_js.as(Request)) |request_object| {
+ if (request_object.signal) |signal| {
+ ctx.signal = signal;
+ _ = signal.ref();
+ }
+ }
+ }
+
ctx.pending_promises_for_abort -|= 1;
if (ctx.aborted) {
@@ -992,13 +1019,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
std.debug.assert(!this.aborted);
//mark request as aborted
this.aborted = true;
+
+ // if signal is not aborted, abort the signal
+ if (this.signal) |signal| {
+ this.signal = null;
+ if (!signal.aborted()) {
+ const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
+ reason.ensureStillAlive();
+ _ = signal.signal(reason);
+ }
+ _ = signal.unref();
+ }
+
//if have sink, call onAborted on sink
if (this.sink) |wrapper| {
wrapper.detach();
wrapper.sink.onAborted(resp);
this.sink = null;
wrapper.sink.destroy();
- this.finalizeForAbort();
return;
}
@@ -1022,7 +1060,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// User called .blob(), .json(), text(), or .arrayBuffer() on the Request object
// but we received nothing or the connection was aborted
if (request_js.as(Request)) |req| {
- this._signalAbort(req);
// the promise is pending
if (req.body == .Locked and (req.body.Locked.action != .none or req.body.Locked.promise != null)) {
@@ -1059,20 +1096,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
- pub fn _signalAbort(this: *RequestContext, req: *Request) void {
- //only call when actually aborted
- if (!this.aborted) return;
- //check if have a valid signal
- if (req.signal) |signal| {
- // if signal is not aborted, abort the signal
- if (!signal.aborted()) {
- const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
- reason.ensureStillAlive();
- _ = signal.signal(reason);
- }
- }
- }
-
pub fn markComplete(this: *RequestContext) void {
if (!this.has_marked_complete) this.server.onRequestComplete();
this.has_marked_complete = true;
@@ -1098,6 +1121,17 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.response_jsvalue = JSC.JSValue.zero;
}
+ // if signal is not aborted, abort the signal
+ if (this.signal) |signal| {
+ this.signal = null;
+ if (this.aborted and !signal.aborted()) {
+ const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
+ reason.ensureStillAlive();
+ _ = signal.signal(reason);
+ }
+ _ = signal.unref();
+ }
+
if (this.request_js_object != null) {
ctxLog("finalizeWithoutDeinit: request_js_object != null", .{});
@@ -1110,7 +1144,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// User called .blob(), .json(), text(), or .arrayBuffer() on the Request object
// but we received nothing or the connection was aborted
if (request_js.as(Request)) |req| {
- this._signalAbort(req);
// the promise is pending
if (req.body == .Locked and req.body.Locked.action != .none and req.body.Locked.promise != null) {
req.body.toErrorInstance(JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis), this.server.globalThis);
@@ -1734,6 +1767,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
switch (promise.status(vm.global.vm())) {
.Pending => {},
.Fulfilled => {
+ if (ctx.signal == null) {
+ if (request_object.signal) |signal| {
+ ctx.signal = signal;
+ _ = signal.ref();
+ }
+ }
const fulfilled_value = promise.result(vm.global.vm());
// if you return a Response object or a Promise<Response>
@@ -1776,6 +1815,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return;
},
.Rejected => {
+ if (ctx.signal == null) {
+ if (request_object.signal) |signal| {
+ ctx.signal = signal;
+ _ = signal.ref();
+ }
+ }
ctx.handleReject(promise.result(vm.global.vm()));
return;
},
@@ -1816,8 +1861,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn handleResolveStream(req: *RequestContext) void {
streamLog("handleResolveStream", .{});
- //aborted already called finalizeForAbort at this stage
- if (req.aborted) return;
+ //aborted so call finalizeForAbort
+ if (req.aborted) {
+ req.finalizeForAbort();
+ return;
+ }
var wrote_anything = false;
if (req.sink) |wrapper| {
@@ -1869,9 +1917,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn handleRejectStream(req: *@This(), globalThis: *JSC.JSGlobalObject, err: JSValue) void {
- //aborted already called finalizeForAbort at this stage
- if (req.aborted) return;
-
streamLog("handleRejectStream", .{});
var wrote_anything = req.has_written_status;
@@ -1895,6 +1940,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
streamLog("onReject({any})", .{wrote_anything});
+ //aborted so call finalizeForAbort
+ if (req.aborted) {
+ req.finalizeForAbort();
+ return;
+ }
+
if (!err.isEmptyOrUndefinedOrNull() and !wrote_anything) {
req.response_jsvalue.unprotect();
req.response_jsvalue = JSValue.zero;
@@ -4696,8 +4747,12 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
ctx.request_js_object = args[0].asObjectRef();
const request_value = args[0];
request_value.ensureStillAlive();
- const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
+ const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
+ if (request_object.signal) |signal| {
+ ctx.signal = signal;
+ _ = signal.ref();
+ }
ctx.onResponse(
this,
req,
diff --git a/src/bun.js/webcore/streams.zig b/src/bun.js/webcore/streams.zig
index 61a83669d..0a38c7ed0 100644
--- a/src/bun.js/webcore/streams.zig
+++ b/src/bun.js/webcore/streams.zig
@@ -2760,9 +2760,9 @@ pub fn HTTPServerWritable(comptime ssl: bool) type {
pub fn onAborted(this: *@This(), _: *UWSResponse) void {
log("onAborted()", .{});
+ this.signal.close(null);
this.done = true;
this.aborted = true;
- this.signal.close(null);
this.flushPromise();
this.finalize();
}
diff --git a/test/bun.js/fetch.test.js b/test/bun.js/fetch.test.js
index d8422dca4..1946681ad 100644
--- a/test/bun.js/fetch.test.js
+++ b/test/bun.js/fetch.test.js
@@ -27,115 +27,6 @@ afterEach(() => {
const payload = new Uint8Array(1024 * 1024 * 2);
crypto.getRandomValues(payload);
-describe("AbortSignalStreamTest", async () => {
- async function abortOnStage(body, stage) {
- let error = undefined;
- var abortController = new AbortController();
- {
- const server = getServer({
- async fetch(request) {
- let chunk_count = 0;
- const reader = request.body.getReader();
- return Response(
- new ReadableStream({
- async pull(controller) {
- while (true) {
- chunk_count++;
-
- const { done, value } = await reader.read();
- if (chunk_count == stage) {
- abortController.abort();
- }
-
- if (done) {
- controller.close();
- return;
- }
- controller.enqueue(value);
- }
- },
- }),
- );
- },
- });
-
- try {
- const signal = abortController.signal;
-
- await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res =>
- res.arrayBuffer(),
- );
- } catch (ex) {
- error = ex;
- }
- expect(error.name).toBe("AbortError");
- expect(error.message).toBe("The operation was aborted.");
- expect(error instanceof DOMException).toBeTruthy();
- }
- }
-
- for (let i = 1; i < 7; i++) {
- it(`Abort after ${i} chunks`, async () => {
- await abortOnStage(payload, i);
- });
- }
-});
-
-describe("AbortSignalDirectStreamTest", () => {
- async function abortOnStage(body, stage) {
- let error = undefined;
- var abortController = new AbortController();
- {
- const server = getServer({
- async fetch(request) {
- let chunk_count = 0;
- const reader = request.body.getReader();
- return Response(
- new ReadableStream({
- type: "direct",
- async pull(controller) {
- while (true) {
- chunk_count++;
-
- const { done, value } = await reader.read();
- if (chunk_count == stage) {
- abortController.abort();
- }
-
- if (done) {
- controller.end();
- return;
- }
- controller.write(value);
- }
- },
- }),
- );
- },
- });
-
- try {
- const signal = abortController.signal;
-
- await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res =>
- res.arrayBuffer(),
- );
- } catch (ex) {
- error = ex;
- }
- expect(error.name).toBe("AbortError");
- expect(error.message).toBe("The operation was aborted.");
- expect(error instanceof DOMException).toBeTruthy();
- }
- }
-
- for (let i = 1; i < 7; i++) {
- it(`Abort after ${i} chunks`, async () => {
- await abortOnStage(payload, i);
- });
- }
-});
-
describe("AbortSignal", () => {
var server;
beforeEach(() => {