diff options
Diffstat (limited to 'src/bun.js/api/server.zig')
-rw-r--r-- | src/bun.js/api/server.zig | 1261 |
1 files changed, 1157 insertions, 104 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 11400dde5..9b26f4b11 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -109,6 +109,9 @@ pub const ServerConfig = struct { onError: JSC.JSValue = JSC.JSValue.zero, onRequest: JSC.JSValue = JSC.JSValue.zero, + websockets: WebSocketServer.List = .{}, + websocket: WebSocketServer = .{}, + pub const SSLConfig = struct { server_name: [*c]const u8 = null, @@ -274,12 +277,89 @@ pub const ServerConfig = struct { args.base_uri = origin; } + var websockets = WebSocketServer.List{}; + if (arguments.next()) |arg| { if (arg.isUndefinedOrNull() or !arg.isObject()) { JSC.throwInvalidArguments("Bun.serve expects an object", .{}, global, exception); return args; } + if (arg.getTruthy(global, "webSocket") orelse arg.getTruthy(global, "websocket")) |websocket_object| { + if (!websocket_object.isObject()) { + JSC.throwInvalidArguments("Expected webSocket to be an object", .{}, global, exception); + if (args.ssl_config) |*conf| { + conf.deinit(); + } + return args; + } + + if (WebSocketServer.onCreate(global, websocket_object)) |wss| { + args.websocket = wss; + } else { + if (args.ssl_config) |*conf| { + conf.deinit(); + } + return args; + } + } + + if (arg.getTruthy(global, "webSockets") orelse arg.getTruthy(global, "websockets")) |websocket_object| { + if (!websocket_object.isObject()) { + JSC.throwInvalidArguments("Expected webSockets to be an object", .{}, global, exception); + if (args.ssl_config) |*conf| { + conf.deinit(); + } + return args; + } + + var property_names = JSC.JSPropertyIterator(.{ + .include_value = true, + .skip_empty_name = true, + }).init(global, websocket_object.asObjectRef()); + + defer property_names.deinit(); + websockets.ensureTotalCapacity(bun.default_allocator, property_names.len) catch unreachable; + while (property_names.next()) |name| { + var str = name.toSlice(bun.default_allocator); + defer str.deinit(); + const slice = str.slice(); + if (slice.len == 0) continue; + if (slice[0] != '/') { + JSC.throwInvalidArguments("Expected webSocket path to start with /", .{}, global, exception); + if (args.ssl_config) |*conf| { + conf.deinit(); + } + return args; + } + + const object = property_names.value; + if (object.isEmptyOrUndefinedOrNull() or !object.isObject()) { + JSC.throwInvalidArguments("Expected webSocket to be an object", .{}, global, exception); + if (args.ssl_config) |*conf| { + conf.deinit(); + } + websockets.deinit(bun.default_allocator); + return args; + } + + const handler = WebSocketServer.onCreate(global, object) orelse { + if (args.ssl_config) |*conf| { + conf.deinit(); + } + websockets.deinit(bun.default_allocator); + return args; + }; + + websockets.putAssumeCapacity( + bun.span(bun.default_allocator.dupeZ(u8, slice) catch unreachable), + handler, + ); + } + + args.websockets = websockets; + } + if (arg.getTruthy(global, "port")) |port_| { args.port = @intCast(u16, @minimum(@maximum(0, port_.toInt32()), std.math.maxInt(u16))); } @@ -1386,6 +1466,118 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp const streamLog = Output.scoped(.ReadableStream, false); + pub fn onResponse( + ctx: *RequestContext, + this: *ThisServer, + req: *uws.Request, + request_object: *Request, + request_value: JSValue, + response_value: JSValue, + ) void { + request_value.ensureStillAlive(); + response_value.ensureStillAlive(); + + if (ctx.aborted) { + ctx.finalizeForAbort(); + return; + } + if (response_value.isEmptyOrUndefinedOrNull() and !ctx.resp.hasResponded()) { + ctx.renderMissing(); + return; + } + + if (response_value.isError() or response_value.isAggregateError(this.globalThis) or response_value.isException(this.globalThis.vm())) { + ctx.runErrorHandler(response_value); + + return; + } + + if (response_value.as(JSC.WebCore.Response)) |response| { + ctx.response_jsvalue = response_value; + ctx.response_jsvalue.ensureStillAlive(); + ctx.response_protected = false; + switch (response.body.value) { + .Blob => |*blob| { + if (blob.needsToReadFile()) { + response_value.protect(); + ctx.response_protected = true; + } + }, + .Locked => { + response_value.protect(); + ctx.response_protected = true; + }, + else => {}, + } + ctx.render(response); + return; + } + + var wait_for_promise = false; + var vm = this.vm; + + if (response_value.asPromise()) |promise| { + // If we immediately have the value available, we can skip the extra event loop tick + switch (promise.status(vm.global.vm())) { + .Pending => {}, + .Fulfilled => { + ctx.handleResolve(promise.result(vm.global.vm())); + return; + }, + .Rejected => { + ctx.handleReject(promise.result(vm.global.vm())); + return; + }, + } + wait_for_promise = true; + // I don't think this case should happen + // But I'm uncertain + } else if (response_value.asInternalPromise()) |promise| { + switch (promise.status(vm.global.vm())) { + .Pending => {}, + .Fulfilled => { + ctx.handleResolve(promise.result(vm.global.vm())); + return; + }, + .Rejected => { + ctx.handleReject(promise.result(vm.global.vm())); + return; + }, + } + wait_for_promise = true; + } + + if (wait_for_promise) { + request_object.uws_request = req; + + request_object.ensureURL() catch { + request_object.url = ""; + }; + + // we have to clone the request headers here since they will soon belong to a different request + if (request_object.headers == null) { + request_object.headers = JSC.FetchHeaders.createFromUWS(this.globalThis, req); + } + + if (comptime debug_mode) { + ctx.pathname = bun.default_allocator.dupe(u8, request_object.url) catch unreachable; + } + + // This object dies after the stack frame is popped + // so we have to clear it in here too + request_object.uws_request = null; + + ctx.setAbortHandler(); + ctx.pending_promises_for_abort += 1; + + response_value.then(this.globalThis, ctx, RequestContext.onResolve, RequestContext.onReject); + return; + } + + // The user returned something that wasn't a promise or a promise with a response + if (!ctx.resp.hasResponded()) ctx.renderMissing(); + } + pub fn handleResolveStream(req: *RequestContext) void { streamLog("onResolve", .{}); var wrote_anything = false; @@ -2060,6 +2252,738 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp }; } +pub const WebSocketServer = struct { + onOpen: JSC.JSValue = .zero, + onUpgrade: JSC.JSValue = .zero, + onMessage: JSC.JSValue = .zero, + onClose: JSC.JSValue = .zero, + onDrain: JSC.JSValue = .zero, + globalObject: *JSC.JSGlobalObject = undefined, + active_connections: usize = 0, + + maxPayloadLength: u32 = 1024 * 1024 * 16, + maxLifetime: u16 = 0, + idleTimeout: u16 = 120, + compression: i32 = 0, + backpressureLimit: u32 = 1024 * 1024 * 16, + sendPingsAutomatically: bool = true, + resetIdleTimeoutOnSend: bool = true, + closeOnBackpressureLimit: bool = false, + + pub fn toBehavior(this: WebSocketServer) uws.WebSocketBehavior { + return .{ + .maxPayloadLength = this.maxPayloadLength, + .idleTimeout = this.idleTimeout, + .compression = this.compression, + .maxBackpressure = this.backpressureLimit, + .sendPingsAutomatically = this.sendPingsAutomatically, + .maxLifetime = this.maxLifetime, + .resetIdleTimeoutOnSend = this.resetIdleTimeoutOnSend, + .closeOnBackpressureLimit = this.closeOnBackpressureLimit, + }; + } + + pub fn protect(this: WebSocketServer) void { + this.onUpgrade.protect(); + this.onOpen.protect(); + this.onMessage.protect(); + this.onClose.protect(); + this.onDrain.protect(); + } + + pub fn unprotect(this: WebSocketServer) void { + this.onUpgrade.unprotect(); + this.onOpen.unprotect(); + this.onMessage.unprotect(); + this.onClose.unprotect(); + this.onDrain.unprotect(); + } + + const CompressTable = bun.ComptimeStringMap(i32, .{ + .{ "disable", 0 }, + .{ "shared", uws.SHARED_COMPRESSOR }, + .{ "dedicated", uws.DEDICATED_COMPRESSOR }, + .{ "3KB", uws.DEDICATED_COMPRESSOR_3KB }, + .{ "4KB", uws.DEDICATED_COMPRESSOR_4KB }, + .{ "8KB", uws.DEDICATED_COMPRESSOR_8KB }, + .{ "16KB", uws.DEDICATED_COMPRESSOR_16KB }, + .{ "32KB", uws.DEDICATED_COMPRESSOR_32KB }, + .{ "64KB", uws.DEDICATED_COMPRESSOR_64KB }, + .{ "128KB", uws.DEDICATED_COMPRESSOR_128KB }, + .{ "256KB", uws.DEDICATED_COMPRESSOR_256KB }, + }); + + const DecompressTable = bun.ComptimeStringMap(i32, .{ + .{ "disable", 0 }, + .{ "shared", uws.SHARED_DECOMPRESSOR }, + .{ "dedicated", uws.DEDICATED_DECOMPRESSOR }, + .{ "3KB", uws.DEDICATED_COMPRESSOR_3KB }, + .{ "4KB", uws.DEDICATED_COMPRESSOR_4KB }, + .{ "8KB", uws.DEDICATED_COMPRESSOR_8KB }, + .{ "16KB", uws.DEDICATED_COMPRESSOR_16KB }, + .{ "32KB", uws.DEDICATED_COMPRESSOR_32KB }, + .{ "64KB", uws.DEDICATED_COMPRESSOR_64KB }, + .{ "128KB", uws.DEDICATED_COMPRESSOR_128KB }, + .{ "256KB", uws.DEDICATED_COMPRESSOR_256KB }, + }); + + pub fn onCreate(globalObject: *JSC.JSGlobalObject, object: JSValue) ?WebSocketServer { + if (!object.isObject()) { + globalObject.throwInvalidArguments("websocket expects an options object", .{}); + return null; + } + + var server = WebSocketServer{}; + + if (object.getTruthy(globalObject, "compressor")) |compression| { + if (compression.isBoolean()) { + server.compression |= if (compression.toBoolean()) uws.SHARED_COMPRESSOR else 0; + } else if (compression.isString()) { + var slice = compression.toSlice(globalObject, bun.default_allocator); + defer slice.deinit(); + server.compression |= CompressTable.get(slice.slice()) orelse { + globalObject.throwInvalidArguments( + "websocket expects a valid compressor option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"", + .{}, + ); + return null; + }; + } else { + globalObject.throwInvalidArguments( + "websocket expects a valid compressor option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"", + .{}, + ); + return null; + } + } + if (object.getTruthy(globalObject, "decompressor")) |compression| { + if (compression.isBoolean()) { + server.compression |= if (compression.toBoolean()) uws.SHARED_DECOMPRESSOR else 0; + } else if (compression.isString()) { + var slice = compression.toSlice(globalObject, bun.default_allocator); + defer slice.deinit(); + server.compression |= DecompressTable.get(slice.slice()) orelse { + globalObject.throwInvalidArguments( + "websocket expects a valid decompressor option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"", + .{}, + ); + return null; + }; + } else { + globalObject.throwInvalidArguments( + "websocket expects a valid decompressor option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"", + .{}, + ); + return null; + } + } + + if (object.get(globalObject, "maxPayloadLength")) |value| { + if (!value.isUndefinedOrNull()) { + if (!value.isAnyInt()) { + globalObject.throwInvalidArguments("websocket expects maxPayloadLength to be an integer", .{}); + return null; + } + server.maxPayloadLength = @intCast(u32, @truncate(i33, @maximum(value.toInt64(), 0))); + } + } + if (object.get(globalObject, "idleTimeout")) |value| { + if (!value.isUndefinedOrNull()) { + if (!value.isAnyInt()) { + globalObject.throwInvalidArguments("websocket expects idleTimeout to be an integer", .{}); + return null; + } + + server.idleTimeout = value.to(u16); + } + } + if (object.get(globalObject, "backpressureLimit")) |value| { + if (!value.isUndefinedOrNull()) { + if (!value.isAnyInt()) { + globalObject.throwInvalidArguments("websocket expects backpressureLimit to be an integer", .{}); + return null; + } + + server.backpressureLimit = @intCast(u32, @truncate(i33, @maximum(value.toInt64(), 0))); + } + } + // if (object.get(globalObject, "sendPings")) |value| { + // if (!value.isUndefinedOrNull()) { + // if (!value.isBoolean()) { + // globalObject.throwInvalidArguments("websocket expects sendPings to be a boolean", .{}); + // return null; + // } + + // server.sendPings = value.toBoolean(); + // } + // } + + if (object.get(globalObject, "closeOnBackpressureLimit")) |value| { + if (!value.isUndefinedOrNull()) { + if (!value.isBoolean()) { + globalObject.throwInvalidArguments("websocket expects closeOnBackpressureLimit to be a boolean", .{}); + return null; + } + + server.closeOnBackpressureLimit = value.toBoolean(); + } + } + + if (object.getTruthy(globalObject, "message")) |message| { + if (!message.isCallable(globalObject.vm())) { + globalObject.throwInvalidArguments("websocket expects a function for the message option", .{}); + return null; + } + server.onMessage = message; + message.ensureStillAlive(); + } + + if (object.getTruthy(globalObject, "open")) |open| { + if (!open.isCallable(globalObject.vm())) { + globalObject.throwInvalidArguments("websocket expects a function for the open option", .{}); + return null; + } + server.onOpen = open; + open.ensureStillAlive(); + } + + if (object.getTruthy(globalObject, "close")) |close| { + if (!close.isCallable(globalObject.vm())) { + globalObject.throwInvalidArguments("websocket expects a function for the close option", .{}); + return null; + } + server.onClose = close; + close.ensureStillAlive(); + } + + if (object.getTruthy(globalObject, "drain")) |drain| { + if (!drain.isCallable(globalObject.vm())) { + globalObject.throwInvalidArguments("websocket expects a function for the drain option", .{}); + return null; + } + server.onDrain = drain; + drain.ensureStillAlive(); + } + + if (object.getTruthy(globalObject, "upgrade")) |upgrade| { + if (!upgrade.isCallable(globalObject.vm())) { + globalObject.throwInvalidArguments("websocket expects a function for the upgrade option", .{}); + return null; + } + server.onUpgrade = upgrade; + upgrade.ensureStillAlive(); + } + + server.protect(); + return server; + } + + pub const List = std.StringArrayHashMapUnmanaged(WebSocketServer); +}; + +const Corker = struct { + args: []const JSValue, + globalObject: *JSC.JSGlobalObject, + callback: JSC.JSValue, + result: JSValue = .zero, + + pub fn run(this: *Corker) void { + this.result = this.callback.call(this.globalObject, this.args); + } +}; + +pub const ServerWebSocket = struct { + handler: *WebSocketServer, + this_value: JSValue = .zero, + websocket: uws.AnyWebSocket = undefined, + closed: bool = false, + + pub usingnamespace JSC.Codegen.JSServerWebSocket; + + const log = Output.scoped(.WebSocketServer, false); + + pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { + log("OnOpen", .{}); + + this.websocket = ws; + this.closed = false; + + // the this value is initially set to whatever the user passed in + const value_to_cache = this.this_value; + + var handler = this.handler; + handler.active_connections +|= 1; + var globalObject = handler.globalObject; + + const onOpenHandler = handler.onOpen; + this.this_value = .zero; + if (value_to_cache != .zero) { + const current_this = this.getThisValue(); + ServerWebSocket.dataSetCached(current_this, globalObject, value_to_cache); + } + + if (onOpenHandler.isEmptyOrUndefinedOrNull()) return; + var args = [_]JSValue{this.this_value}; + + var corker = Corker{ + .args = &args, + .globalObject = globalObject, + .callback = onOpenHandler, + }; + ws.cork(&corker, Corker.run); + if (corker.result.isAnyError(globalObject)) { + log("onOpen exception", .{}); + + ws.close(); + _ = ServerWebSocket.dangerouslySetPtr(this.this_value, null); + handler.active_connections -|= 1; + this.this_value.unprotect(); + bun.default_allocator.destroy(this); + globalObject.bunVM().runErrorHandler(corker.result, null); + } + } + + pub fn getThisValue(this: *ServerWebSocket) JSValue { + var this_value = this.this_value; + if (this_value == .zero) { + this_value = this.toJS(this.handler.globalObject); + this_value.protect(); + this.this_value = this_value; + } + return this_value; + } + + pub fn onMessage( + this: *ServerWebSocket, + ws: uws.AnyWebSocket, + message: []const u8, + opcode: uws.Opcode, + ) void { + log("onMessage({d}): {s}", .{ + @enumToInt(opcode), + message, + }); + const onMessageHandler = this.handler.onMessage; + if (onMessageHandler.isEmptyOrUndefinedOrNull()) return; + var globalObject = this.handler.globalObject; + + const arguments = [_]JSValue{ + this.getThisValue(), + switch (opcode) { + .text => brk: { + var str = ZigString.init(message); + str.markUTF8(); + break :brk str.toValueGC(globalObject); + }, + .binary => JSC.ArrayBuffer.create(globalObject, message, .Uint8Array), + else => unreachable, + }, + }; + + var corker = Corker{ + .args = &arguments, + .globalObject = globalObject, + .callback = onMessageHandler, + }; + + ws.cork(&corker, Corker.run); + const result = corker.result; + + if (result.isEmptyOrUndefinedOrNull()) return; + + if (result.isAnyError(globalObject)) { + this.handler.globalObject.bunVM().runErrorHandler(result, null); + return; + } + + if (result.asPromise()) |promise| { + switch (promise.status(globalObject.vm())) { + .Rejected => { + _ = promise.result(globalObject.vm()); + return; + }, + + else => {}, + } + } + } + pub fn onDrain(this: *ServerWebSocket, _: uws.AnyWebSocket) void { + log("onDrain", .{}); + + var handler = this.handler; + if (handler.onDrain != .zero) { + const result = handler.onDrain.call(handler.globalObject, &[_]JSC.JSValue{this.this_value}); + + if (result.isAnyError(handler.globalObject)) { + log("onDrain error", .{}); + handler.globalObject.bunVM().runErrorHandler(result, null); + } + } + } + pub fn onPing(_: *ServerWebSocket, _: uws.AnyWebSocket, _: []const u8) void { + log("onPing", .{}); + } + pub fn onPong(_: *ServerWebSocket, _: uws.AnyWebSocket, _: []const u8) void { + log("onPong", .{}); + } + pub fn onClose(this: *ServerWebSocket, _: uws.AnyWebSocket, code: i32, message: []const u8) void { + log("onClose", .{}); + var handler = this.handler; + this.closed = true; + defer handler.active_connections -|= 1; + + if (handler.onClose != .zero) { + const result = handler.onClose.call( + handler.globalObject, + &[_]JSC.JSValue{ this.this_value, JSValue.jsNumber(code), ZigString.init(message).toValueGC(handler.globalObject) }, + ); + + if (result.isAnyError(handler.globalObject)) { + log("onClose error", .{}); + handler.globalObject.bunVM().runErrorHandler(result, null); + } + } + + this.this_value.unprotect(); + } + + pub fn behavior(comptime ServerType: type, comptime ssl: bool, opts: uws.WebSocketBehavior) uws.WebSocketBehavior { + return uws.WebSocketBehavior.Wrap(ServerType, @This(), ssl).apply(opts); + } + + pub fn constructor(globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) ?*ServerWebSocket { + globalObject.throw("Cannot construct ServerWebSocket", .{}); + return null; + } + + pub fn finalize(this: *ServerWebSocket) callconv(.C) void { + bun.default_allocator.destroy(this); + } + + pub fn publish( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(4); + + if (args.len < 1) { + log("publish()", .{}); + globalThis.throw("publish requires at least 1 argument", .{}); + return .zero; + } + + if (this.closed) { + log("publish() closed", .{}); + return JSValue.jsNumber(0); + } + + const topic_value = args.ptr[0]; + const message_value = args.ptr[1]; + const compress_value = args.ptr[2]; + + if (topic_value.isEmptyOrUndefinedOrNull() or !topic_value.isString()) { + log("publish() topic invalid", .{}); + globalThis.throw("publish requires a topic string", .{}); + return .zero; + } + + var topic_slice = topic_value.toSlice(globalThis, bun.default_allocator); + defer topic_slice.deinit(); + if (topic_slice.len == 0) { + globalThis.throw("publish requires a non-empty topic", .{}); + return JSValue.jsNumber(0); + } + + const compress = args.len > 1 and compress_value.toBoolean(); + + if (message_value.isEmptyOrUndefinedOrNull()) { + globalThis.throw("publish requires a non-empty message", .{}); + return .zero; + } + + if (message_value.asArrayBuffer(globalThis)) |buffer| { + if (buffer.len == 0) { + globalThis.throw("publish requires a non-empty message", .{}); + return .zero; + } + + return JSValue.jsNumber( + // if 0, return 0 + // else return number of bytes sent + @as(i32, @boolToInt(this.websocket.publishWithOptions(topic_slice.slice(), buffer.slice(), .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + ); + } + + { + var string_slice = message_value.toSlice(globalThis, bun.default_allocator); + defer string_slice.deinit(); + if (string_slice.len == 0) { + return JSValue.jsNumber(0); + } + + const buffer = string_slice.slice(); + return JSValue.jsNumber( + // if 0, return 0 + // else return number of bytes sent + @as(i32, @boolToInt(this.websocket.publishWithOptions(topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + ); + } + + return .zero; + } + + pub fn send( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(2); + + if (args.len < 1) { + log("send()", .{}); + globalThis.throw("send requires at least 1 argument", .{}); + return .zero; + } + + if (this.closed) { + log("send() closed", .{}); + return JSValue.jsNumber(0); + } + + const message_value = args.ptr[0]; + const compress_value = args.ptr[1]; + + const compress = args.len > 1 and compress_value.toBoolean(); + + if (message_value.isEmptyOrUndefinedOrNull()) { + globalThis.throw("send requires a non-empty message", .{}); + return .zero; + } + + if (message_value.asArrayBuffer(globalThis)) |buffer| { + if (buffer.len == 0) { + globalThis.throw("send requires a non-empty message", .{}); + return .zero; + } + + switch (this.websocket.send(buffer.slice(), .binary, compress, true)) { + .backpressure => { + log("send() backpressure ({d} bytes)", .{buffer.len}); + return JSValue.jsNumber(-1); + }, + .success => { + log("send() success ({d} bytes)", .{buffer.len}); + return JSValue.jsNumber(buffer.slice().len); + }, + .dropped => { + log("send() dropped ({d} bytes)", .{buffer.len}); + return JSValue.jsNumber(0); + }, + } + } + + { + var string_slice = message_value.toSlice(globalThis, bun.default_allocator); + defer string_slice.deinit(); + if (string_slice.len == 0) { + return JSValue.jsNumber(0); + } + + const buffer = string_slice.slice(); + switch (this.websocket.send(buffer, .text, compress, true)) { + .backpressure => { + log("send() backpressure ({d} bytes string)", .{buffer.len}); + return JSValue.jsNumber(-1); + }, + .success => { + log("send() success ({d} bytes string)", .{buffer.len}); + return JSValue.jsNumber(buffer.len); + }, + .dropped => { + log("send() dropped ({d} bytes string)", .{buffer.len}); + return JSValue.jsNumber(0); + }, + } + } + + return .zero; + } + + pub fn getData( + _: *ServerWebSocket, + _: *JSC.JSGlobalObject, + ) callconv(.C) JSValue { + log("getData()", .{}); + return JSValue.jsUndefined(); + } + + pub fn setData( + this: *ServerWebSocket, + globalObject: *JSC.JSGlobalObject, + value: JSC.JSValue, + ) callconv(.C) bool { + log("setData()", .{}); + ServerWebSocket.dataSetCached(this.this_value, globalObject, value); + return true; + } + + pub fn getReadyState( + this: *ServerWebSocket, + _: *JSC.JSGlobalObject, + ) callconv(.C) JSValue { + log("getReadyState()", .{}); + + if (this.closed) { + return JSValue.jsNumber(3); + } + + return JSValue.jsNumber(1); + } + + pub fn close( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(2); + log("close()", .{}); + + if (this.closed) { + return .zero; + } + + const code = if (args.len > 0) args.ptr[0].toInt32() else @as(i32, 1000); + var message_value = if (args.len > 1) args.ptr[1].toSlice(globalThis, bun.default_allocator) else ZigString.Slice.empty; + defer message_value.deinit(); + if (code > 0) { + this.websocket.end(code, message_value.slice()); + } else { + this.closed = true; + this.this_value.unprotect(); + this.websocket.close(); + } + + return JSValue.jsUndefined(); + } + pub fn getBufferedAmount( + this: *ServerWebSocket, + _: *JSC.JSGlobalObject, + _: *JSC.CallFrame, + ) callconv(.C) JSValue { + log("getBufferedAmount()", .{}); + + if (this.closed) { + return JSValue.jsNumber(0); + } + + return JSValue.jsNumber(this.websocket.getBufferedAmount()); + } + pub fn subscribe( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(1); + if (args.len < 1) { + globalThis.throw("subscribe requires at least 1 argument", .{}); + return .zero; + } + + if (this.closed) { + return JSValue.jsBoolean(true); + } + + var topic = args.ptr[0].toSlice(globalThis, bun.default_allocator); + defer topic.deinit(); + + if (topic.len == 0) { + globalThis.throw("subscribe requires a non-empty topic name", .{}); + return .zero; + } + + return JSValue.jsBoolean(this.websocket.subscribe(topic.slice())); + } + pub fn unsubscribe( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(1); + if (args.len < 1) { + globalThis.throw("unsubscribe requires at least 1 argument", .{}); + return .zero; + } + + if (this.closed) { + return JSValue.jsBoolean(true); + } + + var topic = args.ptr[0].toSlice(globalThis, bun.default_allocator); + defer topic.deinit(); + + if (topic.len == 0) { + globalThis.throw("unsubscribe requires a non-empty topic name", .{}); + return .zero; + } + + return JSValue.jsBoolean(this.websocket.unsubscribe(topic.slice())); + } + pub fn isSubscribed( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + const args = callframe.arguments(1); + if (args.len < 1) { + globalThis.throw("isSubscribed requires at least 1 argument", .{}); + return .zero; + } + + if (this.closed) { + return JSValue.jsBoolean(false); + } + + var topic = args.ptr[0].toSlice(globalThis, bun.default_allocator); + defer topic.deinit(); + + if (topic.len == 0) { + globalThis.throw("isSubscribed requires a non-empty topic name", .{}); + return .zero; + } + + return JSValue.jsBoolean(this.websocket.isSubscribed(topic.slice())); + } + + // pub fn getTopics( + // this: *ServerWebSocket, + // globalThis: *JSC.JSGlobalObject, + // ) callconv(.C) JSValue { + // if (this.closed) { + // return JSValue.createStringArray(globalThis, bun.default_allocator, null, 0, false); + // } + + // this + // } + + pub fn getRemoteAddress( + this: *ServerWebSocket, + globalThis: *JSC.JSGlobalObject, + ) callconv(.C) JSValue { + if (this.closed) { + return JSValue.jsUndefined(); + } + + var buf: [512]u8 = undefined; + const address = this.websocket.getRemoteAddress(&buf); + if (address.len == 0) { + return JSValue.jsUndefined(); + } + + return ZigString.init(address).toValueGC(globalThis); + } +}; + pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { return struct { pub const ssl_enabled = ssl_enabled_; @@ -2119,6 +3043,9 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { .pendingRequests = .{ .get = JSC.getterWrap(ThisServer, "getPendingRequests"), }, + .pendingSockets = .{ + .get = JSC.getterWrap(ThisServer, "getPendingSockets"), + }, }, ); @@ -2150,6 +3077,54 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { this.config.onError = new_config.onError; } + if (new_config.websocket.onMessage != .zero or new_config.websocket.onOpen != .zero) { + this.config.websocket.unprotect(); + if (this.config.websocket.onMessage == .zero) { + this.app.ws("/*", this, 0, ServerWebSocket.behavior( + ThisServer, + ssl_enabled, + this.config.websocket.toBehavior(), + )); + } + + new_config.websocket.globalObject = ctx; + this.config.websocket = new_config.websocket; + } else if (this.config.websocket.onMessage != .zero or this.config.websocket.onOpen != .zero) { + this.config.websocket.unprotect(); + this.config.websocket = .{}; + } + + // we are going to leak the old memory + const new_keys = new_config.websockets.keys(); + const old_keys = this.config.websockets.keys(); + + if (new_keys.len + old_keys.len > 0) { + var all_match = old_keys.len <= new_keys.len; + + // any existing websockets will now call the new config + for (new_keys) |key, i| { + if (this.config.websockets.getPtr(key)) |old_val| { + old_val.unprotect(); + old_val.* = new_config.websockets.values()[i]; + old_val.globalObject = ctx; + } else { + all_match = false; + var new_value = &new_config.websockets.values()[i]; + new_value.globalObject = ctx; + this.app.ws( + key, + this, + i, + ServerWebSocket.behavior(ThisServer, ssl_enabled, new_value.toBehavior()), + ); + } + } + + if (!all_match) { + this.config.websockets = new_config.websockets; + } + } + return this.thisObject.asObjectRef(); } @@ -2289,6 +3264,10 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { return JSC.JSValue.jsNumber(@intCast(i32, @truncate(u31, this.pending_requests))); } + pub fn getPendingSockets(this: *ThisServer) JSC.JSValue { + return JSC.JSValue.jsNumber(@intCast(i32, @truncate(u31, this.activeSocketsCount()))); + } + pub fn getHostname(this: *ThisServer, globalThis: *JSGlobalObject) JSC.JSValue { return ZigString.init(bun.span(this.config.hostname)).toValue(globalThis); } @@ -2317,8 +3296,21 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { this.deinitIfWeCan(); } + pub fn activeSocketsCount(this: *const ThisServer) u32 { + var count = this.config.websocket.active_connections; + for (this.config.websockets.values()) |conn| { + count += conn.active_connections; + } + + return @truncate(u32, count); + } + + pub fn hasActiveWebSockets(this: *const ThisServer) bool { + return this.activeSocketsCount() > 0; + } + pub fn deinitIfWeCan(this: *ThisServer) void { - if (this.pending_requests == 0 and this.listener == null and this.has_js_deinited) { + if (this.pending_requests == 0 and this.listener == null and this.has_js_deinited and !this.hasActiveWebSockets()) { this.unref(); this.deinit(); } @@ -2517,10 +3509,13 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { } } - pub fn onRequest(this: *ThisServer, req: *uws.Request, resp: *App.Response) void { + pub fn onRequest( + this: *ThisServer, + req: *uws.Request, + resp: *App.Response, + ) void { JSC.markBinding(); this.pending_requests += 1; - var vm = this.vm; req.setYield(false); var ctx = this.request_pool_allocator.create(RequestContext) catch @panic("ran out of memory"); ctx.create(this, req, resp); @@ -2573,112 +3568,21 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { } // We keep the Request object alive for the duration of the request so that we can remove the pointer to the UWS request object. - var args = [_]JSC.C.JSValueRef{request_object.toJS(this.globalThis).asObjectRef()}; + var args = [_]JSC.C.JSValueRef{ + request_object.toJS(this.globalThis).asObjectRef(), + }; ctx.request_js_object = args[0]; const request_value = JSValue.c(args[0]); request_value.ensureStillAlive(); const response_value = JSC.C.JSObjectCallAsFunctionReturnValue(this.globalThis, this.config.onRequest.asObjectRef(), this.thisObject.asObjectRef(), 1, &args); - request_value.ensureStillAlive(); - response_value.ensureStillAlive(); - if (ctx.aborted) { - ctx.finalizeForAbort(); - return; - } - if (response_value.isEmptyOrUndefinedOrNull() and !ctx.resp.hasResponded()) { - ctx.renderMissing(); - return; - } - - if (response_value.isError() or response_value.isAggregateError(this.globalThis) or response_value.isException(this.globalThis.vm())) { - ctx.runErrorHandler(response_value); - - return; - } - - if (response_value.as(JSC.WebCore.Response)) |response| { - ctx.response_jsvalue = response_value; - ctx.response_jsvalue.ensureStillAlive(); - ctx.response_protected = false; - switch (response.body.value) { - .Blob => |*blob| { - if (blob.needsToReadFile()) { - response_value.protect(); - ctx.response_protected = true; - } - }, - .Locked => { - response_value.protect(); - ctx.response_protected = true; - }, - else => {}, - } - ctx.render(response); - return; - } - - var wait_for_promise = false; - - if (response_value.asPromise()) |promise| { - // If we immediately have the value available, we can skip the extra event loop tick - switch (promise.status(vm.global.vm())) { - .Pending => {}, - .Fulfilled => { - ctx.handleResolve(promise.result(vm.global.vm())); - return; - }, - .Rejected => { - ctx.handleReject(promise.result(vm.global.vm())); - return; - }, - } - wait_for_promise = true; - // I don't think this case should happen - // But I'm uncertain - } else if (response_value.asInternalPromise()) |promise| { - switch (promise.status(vm.global.vm())) { - .Pending => {}, - .Fulfilled => { - ctx.handleResolve(promise.result(vm.global.vm())); - return; - }, - .Rejected => { - ctx.handleReject(promise.result(vm.global.vm())); - return; - }, - } - wait_for_promise = true; - } - - if (wait_for_promise) { - request_object.uws_request = req; - - request_object.ensureURL() catch { - request_object.url = ""; - }; - - // we have to clone the request headers here since they will soon belong to a different request - if (request_object.headers == null) { - request_object.headers = JSC.FetchHeaders.createFromUWS(this.globalThis, req); - } - - if (comptime debug_mode) { - ctx.pathname = bun.default_allocator.dupe(u8, request_object.url) catch unreachable; - } - - // This object dies after the stack frame is popped - // so we have to clear it in here too - request_object.uws_request = null; - - ctx.setAbortHandler(); - ctx.pending_promises_for_abort += 1; - - response_value.then(this.globalThis, ctx, RequestContext.onResolve, RequestContext.onReject); - return; - } - - // The user returned something that wasn't a promise or a promise with a response - if (!ctx.resp.hasResponded()) ctx.renderMissing(); + ctx.onResponse( + this, + req, + request_object, + request_value, + response_value, + ); } pub fn listen(this: *ThisServer) void { @@ -2701,6 +3605,30 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { this.app = App.create(.{}); } + const websocket_patterns = this.config.websockets.keys(); + { + const values = this.config.websockets.values(); + for (websocket_patterns) |pattern, i| { + values[i].globalObject = this.globalThis; + + this.app.ws(pattern, this, i + 1, ServerWebSocket.behavior( + ThisServer, + ssl_enabled, + values[i].toBehavior(), + )); + } else { + if (this.config.websocket.onMessage != .zero or this.config.websocket.onOpen != .zero) { + this.config.websocket.globalObject = this.globalThis; + this.app.ws( + "/*", + this, + 0, + ServerWebSocket.behavior(ThisServer, ssl_enabled, this.config.websocket.toBehavior()), + ); + } + } + } + this.app.any("/*", *ThisServer, this, onRequest); if (comptime debug_mode) { @@ -2718,12 +3646,137 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { this.config.hostname; this.ref(); + this.app.listenWithConfig(*ThisServer, this, onListen, .{ .port = this.config.port, .host = host, .options = 0, }); } + + pub fn onWebSocketUpgrade(this: *ThisServer, resp: *App.Response, req: *uws.Request, ctx: *uws.uws_socket_context_t, id: usize) void { + var websocket_handler: *WebSocketServer = switch (id) { + 0 => &this.config.websocket, + else => &this.config.websockets.values()[id - 1], + }; + req.setYield(false); + + const onUpgrade = websocket_handler.onUpgrade; + + if (onUpgrade == .zero) { + var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory"); + upgrader.* = .{ + .handler = websocket_handler, + .this_value = .zero, + }; + resp.upgrade( + *ServerWebSocket, + upgrader, + req.header("sec-websocket-key") orelse "", + req.header("sec-websocket-protocol") orelse "", + req.header("sec-websocket-extensions") orelse "", + ctx, + ); + return; + } + + const method = HTTP.Method.which(req.method()) orelse HTTP.Method.GET; + + var request_object = this.allocator.create(JSC.WebCore.Request) catch unreachable; + request_object.* = .{ + .url = "", + .method = method, + .uws_request = req, + .base_url_string_for_joining = this.base_url_string_for_joining, + .body = .{ + .Empty = .{}, + }, + }; + + var args = [_]JSC.JSValue{ + request_object.toJS(this.globalThis), + }; + var request_value = args[0]; + request_value.ensureStillAlive(); + const response_value = websocket_handler.onUpgrade.call(this.globalThis, &args); + request_value.ensureStillAlive(); + response_value.ensureStillAlive(); + + if (response_value.isBoolean() or response_value.isEmptyOrUndefinedOrNull()) { + if (response_value.toBoolean()) { + var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory"); + upgrader.* = .{ + .handler = websocket_handler, + }; + resp.upgrade( + *ServerWebSocket, + upgrader, + req.header("sec-websocket-key") orelse "", + req.header("sec-websocket-protocol") orelse "", + req.header("sec-websocket-extensions") orelse "", + ctx, + ); + return; + } else { + req.setYield(true); + return; + } + } + + if (response_value.as(Response)) |response| { + if (response.statusCode() == 101) { + var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory"); + upgrader.* = .{ + .handler = websocket_handler, + .this_value = response_value, + }; + response_value.ensureStillAlive(); + resp.upgrade( + *ServerWebSocket, + upgrader, + response.header(.SecWebSocketKey) orelse req.header("sec-websocket-key") orelse "", + response.header(.SecWebSocketProtocol) orelse req.header("sec-websocket-protocol") orelse "", + response.header(.SecWebSocketExtensions) orelse req.header("sec-websocket-extensions") orelse "", + ctx, + ); + return; + } + } + + // The returned object becomes the data for the ServerWebSocket + if (response_value.isObject() or (response_value.isString() and response_value.getLengthOfArray(this.globalThis) > 0)) { + var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory"); + upgrader.* = .{ + .handler = websocket_handler, + .this_value = response_value, + }; + response_value.ensureStillAlive(); + + resp.upgrade( + *ServerWebSocket, + upgrader, + req.header("sec-websocket-key") orelse "", + req.header("sec-websocket-protocol") orelse "", + req.header("sec-websocket-extensions") orelse "", + ctx, + ); + return; + } + + var req_ctx = this.request_pool_allocator.create(RequestContext) catch @panic("ran out of memory"); + req_ctx.create(this, req, resp); + req_ctx.request_js_object = request_value.asObjectRef(); + req_ctx.response_jsvalue = response_value; + req_ctx.resp = resp; + + req_ctx.onResponse( + this, + req, + request_object, + request_value, + response_value, + ); + } }; } |