diff options
| -rw-r--r-- | src/bun.js/api/server.zig | 126 | ||||
| -rw-r--r-- | test/bun.js/websocket-server.test.ts | 52 | 
2 files changed, 128 insertions, 50 deletions
| diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index f47ee9fc0..bed751f1b 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -2514,7 +2514,10 @@ pub const WebSocketServer = struct {          active_connections: usize = 0,          /// used by publish() -        ssl: bool = false, +        flags: packed struct (u2) { +            ssl: bool = false, +            publish_to_self: bool = true, +        } = .{},          pub fn fromJS(globalObject: *JSC.JSGlobalObject, object: JSC.JSValue) ?Handler {              var handler = Handler{ .globalObject = globalObject }; @@ -2752,6 +2755,17 @@ pub const WebSocketServer = struct {              }          } +        if (object.get(globalObject, "publishToSelf")) |value| { +            if (!value.isUndefinedOrNull()) { +                if (!value.isBoolean()) { +                    globalObject.throwInvalidArguments("websocket expects publishToSelf to be a boolean", .{}); +                    return null; +                } + +                server.handler.flags.publish_to_self = value.toBoolean(); +            } +        } +          server.protect();          return server;      } @@ -3025,7 +3039,9 @@ pub const ServerWebSocket = struct {              log("publish() closed", .{});              return JSValue.jsNumber(0);          }; -        const ssl = this.handler.ssl; +        const flags = this.handler.flags; +        const ssl = flags.ssl; +        const publish_to_self = flags.publish_to_self;          const topic_value = args.ptr[0];          const message_value = args.ptr[1]; @@ -3051,16 +3067,23 @@ pub const ServerWebSocket = struct {              return .zero;          } -        if (message_value.asArrayBuffer(globalThis)) |buffer| { +        if (message_value.asArrayBuffer(globalThis)) |array_buffer| { +            const buffer = array_buffer.slice(); +              if (buffer.len == 0) {                  globalThis.throw("publish requires a non-empty message", .{});                  return .zero;              } +            const result = if (!publish_to_self) +                this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) +            else +                uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); +              return JSValue.jsNumber(                  // if 0, return 0                  // else return number of bytes sent -                @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), +                if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),              );          } @@ -3072,10 +3095,16 @@ pub const ServerWebSocket = struct {              }              const buffer = string_slice.slice(); + +            const result = if (!publish_to_self) +                this.websocket.publish(topic_slice.slice(), buffer, .text, compress) +            else +                uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress); +              return JSValue.jsNumber(                  // if 0, return 0                  // else return number of bytes sent -                @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), +                if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),              );          } @@ -3099,7 +3128,9 @@ pub const ServerWebSocket = struct {              log("publish() closed", .{});              return JSValue.jsNumber(0);          }; -        const ssl = this.handler.ssl; +        const flags = this.handler.flags; +        const ssl = flags.ssl; +        const publish_to_self = flags.publish_to_self;          const topic_value = args.ptr[0];          const message_value = args.ptr[1]; @@ -3132,10 +3163,16 @@ pub const ServerWebSocket = struct {          }          const buffer = string_slice.slice(); + +        const result = if (!publish_to_self) +            this.websocket.publish(topic_slice.slice(), buffer, .text, compress) +        else +            uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress); +          return JSValue.jsNumber(              // if 0, return 0              // else return number of bytes sent -            @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress))) * @intCast(i32, @truncate(u31, buffer.len)), +            if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),          );      } @@ -3156,7 +3193,9 @@ pub const ServerWebSocket = struct {              log("publish() closed", .{});              return JSValue.jsNumber(0);          }; -        const ssl = this.handler.ssl; +        const flags = this.handler.flags; +        const ssl = flags.ssl; +        const publish_to_self = flags.publish_to_self;          const topic_value = args.ptr[0];          const message_value = args.ptr[1];          const compress_value = args.ptr[2]; @@ -3180,19 +3219,25 @@ pub const ServerWebSocket = struct {              globalThis.throw("publishBinary requires a non-empty message", .{});              return .zero;          } -        const buffer = message_value.asArrayBuffer(globalThis) orelse { +        const array_buffer = message_value.asArrayBuffer(globalThis) orelse {              globalThis.throw("publishBinary expects an ArrayBufferView", .{});              return .zero;          }; +        const buffer = array_buffer.slice();          if (buffer.len == 0) {              return JSC.JSValue.jsNumber(0);          } +        const result = if (!publish_to_self) +            this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) +        else +            uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); +          return JSValue.jsNumber(              // if 0, return 0              // else return number of bytes sent -            @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer.slice(), .binary, compress))) * @intCast(i32, @truncate(u31, buffer.len)), +            if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),          );      } @@ -3200,13 +3245,15 @@ pub const ServerWebSocket = struct {          this: *ServerWebSocket,          globalThis: *JSC.JSGlobalObject,          topic_str: *JSC.JSString, -        buffer: *JSC.JSUint8Array, +        array: *JSC.JSUint8Array,      ) callconv(.C) JSC.JSValue {          var app = this.handler.app orelse {              log("publish() closed", .{});              return JSValue.jsNumber(0);          }; -        const ssl = this.handler.ssl; +        const flags = this.handler.flags; +        const ssl = flags.ssl; +        const publish_to_self = flags.publish_to_self;          var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);          defer topic_slice.deinit(); @@ -3217,30 +3264,20 @@ pub const ServerWebSocket = struct {          const compress = true; -        const slice = buffer.slice(); -        if (slice.len == 0) { +        const buffer = array.slice(); +        if (buffer.len == 0) {              return JSC.JSValue.jsNumber(0);          } +        const result = if (!publish_to_self) +            this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) +        else +            uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .binary, compress); +          return JSValue.jsNumber(              // if 0, return 0              // else return number of bytes sent -            @as( -                i32, -                @boolToInt( -                    uws.AnyWebSocket.publishWithOptions( -                        ssl, -                        app, -                        topic_slice.slice(), -                        slice, -                        .binary, -                        compress, -                    ), -                ), -            ) * @intCast( -                i32, -                @truncate(u31, slice.len), -            ), +            if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),          );      } @@ -3254,7 +3291,9 @@ pub const ServerWebSocket = struct {              log("publish() closed", .{});              return JSValue.jsNumber(0);          }; -        const ssl = this.handler.ssl; +        const flags = this.handler.flags; +        const ssl = flags.ssl; +        const publish_to_self = flags.publish_to_self;          var topic_slice = topic_str.toSlice(globalThis, bun.default_allocator);          defer topic_slice.deinit(); @@ -3266,24 +3305,21 @@ pub const ServerWebSocket = struct {          const compress = true;          const slice = str.toSlice(globalThis, bun.default_allocator); -        if (slice.len == 0) { +        defer slice.deinit(); +        const buffer = slice.slice(); + +        if (buffer.len == 0) {              return JSC.JSValue.jsNumber(0);          } +        const result = if (!publish_to_self) +            this.websocket.publish(topic_slice.slice(), buffer, .text, compress) +        else +            uws.AnyWebSocket.publishWithOptions(ssl, app, topic_slice.slice(), buffer, .text, compress);          return JSValue.jsNumber(              // if 0, return 0              // else return number of bytes sent -            @as(i32, @boolToInt(uws.AnyWebSocket.publishWithOptions( -                ssl, -                app, -                topic_slice.slice(), -                slice.slice(), -                .text, -                compress, -            ))) * @intCast( -                i32, -                @truncate(u31, slice.len), -            ), +            if (result) @intCast(i32, @truncate(u31, buffer.len)) else @as(i32, 0),          );      } @@ -4107,7 +4143,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {              }              if (new_config.websocket) |*ws| { -                ws.handler.ssl = ssl_enabled; +                ws.handler.flags.ssl = ssl_enabled;                  if (ws.handler.onMessage != .zero or ws.handler.onOpen != .zero) {                      if (this.config.websocket) |old_ws| {                          old_ws.unprotect(); @@ -4671,7 +4707,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {              if (this.config.websocket) |*websocket| {                  websocket.globalObject = this.globalThis;                  websocket.handler.app = this.app; -                websocket.handler.ssl = ssl_enabled; +                websocket.handler.flags.ssl = ssl_enabled;                  this.app.ws(                      "/*",                      this, diff --git a/test/bun.js/websocket-server.test.ts b/test/bun.js/websocket-server.test.ts index 1d9c15341..0dc421eb6 100644 --- a/test/bun.js/websocket-server.test.ts +++ b/test/bun.js/websocket-server.test.ts @@ -1,6 +1,6 @@ -import { serve } from "bun";  import { describe, expect, it } from "bun:test";  import { gcTick } from "./gc"; +import { serve } from "bun";  var port = 4321;  function getPort() { @@ -49,6 +49,44 @@ describe("websocket server", () => {      done();    }); +  it("can do publish() with publishToSelf: false", async (done) => { +    var server = serve({ +      port: getPort(), +      websocket: { +        open(ws) { +          ws.subscribe("all"); +          ws.publish("all", "hey"); +          server.publish("all", "hello"); +        }, +        message(ws, msg) { +          if (new TextDecoder().decode(msg) !== "hello") { +            done(new Error("unexpected message")); +          } +        }, +        close(ws) {}, +        publishToSelf: false, +      }, +      fetch(req, server) { +        if (server.upgrade(req)) { +          return; +        } + +        return new Response("success"); +      }, +    }); + +    await new Promise<void>((resolve2, reject2) => { +      var socket = new WebSocket(`ws://${server.hostname}:${server.port}`); + +      socket.onmessage = (e) => { +        expect(e.data).toBe("hello"); +        resolve2(); +      }; +    }); +    server.stop(); +    done(); +  }); +    for (let method of ["publish", "publishText", "publishBinary"]) {      describe(method, () => {        it("in close() should work", async () => { @@ -463,7 +501,9 @@ describe("websocket server", () => {            server.stop();            expect(() => {              server.upgrade(req); -          }).toThrow('To enable websocket support, set the "websocket" object in Bun.serve({})'); +          }).toThrow( +            'To enable websocket support, set the "websocket" object in Bun.serve({})', +          );            return new Response("success");          },        }); @@ -826,9 +866,11 @@ describe("websocket server", () => {        fetch(req) {          gcTick();          server.stop(); -        if (server.upgrade(req, { -          data: { count: 0 }, -        })) +        if ( +          server.upgrade(req, { +            data: { count: 0 }, +          }) +        )            return;          return new Response("noooooo hello world");        }, | 
