aboutsummaryrefslogtreecommitdiff
path: root/src/http/websocket_http_client.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r--src/http/websocket_http_client.zig206
1 files changed, 128 insertions, 78 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index e1bd42984..ae8e40763 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -60,7 +60,7 @@ fn buildRequestBody(
extra_headers: NonUTF8Headers,
) std.mem.Allocator.Error![]u8 {
const allocator = vm.allocator;
- const input_rand_buf = vm.rareData().nextUUID();
+ const input_rand_buf = vm.rareData().nextUUID().bytes;
const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16);
var encoded_buf: [temp_buf_size]u8 = undefined;
const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf);
@@ -77,9 +77,9 @@ fn buildRequestBody(
};
if (client_protocol.len > 0)
- client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value);
+ client_protocol_hash.* = bun.hash(static_headers[1].value);
- const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];
+ const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))];
const pathname_ = pathname.slice();
const host_ = host.slice();
@@ -145,6 +145,17 @@ const CppWebSocket = opaque {
pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode;
pub const didReceiveText = WebSocket__didReceiveText;
pub const didReceiveBytes = WebSocket__didReceiveBytes;
+ extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void;
+ extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void;
+ pub fn ref(this: *CppWebSocket) void {
+ JSC.markBinding(@src());
+ WebSocket__incrementPendingActivity(this);
+ }
+
+ pub fn unref(this: *CppWebSocket) void {
+ JSC.markBinding(@src());
+ WebSocket__decrementPendingActivity(this);
+ }
};
const body_buf_len = 16384 - 16;
@@ -163,8 +174,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
to_send: []const u8 = "",
read_length: usize = 0,
headers_buf: [128]PicoHTTP.Header = undefined,
- body_buf: ?*BodyBuf = null,
- body_written: usize = 0,
+ body: std.ArrayListUnmanaged(u8) = .{},
websocket_protocol: u64 = 0,
hostname: [:0]const u8 = "",
poll_ref: JSC.PollRef = .{},
@@ -280,16 +290,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get());
this.clearInput();
- if (this.body_buf) |buf| {
- this.body_buf = null;
- buf.release();
- }
+ this.body.clearAndFree(bun.default_allocator);
}
pub fn cancel(this: *HTTPClient) callconv(.C) void {
this.clearData();
if (!this.tcp.isEstablished()) {
- _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket);
+ _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket);
} else {
this.tcp.close(0, null);
}
@@ -355,14 +362,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.to_send = this.input_body_buf[@intCast(usize, wrote)..];
}
- fn getBody(this: *HTTPClient) *BodyBufBytes {
- if (this.body_buf == null) {
- this.body_buf = BodyBufPool.get(bun.default_allocator);
- }
-
- return &this.body_buf.?.data;
- }
-
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
log("onData", .{});
std.debug.assert(socket.socket == this.tcp.socket);
@@ -374,43 +373,37 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
if (comptime Environment.allow_assert)
std.debug.assert(!socket.isShutdown());
- var body = this.getBody();
- var remain = body[this.body_written..];
- const is_first = this.body_written == 0;
+ var body = data;
+ if (this.body.items.len > 0) {
+ this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory");
+ body = this.body.items;
+ }
+
+ const is_first = this.body.items.len == 0;
if (is_first) {
// fail early if we receive a non-101 status code
- if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) {
+ if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) {
this.terminate(ErrorCode.expected_101_status_code);
return;
}
}
- const to_write = remain[0..@min(remain.len, data.len)];
- if (data.len > 0 and to_write.len > 0) {
- @memcpy(remain.ptr, data.ptr, to_write.len);
- this.body_written += to_write.len;
- }
-
- const overflow = data[to_write.len..];
-
- const available_to_read = body[0..this.body_written];
- const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| {
+ const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| {
switch (err) {
error.Malformed_HTTP_Response => {
this.terminate(ErrorCode.invalid_response);
return;
},
error.ShortRead => {
- if (overflow.len > 0) {
- this.terminate(ErrorCode.headers_too_large);
- return;
+ if (this.body.items.len == 0) {
+ this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory");
}
return;
},
}
};
- this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]);
+ this.processResponse(response, body[@intCast(usize, response.bytes_read)..]);
}
pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
@@ -420,8 +413,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void {
- std.debug.assert(this.body_written > 0);
-
var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" };
var connection_header = PicoHTTP.Header{ .name = "", .value = "" };
var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" };
@@ -465,7 +456,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
},
"Sec-WebSocket-Protocol".len => {
if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) {
- if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) {
+ if (this.websocket_protocol == 0 or bun.hash(header.value) != this.websocket_protocol) {
this.terminate(ErrorCode.mismatch_client_protocol);
return;
}
@@ -524,7 +515,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.terminate(ErrorCode.invalid_response);
return;
};
- if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len);
+ @memcpy(overflow, remain_buf);
}
this.clearData();
@@ -757,7 +748,7 @@ const Copy = union(enum) {
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.latin1 => {
- byte_len.* = this.latin1.len;
+ byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1);
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.bytes => {
@@ -775,7 +766,7 @@ const Copy = union(enum) {
if (this == .raw) {
std.debug.assert(buf.len >= this.raw.len);
std.debug.assert(buf.ptr != this.raw.ptr);
- @memcpy(buf.ptr, this.raw.ptr, this.raw.len);
+ @memcpy(buf[0..this.raw.len], this.raw);
return;
}
@@ -821,7 +812,10 @@ const Copy = union(enum) {
.latin1 => |latin1| {
const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1);
std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len);
+
+ // latin1 can contain non-ascii
std.debug.assert(@as(usize, encode_into_result.read) == latin1.len);
+
header.len = WebsocketHeader.packLength(encode_into_result.written);
header.opcode = Opcode.Text;
var fib = std.io.fixedBufferStream(buf);
@@ -863,6 +857,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
globalThis: *JSC.JSGlobalObject,
poll_ref: JSC.PollRef = JSC.PollRef.init(),
+ initial_data_handler: ?*InitialDataHandler = null,
+
pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";
pub const shim = JSC.Shimmer("Bun", name, @This());
@@ -914,7 +910,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return;
if (!this.tcp.isEstablished()) {
- _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket);
+ _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket);
} else {
this.tcp.close(0, null);
}
@@ -924,6 +920,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
JSC.markBinding(@src());
if (this.outgoing_websocket) |ws| {
this.outgoing_websocket = null;
+ log("fail ({s})", .{@tagName(code)});
+
ws.didCloseWithErrorCode(code);
}
@@ -934,7 +932,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
_ = socket;
_ = ssl_error;
JSC.markBinding(@src());
- log("WebSocket.onHandshake({d})", .{success});
+ log("onHandshake({d})", .{success});
JSC.markBinding(@src());
if (success == 0) {
if (this.outgoing_websocket) |ws| {
@@ -1027,7 +1025,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
std.debug.assert(data_.len > 0);
var writable = this.receive_buffer.writableWithSize(data_.len) catch unreachable;
- @memcpy(writable.ptr, data_.ptr, data_.len);
+ @memcpy(writable[0..data_.len], data_);
this.receive_buffer.update(data_.len);
if (left_in_fragment >= data_.len and left_in_fragment - data_.len - this.receive_pending_chunk_len == 0) {
@@ -1041,6 +1039,24 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void {
+ // Due to scheduling, it is possible for the websocket onData
+ // handler to run with additional data before the microtask queue is
+ // drained.
+ if (this.initial_data_handler) |initial_handler| {
+ // This calls `handleData`
+ // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit.
+ // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there)
+ initial_handler.handleWithoutDeinit();
+
+ // handleWithoutDeinit is supposed to clear the handler from WebSocket*
+ // to prevent an infinite loop
+ std.debug.assert(this.initial_data_handler == null);
+
+ // If we disconnected for any reason in the re-entrant case, we should just ignore the data
+ if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed())
+ return;
+ }
+
var data = data_;
var receive_state = this.receive_state;
var terminated = false;
@@ -1138,6 +1154,30 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
terminated = true;
break;
}
+
+ // Handle when the payload length is 0, but it is a message
+ //
+ // This should become
+ //
+ // - ArrayBuffer(0)
+ // - ""
+ // - Buffer(0) (etc)
+ //
+ if (receive_body_remain == 0 and receive_state == .need_body and is_final) {
+ _ = this.consume(
+ "",
+ receive_body_remain,
+ last_receive_data_type,
+ is_final,
+ );
+
+ // Return to the header state to read the next frame
+ receive_state = .need_header;
+ is_fragmented = false;
+
+ // Bail out if there's nothing left to read
+ if (data.len == 0) break;
+ }
},
.need_mask => {
this.terminate(.unexpected_mask_from_server);
@@ -1177,10 +1217,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
.ping => {
const ping_len = @min(data.len, @min(receive_body_remain, 125));
- this.ping_len = @truncate(u8, ping_len);
+ this.ping_len = ping_len;
if (ping_len > 0) {
- @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len);
+ @memcpy(this.ping_frame_bytes[6..][0..ping_len], data[0..ping_len]);
data = data[ping_len..];
}
@@ -1198,6 +1238,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
if (data.len == 0) break;
},
.need_body => {
+ // Empty messages are valid, but we handle that earlier in the flow.
if (receive_body_remain == 0 and data.len > 0) {
this.terminate(ErrorCode.expected_control_frame);
terminated = true;
@@ -1379,7 +1420,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code);
if (body) |data| {
- if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len);
+ if (body_len > 0) @memcpy(final_body_bytes[8..][0..body_len], data[0..body_len]);
}
// we must mask the code
@@ -1431,9 +1472,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return;
}
- if (len == 0)
- return;
-
const slice = ptr[0..len];
const bytes = Copy{ .bytes = slice };
// fast path: small frame, no backpressure, attempt to send without allocating
@@ -1457,9 +1495,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return;
}
- if (str.len == 0) {
- return;
- }
+ // Note: 0 is valid
{
var inline_buf: [stack_frame_size]u8 = undefined;
@@ -1467,9 +1503,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
// fast path: small frame, no backpressure, attempt to send without allocating
if (!str.is16Bit() and str.len < stack_frame_size) {
const bytes = Copy{ .latin1 = str.slice() };
- const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len);
+ var byte_len: usize = 0;
+ const frame_size = bytes.len(&byte_len);
if (!this.hasBackpressure() and frame_size < stack_frame_size) {
- bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len);
+ bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len);
_ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]);
return;
}
@@ -1521,6 +1558,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
this.sendCloseWithBody(this.tcp, code, null, 0);
}
+ const InitialDataHandler = struct {
+ adopted: ?*WebSocket,
+ ws: *CppWebSocket,
+ slice: []u8,
+
+ pub const Handle = JSC.AnyTask.New(@This(), handle);
+
+ pub fn handleWithoutDeinit(this: *@This()) void {
+ var this_socket = this.adopted orelse return;
+ this.adopted = null;
+ this_socket.initial_data_handler = null;
+ var ws = this.ws;
+ defer ws.unref();
+
+ if (this_socket.outgoing_websocket != null)
+ this_socket.handleData(this_socket.tcp, this.slice);
+ }
+
+ pub fn handle(this: *@This()) void {
+ defer {
+ bun.default_allocator.free(this.slice);
+ bun.default_allocator.destroy(this);
+ }
+ this.handleWithoutDeinit();
+ }
+ };
+
pub fn init(
outgoing: *CppWebSocket,
input_socket: *anyopaque,
@@ -1550,33 +1614,19 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
var buffered_slice: []u8 = buffered_data[0..buffered_data_len];
if (buffered_slice.len > 0) {
- const InitialDataHandler = struct {
- adopted: *WebSocket,
- slice: []u8,
- task: JSC.AnyTask = undefined,
-
- pub const Handle = JSC.AnyTask.New(@This(), handle);
-
- pub fn handle(this: *@This()) void {
- defer {
- bun.default_allocator.free(this.slice);
- bun.default_allocator.destroy(this);
- }
-
- this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return;
- var writable = this.adopted.receive_buffer.writableSlice(0);
- @memcpy(writable.ptr, this.slice.ptr, this.slice.len);
-
- this.adopted.handleData(this.adopted.tcp, writable);
- }
- };
var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable;
initial_data.* = .{
.adopted = adopted,
.slice = buffered_slice,
+ .ws = outgoing,
};
- initial_data.task = InitialDataHandler.Handle.init(initial_data);
- globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task));
+
+ // Use a higher-priority callback for the initial onData handler
+ globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle);
+
+ // We need to ref the outgoing websocket so that it doesn't get finalized
+ // before the initial data handler is called
+ outgoing.ref();
}
return @ptrCast(
*anyopaque,