aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bun.js/api/server.classes.ts4
-rw-r--r--src/bun.js/api/server.zig788
-rw-r--r--src/bun.js/bindings/ZigGeneratedClasses.cpp22
-rw-r--r--src/bun.js/bindings/bindings.cpp4
-rw-r--r--src/bun.js/bindings/bindings.zig1
-rw-r--r--src/bun.js/bindings/generated_classes.zig3
-rw-r--r--src/bun.js/builtins/BunBuiltinNames.h1
-rw-r--r--src/bun.js/webcore/response.zig1
-rw-r--r--test/bun.js/websocket-server.test.ts69
9 files changed, 510 insertions, 383 deletions
diff --git a/src/bun.js/api/server.classes.ts b/src/bun.js/api/server.classes.ts
index 5ae4f1739..8cce3336f 100644
--- a/src/bun.js/api/server.classes.ts
+++ b/src/bun.js/api/server.classes.ts
@@ -33,6 +33,10 @@ export default [
fn: "close",
length: 1,
},
+ cork: {
+ fn: "cork",
+ length: 1,
+ },
getBufferedAmount: {
fn: "getBufferedAmount",
length: 0,
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index 2ea6cfc2c..0c0962a4e 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -109,8 +109,7 @@ pub const ServerConfig = struct {
onError: JSC.JSValue = JSC.JSValue.zero,
onRequest: JSC.JSValue = JSC.JSValue.zero,
- websockets: WebSocketServer.List = .{},
- websocket: WebSocketServer = .{},
+ websocket: ?WebSocketServer = null,
pub const SSLConfig = struct {
server_name: [*c]const u8 = null,
@@ -277,8 +276,6 @@ pub const ServerConfig = struct {
args.base_uri = origin;
}
- var websockets = WebSocketServer.List{};
-
if (arguments.next()) |arg| {
if (arg.isUndefinedOrNull() or !arg.isObject()) {
JSC.throwInvalidArguments("Bun.serve expects an object", .{}, global, exception);
@@ -287,7 +284,7 @@ pub const ServerConfig = struct {
if (arg.getTruthy(global, "webSocket") orelse arg.getTruthy(global, "websocket")) |websocket_object| {
if (!websocket_object.isObject()) {
- JSC.throwInvalidArguments("Expected webSocket to be an object", .{}, global, exception);
+ JSC.throwInvalidArguments("Expected websocket to be an object", .{}, global, exception);
if (args.ssl_config) |*conf| {
conf.deinit();
}
@@ -304,62 +301,6 @@ pub const ServerConfig = struct {
}
}
- if (arg.getTruthy(global, "webSockets") orelse arg.getTruthy(global, "websockets")) |websocket_object| {
- if (!websocket_object.isObject()) {
- JSC.throwInvalidArguments("Expected webSockets to be an object", .{}, global, exception);
- if (args.ssl_config) |*conf| {
- conf.deinit();
- }
- return args;
- }
-
- var property_names = JSC.JSPropertyIterator(.{
- .include_value = true,
- .skip_empty_name = true,
- }).init(global, websocket_object.asObjectRef());
-
- defer property_names.deinit();
- websockets.ensureTotalCapacity(bun.default_allocator, property_names.len) catch unreachable;
- while (property_names.next()) |name| {
- var str = name.toSlice(bun.default_allocator);
- defer str.deinit();
- const slice = str.slice();
- if (slice.len == 0) continue;
- if (slice[0] != '/') {
- JSC.throwInvalidArguments("Expected webSocket path to start with /", .{}, global, exception);
- if (args.ssl_config) |*conf| {
- conf.deinit();
- }
- return args;
- }
-
- const object = property_names.value;
- if (object.isEmptyOrUndefinedOrNull() or !object.isObject()) {
- JSC.throwInvalidArguments("Expected webSocket to be an object", .{}, global, exception);
- if (args.ssl_config) |*conf| {
- conf.deinit();
- }
- websockets.deinit(bun.default_allocator);
- return args;
- }
-
- const handler = WebSocketServer.onCreate(global, object) orelse {
- if (args.ssl_config) |*conf| {
- conf.deinit();
- }
- websockets.deinit(bun.default_allocator);
- return args;
- };
-
- websockets.putAssumeCapacity(
- bun.span(bun.default_allocator.dupeZ(u8, slice) catch unreachable),
- handler,
- );
- }
-
- args.websockets = websockets;
- }
-
if (arg.getTruthy(global, "port")) |port_| {
args.port = @intCast(u16, @minimum(@maximum(0, port_.toInt32()), std.math.maxInt(u16)));
}
@@ -599,6 +540,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
method: HTTP.Method,
aborted: bool = false,
finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false),
+ upgrade_context: ?*uws.uws_socket_context_t = null,
/// We can only safely free once the request body promise is finalized
/// and the response is rejected
@@ -661,6 +603,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn handleResolve(ctx: *RequestContext, value: JSC.JSValue) void {
+ if (ctx.didUpgradeWebSocket()) {
+ ctx.finalize();
+ return;
+ }
+
if (value.isEmptyOrUndefinedOrNull()) {
ctx.renderMissing();
return;
@@ -1471,6 +1418,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
const streamLog = Output.scoped(.ReadableStream, false);
+ pub fn didUpgradeWebSocket(this: *RequestContext) bool {
+ return @ptrToInt(this.upgrade_context) == std.math.maxInt(usize);
+ }
+
pub fn onResponse(
ctx: *RequestContext,
this: *ThisServer,
@@ -1486,7 +1437,16 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctx.finalizeForAbort();
return;
}
- if (response_value.isEmptyOrUndefinedOrNull() and !ctx.resp.hasResponded()) {
+
+ if (ctx.didUpgradeWebSocket()) {
+ ctx.finalize();
+ return;
+ }
+
+ if (response_value.isEmptyOrUndefinedOrNull()) {
+ if (ctx.didUpgradeWebSocket())
+ return;
+
ctx.renderMissing();
return;
}
@@ -1530,6 +1490,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
.Fulfilled => {
const fulfilled_value = promise.result(vm.global.vm());
if (fulfilled_value.isEmptyOrUndefinedOrNull()) {
+ if (ctx.didUpgradeWebSocket())
+ return;
+
ctx.renderMissing();
return;
}
@@ -2292,13 +2255,8 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub const WebSocketServer = struct {
- onOpen: JSC.JSValue = .zero,
- onUpgrade: JSC.JSValue = .zero,
- onMessage: JSC.JSValue = .zero,
- onClose: JSC.JSValue = .zero,
- onDrain: JSC.JSValue = .zero,
globalObject: *JSC.JSGlobalObject = undefined,
- active_connections: usize = 0,
+ handler: WebSocketServer.Handler = .{},
maxPayloadLength: u32 = 1024 * 1024 * 16,
maxLifetime: u16 = 0,
@@ -2309,6 +2267,74 @@ pub const WebSocketServer = struct {
resetIdleTimeoutOnSend: bool = true,
closeOnBackpressureLimit: bool = false,
+ pub const Handler = struct {
+ onOpen: JSC.JSValue = .zero,
+ onMessage: JSC.JSValue = .zero,
+ onClose: JSC.JSValue = .zero,
+ onDrain: JSC.JSValue = .zero,
+
+ globalObject: *JSC.JSGlobalObject = undefined,
+ active_connections: usize = 0,
+
+ pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler {
+ var handler = Handler{ .globalObject = globalObject };
+ if (object.getTruthy(globalObject, "message")) |message| {
+ if (!message.isCallable(globalObject.vm())) {
+ globalObject.throwInvalidArguments("websocket expects a function for the message option", .{});
+ return null;
+ }
+ handler.onMessage = message;
+ message.ensureStillAlive();
+ }
+
+ if (object.getTruthy(globalObject, "open")) |open| {
+ if (!open.isCallable(globalObject.vm())) {
+ globalObject.throwInvalidArguments("websocket expects a function for the open option", .{});
+ return null;
+ }
+ handler.onOpen = open;
+ open.ensureStillAlive();
+ }
+
+ if (object.getTruthy(globalObject, "close")) |close| {
+ if (!close.isCallable(globalObject.vm())) {
+ globalObject.throwInvalidArguments("websocket expects a function for the close option", .{});
+ return null;
+ }
+ handler.onClose = close;
+ close.ensureStillAlive();
+ }
+
+ if (object.getTruthy(globalObject, "drain")) |drain| {
+ if (!drain.isCallable(globalObject.vm())) {
+ globalObject.throwInvalidArguments("websocket expects a function for the drain option", .{});
+ return null;
+ }
+ handler.onDrain = drain;
+ drain.ensureStillAlive();
+ }
+
+ if (handler.onMessage != .zero or handler.onOpen != .zero)
+ return handler;
+
+ return null;
+ }
+
+ pub fn protect(this: Handler) void {
+ this.onOpen.protect();
+ this.onMessage.protect();
+ this.onClose.protect();
+ this.onDrain.protect();
+ }
+
+ pub fn unprotect(this: Handler) void {
+ this.onOpen.unprotect();
+ this.onMessage.unprotect();
+ this.onClose.unprotect();
+ this.onDrain.unprotect();
+ }
+ };
+
pub fn toBehavior(this: WebSocketServer) uws.WebSocketBehavior {
return .{
.maxPayloadLength = this.maxPayloadLength,
@@ -2323,19 +2349,10 @@ pub const WebSocketServer = struct {
}
pub fn protect(this: WebSocketServer) void {
- this.onUpgrade.protect();
- this.onOpen.protect();
- this.onMessage.protect();
- this.onClose.protect();
- this.onDrain.protect();
+ this.handler.protect();
}
-
pub fn unprotect(this: WebSocketServer) void {
- this.onUpgrade.unprotect();
- this.onOpen.unprotect();
- this.onMessage.unprotect();
- this.onClose.unprotect();
- this.onDrain.unprotect();
+ this.handler.unprotect();
}
const CompressTable = bun.ComptimeStringMap(i32, .{
@@ -2367,53 +2384,69 @@ pub const WebSocketServer = struct {
});
pub fn onCreate(globalObject: *JSC.JSGlobalObject, object: JSValue) ?WebSocketServer {
- if (!object.isObject()) {
- globalObject.throwInvalidArguments("websocket expects an options object", .{});
+ var server = WebSocketServer{};
+
+ if (Handler.fromJS(globalObject, object)) |handler| {
+ server.handler = handler;
+ } else {
+ globalObject.throwInvalidArguments("WebSocketServer expects a message handler", .{});
return null;
}
- var server = WebSocketServer{};
+ if (object.get(globalObject, "perMessageDeflate")) |per_message_deflate| {
+ getter: {
+ if (per_message_deflate.isUndefined()) {
+ break :getter;
+ }
- if (object.getTruthy(globalObject, "compressor")) |compression| {
- if (compression.isBoolean()) {
- server.compression |= if (compression.toBoolean()) uws.SHARED_COMPRESSOR else 0;
- } else if (compression.isString()) {
- var slice = compression.toSlice(globalObject, bun.default_allocator);
- defer slice.deinit();
- server.compression |= CompressTable.get(slice.slice()) orelse {
- globalObject.throwInvalidArguments(
- "websocket expects a valid compressor option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
- .{},
- );
- return null;
- };
- } else {
- globalObject.throwInvalidArguments(
- "websocket expects a valid compressor option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
- .{},
- );
- return null;
- }
- }
- if (object.getTruthy(globalObject, "decompressor")) |compression| {
- if (compression.isBoolean()) {
- server.compression |= if (compression.toBoolean()) uws.SHARED_DECOMPRESSOR else 0;
- } else if (compression.isString()) {
- var slice = compression.toSlice(globalObject, bun.default_allocator);
- defer slice.deinit();
- server.compression |= DecompressTable.get(slice.slice()) orelse {
- globalObject.throwInvalidArguments(
- "websocket expects a valid decompressor option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
- .{},
- );
- return null;
- };
- } else {
- globalObject.throwInvalidArguments(
- "websocket expects a valid decompressor option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
- .{},
- );
- return null;
+ if (per_message_deflate.isBoolean() or per_message_deflate.isNull()) {
+ if (per_message_deflate.toBoolean()) {
+ server.compression = uws.SHARED_COMPRESSOR | uws.SHARED_DECOMPRESSOR;
+ } else {
+ server.compression = 0;
+ }
+ break :getter;
+ }
+
+ if (per_message_deflate.getTruthy(globalObject, "compress")) |compression| {
+ if (compression.isBoolean()) {
+ server.compression |= if (compression.toBoolean()) uws.SHARED_COMPRESSOR else 0;
+ } else if (compression.isString()) {
+ server.compression |= CompressTable.getWithEql(compression.getZigString(globalObject), ZigString.eqlComptime) orelse {
+ globalObject.throwInvalidArguments(
+ "WebSocketServer expects a valid compress option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
+ .{},
+ );
+ return null;
+ };
+ } else {
+ globalObject.throwInvalidArguments(
+ "websocket expects a valid compress option, either disable \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
+ .{},
+ );
+ return null;
+ }
+ }
+
+ if (per_message_deflate.getTruthy(globalObject, "decompress")) |compression| {
+ if (compression.isBoolean()) {
+ server.compression |= if (compression.toBoolean()) uws.SHARED_DECOMPRESSOR else 0;
+ } else if (compression.isString()) {
+ server.compression |= DecompressTable.getWithEql(compression.getZigString(globalObject), ZigString.eqlComptime) orelse {
+ globalObject.throwInvalidArguments(
+ "websocket expects a valid decompress option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
+ .{},
+ );
+ return null;
+ };
+ } else {
+ globalObject.throwInvalidArguments(
+ "websocket expects a valid decompress option, either \"disable\" \"shared\" \"dedicated\" \"3KB\" \"4KB\" \"8KB\" \"16KB\" \"32KB\" \"64KB\" \"128KB\" or \"256KB\"",
+ .{},
+ );
+ return null;
+ }
+ }
}
}
@@ -2468,71 +2501,29 @@ pub const WebSocketServer = struct {
}
}
- if (object.getTruthy(globalObject, "message")) |message| {
- if (!message.isCallable(globalObject.vm())) {
- globalObject.throwInvalidArguments("websocket expects a function for the message option", .{});
- return null;
- }
- server.onMessage = message;
- message.ensureStillAlive();
- }
-
- if (object.getTruthy(globalObject, "open")) |open| {
- if (!open.isCallable(globalObject.vm())) {
- globalObject.throwInvalidArguments("websocket expects a function for the open option", .{});
- return null;
- }
- server.onOpen = open;
- open.ensureStillAlive();
- }
-
- if (object.getTruthy(globalObject, "close")) |close| {
- if (!close.isCallable(globalObject.vm())) {
- globalObject.throwInvalidArguments("websocket expects a function for the close option", .{});
- return null;
- }
- server.onClose = close;
- close.ensureStillAlive();
- }
-
- if (object.getTruthy(globalObject, "drain")) |drain| {
- if (!drain.isCallable(globalObject.vm())) {
- globalObject.throwInvalidArguments("websocket expects a function for the drain option", .{});
- return null;
- }
- server.onDrain = drain;
- drain.ensureStillAlive();
- }
-
- if (object.getTruthy(globalObject, "upgrade")) |upgrade| {
- if (!upgrade.isCallable(globalObject.vm())) {
- globalObject.throwInvalidArguments("websocket expects a function for the upgrade option", .{});
- return null;
- }
- server.onUpgrade = upgrade;
- upgrade.ensureStillAlive();
- }
-
server.protect();
return server;
}
-
- pub const List = std.StringArrayHashMapUnmanaged(WebSocketServer);
};
const Corker = struct {
- args: []const JSValue,
+ args: []const JSValue = &.{},
globalObject: *JSC.JSGlobalObject,
+ this_value: JSC.JSValue = .zero,
callback: JSC.JSValue,
result: JSValue = .zero,
pub fn run(this: *Corker) void {
- this.result = this.callback.call(this.globalObject, this.args);
+ const this_value = this.this_value;
+ this.result = if (this_value == .zero)
+ this.callback.call(this.globalObject, this.args)
+ else
+ this.callback.callWithThis(this.globalObject, this_value, this.args);
}
};
pub const ServerWebSocket = struct {
- handler: *WebSocketServer,
+ handler: *WebSocketServer.Handler,
this_value: JSValue = .zero,
websocket: uws.AnyWebSocket = undefined,
closed: bool = false,
@@ -2773,6 +2764,47 @@ pub const ServerWebSocket = struct {
return .zero;
}
+ pub fn cork(
+ this: *ServerWebSocket,
+ globalThis: *JSC.JSGlobalObject,
+ callframe: *JSC.CallFrame,
+ ) callconv(.C) JSValue {
+ const args = callframe.arguments(1);
+
+ if (args.len < 1) {
+ globalThis.throw("cork requires at least 1 argument", .{});
+ return .zero;
+ }
+ const callback = args.ptr[0];
+ if (callback.isEmptyOrUndefinedOrNull() or !callback.isCallable(globalThis.vm())) {
+ globalThis.throw("cork requires a function", .{});
+ return .zero;
+ }
+
+ if (this.closed) {
+ return JSValue.jsUndefined();
+ }
+
+ var corker = Corker{
+ .globalObject = globalThis,
+ .this_value = this.this_value,
+ .callback = callback,
+ };
+ this.websocket.cork(&corker, Corker.run);
+
+ const result = corker.result;
+
+ if (result.isEmptyOrUndefinedOrNull())
+ return JSValue.jsUndefined();
+
+ if (result.isAnyError(globalThis)) {
+ globalThis.throwValue(result);
+ return JSValue.jsUndefined();
+ }
+
+ return result;
+ }
+
pub fn send(
this: *ServerWebSocket,
globalThis: *JSC.JSGlobalObject,
@@ -3066,6 +3098,9 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
.reload = .{
.rfn = onReload,
},
+ .upgrade = .{
+ .rfn = JSC.wrapSync(ThisServer, "onUpgrade"),
+ },
},
.{
.port = .{
@@ -3089,6 +3124,142 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
},
);
+ pub fn onUpgrade(
+ this: *ThisServer,
+ globalThis: *JSC.JSGlobalObject,
+ object: JSC.JSValue,
+ optional: ?JSValue,
+ ) JSValue {
+ if (this.config.websocket == null) {
+ globalThis.throw("To enable WebSocket support, set the \"websocket\" object in Bun.serve({})", .{});
+ return .zero;
+ }
+ var request = object.as(Request) orelse {
+ globalThis.throw("upgrade requires a Request object", .{});
+ return .zero;
+ };
+
+ if (request.upgrader == null) {
+ return JSC.jsBoolean(false);
+ }
+
+ var upgrader = bun.cast(*RequestContext, request.upgrader.?);
+ if (upgrader.aborted) {
+ return JSC.jsBoolean(false);
+ }
+ request.upgrader = null;
+ if (upgrader.upgrade_context == null or @ptrToInt(upgrader.upgrade_context) == std.math.maxInt(usize)) {
+ return JSC.jsBoolean(false);
+ }
+ var ctx = upgrader.upgrade_context.?;
+ // obviously invalid pointer marks it as used
+ upgrader.upgrade_context = @intToPtr(*uws.uws_socket_context_s, std.math.maxInt(usize));
+
+ var sec_websocket_key_str = ZigString.Empty;
+
+ if (request.headers) |head| {
+ sec_websocket_key_str = head.fastGet(.SecWebSocketKey) orelse ZigString.Empty;
+ }
+
+ if (sec_websocket_key_str.len == 0) {
+ sec_websocket_key_str = ZigString.init(upgrader.req.header("sec-websocket-key") orelse "");
+ }
+
+ if (sec_websocket_key_str.len == 0) {
+ return JSC.jsBoolean(false);
+ }
+
+ var sec_websocket_protocol = ZigString.init(upgrader.req.header("sec-websocket-protocol") orelse "");
+ var sec_websocket_extensions = ZigString.init(upgrader.req.header("sec-websocket-extensions") orelse "");
+
+ if (sec_websocket_protocol.len > 0) {
+ sec_websocket_protocol.markUTF8();
+ }
+ if (sec_websocket_extensions.len > 0) {
+ sec_websocket_extensions.markUTF8();
+ }
+
+ var data_value = JSC.JSValue.zero;
+ var fetch_headers_to_deref: ?*JSC.FetchHeaders = null;
+ defer {
+ if (fetch_headers_to_deref) |fh| {
+ fh.deref();
+ }
+ }
+
+ if (optional) |opts| {
+ getter: {
+ if (opts.isEmptyOrUndefinedOrNull()) {
+ break :getter;
+ }
+
+ if (!opts.isObject()) {
+ globalThis.throw("upgrade options must be an object", .{});
+ return .zero;
+ }
+
+ if (opts.fastGet(globalThis, .headers)) |headers_value| {
+ if (headers_value.as(JSC.FetchHeaders)) |fetch_headers| {
+ if (fetch_headers.fastGet(.SecWebSocketProtocol)) |protocol| {
+ sec_websocket_protocol = protocol;
+ }
+
+ if (fetch_headers.fastGet(.SecWebSocketExtensions)) |protocol| {
+ sec_websocket_extensions = protocol;
+ }
+
+ fetch_headers.toUWSResponse(comptime ssl_enabled, upgrader.resp);
+ break :getter;
+ } else if (headers_value.isObject()) {
+ if (JSC.FetchHeaders.createFromJS(globalThis, headers_value)) |fetch_headers| {
+ if (fetch_headers.fastGet(.SecWebSocketProtocol)) |protocol| {
+ sec_websocket_protocol = protocol;
+ }
+
+ if (fetch_headers.fastGet(.SecWebSocketExtensions)) |protocol| {
+ sec_websocket_extensions = protocol;
+ }
+
+ fetch_headers.toUWSResponse(comptime ssl_enabled, upgrader.resp);
+ fetch_headers_to_deref = fetch_headers;
+ break :getter;
+ }
+ }
+
+ globalThis.throw("upgrade options.headers must be an object or Headers", .{});
+ return .zero;
+ }
+
+ if (opts.fastGet(globalThis, .data)) |headers_value| {
+ data_value = headers_value;
+ }
+ }
+ }
+
+ upgrader.resp.clearAborted();
+ var ws = this.vm.allocator.create(ServerWebSocket) catch return .zero;
+ ws.* = .{
+ .handler = &this.config.websocket.?.handler,
+ .this_value = data_value,
+ };
+
+ var sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
+ defer sec_websocket_protocol_str.deinit();
+ var sec_websocket_extensions_str = sec_websocket_extensions.toSlice(bun.default_allocator);
+ defer sec_websocket_extensions_str.deinit();
+
+ upgrader.resp.upgrade(
+ *ServerWebSocket,
+ ws,
+ sec_websocket_key_str.slice(),
+ sec_websocket_protocol_str.slice(),
+ sec_websocket_extensions_str.slice(),
+ ctx,
+ );
+
+ return JSC.jsBoolean(true);
+ }
+
pub fn onReload(
this: *ThisServer,
ctx: js.JSContextRef,
@@ -3117,52 +3288,21 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
this.config.onError = new_config.onError;
}
- if (new_config.websocket.onMessage != .zero or new_config.websocket.onOpen != .zero) {
- this.config.websocket.unprotect();
- if (this.config.websocket.onMessage == .zero) {
- this.app.ws("/*", this, 0, ServerWebSocket.behavior(
- ThisServer,
- ssl_enabled,
- this.config.websocket.toBehavior(),
- ));
- }
-
- new_config.websocket.globalObject = ctx;
- this.config.websocket = new_config.websocket;
- } else if (this.config.websocket.onMessage != .zero or this.config.websocket.onOpen != .zero) {
- this.config.websocket.unprotect();
- this.config.websocket = .{};
- }
-
- // we are going to leak the old memory
- const new_keys = new_config.websockets.keys();
- const old_keys = this.config.websockets.keys();
-
- if (new_keys.len + old_keys.len > 0) {
- var all_match = old_keys.len <= new_keys.len;
-
- // any existing websockets will now call the new config
- for (new_keys) |key, i| {
- if (this.config.websockets.getPtr(key)) |old_val| {
- old_val.unprotect();
- old_val.* = new_config.websockets.values()[i];
- old_val.globalObject = ctx;
+ if (new_config.websocket) |*ws| {
+ if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) {
+ if (this.config.websocket) |old_ws| {
+ old_ws.unprotect();
} else {
- all_match = false;
- var new_value = &new_config.websockets.values()[i];
- new_value.globalObject = ctx;
- this.app.ws(
- key,
- this,
- i,
- ServerWebSocket.behavior(ThisServer, ssl_enabled, new_value.toBehavior()),
- );
+ this.app.ws("/*", this, 0, ServerWebSocket.behavior(
+ ThisServer,
+ ssl_enabled,
+ ws.toBehavior(),
+ ));
}
- }
- if (!all_match) {
- this.config.websockets = new_config.websockets;
- }
+ ws.globalObject = ctx;
+ this.config.websocket = ws.*;
+ } // we don't remove it
}
return this.thisObject.asObjectRef();
@@ -3337,12 +3477,8 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
}
pub fn activeSocketsCount(this: *const ThisServer) u32 {
- var count = this.config.websocket.active_connections;
- for (this.config.websockets.values()) |conn| {
- count += conn.active_connections;
- }
-
- return @truncate(u32, count);
+ const websocket = &(this.config.websocket orelse return 0);
+ return @truncate(u32, websocket.handler.active_connections);
}
pub fn hasActiveWebSockets(this: *const ThisServer) bool {
@@ -3608,13 +3744,58 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
}
// We keep the Request object alive for the duration of the request so that we can remove the pointer to the UWS request object.
- var args = [_]JSC.C.JSValueRef{
- request_object.toJS(this.globalThis).asObjectRef(),
+ var args = [_]JSC.JSValue{
+ request_object.toJS(this.globalThis),
+ this.thisObject,
+ };
+ ctx.request_js_object = args[0].asObjectRef();
+ const request_value = args[0];
+ request_value.ensureStillAlive();
+ const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
+
+ ctx.onResponse(
+ this,
+ req,
+ request_object,
+ request_value,
+ response_value,
+ );
+ }
+
+ pub fn onWebSocketUpgrade(
+ this: *ThisServer,
+ resp: *App.Response,
+ req: *uws.Request,
+ upgrade_ctx: *uws.uws_socket_context_t,
+ _: usize,
+ ) void {
+ JSC.markBinding(@src());
+ this.pending_requests += 1;
+ req.setYield(false);
+ var ctx = this.request_pool_allocator.create(RequestContext) catch @panic("ran out of memory");
+ ctx.create(this, req, resp);
+ var request_object = this.allocator.create(JSC.WebCore.Request) catch unreachable;
+ request_object.* = .{
+ .url = "",
+ .method = ctx.method,
+ .uws_request = req,
+ .base_url_string_for_joining = this.base_url_string_for_joining,
+ .upgrader = ctx,
+ .body = .{
+ .Empty = .{},
+ },
+ };
+ ctx.upgrade_context = upgrade_ctx;
+
+ // We keep the Request object alive for the duration of the request so that we can remove the pointer to the UWS request object.
+ var args = [_]JSC.JSValue{
+ request_object.toJS(this.globalThis),
+ this.thisObject,
};
- ctx.request_js_object = args[0];
- const request_value = JSValue.c(args[0]);
+ ctx.request_js_object = args[0].asObjectRef();
+ const request_value = args[0];
request_value.ensureStillAlive();
- const response_value = JSC.C.JSObjectCallAsFunctionReturnValue(this.globalThis, this.config.onRequest.asObjectRef(), this.thisObject.asObjectRef(), 1, &args);
+ const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
ctx.onResponse(
this,
@@ -3645,28 +3826,14 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
this.app = App.create(.{});
}
- const websocket_patterns = this.config.websockets.keys();
- {
- const values = this.config.websockets.values();
- for (websocket_patterns) |pattern, i| {
- values[i].globalObject = this.globalThis;
-
- this.app.ws(pattern, this, i + 1, ServerWebSocket.behavior(
- ThisServer,
- ssl_enabled,
- values[i].toBehavior(),
- ));
- } else {
- if (this.config.websocket.onMessage != .zero or this.config.websocket.onOpen != .zero) {
- this.config.websocket.globalObject = this.globalThis;
- this.app.ws(
- "/*",
- this,
- 0,
- ServerWebSocket.behavior(ThisServer, ssl_enabled, this.config.websocket.toBehavior()),
- );
- }
- }
+ if (this.config.websocket) |*websocket| {
+ websocket.globalObject = this.globalThis;
+ this.app.ws(
+ "/*",
+ this,
+ 0,
+ ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()),
+ );
}
this.app.any("/*", *ThisServer, this, onRequest);
@@ -3693,131 +3860,6 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
.options = 0,
});
}
-
- pub fn onWebSocketUpgrade(this: *ThisServer, resp: *App.Response, req: *uws.Request, ctx: *uws.uws_socket_context_t, id: usize) void {
- var websocket_handler: *WebSocketServer = switch (id) {
- 0 => &this.config.websocket,
- else => &this.config.websockets.values()[id - 1],
- };
- req.setYield(false);
-
- const onUpgrade = websocket_handler.onUpgrade;
-
- if (onUpgrade == .zero) {
- var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory");
- upgrader.* = .{
- .handler = websocket_handler,
- .this_value = .zero,
- };
- resp.upgrade(
- *ServerWebSocket,
- upgrader,
- req.header("sec-websocket-key") orelse "",
- req.header("sec-websocket-protocol") orelse "",
- req.header("sec-websocket-extensions") orelse "",
- ctx,
- );
- return;
- }
-
- const method = HTTP.Method.which(req.method()) orelse HTTP.Method.GET;
-
- var request_object = this.allocator.create(JSC.WebCore.Request) catch unreachable;
- request_object.* = .{
- .url = "",
- .method = method,
- .uws_request = req,
- .base_url_string_for_joining = this.base_url_string_for_joining,
- .body = .{
- .Empty = .{},
- },
- };
-
- var args = [_]JSC.JSValue{
- request_object.toJS(this.globalThis),
- };
- var request_value = args[0];
- request_value.ensureStillAlive();
- const response_value = websocket_handler.onUpgrade.call(this.globalThis, &args);
- request_value.ensureStillAlive();
- response_value.ensureStillAlive();
-
- if (response_value.isBoolean() or response_value.isEmptyOrUndefinedOrNull()) {
- if (response_value.toBoolean()) {
- var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory");
- upgrader.* = .{
- .handler = websocket_handler,
- .this_value = .zero,
- };
- resp.upgrade(
- *ServerWebSocket,
- upgrader,
- req.header("sec-websocket-key") orelse "",
- req.header("sec-websocket-protocol") orelse "",
- req.header("sec-websocket-extensions") orelse "",
- ctx,
- );
- return;
- } else {
- req.setYield(true);
- return;
- }
- }
-
- if (response_value.as(Response)) |response| {
- if (response.statusCode() == 101) {
- var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory");
- upgrader.* = .{
- .handler = websocket_handler,
- .this_value = response_value,
- };
- response_value.ensureStillAlive();
- resp.upgrade(
- *ServerWebSocket,
- upgrader,
- response.header(.SecWebSocketKey) orelse req.header("sec-websocket-key") orelse "",
- response.header(.SecWebSocketProtocol) orelse req.header("sec-websocket-protocol") orelse "",
- response.header(.SecWebSocketExtensions) orelse req.header("sec-websocket-extensions") orelse "",
- ctx,
- );
- return;
- }
- }
-
- // The returned object becomes the data for the ServerWebSocket
- if (response_value.isCell() or (response_value.isString() and response_value.getLengthOfArray(this.globalThis) > 0)) {
- var upgrader = this.allocator.create(ServerWebSocket) catch @panic("Out of memory");
- upgrader.* = .{
- .handler = websocket_handler,
- .this_value = response_value,
- };
- response_value.ensureStillAlive();
-
- resp.upgrade(
- *ServerWebSocket,
- upgrader,
- req.header("sec-websocket-key") orelse "",
- req.header("sec-websocket-protocol") orelse "",
- req.header("sec-websocket-extensions") orelse "",
- ctx,
- );
- return;
- }
-
- var req_ctx = this.request_pool_allocator.create(RequestContext) catch @panic("ran out of memory");
- req_ctx.create(this, req, resp);
- req_ctx.request_js_object = request_value.asObjectRef();
- req_ctx.response_jsvalue = response_value;
- req_ctx.resp = resp;
-
- req_ctx.onResponse(
- this,
- req,
- request_object,
- request_value,
- response_value,
- );
- }
};
}
diff --git a/src/bun.js/bindings/ZigGeneratedClasses.cpp b/src/bun.js/bindings/ZigGeneratedClasses.cpp
index 93e69455b..62952f004 100644
--- a/src/bun.js/bindings/ZigGeneratedClasses.cpp
+++ b/src/bun.js/bindings/ZigGeneratedClasses.cpp
@@ -2410,6 +2410,10 @@ extern "C" EncodedJSValue ServerWebSocketPrototype__close(void* ptr, JSC::JSGlob
JSC_DECLARE_HOST_FUNCTION(ServerWebSocketPrototype__closeCallback);
+extern "C" EncodedJSValue ServerWebSocketPrototype__cork(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame);
+JSC_DECLARE_HOST_FUNCTION(ServerWebSocketPrototype__corkCallback);
+
+
extern "C" JSC::EncodedJSValue ServerWebSocketPrototype__getData(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject);
JSC_DECLARE_CUSTOM_GETTER(ServerWebSocketPrototype__dataGetterWrap);
@@ -2455,6 +2459,7 @@ STATIC_ASSERT_ISO_SUBSPACE_SHARABLE(JSServerWebSocketPrototype, JSServerWebSocke
static const HashTableValue JSServerWebSocketPrototypeTableValues[] = {
{ "close"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, ServerWebSocketPrototype__closeCallback, 1 } } ,
+{ "cork"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, ServerWebSocketPrototype__corkCallback, 1 } } ,
{ "data"_s, static_cast<unsigned>(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, ServerWebSocketPrototype__dataGetterWrap, ServerWebSocketPrototype__dataSetterWrap } } ,
{ "getBufferedAmount"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, ServerWebSocketPrototype__getBufferedAmountCallback, 0 } } ,
{ "isSubscribed"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, ServerWebSocketPrototype__isSubscribedCallback, 1 } } ,
@@ -2502,6 +2507,23 @@ JSC_DEFINE_HOST_FUNCTION(ServerWebSocketPrototype__closeCallback, (JSGlobalObjec
}
+JSC_DEFINE_HOST_FUNCTION(ServerWebSocketPrototype__corkCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame))
+{
+ auto& vm = lexicalGlobalObject->vm();
+
+ JSServerWebSocket* thisObject = jsDynamicCast<JSServerWebSocket*>(callFrame->thisValue());
+
+ if (UNLIKELY(!thisObject)) {
+ auto throwScope = DECLARE_THROW_SCOPE(vm);
+ return throwVMTypeError(lexicalGlobalObject, throwScope);
+ }
+
+ JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject);
+
+ return ServerWebSocketPrototype__cork(thisObject->wrapped(), lexicalGlobalObject, callFrame);
+}
+
+
JSC_DEFINE_CUSTOM_GETTER(ServerWebSocketPrototype__dataGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName))
{
auto& vm = lexicalGlobalObject->vm();
diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp
index e9965f3ad..c8a82abef 100644
--- a/src/bun.js/bindings/bindings.cpp
+++ b/src/bun.js/bindings/bindings.cpp
@@ -3063,6 +3063,7 @@ enum class BuiltinNamesMap : uint8_t {
status,
url,
body,
+ data,
};
static JSC::Identifier builtinNameMap(JSC::JSGlobalObject* globalObject, unsigned char name)
@@ -3084,6 +3085,9 @@ static JSC::Identifier builtinNameMap(JSC::JSGlobalObject* globalObject, unsigne
case BuiltinNamesMap::body: {
return clientData->builtinNames().bodyPublicName();
}
+ case BuiltinNamesMap::data: {
+ return clientData->builtinNames().dataPublicName();
+ }
}
}
diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig
index d69d5b4ed..8d119821d 100644
--- a/src/bun.js/bindings/bindings.zig
+++ b/src/bun.js/bindings/bindings.zig
@@ -3218,6 +3218,7 @@ pub const JSValue = enum(JSValueReprInt) {
status,
url,
body,
+ data,
};
// intended to be more lightweight than ZigString
diff --git a/src/bun.js/bindings/generated_classes.zig b/src/bun.js/bindings/generated_classes.zig
index 0bd06656f..7c954c0c1 100644
--- a/src/bun.js/bindings/generated_classes.zig
+++ b/src/bun.js/bindings/generated_classes.zig
@@ -822,6 +822,8 @@ pub const JSServerWebSocket = struct {
if (@TypeOf(ServerWebSocket.close) != CallbackType)
@compileLog("Expected ServerWebSocket.close to be a callback");
+ if (@TypeOf(ServerWebSocket.cork) != CallbackType)
+ @compileLog("Expected ServerWebSocket.cork to be a callback");
if (@TypeOf(ServerWebSocket.getData) != GetterType)
@compileLog("Expected ServerWebSocket.getData to be a getter");
@@ -848,6 +850,7 @@ pub const JSServerWebSocket = struct {
if (!JSC.is_bindgen) {
@export(ServerWebSocket.close, .{ .name = "ServerWebSocketPrototype__close" });
@export(ServerWebSocket.constructor, .{ .name = "ServerWebSocketClass__construct" });
+ @export(ServerWebSocket.cork, .{ .name = "ServerWebSocketPrototype__cork" });
@export(ServerWebSocket.finalize, .{ .name = "ServerWebSocketClass__finalize" });
@export(ServerWebSocket.getBufferedAmount, .{ .name = "ServerWebSocketPrototype__getBufferedAmount" });
@export(ServerWebSocket.getData, .{ .name = "ServerWebSocketPrototype__getData" });
diff --git a/src/bun.js/builtins/BunBuiltinNames.h b/src/bun.js/builtins/BunBuiltinNames.h
index 2982e8fae..5977f05dd 100644
--- a/src/bun.js/builtins/BunBuiltinNames.h
+++ b/src/bun.js/builtins/BunBuiltinNames.h
@@ -77,6 +77,7 @@ using namespace JSC;
macro(createUninitializedArrayBuffer) \
macro(createWritableStreamFromInternal) \
macro(cwd) \
+ macro(data) \
macro(dataView) \
macro(decode) \
macro(delimiter) \
diff --git a/src/bun.js/webcore/response.zig b/src/bun.js/webcore/response.zig
index 8d2f8ab27..81b0d7f14 100644
--- a/src/bun.js/webcore/response.zig
+++ b/src/bun.js/webcore/response.zig
@@ -5179,6 +5179,7 @@ pub const Request = struct {
body: Body.Value = Body.Value{ .Empty = .{} },
method: Method = Method.GET,
uws_request: ?*uws.Request = null,
+ upgrader: ?*anyopaque = null,
// We must report a consistent value for this
reported_estimated_size: ?u63 = null,
diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts
index a8dd31da4..086529a95 100644
--- a/test/bun.js/websocket-server.test.ts
+++ b/test/bun.js/websocket-server.test.ts
@@ -22,7 +22,50 @@ describe("websocket server", () => {
},
message(ws, msg) {},
},
- fetch(req) {
+ fetch(req, server) {
+ if (server.upgrade(req)) return;
+
+ return new Response("noooooo hello world");
+ },
+ });
+
+ await new Promise((resolve, reject) => {
+ const websocket = new WebSocket(`ws://localhost:${server.port}`);
+
+ websocket.onmessage = (e) => {
+ try {
+ expect(e.data).toBe("hello world");
+ resolve();
+ } catch (r) {
+ reject(r);
+ return;
+ } finally {
+ server?.stop();
+ websocket.close();
+ }
+ };
+ websocket.onerror = (e) => {
+ reject(e);
+ };
+ });
+ });
+
+ it("can do hello world corked", async () => {
+ var server = serve({
+ port: getPort(),
+ websocket: {
+ open(ws) {
+ ws.send("hello world");
+ },
+ message(ws, msg) {
+ ws.cork(() => {
+ ws.send("hello world");
+ });
+ },
+ },
+ fetch(req, server) {
+ if (server.upgrade(req)) return;
+
return new Response("noooooo hello world");
},
});
@@ -53,9 +96,6 @@ describe("websocket server", () => {
var server = serve({
port: getPort(),
websocket: {
- upgrade(ws) {
- return { count: 0 };
- },
open(ws) {},
message(ws, msg) {
if (msg === "first") {
@@ -65,8 +105,13 @@ describe("websocket server", () => {
ws.send(`counter: ${dataCount++}`);
},
},
- fetch(req) {
- return new Response("noooooo hello world");
+ fetch(req, server) {
+ if (
+ server.upgrade(req, {
+ count: 0,
+ })
+ )
+ return new Response("noooooo hello world");
},
});
@@ -128,15 +173,19 @@ describe("websocket server", () => {
var server = serve({
port: getPort(),
websocket: {
- upgrade(ws) {
- return { count: 0 };
- },
open(ws) {},
message(ws, msg) {
ws.send(sendQueue[serverCounter++] + " ");
},
},
- fetch(req) {
+ fetch(req, server) {
+ if (
+ server.upgrade(req, {
+ data: { count: 0 },
+ })
+ )
+ return;
+
return new Response("noooooo hello world");
},
});