diff options
-rw-r--r-- | src/bun.js/api/server.zig | 60 | ||||
-rw-r--r-- | src/deps/uws.zig | 16 | ||||
-rw-r--r-- | test/bun.js/websocket-server.test.ts | 63 |
3 files changed, 111 insertions, 28 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index ce0b388d7..dea478d7a 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -2289,11 +2289,14 @@ pub const WebSocketServer = struct { onDrain: JSC.JSValue = .zero, onError: JSC.JSValue = .zero, - app: *anyopaque = undefined, + app: ?*anyopaque = null, globalObject: *JSC.JSGlobalObject = undefined, active_connections: usize = 0, + /// used by publish() + ssl: bool = false, + pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler { var handler = Handler{ .globalObject = globalObject }; if (object.getTruthy(globalObject, "message")) |message| { @@ -2792,10 +2795,11 @@ pub const ServerWebSocket = struct { return .zero; } - if (this.closed) { + var app = this.handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); - } + }; + const ssl = this.handler.ssl; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; @@ -2830,7 +2834,7 @@ pub const ServerWebSocket = struct { return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), ); } @@ -2845,7 +2849,7 @@ pub const ServerWebSocket = struct { return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), ); } @@ -2865,10 +2869,11 @@ pub const ServerWebSocket = struct { return .zero; } - if (this.closed) { + var app = this.handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); - } + }; + const ssl = this.handler.ssl; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; @@ -2904,7 +2909,7 @@ pub const ServerWebSocket = struct { return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), ); } @@ -2921,11 +2926,11 @@ pub const ServerWebSocket = struct { return .zero; } - if (this.closed) { - log("publishBinary() closed", .{}); + var app = this.handler.app orelse { + log("publish() closed", .{}); return JSValue.jsNumber(0); - } - + }; + const ssl = this.handler.ssl; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; const compress_value = args.ptr[2]; @@ -2961,7 +2966,7 @@ pub const ServerWebSocket = struct { return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), ); } @@ -2971,10 +2976,11 @@ pub const ServerWebSocket = struct { topic_str: *JSC.JSString, buffer: *JSC.JSUint8Array, ) callconv(.C) JSC.JSValue { - if (this.closed) { - log("publishBinary() closed", .{}); + var app = this.handler.app orelse { + log("publish() closed", .{}); return JSValue.jsNumber(0); - } + }; + const ssl = this.handler.ssl; var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator); defer topic_slice.deinit(); @@ -2996,8 +3002,9 @@ pub const ServerWebSocket = struct { @as( i32, @boolToInt( - this.websocket.publishWithOptions( - this.handler.app, + uws.AnyWebSocket.publishWithOptions( + ssl, + app, topic_slice.slice(), slice, .binary, @@ -3017,10 +3024,11 @@ pub const ServerWebSocket = struct { topic_str: *JSC.JSString, str: *JSC.JSString, ) callconv(.C) JSC.JSValue { - if (this.closed) { - log("publishBinary() closed", .{}); + var app = this.handler.app orelse { + log("publish() closed", .{}); return JSValue.jsNumber(0); - } + }; + const ssl = this.handler.ssl; var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator); defer topic_slice.deinit(); @@ -3039,8 +3047,9 @@ pub const ServerWebSocket = struct { return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(this.websocket.publishWithOptions( - this.handler.app, + @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions( + ssl, + app, topic_slice.slice(), slice.slice(), .text, @@ -3809,6 +3818,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { } if (new_config.websocket) |*ws| { + ws.handler.ssl = ssl_enabled; if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) { if (this.config.websocket) |old_ws| { old_ws.unprotect(); @@ -4007,6 +4017,9 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { pub fn deinitIfWeCan(this: *ThisServer) void { if (this.pending_requests == 0 and this.listener == null and this.has_js_deinited and !this.hasActiveWebSockets()) { + if (this.config.websocket) |*ws| { + ws.handler.app = null; + } this.unref(); this.deinit(); } @@ -4349,6 +4362,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { if (this.config.websocket) |*websocket| { websocket.globalObject = this.globalThis; websocket.handler.app = this.app; + websocket.handler.ssl = ssl_enabled; this.app.ws( "/*", this, diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 330d0e535..9b4903657 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -655,11 +655,17 @@ pub const AnyWebSocket = union(enum) { .tcp => uws_ws_publish(0, this.tcp.raw(), topic.ptr, topic.len, message.ptr, message.len), }; } - pub fn publishWithOptions(this: AnyWebSocket, app: *anyopaque, topic: []const u8, message: []const u8, opcode: Opcode, compress: bool) bool { - return switch (this) { - .ssl => uws_publish(1, @ptrCast(*uws_app_t, app), topic.ptr, topic.len, message.ptr, message.len, opcode, compress), - .tcp => uws_publish(0, @ptrCast(*uws_app_t, app), topic.ptr, topic.len, message.ptr, message.len, opcode, compress), - }; + pub fn publishWithOptions(ssl: bool, app: *anyopaque, topic: []const u8, message: []const u8, opcode: Opcode, compress: bool) bool { + return uws_publish( + @boolToInt(ssl), + @ptrCast(*uws_app_t, app), + topic.ptr, + topic.len, + message.ptr, + message.len, + opcode, + compress, + ); } pub fn getBufferedAmount(this: AnyWebSocket) u32 { return switch (this) { diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts index eed5ffdc4..7c7869fad 100644 --- a/test/bun.js/websocket-server.test.ts +++ b/test/bun.js/websocket-server.test.ts @@ -12,6 +12,69 @@ function getPort() { } describe("websocket server", () => { + for (let method of ["publish", "publishText", "publishBinary"]) { + describe(method, () => { + it("in close() should work", async () => { + var server = serve({ + port: getPort(), + websocket: { + open(ws) { + ws.subscribe("all"); + }, + message(ws, msg) {}, + close(ws) { + ws[method]( + "all", + method === "publishBinary" ? Buffer.from("bye!") : "bye!" + ); + }, + }, + fetch(req, server) { + if (server.upgrade(req)) { + return; + } + + return new Response("success"); + }, + }); + + try { + const first = await new Promise((resolve2, reject2) => { + var socket = new WebSocket( + `ws://${server.hostname}:${server.port}` + ); + socket.onopen = () => resolve2(socket); + }); + + const second = await new Promise((resolve2, reject2) => { + var socket = new WebSocket( + `ws://${server.hostname}:${server.port}` + ); + socket.onmessage = (ev) => { + var msg = ev.data; + if (typeof msg !== "string") { + msg = new TextDecoder().decode(msg); + } + if (msg === "bye!") { + resolve2(socket); + } else { + reject2(msg); + } + }; + socket.onopen = () => { + first.close(); + }; + }); + + second.close(); + } catch (r) { + } finally { + server.stop(); + } + }); + }); + } + it("close inside open", async () => { var resolve; var server = serve({ |