diff options
Diffstat (limited to 'src/bun.js/api/server.zig')
-rw-r--r-- | src/bun.js/api/server.zig | 272 |
1 files changed, 211 insertions, 61 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 85d4dadb5..edf1d6d69 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -1125,6 +1125,92 @@ fn NewFlags(comptime debug_mode: bool) type { }; } +/// A generic wrapper for the HTTP(s) Server`RequestContext`s. +/// Only really exists because of `NewServer()` and `NewRequestContext()` generics. +pub const AnyRequestContext = struct { + pub const Pointer = bun.TaggedPointerUnion(.{ + HTTPServer.RequestContext, + HTTPSServer.RequestContext, + DebugHTTPServer.RequestContext, + DebugHTTPSServer.RequestContext, + }); + + tagged_pointer: Pointer, + + pub const Null = .{ .tagged_pointer = Pointer.Null }; + + pub fn init(request_ctx: anytype) AnyRequestContext { + return .{ .tagged_pointer = Pointer.init(request_ctx) }; + } + + pub fn getRemoteSocketInfo(self: AnyRequestContext) ?uws.SocketAddress { + if (self.tagged_pointer.isNull()) { + return null; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPSServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).getRemoteSocketInfo(); + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } + + /// Wont actually set anything if `self` is `.none` + pub fn setRequest(self: AnyRequestContext, req: *uws.Request) void { + if (self.tagged_pointer.isNull()) { + return; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + self.tagged_pointer.as(HTTPServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + self.tagged_pointer.as(HTTPSServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + self.tagged_pointer.as(DebugHTTPServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req = req; + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } + + pub fn getRequest(self: AnyRequestContext) ?*uws.Request { + if (self.tagged_pointer.isNull()) { + return null; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPSServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req; + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } +}; + // This is defined separately partially to work-around an LLVM debugger bug. fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comptime ThisServer: type) type { return struct { @@ -1443,6 +1529,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } pub fn endStream(this: *RequestContext, closeConnection: bool) void { + ctxLog("endStream", .{}); if (this.resp) |resp| { if (this.flags.is_waiting_for_request_body) { this.flags.is_waiting_for_request_body = false; @@ -1537,8 +1624,17 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp pub fn onAbort(this: *RequestContext, resp: *App.Response) void { std.debug.assert(this.resp == resp); std.debug.assert(!this.flags.aborted); - //mark request as aborted + // mark request as aborted this.flags.aborted = true; + var any_js_calls = false; + var vm = this.server.vm; + defer { + // This is a task in the event loop. + // If we called into JavaScript, we must drain the microtask queue + if (any_js_calls) { + vm.drainMicrotasks(); + } + } // if signal is not aborted, abort the signal if (this.signal) |signal| { @@ -1547,6 +1643,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); reason.ensureStillAlive(); _ = signal.signal(reason); + any_js_calls = true; } _ = signal.unref(); } @@ -1578,6 +1675,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } else if (body.value.Locked.readable != null) { body.value.Locked.readable.?.abort(this.server.globalThis); body.value.Locked.readable = null; + any_js_calls = true; } body.value.toErrorInstance(JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis), this.server.globalThis); } @@ -1588,6 +1686,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp if (response.body.value.Locked.readable) |*readable| { response.body.value.Locked.readable = null; readable.abort(this.server.globalThis); + any_js_calls = true; } } } @@ -1597,10 +1696,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this.pending_promises_for_abort += 1; this.promise = null; promise.asAnyPromise().?.reject(this.server.globalThis, JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis)); - } - - if (this.pending_promises_for_abort > 0) { - this.server.vm.tick(); + any_js_calls = true; } } } @@ -1720,6 +1816,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this: *RequestContext, headers: *JSC.FetchHeaders, ) void { + ctxLog("writeHeaders", .{}); headers.fastRemove(.ContentLength); headers.fastRemove(.TransferEncoding); if (!ssl_enabled) headers.fastRemove(.StrictTransportSecurity); @@ -2091,6 +2188,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } fn doRenderStream(pair: *StreamPair) void { + ctxLog("doRenderStream", .{}); var this = pair.this; var stream = pair.stream; if (this.resp == null or this.flags.aborted) { @@ -2214,6 +2312,14 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp }, } return; + } else { + // if is not a promise we treat it as Error + streamLog("returned an error", .{}); + if (!this.flags.aborted) resp.clearAborted(); + response_stream.detach(); + this.sink = null; + response_stream.sink.destroy(); + return this.handleReject(assignment_result); } } @@ -2223,6 +2329,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp defer stream.value.unprotect(); response_stream.sink.markDone(); this.finalizeForAbort(); + response_stream.sink.onFirstWrite = null; response_stream.sink.finalize(); return; @@ -2246,7 +2353,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this.setAbortHandler(); streamLog("is in progress, but did not return a Promise. Finalizing request context", .{}); - this.finalize(); + response_stream.sink.onFirstWrite = null; + response_stream.sink.ctx = null; + response_stream.detach(); + stream.cancel(globalThis); + response_stream.sink.markDone(); + this.renderMissing(); } const streamLog = Output.scoped(.ReadableStream, false); @@ -2256,7 +2368,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } fn toAsyncWithoutAbortHandler(ctx: *RequestContext, req: *uws.Request, request_object: *Request) void { - request_object.uws_request = req; + request_object.request_context.setRequest(req); request_object.ensureURL() catch { request_object.url = bun.String.empty; @@ -2269,7 +2381,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // This object dies after the stack frame is popped // so we have to clear it in here too - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } fn toAsync( @@ -2446,7 +2558,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } streamLog("onResolve({any})", .{wrote_anything}); - //aborted so call finalizeForAbort if (req.flags.aborted or req.resp == null) { req.finalizeForAbort(); @@ -2723,7 +2834,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } pub fn doRender(this: *RequestContext) void { - ctxLog("render", .{}); + ctxLog("doRender", .{}); if (this.flags.aborted) { this.finalizeForAbort(); @@ -2877,7 +2988,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp var needs_content_type = true; const content_type: MimeType = brk: { - if (response.body.init.headers) |headers_| { + if (response.init.headers) |headers_| { if (headers_.fastGet(.ContentType)) |content| { needs_content_type = false; break :brk MimeType.byName(content.slice()); @@ -2897,7 +3008,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp }; var has_content_disposition = false; - if (response.body.init.headers) |headers_| { + if (response.init.headers) |headers_| { has_content_disposition = headers_.fastHas(.ContentDisposition); needs_content_range = needs_content_range and headers_.fastHas(.ContentRange); if (needs_content_range) { @@ -2907,7 +3018,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this.writeStatus(status); this.writeHeaders(headers_); - response.body.init.headers = null; + response.init.headers = null; headers_.deref(); } else if (needs_content_range) { status = 206; @@ -3039,7 +3150,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp if (last) { var bytes = this.request_body_buf; - defer this.request_body_buf = .{}; + var old = body.value; const total = bytes.items.len + chunk.len; @@ -3070,6 +3181,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp }; // } } + this.request_body_buf = .{}; if (old == .Locked) { var vm = this.server.vm; @@ -3142,6 +3254,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return onStartStreamingRequestBody(bun.cast(*RequestContext, this)); } + pub fn getRemoteSocketInfo(this: *RequestContext) ?uws.SocketAddress { + return (this.resp orelse return null).getRemoteSocketInfo(); + } + pub const Export = shim.exportFunctions(.{ .onResolve = onResolve, .onReject = onReject, @@ -4667,17 +4783,6 @@ pub const ServerWebSocket = struct { return JSValue.jsBoolean(this.websocket.isSubscribed(topic.slice())); } - // pub fn getTopics( - // this: *ServerWebSocket, - // globalThis: *JSC.JSGlobalObject, - // ) callconv(.C) JSValue { - // if (this.closed) { - // return JSValue.createStringArray(globalThis, bun.default_allocator, null, 0, false); - // } - - // this - // } - pub fn getRemoteAddress( this: *ServerWebSocket, globalThis: *JSC.JSGlobalObject, @@ -4704,7 +4809,7 @@ pub const ServerWebSocket = struct { pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { return struct { pub const ssl_enabled = ssl_enabled_; - const debug_mode = debug_mode_; + pub const debug_mode = debug_mode_; const ThisServer = @This(); pub const RequestContext = NewRequestContext(ssl_enabled, debug_mode, @This()); @@ -4742,6 +4847,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp pub const doPublish = JSC.wrapInstanceMethod(ThisServer, "publish", false); pub const doReload = onReload; pub const doFetch = onFetch; + pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false); pub usingnamespace NamespaceType; @@ -4749,6 +4855,24 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp globalThis.throw("Server() is not a constructor", .{}); return null; } + + extern fn JSSocketAddress__create(global: *JSC.JSGlobalObject, ip: JSValue, port: i32, is_ipv6: bool) JSValue; + + pub fn requestIP(this: *ThisServer, request: *JSC.WebCore.Request) JSC.JSValue { + if (this.config.address == .unix) { + return JSValue.jsNull(); + } + return if (request.request_context.getRemoteSocketInfo()) |info| + JSSocketAddress__create( + this.globalThis, + bun.String.static(info.ip).toJSConst(this.globalThis), + info.port, + info.is_ipv6, + ) + else + JSValue.jsNull(); + } + pub fn publish(this: *ThisServer, globalThis: *JSC.JSGlobalObject, topic: ZigString, message_value: JSValue, compress_value: ?JSValue, exception: JSC.C.ExceptionRef) JSValue { if (this.config.websocket == null) return JSValue.jsNumber(0); @@ -5092,7 +5216,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp return JSPromise.rejectedPromiseValue(ctx, err); } - var request = ctx.bunVM().allocator.create(Request) catch unreachable; + var request = bun.default_allocator.create(Request) catch unreachable; request.* = existing_request; const response_value = this.config.onRequest.callWithThis( @@ -5173,6 +5297,37 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp return JSC.JSValue.jsNumber(@as(i32, @intCast(@as(u31, @truncate(this.activeSocketsCount()))))); } + pub fn getAddress(this: *ThisServer, globalThis: *JSGlobalObject) callconv(.C) JSC.JSValue { + switch (this.config.address) { + .unix => |unix| { + var value = bun.String.create(bun.sliceTo(@constCast(unix), 0)); + defer value.deref(); + return value.toJS(globalThis); + }, + .tcp => { + var port: u16 = this.config.address.tcp.port; + + if (this.listener) |listener| { + port = @intCast(listener.getLocalPort()); + + var buf: [64]u8 = [_]u8{0} ** 64; + var is_ipv6: bool = false; + + if (listener.socket().localAddressText(&buf, &is_ipv6)) |slice| { + var ip = bun.String.create(slice); + return JSSocketAddress__create( + this.globalThis, + ip.toJS(this.globalThis), + port, + is_ipv6, + ); + } + } + return JSValue.jsNull(); + }, + } + } + pub fn getHostname(this: *ThisServer, globalThis: *JSGlobalObject) callconv(.C) JSC.JSValue { if (this.cached_hostname.isEmpty()) { if (this.listener) |listener| { @@ -5254,6 +5409,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp var listener = this.listener orelse return; this.listener = null; this.unref(); + + if (!ssl_enabled_) + this.vm.removeListeningSocketForWatchMode(@intCast(listener.socket().fd())); + if (!abrupt) { listener.close(); } else if (!this.flags.terminated) { @@ -5388,24 +5547,18 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp if (error_instance == .zero) { switch (this.config.address) { .tcp => |tcp| { - error_instance = ZigString.init( - std.fmt.bufPrint(&output_buf, "Failed to start server. Is port {d} in use?", .{tcp.port}) catch "Failed to start server", - ).toErrorInstance( - this.globalThis, - ); + error_instance = (JSC.SystemError{ + .message = bun.String.init(std.fmt.bufPrint(&output_buf, "Failed to start server. Is port {d} in use?", .{tcp.port}) catch "Failed to start server"), + .code = bun.String.static("EADDRINUSE"), + .syscall = bun.String.static("listen"), + }).toErrorInstance(this.globalThis); }, .unix => |unix| { - error_instance = ZigString.init( - std.fmt.bufPrint( - &output_buf, - "Failed to listen on unix socket {}", - .{ - strings.QuotedFormatter{ .text = bun.sliceTo(unix, 0) }, - }, - ) catch "Failed to start server", - ).toErrorInstance( - this.globalThis, - ); + error_instance = (JSC.SystemError{ + .message = bun.String.init(std.fmt.bufPrint(&output_buf, "Failed to listen on unix socket {}", .{strings.QuotedFormatter{ .text = bun.sliceTo(unix, 0) }}) catch "Failed to start server"), + .code = bun.String.static("EADDRINUSE"), + .syscall = bun.String.static("listen"), + }).toErrorInstance(this.globalThis); }, } } @@ -5428,6 +5581,8 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp this.listener = socket; this.vm.event_loop_handle = uws.Loop.get(); + if (!ssl_enabled_) + this.vm.addListeningSocketForWatchMode(@intCast(socket.?.socket().fd())); } pub fn ref(this: *ThisServer) void { @@ -5512,21 +5667,19 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp req.setYield(false); var ctx = this.request_pool_allocator.tryGet() catch @panic("ran out of memory"); ctx.create(this, req, resp); + this.vm.jsc.reportExtraMemory(@sizeOf(RequestContext)); var request_object = this.allocator.create(JSC.WebCore.Request) catch unreachable; var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable; ctx.request_body = body; - const js_signal = JSC.WebCore.AbortSignal.create(this.globalThis); - js_signal.ensureStillAlive(); - if (JSC.WebCore.AbortSignal.fromJS(js_signal)) |signal| { - ctx.signal = signal.ref().ref(); // +2 refs 1 for the request and 1 for the request context - } + var signal = JSC.WebCore.AbortSignal.new(this.globalThis); + ctx.signal = signal; request_object.* = .{ .method = ctx.method, - .uws_request = req, + .request_context = AnyRequestContext.init(ctx), .https = ssl_enabled, - .signal = ctx.signal, + .signal = signal.ref(), .body = body.ref(), }; @@ -5593,7 +5746,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { // uWS request will not live longer than this function - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } var should_deinit_context = false; @@ -5608,7 +5761,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp ctx.defer_deinit_until_callback_completes = null; if (should_deinit_context) { - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; ctx.deinit(); return; } @@ -5637,18 +5790,15 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable; ctx.request_body = body; - const js_signal = JSC.WebCore.AbortSignal.create(this.globalThis); - js_signal.ensureStillAlive(); - if (JSC.WebCore.AbortSignal.fromJS(js_signal)) |signal| { - ctx.signal = signal.ref().ref(); // +2 refs 1 for the request and 1 for the request context - } + var signal = JSC.WebCore.AbortSignal.new(this.globalThis); + ctx.signal = signal; request_object.* = .{ .method = ctx.method, - .uws_request = req, + .request_context = AnyRequestContext.init(ctx), .upgrader = ctx, .https = ssl_enabled, - .signal = ctx.signal, + .signal = signal.ref(), .body = body.ref(), }; ctx.upgrade_context = upgrade_ctx; @@ -5663,7 +5813,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { // uWS request will not live longer than this function - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } var should_deinit_context = false; @@ -5678,7 +5828,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp ctx.defer_deinit_until_callback_completes = null; if (should_deinit_context) { - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; ctx.deinit(); return; } |