diff options
Diffstat (limited to 'src/bun.js/api/server.zig')
-rw-r--r-- | src/bun.js/api/server.zig | 135 |
1 files changed, 115 insertions, 20 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index fd6cc1f2b..9d4a8a133 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -1125,6 +1125,92 @@ fn NewFlags(comptime debug_mode: bool) type { }; } +/// A generic wrapper for the HTTP(s) Server`RequestContext`s. +/// Only really exists because of `NewServer()` and `NewRequestContext()` generics. +pub const AnyRequestContext = struct { + pub const Pointer = bun.TaggedPointerUnion(.{ + HTTPServer.RequestContext, + HTTPSServer.RequestContext, + DebugHTTPServer.RequestContext, + DebugHTTPSServer.RequestContext, + }); + + tagged_pointer: Pointer, + + pub const Null = .{ .tagged_pointer = Pointer.Null }; + + pub fn init(request_ctx: anytype) AnyRequestContext { + return .{ .tagged_pointer = Pointer.init(request_ctx) }; + } + + pub fn getRemoteSocketInfo(self: AnyRequestContext) ?uws.SocketAddress { + if (self.tagged_pointer.isNull()) { + return null; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPSServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPServer.RequestContext).getRemoteSocketInfo(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).getRemoteSocketInfo(); + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } + + /// Wont actually set anything if `self` is `.none` + pub fn setRequest(self: AnyRequestContext, req: *uws.Request) void { + if (self.tagged_pointer.isNull()) { + return; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + self.tagged_pointer.as(HTTPServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + self.tagged_pointer.as(HTTPSServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + self.tagged_pointer.as(DebugHTTPServer.RequestContext).req = req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req = req; + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } + + pub fn getRequest(self: AnyRequestContext) ?*uws.Request { + if (self.tagged_pointer.isNull()) { + return null; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPSServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPServer.RequestContext).req; + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).req; + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } +}; + // This is defined separately partially to work-around an LLVM debugger bug. fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comptime ThisServer: type) type { return struct { @@ -2282,7 +2368,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } fn toAsyncWithoutAbortHandler(ctx: *RequestContext, req: *uws.Request, request_object: *Request) void { - request_object.uws_request = req; + request_object.request_context.setRequest(req); request_object.ensureURL() catch { request_object.url = bun.String.empty; @@ -2295,7 +2381,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // This object dies after the stack frame is popped // so we have to clear it in here too - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } fn toAsync( @@ -3168,6 +3254,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return onStartStreamingRequestBody(bun.cast(*RequestContext, this)); } + pub fn getRemoteSocketInfo(this: *RequestContext) ?uws.SocketAddress { + return (this.resp orelse return null).getRemoteSocketInfo(); + } + pub const Export = shim.exportFunctions(.{ .onResolve = onResolve, .onReject = onReject, @@ -4693,17 +4783,6 @@ pub const ServerWebSocket = struct { return JSValue.jsBoolean(this.websocket.isSubscribed(topic.slice())); } - // pub fn getTopics( - // this: *ServerWebSocket, - // globalThis: *JSC.JSGlobalObject, - // ) callconv(.C) JSValue { - // if (this.closed) { - // return JSValue.createStringArray(globalThis, bun.default_allocator, null, 0, false); - // } - - // this - // } - pub fn getRemoteAddress( this: *ServerWebSocket, globalThis: *JSC.JSGlobalObject, @@ -4730,7 +4809,7 @@ pub const ServerWebSocket = struct { pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { return struct { pub const ssl_enabled = ssl_enabled_; - const debug_mode = debug_mode_; + pub const debug_mode = debug_mode_; const ThisServer = @This(); pub const RequestContext = NewRequestContext(ssl_enabled, debug_mode, @This()); @@ -4768,6 +4847,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp pub const doPublish = JSC.wrapInstanceMethod(ThisServer, "publish", false); pub const doReload = onReload; pub const doFetch = onFetch; + pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false); pub usingnamespace NamespaceType; @@ -4775,6 +4855,21 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp globalThis.throw("Server() is not a constructor", .{}); return null; } + + extern fn JSSocketAddress__create(global: *JSC.JSGlobalObject, ip: JSValue, port: i32, is_ipv6: bool) JSValue; + + pub fn requestIP(this: *ThisServer, request: *JSC.WebCore.Request) JSC.JSValue { + return if (request.request_context.getRemoteSocketInfo()) |info| + JSSocketAddress__create( + this.globalThis, + bun.String.static(info.ip).toJSConst(this.globalThis), + info.port, + info.is_ipv6, + ) + else + JSValue.jsNull(); + } + pub fn publish(this: *ThisServer, globalThis: *JSC.JSGlobalObject, topic: ZigString, message_value: JSValue, compress_value: ?JSValue, exception: JSC.C.ExceptionRef) JSValue { if (this.config.websocket == null) return JSValue.jsNumber(0); @@ -5551,7 +5646,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp request_object.* = .{ .method = ctx.method, - .uws_request = req, + .request_context = AnyRequestContext.init(ctx), .https = ssl_enabled, .signal = ctx.signal, .body = body.ref(), @@ -5620,7 +5715,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { // uWS request will not live longer than this function - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } var should_deinit_context = false; @@ -5635,7 +5730,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp ctx.defer_deinit_until_callback_completes = null; if (should_deinit_context) { - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; ctx.deinit(); return; } @@ -5672,7 +5767,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp request_object.* = .{ .method = ctx.method, - .uws_request = req, + .request_context = AnyRequestContext.init(ctx), .upgrader = ctx, .https = ssl_enabled, .signal = ctx.signal, @@ -5690,7 +5785,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { // uWS request will not live longer than this function - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; } var should_deinit_context = false; @@ -5705,7 +5800,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp ctx.defer_deinit_until_callback_completes = null; if (should_deinit_context) { - request_object.uws_request = null; + request_object.request_context = JSC.API.AnyRequestContext.Null; ctx.deinit(); return; } |