// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig // Thank you @frmdstryr. const std = @import("std"); const native_endian = @import("builtin").target.cpu.arch.endian(); const bun = @import("bun"); const string = bun.string; const Output = bun.Output; const Global = bun.Global; const Environment = bun.Environment; const strings = bun.strings; const MutableString = bun.MutableString; const stringZ = bun.stringZ; const default_allocator = bun.default_allocator; const C = bun.C; const uws = @import("bun").uws; const JSC = @import("bun").JSC; const PicoHTTP = @import("bun").picohttp; const ObjectPool = @import("../pool.zig").ObjectPool; const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; const Opcode = @import("./websocket.zig").Opcode; const log = Output.scoped(.WebSocketClient, false); const NonUTF8Headers = struct { names: []const JSC.ZigString, values: []const JSC.ZigString, pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { const count = self.names.len; var i: usize = 0; while (i < count) : (i += 1) { try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] }); } } pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers { if (len == 0) { return .{ .names = &[_]JSC.ZigString{}, .values = &[_]JSC.ZigString{}, }; } return .{ .names = names.?[0..len], .values = values.?[0..len], }; } }; fn buildRequestBody( vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64, extra_headers: NonUTF8Headers, ) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; const input_rand_buf = vm.rareData().nextUUID(); const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); var static_headers = [_]PicoHTTP.Header{ .{ .name = "Sec-WebSocket-Key", .value = accept_key, }, .{ .name = "Sec-WebSocket-Protocol", .value = client_protocol.slice(), }, }; if (client_protocol.len > 0) client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value); const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; const pathname_ = pathname.slice(); const host_ = host.slice(); const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; return try std.fmt.allocPrint( allocator, "GET {s} HTTP/1.1\r\n" ++ "Host: {s}\r\n" ++ "Pragma: no-cache\r\n" ++ "Cache-Control: no-cache\r\n" ++ "Connection: Upgrade\r\n" ++ "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{any}" ++ "{any}" ++ "\r\n", .{ pathname_, host_, pico_headers, extra_headers }, ); } const ErrorCode = enum(i32) { cancel, invalid_response, expected_101_status_code, missing_upgrade_header, missing_connection_header, missing_websocket_accept_header, invalid_upgrade_header, invalid_connection_header, invalid_websocket_version, mismatch_websocket_accept_header, missing_client_protocol, mismatch_client_protocol, timeout, closed, failed_to_write, failed_to_connect, headers_too_large, ended, failed_to_allocate_memory, control_frame_is_fragmented, invalid_control_frame, compression_unsupported, unexpected_mask_from_server, expected_control_frame, unsupported_control_frame, unexpected_opcode, invalid_utf8, }; extern fn WebSocket__didConnect( websocket_context: *anyopaque, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize, ) void; extern fn WebSocket__didCloseWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void; extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: [*]const u8, byte_len: usize) void; const body_buf_len = 16384 - 16; const BodyBufBytes = [body_buf_len]u8; const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4); const BodyBuf = BodyBufPool.Node; pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: *anyopaque, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, body_buf: ?*BodyBuf = null, body_written: usize = 0, websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: JSC.PollRef = .{}, pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const HTTPClient = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, @alignCast(@alignOf(uws.Loop), loop_)); var ctx: *uws.SocketContext = @ptrCast(*uws.SocketContext, ctx_); if (vm.uws_event_loop) |other| { std.debug.assert(other == loop); } const is_new_loop = vm.uws_event_loop == null; vm.uws_event_loop = loop; Socket.configure( ctx, false, HTTPClient, struct { pub const onOpen = handleOpen; pub const onClose = handleClose; pub const onData = handleData; pub const onWritable = handleWritable; pub const onTimeout = handleTimeout; pub const onConnectError = handleConnectError; pub const onEnd = handleEnd; }, ); if (is_new_loop) { vm.prepareLoop(); } } pub fn connect( global: *JSC.JSGlobalObject, socket_ctx: *anyopaque, websocket: *anyopaque, host: *const JSC.ZigString, port: u16, pathname: *const JSC.ZigString, client_protocol: *const JSC.ZigString, header_names: ?[*]const JSC.ZigString, header_values: ?[*]const JSC.ZigString, header_count: usize, ) callconv(.C) ?*HTTPClient { std.debug.assert(global.bunVM().uws_event_loop != null); var client_protocol_hash: u64 = 0; var body = buildRequestBody( global.bunVM(), pathname, host, client_protocol, &client_protocol_hash, NonUTF8Headers.init(header_names, header_values, header_count), ) catch return null; var client: HTTPClient = HTTPClient{ .tcp = undefined, .outgoing_websocket = websocket, .input_body_buf = body, .websocket_protocol = client_protocol_hash, }; var host_ = host.toSlice(bun.default_allocator); defer host_.deinit(); var vm = global.bunVM(); const prev_start_server_on_next_tick = vm.eventLoop().start_server_on_next_tick; vm.eventLoop().start_server_on_next_tick = true; client.poll_ref.ref(vm); const display_host_ = host_.slice(); const display_host = if (bun.FeatureFlags.hardcode_localhost_to_127_0_0_1 and strings.eqlComptime(display_host_, "localhost")) "127.0.0.1" else display_host_; if (Socket.connect( display_host, port, @ptrCast(*uws.SocketContext, socket_ctx), HTTPClient, client, "tcp", )) |out| { if (comptime ssl) { if (!strings.isIPAddress(host_.slice())) { out.hostname = bun.default_allocator.dupeZ(u8, host_.slice()) catch ""; } } out.tcp.timeout(120); return out; } vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick; client.clearData(); return null; } pub fn clearInput(this: *HTTPClient) void { if (this.input_body_buf.len > 0) bun.default_allocator.free(this.input_body_buf); this.input_body_buf.len = 0; } pub fn clearData(this: *HTTPClient) void { this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get()); this.clearInput(); if (this.body_buf) |buf| { this.body_buf = null; buf.release(); } } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { JSC.markBinding(@src()); WebSocket__didCloseWithErrorCode(this.outgoing_websocket, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(@src()); this.clearData(); WebSocket__didCloseWithErrorCode(this.outgoing_websocket, ErrorCode.ended); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); if (!this.tcp.isClosed()) this.tcp.close(0, null); } pub fn handleOpen(this: *HTTPClient, socket: Socket) void { log("onOpen", .{}); std.debug.assert(socket.socket == this.tcp.socket); std.debug.assert(this.input_body_buf.len > 0); std.debug.assert(this.to_send.len == 0); if (comptime ssl) { if (this.hostname.len > 0) { socket.getNativeHandle().configureHTTPClient(this.hostname); bun.default_allocator.free(this.hostname); this.hostname = ""; } } const wrote = socket.write(this.input_body_buf, true); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; } fn getBody(this: *HTTPClient) *BodyBufBytes { if (this.body_buf == null) { this.body_buf = BodyBufPool.get(bun.default_allocator); } return &this.body_buf.?.data; } pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); std.debug.assert(socket.socket == this.tcp.socket); if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); var body = this.getBody(); var remain = body[this.body_written..]; const is_first = this.body_written == 0; if (is_first) { // fail early if we receive a non-101 status code if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } const to_write = remain[0..@min(remain.len, data.len)]; if (data.len > 0 and to_write.len > 0) { @memcpy(remain.ptr, data.ptr, to_write.len); this.body_written += to_write.len; } const overflow = data[to_write.len..]; const available_to_read = body[0..this.body_written]; const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { if (overflow.len > 0) { this.terminate(ErrorCode.headers_too_large); return; } return; }, } }; this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { log("onEnd", .{}); std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { std.debug.assert(this.body_written > 0); var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; var visited_protocol = this.websocket_protocol == 0; // var visited_version = false; std.debug.assert(response.status_code == 101); for (response.headers) |header| { switch (header.name.len) { "Connection".len => { if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) { connection_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Upgrade".len => { if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) { upgrade_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Version".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) { if (!strings.eqlComptimeIgnoreLen(header.value, "13")) { this.terminate(ErrorCode.invalid_websocket_version); return; } } }, "Sec-WebSocket-Accept".len => { if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) { websocket_accept_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Protocol".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } visited_protocol = true; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, else => {}, } } // if (!visited_version) { // this.terminate(ErrorCode.invalid_websocket_version); // return; // } if (@min(upgrade_header.name.len, upgrade_header.value.len) == 0) { this.terminate(ErrorCode.missing_upgrade_header); return; } if (@min(connection_header.name.len, connection_header.value.len) == 0) { this.terminate(ErrorCode.missing_connection_header); return; } if (@min(websocket_accept_header.name.len, websocket_accept_header.value.len) == 0) { this.terminate(ErrorCode.missing_websocket_accept_header); return; } if (!visited_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } if (!strings.eqlCaseInsensitiveASCII(connection_header.value, "Upgrade", true)) { this.terminate(ErrorCode.invalid_connection_header); return; } if (!strings.eqlCaseInsensitiveASCII(upgrade_header.value, "websocket", true)) { this.terminate(ErrorCode.invalid_upgrade_header); return; } // TODO: check websocket_accept_header.value const overflow_len = remain_buf.len; var overflow: []u8 = &.{}; if (overflow_len > 0) { overflow = bun.default_allocator.alloc(u8, overflow_len) catch { this.terminate(ErrorCode.invalid_response); return; }; if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); } this.clearData(); JSC.markBinding(@src()); this.tcp.timeout(0); log("onDidConnect", .{}); WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); if (this.to_send.len == 0) return; const wrote = socket.write(this.to_send, true); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } this.to_send = this.to_send[@min(@intCast(usize, wrote), this.to_send.len)..]; } pub fn handleTimeout( this: *HTTPClient, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub const Export = shim.exportFunctions(.{ .connect = connect, .cancel = cancel, .register = register, }); comptime { if (!JSC.is_bindgen) { @export(connect, .{ .name = Export[0].symbol_name, }); @export(cancel, .{ .name = Export[1].symbol_name, }); @export(register, .{ .name = Export[2].symbol_name, }); } } }; } pub const Mask = struct { pub fn fill(globalThis: *JSC.JSGlobalObject, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void { mask_buf.* = globalThis.bunVM().rareData().entropySlice(4)[0..4].*; const mask = mask_buf.*; const skip_mask = @bitCast(u32, mask) == 0; if (!skip_mask) { fillWithSkipMask(mask, output_, input_, false); } else { fillWithSkipMask(mask, output_, input_, true); } } fn fillWithSkipMask(mask: [4]u8, output_: []u8, input_: []const u8, comptime skip_mask: bool) void { var input = input_; var output = output_; if (comptime Environment.enableSIMD) { if (input.len >= strings.ascii_vector_size) { const vec: strings.AsciiVector = brk: { var in: [strings.ascii_vector_size]u8 = undefined; comptime var i: usize = 0; inline while (i < strings.ascii_vector_size) : (i += 4) { in[i..][0..4].* = mask; } break :brk @as(strings.AsciiVector, in); }; const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size); if (comptime skip_mask) { while (input.ptr != end_ptr_wrapped_to_last_16) { const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); output.ptr[0..strings.ascii_vector_size].* = input_vec; output = output[strings.ascii_vector_size..]; input = input[strings.ascii_vector_size..]; } } else { while (input.ptr != end_ptr_wrapped_to_last_16) { const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; output = output[strings.ascii_vector_size..]; input = input[strings.ascii_vector_size..]; } } } // hint to the compiler not to vectorize the next loop std.debug.assert(input.len < strings.ascii_vector_size); } if (comptime !skip_mask) { while (input.len >= 4) { const input_vec: [4]u8 = input[0..4].*; output.ptr[0..4].* = [4]u8{ input_vec[0] ^ mask[0], input_vec[1] ^ mask[1], input_vec[2] ^ mask[2], input_vec[3] ^ mask[3], }; output = output[4..]; input = input[4..]; } } else { while (input.len >= 4) { const input_vec: [4]u8 = input[0..4].*; output.ptr[0..4].* = input_vec; output = output[4..]; input = input[4..]; } } if (comptime !skip_mask) { for (input) |c, i| { output[i] = c ^ mask[i % 4]; } } else { for (input) |c, i| { output[i] = c; } } } }; const ReceiveState = enum { need_header, need_mask, need_body, extended_payload_length_16, extended_payload_length_64, ping, pong, close, fail, pub fn needControlFrame(this: ReceiveState) bool { return this != .need_body; } }; const DataType = enum { none, text, binary, }; fn parseWebSocketHeader( bytes: [2]u8, receiving_type: *Opcode, payload_length: *usize, is_fragmented: *bool, is_final: *bool, need_compression: *bool, ) ReceiveState { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ const header = @bitCast(WebsocketHeader, @byteSwap(@bitCast(u16, bytes))); const payload = @as(usize, header.len); payload_length.* = payload; receiving_type.* = header.opcode; is_fragmented.* = switch (header.opcode) { .Continue => true, else => false, } or !header.final; is_final.* = header.final; need_compression.* = header.compressed; if (header.mask and (header.opcode == .Text or header.opcode == .Binary)) { return .need_mask; } return switch (header.opcode) { .Text, .Continue, .Binary => if (payload <= 125) return .need_body else if (payload == 126) return .extended_payload_length_16 else if (payload == 127) return .extended_payload_length_64 else return .fail, .Close => ReceiveState.close, .Ping => ReceiveState.ping, .Pong => ReceiveState.pong, else => ReceiveState.fail, }; } const Copy = union(enum) { utf16: []const u16, latin1: []const u8, bytes: []const u8, raw: []const u8, pub fn len(this: @This(), byte_len: *usize) usize { switch (this) { .utf16 => { byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { byte_len.* = this.latin1.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { byte_len.* = this.bytes.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .raw => { byte_len.* = this.raw.len; return this.raw.len; }, } } pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void { if (this == .raw) { std.debug.assert(buf.len >= this.raw.len); std.debug.assert(buf.ptr != this.raw.ptr); @memcpy(buf.ptr, this.raw.ptr, this.raw.len); return; } const how_big_is_the_length_integer = WebsocketHeader.lengthByteCount(content_byte_len); const how_big_is_the_mask = 4; const mask_offset = 2 + how_big_is_the_length_integer; const content_offset = mask_offset + how_big_is_the_mask; // 2 byte header // 4 byte mask // 0, 2, 8 byte length var to_mask = buf[content_offset..]; var header = @bitCast(WebsocketHeader, @as(u16, 0)); // Write extended length if needed switch (how_big_is_the_length_integer) { 0 => {}, 2 => std.mem.writeIntBig(u16, buf[2..][0..2], @truncate(u16, content_byte_len)), 8 => std.mem.writeIntBig(u64, buf[2..][0..8], @truncate(u64, content_byte_len)), else => unreachable, } header.mask = true; header.compressed = false; header.final = true; std.debug.assert(WebsocketHeader.frameSizeIncludingMask(content_byte_len) == buf.len); switch (this) { .utf16 => |utf16| { header.len = WebsocketHeader.packLength(content_byte_len); const encode_into_result = strings.copyUTF16IntoUTF8(to_mask, []const u16, utf16); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == utf16.len); header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .latin1 => |latin1| { const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .bytes => |bytes| { header.len = WebsocketHeader.packLength(bytes.len); header.opcode = Opcode.Binary; var fib = std.io.fixedBufferStream(buf); header.writeHeader(fib.writer(), bytes.len) catch unreachable; Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], bytes); }, .raw => unreachable, } } }; pub fn NewWebSocketClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: ?*anyopaque, receive_state: ReceiveState = ReceiveState.need_header, receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)), receiving_type: Opcode = Opcode.ResB, ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6), ping_len: u8 = 0, receive_frame: usize = 0, receive_body_remain: usize = 0, receive_pending_chunk_len: usize = 0, receive_buffer: bun.LinearFifo(u8, .Dynamic), send_buffer: bun.LinearFifo(u8, .Dynamic), globalThis: *JSC.JSGlobalObject, poll_ref: JSC.PollRef = JSC.PollRef.init(), pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const stack_frame_size = 1024; const WebSocket = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, @alignCast(@alignOf(uws.Loop), loop_)); var ctx: *uws.SocketContext = @ptrCast(*uws.SocketContext, ctx_); if (vm.uws_event_loop) |other| { std.debug.assert(other == loop); } vm.uws_event_loop = loop; Socket.configure( ctx, false, WebSocket, struct { pub const onClose = handleClose; pub const onData = handleData; pub const onWritable = handleWritable; pub const onTimeout = handleTimeout; pub const onConnectError = handleConnectError; pub const onEnd = handleEnd; }, ); } pub fn clearData(this: *WebSocket) void { this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); this.clearReceiveBuffers(true); this.clearSendBuffers(true); this.ping_len = 0; this.receive_pending_chunk_len = 0; } pub fn cancel(this: *WebSocket) callconv(.C) void { this.clearData(); if (this.tcp.isClosed() or this.tcp.isShutdown()) return; if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *WebSocket, code: ErrorCode) void { JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| WebSocket__didCloseWithErrorCode(ws, code); this.cancel(); } pub fn handleClose(this: *WebSocket, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(@src()); this.clearData(); if (this.outgoing_websocket) |ws| WebSocket__didCloseWithErrorCode(ws, ErrorCode.ended); } pub fn terminate(this: *WebSocket, code: ErrorCode) void { this.fail(code); } fn getReceiveBody(this: *WebSocket) *BodyBufBytes { if (this.receive_body_buf == null) { this.receive_body_buf = BodyBufPool.get(bun.default_allocator); } return &this.receive_body_buf.?.data; } fn clearReceiveBuffers(this: *WebSocket, free: bool) void { this.receive_buffer.head = 0; this.receive_buffer.count = 0; if (free) { this.receive_buffer.deinit(); this.receive_buffer.buf.len = 0; } this.receive_pending_chunk_len = 0; this.receive_body_remain = 0; } fn clearSendBuffers(this: *WebSocket, free: bool) void { this.send_buffer.head = 0; this.send_buffer.count = 0; if (free) { this.send_buffer.deinit(); this.send_buffer.buf.len = 0; } } fn dispatchData(this: *WebSocket, data_: []const u8, kind: Opcode) void { var out = this.outgoing_websocket orelse { this.clearData(); return; }; switch (kind) { .Text => { // this function encodes to UTF-16 if > 127 // so we don't need to worry about latin1 non-ascii code points const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch { this.terminate(ErrorCode.invalid_utf8); return; }; var outstring = JSC.ZigString.Empty; if (utf16_bytes_) |utf16| { outstring = JSC.ZigString.from16Slice(utf16); outstring.mark(); JSC.markBinding(@src()); WebSocket__didReceiveText(out, false, &outstring); } else { outstring = JSC.ZigString.init(data_); JSC.markBinding(@src()); WebSocket__didReceiveText(out, true, &outstring); } }, .Binary => { JSC.markBinding(@src()); WebSocket__didReceiveBytes(out, data_.ptr, data_.len); }, else => unreachable, } } pub fn consume(this: *WebSocket, data_: []const u8, left_in_fragment: usize, kind: Opcode, is_final: bool) usize { std.debug.assert(kind == .Text or kind == .Binary); std.debug.assert(data_.len <= left_in_fragment); // did all the data fit in the buffer? // we can avoid copying & allocating a temporary buffer if (is_final and data_.len == left_in_fragment and this.receive_pending_chunk_len == 0) { this.dispatchData(data_, kind); return data_.len; } // this must come after the above check std.debug.assert(data_.len > 0); var writable = this.receive_buffer.writableWithSize(data_.len) catch unreachable; @memcpy(writable.ptr, data_.ptr, data_.len); this.receive_buffer.update(data_.len); if (left_in_fragment > data_.len and left_in_fragment - data_.len - this.receive_pending_chunk_len == 0) { this.receive_pending_chunk_len = 0; this.dispatchData(this.receive_buffer.readableSlice(0), kind); this.clearReceiveBuffers(false); } else { this.receive_pending_chunk_len -|= left_in_fragment; } return data_.len; } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { var data = data_; var receive_state = this.receive_state; var terminated = false; var is_fragmented = false; var receiving_type = this.receiving_type; var receive_body_remain = this.receive_body_remain; var is_final = false; var last_receive_data_type = receiving_type; defer { if (!terminated) { this.receive_state = receive_state; this.receiving_type = last_receive_data_type; this.receive_body_remain = receive_body_remain; // if we receive multiple pings in a row // we just send back the last one if (this.ping_len > 0) { _ = this.sendPong(socket); this.ping_len = 0; } } } var header_bytes: [@sizeOf(usize)]u8 = [_]u8{0} ** @sizeOf(usize); while (true) { log("onData ({s})", .{@tagName(receive_state)}); switch (receive_state) { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ .need_header => { if (data.len < 2) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } header_bytes[0..2].* = data[0..2].*; receive_body_remain = 0; var need_compression = false; is_final = false; receive_state = parseWebSocketHeader( header_bytes[0..2].*, &receiving_type, &receive_body_remain, &is_fragmented, &is_final, &need_compression, ); last_receive_data_type = if (receiving_type == .Text or receiving_type == .Binary) receiving_type else last_receive_data_type; data = data[2..]; if (receiving_type.isControl() and is_fragmented) { // Control frames must not be fragmented. this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } switch (receiving_type) { .Continue, .Text, .Binary, .Ping, .Pong, .Close => {}, else => { this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, } if (need_compression) { this.terminate(ErrorCode.compression_unsupported); terminated = true; break; } }, .need_mask => { this.terminate(.unexpected_mask_from_server); terminated = true; break; }, .extended_payload_length_64, .extended_payload_length_16 => |rc| { const byte_size = switch (rc) { .extended_payload_length_64 => @as(usize, 8), .extended_payload_length_16 => @as(usize, 2), else => unreachable, }; if (data.len < byte_size) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } // Multibyte length quantities are expressed in network byte order receive_body_remain = switch (byte_size) { 8 => @as(usize, std.mem.readIntBig(u64, data[0..8])), 2 => @as(usize, std.mem.readIntBig(u16, data[0..2])), else => unreachable, }; data = data[byte_size..]; receive_state = .need_body; if (receive_body_remain == 0) { // this is an error // the server should've set length to zero this.terminate(ErrorCode.invalid_control_frame); terminated = true; break; } }, .ping => { const ping_len = @min(data.len, @min(receive_body_remain, 125)); this.ping_len = @truncate(u8, ping_len); if (ping_len > 0) { @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len); data = data[ping_len..]; } receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .pong => { const pong_len = @min(data.len, @min(receive_body_remain, this.ping_frame_bytes.len)); data = data[pong_len..]; receive_state = .need_header; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .need_body => { if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; break; } if (data.len == 0) return; const to_consume = @min(receive_body_remain, data.len); const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final); if (consumed == 0 and last_receive_data_type == .Text) { this.terminate(ErrorCode.invalid_utf8); terminated = true; break; } receive_body_remain -= consumed; data = data[to_consume..]; if (receive_body_remain == 0) { receive_state = .need_header; is_fragmented = false; } if (data.len == 0) break; }, .close => { // closing frame data is text only. // 2 byte close code if (data.len > 2 and receive_body_remain >= 2) { _ = this.consume(data[2..receive_body_remain], receive_body_remain - 2, .Text, true); data = data[receive_body_remain..]; } this.sendClose(); terminated = true; break; }, .fail => { this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, } } } pub fn sendClose(this: *WebSocket) void { this.sendCloseWithBody(this.tcp, 1001, null, 0); } fn enqueueEncodedBytesMaybeFinal( this: *WebSocket, socket: Socket, bytes: []const u8, is_closing: bool, ) bool { // fast path: no backpressure, no queue, just send the bytes. if (!this.hasBackpressure()) { const wrote = socket.write(bytes, !is_closing); const expected = @intCast(c_int, bytes.len); if (wrote == expected) { return true; } if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } _ = this.copyToSendBuffer(bytes[@intCast(usize, wrote)..], false, is_closing); return true; } return this.copyToSendBuffer(bytes, true, is_closing); } fn copyToSendBuffer(this: *WebSocket, bytes: []const u8, do_write: bool, is_closing: bool) bool { return this.sendData(.{ .raw = bytes }, do_write, is_closing); } fn sendData(this: *WebSocket, bytes: Copy, do_write: bool, is_closing: bool) bool { var content_byte_len: usize = 0; const write_len = bytes.len(&content_byte_len); std.debug.assert(write_len > 0); var writable = this.send_buffer.writableWithSize(write_len) catch unreachable; bytes.copy(this.globalThis, writable[0..write_len], content_byte_len); this.send_buffer.update(write_len); if (do_write) { if (comptime Environment.allow_assert) { std.debug.assert(!this.tcp.isShutdown()); std.debug.assert(!this.tcp.isClosed()); std.debug.assert(this.tcp.isEstablished()); } return this.sendBuffer(this.send_buffer.readableSlice(0), is_closing, !is_closing); } return true; } fn sendBuffer( this: *WebSocket, out_buf: []const u8, is_closing: bool, _: bool, ) bool { std.debug.assert(out_buf.len > 0); _ = is_closing; // set msg_more to false // it seems to improve perf by ~20% const wrote = this.tcp.write(out_buf, false); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } const expected = @intCast(usize, wrote); var readable = this.send_buffer.readableSlice(0); if (readable.ptr == out_buf.ptr) { this.send_buffer.discard(expected); } return true; } fn enqueueEncodedBytes(this: *WebSocket, socket: Socket, bytes: []const u8) bool { return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false); } fn sendPong(this: *WebSocket, socket: Socket) bool { if (socket.isClosed() or socket.isShutdown()) { this.dispatchClose(); return false; } var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.final = true; header.opcode = .Pong; var to_mask = this.ping_frame_bytes[6..][0..this.ping_len]; header.mask = to_mask.len > 0; header.len = @truncate(u7, this.ping_len); this.ping_frame_bytes[0..2].* = @bitCast([2]u8, header); if (to_mask.len > 0) { Mask.fill(this.globalThis, this.ping_frame_bytes[2..6], to_mask, to_mask); return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]); } else { return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0..2]); } } fn sendCloseWithBody( this: *WebSocket, socket: Socket, code: u16, body: ?*[125]u8, body_len: usize, ) void { log("Sending close with code {d}", .{code}); if (socket.isClosed() or socket.isShutdown()) { this.dispatchClose(); this.clearData(); return; } socket.shutdownRead(); var final_body_bytes: [128 + 8]u8 = undefined; var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.final = true; header.opcode = .Close; header.mask = true; header.len = @truncate(u7, body_len + 2); final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header)); var mask_buf: *[4]u8 = final_body_bytes[2..6]; std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code); if (body) |data| { if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len); } // we must mask the code var slice = final_body_bytes[0..(8 + body_len)]; Mask.fill(this.globalThis, mask_buf, slice[6..], slice[6..]); if (this.enqueueEncodedBytesMaybeFinal(socket, slice, true)) { this.dispatchClose(); this.clearData(); } } pub fn handleEnd(this: *WebSocket, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn handleWritable( this: *WebSocket, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); const send_buf = this.send_buffer.readableSlice(0); if (send_buf.len == 0) return; _ = this.sendBuffer(send_buf, false, true); } pub fn handleTimeout( this: *WebSocket, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *WebSocket, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub fn hasBackpressure(this: *const WebSocket) bool { return this.send_buffer.count > 0; } pub fn writeBinaryData( this: *WebSocket, ptr: [*]const u8, len: usize, ) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchClose(); return; } if (len == 0) return; const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating const frame_size = WebsocketHeader.frameSizeIncludingMask(len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { var inline_buf: [stack_frame_size]u8 = undefined; bytes.copy(this.globalThis, inline_buf[0..frame_size], slice.len); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } _ = this.sendData(bytes, !this.hasBackpressure(), false); } pub fn writeString( this: *WebSocket, str_: *const JSC.ZigString, ) callconv(.C) void { const str = str_.*; if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchClose(); return; } if (str.len == 0) { return; } { var inline_buf: [stack_frame_size]u8 = undefined; // fast path: small frame, no backpressure, attempt to send without allocating if (!str.is16Bit() and str.len < stack_frame_size) { const bytes = Copy{ .latin1 = str.slice() }; const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string } else if ((str.len * 4) < (stack_frame_size) and !this.hasBackpressure()) { const bytes = Copy{ .utf16 = str.utf16SliceAligned() }; var byte_len: usize = 0; const frame_size = bytes.len(&byte_len); std.debug.assert(frame_size <= stack_frame_size); bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } } _ = this.sendData( if (str.is16Bit()) Copy{ .utf16 = str.utf16SliceAligned() } else Copy{ .latin1 = str.slice() }, !this.hasBackpressure(), false, ); } fn dispatchClose(this: *WebSocket) void { var out = this.outgoing_websocket orelse return; this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); JSC.markBinding(@src()); WebSocket__didCloseWithErrorCode(out, ErrorCode.closed); } pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) return; var close_reason_buf: [128]u8 = undefined; if (reason) |str| { inner: { var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf); const allocator = fixed_buffer.allocator(); const wrote = std.fmt.allocPrint(allocator, "{}", .{str.*}) catch break :inner; this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len); return; } } this.sendCloseWithBody(this.tcp, code, null, 0); } pub fn init( outgoing: *anyopaque, input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, buffered_data: [*]u8, buffered_data_len: usize, ) callconv(.C) ?*anyopaque { var tcp = @ptrCast(*uws.Socket, input_socket); var ctx = @ptrCast(*uws.SocketContext, socket_ctx); var adopted = Socket.adopt( tcp, ctx, WebSocket, "tcp", WebSocket{ .tcp = undefined, .outgoing_websocket = outgoing, .globalThis = globalThis, .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), .receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), }, ) orelse return null; adopted.send_buffer.ensureTotalCapacity(2048) catch return null; adopted.receive_buffer.ensureTotalCapacity(2048) catch return null; adopted.poll_ref.ref(globalThis.bunVM()); var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { const InitialDataHandler = struct { adopted: *WebSocket, slice: []u8, task: JSC.AnyTask = undefined, pub const Handle = JSC.AnyTask.New(@This(), handle); pub fn handle(this: *@This()) void { defer { bun.default_allocator.free(this.slice); bun.default_allocator.destroy(this); } this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; var writable = this.adopted.receive_buffer.writableSlice(0); @memcpy(writable.ptr, this.slice.ptr, this.slice.len); this.adopted.handleData(this.adopted.tcp, writable); } }; var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ .adopted = adopted, .slice = buffered_slice, }; initial_data.task = InitialDataHandler.Handle.init(initial_data); globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task)); } return @ptrCast( *anyopaque, adopted, ); } pub fn finalize(this: *WebSocket) callconv(.C) void { log("finalize", .{}); this.clearData(); this.outgoing_websocket = null; if (this.tcp.isClosed()) return; this.tcp.close(0, null); } pub const Export = shim.exportFunctions(.{ .writeBinaryData = writeBinaryData, .writeString = writeString, .close = close, .register = register, .init = init, .finalize = finalize, }); comptime { if (!JSC.is_bindgen) { @export(writeBinaryData, .{ .name = Export[0].symbol_name }); @export(writeString, .{ .name = Export[1].symbol_name }); @export(close, .{ .name = Export[2].symbol_name }); @export(register, .{ .name = Export[3].symbol_name }); @export(init, .{ .name = Export[4].symbol_name }); @export(finalize, .{ .name = Export[5].symbol_name }); } } }; } pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false); pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true); pub const WebSocketClient = NewWebSocketClient(false); pub const WebSocketClientTLS = NewWebSocketClient(true);