aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bun.js/api/server.zig126
-rw-r--r--test/bun.js/websocket-server.test.ts52
2 files changed, 128 insertions, 50 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index f47ee9fc0..bed751f1b 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -2514,7 +2514,10 @@ pub const WebSocketServer = struct {
active_connections: usize = 0,
/// used by publish()
- ssl: bool = false,
+ flags: packed struct (u2) {
+ ssl: bool = false,
+ publish_to_self: bool = true,
+ } = .{},
pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler {
var handler = Handler{ .globalObject = globalObject };
@@ -2752,6 +2755,17 @@ pub const WebSocketServer = struct {
}
}
+ if (object.get(globalObject, "publishToSelf")) |value| {
+ if (!value.isUndefinedOrNull()) {
+ if (!value.isBoolean()) {
+ globalObject.throwInvalidArguments("websocket expects publishToSelf to be a boolean", .{});
+ return null;
+ }
+
+ server.handler.flags.publish_to_self = value.toBoolean();
+ }
+ }
+
server.protect();
return server;
}
@@ -3025,7 +3039,9 @@ pub const ServerWebSocket = struct {
log("publish() closed", .{});
return JSValue.jsNumber(0);
};
- const ssl = this.handler.ssl;
+ const flags = this.handler.flags;
+ const ssl = flags.ssl;
+ const publish_to_self = flags.publish_to_self;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
@@ -3051,16 +3067,23 @@ pub const ServerWebSocket = struct {
return .zero;
}
- if (message_value.asArrayBuffer(globalThis)) |buffer| {
+ if (message_value.asArrayBuffer(globalThis)) |array_buffer| {
+ const buffer = array_buffer.slice();
+
if (buffer.len == 0) {
globalThis.throw("publish requires a non-empty message", .{});
return .zero;
}
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .binary, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress);
+
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -3072,10 +3095,16 @@ pub const ServerWebSocket = struct {
}
const buffer = string_slice.slice();
+
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .text, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress);
+
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -3099,7 +3128,9 @@ pub const ServerWebSocket = struct {
log("publish() closed", .{});
return JSValue.jsNumber(0);
};
- const ssl = this.handler.ssl;
+ const flags = this.handler.flags;
+ const ssl = flags.ssl;
+ const publish_to_self = flags.publish_to_self;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
@@ -3132,10 +3163,16 @@ pub const ServerWebSocket = struct {
}
const buffer = string_slice.slice();
+
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .text, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress);
+
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -3156,7 +3193,9 @@ pub const ServerWebSocket = struct {
log("publish() closed", .{});
return JSValue.jsNumber(0);
};
- const ssl = this.handler.ssl;
+ const flags = this.handler.flags;
+ const ssl = flags.ssl;
+ const publish_to_self = flags.publish_to_self;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
const compress_value = args.ptr[2];
@@ -3180,19 +3219,25 @@ pub const ServerWebSocket = struct {
globalThis.throw("publishBinary requires a non-empty message", .{});
return .zero;
}
- const buffer = message_value.asArrayBuffer(globalThis) orelse {
+ const array_buffer = message_value.asArrayBuffer(globalThis) orelse {
globalThis.throw("publishBinary expects an ArrayBufferView", .{});
return .zero;
};
+ const buffer = array_buffer.slice();
if (buffer.len == 0) {
return JSC.JSValue.jsNumber(0);
}
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .binary, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress);
+
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -3200,13 +3245,15 @@ pub const ServerWebSocket = struct {
this: *ServerWebSocket,
globalThis: *JSC.JSGlobalObject,
topic_str: *JSC.JSString,
- buffer: *JSC.JSUint8Array,
+ array: *JSC.JSUint8Array,
) callconv(.C) JSC.JSValue {
var app = this.handler.app orelse {
log("publish() closed", .{});
return JSValue.jsNumber(0);
};
- const ssl = this.handler.ssl;
+ const flags = this.handler.flags;
+ const ssl = flags.ssl;
+ const publish_to_self = flags.publish_to_self;
var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);
defer topic_slice.deinit();
@@ -3217,30 +3264,20 @@ pub const ServerWebSocket = struct {
const compress = true;
- const slice = buffer.slice();
- if (slice.len == 0) {
+ const buffer = array.slice();
+ if (buffer.len == 0) {
return JSC.JSValue.jsNumber(0);
}
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .binary, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress);
+
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(
- i32,
- @boolToInt(
- uws.AnyWebSocket.publishWithOptions(
- ssl,
- app,
- topic_slice.slice(),
- slice,
- .binary,
- compress,
- ),
- ),
- ) * @intCast(
- i32,
- @truncate(u31, slice.len),
- ),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -3254,7 +3291,9 @@ pub const ServerWebSocket = struct {
log("publish() closed", .{});
return JSValue.jsNumber(0);
};
- const ssl = this.handler.ssl;
+ const flags = this.handler.flags;
+ const ssl = flags.ssl;
+ const publish_to_self = flags.publish_to_self;
var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);
defer topic_slice.deinit();
@@ -3266,24 +3305,21 @@ pub const ServerWebSocket = struct {
const compress = true;
const slice = str.toSlice(globalThis, bun.default_allocator);
- if (slice.len == 0) {
+ defer slice.deinit();
+ const buffer = slice.slice();
+
+ if (buffer.len == 0) {
return JSC.JSValue.jsNumber(0);
}
+ const result = if (!publish_to_self)
+ this.websocket.publish(topic_slice.slice(), buffer, .text, compress)
+ else
+ uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress);
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(
- ssl,
- app,
- topic_slice.slice(),
- slice.slice(),
- .text,
- compress,
- ))) * @intCast(
- i32,
- @truncate(u31, slice.len),
- ),
+ if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),
);
}
@@ -4107,7 +4143,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
}
if (new_config.websocket) |*ws| {
- ws.handler.ssl = ssl_enabled;
+ ws.handler.flags.ssl = ssl_enabled;
if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) {
if (this.config.websocket) |old_ws| {
old_ws.unprotect();
@@ -4671,7 +4707,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
if (this.config.websocket) |*websocket| {
websocket.globalObject = this.globalThis;
websocket.handler.app = this.app;
- websocket.handler.ssl = ssl_enabled;
+ websocket.handler.flags.ssl = ssl_enabled;
this.app.ws(
"/*",
this,
diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts
index 1d9c15341..0dc421eb6 100644
--- a/test/bun.js/websocket-server.test.ts
+++ b/test/bun.js/websocket-server.test.ts
@@ -1,6 +1,6 @@
-import { serve } from "bun";
import { describe, expect, it } from "bun:test";
import { gcTick } from "./gc";
+import { serve } from "bun";
var port = 4321;
function getPort() {
@@ -49,6 +49,44 @@ describe("websocket server", () => {
done();
});
+ it("can do publish() with publishToSelf: false", async (done) => {
+ var server = serve({
+ port: getPort(),
+ websocket: {
+ open(ws) {
+ ws.subscribe("all");
+ ws.publish("all", "hey");
+ server.publish("all", "hello");
+ },
+ message(ws, msg) {
+ if (new TextDecoder().decode(msg) !== "hello") {
+ done(new Error("unexpected message"));
+ }
+ },
+ close(ws) {},
+ publishToSelf: false,
+ },
+ fetch(req, server) {
+ if (server.upgrade(req)) {
+ return;
+ }
+
+ return new Response("success");
+ },
+ });
+
+ await new Promise<void>((resolve2, reject2) => {
+ var socket = new WebSocket(`ws://${server.hostname}:${server.port}`);
+
+ socket.onmessage = (e) => {
+ expect(e.data).toBe("hello");
+ resolve2();
+ };
+ });
+ server.stop();
+ done();
+ });
+
for (let method of ["publish", "publishText", "publishBinary"]) {
describe(method, () => {
it("in close() should work", async () => {
@@ -463,7 +501,9 @@ describe("websocket server", () => {
server.stop();
expect(() => {
server.upgrade(req);
- }).toThrow('To enable websocket support, set the "websocket" object in Bun.serve({})');
+ }).toThrow(
+ 'To enable websocket support, set the "websocket" object in Bun.serve({})',
+ );
return new Response("success");
},
});
@@ -826,9 +866,11 @@ describe("websocket server", () => {
fetch(req) {
gcTick();
server.stop();
- if (server.upgrade(req, {
- data: { count: 0 },
- }))
+ if (
+ server.upgrade(req, {
+ data: { count: 0 },
+ })
+ )
return;
return new Response("noooooo hello world");
},