diff options
author | 2023-01-11 17:13:46 -0800 | |
---|---|---|
committer | 2023-01-11 17:14:07 -0800 | |
commit | 1c20e05d7011a302c04adc1f7aa6fc3ce3963799 (patch) | |
tree | 79fe797287a26af95defca2c706981cea4d6cd76 | |
parent | 4969f068f63e68a19aacc67cc094421fa7297c07 (diff) | |
download | bun-1c20e05d7011a302c04adc1f7aa6fc3ce3963799.tar.gz bun-1c20e05d7011a302c04adc1f7aa6fc3ce3963799.tar.zst bun-1c20e05d7011a302c04adc1f7aa6fc3ce3963799.zip |
[Bun.serve] Introduce publishToSelf boolean on websocket: {} config object
-rw-r--r-- | src/bun.js/api/server.zig | 126 | ||||
-rw-r--r-- | test/bun.js/websocket-server.test.ts | 52 |
2 files changed, 128 insertions, 50 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index f47ee9fc0..bed751f1b 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -2514,7 +2514,10 @@ pub const WebSocketServer = struct { active_connections: usize = 0, /// used by publish() - ssl: bool = false, + flags: packed struct (u2) { + ssl: bool = false, + publish_to_self: bool = true, + } = .{}, pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler { var handler = Handler{ .globalObject = globalObject }; @@ -2752,6 +2755,17 @@ pub const WebSocketServer = struct { } } + if (object.get(globalObject, "publishToSelf")) |value| { + if (!value.isUndefinedOrNull()) { + if (!value.isBoolean()) { + globalObject.throwInvalidArguments("websocket expects publishToSelf to be a boolean", .{}); + return null; + } + + server.handler.flags.publish_to_self = value.toBoolean(); + } + } + server.protect(); return server; } @@ -3025,7 +3039,9 @@ pub const ServerWebSocket = struct { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const ssl = this.handler.ssl; + const flags = this.handler.flags; + const ssl = flags.ssl; + const publish_to_self = flags.publish_to_self; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; @@ -3051,16 +3067,23 @@ pub const ServerWebSocket = struct { return .zero; } - if (message_value.asArrayBuffer(globalThis)) |buffer| { + if (message_value.asArrayBuffer(globalThis)) |array_buffer| { + const buffer = array_buffer.slice(); + if (buffer.len == 0) { globalThis.throw("publish requires a non-empty message", .{}); return .zero; } + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); + return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -3072,10 +3095,16 @@ pub const ServerWebSocket = struct { } const buffer = string_slice.slice(); + + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .text, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress); + return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -3099,7 +3128,9 @@ pub const ServerWebSocket = struct { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const ssl = this.handler.ssl; + const flags = this.handler.flags; + const ssl = flags.ssl; + const publish_to_self = flags.publish_to_self; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; @@ -3132,10 +3163,16 @@ pub const ServerWebSocket = struct { } const buffer = string_slice.slice(); + + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .text, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress); + return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -3156,7 +3193,9 @@ pub const ServerWebSocket = struct { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const ssl = this.handler.ssl; + const flags = this.handler.flags; + const ssl = flags.ssl; + const publish_to_self = flags.publish_to_self; const topic_value = args.ptr[0]; const message_value = args.ptr[1]; const compress_value = args.ptr[2]; @@ -3180,19 +3219,25 @@ pub const ServerWebSocket = struct { globalThis.throw("publishBinary requires a non-empty message", .{}); return .zero; } - const buffer = message_value.asArrayBuffer(globalThis) orelse { + const array_buffer = message_value.asArrayBuffer(globalThis) orelse { globalThis.throw("publishBinary expects an ArrayBufferView", .{}); return .zero; }; + const buffer = array_buffer.slice(); if (buffer.len == 0) { return JSC.JSValue.jsNumber(0); } + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); + return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -3200,13 +3245,15 @@ pub const ServerWebSocket = struct { this: *ServerWebSocket, globalThis: *JSC.JSGlobalObject, topic_str: *JSC.JSString, - buffer: *JSC.JSUint8Array, + array: *JSC.JSUint8Array, ) callconv(.C) JSC.JSValue { var app = this.handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const ssl = this.handler.ssl; + const flags = this.handler.flags; + const ssl = flags.ssl; + const publish_to_self = flags.publish_to_self; var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator); defer topic_slice.deinit(); @@ -3217,30 +3264,20 @@ pub const ServerWebSocket = struct { const compress = true; - const slice = buffer.slice(); - if (slice.len == 0) { + const buffer = array.slice(); + if (buffer.len == 0) { return JSC.JSValue.jsNumber(0); } + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); + return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as( - i32, - @boolToInt( - uws.AnyWebSocket.publishWithOptions( - ssl, - app, - topic_slice.slice(), - slice, - .binary, - compress, - ), - ), - ) * @intCast( - i32, - @truncate(u31, slice.len), - ), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -3254,7 +3291,9 @@ pub const ServerWebSocket = struct { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const ssl = this.handler.ssl; + const flags = this.handler.flags; + const ssl = flags.ssl; + const publish_to_self = flags.publish_to_self; var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator); defer topic_slice.deinit(); @@ -3266,24 +3305,21 @@ pub const ServerWebSocket = struct { const compress = true; const slice = str.toSlice(globalThis, bun.default_allocator); - if (slice.len == 0) { + defer slice.deinit(); + const buffer = slice.slice(); + + if (buffer.len == 0) { return JSC.JSValue.jsNumber(0); } + const result = if (!publish_to_self) + this.websocket.publish(topic_slice.slice(), buffer, .text, compress) + else + uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress); return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent - @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions( - ssl, - app, - topic_slice.slice(), - slice.slice(), - .text, - compress, - ))) * @intCast( - i32, - @truncate(u31, slice.len), - ), + if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0), ); } @@ -4107,7 +4143,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { } if (new_config.websocket) |*ws| { - ws.handler.ssl = ssl_enabled; + ws.handler.flags.ssl = ssl_enabled; if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) { if (this.config.websocket) |old_ws| { old_ws.unprotect(); @@ -4671,7 +4707,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { if (this.config.websocket) |*websocket| { websocket.globalObject = this.globalThis; websocket.handler.app = this.app; - websocket.handler.ssl = ssl_enabled; + websocket.handler.flags.ssl = ssl_enabled; this.app.ws( "/*", this, diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts index 1d9c15341..0dc421eb6 100644 --- a/test/bun.js/websocket-server.test.ts +++ b/test/bun.js/websocket-server.test.ts @@ -1,6 +1,6 @@ -import { serve } from "bun"; import { describe, expect, it } from "bun:test"; import { gcTick } from "./gc"; +import { serve } from "bun"; var port = 4321; function getPort() { @@ -49,6 +49,44 @@ describe("websocket server", () => { done(); }); + it("can do publish() with publishToSelf: false", async (done) => { + var server = serve({ + port: getPort(), + websocket: { + open(ws) { + ws.subscribe("all"); + ws.publish("all", "hey"); + server.publish("all", "hello"); + }, + message(ws, msg) { + if (new TextDecoder().decode(msg) !== "hello") { + done(new Error("unexpected message")); + } + }, + close(ws) {}, + publishToSelf: false, + }, + fetch(req, server) { + if (server.upgrade(req)) { + return; + } + + return new Response("success"); + }, + }); + + await new Promise<void>((resolve2, reject2) => { + var socket = new WebSocket(`ws://${server.hostname}:${server.port}`); + + socket.onmessage = (e) => { + expect(e.data).toBe("hello"); + resolve2(); + }; + }); + server.stop(); + done(); + }); + for (let method of ["publish", "publishText", "publishBinary"]) { describe(method, () => { it("in close() should work", async () => { @@ -463,7 +501,9 @@ describe("websocket server", () => { server.stop(); expect(() => { server.upgrade(req); - }).toThrow('To enable websocket support, set the "websocket" object in Bun.serve({})'); + }).toThrow( + 'To enable websocket support, set the "websocket" object in Bun.serve({})', + ); return new Response("success"); }, }); @@ -826,9 +866,11 @@ describe("websocket server", () => { fetch(req) { gcTick(); server.stop(); - if (server.upgrade(req, { - data: { count: 0 }, - })) + if ( + server.upgrade(req, { + data: { count: 0 }, + }) + ) return; return new Response("noooooo hello world"); }, |