aboutsummaryrefslogtreecommitdiff
path: root/src/bun.js/api/server.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/bun.js/api/server.zig')
-rw-r--r--src/bun.js/api/server.zig272
1 files changed, 211 insertions, 61 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index 85d4dadb5..edf1d6d69 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -1125,6 +1125,92 @@ fn NewFlags(comptime debug_mode: bool) type {
};
}
+/// A generic wrapper for the HTTP(s) Server`RequestContext`s.
+/// Only really exists because of `NewServer()` and `NewRequestContext()` generics.
+pub const AnyRequestContext = struct {
+ pub const Pointer = bun.TaggedPointerUnion(.{
+ HTTPServer.RequestContext,
+ HTTPSServer.RequestContext,
+ DebugHTTPServer.RequestContext,
+ DebugHTTPSServer.RequestContext,
+ });
+
+ tagged_pointer: Pointer,
+
+ pub const Null = .{ .tagged_pointer = Pointer.Null };
+
+ pub fn init(request_ctx: anytype) AnyRequestContext {
+ return .{ .tagged_pointer = Pointer.init(request_ctx) };
+ }
+
+ pub fn getRemoteSocketInfo(self: AnyRequestContext) ?uws.SocketAddress {
+ if (self.tagged_pointer.isNull()) {
+ return null;
+ }
+
+ switch (self.tagged_pointer.tag()) {
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => {
+ return self.tagged_pointer.as(HTTPServer.RequestContext).getRemoteSocketInfo();
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => {
+ return self.tagged_pointer.as(HTTPSServer.RequestContext).getRemoteSocketInfo();
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => {
+ return self.tagged_pointer.as(DebugHTTPServer.RequestContext).getRemoteSocketInfo();
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => {
+ return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).getRemoteSocketInfo();
+ },
+ else => @panic("Unexpected AnyRequestContext tag"),
+ }
+ }
+
+ /// Wont actually set anything if `self` is `.none`
+ pub fn setRequest(self: AnyRequestContext, req: *uws.Request) void {
+ if (self.tagged_pointer.isNull()) {
+ return;
+ }
+
+ switch (self.tagged_pointer.tag()) {
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => {
+ self.tagged_pointer.as(HTTPServer.RequestContext).req = req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => {
+ self.tagged_pointer.as(HTTPSServer.RequestContext).req = req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => {
+ self.tagged_pointer.as(DebugHTTPServer.RequestContext).req = req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => {
+ self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req = req;
+ },
+ else => @panic("Unexpected AnyRequestContext tag"),
+ }
+ }
+
+ pub fn getRequest(self: AnyRequestContext) ?*uws.Request {
+ if (self.tagged_pointer.isNull()) {
+ return null;
+ }
+
+ switch (self.tagged_pointer.tag()) {
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => {
+ return self.tagged_pointer.as(HTTPServer.RequestContext).req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => {
+ return self.tagged_pointer.as(HTTPSServer.RequestContext).req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => {
+ return self.tagged_pointer.as(DebugHTTPServer.RequestContext).req;
+ },
+ @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => {
+ return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req;
+ },
+ else => @panic("Unexpected AnyRequestContext tag"),
+ }
+ }
+};
+
// This is defined separately partially to work-around an LLVM debugger bug.
fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comptime ThisServer: type) type {
return struct {
@@ -1443,6 +1529,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn endStream(this: *RequestContext, closeConnection: bool) void {
+ ctxLog("endStream", .{});
if (this.resp) |resp| {
if (this.flags.is_waiting_for_request_body) {
this.flags.is_waiting_for_request_body = false;
@@ -1537,8 +1624,17 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn onAbort(this: *RequestContext, resp: *App.Response) void {
std.debug.assert(this.resp == resp);
std.debug.assert(!this.flags.aborted);
- //mark request as aborted
+ // mark request as aborted
this.flags.aborted = true;
+ var any_js_calls = false;
+ var vm = this.server.vm;
+ defer {
+ // This is a task in the event loop.
+ // If we called into JavaScript, we must drain the microtask queue
+ if (any_js_calls) {
+ vm.drainMicrotasks();
+ }
+ }
// if signal is not aborted, abort the signal
if (this.signal) |signal| {
@@ -1547,6 +1643,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
+ any_js_calls = true;
}
_ = signal.unref();
}
@@ -1578,6 +1675,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
} else if (body.value.Locked.readable != null) {
body.value.Locked.readable.?.abort(this.server.globalThis);
body.value.Locked.readable = null;
+ any_js_calls = true;
}
body.value.toErrorInstance(JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis), this.server.globalThis);
}
@@ -1588,6 +1686,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
if (response.body.value.Locked.readable) |*readable| {
response.body.value.Locked.readable = null;
readable.abort(this.server.globalThis);
+ any_js_calls = true;
}
}
}
@@ -1597,10 +1696,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.pending_promises_for_abort += 1;
this.promise = null;
promise.asAnyPromise().?.reject(this.server.globalThis, JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis));
- }
-
- if (this.pending_promises_for_abort > 0) {
- this.server.vm.tick();
+ any_js_calls = true;
}
}
}
@@ -1720,6 +1816,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this: *RequestContext,
headers: *JSC.FetchHeaders,
) void {
+ ctxLog("writeHeaders", .{});
headers.fastRemove(.ContentLength);
headers.fastRemove(.TransferEncoding);
if (!ssl_enabled) headers.fastRemove(.StrictTransportSecurity);
@@ -2091,6 +2188,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn doRenderStream(pair: *StreamPair) void {
+ ctxLog("doRenderStream", .{});
var this = pair.this;
var stream = pair.stream;
if (this.resp == null or this.flags.aborted) {
@@ -2214,6 +2312,14 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
},
}
return;
+ } else {
+ // if is not a promise we treat it as Error
+ streamLog("returned an error", .{});
+ if (!this.flags.aborted) resp.clearAborted();
+ response_stream.detach();
+ this.sink = null;
+ response_stream.sink.destroy();
+ return this.handleReject(assignment_result);
}
}
@@ -2223,6 +2329,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
defer stream.value.unprotect();
response_stream.sink.markDone();
this.finalizeForAbort();
+ response_stream.sink.onFirstWrite = null;
response_stream.sink.finalize();
return;
@@ -2246,7 +2353,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.setAbortHandler();
streamLog("is in progress, but did not return a Promise. Finalizing request context", .{});
- this.finalize();
+ response_stream.sink.onFirstWrite = null;
+ response_stream.sink.ctx = null;
+ response_stream.detach();
+ stream.cancel(globalThis);
+ response_stream.sink.markDone();
+ this.renderMissing();
}
const streamLog = Output.scoped(.ReadableStream, false);
@@ -2256,7 +2368,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn toAsyncWithoutAbortHandler(ctx: *RequestContext, req: *uws.Request, request_object: *Request) void {
- request_object.uws_request = req;
+ request_object.request_context.setRequest(req);
request_object.ensureURL() catch {
request_object.url = bun.String.empty;
@@ -2269,7 +2381,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// This object dies after the stack frame is popped
// so we have to clear it in here too
- request_object.uws_request = null;
+ request_object.request_context = JSC.API.AnyRequestContext.Null;
}
fn toAsync(
@@ -2446,7 +2558,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
streamLog("onResolve({any})", .{wrote_anything});
-
//aborted so call finalizeForAbort
if (req.flags.aborted or req.resp == null) {
req.finalizeForAbort();
@@ -2723,7 +2834,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn doRender(this: *RequestContext) void {
- ctxLog("render", .{});
+ ctxLog("doRender", .{});
if (this.flags.aborted) {
this.finalizeForAbort();
@@ -2877,7 +2988,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
var needs_content_type = true;
const content_type: MimeType = brk: {
- if (response.body.init.headers) |headers_| {
+ if (response.init.headers) |headers_| {
if (headers_.fastGet(.ContentType)) |content| {
needs_content_type = false;
break :brk MimeType.byName(content.slice());
@@ -2897,7 +3008,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
};
var has_content_disposition = false;
- if (response.body.init.headers) |headers_| {
+ if (response.init.headers) |headers_| {
has_content_disposition = headers_.fastHas(.ContentDisposition);
needs_content_range = needs_content_range and headers_.fastHas(.ContentRange);
if (needs_content_range) {
@@ -2907,7 +3018,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.writeStatus(status);
this.writeHeaders(headers_);
- response.body.init.headers = null;
+ response.init.headers = null;
headers_.deref();
} else if (needs_content_range) {
status = 206;
@@ -3039,7 +3150,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
if (last) {
var bytes = this.request_body_buf;
- defer this.request_body_buf = .{};
+
var old = body.value;
const total = bytes.items.len + chunk.len;
@@ -3070,6 +3181,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
};
// }
}
+ this.request_body_buf = .{};
if (old == .Locked) {
var vm = this.server.vm;
@@ -3142,6 +3254,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return onStartStreamingRequestBody(bun.cast(*RequestContext, this));
}
+ pub fn getRemoteSocketInfo(this: *RequestContext) ?uws.SocketAddress {
+ return (this.resp orelse return null).getRemoteSocketInfo();
+ }
+
pub const Export = shim.exportFunctions(.{
.onResolve = onResolve,
.onReject = onReject,
@@ -4667,17 +4783,6 @@ pub const ServerWebSocket = struct {
return JSValue.jsBoolean(this.websocket.isSubscribed(topic.slice()));
}
- // pub fn getTopics(
- // this: *ServerWebSocket,
- // globalThis: *JSC.JSGlobalObject,
- // ) callconv(.C) JSValue {
- // if (this.closed) {
- // return JSValue.createStringArray(globalThis, bun.default_allocator, null, 0, false);
- // }
-
- // this
- // }
-
pub fn getRemoteAddress(
this: *ServerWebSocket,
globalThis: *JSC.JSGlobalObject,
@@ -4704,7 +4809,7 @@ pub const ServerWebSocket = struct {
pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
return struct {
pub const ssl_enabled = ssl_enabled_;
- const debug_mode = debug_mode_;
+ pub const debug_mode = debug_mode_;
const ThisServer = @This();
pub const RequestContext = NewRequestContext(ssl_enabled, debug_mode, @This());
@@ -4742,6 +4847,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
pub const doPublish = JSC.wrapInstanceMethod(ThisServer, "publish", false);
pub const doReload = onReload;
pub const doFetch = onFetch;
+ pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false);
pub usingnamespace NamespaceType;
@@ -4749,6 +4855,24 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
globalThis.throw("Server() is not a constructor", .{});
return null;
}
+
+ extern fn JSSocketAddress__create(global: *JSC.JSGlobalObject, ip: JSValue, port: i32, is_ipv6: bool) JSValue;
+
+ pub fn requestIP(this: *ThisServer, request: *JSC.WebCore.Request) JSC.JSValue {
+ if (this.config.address == .unix) {
+ return JSValue.jsNull();
+ }
+ return if (request.request_context.getRemoteSocketInfo()) |info|
+ JSSocketAddress__create(
+ this.globalThis,
+ bun.String.static(info.ip).toJSConst(this.globalThis),
+ info.port,
+ info.is_ipv6,
+ )
+ else
+ JSValue.jsNull();
+ }
+
pub fn publish(this: *ThisServer, globalThis: *JSC.JSGlobalObject, topic: ZigString, message_value: JSValue, compress_value: ?JSValue, exception: JSC.C.ExceptionRef) JSValue {
if (this.config.websocket == null)
return JSValue.jsNumber(0);
@@ -5092,7 +5216,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
return JSPromise.rejectedPromiseValue(ctx, err);
}
- var request = ctx.bunVM().allocator.create(Request) catch unreachable;
+ var request = bun.default_allocator.create(Request) catch unreachable;
request.* = existing_request;
const response_value = this.config.onRequest.callWithThis(
@@ -5173,6 +5297,37 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
return JSC.JSValue.jsNumber(@as(i32, @intCast(@as(u31, @truncate(this.activeSocketsCount())))));
}
+ pub fn getAddress(this: *ThisServer, globalThis: *JSGlobalObject) callconv(.C) JSC.JSValue {
+ switch (this.config.address) {
+ .unix => |unix| {
+ var value = bun.String.create(bun.sliceTo(@constCast(unix), 0));
+ defer value.deref();
+ return value.toJS(globalThis);
+ },
+ .tcp => {
+ var port: u16 = this.config.address.tcp.port;
+
+ if (this.listener) |listener| {
+ port = @intCast(listener.getLocalPort());
+
+ var buf: [64]u8 = [_]u8{0} ** 64;
+ var is_ipv6: bool = false;
+
+ if (listener.socket().localAddressText(&buf, &is_ipv6)) |slice| {
+ var ip = bun.String.create(slice);
+ return JSSocketAddress__create(
+ this.globalThis,
+ ip.toJS(this.globalThis),
+ port,
+ is_ipv6,
+ );
+ }
+ }
+ return JSValue.jsNull();
+ },
+ }
+ }
+
pub fn getHostname(this: *ThisServer, globalThis: *JSGlobalObject) callconv(.C) JSC.JSValue {
if (this.cached_hostname.isEmpty()) {
if (this.listener) |listener| {
@@ -5254,6 +5409,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
var listener = this.listener orelse return;
this.listener = null;
this.unref();
+
+ if (!ssl_enabled_)
+ this.vm.removeListeningSocketForWatchMode(@intCast(listener.socket().fd()));
+
if (!abrupt) {
listener.close();
} else if (!this.flags.terminated) {
@@ -5388,24 +5547,18 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
if (error_instance == .zero) {
switch (this.config.address) {
.tcp => |tcp| {
- error_instance = ZigString.init(
- std.fmt.bufPrint(&output_buf, "Failed to start server. Is port {d} in use?", .{tcp.port}) catch "Failed to start server",
- ).toErrorInstance(
- this.globalThis,
- );
+ error_instance = (JSC.SystemError{
+ .message = bun.String.init(std.fmt.bufPrint(&output_buf, "Failed to start server. Is port {d} in use?", .{tcp.port}) catch "Failed to start server"),
+ .code = bun.String.static("EADDRINUSE"),
+ .syscall = bun.String.static("listen"),
+ }).toErrorInstance(this.globalThis);
},
.unix => |unix| {
- error_instance = ZigString.init(
- std.fmt.bufPrint(
- &output_buf,
- "Failed to listen on unix socket {}",
- .{
- strings.QuotedFormatter{ .text = bun.sliceTo(unix, 0) },
- },
- ) catch "Failed to start server",
- ).toErrorInstance(
- this.globalThis,
- );
+ error_instance = (JSC.SystemError{
+ .message = bun.String.init(std.fmt.bufPrint(&output_buf, "Failed to listen on unix socket {}", .{strings.QuotedFormatter{ .text = bun.sliceTo(unix, 0) }}) catch "Failed to start server"),
+ .code = bun.String.static("EADDRINUSE"),
+ .syscall = bun.String.static("listen"),
+ }).toErrorInstance(this.globalThis);
},
}
}
@@ -5428,6 +5581,8 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
this.listener = socket;
this.vm.event_loop_handle = uws.Loop.get();
+ if (!ssl_enabled_)
+ this.vm.addListeningSocketForWatchMode(@intCast(socket.?.socket().fd()));
}
pub fn ref(this: *ThisServer) void {
@@ -5512,21 +5667,19 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
req.setYield(false);
var ctx = this.request_pool_allocator.tryGet() catch @panic("ran out of memory");
ctx.create(this, req, resp);
+ this.vm.jsc.reportExtraMemory(@sizeOf(RequestContext));
var request_object = this.allocator.create(JSC.WebCore.Request) catch unreachable;
var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable;
ctx.request_body = body;
- const js_signal = JSC.WebCore.AbortSignal.create(this.globalThis);
- js_signal.ensureStillAlive();
- if (JSC.WebCore.AbortSignal.fromJS(js_signal)) |signal| {
- ctx.signal = signal.ref().ref(); // +2 refs 1 for the request and 1 for the request context
- }
+ var signal = JSC.WebCore.AbortSignal.new(this.globalThis);
+ ctx.signal = signal;
request_object.* = .{
.method = ctx.method,
- .uws_request = req,
+ .request_context = AnyRequestContext.init(ctx),
.https = ssl_enabled,
- .signal = ctx.signal,
+ .signal = signal.ref(),
.body = body.ref(),
};
@@ -5593,7 +5746,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
defer {
// uWS request will not live longer than this function
- request_object.uws_request = null;
+ request_object.request_context = JSC.API.AnyRequestContext.Null;
}
var should_deinit_context = false;
@@ -5608,7 +5761,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
ctx.defer_deinit_until_callback_completes = null;
if (should_deinit_context) {
- request_object.uws_request = null;
+ request_object.request_context = JSC.API.AnyRequestContext.Null;
ctx.deinit();
return;
}
@@ -5637,18 +5790,15 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable;
ctx.request_body = body;
- const js_signal = JSC.WebCore.AbortSignal.create(this.globalThis);
- js_signal.ensureStillAlive();
- if (JSC.WebCore.AbortSignal.fromJS(js_signal)) |signal| {
- ctx.signal = signal.ref().ref(); // +2 refs 1 for the request and 1 for the request context
- }
+ var signal = JSC.WebCore.AbortSignal.new(this.globalThis);
+ ctx.signal = signal;
request_object.* = .{
.method = ctx.method,
- .uws_request = req,
+ .request_context = AnyRequestContext.init(ctx),
.upgrader = ctx,
.https = ssl_enabled,
- .signal = ctx.signal,
+ .signal = signal.ref(),
.body = body.ref(),
};
ctx.upgrade_context = upgrade_ctx;
@@ -5663,7 +5813,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
defer {
// uWS request will not live longer than this function
- request_object.uws_request = null;
+ request_object.request_context = JSC.API.AnyRequestContext.Null;
}
var should_deinit_context = false;
@@ -5678,7 +5828,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
ctx.defer_deinit_until_callback_completes = null;
if (should_deinit_context) {
- request_object.uws_request = null;
+ request_object.request_context = JSC.API.AnyRequestContext.Null;
ctx.deinit();
return;
}