aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bun.js/api/server.zig60
-rw-r--r--src/deps/uws.zig16
-rw-r--r--test/bun.js/websocket-server.test.ts63
3 files changed, 111 insertions, 28 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index ce0b388d7..dea478d7a 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -2289,11 +2289,14 @@ pub const WebSocketServer = struct {
onDrain: JSC.JSValue = .zero,
onError: JSC.JSValue = .zero,
- app: *anyopaque = undefined,
+ app: ?*anyopaque = null,
globalObject: *JSC.JSGlobalObject = undefined,
active_connections: usize = 0,
+ /// used by publish()
+ ssl: bool = false,
+
pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler {
var handler = Handler{ .globalObject = globalObject };
if (object.getTruthy(globalObject, "message")) |message| {
@@ -2792,10 +2795,11 @@ pub const ServerWebSocket = struct {
return .zero;
}
- if (this.closed) {
+ var app = this.handler.app orelse {
log("publish() closed", .{});
return JSValue.jsNumber(0);
- }
+ };
+ const ssl = this.handler.ssl;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
@@ -2830,7 +2834,7 @@ pub const ServerWebSocket = struct {
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
);
}
@@ -2845,7 +2849,7 @@ pub const ServerWebSocket = struct {
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
);
}
@@ -2865,10 +2869,11 @@ pub const ServerWebSocket = struct {
return .zero;
}
- if (this.closed) {
+ var app = this.handler.app orelse {
log("publish() closed", .{});
return JSValue.jsNumber(0);
- }
+ };
+ const ssl = this.handler.ssl;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
@@ -2904,7 +2909,7 @@ pub const ServerWebSocket = struct {
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
);
}
@@ -2921,11 +2926,11 @@ pub const ServerWebSocket = struct {
return .zero;
}
- if (this.closed) {
- log("publishBinary() closed", .{});
+ var app = this.handler.app orelse {
+ log("publish() closed", .{});
return JSValue.jsNumber(0);
- }
-
+ };
+ const ssl = this.handler.ssl;
const topic_value = args.ptr[0];
const message_value = args.ptr[1];
const compress_value = args.ptr[2];
@@ -2961,7 +2966,7 @@ pub const ServerWebSocket = struct {
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(this.websocket.publishWithOptions(this.handler.app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
+ @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)),
);
}
@@ -2971,10 +2976,11 @@ pub const ServerWebSocket = struct {
topic_str: *JSC.JSString,
buffer: *JSC.JSUint8Array,
) callconv(.C) JSC.JSValue {
- if (this.closed) {
- log("publishBinary() closed", .{});
+ var app = this.handler.app orelse {
+ log("publish() closed", .{});
return JSValue.jsNumber(0);
- }
+ };
+ const ssl = this.handler.ssl;
var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);
defer topic_slice.deinit();
@@ -2996,8 +3002,9 @@ pub const ServerWebSocket = struct {
@as(
i32,
@boolToInt(
- this.websocket.publishWithOptions(
- this.handler.app,
+ uws.AnyWebSocket.publishWithOptions(
+ ssl,
+ app,
topic_slice.slice(),
slice,
.binary,
@@ -3017,10 +3024,11 @@ pub const ServerWebSocket = struct {
topic_str: *JSC.JSString,
str: *JSC.JSString,
) callconv(.C) JSC.JSValue {
- if (this.closed) {
- log("publishBinary() closed", .{});
+ var app = this.handler.app orelse {
+ log("publish() closed", .{});
return JSValue.jsNumber(0);
- }
+ };
+ const ssl = this.handler.ssl;
var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);
defer topic_slice.deinit();
@@ -3039,8 +3047,9 @@ pub const ServerWebSocket = struct {
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
- @as(i32, @boolToInt(this.websocket.publishWithOptions(
- this.handler.app,
+ @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(
+ ssl,
+ app,
topic_slice.slice(),
slice.slice(),
.text,
@@ -3809,6 +3818,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
}
if (new_config.websocket) |*ws| {
+ ws.handler.ssl = ssl_enabled;
if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) {
if (this.config.websocket) |old_ws| {
old_ws.unprotect();
@@ -4007,6 +4017,9 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
pub fn deinitIfWeCan(this: *ThisServer) void {
if (this.pending_requests == 0 and this.listener == null and this.has_js_deinited and !this.hasActiveWebSockets()) {
+ if (this.config.websocket) |*ws| {
+ ws.handler.app = null;
+ }
this.unref();
this.deinit();
}
@@ -4349,6 +4362,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
if (this.config.websocket) |*websocket| {
websocket.globalObject = this.globalThis;
websocket.handler.app = this.app;
+ websocket.handler.ssl = ssl_enabled;
this.app.ws(
"/*",
this,
diff --git a/src/deps/uws.zig b/src/deps/uws.zig
index 330d0e535..9b4903657 100644
--- a/src/deps/uws.zig
+++ b/src/deps/uws.zig
@@ -655,11 +655,17 @@ pub const AnyWebSocket = union(enum) {
.tcp => uws_ws_publish(0, this.tcp.raw(), topic.ptr, topic.len, message.ptr, message.len),
};
}
- pub fn publishWithOptions(this: AnyWebSocket, app: *anyopaque, topic: []const u8, message: []const u8, opcode: Opcode, compress: bool) bool {
- return switch (this) {
- .ssl => uws_publish(1, @ptrCast(*uws_app_t, app), topic.ptr, topic.len, message.ptr, message.len, opcode, compress),
- .tcp => uws_publish(0, @ptrCast(*uws_app_t, app), topic.ptr, topic.len, message.ptr, message.len, opcode, compress),
- };
+ pub fn publishWithOptions(ssl: bool, app: *anyopaque, topic: []const u8, message: []const u8, opcode: Opcode, compress: bool) bool {
+ return uws_publish(
+ @boolToInt(ssl),
+ @ptrCast(*uws_app_t, app),
+ topic.ptr,
+ topic.len,
+ message.ptr,
+ message.len,
+ opcode,
+ compress,
+ );
}
pub fn getBufferedAmount(this: AnyWebSocket) u32 {
return switch (this) {
diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts
index eed5ffdc4..7c7869fad 100644
--- a/test/bun.js/websocket-server.test.ts
+++ b/test/bun.js/websocket-server.test.ts
@@ -12,6 +12,69 @@ function getPort() {
}
describe("websocket server", () => {
+ for (let method of ["publish", "publishText", "publishBinary"]) {
+ describe(method, () => {
+ it("in close() should work", async () => {
+ var server = serve({
+ port: getPort(),
+ websocket: {
+ open(ws) {
+ ws.subscribe("all");
+ },
+ message(ws, msg) {},
+ close(ws) {
+ ws[method](
+ "all",
+ method === "publishBinary" ? Buffer.from("bye!") : "bye!"
+ );
+ },
+ },
+ fetch(req, server) {
+ if (server.upgrade(req)) {
+ return;
+ }
+
+ return new Response("success");
+ },
+ });
+
+ try {
+ const first = await new Promise((resolve2, reject2) => {
+ var socket = new WebSocket(
+ `ws://${server.hostname}:${server.port}`
+ );
+ socket.onopen = () => resolve2(socket);
+ });
+
+ const second = await new Promise((resolve2, reject2) => {
+ var socket = new WebSocket(
+ `ws://${server.hostname}:${server.port}`
+ );
+ socket.onmessage = (ev) => {
+ var msg = ev.data;
+ if (typeof msg !== "string") {
+ msg = new TextDecoder().decode(msg);
+ }
+ if (msg === "bye!") {
+ resolve2(socket);
+ } else {
+ reject2(msg);
+ }
+ };
+ socket.onopen = () => {
+ first.close();
+ };
+ });
+
+ second.close();
+ } catch (r) {
+ } finally {
+ server.stop();
+ }
+ });
+ });
+ }
+
it("close inside open", async () => {
var resolve;
var server = serve({