diff options
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r-- | src/http/websocket_http_client.zig | 206 |
1 files changed, 128 insertions, 78 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index e1bd42984..ae8e40763 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -60,7 +60,7 @@ fn buildRequestBody( extra_headers: NonUTF8Headers, ) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; - const input_rand_buf = vm.rareData().nextUUID(); + const input_rand_buf = vm.rareData().nextUUID().bytes; const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); @@ -77,9 +77,9 @@ fn buildRequestBody( }; if (client_protocol.len > 0) - client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value); + client_protocol_hash.* = bun.hash(static_headers[1].value); - const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; + const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))]; const pathname_ = pathname.slice(); const host_ = host.slice(); @@ -145,6 +145,17 @@ const CppWebSocket = opaque { pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode; pub const didReceiveText = WebSocket__didReceiveText; pub const didReceiveBytes = WebSocket__didReceiveBytes; + extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void; + extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void; + pub fn ref(this: *CppWebSocket) void { + JSC.markBinding(@src()); + WebSocket__incrementPendingActivity(this); + } + + pub fn unref(this: *CppWebSocket) void { + JSC.markBinding(@src()); + WebSocket__decrementPendingActivity(this); + } }; const body_buf_len = 16384 - 16; @@ -163,8 +174,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, - body_buf: ?*BodyBuf = null, - body_written: usize = 0, + body: std.ArrayListUnmanaged(u8) = .{}, websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: JSC.PollRef = .{}, @@ -280,16 +290,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get()); this.clearInput(); - if (this.body_buf) |buf| { - this.body_buf = null; - buf.release(); - } + this.body.clearAndFree(bun.default_allocator); } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); + _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } @@ -355,14 +362,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; } - fn getBody(this: *HTTPClient) *BodyBufBytes { - if (this.body_buf == null) { - this.body_buf = BodyBufPool.get(bun.default_allocator); - } - - return &this.body_buf.?.data; - } - pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); std.debug.assert(socket.socket == this.tcp.socket); @@ -374,43 +373,37 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); - var body = this.getBody(); - var remain = body[this.body_written..]; - const is_first = this.body_written == 0; + var body = data; + if (this.body.items.len > 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); + body = this.body.items; + } + + const is_first = this.body.items.len == 0; if (is_first) { // fail early if we receive a non-101 status code - if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) { + if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } - const to_write = remain[0..@min(remain.len, data.len)]; - if (data.len > 0 and to_write.len > 0) { - @memcpy(remain.ptr, data.ptr, to_write.len); - this.body_written += to_write.len; - } - - const overflow = data[to_write.len..]; - - const available_to_read = body[0..this.body_written]; - const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { + const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { - if (overflow.len > 0) { - this.terminate(ErrorCode.headers_too_large); - return; + if (this.body.items.len == 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); } return; }, } }; - this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); + this.processResponse(response, body[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { @@ -420,8 +413,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { - std.debug.assert(this.body_written > 0); - var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; @@ -465,7 +456,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { }, "Sec-WebSocket-Protocol".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { - if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) { + if (this.websocket_protocol == 0 or bun.hash(header.value) != this.websocket_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } @@ -524,7 +515,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.invalid_response); return; }; - if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); + @memcpy(overflow, remain_buf); } this.clearData(); @@ -757,7 +748,7 @@ const Copy = union(enum) { return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { - byte_len.* = this.latin1.len; + byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { @@ -775,7 +766,7 @@ const Copy = union(enum) { if (this == .raw) { std.debug.assert(buf.len >= this.raw.len); std.debug.assert(buf.ptr != this.raw.ptr); - @memcpy(buf.ptr, this.raw.ptr, this.raw.len); + @memcpy(buf[0..this.raw.len], this.raw); return; } @@ -821,7 +812,10 @@ const Copy = union(enum) { .latin1 => |latin1| { const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); + + // latin1 can contain non-ascii std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); + header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; var fib = std.io.fixedBufferStream(buf); @@ -863,6 +857,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { globalThis: *JSC.JSGlobalObject, poll_ref: JSC.PollRef = JSC.PollRef.init(), + initial_data_handler: ?*InitialDataHandler = null, + pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); @@ -914,7 +910,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); + _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } @@ -924,6 +920,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; + log("fail ({s})", .{@tagName(code)}); + ws.didCloseWithErrorCode(code); } @@ -934,7 +932,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { _ = socket; _ = ssl_error; JSC.markBinding(@src()); - log("WebSocket.onHandshake({d})", .{success}); + log("onHandshake({d})", .{success}); JSC.markBinding(@src()); if (success == 0) { if (this.outgoing_websocket) |ws| { @@ -1027,7 +1025,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { std.debug.assert(data_.len > 0); var writable = this.receive_buffer.writableWithSize(data_.len) catch unreachable; - @memcpy(writable.ptr, data_.ptr, data_.len); + @memcpy(writable[0..data_.len], data_); this.receive_buffer.update(data_.len); if (left_in_fragment >= data_.len and left_in_fragment - data_.len - this.receive_pending_chunk_len == 0) { @@ -1041,6 +1039,24 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { + // Due to scheduling, it is possible for the websocket onData + // handler to run with additional data before the microtask queue is + // drained. + if (this.initial_data_handler) |initial_handler| { + // This calls `handleData` + // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit. + // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there) + initial_handler.handleWithoutDeinit(); + + // handleWithoutDeinit is supposed to clear the handler from WebSocket* + // to prevent an infinite loop + std.debug.assert(this.initial_data_handler == null); + + // If we disconnected for any reason in the re-entrant case, we should just ignore the data + if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed()) + return; + } + var data = data_; var receive_state = this.receive_state; var terminated = false; @@ -1138,6 +1154,30 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { terminated = true; break; } + + // Handle when the payload length is 0, but it is a message + // + // This should become + // + // - ArrayBuffer(0) + // - "" + // - Buffer(0) (etc) + // + if (receive_body_remain == 0 and receive_state == .need_body and is_final) { + _ = this.consume( + "", + receive_body_remain, + last_receive_data_type, + is_final, + ); + + // Return to the header state to read the next frame + receive_state = .need_header; + is_fragmented = false; + + // Bail out if there's nothing left to read + if (data.len == 0) break; + } }, .need_mask => { this.terminate(.unexpected_mask_from_server); @@ -1177,10 +1217,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { .ping => { const ping_len = @min(data.len, @min(receive_body_remain, 125)); - this.ping_len = @truncate(u8, ping_len); + this.ping_len = ping_len; if (ping_len > 0) { - @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len); + @memcpy(this.ping_frame_bytes[6..][0..ping_len], data[0..ping_len]); data = data[ping_len..]; } @@ -1198,6 +1238,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (data.len == 0) break; }, .need_body => { + // Empty messages are valid, but we handle that earlier in the flow. if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; @@ -1379,7 +1420,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code); if (body) |data| { - if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len); + if (body_len > 0) @memcpy(final_body_bytes[8..][0..body_len], data[0..body_len]); } // we must mask the code @@ -1431,9 +1472,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (len == 0) - return; - const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating @@ -1457,9 +1495,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (str.len == 0) { - return; - } + // Note: 0 is valid { var inline_buf: [stack_frame_size]u8 = undefined; @@ -1467,9 +1503,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // fast path: small frame, no backpressure, attempt to send without allocating if (!str.is16Bit() and str.len < stack_frame_size) { const bytes = Copy{ .latin1 = str.slice() }; - const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); + var byte_len: usize = 0; + const frame_size = bytes.len(&byte_len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { - bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len); + bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } @@ -1521,6 +1558,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.sendCloseWithBody(this.tcp, code, null, 0); } + const InitialDataHandler = struct { + adopted: ?*WebSocket, + ws: *CppWebSocket, + slice: []u8, + + pub const Handle = JSC.AnyTask.New(@This(), handle); + + pub fn handleWithoutDeinit(this: *@This()) void { + var this_socket = this.adopted orelse return; + this.adopted = null; + this_socket.initial_data_handler = null; + var ws = this.ws; + defer ws.unref(); + + if (this_socket.outgoing_websocket != null) + this_socket.handleData(this_socket.tcp, this.slice); + } + + pub fn handle(this: *@This()) void { + defer { + bun.default_allocator.free(this.slice); + bun.default_allocator.destroy(this); + } + this.handleWithoutDeinit(); + } + }; + pub fn init( outgoing: *CppWebSocket, input_socket: *anyopaque, @@ -1550,33 +1614,19 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { - const InitialDataHandler = struct { - adopted: *WebSocket, - slice: []u8, - task: JSC.AnyTask = undefined, - - pub const Handle = JSC.AnyTask.New(@This(), handle); - - pub fn handle(this: *@This()) void { - defer { - bun.default_allocator.free(this.slice); - bun.default_allocator.destroy(this); - } - - this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; - var writable = this.adopted.receive_buffer.writableSlice(0); - @memcpy(writable.ptr, this.slice.ptr, this.slice.len); - - this.adopted.handleData(this.adopted.tcp, writable); - } - }; var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ .adopted = adopted, .slice = buffered_slice, + .ws = outgoing, }; - initial_data.task = InitialDataHandler.Handle.init(initial_data); - globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task)); + + // Use a higher-priority callback for the initial onData handler + globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle); + + // We need to ref the outgoing websocket so that it doesn't get finalized + // before the initial data handler is called + outgoing.ref(); } return @ptrCast( *anyopaque, |