aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <jarred@jarredsumner.com> 2023-07-03 20:53:41 -0700
committerGravatar GitHub <noreply@github.com> 2023-07-03 20:53:41 -0700
commit3345a7fc3c81f914000fd5a3c5a24f920a70386a (patch)
treeb4a0d9ffa9d37d17bbf41f74c9d3f3a624ae8aab
parentb26b0d886ce2f9898833e8efa16b71952c39b615 (diff)
downloadbun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.tar.gz
bun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.tar.zst
bun-3345a7fc3c81f914000fd5a3c5a24f920a70386a.zip
Allow zero length WebSocket client & server messages (#3488)
* Allow zero length WebSocket client & server messages * Add test * Clean this up a little * Clean up these tests a little * Hopefully fix the test failure in release build * Don't copy into the receive buffer * Less flaky --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
-rw-r--r--src/bun.js/api/server.zig53
-rw-r--r--src/bun.js/bindings/ZigGlobalObject.cpp26
-rw-r--r--src/bun.js/bindings/ZigGlobalObject.h2
-rw-r--r--src/bun.js/bindings/bindings.zig17
-rw-r--r--src/bun.js/bindings/webcore/WebSocket.cpp27
-rw-r--r--src/bun.js/bindings/webcore/WebSocket.h28
-rw-r--r--src/http/websocket_http_client.zig172
-rw-r--r--test/js/bun/websocket/websocket-server.test.ts292
-rw-r--r--test/js/web/websocket/websocket.test.js34
9 files changed, 391 insertions, 260 deletions
diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig
index f52c08301..9625ff693 100644
--- a/src/bun.js/api/server.zig
+++ b/src/bun.js/api/server.zig
@@ -3560,11 +3560,6 @@ pub const ServerWebSocket = struct {
if (message_value.asArrayBuffer(globalThis)) |array_buffer| {
const buffer = array_buffer.slice();
- if (buffer.len == 0) {
- globalThis.throw("publish requires a non-empty message", .{});
- return .zero;
- }
-
const result = if (!publish_to_self)
this.websocket.publish(topic_slice.slice(), buffer, .binary, compress)
else
@@ -3580,9 +3575,6 @@ pub const ServerWebSocket = struct {
{
var string_slice = message_value.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
@@ -3634,10 +3626,6 @@ pub const ServerWebSocket = struct {
var topic_slice = topic_value.toSlice(globalThis, bun.default_allocator);
defer topic_slice.deinit();
- if (topic_slice.len == 0) {
- globalThis.throw("publishText requires a non-empty topic", .{});
- return .zero;
- }
const compress = args.len > 1 and compress_value.toBoolean();
@@ -3648,9 +3636,6 @@ pub const ServerWebSocket = struct {
var string_slice = message_value.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
@@ -3715,10 +3700,6 @@ pub const ServerWebSocket = struct {
};
const buffer = array_buffer.slice();
- if (buffer.len == 0) {
- return JSC.JSValue.jsNumber(0);
- }
-
const result = if (!publish_to_self)
this.websocket.publish(topic_slice.slice(), buffer, .binary, compress)
else
@@ -3883,10 +3864,6 @@ pub const ServerWebSocket = struct {
}
if (message_value.asArrayBuffer(globalThis)) |buffer| {
- if (buffer.len == 0) {
- return JSValue.jsNumber(0);
- }
-
switch (this.websocket.send(buffer.slice(), .binary, compress, true)) {
.backpressure => {
log("send() backpressure ({d} bytes)", .{buffer.len});
@@ -3906,9 +3883,6 @@ pub const ServerWebSocket = struct {
{
var string_slice = message_value.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
switch (this.websocket.send(buffer, .text, compress, true)) {
@@ -3960,9 +3934,6 @@ pub const ServerWebSocket = struct {
var string_slice = message_value.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
switch (this.websocket.send(buffer, .text, compress, true)) {
@@ -3994,9 +3965,6 @@ pub const ServerWebSocket = struct {
var string_slice = message_str.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
switch (this.websocket.send(buffer, .text, compress, true)) {
@@ -4043,10 +4011,6 @@ pub const ServerWebSocket = struct {
return .zero;
};
- if (buffer.len == 0) {
- return JSValue.jsNumber(0);
- }
-
switch (this.websocket.send(buffer.slice(), .binary, compress, true)) {
.backpressure => {
log("sendBinary() backpressure ({d} bytes)", .{buffer.len});
@@ -4076,10 +4040,6 @@ pub const ServerWebSocket = struct {
const buffer = array_buffer.slice();
- if (buffer.len == 0) {
- return JSValue.jsNumber(0);
- }
-
switch (this.websocket.send(buffer, .binary, compress, true)) {
.backpressure => {
log("sendBinary() backpressure ({d} bytes)", .{buffer.len});
@@ -4416,17 +4376,7 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
const compress = (compress_value orelse JSValue.jsBoolean(true)).toBoolean();
- if (message_value.isEmptyOrUndefinedOrNull()) {
- JSC.JSError(this.vm.allocator, "publish requires a non-empty message", .{}, globalThis, exception);
- return .zero;
- }
-
if (message_value.asArrayBuffer(globalThis)) |buffer| {
- if (buffer.len == 0) {
- JSC.JSError(this.vm.allocator, "publish requires a non-empty message", .{}, globalThis, exception);
- return .zero;
- }
-
return JSValue.jsNumber(
// if 0, return 0
// else return number of bytes sent
@@ -4437,9 +4387,6 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
{
var string_slice = message_value.toSlice(globalThis, bun.default_allocator);
defer string_slice.deinit();
- if (string_slice.len == 0) {
- return JSValue.jsNumber(0);
- }
const buffer = string_slice.slice();
return JSValue.jsNumber(
diff --git a/src/bun.js/bindings/ZigGlobalObject.cpp b/src/bun.js/bindings/ZigGlobalObject.cpp
index c00670289..b3236a4a2 100644
--- a/src/bun.js/bindings/ZigGlobalObject.cpp
+++ b/src/bun.js/bindings/ZigGlobalObject.cpp
@@ -1021,6 +1021,20 @@ JSC_DEFINE_HOST_FUNCTION(functionBunSleepThenCallback,
return JSC::JSValue::encode(promise);
}
+using MicrotaskCallback = void (*)(void*);
+
+JSC_DEFINE_HOST_FUNCTION(functionNativeMicrotaskTrampoline,
+ (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame))
+{
+ JSCell* cellPtr = callFrame->uncheckedArgument(0).asCell();
+ JSCell* callbackPtr = callFrame->uncheckedArgument(1).asCell();
+
+ void* cell = reinterpret_cast<void*>(cellPtr);
+ auto* callback = reinterpret_cast<MicrotaskCallback>(callbackPtr);
+ callback(cell);
+ return JSValue::encode(jsUndefined());
+}
+
JSC_DEFINE_HOST_FUNCTION(functionBunSleep,
(JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame))
{
@@ -3027,6 +3041,11 @@ void GlobalObject::finishCreation(VM& vm)
init.set(JSFunction::create(init.vm, init.owner, 4, "performMicrotaskVariadic"_s, jsFunctionPerformMicrotaskVariadic, ImplementationVisibility::Public));
});
+ m_nativeMicrotaskTrampoline.initLater(
+ [](const Initializer<JSFunction>& init) {
+ init.set(JSFunction::create(init.vm, init.owner, 2, ""_s, functionNativeMicrotaskTrampoline, ImplementationVisibility::Public));
+ });
+
m_navigatorObject.initLater(
[](const Initializer<JSObject>& init) {
int cpuCount = 0;
@@ -4225,6 +4244,7 @@ void GlobalObject::visitChildrenImpl(JSCell* cell, Visitor& visitor)
thisObject->m_JSFileSinkControllerPrototype.visit(visitor);
thisObject->m_JSHTTPSResponseControllerPrototype.visit(visitor);
thisObject->m_navigatorObject.visit(visitor);
+ thisObject->m_nativeMicrotaskTrampoline.visit(visitor);
thisObject->m_performanceObject.visit(visitor);
thisObject->m_primordialsObject.visit(visitor);
thisObject->m_processEnvObject.visit(visitor);
@@ -4387,6 +4407,12 @@ extern "C" void JSC__JSGlobalObject__reload(JSC__JSGlobalObject* arg0)
globalObject->reload();
}
+extern "C" void JSC__JSGlobalObject__queueMicrotaskCallback(Zig::GlobalObject* globalObject, void* ptr, MicrotaskCallback callback)
+{
+ JSFunction* function = globalObject->nativeMicrotaskTrampoline();
+ globalObject->queueMicrotask(function, JSValue(reinterpret_cast<JSC::JSCell*>(ptr)), JSValue(reinterpret_cast<JSC::JSCell*>(callback)), jsUndefined(), jsUndefined());
+}
+
JSC::Identifier GlobalObject::moduleLoaderResolve(JSGlobalObject* globalObject,
JSModuleLoader* loader, JSValue key,
JSValue referrer, JSValue origin)
diff --git a/src/bun.js/bindings/ZigGlobalObject.h b/src/bun.js/bindings/ZigGlobalObject.h
index a5b802ced..da6ba92a0 100644
--- a/src/bun.js/bindings/ZigGlobalObject.h
+++ b/src/bun.js/bindings/ZigGlobalObject.h
@@ -369,6 +369,7 @@ public:
mutable WriteBarrier<JSFunction> m_thenables[promiseFunctionsSize + 1];
JSObject* navigatorObject();
+ JSFunction* nativeMicrotaskTrampoline() { return m_nativeMicrotaskTrampoline.getInitializedOnMainThread(this); }
void trackFFIFunction(JSC::JSFunction* function)
{
@@ -466,6 +467,7 @@ private:
*/
LazyProperty<JSGlobalObject, JSC::Structure> m_pendingVirtualModuleResultStructure;
LazyProperty<JSGlobalObject, JSFunction> m_performMicrotaskFunction;
+ LazyProperty<JSGlobalObject, JSFunction> m_nativeMicrotaskTrampoline;
LazyProperty<JSGlobalObject, JSFunction> m_performMicrotaskVariadicFunction;
LazyProperty<JSGlobalObject, JSFunction> m_emitReadableNextTickFunction;
LazyProperty<JSGlobalObject, JSMap> m_lazyReadableStreamPrototypeMap;
diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig
index 277172b81..07882d857 100644
--- a/src/bun.js/bindings/bindings.zig
+++ b/src/bun.js/bindings/bindings.zig
@@ -2732,6 +2732,23 @@ pub const JSGlobalObject = extern struct {
this.vm().throwError(this, this.createErrorInstance(Output.prettyFmt(fmt, false), args));
}
}
+ extern fn JSC__JSGlobalObject__queueMicrotaskCallback(*JSGlobalObject, *anyopaque, Function: *const (fn (*anyopaque) callconv(.C) void)) void;
+ pub fn queueMicrotaskCallback(
+ this: *JSGlobalObject,
+ ctx_val: anytype,
+ comptime Function: fn (ctx: @TypeOf(ctx_val)) void,
+ ) void {
+ JSC.markBinding(@src());
+ const Fn = Function;
+ const ContextType = @TypeOf(ctx_val);
+ const Wrapper = struct {
+ pub fn call(p: *anyopaque) callconv(.C) void {
+ Fn(bun.cast(ContextType, p));
+ }
+ };
+
+ JSC__JSGlobalObject__queueMicrotaskCallback(this, ctx_val, &Wrapper.call);
+ }
pub fn queueMicrotask(
this: *JSGlobalObject,
diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp
index a346175df..1d6392f44 100644
--- a/src/bun.js/bindings/webcore/WebSocket.cpp
+++ b/src/bun.js/bindings/webcore/WebSocket.cpp
@@ -458,8 +458,8 @@ ExceptionOr<void> WebSocket::send(const String& message)
return {};
}
- if (message.length() > 0)
- this->sendWebSocketString(message);
+ // 0-length is allowed
+ this->sendWebSocketString(message);
return {};
}
@@ -477,8 +477,8 @@ ExceptionOr<void> WebSocket::send(ArrayBuffer& binaryData)
}
char* data = static_cast<char*>(binaryData.data());
size_t length = binaryData.byteLength();
- if (length > 0)
- this->sendWebSocketData(data, length);
+ // 0-length is allowed
+ this->sendWebSocketData(data, length);
return {};
}
@@ -498,8 +498,8 @@ ExceptionOr<void> WebSocket::send(ArrayBufferView& arrayBufferView)
auto buffer = arrayBufferView.unsharedBuffer().get();
char* baseAddress = reinterpret_cast<char*>(buffer->data()) + arrayBufferView.byteOffset();
size_t length = arrayBufferView.byteLength();
- if (length > 0)
- this->sendWebSocketData(baseAddress, length);
+ // 0-length is allowed
+ this->sendWebSocketData(baseAddress, length);
return {};
}
@@ -1232,14 +1232,19 @@ extern "C" void WebSocket__didCloseWithErrorCode(WebCore::WebSocket* webSocket,
extern "C" void WebSocket__didReceiveText(WebCore::WebSocket* webSocket, bool clone, const ZigString* str)
{
- WTF::String wtf_str = Zig::toString(*str);
- if (clone) {
- wtf_str = wtf_str.isolatedCopy();
- }
-
+ WTF::String wtf_str = clone ? Zig::toStringCopy(*str) : Zig::toString(*str);
webSocket->didReceiveMessage(WTFMove(wtf_str));
}
extern "C" void WebSocket__didReceiveBytes(WebCore::WebSocket* webSocket, uint8_t* bytes, size_t len)
{
webSocket->didReceiveBinaryData({ bytes, len });
}
+
+extern "C" void WebSocket__incrementPendingActivity(WebCore::WebSocket* webSocket)
+{
+ webSocket->incPendingActivityCount();
+}
+extern "C" void WebSocket__decrementPendingActivity(WebCore::WebSocket* webSocket)
+{
+ webSocket->decPendingActivityCount();
+} \ No newline at end of file
diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h
index 42261cfc4..846bd186b 100644
--- a/src/bun.js/bindings/webcore/WebSocket.h
+++ b/src/bun.js/bindings/webcore/WebSocket.h
@@ -111,6 +111,20 @@ public:
return m_hasPendingActivity.load();
}
+ void incPendingActivityCount()
+ {
+ m_pendingActivityCount++;
+ ref();
+ updateHasPendingActivity();
+ }
+
+ void decPendingActivityCount()
+ {
+ m_pendingActivityCount--;
+ deref();
+ updateHasPendingActivity();
+ }
+
private:
typedef union AnyWebSocket {
WebSocketClient* client;
@@ -147,20 +161,6 @@ private:
void sendWebSocketString(const String& message);
void sendWebSocketData(const char* data, size_t length);
- void incPendingActivityCount()
- {
- m_pendingActivityCount++;
- ref();
- updateHasPendingActivity();
- }
-
- void decPendingActivityCount()
- {
- m_pendingActivityCount--;
- deref();
- updateHasPendingActivity();
- }
-
void failAsynchronously();
enum class BinaryType { Blob,
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index ee0fb9c77..a3ae8c3ba 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -145,6 +145,15 @@ const CppWebSocket = opaque {
pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode;
pub const didReceiveText = WebSocket__didReceiveText;
pub const didReceiveBytes = WebSocket__didReceiveBytes;
+ extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void;
+ extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void;
+ pub fn ref(this: *CppWebSocket) void {
+ WebSocket__incrementPendingActivity(this);
+ }
+
+ pub fn unref(this: *CppWebSocket) void {
+ WebSocket__decrementPendingActivity(this);
+ }
};
const body_buf_len = 16384 - 16;
@@ -163,8 +172,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
to_send: []const u8 = "",
read_length: usize = 0,
headers_buf: [128]PicoHTTP.Header = undefined,
- body_buf: ?*BodyBuf = null,
- body_written: usize = 0,
+ body: std.ArrayListUnmanaged(u8) = .{},
websocket_protocol: u64 = 0,
hostname: [:0]const u8 = "",
poll_ref: JSC.PollRef = .{},
@@ -280,10 +288,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get());
this.clearInput();
- if (this.body_buf) |buf| {
- this.body_buf = null;
- buf.release();
- }
+ this.body.clearAndFree(bun.default_allocator);
}
pub fn cancel(this: *HTTPClient) callconv(.C) void {
this.clearData();
@@ -355,14 +360,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.to_send = this.input_body_buf[@intCast(usize, wrote)..];
}
- fn getBody(this: *HTTPClient) *BodyBufBytes {
- if (this.body_buf == null) {
- this.body_buf = BodyBufPool.get(bun.default_allocator);
- }
-
- return &this.body_buf.?.data;
- }
-
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
log("onData", .{});
std.debug.assert(socket.socket == this.tcp.socket);
@@ -374,43 +371,37 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
if (comptime Environment.allow_assert)
std.debug.assert(!socket.isShutdown());
- var body = this.getBody();
- var remain = body[this.body_written..];
- const is_first = this.body_written == 0;
+ var body = data;
+ if (this.body.items.len > 0) {
+ this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory");
+ body = this.body.items;
+ }
+
+ const is_first = this.body.items.len == 0;
if (is_first) {
// fail early if we receive a non-101 status code
- if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) {
+ if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) {
this.terminate(ErrorCode.expected_101_status_code);
return;
}
}
- const to_write = remain[0..@min(remain.len, data.len)];
- if (data.len > 0 and to_write.len > 0) {
- @memcpy(remain[0..to_write.len], data[0..to_write.len]);
- this.body_written += to_write.len;
- }
-
- const overflow = data[to_write.len..];
-
- const available_to_read = body[0..this.body_written];
- const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| {
+ const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| {
switch (err) {
error.Malformed_HTTP_Response => {
this.terminate(ErrorCode.invalid_response);
return;
},
error.ShortRead => {
- if (overflow.len > 0) {
- this.terminate(ErrorCode.headers_too_large);
- return;
+ if (this.body.items.len == 0) {
+ this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory");
}
return;
},
}
};
- this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]);
+ this.processResponse(response, body[@intCast(usize, response.bytes_read)..]);
}
pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
@@ -420,8 +411,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void {
- std.debug.assert(this.body_written > 0);
-
var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" };
var connection_header = PicoHTTP.Header{ .name = "", .value = "" };
var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" };
@@ -524,7 +513,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.terminate(ErrorCode.invalid_response);
return;
};
- if (remain_buf.len > 0) @memcpy(overflow[0..remain_buf.len], remain_buf);
+ @memcpy(overflow, remain_buf);
}
this.clearData();
@@ -866,6 +855,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
globalThis: *JSC.JSGlobalObject,
poll_ref: JSC.PollRef = JSC.PollRef.init(),
+ initial_data_handler: ?*InitialDataHandler = null,
+
pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";
pub const shim = JSC.Shimmer("Bun", name, @This());
@@ -927,6 +918,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
JSC.markBinding(@src());
if (this.outgoing_websocket) |ws| {
this.outgoing_websocket = null;
+ log("fail ({s})", .{@tagName(code)});
+
ws.didCloseWithErrorCode(code);
}
@@ -937,7 +930,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
_ = socket;
_ = ssl_error;
JSC.markBinding(@src());
- log("WebSocket.onHandshake({d})", .{success});
+ log("onHandshake({d})", .{success});
JSC.markBinding(@src());
if (success == 0) {
if (this.outgoing_websocket) |ws| {
@@ -1044,6 +1037,24 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void {
+ // Due to scheduling, it is possible for the websocket onData
+ // handler to run with additional data before the microtask queue is
+ // drained.
+ if (this.initial_data_handler) |initial_handler| {
+ // This calls `handleData`
+ // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit.
+ // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there)
+ initial_handler.handleWithoutDeinit();
+
+ // handleWithoutDeinit is supposed to clear the handler from WebSocket*
+ // to prevent an infinite loop
+ std.debug.assert(this.initial_data_handler == null);
+
+ // If we disconnected for any reason in the re-entrant case, we should just ignore the data
+ if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed())
+ return;
+ }
+
var data = data_;
var receive_state = this.receive_state;
var terminated = false;
@@ -1141,6 +1152,30 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
terminated = true;
break;
}
+
+ // Handle when the payload length is 0, but it is a message
+ //
+ // This should become
+ //
+ // - ArrayBuffer(0)
+ // - ""
+ // - Buffer(0) (etc)
+ //
+ if (receive_body_remain == 0 and receive_state == .need_body and is_final) {
+ _ = this.consume(
+ "",
+ receive_body_remain,
+ last_receive_data_type,
+ is_final,
+ );
+
+ // Return to the header state to read the next frame
+ receive_state = .need_header;
+ is_fragmented = false;
+
+ // Bail out if there's nothing left to read
+ if (data.len == 0) break;
+ }
},
.need_mask => {
this.terminate(.unexpected_mask_from_server);
@@ -1201,6 +1236,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
if (data.len == 0) break;
},
.need_body => {
+ // Empty messages are valid, but we handle that earlier in the flow.
if (receive_body_remain == 0 and data.len > 0) {
this.terminate(ErrorCode.expected_control_frame);
terminated = true;
@@ -1434,9 +1470,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return;
}
- if (len == 0)
- return;
-
const slice = ptr[0..len];
const bytes = Copy{ .bytes = slice };
// fast path: small frame, no backpressure, attempt to send without allocating
@@ -1460,9 +1493,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return;
}
- if (str.len == 0) {
- return;
- }
+ // Note: 0 is valid
{
var inline_buf: [stack_frame_size]u8 = undefined;
@@ -1525,6 +1556,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
this.sendCloseWithBody(this.tcp, code, null, 0);
}
+ const InitialDataHandler = struct {
+ adopted: ?*WebSocket,
+ ws: *CppWebSocket,
+ slice: []u8,
+
+ pub const Handle = JSC.AnyTask.New(@This(), handle);
+
+ pub fn handleWithoutDeinit(this: *@This()) void {
+ var this_socket = this.adopted orelse return;
+ this.adopted = null;
+ this_socket.initial_data_handler = null;
+ var ws = this.ws;
+ defer ws.unref();
+
+ if (this_socket.outgoing_websocket != null)
+ this_socket.handleData(this_socket.tcp, this.slice);
+ }
+
+ pub fn handle(this: *@This()) void {
+ defer {
+ bun.default_allocator.free(this.slice);
+ bun.default_allocator.destroy(this);
+ }
+ this.handleWithoutDeinit();
+ }
+ };
+
pub fn init(
outgoing: *CppWebSocket,
input_socket: *anyopaque,
@@ -1554,33 +1612,19 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
var buffered_slice: []u8 = buffered_data[0..buffered_data_len];
if (buffered_slice.len > 0) {
- const InitialDataHandler = struct {
- adopted: *WebSocket,
- slice: []u8,
- task: JSC.AnyTask = undefined,
-
- pub const Handle = JSC.AnyTask.New(@This(), handle);
-
- pub fn handle(this: *@This()) void {
- defer {
- bun.default_allocator.free(this.slice);
- bun.default_allocator.destroy(this);
- }
-
- this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return;
- var writable = this.adopted.receive_buffer.writableSlice(0);
- @memcpy(writable[0..this.slice.len], this.slice);
-
- this.adopted.handleData(this.adopted.tcp, writable);
- }
- };
var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable;
initial_data.* = .{
.adopted = adopted,
.slice = buffered_slice,
+ .ws = outgoing,
};
- initial_data.task = InitialDataHandler.Handle.init(initial_data);
- globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task));
+
+ // Use a higher-priority callback for the initial onData handler
+ globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle);
+
+ // We need to ref the outgoing websocket so that it doesn't get finalized
+ // before the initial data handler is called
+ outgoing.ref();
}
return @ptrCast(
*anyopaque,
diff --git a/test/js/bun/websocket/websocket-server.test.ts b/test/js/bun/websocket/websocket-server.test.ts
index 2c2352f91..7913147f9 100644
--- a/test/js/bun/websocket/websocket-server.test.ts
+++ b/test/js/bun/websocket/websocket-server.test.ts
@@ -3,6 +3,77 @@ import { gcTick } from "harness";
import { serve, ServerWebSocket } from "bun";
describe("websocket server", () => {
+ it("send & receive empty messages", done => {
+ const serverReceived: any[] = [];
+ const clientReceived: any[] = [];
+ var clientDone = false;
+ var serverDone = false;
+
+ let server = Bun.serve({
+ websocket: {
+ open(ws) {
+ ws.send("");
+ ws.send(new ArrayBuffer(0));
+ },
+ message(ws, data) {
+ serverReceived.push(data);
+
+ if (serverReceived.length === 2) {
+ if (serverReceived.find(d => d === "") === undefined) {
+ done(new Error("expected empty string"));
+ }
+
+ if (!serverReceived.find(d => d.byteLength === 0)) {
+ done(new Error("expected empty Buffer"));
+ }
+
+ serverDone = true;
+
+ if (clientDone && serverDone) {
+ z.close();
+ server.stop(true);
+ done();
+ }
+ }
+ },
+ close() {},
+ },
+ fetch(req, server) {
+ if (!server.upgrade(req)) {
+ return new Response(null, { status: 404 });
+ }
+ },
+ port: 0,
+ });
+
+ let z = new WebSocket(`ws://${server.hostname}:${server.port}`);
+ z.onmessage = e => {
+ clientReceived.push(e.data);
+
+ if (clientReceived.length === 2) {
+ if (clientReceived.find(d => d === "") === undefined) {
+ done(new Error("expected empty string"));
+ }
+
+ if (!clientReceived.find(d => d.byteLength === 0)) {
+ done(new Error("expected empty Buffer"));
+ }
+
+ clientDone = true;
+ if (clientDone && serverDone) {
+ server.stop(true);
+ z.close();
+
+ done();
+ }
+ }
+ };
+ z.addEventListener("open", () => {
+ z.send("");
+ z.send(new Buffer(0));
+ });
+ });
+
it("remoteAddress works", done => {
let server = Bun.serve({
websocket: {
@@ -859,16 +930,13 @@ describe("websocket server", () => {
const server = serve({
port: 0,
websocket: {
- open(ws) {
- server.stop();
- },
+ open(ws) {},
message(ws, msg) {
ws.send(sendQueue[serverCounter++] + " ");
- gcTick();
+ serverCounter % 10 === 0 && gcTick();
},
},
fetch(req, server) {
- server.stop();
if (
server.upgrade(req, {
data: { count: 0 },
@@ -879,32 +947,39 @@ describe("websocket server", () => {
return new Response("noooooo hello world");
},
});
+ try {
+ await new Promise<void>((resolve, reject) => {
+ const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`);
+ websocket.onerror = e => {
+ reject(e);
+ };
- await new Promise<void>((resolve, reject) => {
- const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`);
- websocket.onerror = e => {
- reject(e);
- };
+ websocket.onopen = () => {
+ server.stop();
+ websocket.send("first");
+ };
- var counter = 0;
- websocket.onopen = () => websocket.send("first");
- websocket.onmessage = e => {
- try {
- const expected = sendQueue[clientCounter++] + " ";
- expect(e.data).toBe(expected);
- websocket.send("next");
- if (clientCounter === sendQueue.length) {
+ websocket.onmessage = e => {
+ try {
+ const expected = sendQueue[clientCounter++] + " ";
+ expect(e.data).toBe(expected);
+ websocket.send("next");
+ if (clientCounter === sendQueue.length) {
+ websocket.close();
+ resolve();
+ }
+ } catch (r) {
+ reject(r);
+ console.error(r);
websocket.close();
- resolve();
}
- } catch (r) {
- reject(r);
- console.error(r);
- websocket.close();
- }
- };
- });
- server.stop(true);
+ };
+ });
+ } catch (e) {
+ throw e;
+ } finally {
+ server.stop(true);
+ }
});
// this test sends 100 messages to 10 connected clients via pubsub
@@ -913,20 +988,15 @@ describe("websocket server", () => {
var sendQueue: any[] = [];
for (var i = 0; i < 100; i++) {
sendQueue.push(ropey + " " + i);
- gcTick();
}
+
var serverCounter = 0;
var clientCount = 0;
const server = serve({
port: 0,
websocket: {
- // FIXME: update this test to not rely on publishToSelf: true,
- publishToSelf: true,
-
open(ws) {
- server.stop();
ws.subscribe("test");
- gcTick();
if (!ws.isSubscribed("test")) {
throw new Error("not subscribed");
}
@@ -936,15 +1006,15 @@ describe("websocket server", () => {
}
ws.subscribe("test");
clientCount++;
- if (clientCount === 10) setTimeout(() => ws.publish("test", "hello world"), 1);
+ if (clientCount === 10) {
+ setTimeout(() => server.publish("test", "hello world"), 1);
+ }
},
message(ws, msg) {
- if (serverCounter < sendQueue.length) ws.publish("test", sendQueue[serverCounter++] + " ");
+ if (serverCounter < sendQueue.length) server.publish("test", sendQueue[serverCounter++] + " ");
},
},
fetch(req) {
- gcTick();
- server.stop();
if (
server.upgrade(req, {
data: { count: 0 },
@@ -954,89 +1024,89 @@ describe("websocket server", () => {
return new Response("noooooo hello world");
},
});
+ try {
+ const connections = new Array(10);
+ const websockets = new Array(connections.length);
+ var doneCounter = 0;
+ await new Promise<void>(done => {
+ for (var i = 0; i < connections.length; i++) {
+ var j = i;
+ var resolve: (_?: unknown) => void,
+ reject: (_?: unknown) => void,
+ resolveConnection: (_?: unknown) => void,
+ rejectConnection: (_?: unknown) => void;
+ connections[j] = new Promise((res, rej) => {
+ resolveConnection = res;
+ rejectConnection = rej;
+ });
+ websockets[j] = new Promise((res, rej) => {
+ resolve = res;
+ reject = rej;
+ });
+ const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`);
+ websocket.onerror = e => {
+ reject(e);
+ };
+ websocket.onclose = () => {
+ doneCounter++;
+ if (doneCounter === connections.length) {
+ done();
+ }
+ };
+ var hasOpened = false;
+ websocket.onopen = () => {
+ if (!hasOpened) {
+ hasOpened = true;
+ resolve(websocket);
+ }
+ };
- const connections = new Array(10);
- const websockets = new Array(connections.length);
- var doneCounter = 0;
- await new Promise<void>(done => {
- for (var i = 0; i < connections.length; i++) {
- var j = i;
- var resolve: (_?: unknown) => void,
- reject: (_?: unknown) => void,
- resolveConnection: (_?: unknown) => void,
- rejectConnection: (_?: unknown) => void;
- connections[j] = new Promise((res, rej) => {
- resolveConnection = res;
- rejectConnection = rej;
- });
- websockets[j] = new Promise((res, rej) => {
- resolve = res;
- reject = rej;
- });
- gcTick();
- const websocket = new WebSocket(`ws://${server.hostname}:${server.port}`);
- websocket.onerror = e => {
- reject(e);
- };
- websocket.onclose = () => {
- doneCounter++;
- if (doneCounter === connections.length) {
- done();
- }
- };
- var hasOpened = false;
- websocket.onopen = () => {
- if (!hasOpened) {
- hasOpened = true;
- resolve(websocket);
- }
- };
-
- let clientCounter = -1;
- var hasSentThisTick = false;
-
- websocket.onmessage = e => {
- gcTick();
-
- if (!hasOpened) {
- hasOpened = true;
- resolve(websocket);
- }
+ let clientCounter = -1;
+ var hasSentThisTick = false;
- if (e.data === "hello world") {
- clientCounter = 0;
- websocket.send("first");
- return;
- }
+ websocket.onmessage = e => {
+ if (!hasOpened) {
+ hasOpened = true;
+ resolve(websocket);
+ }
- try {
- expect(!!sendQueue.find(a => a + " " === e.data)).toBe(true);
-
- if (!hasSentThisTick) {
- websocket.send("second");
- hasSentThisTick = true;
- queueMicrotask(() => {
- hasSentThisTick = false;
- });
+ if (e.data === "hello world") {
+ clientCounter = 0;
+ websocket.send("first");
+ return;
}
- gcTick();
+ try {
+ expect(!!sendQueue.find(a => a + " " === e.data)).toBe(true);
+
+ if (!hasSentThisTick) {
+ websocket.send("second");
+ hasSentThisTick = true;
+ queueMicrotask(() => {
+ hasSentThisTick = false;
+ });
+ }
- if (clientCounter++ === sendQueue.length - 1) {
+ if (clientCounter++ === sendQueue.length - 1) {
+ websocket.close();
+ resolveConnection();
+ }
+ } catch (r) {
+ console.error(r);
websocket.close();
- resolveConnection();
+ rejectConnection(r);
}
- } catch (r) {
- console.error(r);
- websocket.close();
- rejectConnection(r);
- gcTick();
- }
- };
- }
- });
+ };
+ }
+ });
+ } catch (e) {
+ throw e;
+ } finally {
+ server.stop(true);
+ gcTick();
+ }
+
expect(serverCounter).toBe(sendQueue.length);
- server.stop(true);
}, 30_000);
it("can close with reason and code #2631", done => {
let timeout: any;
diff --git a/test/js/web/websocket/websocket.test.js b/test/js/web/websocket/websocket.test.js
index 867b86123..76ff16ecb 100644
--- a/test/js/web/websocket/websocket.test.js
+++ b/test/js/web/websocket/websocket.test.js
@@ -6,16 +6,33 @@ const TEST_WEBSOCKET_HOST = process.env.TEST_WEBSOCKET_HOST || "wss://ws.postman
describe("WebSocket", () => {
it("should connect", async () => {
- const ws = new WebSocket(TEST_WEBSOCKET_HOST);
- await new Promise((resolve, reject) => {
+ const server = Bun.serve({
+ port: 0,
+ fetch(req, server) {
+ if (server.upgrade(req)) {
+ server.stop();
+ return;
+ }
+
+ return new Response();
+ },
+ websocket: {
+ open(ws) {},
+ message(ws) {
+ ws.close();
+ },
+ },
+ });
+ const ws = new WebSocket(`ws://${server.hostname}:${server.port}`, {});
+ await new Promise(resolve => {
ws.onopen = resolve;
- ws.onerror = reject;
});
- var closed = new Promise((resolve, reject) => {
+ var closed = new Promise(resolve => {
ws.onclose = resolve;
});
ws.close();
await closed;
+ server.stop(true);
});
it("should connect over https", async () => {
@@ -59,17 +76,18 @@ describe("WebSocket", () => {
const server = Bun.serve({
port: 0,
fetch(req, server) {
- server.stop();
done();
+ server.stop();
return new Response();
},
websocket: {
- open(ws) {
+ open(ws) {},
+ message(ws) {
ws.close();
},
},
});
- const ws = new WebSocket(`http://${server.hostname}:${server.port}`, {});
+ new WebSocket(`http://${server.hostname}:${server.port}`, {});
});
describe("nodebuffer", () => {
it("should support 'nodebuffer' binaryType", done => {
@@ -93,6 +111,7 @@ describe("WebSocket", () => {
expect(ws.binaryType).toBe("nodebuffer");
Bun.gc(true);
ws.onmessage = ({ data }) => {
+ ws.close();
expect(Buffer.isBuffer(data)).toBe(true);
expect(data).toEqual(new Uint8Array([1, 2, 3]));
server.stop(true);
@@ -117,6 +136,7 @@ describe("WebSocket", () => {
ws.sendBinary(new Uint8Array([1, 2, 3]));
setTimeout(() => {
client.onmessage = ({ data }) => {
+ client.close();
expect(Buffer.isBuffer(data)).toBe(true);
expect(data).toEqual(new Uint8Array([1, 2, 3]));
server.stop(true);