// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig // Thank you @frmdstryr. const std = @import("std"); const native_endian = @import("builtin").target.cpu.arch.endian(); const bun = @import("../global.zig"); const string = bun.string; const Output = bun.Output; const Global = bun.Global; const Environment = bun.Environment; const strings = bun.strings; const MutableString = bun.MutableString; const stringZ = bun.stringZ; const default_allocator = bun.default_allocator; const C = bun.C; const uws = @import("uws"); const JSC = @import("javascript_core"); const PicoHTTP = @import("picohttp"); const ObjectPool = @import("../pool.zig").ObjectPool; const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; const Opcode = @import("./websocket.zig").Opcode; fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; var input_rand_buf: [16]u8 = undefined; std.crypto.random.bytes(&input_rand_buf); const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); var headers = [_]PicoHTTP.Header{ .{ .name = "Sec-WebSocket-Key", .value = accept_key, }, .{ .name = "Sec-WebSocket-Protocol", .value = client_protocol.slice(), }, }; if (client_protocol.len > 0) client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value); var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; const pathname_ = pathname.slice(); const host_ = host.slice(); const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; return try std.fmt.allocPrint( allocator, "GET {s} HTTP/1.1\r\n" ++ "Host: {s}\r\n" ++ "Pragma: no-cache\r\n" ++ "Cache-Control: no-cache\r\n" ++ "Connection: Upgrade\r\n" ++ "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{any}" ++ "\r\n", .{ pathname_, host_, pico_headers, }, ); } const ErrorCode = enum(i32) { cancel, invalid_response, expected_101_status_code, missing_upgrade_header, missing_connection_header, missing_websocket_accept_header, invalid_upgrade_header, invalid_connection_header, invalid_websocket_version, mismatch_websocket_accept_header, missing_client_protocol, mismatch_client_protocol, timeout, closed, failed_to_write, failed_to_connect, headers_too_large, ended, failed_to_allocate_memory, control_frame_is_fragmented, invalid_control_frame, compression_unsupported, unexpected_mask_from_server, expected_control_frame, unexpected_opcode, invalid_utf8, }; extern fn WebSocket__didConnect( websocket_context: *anyopaque, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize, ) void; extern fn WebSocket__didFailWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void; extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: []const u8) void; const body_buf_len = 16384 - 16; const BodyBufBytes = [body_buf_len]u8; const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4); const BodyBuf = BodyBufPool.Node; pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: *anyopaque, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, body_buf: ?*BodyBuf = null, body_written: usize = 0, websocket_protocol: u64 = 0, pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const HTTPClient = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, loop_); var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); if (vm.uws_event_loop) |other| { std.debug.assert(other == loop); } vm.uws_event_loop = loop; Socket.configure(ctx, HTTPClient, handleOpen, handleClose, handleData, handleWritable, handleTimeout, handleConnectError, handleEnd); } pub fn connect( global: *JSC.JSGlobalObject, socket_ctx: *anyopaque, websocket: *anyopaque, host: *const JSC.ZigString, port: u16, pathname: *const JSC.ZigString, client_protocol: *const JSC.ZigString, ) callconv(.C) ?*HTTPClient { std.debug.assert(global.bunVM().uws_event_loop != null); var client_protocol_hash: u64 = 0; var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null; var client: HTTPClient = HTTPClient{ .socket = undefined, .outgoing_websocket = websocket, .input_body_buf = body, .websocket_protocol = client_protocol_hash, }; var host_ = host.toSlice(bun.default_allocator); defer host_.deinit(); if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "tcp")) |out| { out.tcp.timeout(120); return out; } client.clearData(); return null; } pub fn clearInput(this: *HTTPClient) void { if (this.input_body_buf.len > 0) bun.default_allocator.free(this.input_body_buf); this.input_body_buf.len = 0; } pub fn clearData(this: *HTTPClient) void { this.clearInput(); if (this.body_buf) |buf| { this.body_buf = null; buf.release(); } } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { JSC.markBinding(); WebSocket__didFailWithErrorCode(this.outgoing_websocket, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(); this.clearData(); WebSocket__didFailWithErrorCode(this.outgoing_websocket, ErrorCode.ended); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); if (!this.tcp.isClosed()) this.tcp.close(0, null); } pub fn handleOpen(this: *HTTPClient, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); std.debug.assert(this.input_body_buf.len > 0); std.debug.assert(this.to_send.len == 0); const wrote = socket.write(this.input_body_buf, true); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; } fn getBody(this: *HTTPClient) *BodyBufBytes { if (this.body_buf == null) { this.body_buf = BodyBufPool.get(bun.default_allocator); } return &this.body_buf.?.data; } pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { std.debug.assert(socket.socket == this.tcp.socket); if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); var body = this.getBody(); var remain = body[this.body_written..]; const is_first = this.body_written == 0; if (is_first and data.len >= "HTTP/1.1 101 ".len) { // fail early if we receive a non-101 status code if (!strings.eqlComptimeIgnoreLen(data[0.."HTTP/1.1 101 ".len], "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } const to_write = remain[0..@minimum(remain.len, data.len)]; if (data.len > 0 and to_write.len > 0) { @memcpy(remain.ptr, data.ptr, to_write.len); this.body_written += to_write.len; } const overflow = data[to_write.len..]; const available_to_read = body[0..this.body_written]; const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { if (overflow.len > 0) { this.terminate(ErrorCode.headers_too_large); return; } return; }, } }; var buffered_body_data = body[@minimum(@intCast(usize, response.bytes_read), body.len)..]; buffered_body_data = buffered_body_data[0..@minimum(buffered_body_data.len, this.body_written)]; this.processResponse(response, buffered_body_data, overflow); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8, overflow_buf: []const u8) void { std.debug.assert(this.body_written > 0); var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; var visited_protocol = this.websocket_protocol == 0; // var visited_version = false; std.debug.assert(response.status_code == 101); if (remain_buf.len > 0) { std.debug.assert(overflow_buf.len == 0); } for (response.headers) |header| { switch (header.name.len) { "Connection".len => { if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) { connection_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Upgrade".len => { if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) { upgrade_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Version".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) { if (!strings.eqlComptimeIgnoreLen(header.value, "13")) { this.terminate(ErrorCode.invalid_websocket_version); return; } } }, "Sec-WebSocket-Accept".len => { if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) { websocket_accept_header = header; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, "Sec-WebSocket-Protocol".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } visited_protocol = true; if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) { break; } } }, else => {}, } } // if (!visited_version) { // this.terminate(ErrorCode.invalid_websocket_version); // return; // } if (@minimum(upgrade_header.name.len, upgrade_header.value.len) == 0) { this.terminate(ErrorCode.missing_upgrade_header); return; } if (@minimum(connection_header.name.len, connection_header.value.len) == 0) { this.terminate(ErrorCode.missing_connection_header); return; } if (@minimum(websocket_accept_header.name.len, websocket_accept_header.value.len) == 0) { this.terminate(ErrorCode.missing_websocket_accept_header); return; } if (!visited_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } if (!strings.eqlComptime(connection_header.value, "Upgrade")) { this.terminate(ErrorCode.invalid_connection_header); return; } if (!strings.eqlComptime(upgrade_header.value, "websocket")) { this.terminate(ErrorCode.invalid_upgrade_header); return; } // TODO: check websocket_accept_header.value const overflow_len = overflow_buf.len + remain_buf.len; var overflow: []u8 = &.{}; if (overflow_len > 0) { overflow = bun.default_allocator.alloc(u8, overflow_len) catch { this.terminate(ErrorCode.invalid_response); return; }; if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); if (overflow_buf.len > 0) @memcpy(overflow.ptr + remain_buf.len, overflow_buf.ptr, overflow_buf.len); } this.clearData(); JSC.markBinding(); WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); if (this.to_send.len == 0) return; const wrote = socket.write(this.to_send, true); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return; } std.debug.assert(@intCast(usize, wrote) >= this.to_send.len); this.to_send = this.to_send[@minimum(@intCast(usize, wrote), this.to_send.len)..]; } pub fn handleTimeout( this: *HTTPClient, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub const Export = shim.exportFunctions(.{ .connect = connect, .cancel = cancel, .register = register, }); comptime { if (!JSC.is_bindgen) { @export(connect, .{ .name = Export[0].symbol_name, }); @export(cancel, .{ .name = Export[1].symbol_name, }); @export(register, .{ .name = Export[2].symbol_name, }); } } }; } pub const Mask = struct { const rand_buffer_size = 64; bytes: [rand_buffer_size]u8 = [_]u8{0} ** rand_buffer_size, offset: usize = 0, needs_reset: bool = true, queued_reset: bool = false, pub fn from(globalObject: *JSC.JSGlobalObject) *Mask { return &globalObject.bunVM().rareData().websocket_mask; } pub fn get(this: *Mask) [4]u8 { if (this.needs_reset) { this.needs_reset = false; this.offset = 0; std.crypto.random.bytes(&this.bytes); } const offset = this.offset; const wrapped = offset % rand_buffer_size; var mask = this.bytes[wrapped..][0..4].*; if (offset > rand_buffer_size) { const wrapped2 = @truncate(u8, wrapped); mask[0] +%= wrapped2; mask[1] +%= wrapped2; mask[2] +%= wrapped2; mask[3] +%= wrapped2; } this.offset += 4; if (!this.queued_reset and this.offset % rand_buffer_size == 0) { this.queued_reset = true; uws.Loop.get().?.nextTick(*Mask, this, reset); } return mask; } pub fn fill(this: *Mask, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void { const mask = this.get(); mask_buf.* = mask; var input = input_; var output = output_; if (comptime Environment.isAarch64 or Environment.isX64) { if (input.len >= strings.ascii_vector_size) { const vec: strings.AsciiVector = @as(strings.AsciiVector, mask ** (strings.ascii_vector_size / 4)); const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size); while (input.ptr != end_ptr_wrapped_to_last_16) { const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; output = output[strings.ascii_vector_size..]; input = input[strings.ascii_vector_size..]; } } // hint to the compiler not to vectorize the next loop std.debug.assert(input.len < strings.ascii_vector_size); } while (input.len >= 4) { const input_vec: [4]u8 = input[0..4].*; output.ptr[0..4].* = [4]u8{ input_vec[0] ^ mask[0], input_vec[1] ^ mask[1], input_vec[2] ^ mask[2], input_vec[3] ^ mask[3], }; output = output[4..]; input = input[4..]; } for (input) |c, i| { output[i] = c ^ mask[i % 4]; } } pub fn reset(this: *Mask) void { this.queued_reset = false; this.needs_reset = true; } }; const ReceiveState = enum { need_header, need_mask, need_body, extended_payload_length_16, extended_payload_length_64, ping, closing, fail, pub fn needControlFrame(this: ReceiveState) bool { return this != .need_body; } }; const DataType = enum { none, text, binary, }; fn parseWebSocketHeader( bytes: [2]u8, receiving_type: *Opcode, payload_length: *usize, is_fragmented: *bool, is_final: *bool, need_compression: *bool, ) ReceiveState { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ const header = @bitCast(WebsocketHeader, @bitCast(u16, bytes)); const payload = @as(usize, header.len); payload_length.* = payload; receiving_type.* = header.opcode; is_fragmented.* = switch (header.opcode) { .Continue => true, else => false, }; is_final.* = header.final; need_compression.* = header.compressed; if (header.mask) { return .need_mask; } return switch (header.opcode) { .Text, .Continue, .Binary => switch (payload) { 0...125 => ReceiveState.need_body, 126 => ReceiveState.extended_payload_length_16, 127 => ReceiveState.extended_payload_length_64, else => unreachable, }, .Close => ReceiveState.close, .Ping => ReceiveState.ping, .Pong => ReceiveState.pong, else => ReceiveState.fail, }; } const Copy = union(enum) { utf16: []const u16, latin1: []const u8, bytes: []const u8, raw: []const u8, pub fn len(this: @This(), byte_len: *usize) usize { switch (this) { .utf16 => { byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { byte_len.* = this.bytes.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .raw => return this.raw.len, } } pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void { switch (this) { .utf16 => |utf16| { const length_offset = 2; const length_length = WebsocketHeader.lengthByteCount(content_byte_len); const mask_offset = length_offset + length_length; const content_offset = mask_offset + 4; var to_mask = buf[content_offset..]; const encode_into_result = strings.copyUTF16IntoUTF8(utf16, to_mask); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == utf16.len); var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.len = WebsocketHeader.packLength(content_byte_len); header.mask = true; header.opcode = Opcode.Text; header.compressed = false; header.final = true; header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); }, .latin1 => |latin1| { const length_offset = 2; const length_length = WebsocketHeader.lengthByteCount(content_byte_len); const mask_offset = length_offset + length_length; const content_offset = mask_offset + 4; var to_mask = buf[content_offset..]; const encode_into_result = strings.copyLatin1IntoUTF8(latin1, to_mask); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.len = WebsocketHeader.packLength(content_byte_len); header.mask = true; header.opcode = Opcode.Text; header.compressed = false; header.final = true; header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); }, .bytes => |bytes| { const length_offset = 2; const length_length = WebsocketHeader.lengthByteCount(bytes.len); const mask_offset = length_offset + length_length; const content_offset = mask_offset + 4; var to_mask = buf[content_offset..]; var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.len = WebsocketHeader.packLength(bytes.len); header.mask = true; header.opcode = Opcode.Text; header.compressed = false; header.final = true; header.writeHeader(std.io.fixedBufferStream(buf), bytes.len) catch unreachable; Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); }, .raw => @memcpy(buf.ptr, this.raw.ptr, this.raw.len), } } }; pub fn NewWebSocketClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, outgoing_websocket: ?*anyopaque, receive_state: ReceiveState = ReceiveState.need_header, receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)), receive_remaining: usize = 0, receiving_type: Opcode = Opcode.ResB, ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** 128 + 6, ping_len: u8 = 0, receive_frame: usize = 0, receive_body_remain: usize = 0, receive_pending_chunk_len: usize = 0, receive_body_buf: ?*BodyBuf = null, receive_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, send_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, send_body_buf: ?*BodyBuf = null, send_len: usize = 0, send_off: usize = 0, globalThis: *JSC.JSGlobalObject, pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); const HTTPClient = @This(); pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, parent: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, loop_); var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); if (vm.uws_event_loop) |other| { std.debug.assert(other == loop); } vm.uws_event_loop = loop; Socket.configureChild( ctx, parent, HTTPClient, null, handleClose, handleData, handleWritable, handleTimeout, handleConnectError, handleEnd, ); } pub fn clearData(this: *HTTPClient) void { this.clearReceiveBuffers(true); this.clearSendBuffers(true); this.ping_len = 0; this.receive_pending_chunk_len = 0; this.receive_remaining = 0; } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (this.tcp.isClosed() or this.tcp.isShutdown()) return; if (!this.tcp.isEstablished()) { _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { JSC.markBinding(); if (this.outgoing_websocket) |ws| WebSocket__didFailWithErrorCode(ws, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(); this.clearData(); if (this.outgoing_websocket) |ws| WebSocket__didFailWithErrorCode(ws, ErrorCode.ended); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); } fn getBody(this: *HTTPClient) *BodyBufBytes { if (this.send_body_buf == null) { this.send_body_buf = BodyBufPool.get(bun.default_allocator); } return &this.send_body_buf.?.data; } fn getReceiveBody(this: *HTTPClient) *BodyBufBytes { if (this.receive_body_buf == null) { this.receive_body_buf = BodyBufPool.get(bun.default_allocator); } return &this.receive_body_buf.?.data; } fn clearReceiveBuffers(this: *HTTPClient, free: bool) void { if (this.receive_body_buf) |receive_buf| { receive_buf.release(); this.receive_body_buf = null; } if (free) { this.receive_overflow_buffer.clearAndFree(bun.default_allocator); } else { this.receive_overflow_buffer.clearRetainingCapacity(); } } fn clearSendBuffers(this: *HTTPClient, free: bool) void { if (this.send_body_buf) |buf| { buf.release(); this.send_body_buf = null; } if (free) { this.send_overflow_buffer.clearAndFree(bun.default_allocator); } else { this.send_overflow_buffer.clearRetainingCapacity(); } this.send_off = 0; this.send_len = 0; } fn dispatchData(this: *HTTPClient, data_: []const u8, kind: Opcode) void { switch (kind) { .Text => { // this function encodes to UTF-16 if > 127 // so we don't need to worry about latin1 non-ascii code points const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch { this.terminate(ErrorCode.invalid_utf8); return 0; }; defer this.clearReceiveBuffers(false); var outstring = JSC.ZigString.Empty; if (utf16_bytes_) |utf16| { outstring = JSC.ZigString.from16Slice(utf16); outstring.markUTF16(); WebSocket__didReceiveText(this.outgoing_websocket, false, &outstring); } else { outstring = JSC.ZigString.init(data_); WebSocket__didReceiveText(this.outgoing_websocket, true, &outstring); } }, .Binary => { WebSocket__didReceiveBytes(this.outgoing_websocket, data_); this.clearReceiveBuffers(false); }, else => unreachable, } } pub fn consume(this: *HTTPClient, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize { std.debug.assert(kind == .Text or kind == .Binary); std.debug.assert(data_.len <= max); const can_dispatch_data = is_final and data_.len == max; // did all the data fit in the buffer? // we can avoid copying & allocating a temporary buffer if (can_dispatch_data and this.receive_pending_chunk_len == 0) { this.dispatchData(data_, kind); return data_.len; } // if we previously allocated a buffer and there's room, attempt to use that one const new_pending_chunk_len = max + this.receive_pending_chunk_len; if (new_pending_chunk_len <= this.receive_overflow_buffer.capacity) { @memcpy(this.receive_overflow_buffer.items.ptr + this.receive_overflow_buffer.items.len, data_.ptr, data_.len); this.receive_overflow_buffer.items.len += data_.len; if (can_dispatch_data) { this.dispatchData(this.receive_overflow_buffer.items, kind); this.receive_pending_chunk_len = 0; } else { this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len; } return data_.len; } if (new_pending_chunk_len <= body_buf_len) { // if our previously-allocated buffer is too small or we don't have one, use from the pool var body = this.getReceiveBody(); @memcpy(body + this.receive_pending_chunk_len, data_.ptr, data_.len); if (can_dispatch_data) { this.dispatchData(body[0..new_pending_chunk_len], kind); this.receive_pending_chunk_len = 0; } else { this.receive_pending_chunk_len += data_.len; } return data_.len; } { // we need to copy the data into a potentially large temporary buffer this.receive_overflow_buffer.appendSlice(bun.default_allocator, data_) catch { this.terminate(ErrorCode.failed_to_allocate_memory); return 0; }; if (can_dispatch_data) { this.dispatchData(this.receive_overflow_buffer.items, kind); this.receive_pending_chunk_len = 0; } else { this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len; } return data_.len; } } pub fn handleData(this: *HTTPClient, socket: Socket, data_: []const u8) void { var data = data_; var receive_state = this.receive_state; var terminated = false; var is_fragmented = false; var receiving_type = this.receiving_type; var receive_body_remain = this.receive_body_remain; var is_final = false; var last_receive_data_type = receiving_type; defer { if (!terminated) { this.receive_state = receive_state; this.receiving_type = last_receive_data_type; this.receive_body_remain = receive_body_remain; // if we receive multiple pings in a row // we just send back the last one if (this.ping_len > 0) { _ = this.sendPong(socket); this.ping_len = 0; } } } var header_bytes: [@sizeOf(usize)]u8 = [_]u8{0} ** @sizeOf(usize); while (true) { switch (receive_state) { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-------+-+-------------+-------------------------------+ // |F|R|R|R| opcode|M| Payload len | Extended payload length | // |I|S|S|S| (4) |A| (7) | (16/64) | // |N|V|V|V| |S| | (if payload len==126/127) | // | |1|2|3| |K| | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + // | Extended payload length continued, if payload len == 127 | // + - - - - - - - - - - - - - - - +-------------------------------+ // | |Masking-key, if MASK set to 1 | // +-------------------------------+-------------------------------+ // | Masking-key (continued) | Payload Data | // +-------------------------------- - - - - - - - - - - - - - - - + // : Payload Data continued ... : // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // | Payload Data continued ... | // +---------------------------------------------------------------+ .need_header => { if (data.len < 2) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } header_bytes[0..2].* = data[0..2].*; receive_body_remain = 0; var need_compression = false; is_final = false; receive_state = parseWebSocketHeader( header_bytes[0..2], &receiving_type, &receive_body_remain, &is_fragmented, &is_final, &need_compression, ); if (receiving_type == .Text or receiving_type == .Binary) { last_receive_data_type = receiving_type; } data = data[2..]; if (receiving_type.isControl() and is_fragmented) { // Control frames must not be fragmented. this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } switch (receiving_type) { .Continue, .Text, .Binary, .Ping, .Pong, .Close => {}, else => { this.terminate(ErrorCode.unexpected_opcode); terminated = true; break; }, } if (need_compression) { this.terminate(ErrorCode.compression_unsupported); terminated = true; break; } }, .need_mask => { this.terminate(.unexpected_mask_from_server); terminated = true; break; }, .extended_payload_length_64, .extended_payload_length_16 => |rc| { const byte_size = switch (rc) { .extended_payload_length_64 => @as(usize, 8), .extended_payload_length_16 => @as(usize, 2), else => unreachable, }; if (data.len < byte_size) { this.terminate(ErrorCode.control_frame_is_fragmented); terminated = true; break; } // Multibyte length quantities are expressed in network byte order receive_body_remain = switch (byte_size) { 8 => @as(usize, std.mem.readIntBig(u64, data[0..8].*)), 2 => @as(usize, std.mem.readIntBig(u16, data[0..2].*)), else => unreachable, }; data = data[byte_size..]; receive_state = .need_body; if (receive_body_remain == 0) { // this is an error // the server should've set length to zero this.terminate(ErrorCode.invalid_control_frame); terminated = true; break; } }, .ping => { const ping_len = @minimum(data.len, @minimum(receive_body_remain, 125)); this.ping_len = @truncate(u8, ping_len); if (ping_len > 0) { @memcpy(&this.ping_frame_bytes + 6, data.ptr, ping_len); data = data[ping_len..]; } receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .pong => { const pong_len = @minimum(data.len, @minimum(receive_body_remain, this.ping_frame_bytes.len)); data = data[pong_len..]; receive_state = .need_header; receiving_type = last_receive_data_type; if (data.len == 0) break; }, .need_body => { if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; break; } if (data.len == 0) return; const to_consume = @minimum(receive_body_remain, data.len); const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final); if (consumed == 0 and last_receive_data_type == .Text) { this.terminate(ErrorCode.invalid_utf8); terminated = true; break; } receive_body_remain -= consumed; data = data[to_consume..]; if (receive_body_remain == 0) { receive_state = .need_header; is_fragmented = false; } if (data.len == 0) break; }, .close => { // closing frame data is text only. _ = this.consume(data[0..receive_body_remain], receive_body_remain, .Text, true); this.sendClose(); terminated = true; break; }, .fail => { this.terminate(ErrorCode.unexpected_control_frame); terminated = true; break; }, } } } pub fn sendClose(this: *HTTPClient) void { this.sendCloseWithBody(this.tcp, 1001, null, null, 0); } fn enqueueEncodedBytesMaybeFinal( this: *HTTPClient, socket: Socket, bytes: []const u8, is_closing: bool, ) bool { // fast path: no backpressure, no queue, just send the bytes. if (this.send_len == 0) { const wrote = socket.write(bytes, !is_closing); const expected = @intCast(c_int, bytes.len); if (wrote == expected) { return true; } if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } _ = this.copyToSendBuffer(bytes[bytes.len - @intCast(usize, wrote) ..], false, is_closing, false); return true; } return this.copyToSendBuffer(bytes, true, is_closing, false); } fn copyToSendBuffer(this: *HTTPClient, bytes: []const u8, do_write: bool, is_closing: bool) bool { return this.sendData(.{ .raw = bytes }, do_write, is_closing); } fn sendData(this: *HTTPClient, bytes: Copy, do_write: bool, is_closing: bool) bool { var content_byte_len: usize = 0; const write_len = bytes.len(&content_byte_len); std.debug.assert(write_len > 0); this.send_len += write_len; var out_buf: []const u8 = ""; const send_end = this.send_off + this.send_len; var ring_buffer = &this.send_overflow_buffer; if (send_end <= ring_buffer.capacity) { bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len); ring_buffer.items.len += write_len; out_buf = ring_buffer.items[this.send_off..send_end]; } else if (send_end <= body_buf_len) { var buf = this.getBody(); bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len); out_buf = buf[this.send_off..send_end]; } else { if (this.send_body_buf) |send_body| { // transfer existing send buffer to overflow buffer ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch { this.terminate(ErrorCode.failed_to_allocate_memory); return false; }; const off = this.send_len - write_len; @memcpy(ring_buffer.items.ptr, send_body + this.send_off, off); send_body.release(); bytes.copy(this.globalThis, (ring_buffer.items.ptr + off)[0..write_len], content_byte_len); ring_buffer.items.len = this.send_len; this.send_body_buf = null; this.send_off = 0; } else if (send_end <= ring_buffer.capacity) { bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len); ring_buffer.items.len += write_len; // can we treat it as a ring buffer without re-allocating the array? } else if (send_end - this.send_off <= ring_buffer.capacity) { std.mem.copyBackwards(u8, ring_buffer.items[0..this.send_len], ring_buffer.items[this.send_off..send_end]); bytes.copy(this.globalThis, ring_buffer.items.ptr[this.send_len - write_len .. this.send_len], content_byte_len); ring_buffer.items.len = this.send_len; this.send_off = 0; } else { // we need to re-allocate the array ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch { this.terminate(ErrorCode.failed_to_allocate_memory); return false; }; const data_to_copy = ring_buffer.items[this.send_off..][0..(this.send_len - write_len)]; const position_in_slice = ring_buffer.items[0 .. this.send_len - write_len]; if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr) std.mem.copyBackwards( u8, position_in_slice, data_to_copy, ); ring_buffer.items.len = this.send_len; bytes.copy(this.globalThis, ring_buffer.items[this.send_len - write_len ..], content_byte_len); this.send_off = 0; } out_buf = ring_buffer.items[this.send_off..send_end]; } if (do_write) { if (comptime Environment.allow_assert) { std.debug.assert(!this.tcp.isShutdown()); std.debug.assert(!this.tcp.isClosed()); std.debug.assert(this.tcp.isEstablished()); } return this.sendBuffer(out_buf, is_closing); } return true; } fn sendBuffer(this: *HTTPClient, out_buf: []const u8, is_closing: bool) bool { std.debug.assert(out_buf.len > 0); const wrote = this.tcp.write(out_buf, !is_closing); const expected = @intCast(c_int, out_buf.len); if (wrote == expected) { this.clearSendBuffers(false); return true; } if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } this.send_len -= @intCast(usize, wrote); this.send_off += @intCast(usize, wrote); return true; } fn enqueueEncodedBytes(this: *HTTPClient, socket: Socket, bytes: []const u8) bool { return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false); } fn sendPong(this: *HTTPClient, socket: Socket) bool { if (socket.isClosed() or socket.isShutdown()) { this.dispatchClose(); return false; } var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.fin = true; header.opcode = .Pong; var to_mask = &this.ping_frame_bytes[6..][0..this.ping_len]; header.mask = to_mask.len > 0; header.len = @truncate(u7, this.ping_len); this.ping_frame_bytes[0..2].* = @bitCast(u16, header); if (to_mask.len > 0) { Mask.from(this.globalThis).fill(&this.ping_frame_bytes[2..6], &to_mask, &to_mask); return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]); } else { return this.enqueueEncodedBytes(socket, &this.ping_frame_bytes[0..2]); } } fn sendCloseWithBody( this: *HTTPClient, socket: Socket, code: u16, body: ?*[125]u8, body_len: usize, ) void { if (socket.isClosed() or socket.isShutdown()) { this.dispatchClose(); this.clearData(); return; } socket.shutdownRead(); var final_body_bytes: [128 + 8]u8 = undefined; var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.fin = true; header.opcode = .Close; header.mask = true; header.len = body_len + 2; final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header)); var mask_buf = &final_body_bytes[2..6]; std.mem.writeIntSliceBig(u16, &final_body_bytes[6..8], code); if (body) |data| { if (body_len > 0) @memcpy(&final_body_bytes[8..], data, body_len); } // we must mask the code var slice = final_body_bytes[0..(8 + body_len)]; Mask.from(this.globalThis).fill(mask_buf, slice[6..], slice[6..]); if (this.enqueueEncodedBytesMaybeFinal(socket, slice, true)) { this.dispatchClose(); this.clearData(); } } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); if (this.send_len == 0) return; var send_buf: []const u8 = undefined; if (this.send_body_buf) |send_body_buf| { send_buf = send_body_buf[this.send_off..][0..this.send_len]; std.debug.assert(this.send_overflow_buffer.items.len == 0); } else { send_buf = this.send_overflow_buffer.items[this.send_off..][0..this.send_len]; } std.debug.assert(send_buf.len == this.send_len); _ = this.sendBuffer(send_buf, false); } pub fn handleTimeout( this: *HTTPClient, _: Socket, ) void { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } pub fn writeBinaryData( this: *HTTPClient, ptr: [*]const u8, len: usize, ) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchClose(); return; } if (len == 0) return; const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating const frame_size = WebsocketHeader.frameSizeIncludingMask(len); if (this.send_len == 0 and frame_size < 128 + 4) { var inline_buf: [128 + 4]u8 = undefined; bytes.copy(this.globalThis, &inline_buf[0..frame_size], slice.len); _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); return; } _ = this.sendData(bytes, true, false); } pub fn writeString( this: *HTTPClient, str_: *const JSC.ZigString, ) callconv(.C) void { const str = str_.*; if (this.tcp.isClosed() or this.tcp.isShutdown()) { this.dispatchClose(); return; } if (str.len == 0) { return; } { var inline_buf: [128 + 4]u8 = undefined; // fast path: small frame, no backpressure, attempt to send without allocating if (!str.is16Bit() and str.len < 128 + 4) { const bytes = Copy{ .latin1 = str.slice() }; const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); if (this.send_len == 0 and frame_size < 128 + 4) { bytes.copy(this.globalThis, &inline_buf[0..frame_size], str.len); _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); return; } // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string } else if ((str.len * 4) < (128 + 4) and this.send_len == 0) { const bytes = Copy{ .utf16 = str.slice() }; var byte_len: usize = 0; const frame_size = bytes.len(&byte_len); std.debug.assert(frame_size <= 128 + 4); bytes.copy(this.globalThis, &inline_buf[0..frame_size], byte_len); _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); return; } } _ = this.sendData( if (str.is16Bit()) Copy{ .utf16 = str.utf16SliceAligned() } else Copy{ .latin1 = str.slice() }, true, false, ); } pub fn close(this: *HTTPClient, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) return; var close_reason_buf: [128]u8 = undefined; if (reason) |str| { inner: { var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf); const allocator = fixed_buffer.get(); const wrote = std.fmt.allocPrint(allocator, "{}", str.*) catch break :inner; this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len); return; } } this.sendCloseWithBody(this.tcp, code, null, 0); } pub fn init( outgoing: *anyopaque, input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, ) callconv(.C) *anyopaque { var tcp = @ptrCast(*uws.Socket, input_socket); var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx); return @ptrCast( *anyopaque, Socket.adopt(tcp, ctx, HTTPClient{ .outgoing_websocket = outgoing, .globalThis = globalThis, }), ); } pub fn finalize(this: *HTTPClient) callconv(.C) void { this.clearData(); if (this.tcp.isClosed()) return; this.tcp.close(0, null); this.outgoing_websocket = null; } pub const Export = shim.exportFunctions(.{ .writeBinaryData = writeBinaryData, .writeString = writeString, .close = close, .register = register, .init = init, .finalize = finalize, }); comptime { if (!JSC.is_bindgen) { @export(writeBinaryData, .{ .name = Export[0].symbol_name }); @export(writeString, .{ .name = Export[1].symbol_name }); @export(close, .{ .name = Export[2].symbol_name }); @export(register, .{ .name = Export[3].symbol_name }); @export(init, .{ .name = Export[4].symbol_name }); @export(finalize, .{ .name = Export[5].symbol_name }); } } }; } pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false); pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true); pub const WebSocketClient = NewWebSocketClient(false); pub const WebSocketClientTLS = NewWebSocketClient(true);