// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig // Thank you @frmdstryr. const std = @import("std"); const native_endian = @import("builtin").target.cpu.arch.endian(); const bun = @import("root").bun; const string = bun.string; const Output = bun.Output; const Global = bun.Global; const Environment = bun.Environment; const strings = bun.strings; const MutableString = bun.MutableString; const stringZ = bun.stringZ; const default_allocator = bun.default_allocator; const C = bun.C; const uws = @import("root").bun.uws; const JSC = @import("root").bun.JSC; const PicoHTTP = @import("root").bun.picohttp; const ObjectPool = @import("../pool.zig").ObjectPool; const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; const Opcode = @import("./websocket.zig").Opcode; const log = Output.scoped(.WebSocketClient, false); const NonUTF8Headers = struct { names: []const JSC.ZigString, values: []const JSC.ZigString, pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { const count = self.names.len; var i: usize = 0; while (i < count) : (i += 1) { try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] }); } } pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers { if (len == 0) { return .{ .names = &[_]JSC.ZigString{}, .values = &[_]JSC.ZigString{}, }; } return .{ .names = names.?[0..len], .values = values.?[0..len], }; } }; fn buildRequestBody( vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64, extra_headers: NonUTF8Headers, ) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; const input_rand_buf = vm.rareData().nextUUID().bytes; const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); var static_headers = [_]PicoHTTP.Header{ .{ .name = "Sec-WebSocket-Key", .value = accept_key, }, .{ .name = "Sec-WebSocket-Protocol", .value = client_protocol.slice(), }, }; if (client_protocol.len > 0) client_protocol_hash.* = bun.hash(static_headers[1].value); const pathname_ = pathname.toSlice(allocator); const host_ = host.toSlice(allocator); defer { pathname_.deinit(); host_.deinit(); } const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))]; const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; return try std.fmt.allocPrint( allocator, "GET {s} HTTP/1.1\r\n" ++ "Host: {s}\r\n" ++ "Pragma: no-cache\r\n" ++ "Cache-Control: no-cache\r\n" ++ "Connection: Upgrade\r\n" ++ "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{s}" ++ "{s}" ++ "\r\n", .{ pathname_.slice(), host_.slice(), pico_headers, extra_headers }, ); } const ErrorCode = enum(i32) { cancel, invalid_response, expected_101_status_code, missing_upgrade_header, missing_connection_header, missing_websocket_accept_header, invalid_upgrade_header, invalid_connection_header, invalid_websocket_version, mismatch_websocket_accept_header, missing_client_protocol, mismatch_client_protocol, timeout, closed, failed_to_write, failed_to_connect, headers_too_large, ended, failed_to_allocate_memory, control_frame_is_fragmented, invalid_control_frame, compression_unsupported, unexpected_mask_from_server, expected_control_frame, unsupported_control_frame, unexpected_opcode, invalid_utf8, }; const CppWebSocket = opaque { extern fn WebSocket__didConnect( websocket_context: *CppWebSocket, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize, ) void; extern fn WebSocket__didAbruptClose(websocket_context: *CppWebSocket, reason: ErrorCode) void; extern fn WebSocket__didClose(websocket_context: *CppWebSocket, code: u16, reason: *const bun.String) void; extern fn WebSocket__didReceiveText(websocket_context: *CppWebSocket, clone: bool, text: *const JSC.ZigString) void; extern fn WebSocket__didReceiveBytes(websocket_context: *CppWebSocket, bytes: [*]const u8, byte_len: usize, opcode: u8) void; pub const didConnect = WebSocket__didConnect; pub const didAbruptClose = WebSocket__didAbruptClose; pub const didClose = WebSocket__didClose; pub const didReceiveText = WebSocket__didReceiveText; pub const didReceiveBytes = WebSocket__didReceiveBytes; extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void; extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void; pub fn ref(this: *CppWebSocket) void { JSC.markBinding(@src()); WebSocket__incrementPendingActivity(this); } pub fn unref(this: *CppWebSocket) void { JSC.markBinding(@src()); WebSocket__decrementPendingActivity(this); } }; const body_buf_len = 16384 - 16; const BodyBufBytes = [body_buf_len]u8; const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4); const BodyBuf = BodyBufPool.Node; pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: ?*CppWebSocket, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, body: std.ArrayListUnmanaged(u8) = .{}, websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: JSC.PollRef = .{}, pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const HTTPClient = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @as(*uws.Loop, @ptrCast(@alignCast(loop_))); var ctx: *uws.SocketContext = @as(*uws.SocketContext, @ptrCast(ctx_)); if (vm.event_loop_handle) |other| { std.debug.assert(other == loop); } const is_new_loop = vm.event_loop_handle == null; vm.event_loop_handle = loop; Socket.configure( ctx, false, HTTPClient, struct { pub const onOpen = handleOpen; pub const onClose = handleClose; pub const onData = handleData; pub const onWritable = handleWritable; pub const onTimeout = handleTimeout; pub const onConnectError = handleConnectError; pub const onEnd = handleEnd; pub const onHandshake = handleHandshake; }, ); if (is_new_loop) { vm.prepareLoop(); } } pub fn connect( global: *JSC.JSGlobalObject, socket_ctx: *anyopaque, websocket: *CppWebSocket, host: *const JSC.ZigString, port: u16, pathname: *const JSC.ZigString, client_protocol: *const JSC.ZigString, header_names: ?[*]const JSC.ZigString, header_values: ?[*]const JSC.ZigString, header_count: usize, ) callconv(.C) ?*HTTPClient { std.debug.assert(global.bunVM().event_loop_handle != null); var client_protocol_hash: u64 = 0; var body = buildRequestBody( global.bunVM(), pathname, host, client_protocol, &client_protocol_hash, NonUTF8Headers.init(header_names, header_values, header_count), ) catch return null; var client: HTTPClient = HTTPClient{ .tcp = undefined, .outgoing_websocket = websocket, .input_body_buf = body, .websocket_protocol = client_protocol_hash, }; var host_ = host.toSlice(bun.default_allocator); defer host_.deinit(); var vm = global.bunVM(); const prev_start_server_on_next_tick = vm.eventLoop().start_server_on_next_tick; vm.eventLoop().start_server_on_next_tick = true; client.poll_ref.ref(vm); const display_host_ = host_.slice(); const display_host = if (bun.FeatureFlags.hardcode_localhost_to_127_0_0_1 and strings.eqlComptime(display_host_, "localhost")) "127.0.0.1" else display_host_; if (Socket.connect( display_host, port, @as(*uws.SocketContext, @ptrCast(socket_ctx)), HTTPClient, client, "tcp", )) |out| { if (comptime ssl) { if (!strings.isIPAddress(host_.slice())) { out.hostname = bun.default_allocator.dupeZ(u8, host_.slice()) catch ""; } } out.tcp.timeout(120); return out; } vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick; client.clearData(); return null; } pub fn clearInput(this: *HTTPClient) void { if (this.input_body_buf.len > 0) bun.default_allocator.free(this.input_body_buf); this.input_body_buf.len = 0; } pub fn clearData(this: *HTTPClient) void { this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get()); this.clearInput(); this.body.clearAndFree(bun.default_allocator); } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { log("onFail", .{}); JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; ws.didAbruptClose(code); } this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { log("onClose", .{}); JSC.markBinding(@src()); this.clearData(); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; ws.didAbruptClose(ErrorCode.ended); } } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); if (!this.tcp.isClosed()) this.tcp.close(0, null); } pub fn handleHandshake(this: *HTTPClient, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { _ = socket; _ = ssl_error; log("onHandshake({d})", .{success}); if (success == 0) { this.fail(ErrorCode.failed_to_connect); } } pub fn handleOpen(this: *HTTPClient, socket: Socket) void { log("onOpen", .{}); std.debug.assert(socket.socket == this.tcp.socket); std.debug.assert(this.input_body_buf.len > 0); std.debug.assert(this.to_send.len == 0); if (comptime ssl) { if (this.hostname.len > 0) { socket.getNativeHandle().configureHTTPClient(this.hostname); bun.default_allocator.free(this.hostname); this.hostname = ""; } } // Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010 const wrote = socket.write(this.input_body_buf, false); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } this.to_send = this.input_body_buf[@as(usize, @intCast(wrote))..]; } pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); defer JSC.VirtualMachine.get().drainMicrotasks(); std.debug.assert(socket.socket == this.tcp.socket); if (this.outgoing_websocket == null) { this.clearData(); return; } if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); var body = data; if (this.body.items.len > 0) { this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); body = this.body.items; } const is_first = this.body.items.len == 0; if (is_first) { // fail early if we receive a non-101 status code if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { if (this.body.items.len == 0) { this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); } return; }, } }; this.processResponse(response, body[@as(usize, @intCast(response.bytes_read))..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { log("onEnd", .{}); std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; var visited_protocol = this.websocket_protocol == 0; // var visited_version = false; std.debug.assert(response.status_code == 101); for (response.headers) |header| { switch (header.name.len) { "Connection".len => { if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) { connection_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Upgrade".len => { if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) { upgrade_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Version".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) { if (!strings.eqlComptimeIgnoreLen(header.value, "13")) { this.terminate(ErrorCode.invalid_websocket_version); return; } } }, "Sec-WebSocket-Accept".len => { if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) { websocket_accept_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Protocol".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { if (this.websocket_protocol == 0 or bun.hash(header.value) != this.websocket_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } visited_protocol = true; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, else => {}, } } // if (!visited_version) { // this.terminate(ErrorCode.invalid_websocket_version); // return; // } if (@min(upgrade_header.name.len, upgrade_header.value.len) == 0) { this.terminate(ErrorCode.missing_upgrade_header); return; } if (@min(connection_header.name.len, connection_header.value.len) == 0) { this.terminate(ErrorCode.missing_connection_header); return; } if (@min(websocket_accept_header.name.len, websocket_accept_header.value.len) == 0) { this.terminate(ErrorCode.missing_websocket_accept_header); return; } if (!visited_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } if (!strings.eqlCaseInsensitiveASCII(connection_header.value, "Upgrade", true)) { this.terminate(ErrorCode.invalid_connection_header); return; } if (!strings.eqlCaseInsensitiveASCII(upgrade_header.value, "websocket", true)) { this.terminate(ErrorCode.invalid_upgrade_header); return; } // TODO: check websocket_accept_header.value const overflow_len = remain_buf.len; var overflow: []u8 = &.{}; if (overflow_len > 0) { overflow = bun.default_allocator.alloc(u8, overflow_len) catch { this.terminate(ErrorCode.invalid_response); return; }; @memcpy(overflow, remain_buf); } this.clearData(); JSC.markBinding(@src()); this.tcp.timeout(0); log("onDidConnect", .{}); this.outgoing_websocket.?.didConnect(this.tcp.socket, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); if (this.to_send.len == 0) return; // Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010 const wrote = socket.write(this.to_send, false); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } this.to_send = this.to_send[@min(@as(usize, @intCast(wrote)), this.to_send.len)..]; } pub fn handleTimeout( this: *HTTPClient, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub const Export = shim.exportFunctions(.{ .connect = connect, .cancel = cancel, .register = register, }); comptime { if (!JSC.is_bindgen) { @export(connect, .{ .name = Export[0].symbol_name, }); @export(cancel, .{ .name = Export[1].symbol_name, }); @export(register, .{ .name = Export[2].symbol_name, }); } } }; } pub const Mask = struct { pub fn fill(globalThis: *JSC.JSGlobalObject, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void { mask_buf.* = globalThis.bunVM().rareData().entropySlice(4)[0..4].*; const mask = mask_buf.*; const skip_mask = @as(u32, @bitCast(mask)) == 0; if (!skip_mask) { fillWithSkipMask(mask, output_, input_, false); } else { fillWithSkipMask(mask, output_, input_, true); } } fn fillWithSkipMask(mask: [4]u8, output_: []u8, input_: []const u8, comptime skip_mask: bool) void { var input = input_; var output = output_; if (comptime Environment.enableSIMD) { if (input.len >= strings.ascii_vector_size) { const vec: strings.AsciiVector = brk: { var in: [strings.ascii_vector_size]u8 = undefined; comptime var i: usize = 0; inline while (i < strings.ascii_vector_size) : (i += 4) { in[i..][0..4].* = mask; } break :brk @as(strings.AsciiVector, in); }; const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size); if (comptime skip_mask) { while (input.ptr != end_ptr_wrapped_to_last_16) { const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); output.ptr[0..strings.ascii_vector_size].* = input_vec; output = output[strings.ascii_vector_size..]; input = input[strings.ascii_vector_size..]; } } else { while (input.ptr != end_ptr_wrapped_to_last_16) { const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; output = output[strings.ascii_vector_size..]; input = input[strings.ascii_vector_size..]; } } } // hint to the compiler not to vectorize the next loop std.debug.assert(input.len < strings.ascii_vector_size); } if (comptime !skip_mask) { while (input.len >= 4) { const input_vec: [4]u8 = input[0..4].*; output.ptr[0..4].* = [4]u8{ input_vec[0] ^ mask[0], input_vec[1] ^ mask[1], input_vec[2] ^ mask[2], input_vec[3] ^ mask[3], }; output = output[4..]; input = input[4..]; } } else { while (input.len >= 4) { const input_vec: [4]u8 = input[0..4].*; output.ptr[0..4].* = input_vec; output = output[4..]; input = input[4..]; } } if (comptime !skip_mask) { for (input, 0..) |c, i| { output[i] = c ^ mask[i % 4]; } } else { for (input, 0..) |c, i| { output[i] = c; } } } }; const ReceiveState = enum { need_header, need_mask, need_body, extended_payload_length_16, extended_payload_length_64, ping, pong, close, fail, pub fn needControlFrame(this: ReceiveState) bool { return this != .need_body; } }; const DataType = enum { none, text, binary, }; fn parseWebSocketHeader( bytes: [2]u8, receiving_type: *Opcode, payload_length: *usize, is_fragmented: *bool, is_final: *bool, need_compression: *bool, ) ReceiveState { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ const header = @as(WebsocketHeader, @bitCast(@byteSwap(@as(u16, @bitCast(bytes))))); const payload = @as(usize, header.len); payload_length.* = payload; receiving_type.* = header.opcode; is_fragmented.* = switch (header.opcode) { .Continue => true, else => false, } or !header.final; is_final.* = header.final; need_compression.* = header.compressed; if (header.mask and (header.opcode == .Text or header.opcode == .Binary)) { return .need_mask; } return switch (header.opcode) { .Text, .Continue, .Binary => if (payload <= 125) return .need_body else if (payload == 126) return .extended_payload_length_16 else if (payload == 127) return .extended_payload_length_64 else return .fail, .Close => .close, .Ping => .ping, .Pong => .pong, else => .fail, }; } const Copy = union(enum) { utf16: []const u16, latin1: []const u8, bytes: []const u8, raw: []const u8, pub fn len(this: @This(), byte_len: *usize) usize { switch (this) { .utf16 => { byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { byte_len.* = this.bytes.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .raw => { byte_len.* = this.raw.len; return this.raw.len; }, } } pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize, opcode: Opcode) void { if (this == .raw) { std.debug.assert(buf.len >= this.raw.len); std.debug.assert(buf.ptr != this.raw.ptr); @memcpy(buf[0..this.raw.len], this.raw); return; } const how_big_is_the_length_integer = WebsocketHeader.lengthByteCount(content_byte_len); const how_big_is_the_mask = 4; const mask_offset = 2 + how_big_is_the_length_integer; const content_offset = mask_offset + how_big_is_the_mask; // 2 byte header // 4 byte mask // 0, 2, 8 byte length var to_mask = buf[content_offset..]; var header = @as(WebsocketHeader, @bitCast(@as(u16, 0))); // Write extended length if needed switch (how_big_is_the_length_integer) { 0 => {}, 2 => std.mem.writeIntBig(u16, buf[2..][0..2], @as(u16, @truncate(content_byte_len))), 8 => std.mem.writeIntBig(u64, buf[2..][0..8], @as(u64, @truncate(content_byte_len))), else => unreachable, } header.mask = true; header.compressed = false; header.final = true; header.opcode = opcode; std.debug.assert(WebsocketHeader.frameSizeIncludingMask(content_byte_len) == buf.len); switch (this) { .utf16 => |utf16| { header.len = WebsocketHeader.packLength(content_byte_len); const encode_into_result = strings.copyUTF16IntoUTF8(to_mask, []const u16, utf16, true); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == utf16.len); header.len = WebsocketHeader.packLength(encode_into_result.written); var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .latin1 => |latin1| { const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); // latin1 can contain non-ascii std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); header.len = WebsocketHeader.packLength(encode_into_result.written); var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .bytes => |bytes| { header.len = WebsocketHeader.packLength(bytes.len); var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), bytes.len) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], bytes); }, .raw => unreachable, } } }; pub fn NewWebSocketClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: ?*CppWebSocket = null, receive_state: ReceiveState = ReceiveState.need_header, receive_header: WebsocketHeader = @as(WebsocketHeader, @bitCast(@as(u16, 0))), receiving_type: Opcode = Opcode.ResB, ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6), ping_len: u8 = 0, ping_received: bool = false, receive_frame: usize = 0, receive_body_remain: usize = 0, receive_pending_chunk_len: usize = 0, receive_buffer: bun.LinearFifo(u8, .Dynamic), send_buffer: bun.LinearFifo(u8, .Dynamic), globalThis: *JSC.JSGlobalObject, poll_ref: JSC.PollRef = JSC.PollRef.init(), initial_data_handler: ?*InitialDataHandler = null, pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const stack_frame_size = 1024; const WebSocket = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @as(*uws.Loop, @ptrCast(@alignCast(loop_))); var ctx: *uws.SocketContext = @as(*uws.SocketContext, @ptrCast(ctx_)); if (vm.event_loop_handle) |other| { std.debug.assert(other == loop); } vm.event_loop_handle = loop; Socket.configure( ctx, false, WebSocket, struct { pub const onClose = handleClose; pub const onData = handleData; pub const onWritable = handleWritable; pub const onTimeout = handleTimeout; pub const onConnectError = handleConnectError; pub const onEnd = handleEnd; // just by adding it will fix ssl handshake pub const onHandshake = handleHandshake; }, ); } pub fn clearData(this: *WebSocket) void { this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); this.clearReceiveBuffers(true); this.clearSendBuffers(true); this.ping_received = false; this.ping_len = 0; this.receive_pending_chunk_len = 0; } pub fn cancel(this: *WebSocket) callconv(.C) void { this.clearData(); if (this.tcp.isClosed() or this.tcp.isShutdown()) return; if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *WebSocket, code: ErrorCode) void { JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; log("fail ({s})", .{@tagName(code)}); ws.didAbruptClose(code); } this.cancel(); } pub fn handleHandshake(this: *WebSocket, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { JSC.markBinding(@src()); _ = socket; _ = ssl_error; JSC.markBinding(@src()); log("onHandshake({d})", .{success}); JSC.markBinding(@src()); if (success == 0) { if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; ws.didAbruptClose(ErrorCode.failed_to_connect); } } } pub fn handleClose(this: *WebSocket, _: Socket, _: c_int, _: ?*anyopaque) void { log("onClose", .{}); JSC.markBinding(@src()); this.clearData(); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; ws.didAbruptClose(ErrorCode.ended); } } pub fn terminate(this: *WebSocket, code: ErrorCode) void { this.fail(code); } fn clearReceiveBuffers(this: *WebSocket, free: bool) void { this.receive_buffer.head = 0; this.receive_buffer.count = 0; if (free) { this.receive_buffer.deinit(); this.receive_buffer.buf.len = 0; } this.receive_pending_chunk_len = 0; this.receive_body_remain = 0; } fn clearSendBuffers(this: *WebSocket, free: bool) void { this.send_buffer.head = 0; this.send_buffer.count = 0; if (free) { this.send_buffer.deinit(); this.send_buffer.buf.len = 0; } } fn dispatchData(this: *WebSocket, data_: []const u8, kind: Opcode) void { var out = this.outgoing_websocket orelse { this.clearData(); return; }; switch (kind) { .Text => { // this function encodes to UTF-16 if > 127 // so we don't need to worry about latin1 non-ascii code points const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch { this.terminate(ErrorCode.invalid_utf8); return; }; var outstring = JSC.ZigString.Empty; if (utf16_bytes_) |utf16| { outstring = JSC.ZigString.from16Slice(utf16); outstring.mark(); JSC.markBinding(@src()); out.didReceiveText(false, &outstring); } else { outstring = JSC.ZigString.init(data_); JSC.markBinding(@src()); out.didReceiveText(true, &outstring); } }, .Binary, .Ping, .Pong => { JSC.markBinding(@src()); out.didReceiveBytes(data_.ptr, data_.len, @as(u8, @intFromEnum(kind))); }, else => unreachable, } } pub fn consume(this: *WebSocket, data_: []const u8, left_in_fragment: usize, kind: Opcode, is_final: bool) usize { std.debug.assert(data_.len <= left_in_fragment); // did all the data fit in the buffer? // we can avoid copying & allocating a temporary buffer if (is_final and data_.len == left_in_fragment and this.receive_pending_chunk_len == 0) { this.dispatchData(data_, kind); return data_.len; } // this must come after the above check std.debug.assert(data_.len > 0); var writable = this.receive_buffer.writableWithSize(data_.len) catch unreachable; @memcpy(writable[0..data_.len], data_); this.receive_buffer.update(data_.len); if (left_in_fragment >= data_.len and left_in_fragment - data_.len - this.receive_pending_chunk_len == 0) { this.receive_pending_chunk_len = 0; this.dispatchData(this.receive_buffer.readableSlice(0), kind); this.clearReceiveBuffers(false); } else { this.receive_pending_chunk_len -|= left_in_fragment; } return data_.len; } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { // This is the start of a task, so we need to drain the microtask queue at the end defer JSC.VirtualMachine.get().drainMicrotasks(); // Due to scheduling, it is possible for the websocket onData // handler to run with additional data before the microtask queue is // drained. if (this.initial_data_handler) |initial_handler| { // This calls `handleData` // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit. // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there) initial_handler.handleWithoutDeinit(); // handleWithoutDeinit is supposed to clear the handler from WebSocket* // to prevent an infinite loop std.debug.assert(this.initial_data_handler == null); // If we disconnected for any reason in the re-entrant case, we should just ignore the data if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed()) return; } var data = data_; var receive_state = this.receive_state; var terminated = false; var is_fragmented = false; var receiving_type = this.receiving_type; var receive_body_remain = this.receive_body_remain; var is_final = false; var last_receive_data_type = receiving_type; defer { if (!terminated) { this.receive_state = receive_state; this.receiving_type = last_receive_data_type; this.receive_body_remain = receive_body_remain; // if we receive multiple pings in a row // we just send back the last one if (this.ping_received) { _ = this.sendPong(socket); this.ping_received = false; } } } var header_bytes: [@sizeOf(usize)]u8 = [_]u8{0} ** @sizeOf(usize); while (true) { log("onData ({s})", .{@tagName(receive_state)}); switch (receive_state) { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ .need_header => { if (data.len < 2) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } header_bytes[0..2].* = data[0..2].*; receive_body_remain = 0; var need_compression = false; is_final = false; receive_state = parseWebSocketHeader( header_bytes[0..2].*, &receiving_type, &receive_body_remain, &is_fragmented, &is_final, &need_compression, ); last_receive_data_type = if (receiving_type == .Text or receiving_type == .Binary) receiving_type else last_receive_data_type; data = data[2..]; if (receiving_type.isControl() and is_fragmented) { // Control frames must not be fragmented. this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } switch (receiving_type) { .Continue, .Text, .Binary, .Ping, .Pong, .Close => {}, else => { this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, } if (need_compression) { this.terminate(ErrorCode.compression_unsupported); terminated = true; break; } // Handle when the payload length is 0, but it is a message // // This should become // // - ArrayBuffer(0) // - "" // - Buffer(0) (etc) // if (receive_body_remain == 0 and receive_state == .need_body and is_final) { _ = this.consume( "", receive_body_remain, last_receive_data_type, is_final, ); // Return to the header state to read the next frame receive_state = .need_header; is_fragmented = false; // Bail out if there's nothing left to read if (data.len == 0) break; } }, .need_mask => { this.terminate(.unexpected_mask_from_server); terminated = true; break; }, .extended_payload_length_64, .extended_payload_length_16 => |rc| { const byte_size = switch (rc) { .extended_payload_length_64 => @as(usize, 8), .extended_payload_length_16 => @as(usize, 2), else => unreachable, }; if (data.len < byte_size) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } // Multibyte length quantities are expressed in network byte order receive_body_remain = switch (byte_size) { 8 => @as(usize, std.mem.readIntBig(u64, data[0..8])), 2 => @as(usize, std.mem.readIntBig(u16, data[0..2])), else => unreachable, }; data = data[byte_size..]; receive_state = .need_body; if (receive_body_remain == 0) { // this is an error // the server should've set length to zero this.terminate(ErrorCode.invalid_control_frame); terminated = true; break; } }, .ping => { const ping_len = @min(data.len, @min(receive_body_remain, 125)); this.ping_len = ping_len; this.ping_received = true; this.dispatchData(data[0..ping_len], .Ping); if (ping_len > 0) { @memcpy(this.ping_frame_bytes[6..][0..ping_len], data[0..ping_len]); data = data[ping_len..]; } receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .pong => { const pong_len = @min(data.len, @min(receive_body_remain, this.ping_frame_bytes.len)); this.dispatchData(data[0..pong_len], .Pong); data = data[pong_len..]; receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .need_body => { // Empty messages are valid, but we handle that earlier in the flow. if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; break; } if (data.len == 0) return; const to_consume = @min(receive_body_remain, data.len); const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final); if (consumed == 0 and last_receive_data_type == .Text) { this.terminate(ErrorCode.invalid_utf8); terminated = true; break; } receive_body_remain -= consumed; data = data[to_consume..]; if (receive_body_remain == 0) { receive_state = .need_header; is_fragmented = false; } if (data.len == 0) break; }, .close => { // closing frame data is text only. // 2 byte close code if (data.len > 2 and receive_body_remain >= 2) { _ = this.consume(data[2..receive_body_remain], receive_body_remain - 2, .Text, true); data = data[receive_body_remain..]; } this.sendClose(); terminated = true; break; }, .fail => { this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, } } } pub fn sendClose(this: *WebSocket) void { this.sendCloseWithBody(this.tcp, 1001, null, 0); } fn enqueueEncodedBytes( this: *WebSocket, socket: Socket, bytes: []const u8, ) bool { // fast path: no backpressure, no queue, just send the bytes. if (!this.hasBackpressure()) { // Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010 const wrote = socket.write(bytes, false); const expected = @as(c_int, @intCast(bytes.len)); if (wrote == expected) { return true; } if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } _ = this.copyToSendBuffer(bytes[@as(usize, @intCast(wrote))..], false); return true; } return this.copyToSendBuffer(bytes, true); } fn copyToSendBuffer(this: *WebSocket, bytes: []const u8, do_write: bool) bool { return this.sendData(.{ .raw = bytes }, do_write, .Binary); } fn sendData(this: *WebSocket, bytes: Copy, do_write: bool, opcode: Opcode) bool { var content_byte_len: usize = 0; const write_len = bytes.len(&content_byte_len); std.debug.assert(write_len > 0); var writable = this.send_buffer.writableWithSize(write_len) catch unreachable; bytes.copy(this.globalThis, writable[0..write_len], content_byte_len, opcode); this.send_buffer.update(write_len); if (do_write) { if (comptime Environment.allow_assert) { std.debug.assert(!this.tcp.isShutdown()); std.debug.assert(!this.tcp.isClosed()); std.debug.assert(this.tcp.isEstablished()); } return this.sendBuffer(this.send_buffer.readableSlice(0)); } return true; } fn sendBuffer( this: *WebSocket, out_buf: []const u8, ) bool { std.debug.assert(out_buf.len > 0); // Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010 const wrote = this.tcp.write(out_buf, false); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } const expected = @as(usize, @intCast(wrote)); var readable = this.send_buffer.readableSlice(0); if (readable.ptr == out_buf.ptr) { this.send_buffer.discard(expected); } return true; } fn sendPong(this: *WebSocket, socket: Socket) bool { if (socket.isClosed() or socket.isShutdown()) { this.dispatchAbruptClose(); return false; } var header = @as(WebsocketHeader, @bitCast(@as(u16, 0))); header.final = true; header.opcode = .Pong; var to_mask = this.ping_frame_bytes[6..][0..this.ping_len]; header.mask = to_mask.len > 0; header.len = @as(u7, @truncate(this.ping_len)); this.ping_frame_bytes[0..2].* = @as([2]u8, @bitCast(header)); if (to_mask.len > 0) { Mask.fill(this.globalThis, this.ping_frame_bytes[2..6], to_mask, to_mask); return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]); } else { return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0..2]); } } fn sendCloseWithBody( this: *WebSocket, socket: Socket, code: u16, body: ?*[125]u8, body_len: usize, ) void { log("Sending close with code {d}", .{code}); if (socket.isClosed() or socket.isShutdown()) { this.dispatchAbruptClose(); this.clearData(); return; } socket.shutdownRead(); var final_body_bytes: [128 + 8]u8 = undefined; var header = @as(WebsocketHeader, @bitCast(@as(u16, 0))); header.final = true; header.opcode = .Close; header.mask = true; header.len = @as(u7, @truncate(body_len + 2)); final_body_bytes[0..2].* = @as([2]u8, @bitCast(@as(u16, @bitCast(header)))); var mask_buf: *[4]u8 = final_body_bytes[2..6]; std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code); var reason = bun.String.empty; if (body) |data| { if (body_len > 0) { reason = bun.String.create(data[0..body_len]); @memcpy(final_body_bytes[8..][0..body_len], data[0..body_len]); } } // we must mask the code var slice = final_body_bytes[0..(8 + body_len)]; Mask.fill(this.globalThis, mask_buf, slice[6..], slice[6..]); if (this.enqueueEncodedBytes(socket, slice)) { this.dispatchClose(code, &reason); this.clearData(); } } pub fn handleEnd(this: *WebSocket, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn handleWritable( this: *WebSocket, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); const send_buf = this.send_buffer.readableSlice(0); if (send_buf.len == 0) return; _ = this.sendBuffer(send_buf); } pub fn handleTimeout( this: *WebSocket, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *WebSocket, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub fn hasBackpressure(this: *const WebSocket) bool { return this.send_buffer.count > 0; } pub fn writeBinaryData( this: *WebSocket, ptr: [*]const u8, len: usize, op: u8, ) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchAbruptClose(); return; } const opcode = @as(Opcode, @enumFromInt(@as(u4, @truncate(op)))); const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating const frame_size = WebsocketHeader.frameSizeIncludingMask(len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { var inline_buf: [stack_frame_size]u8 = undefined; bytes.copy(this.globalThis, inline_buf[0..frame_size], slice.len, opcode); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } _ = this.sendData(bytes, !this.hasBackpressure(), opcode); } pub fn writeString( this: *WebSocket, str_: *const JSC.ZigString, op: u8, ) callconv(.C) void { const str = str_.*; if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchAbruptClose(); return; } // Note: 0 is valid const opcode = @as(Opcode, @enumFromInt(@as(u4, @truncate(op)))); { var inline_buf: [stack_frame_size]u8 = undefined; // fast path: small frame, no backpressure, attempt to send without allocating if (!str.is16Bit() and str.len < stack_frame_size) { const bytes = Copy{ .latin1 = str.slice() }; var byte_len: usize = 0; const frame_size = bytes.len(&byte_len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len, opcode); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string } else if ((str.len * 4) < (stack_frame_size) and !this.hasBackpressure()) { const bytes = Copy{ .utf16 = str.utf16SliceAligned() }; var byte_len: usize = 0; const frame_size = bytes.len(&byte_len); std.debug.assert(frame_size <= stack_frame_size); bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len, opcode); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } } _ = this.sendData( if (str.is16Bit()) Copy{ .utf16 = str.utf16SliceAligned() } else Copy{ .latin1 = str.slice() }, !this.hasBackpressure(), opcode, ); } fn dispatchAbruptClose(this: *WebSocket) void { var out = this.outgoing_websocket orelse return; this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); JSC.markBinding(@src()); this.outgoing_websocket = null; out.didAbruptClose(ErrorCode.closed); } fn dispatchClose(this: *WebSocket, code: u16, reason: *const bun.String) void { var out = this.outgoing_websocket orelse return; this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); JSC.markBinding(@src()); this.outgoing_websocket = null; out.didClose(code, reason); } pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) return; var close_reason_buf: [128]u8 = undefined; if (reason) |str| { inner: { var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf); const allocator = fixed_buffer.allocator(); const wrote = std.fmt.allocPrint(allocator, "{}", .{str.*}) catch break :inner; this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len); return; } } this.sendCloseWithBody(this.tcp, code, null, 0); } const InitialDataHandler = struct { adopted: ?*WebSocket, ws: *CppWebSocket, slice: []u8, pub const Handle = JSC.AnyTask.New(@This(), handle); pub fn handleWithoutDeinit(this: *@This()) void { var this_socket = this.adopted orelse return; this.adopted = null; this_socket.initial_data_handler = null; var ws = this.ws; defer ws.unref(); if (this_socket.outgoing_websocket != null) this_socket.handleData(this_socket.tcp, this.slice); } pub fn handle(this: *@This()) void { defer { bun.default_allocator.free(this.slice); bun.default_allocator.destroy(this); } this.handleWithoutDeinit(); } }; pub fn init( outgoing: *CppWebSocket, input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, buffered_data: [*]u8, buffered_data_len: usize, ) callconv(.C) ?*anyopaque { var tcp = @as(*uws.Socket, @ptrCast(input_socket)); var ctx = @as(*uws.SocketContext, @ptrCast(socket_ctx)); var adopted = Socket.adopt( tcp, ctx, WebSocket, "tcp", WebSocket{ .tcp = undefined, .outgoing_websocket = outgoing, .globalThis = globalThis, .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), .receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), }, ) orelse return null; adopted.send_buffer.ensureTotalCapacity(2048) catch return null; adopted.receive_buffer.ensureTotalCapacity(2048) catch return null; adopted.poll_ref.ref(globalThis.bunVM()); var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ .adopted = adopted, .slice = buffered_slice, .ws = outgoing, }; // Use a higher-priority callback for the initial onData handler globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle); // We need to ref the outgoing websocket so that it doesn't get finalized // before the initial data handler is called outgoing.ref(); } return @as( *anyopaque, @ptrCast(adopted), ); } pub fn finalize(this: *WebSocket) callconv(.C) void { log("finalize", .{}); this.clearData(); this.outgoing_websocket = null; if (this.tcp.isClosed()) return; this.tcp.close(0, null); } pub const Export = shim.exportFunctions(.{ .writeBinaryData = writeBinaryData, .writeString = writeString, .close = close, .cancel = cancel, .register = register, .init = init, .finalize = finalize, }); comptime { if (!JSC.is_bindgen) { @export(writeBinaryData, .{ .name = Export[0].symbol_name }); @export(writeString, .{ .name = Export[1].symbol_name }); @export(close, .{ .name = Export[2].symbol_name }); @export(cancel, .{ .name = Export[3].symbol_name }); @export(register, .{ .name = Export[4].symbol_name }); @export(init, .{ .name = Export[5].symbol_name }); @export(finalize, .{ .name = Export[6].symbol_name }); } } }; } pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false); pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true); pub const WebSocketClient = NewWebSocketClient(false); pub const WebSocketClientTLS = NewWebSocketClient(true);