diff options
author | 2023-07-03 20:53:41 -0700 | |
---|---|---|
committer | 2023-07-03 20:53:41 -0700 | |
commit | 3345a7fc3c81f914000fd5a3c5a24f920a70386a (patch) | |
tree | b4a0d9ffa9d37d17bbf41f74c9d3f3a624ae8aab | |
parent | b26b0d886ce2f9898833e8efa16b71952c39b615 (diff) | |
download | bun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.tar.gz bun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.tar.zst bun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.zip |
Allow zero length WebSocket client & server messages (#3488)
* Allow zero length WebSocket client & server messages
* Add test
* Clean this up a little
* Clean up these tests a little
* Hopefully fix the test failure in release build
* Don't copy into the receive buffer
* Less flaky
---------
Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
-rw-r--r-- | src/bun.js/api/server.zig | 53 | ||||
-rw-r--r-- | src/bun.js/bindings/ZigGlobalObject.cpp | 26 | ||||
-rw-r--r-- | src/bun.js/bindings/ZigGlobalObject.h | 2 | ||||
-rw-r--r-- | src/bun.js/bindings/bindings.zig | 17 | ||||
-rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.cpp | 27 | ||||
-rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.h | 28 | ||||
-rw-r--r-- | src/http/websocket_http_client.zig | 172 | ||||
-rw-r--r-- | test/js/bun/websocket/websocket-server.test.ts | 292 | ||||
-rw-r--r-- | test/js/web/websocket/websocket.test.js | 34 |
9 files changed, 391 insertions, 260 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index f52c08301..9625ff693 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -3560,11 +3560,6 @@ pub const ServerWebSocket = struct { if (message_value.asArrayBuffer(globalThis)) |array_buffer| { const buffer = array_buffer.slice(); - if (buffer.len == 0) { - globalThis.throw("publish requires a non-empty message", .{}); - return .zero; - } - const result = if (!publish_to_self) this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) else @@ -3580,9 +3575,6 @@ pub const ServerWebSocket = struct { { var string_slice = message_value.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); @@ -3634,10 +3626,6 @@ pub const ServerWebSocket = struct { var topic_slice = topic_value.toSlice(globalThis, bun.default_allocator); defer topic_slice.deinit(); - if (topic_slice.len == 0) { - globalThis.throw("publishText requires a non-empty topic", .{}); - return .zero; - } const compress = args.len > 1 and compress_value.toBoolean(); @@ -3648,9 +3636,6 @@ pub const ServerWebSocket = struct { var string_slice = message_value.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); @@ -3715,10 +3700,6 @@ pub const ServerWebSocket = struct { }; const buffer = array_buffer.slice(); - if (buffer.len == 0) { - return JSC.JSValue.jsNumber(0); - } - const result = if (!publish_to_self) this.websocket.publish(topic_slice.slice(), buffer, .binary, compress) else @@ -3883,10 +3864,6 @@ pub const ServerWebSocket = struct { } if (message_value.asArrayBuffer(globalThis)) |buffer| { - if (buffer.len == 0) { - return JSValue.jsNumber(0); - } - switch (this.websocket.send(buffer.slice(), .binary, compress, true)) { .backpressure => { log("send() backpressure ({d} bytes)", .{buffer.len}); @@ -3906,9 +3883,6 @@ pub const ServerWebSocket = struct { { var string_slice = message_value.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); switch (this.websocket.send(buffer, .text, compress, true)) { @@ -3960,9 +3934,6 @@ pub const ServerWebSocket = struct { var string_slice = message_value.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); switch (this.websocket.send(buffer, .text, compress, true)) { @@ -3994,9 +3965,6 @@ pub const ServerWebSocket = struct { var string_slice = message_str.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); switch (this.websocket.send(buffer, .text, compress, true)) { @@ -4043,10 +4011,6 @@ pub const ServerWebSocket = struct { return .zero; }; - if (buffer.len == 0) { - return JSValue.jsNumber(0); - } - switch (this.websocket.send(buffer.slice(), .binary, compress, true)) { .backpressure => { log("sendBinary() backpressure ({d} bytes)", .{buffer.len}); @@ -4076,10 +4040,6 @@ pub const ServerWebSocket = struct { const buffer = array_buffer.slice(); - if (buffer.len == 0) { - return JSValue.jsNumber(0); - } - switch (this.websocket.send(buffer, .binary, compress, true)) { .backpressure => { log("sendBinary() backpressure ({d} bytes)", .{buffer.len}); @@ -4416,17 +4376,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { const compress = (compress_value orelse JSValue.jsBoolean(true)).toBoolean(); - if (message_value.isEmptyOrUndefinedOrNull()) { - JSC.JSError(this.vm.allocator, "publish requires a non-empty message", .{}, globalThis, exception); - return .zero; - } - if (message_value.asArrayBuffer(globalThis)) |buffer| { - if (buffer.len == 0) { - JSC.JSError(this.vm.allocator, "publish requires a non-empty message", .{}, globalThis, exception); - return .zero; - } - return JSValue.jsNumber( // if 0, return 0 // else return number of bytes sent @@ -4437,9 +4387,6 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type { { var string_slice = message_value.toSlice(globalThis, bun.default_allocator); defer string_slice.deinit(); - if (string_slice.len == 0) { - return JSValue.jsNumber(0); - } const buffer = string_slice.slice(); return JSValue.jsNumber( diff --git a/src/bun.js/bindings/ZigGlobalObject.cpp b/src/bun.js/bindings/ZigGlobalObject.cpp index c00670289..b3236a4a2 100644 --- a/src/bun.js/bindings/ZigGlobalObject.cpp +++ b/src/bun.js/bindings/ZigGlobalObject.cpp @@ -1021,6 +1021,20 @@ JSC_DEFINE_HOST_FUNCTION(functionBunSleepThenCallback, return JSC::JSValue::encode(promise); } +using MicrotaskCallback = void (*)(void*); + +JSC_DEFINE_HOST_FUNCTION(functionNativeMicrotaskTrampoline, + (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) +{ + JSCell* cellPtr = callFrame->uncheckedArgument(0).asCell(); + JSCell* callbackPtr = callFrame->uncheckedArgument(1).asCell(); + + void* cell = reinterpret_cast<void*>(cellPtr); + auto* callback = reinterpret_cast<MicrotaskCallback>(callbackPtr); + callback(cell); + return JSValue::encode(jsUndefined()); +} + JSC_DEFINE_HOST_FUNCTION(functionBunSleep, (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) { @@ -3027,6 +3041,11 @@ void GlobalObject::finishCreation(VM& vm) init.set(JSFunction::create(init.vm, init.owner, 4, "performMicrotaskVariadic"_s, jsFunctionPerformMicrotaskVariadic, ImplementationVisibility::Public)); }); + m_nativeMicrotaskTrampoline.initLater( + [](const Initializer<JSFunction>& init) { + init.set(JSFunction::create(init.vm, init.owner, 2, ""_s, functionNativeMicrotaskTrampoline, ImplementationVisibility::Public)); + }); + m_navigatorObject.initLater( [](const Initializer<JSObject>& init) { int cpuCount = 0; @@ -4225,6 +4244,7 @@ void GlobalObject::visitChildrenImpl(JSCell* cell, Visitor& visitor) thisObject->m_JSFileSinkControllerPrototype.visit(visitor); thisObject->m_JSHTTPSResponseControllerPrototype.visit(visitor); thisObject->m_navigatorObject.visit(visitor); + thisObject->m_nativeMicrotaskTrampoline.visit(visitor); thisObject->m_performanceObject.visit(visitor); thisObject->m_primordialsObject.visit(visitor); thisObject->m_processEnvObject.visit(visitor); @@ -4387,6 +4407,12 @@ extern "C" void JSC__JSGlobalObject__reload(JSC__JSGlobalObject* arg0) globalObject->reload(); } +extern "C" void JSC__JSGlobalObject__queueMicrotaskCallback(Zig::GlobalObject* globalObject, void* ptr, MicrotaskCallback callback) +{ + JSFunction* function = globalObject->nativeMicrotaskTrampoline(); + globalObject->queueMicrotask(function, JSValue(reinterpret_cast<JSC::JSCell*>(ptr)), JSValue(reinterpret_cast<JSC::JSCell*>(callback)), jsUndefined(), jsUndefined()); +} + JSC::Identifier GlobalObject::moduleLoaderResolve(JSGlobalObject* globalObject, JSModuleLoader* loader, JSValue key, JSValue referrer, JSValue origin) diff --git a/src/bun.js/bindings/ZigGlobalObject.h b/src/bun.js/bindings/ZigGlobalObject.h index a5b802ced..da6ba92a0 100644 --- a/src/bun.js/bindings/ZigGlobalObject.h +++ b/src/bun.js/bindings/ZigGlobalObject.h @@ -369,6 +369,7 @@ public: mutable WriteBarrier<JSFunction> m_thenables[promiseFunctionsSize + 1]; JSObject* navigatorObject(); + JSFunction* nativeMicrotaskTrampoline() { return m_nativeMicrotaskTrampoline.getInitializedOnMainThread(this); } void trackFFIFunction(JSC::JSFunction* function) { @@ -466,6 +467,7 @@ private: */ LazyProperty<JSGlobalObject, JSC::Structure> m_pendingVirtualModuleResultStructure; LazyProperty<JSGlobalObject, JSFunction> m_performMicrotaskFunction; + LazyProperty<JSGlobalObject, JSFunction> m_nativeMicrotaskTrampoline; LazyProperty<JSGlobalObject, JSFunction> m_performMicrotaskVariadicFunction; LazyProperty<JSGlobalObject, JSFunction> m_emitReadableNextTickFunction; LazyProperty<JSGlobalObject, JSMap> m_lazyReadableStreamPrototypeMap; diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig index 277172b81..07882d857 100644 --- a/src/bun.js/bindings/bindings.zig +++ b/src/bun.js/bindings/bindings.zig @@ -2732,6 +2732,23 @@ pub const JSGlobalObject = extern struct { this.vm().throwError(this, this.createErrorInstance(Output.prettyFmt(fmt, false), args)); } } + extern fn JSC__JSGlobalObject__queueMicrotaskCallback(*JSGlobalObject, *anyopaque, Function: *const (fn (*anyopaque) callconv(.C) void)) void; + pub fn queueMicrotaskCallback( + this: *JSGlobalObject, + ctx_val: anytype, + comptime Function: fn (ctx: @TypeOf(ctx_val)) void, + ) void { + JSC.markBinding(@src()); + const Fn = Function; + const ContextType = @TypeOf(ctx_val); + const Wrapper = struct { + pub fn call(p: *anyopaque) callconv(.C) void { + Fn(bun.cast(ContextType, p)); + } + }; + + JSC__JSGlobalObject__queueMicrotaskCallback(this, ctx_val, &Wrapper.call); + } pub fn queueMicrotask( this: *JSGlobalObject, diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp index a346175df..1d6392f44 100644 --- a/src/bun.js/bindings/webcore/WebSocket.cpp +++ b/src/bun.js/bindings/webcore/WebSocket.cpp @@ -458,8 +458,8 @@ ExceptionOr<void> WebSocket::send(const String& message) return {}; } - if (message.length() > 0) - this->sendWebSocketString(message); + // 0-length is allowed + this->sendWebSocketString(message); return {}; } @@ -477,8 +477,8 @@ ExceptionOr<void> WebSocket::send(ArrayBuffer& binaryData) } char* data = static_cast<char*>(binaryData.data()); size_t length = binaryData.byteLength(); - if (length > 0) - this->sendWebSocketData(data, length); + // 0-length is allowed + this->sendWebSocketData(data, length); return {}; } @@ -498,8 +498,8 @@ ExceptionOr<void> WebSocket::send(ArrayBufferView& arrayBufferView) auto buffer = arrayBufferView.unsharedBuffer().get(); char* baseAddress = reinterpret_cast<char*>(buffer->data()) + arrayBufferView.byteOffset(); size_t length = arrayBufferView.byteLength(); - if (length > 0) - this->sendWebSocketData(baseAddress, length); + // 0-length is allowed + this->sendWebSocketData(baseAddress, length); return {}; } @@ -1232,14 +1232,19 @@ extern "C" void WebSocket__didCloseWithErrorCode(WebCore::WebSocket* webSocket, extern "C" void WebSocket__didReceiveText(WebCore::WebSocket* webSocket, bool clone, const ZigString* str) { - WTF::String wtf_str = Zig::toString(*str); - if (clone) { - wtf_str = wtf_str.isolatedCopy(); - } - + WTF::String wtf_str = clone ? Zig::toStringCopy(*str) : Zig::toString(*str); webSocket->didReceiveMessage(WTFMove(wtf_str)); } extern "C" void WebSocket__didReceiveBytes(WebCore::WebSocket* webSocket, uint8_t* bytes, size_t len) { webSocket->didReceiveBinaryData({ bytes, len }); } + +extern "C" void WebSocket__incrementPendingActivity(WebCore::WebSocket* webSocket) +{ + webSocket->incPendingActivityCount(); +} +extern "C" void WebSocket__decrementPendingActivity(WebCore::WebSocket* webSocket) +{ + webSocket->decPendingActivityCount(); +}
\ No newline at end of file diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h index 42261cfc4..846bd186b 100644 --- a/src/bun.js/bindings/webcore/WebSocket.h +++ b/src/bun.js/bindings/webcore/WebSocket.h @@ -111,6 +111,20 @@ public: return m_hasPendingActivity.load(); } + void incPendingActivityCount() + { + m_pendingActivityCount++; + ref(); + updateHasPendingActivity(); + } + + void decPendingActivityCount() + { + m_pendingActivityCount--; + deref(); + updateHasPendingActivity(); + } + private: typedef union AnyWebSocket { WebSocketClient* client; @@ -147,20 +161,6 @@ private: void sendWebSocketString(const String& message); void sendWebSocketData(const char* data, size_t length); - void incPendingActivityCount() - { - m_pendingActivityCount++; - ref(); - updateHasPendingActivity(); - } - - void decPendingActivityCount() - { - m_pendingActivityCount--; - deref(); - updateHasPendingActivity(); - } - void failAsynchronously(); enum class BinaryType { Blob, diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index ee0fb9c77..a3ae8c3ba 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -145,6 +145,15 @@ const CppWebSocket = opaque { pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode; pub const didReceiveText = WebSocket__didReceiveText; pub const didReceiveBytes = WebSocket__didReceiveBytes; + extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void; + extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void; + pub fn ref(this: *CppWebSocket) void { + WebSocket__incrementPendingActivity(this); + } + + pub fn unref(this: *CppWebSocket) void { + WebSocket__decrementPendingActivity(this); + } }; const body_buf_len = 16384 - 16; @@ -163,8 +172,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, - body_buf: ?*BodyBuf = null, - body_written: usize = 0, + body: std.ArrayListUnmanaged(u8) = .{}, websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: JSC.PollRef = .{}, @@ -280,10 +288,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get()); this.clearInput(); - if (this.body_buf) |buf| { - this.body_buf = null; - buf.release(); - } + this.body.clearAndFree(bun.default_allocator); } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); @@ -355,14 +360,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; } - fn getBody(this: *HTTPClient) *BodyBufBytes { - if (this.body_buf == null) { - this.body_buf = BodyBufPool.get(bun.default_allocator); - } - - return &this.body_buf.?.data; - } - pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); std.debug.assert(socket.socket == this.tcp.socket); @@ -374,43 +371,37 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); - var body = this.getBody(); - var remain = body[this.body_written..]; - const is_first = this.body_written == 0; + var body = data; + if (this.body.items.len > 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); + body = this.body.items; + } + + const is_first = this.body.items.len == 0; if (is_first) { // fail early if we receive a non-101 status code - if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) { + if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } - const to_write = remain[0..@min(remain.len, data.len)]; - if (data.len > 0 and to_write.len > 0) { - @memcpy(remain[0..to_write.len], data[0..to_write.len]); - this.body_written += to_write.len; - } - - const overflow = data[to_write.len..]; - - const available_to_read = body[0..this.body_written]; - const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { + const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { - if (overflow.len > 0) { - this.terminate(ErrorCode.headers_too_large); - return; + if (this.body.items.len == 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); } return; }, } }; - this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); + this.processResponse(response, body[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { @@ -420,8 +411,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { - std.debug.assert(this.body_written > 0); - var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; @@ -524,7 +513,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.invalid_response); return; }; - if (remain_buf.len > 0) @memcpy(overflow[0..remain_buf.len], remain_buf); + @memcpy(overflow, remain_buf); } this.clearData(); @@ -866,6 +855,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { globalThis: *JSC.JSGlobalObject, poll_ref: JSC.PollRef = JSC.PollRef.init(), + initial_data_handler: ?*InitialDataHandler = null, + pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); @@ -927,6 +918,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; + log("fail ({s})", .{@tagName(code)}); + ws.didCloseWithErrorCode(code); } @@ -937,7 +930,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { _ = socket; _ = ssl_error; JSC.markBinding(@src()); - log("WebSocket.onHandshake({d})", .{success}); + log("onHandshake({d})", .{success}); JSC.markBinding(@src()); if (success == 0) { if (this.outgoing_websocket) |ws| { @@ -1044,6 +1037,24 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { + // Due to scheduling, it is possible for the websocket onData + // handler to run with additional data before the microtask queue is + // drained. + if (this.initial_data_handler) |initial_handler| { + // This calls `handleData` + // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit. + // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there) + initial_handler.handleWithoutDeinit(); + + // handleWithoutDeinit is supposed to clear the handler from WebSocket* + // to prevent an infinite loop + std.debug.assert(this.initial_data_handler == null); + + // If we disconnected for any reason in the re-entrant case, we should just ignore the data + if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed()) + return; + } + var data = data_; var receive_state = this.receive_state; var terminated = false; @@ -1141,6 +1152,30 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { terminated = true; break; } + + // Handle when the payload length is 0, but it is a message + // + // This should become + // + // - ArrayBuffer(0) + // - "" + // - Buffer(0) (etc) + // + if (receive_body_remain == 0 and receive_state == .need_body and is_final) { + _ = this.consume( + "", + receive_body_remain, + last_receive_data_type, + is_final, + ); + + // Return to the header state to read the next frame + receive_state = .need_header; + is_fragmented = false; + + // Bail out if there's nothing left to read + if (data.len == 0) break; + } }, .need_mask => { this.terminate(.unexpected_mask_from_server); @@ -1201,6 +1236,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (data.len == 0) break; }, .need_body => { + // Empty messages are valid, but we handle that earlier in the flow. if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; @@ -1434,9 +1470,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (len == 0) - return; - const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating @@ -1460,9 +1493,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (str.len == 0) { - return; - } + // Note: 0 is valid { var inline_buf: [stack_frame_size]u8 = undefined; @@ -1525,6 +1556,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.sendCloseWithBody(this.tcp, code, null, 0); } + const InitialDataHandler = struct { + adopted: ?*WebSocket, + ws: *CppWebSocket, + slice: []u8, + + pub const Handle = JSC.AnyTask.New(@This(), handle); + + pub fn handleWithoutDeinit(this: *@This()) void { + var this_socket = this.adopted orelse return; + this.adopted = null; + this_socket.initial_data_handler = null; + var ws = this.ws; + defer ws.unref(); + + if (this_socket.outgoing_websocket != null) + this_socket.handleData(this_socket.tcp, this.slice); + } + + pub fn handle(this: *@This()) void { + defer { + bun.default_allocator.free(this.slice); + bun.default_allocator.destroy(this); + } + this.handleWithoutDeinit(); + } + }; + pub fn init( outgoing: *CppWebSocket, input_socket: *anyopaque, @@ -1554,33 +1612,19 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { - const InitialDataHandler = struct { - adopted: *WebSocket, - slice: []u8, - task: JSC.AnyTask = undefined, - - pub const Handle = JSC.AnyTask.New(@This(), handle); - - pub fn handle(this: *@This()) void { - defer { - bun.default_allocator.free(this.slice); - bun.default_allocator.destroy(this); - } - - this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; - var writable = this.adopted.receive_buffer.writableSlice(0); - @memcpy(writable[0..this.slice.len], this.slice); - - this.adopted.handleData(this.adopted.tcp, writable); - } - }; var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ .adopted = adopted, .slice = buffered_slice, + .ws = outgoing, }; - initial_data.task = InitialDataHandler.Handle.init(initial_data); - globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task)); + + // Use a higher-priority callback for the initial onData handler + globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle); + + // We need to ref the outgoing websocket so that it doesn't get finalized + // before the initial data handler is called + outgoing.ref(); } return @ptrCast( *anyopaque, diff --git a/test/js/bun/websocket/websocket-server.test.ts b/test/js/bun/websocket/websocket-server.test.ts index 2c2352f91..7913147f9 100644 --- a/test/js/bun/websocket/websocket-server.test.ts +++ b/test/js/bun/websocket/websocket-server.test.ts @@ -3,6 +3,77 @@ import { gcTick } from "harness"; import { serve, ServerWebSocket } from "bun"; describe("websocket server", () => { + it("send & receive empty messages", done => { + const serverReceived: any[] = []; + const clientReceived: any[] = []; + var clientDone = false; + var serverDone = false; + + let server = Bun.serve({ + websocket: { + open(ws) { + ws.send(""); + ws.send(new ArrayBuffer(0)); + }, + message(ws, data) { + serverReceived.push(data); + + if (serverReceived.length === 2) { + if (serverReceived.find(d => d === "") === undefined) { + done(new Error("expected empty string")); + } + + if (!serverReceived.find(d => d.byteLength === 0)) { + done(new Error("expected empty Buffer")); + } + + serverDone = true; + + if (clientDone && serverDone) { + z.close(); + server.stop(true); + done(); + } + } + }, + close() {}, + }, + fetch(req, server) { + if (!server.upgrade(req)) { + return new Response(null, { status: 404 }); + } + }, + port: 0, + }); + + let z = new WebSocket(`ws://${server.hostname}:${server.port}`); + z.onmessage = e => { + clientReceived.push(e.data); + + if (clientReceived.length === 2) { + if (clientReceived.find(d => d === "") === undefined) { + done(new Error("expected empty string")); + } + + if (!clientReceived.find(d => d.byteLength === 0)) { + done(new Error("expected empty Buffer")); + } + + clientDone = true; + if (clientDone && serverDone) { + server.stop(true); + z.close(); + + done(); + } + } + }; + z.addEventListener("open", () => { + z.send(""); + z.send(new Buffer(0)); + }); + }); + it("remoteAddress works", done => { let server = Bun.serve({ websocket: { @@ -859,16 +930,13 @@ describe("websocket server", () => { const server = serve({ port: 0, websocket: { - open(ws) { - server.stop(); - }, + open(ws) {}, message(ws, msg) { ws.send(sendQueue[serverCounter++] + " "); - gcTick(); + serverCounter % 10 === 0 && gcTick(); }, }, fetch(req, server) { - server.stop(); if ( server.upgrade(req, { data: { count: 0 }, @@ -879,32 +947,39 @@ describe("websocket server", () => { return new Response("noooooo hello world"); }, }); + try { + await new Promise<void>((resolve, reject) => { + const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`); + websocket.onerror = e => { + reject(e); + }; - await new Promise<void>((resolve, reject) => { - const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`); - websocket.onerror = e => { - reject(e); - }; + websocket.onopen = () => { + server.stop(); + websocket.send("first"); + }; - var counter = 0; - websocket.onopen = () => websocket.send("first"); - websocket.onmessage = e => { - try { - const expected = sendQueue[clientCounter++] + " "; - expect(e.data).toBe(expected); - websocket.send("next"); - if (clientCounter === sendQueue.length) { + websocket.onmessage = e => { + try { + const expected = sendQueue[clientCounter++] + " "; + expect(e.data).toBe(expected); + websocket.send("next"); + if (clientCounter === sendQueue.length) { + websocket.close(); + resolve(); + } + } catch (r) { + reject(r); + console.error(r); websocket.close(); - resolve(); } - } catch (r) { - reject(r); - console.error(r); - websocket.close(); - } - }; - }); - server.stop(true); + }; + }); + } catch (e) { + throw e; + } finally { + server.stop(true); + } }); // this test sends 100 messages to 10 connected clients via pubsub @@ -913,20 +988,15 @@ describe("websocket server", () => { var sendQueue: any[] = []; for (var i = 0; i < 100; i++) { sendQueue.push(ropey + " " + i); - gcTick(); } + var serverCounter = 0; var clientCount = 0; const server = serve({ port: 0, websocket: { - // FIXME: update this test to not rely on publishToSelf: true, - publishToSelf: true, - open(ws) { - server.stop(); ws.subscribe("test"); - gcTick(); if (!ws.isSubscribed("test")) { throw new Error("not subscribed"); } @@ -936,15 +1006,15 @@ describe("websocket server", () => { } ws.subscribe("test"); clientCount++; - if (clientCount === 10) setTimeout(() => ws.publish("test", "hello world"), 1); + if (clientCount === 10) { + setTimeout(() => server.publish("test", "hello world"), 1); + } }, message(ws, msg) { - if (serverCounter < sendQueue.length) ws.publish("test", sendQueue[serverCounter++] + " "); + if (serverCounter < sendQueue.length) server.publish("test", sendQueue[serverCounter++] + " "); }, }, fetch(req) { - gcTick(); - server.stop(); if ( server.upgrade(req, { data: { count: 0 }, @@ -954,89 +1024,89 @@ describe("websocket server", () => { return new Response("noooooo hello world"); }, }); + try { + const connections = new Array(10); + const websockets = new Array(connections.length); + var doneCounter = 0; + await new Promise<void>(done => { + for (var i = 0; i < connections.length; i++) { + var j = i; + var resolve: (_?: unknown) => void, + reject: (_?: unknown) => void, + resolveConnection: (_?: unknown) => void, + rejectConnection: (_?: unknown) => void; + connections[j] = new Promise((res, rej) => { + resolveConnection = res; + rejectConnection = rej; + }); + websockets[j] = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`); + websocket.onerror = e => { + reject(e); + }; + websocket.onclose = () => { + doneCounter++; + if (doneCounter === connections.length) { + done(); + } + }; + var hasOpened = false; + websocket.onopen = () => { + if (!hasOpened) { + hasOpened = true; + resolve(websocket); + } + }; - const connections = new Array(10); - const websockets = new Array(connections.length); - var doneCounter = 0; - await new Promise<void>(done => { - for (var i = 0; i < connections.length; i++) { - var j = i; - var resolve: (_?: unknown) => void, - reject: (_?: unknown) => void, - resolveConnection: (_?: unknown) => void, - rejectConnection: (_?: unknown) => void; - connections[j] = new Promise((res, rej) => { - resolveConnection = res; - rejectConnection = rej; - }); - websockets[j] = new Promise((res, rej) => { - resolve = res; - reject = rej; - }); - gcTick(); - const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`); - websocket.onerror = e => { - reject(e); - }; - websocket.onclose = () => { - doneCounter++; - if (doneCounter === connections.length) { - done(); - } - }; - var hasOpened = false; - websocket.onopen = () => { - if (!hasOpened) { - hasOpened = true; - resolve(websocket); - } - }; - - let clientCounter = -1; - var hasSentThisTick = false; - - websocket.onmessage = e => { - gcTick(); - - if (!hasOpened) { - hasOpened = true; - resolve(websocket); - } + let clientCounter = -1; + var hasSentThisTick = false; - if (e.data === "hello world") { - clientCounter = 0; - websocket.send("first"); - return; - } + websocket.onmessage = e => { + if (!hasOpened) { + hasOpened = true; + resolve(websocket); + } - try { - expect(!!sendQueue.find(a => a + " " === e.data)).toBe(true); - - if (!hasSentThisTick) { - websocket.send("second"); - hasSentThisTick = true; - queueMicrotask(() => { - hasSentThisTick = false; - }); + if (e.data === "hello world") { + clientCounter = 0; + websocket.send("first"); + return; } - gcTick(); + try { + expect(!!sendQueue.find(a => a + " " === e.data)).toBe(true); + + if (!hasSentThisTick) { + websocket.send("second"); + hasSentThisTick = true; + queueMicrotask(() => { + hasSentThisTick = false; + }); + } - if (clientCounter++ === sendQueue.length - 1) { + if (clientCounter++ === sendQueue.length - 1) { + websocket.close(); + resolveConnection(); + } + } catch (r) { + console.error(r); websocket.close(); - resolveConnection(); + rejectConnection(r); } - } catch (r) { - console.error(r); - websocket.close(); - rejectConnection(r); - gcTick(); - } - }; - } - }); + }; + } + }); + } catch (e) { + throw e; + } finally { + server.stop(true); + gcTick(); + } + expect(serverCounter).toBe(sendQueue.length); - server.stop(true); }, 30_000); it("can close with reason and code #2631", done => { let timeout: any; diff --git a/test/js/web/websocket/websocket.test.js b/test/js/web/websocket/websocket.test.js index 867b86123..76ff16ecb 100644 --- a/test/js/web/websocket/websocket.test.js +++ b/test/js/web/websocket/websocket.test.js @@ -6,16 +6,33 @@ const TEST_WEBSOCKET_HOST = process.env.TEST_WEBSOCKET_HOST || "wss://ws.postman describe("WebSocket", () => { it("should connect", async () => { - const ws = new WebSocket(TEST_WEBSOCKET_HOST); - await new Promise((resolve, reject) => { + const server = Bun.serve({ + port: 0, + fetch(req, server) { + if (server.upgrade(req)) { + server.stop(); + return; + } + + return new Response(); + }, + websocket: { + open(ws) {}, + message(ws) { + ws.close(); + }, + }, + }); + const ws = new WebSocket(`ws://${server.hostname}:${server.port}`, {}); + await new Promise(resolve => { ws.onopen = resolve; - ws.onerror = reject; }); - var closed = new Promise((resolve, reject) => { + var closed = new Promise(resolve => { ws.onclose = resolve; }); ws.close(); await closed; + server.stop(true); }); it("should connect over https", async () => { @@ -59,17 +76,18 @@ describe("WebSocket", () => { const server = Bun.serve({ port: 0, fetch(req, server) { - server.stop(); done(); + server.stop(); return new Response(); }, websocket: { - open(ws) { + open(ws) {}, + message(ws) { ws.close(); }, }, }); - const ws = new WebSocket(`http://${server.hostname}:${server.port}`, {}); + new WebSocket(`http://${server.hostname}:${server.port}`, {}); }); describe("nodebuffer", () => { it("should support 'nodebuffer' binaryType", done => { @@ -93,6 +111,7 @@ describe("WebSocket", () => { expect(ws.binaryType).toBe("nodebuffer"); Bun.gc(true); ws.onmessage = ({ data }) => { + ws.close(); expect(Buffer.isBuffer(data)).toBe(true); expect(data).toEqual(new Uint8Array([1, 2, 3])); server.stop(true); @@ -117,6 +136,7 @@ describe("WebSocket", () => { ws.sendBinary(new Uint8Array([1, 2, 3])); setTimeout(() => { client.onmessage = ({ data }) => { + client.close(); expect(Buffer.isBuffer(data)).toBe(true); expect(data).toEqual(new Uint8Array([1, 2, 3])); server.stop(true); |