diff options
author | 2022-06-19 03:59:08 -0700 | |
---|---|---|
committer | 2022-06-22 06:56:47 -0700 | |
commit | dda85d92c9bafd0fe86540efd0f30be3e6c08c03 (patch) | |
tree | af7280cb8de11bf521b195e6cfe46604050b627c | |
parent | ab888d2ebebea0d128f4151a4240180211d95f03 (diff) | |
download | bun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.tar.gz bun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.tar.zst bun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.zip |
implement a custom websocket client
-rw-r--r-- | src/deps/uws.zig | 43 | ||||
-rw-r--r-- | src/http/websocket.zig | 23 | ||||
-rw-r--r-- | src/http/websocket_http_client.zig | 1087 | ||||
-rw-r--r-- | src/javascript/jsc/bindings/ScriptExecutionContext.cpp | 14 | ||||
-rw-r--r-- | src/javascript/jsc/bindings/bindings.zig | 4 | ||||
-rw-r--r-- | src/javascript/jsc/bindings/exports.zig | 4 | ||||
-rw-r--r-- | src/javascript/jsc/bindings/webcore/WebSocket.cpp | 65 | ||||
-rw-r--r-- | src/javascript/jsc/bindings/webcore/WebSocket.h | 2 | ||||
-rw-r--r-- | src/javascript/jsc/rare_data.zig | 4 |
9 files changed, 1171 insertions, 75 deletions
diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 08037a890..48c93ab53 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -70,11 +70,11 @@ pub fn NewSocketHandler(comptime ssl: bool) type { this.socket, ) > 0; } - pub fn isClosed(this: ThisSocket) c_int { + pub fn isClosed(this: ThisSocket) bool { return us_socket_is_closed( comptime ssl_int, this.socket, - ); + ) > 0; } pub fn close(this: ThisSocket, code: c_int, reason: ?*anyopaque) void { _ = us_socket_close( @@ -191,13 +191,38 @@ pub fn NewSocketHandler(comptime ssl: bool) type { } }; - us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); - us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); - us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); - us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); - us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); - us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); - us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); + if (comptime @typeInfo(@TypeOf(onOpen)) != .Null) + us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); + if (comptime @typeInfo(@TypeOf(onClose)) != .Null) + us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); + if (comptime @typeInfo(@TypeOf(onData)) != .Null) + us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); + if (comptime @typeInfo(@TypeOf(onWritable)) != .Null) + us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); + if (comptime @typeInfo(@TypeOf(onTimeout)) != .Null) + us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); + if (comptime @typeInfo(@TypeOf(onConnectError)) != .Null) + us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); + if (comptime @typeInfo(@TypeOf(onEnd)) != .Null) + us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); + } + + pub fn adopt( + socket: *Socket, + socket_ctx: *us_socket_context_t, + comptime Context: type, + comptime socket_field_name: []const u8, + ctx: Context, + ) ?*Context { + var adopted = Socket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(Context)) orelse return null }; + var holder = adopted.ext(Context) orelse { + if (comptime bun.Environment.allow_assert) unreachable; + _ = us_socket_close(comptime ssl_int, socket); + return null; + }; + holder.* = ctx; + @field(holder, socket_field_name) = adopted; + return holder; } }; } diff --git a/src/http/websocket.zig b/src/http/websocket.zig index 293c67757..6bb433ea3 100644 --- a/src/http/websocket.zig +++ b/src/http/websocket.zig @@ -65,13 +65,30 @@ pub const WebsocketHeader = packed struct { } pub fn packLength(length: usize) u7 { - var eight: u7 = 0; - eight = switch (length) { + return switch (length) { 0...126 => @truncate(u7, length), 127...0xFFFF => 126, else => 127, }; - return eight; + } + + const mask_length = 4; + const header_length = 2; + + pub fn lengthByteCount(byte_length: usize) usize { + return switch (byte_length) { + 0...126 => 0, + 127...0xFFFF => @sizeOf(u16), + else => @sizeOf(u64), + }; + } + + pub fn frameSize(byte_length: usize) usize { + return header_length + byte_length + lengthByteCount(byte_length); + } + + pub fn frameSizeIncludingMask(byte_length: usize) usize { + return frameSize(byte_length) + mask_length; } }; diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 0a141c4c2..2be382dbe 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -3,12 +3,6 @@ const std = @import("std"); const native_endian = @import("builtin").target.cpu.arch.endian(); -const tcp = std.x.net.tcp; -const ip = std.x.net.ip; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const os = std.os; const bun = @import("../global.zig"); const string = bun.string; const Output = bun.Output; @@ -24,6 +18,9 @@ const uws = @import("uws"); const JSC = @import("javascript_core"); const PicoHTTP = @import("picohttp"); const ObjectPool = @import("../pool.zig").ObjectPool; +const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; +const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; +const Opcode = @import("./websocket.zig").Opcode; fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; @@ -89,6 +86,14 @@ const ErrorCode = enum(i32) { failed_to_connect, headers_too_large, ended, + failed_to_allocate_memory, + control_frame_is_fragmented, + invalid_control_frame, + compression_unsupported, + unexpected_mask_from_server, + expected_control_frame, + unexpected_opcode, + invalid_utf8, }; extern fn WebSocket__didConnect( websocket_context: *anyopaque, @@ -96,9 +101,12 @@ extern fn WebSocket__didConnect( buffered_data: ?[*]u8, buffered_len: usize, ) void; -extern fn WebSocket__didFailToConnect(websocket_context: *anyopaque, reason: ErrorCode) void; +extern fn WebSocket__didFailWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; +extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void; +extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: []const u8) void; -const BodyBufBytes = [16384 - 16]u8; +const body_buf_len = 16384 - 16; +const BodyBufBytes = [body_buf_len]u8; const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4); const BodyBuf = BodyBufPool.Node; @@ -106,7 +114,7 @@ const BodyBuf = BodyBufPool.Node; pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); - socket: Socket, + tcp: Socket, outgoing_websocket: *anyopaque, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", @@ -159,8 +167,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { var host_ = host.toSlice(bun.default_allocator); defer host_.deinit(); - if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "socket")) |out| { - out.socket.timeout(120); + if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "tcp")) |out| { + out.tcp.timeout(120); return out; } @@ -183,33 +191,33 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); - if (!this.socket.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.socket.socket); + if (!this.tcp.isEstablished()) { + _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); } else { - this.socket.close(0, null); + this.tcp.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { JSC.markBinding(); - WebSocket__didFailToConnect(this.outgoing_websocket, code); + WebSocket__didFailWithErrorCode(this.outgoing_websocket, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(); this.clearData(); - WebSocket__didFailToConnect(this.outgoing_websocket, ErrorCode.ended); + WebSocket__didFailWithErrorCode(this.outgoing_websocket, ErrorCode.ended); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); - if (this.socket.isClosed() == 0) - this.socket.close(0, null); + if (!this.tcp.isClosed()) + this.tcp.close(0, null); } pub fn handleOpen(this: *HTTPClient, socket: Socket) void { - std.debug.assert(socket.socket == this.socket.socket); + std.debug.assert(socket.socket == this.tcp.socket); std.debug.assert(this.input_body_buf.len > 0); std.debug.assert(this.to_send.len == 0); @@ -232,7 +240,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { - std.debug.assert(socket.socket == this.socket.socket); + std.debug.assert(socket.socket == this.tcp.socket); if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); @@ -280,7 +288,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { - std.debug.assert(socket.socket == this.socket.socket); + std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } @@ -399,14 +407,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.clearData(); JSC.markBinding(); - WebSocket__didConnect(this.outgoing_websocket, this.socket.socket, overflow.ptr, overflow.len); + WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { - std.debug.assert(socket.socket == this.socket.socket); + std.debug.assert(socket.socket == this.tcp.socket); if (this.to_send.len == 0) return; @@ -451,5 +459,1038 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { }; } +pub const Mask = struct { + const rand_buffer_size = 64; + bytes: [rand_buffer_size]u8 = [_]u8{0} ** rand_buffer_size, + offset: usize = 0, + needs_reset: bool = true, + queued_reset: bool = false, + + pub fn from(globalObject: *JSC.JSGlobalObject) *Mask { + return &globalObject.bunVM().rareData().websocket_mask; + } + + pub fn get(this: *Mask) [4]u8 { + if (this.needs_reset) { + this.needs_reset = false; + this.offset = 0; + std.crypto.random.bytes(&this.bytes); + } + const offset = this.offset; + const wrapped = offset % rand_buffer_size; + var mask = this.bytes[wrapped..][0..4].*; + if (offset > rand_buffer_size) { + const wrapped2 = @truncate(u8, wrapped); + mask[0] +%= wrapped2; + mask[1] +%= wrapped2; + mask[2] +%= wrapped2; + mask[3] +%= wrapped2; + } + + this.offset += 4; + + if (!this.queued_reset and this.offset % rand_buffer_size == 0) { + this.queued_reset = true; + uws.Loop.get().?.nextTick(*Mask, this, reset); + } + + return mask; + } + + pub fn fill(this: *Mask, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void { + const mask = this.get(); + mask_buf.* = mask; + var input = input_; + var output = output_; + if (comptime Environment.isAarch64 or Environment.isX64) { + if (input.len >= strings.ascii_vector_size) { + const vec: strings.AsciiVector = @as(strings.AsciiVector, mask ** (strings.ascii_vector_size / 4)); + const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size); + while (input.ptr != end_ptr_wrapped_to_last_16) { + const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); + output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; + output = output[strings.ascii_vector_size..]; + input = input[strings.ascii_vector_size..]; + } + } + + // hint to the compiler not to vectorize the next loop + std.debug.assert(input.len < strings.ascii_vector_size); + } + + while (input.len >= 4) { + const input_vec: [4]u8 = input[0..4].*; + output.ptr[0..4].* = [4]u8{ + input_vec[0] ^ mask[0], + input_vec[1] ^ mask[1], + input_vec[2] ^ mask[2], + input_vec[3] ^ mask[3], + }; + output = output[4..]; + input = input[4..]; + } + + for (input) |c, i| { + output[i] = c ^ mask[i % 4]; + } + } + + pub fn reset(this: *Mask) void { + this.queued_reset = false; + this.needs_reset = true; + } +}; + +const ReceiveState = enum { + need_header, + need_mask, + need_body, + extended_payload_length_16, + extended_payload_length_64, + ping, + closing, + fail, + + pub fn needControlFrame(this: ReceiveState) bool { + return this != .need_body; + } +}; +const DataType = enum { + none, + text, + binary, +}; + +fn parseWebSocketHeader( + bytes: [2]u8, + receiving_type: *Opcode, + payload_length: *usize, + is_fragmented: *bool, + is_final: *bool, + need_compression: *bool, +) ReceiveState { + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + const header = @bitCast(WebsocketHeader, @bitCast(u16, bytes)); + const payload = @as(usize, header.len); + payload_length.* = payload; + receiving_type.* = header.opcode; + is_fragmented.* = switch (header.opcode) { + .Continue => true, + else => false, + }; + is_final.* = header.final; + need_compression.* = header.compressed; + + if (header.mask) { + return .need_mask; + } + + return switch (header.opcode) { + .Text, .Continue, .Binary => switch (payload) { + 0...125 => ReceiveState.need_body, + 126 => ReceiveState.extended_payload_length_16, + 127 => ReceiveState.extended_payload_length_64, + else => unreachable, + }, + .Close => ReceiveState.close, + .Ping => ReceiveState.ping, + .Pong => ReceiveState.pong, + else => ReceiveState.fail, + }; +} + +const Copy = union(enum) { + utf16: []const u16, + latin1: []const u8, + bytes: []const u8, + raw: []const u8, + + pub fn len(this: @This(), byte_len: *usize) usize { + switch (this) { + .utf16 => { + byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16); + return WebsocketHeader.frameSizeIncludingMask(byte_len.*); + }, + .latin1 => { + byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); + return WebsocketHeader.frameSizeIncludingMask(byte_len.*); + }, + .bytes => { + byte_len.* = this.bytes.len; + return WebsocketHeader.frameSizeIncludingMask(byte_len.*); + }, + .raw => return this.raw.len, + } + } + + pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void { + switch (this) { + .utf16 => |utf16| { + const length_offset = 2; + const length_length = WebsocketHeader.lengthByteCount(content_byte_len); + const mask_offset = length_offset + length_length; + const content_offset = mask_offset + 4; + var to_mask = buf[content_offset..]; + const encode_into_result = strings.copyUTF16IntoUTF8(utf16, to_mask); + std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); + std.debug.assert(@as(usize, encode_into_result.read) == utf16.len); + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + header.len = WebsocketHeader.packLength(content_byte_len); + header.mask = true; + header.opcode = Opcode.Text; + header.compressed = false; + header.final = true; + header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + }, + .latin1 => |latin1| { + const length_offset = 2; + const length_length = WebsocketHeader.lengthByteCount(content_byte_len); + const mask_offset = length_offset + length_length; + const content_offset = mask_offset + 4; + var to_mask = buf[content_offset..]; + const encode_into_result = strings.copyLatin1IntoUTF8(latin1, to_mask); + std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); + std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + header.len = WebsocketHeader.packLength(content_byte_len); + header.mask = true; + header.opcode = Opcode.Text; + header.compressed = false; + header.final = true; + header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + }, + .bytes => |bytes| { + const length_offset = 2; + const length_length = WebsocketHeader.lengthByteCount(bytes.len); + const mask_offset = length_offset + length_length; + const content_offset = mask_offset + 4; + var to_mask = buf[content_offset..]; + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + header.len = WebsocketHeader.packLength(bytes.len); + header.mask = true; + header.opcode = Opcode.Text; + header.compressed = false; + header.final = true; + header.writeHeader(std.io.fixedBufferStream(buf), bytes.len) catch unreachable; + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + }, + .raw => @memcpy(buf.ptr, this.raw.ptr, this.raw.len), + } + } +}; + +pub fn NewWebSocketClient(comptime ssl: bool) type { + return struct { + pub const Socket = uws.NewSocketHandler(ssl); + tcp: Socket, + outgoing_websocket: ?*anyopaque, + + receive_state: ReceiveState = ReceiveState.need_header, + receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)), + receive_remaining: usize = 0, + receiving_type: Opcode = Opcode.ResB, + + ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** 128 + 6, + ping_len: u8 = 0, + + receive_frame: usize = 0, + receive_body_remain: usize = 0, + receive_pending_chunk_len: usize = 0, + receive_body_buf: ?*BodyBuf = null, + receive_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, + send_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, + + send_body_buf: ?*BodyBuf = null, + send_len: usize = 0, + send_off: usize = 0, + + globalThis: *JSC.JSGlobalObject, + + pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; + + pub const shim = JSC.Shimmer("Bun", name, @This()); + + const HTTPClient = @This(); + + pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, parent: *anyopaque, ctx_: *anyopaque) callconv(.C) void { + var vm = global.bunVM(); + var loop = @ptrCast(*uws.Loop, loop_); + var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); + + if (vm.uws_event_loop) |other| { + std.debug.assert(other == loop); + } + + vm.uws_event_loop = loop; + + Socket.configureChild( + ctx, + parent, + HTTPClient, + null, + handleClose, + handleData, + handleWritable, + handleTimeout, + handleConnectError, + handleEnd, + ); + } + + pub fn clearData(this: *HTTPClient) void { + this.clearReceiveBuffers(true); + this.clearSendBuffers(true); + this.ping_len = 0; + this.receive_pending_chunk_len = 0; + this.receive_remaining = 0; + } + + pub fn cancel(this: *HTTPClient) callconv(.C) void { + this.clearData(); + + if (this.tcp.isClosed() or this.tcp.isShutdown()) + return; + + if (!this.tcp.isEstablished()) { + _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); + } else { + this.tcp.close(0, null); + } + } + + pub fn fail(this: *HTTPClient, code: ErrorCode) void { + JSC.markBinding(); + if (this.outgoing_websocket) |ws| + WebSocket__didFailWithErrorCode(ws, code); + + this.cancel(); + } + + pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { + JSC.markBinding(); + this.clearData(); + if (this.outgoing_websocket) |ws| + WebSocket__didFailWithErrorCode(ws, ErrorCode.ended); + } + + pub fn terminate(this: *HTTPClient, code: ErrorCode) void { + this.fail(code); + } + + fn getBody(this: *HTTPClient) *BodyBufBytes { + if (this.send_body_buf == null) { + this.send_body_buf = BodyBufPool.get(bun.default_allocator); + } + + return &this.send_body_buf.?.data; + } + + fn getReceiveBody(this: *HTTPClient) *BodyBufBytes { + if (this.receive_body_buf == null) { + this.receive_body_buf = BodyBufPool.get(bun.default_allocator); + } + + return &this.receive_body_buf.?.data; + } + + fn clearReceiveBuffers(this: *HTTPClient, free: bool) void { + if (this.receive_body_buf) |receive_buf| { + receive_buf.release(); + this.receive_body_buf = null; + } + + if (free) { + this.receive_overflow_buffer.clearAndFree(bun.default_allocator); + } else { + this.receive_overflow_buffer.clearRetainingCapacity(); + } + } + + fn clearSendBuffers(this: *HTTPClient, free: bool) void { + if (this.send_body_buf) |buf| { + buf.release(); + this.send_body_buf = null; + } + + if (free) { + this.send_overflow_buffer.clearAndFree(bun.default_allocator); + } else { + this.send_overflow_buffer.clearRetainingCapacity(); + } + this.send_off = 0; + this.send_len = 0; + } + + fn dispatchData(this: *HTTPClient, data_: []const u8, kind: Opcode) void { + switch (kind) { + .Text => { + // this function encodes to UTF-16 if > 127 + // so we don't need to worry about latin1 non-ascii code points + const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch { + this.terminate(ErrorCode.invalid_utf8); + return 0; + }; + defer this.clearReceiveBuffers(false); + var outstring = JSC.ZigString.Empty; + if (utf16_bytes_) |utf16| { + outstring = JSC.ZigString.from16Slice(utf16); + outstring.markUTF16(); + WebSocket__didReceiveText(this.outgoing_websocket, false, &outstring); + } else { + outstring = JSC.ZigString.init(data_); + WebSocket__didReceiveText(this.outgoing_websocket, true, &outstring); + } + }, + .Binary => { + WebSocket__didReceiveBytes(this.outgoing_websocket, data_); + this.clearReceiveBuffers(false); + }, + else => unreachable, + } + } + + pub fn consume(this: *HTTPClient, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize { + std.debug.assert(kind == .Text or kind == .Binary); + std.debug.assert(data_.len <= max); + + const can_dispatch_data = is_final and data_.len == max; + + // did all the data fit in the buffer? + // we can avoid copying & allocating a temporary buffer + if (can_dispatch_data and this.receive_pending_chunk_len == 0) { + this.dispatchData(data_, kind); + return data_.len; + } + + // if we previously allocated a buffer and there's room, attempt to use that one + const new_pending_chunk_len = max + this.receive_pending_chunk_len; + if (new_pending_chunk_len <= this.receive_overflow_buffer.capacity) { + @memcpy(this.receive_overflow_buffer.items.ptr + this.receive_overflow_buffer.items.len, data_.ptr, data_.len); + this.receive_overflow_buffer.items.len += data_.len; + if (can_dispatch_data) { + this.dispatchData(this.receive_overflow_buffer.items, kind); + this.receive_pending_chunk_len = 0; + } else { + this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len; + } + return data_.len; + } + + if (new_pending_chunk_len <= body_buf_len) { + // if our previously-allocated buffer is too small or we don't have one, use from the pool + var body = this.getReceiveBody(); + @memcpy(body + this.receive_pending_chunk_len, data_.ptr, data_.len); + if (can_dispatch_data) { + this.dispatchData(body[0..new_pending_chunk_len], kind); + this.receive_pending_chunk_len = 0; + } else { + this.receive_pending_chunk_len += data_.len; + } + return data_.len; + } + + { + // we need to copy the data into a potentially large temporary buffer + this.receive_overflow_buffer.appendSlice(bun.default_allocator, data_) catch { + this.terminate(ErrorCode.failed_to_allocate_memory); + return 0; + }; + if (can_dispatch_data) { + this.dispatchData(this.receive_overflow_buffer.items, kind); + this.receive_pending_chunk_len = 0; + } else { + this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len; + } + return data_.len; + } + } + + pub fn handleData(this: *HTTPClient, socket: Socket, data_: []const u8) void { + var data = data_; + var receive_state = this.receive_state; + var terminated = false; + var is_fragmented = false; + var receiving_type = this.receiving_type; + var receive_body_remain = this.receive_body_remain; + var is_final = false; + var last_receive_data_type = receiving_type; + + defer { + if (!terminated) { + this.receive_state = receive_state; + this.receiving_type = last_receive_data_type; + this.receive_body_remain = receive_body_remain; + + // if we receive multiple pings in a row + // we just send back the last one + if (this.ping_len > 0) { + _ = this.sendPong(socket); + this.ping_len = 0; + } + } + } + + var header_bytes: [@sizeOf(usize)]u8 = [_]u8{0} ** @sizeOf(usize); + while (true) { + switch (receive_state) { + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + .need_header => { + if (data.len < 2) { + this.terminate(ErrorCode.control_frame_is_fragmented); + terminated = true; + break; + } + + header_bytes[0..2].* = data[0..2].*; + receive_body_remain = 0; + var need_compression = false; + is_final = false; + + receive_state = parseWebSocketHeader( + header_bytes[0..2], + &receiving_type, + &receive_body_remain, + &is_fragmented, + &is_final, + &need_compression, + ); + if (receiving_type == .Text or receiving_type == .Binary) { + last_receive_data_type = receiving_type; + } + data = data[2..]; + + if (receiving_type.isControl() and is_fragmented) { + // Control frames must not be fragmented. + this.terminate(ErrorCode.control_frame_is_fragmented); + terminated = true; + break; + } + + switch (receiving_type) { + .Continue, .Text, .Binary, .Ping, .Pong, .Close => {}, + else => { + this.terminate(ErrorCode.unexpected_opcode); + terminated = true; + break; + }, + } + + if (need_compression) { + this.terminate(ErrorCode.compression_unsupported); + terminated = true; + break; + } + }, + .need_mask => { + this.terminate(.unexpected_mask_from_server); + terminated = true; + break; + }, + .extended_payload_length_64, .extended_payload_length_16 => |rc| { + const byte_size = switch (rc) { + .extended_payload_length_64 => @as(usize, 8), + .extended_payload_length_16 => @as(usize, 2), + else => unreachable, + }; + + if (data.len < byte_size) { + this.terminate(ErrorCode.control_frame_is_fragmented); + terminated = true; + break; + } + + // Multibyte length quantities are expressed in network byte order + receive_body_remain = switch (byte_size) { + 8 => @as(usize, std.mem.readIntBig(u64, data[0..8].*)), + 2 => @as(usize, std.mem.readIntBig(u16, data[0..2].*)), + else => unreachable, + }; + data = data[byte_size..]; + receive_state = .need_body; + + if (receive_body_remain == 0) { + // this is an error + // the server should've set length to zero + this.terminate(ErrorCode.invalid_control_frame); + terminated = true; + break; + } + }, + + .ping => { + const ping_len = @minimum(data.len, @minimum(receive_body_remain, 125)); + this.ping_len = @truncate(u8, ping_len); + + if (ping_len > 0) { + @memcpy(&this.ping_frame_bytes + 6, data.ptr, ping_len); + data = data[ping_len..]; + } + + receive_state = .need_header; + receive_body_remain = 0; + receiving_type = last_receive_data_type; + + if (data.len == 0) break; + }, + .pong => { + const pong_len = @minimum(data.len, @minimum(receive_body_remain, this.ping_frame_bytes.len)); + data = data[pong_len..]; + receive_state = .need_header; + receiving_type = last_receive_data_type; + if (data.len == 0) break; + }, + .need_body => { + if (receive_body_remain == 0 and data.len > 0) { + this.terminate(ErrorCode.expected_control_frame); + terminated = true; + break; + } + if (data.len == 0) return; + + const to_consume = @minimum(receive_body_remain, data.len); + + const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final); + if (consumed == 0 and last_receive_data_type == .Text) { + this.terminate(ErrorCode.invalid_utf8); + terminated = true; + break; + } + + receive_body_remain -= consumed; + data = data[to_consume..]; + if (receive_body_remain == 0) { + receive_state = .need_header; + is_fragmented = false; + } + + if (data.len == 0) break; + }, + + .close => { + // closing frame data is text only. + _ = this.consume(data[0..receive_body_remain], receive_body_remain, .Text, true); + this.sendClose(); + terminated = true; + break; + }, + .fail => { + this.terminate(ErrorCode.unexpected_control_frame); + terminated = true; + break; + }, + } + } + } + + pub fn sendClose(this: *HTTPClient) void { + this.sendCloseWithBody(this.tcp, 1001, null, null, 0); + } + + fn enqueueEncodedBytesMaybeFinal( + this: *HTTPClient, + socket: Socket, + bytes: []const u8, + is_closing: bool, + ) bool { + // fast path: no backpressure, no queue, just send the bytes. + if (this.send_len == 0) { + const wrote = socket.write(bytes, !is_closing); + const expected = @intCast(c_int, bytes.len); + if (wrote == expected) { + return true; + } + + if (wrote < 0) { + this.terminate(ErrorCode.failed_to_write); + return false; + } + + _ = this.copyToSendBuffer(bytes[bytes.len - @intCast(usize, wrote) ..], false, is_closing, false); + return true; + } + + return this.copyToSendBuffer(bytes, true, is_closing, false); + } + + fn copyToSendBuffer(this: *HTTPClient, bytes: []const u8, do_write: bool, is_closing: bool) bool { + return this.sendData(.{ .raw = bytes }, do_write, is_closing); + } + + fn sendData(this: *HTTPClient, bytes: Copy, do_write: bool, is_closing: bool) bool { + var content_byte_len: usize = 0; + const write_len = bytes.len(&content_byte_len); + std.debug.assert(write_len > 0); + + this.send_len += write_len; + var out_buf: []const u8 = ""; + const send_end = this.send_off + this.send_len; + var ring_buffer = &this.send_overflow_buffer; + + if (send_end <= ring_buffer.capacity) { + bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len); + ring_buffer.items.len += write_len; + + out_buf = ring_buffer.items[this.send_off..send_end]; + } else if (send_end <= body_buf_len) { + var buf = this.getBody(); + bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len); + out_buf = buf[this.send_off..send_end]; + } else { + if (this.send_body_buf) |send_body| { + // transfer existing send buffer to overflow buffer + ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch { + this.terminate(ErrorCode.failed_to_allocate_memory); + return false; + }; + const off = this.send_len - write_len; + @memcpy(ring_buffer.items.ptr, send_body + this.send_off, off); + send_body.release(); + bytes.copy(this.globalThis, (ring_buffer.items.ptr + off)[0..write_len], content_byte_len); + ring_buffer.items.len = this.send_len; + this.send_body_buf = null; + this.send_off = 0; + } else if (send_end <= ring_buffer.capacity) { + bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len); + ring_buffer.items.len += write_len; + // can we treat it as a ring buffer without re-allocating the array? + } else if (send_end - this.send_off <= ring_buffer.capacity) { + std.mem.copyBackwards(u8, ring_buffer.items[0..this.send_len], ring_buffer.items[this.send_off..send_end]); + bytes.copy(this.globalThis, ring_buffer.items.ptr[this.send_len - write_len .. this.send_len], content_byte_len); + ring_buffer.items.len = this.send_len; + this.send_off = 0; + } else { + // we need to re-allocate the array + ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch { + this.terminate(ErrorCode.failed_to_allocate_memory); + return false; + }; + + const data_to_copy = ring_buffer.items[this.send_off..][0..(this.send_len - write_len)]; + const position_in_slice = ring_buffer.items[0 .. this.send_len - write_len]; + if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr) + std.mem.copyBackwards( + u8, + position_in_slice, + data_to_copy, + ); + + ring_buffer.items.len = this.send_len; + bytes.copy(this.globalThis, ring_buffer.items[this.send_len - write_len ..], content_byte_len); + this.send_off = 0; + } + + out_buf = ring_buffer.items[this.send_off..send_end]; + } + + if (do_write) { + if (comptime Environment.allow_assert) { + std.debug.assert(!this.tcp.isShutdown()); + std.debug.assert(!this.tcp.isClosed()); + std.debug.assert(this.tcp.isEstablished()); + } + return this.sendBuffer(out_buf, is_closing); + } + + return true; + } + + fn sendBuffer(this: *HTTPClient, out_buf: []const u8, is_closing: bool) bool { + std.debug.assert(out_buf.len > 0); + const wrote = this.tcp.write(out_buf, !is_closing); + const expected = @intCast(c_int, out_buf.len); + if (wrote == expected) { + this.clearSendBuffers(false); + return true; + } + + if (wrote < 0) { + this.terminate(ErrorCode.failed_to_write); + return false; + } + + this.send_len -= @intCast(usize, wrote); + this.send_off += @intCast(usize, wrote); + return true; + } + + fn enqueueEncodedBytes(this: *HTTPClient, socket: Socket, bytes: []const u8) bool { + return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false); + } + + fn sendPong(this: *HTTPClient, socket: Socket) bool { + if (socket.isClosed() or socket.isShutdown()) { + this.dispatchClose(); + return false; + } + + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + header.fin = true; + header.opcode = .Pong; + + var to_mask = &this.ping_frame_bytes[6..][0..this.ping_len]; + + header.mask = to_mask.len > 0; + header.len = @truncate(u7, this.ping_len); + this.ping_frame_bytes[0..2].* = @bitCast(u16, header); + + if (to_mask.len > 0) { + Mask.from(this.globalThis).fill(&this.ping_frame_bytes[2..6], &to_mask, &to_mask); + return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]); + } else { + return this.enqueueEncodedBytes(socket, &this.ping_frame_bytes[0..2]); + } + } + + fn sendCloseWithBody( + this: *HTTPClient, + socket: Socket, + code: u16, + body: ?*[125]u8, + body_len: usize, + ) void { + if (socket.isClosed() or socket.isShutdown()) { + this.dispatchClose(); + this.clearData(); + return; + } + + socket.shutdownRead(); + var final_body_bytes: [128 + 8]u8 = undefined; + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + header.fin = true; + header.opcode = .Close; + header.mask = true; + header.len = body_len + 2; + final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header)); + var mask_buf = &final_body_bytes[2..6]; + std.mem.writeIntSliceBig(u16, &final_body_bytes[6..8], code); + + if (body) |data| { + if (body_len > 0) @memcpy(&final_body_bytes[8..], data, body_len); + } + + // we must mask the code + var slice = final_body_bytes[0..(8 + body_len)]; + Mask.from(this.globalThis).fill(mask_buf, slice[6..], slice[6..]); + + if (this.enqueueEncodedBytesMaybeFinal(socket, slice, true)) { + this.dispatchClose(); + this.clearData(); + } + } + + pub fn handleEnd(this: *HTTPClient, socket: Socket) void { + std.debug.assert(socket.socket == this.tcp.socket); + this.terminate(ErrorCode.ended); + } + + pub fn handleWritable( + this: *HTTPClient, + socket: Socket, + ) void { + std.debug.assert(socket.socket == this.tcp.socket); + if (this.send_len == 0) + return; + + var send_buf: []const u8 = undefined; + if (this.send_body_buf) |send_body_buf| { + send_buf = send_body_buf[this.send_off..][0..this.send_len]; + std.debug.assert(this.send_overflow_buffer.items.len == 0); + } else { + send_buf = this.send_overflow_buffer.items[this.send_off..][0..this.send_len]; + } + std.debug.assert(send_buf.len == this.send_len); + _ = this.sendBuffer(send_buf, false); + } + pub fn handleTimeout( + this: *HTTPClient, + _: Socket, + ) void { + this.terminate(ErrorCode.timeout); + } + pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { + this.terminate(ErrorCode.failed_to_connect); + } + + pub fn writeBinaryData( + this: *HTTPClient, + ptr: [*]const u8, + len: usize, + ) callconv(.C) void { + if (this.tcp.isClosed() or this.tcp.isShutdown()) { + this.dispatchClose(); + return; + } + + if (len == 0) + return; + + const slice = ptr[0..len]; + const bytes = Copy{ .bytes = slice }; + // fast path: small frame, no backpressure, attempt to send without allocating + const frame_size = WebsocketHeader.frameSizeIncludingMask(len); + if (this.send_len == 0 and frame_size < 128 + 4) { + var inline_buf: [128 + 4]u8 = undefined; + bytes.copy(this.globalThis, &inline_buf[0..frame_size], slice.len); + _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + return; + } + + _ = this.sendData(bytes, true, false); + } + pub fn writeString( + this: *HTTPClient, + str_: *const JSC.ZigString, + ) callconv(.C) void { + const str = str_.*; + if (this.tcp.isClosed() or this.tcp.isShutdown()) { + this.dispatchClose(); + return; + } + + if (str.len == 0) { + return; + } + + { + var inline_buf: [128 + 4]u8 = undefined; + + // fast path: small frame, no backpressure, attempt to send without allocating + if (!str.is16Bit() and str.len < 128 + 4) { + const bytes = Copy{ .latin1 = str.slice() }; + const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); + if (this.send_len == 0 and frame_size < 128 + 4) { + bytes.copy(this.globalThis, &inline_buf[0..frame_size], str.len); + _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + return; + } + // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string + } else if ((str.len * 4) < (128 + 4) and this.send_len == 0) { + const bytes = Copy{ .utf16 = str.slice() }; + var byte_len: usize = 0; + const frame_size = bytes.len(&byte_len); + std.debug.assert(frame_size <= 128 + 4); + bytes.copy(this.globalThis, &inline_buf[0..frame_size], byte_len); + _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + return; + } + } + + _ = this.sendData( + if (str.is16Bit()) + Copy{ .utf16 = str.utf16SliceAligned() } + else + Copy{ .latin1 = str.slice() }, + true, + false, + ); + } + + pub fn close(this: *HTTPClient, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { + if (this.tcp.isClosed() or this.tcp.isShutdown()) + return; + + var close_reason_buf: [128]u8 = undefined; + if (reason) |str| { + inner: { + var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf); + const allocator = fixed_buffer.get(); + const wrote = std.fmt.allocPrint(allocator, "{}", str.*) catch break :inner; + this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len); + return; + } + } + + this.sendCloseWithBody(this.tcp, code, null, 0); + } + + pub fn init( + outgoing: *anyopaque, + input_socket: *anyopaque, + socket_ctx: *anyopaque, + globalThis: *JSC.JSGlobalObject, + ) callconv(.C) *anyopaque { + var tcp = @ptrCast(*uws.Socket, input_socket); + var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx); + + return @ptrCast( + *anyopaque, + Socket.adopt(tcp, ctx, HTTPClient{ + .outgoing_websocket = outgoing, + .globalThis = globalThis, + }), + ); + } + + pub fn finalize(this: *HTTPClient) callconv(.C) void { + this.clearData(); + + if (this.tcp.isClosed()) + return; + + this.tcp.close(0, null); + this.outgoing_websocket = null; + } + + pub const Export = shim.exportFunctions(.{ + .writeBinaryData = writeBinaryData, + .writeString = writeString, + .close = close, + .register = register, + .init = init, + .finalize = finalize, + }); + + comptime { + if (!JSC.is_bindgen) { + @export(writeBinaryData, .{ .name = Export[0].symbol_name }); + @export(writeString, .{ .name = Export[1].symbol_name }); + @export(close, .{ .name = Export[2].symbol_name }); + @export(register, .{ .name = Export[3].symbol_name }); + @export(init, .{ .name = Export[4].symbol_name }); + @export(finalize, .{ .name = Export[5].symbol_name }); + } + } + }; +} + pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false); pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true); +pub const WebSocketClient = NewWebSocketClient(false); +pub const WebSocketClientTLS = NewWebSocketClient(true); diff --git a/src/javascript/jsc/bindings/ScriptExecutionContext.cpp b/src/javascript/jsc/bindings/ScriptExecutionContext.cpp index b89e0645f..9c6735993 100644 --- a/src/javascript/jsc/bindings/ScriptExecutionContext.cpp +++ b/src/javascript/jsc/bindings/ScriptExecutionContext.cpp @@ -63,18 +63,18 @@ static uWS::WebSocketContext<SSL, false, WebCore::WebSocket*>* registerWebSocket auto* opts = ctx->getExt(); /* Maximum message size we can receive */ - static unsigned int maxPayloadLength = 128 * 1024 * 1024; + unsigned int maxPayloadLength = 16 * 1024; /* 2 minutes timeout is good */ - static unsigned short idleTimeout = 120; + unsigned short idleTimeout = 120; /* 64kb backpressure is probably good */ - static unsigned int maxBackpressure = 128 * 1024 * 1024; - static bool closeOnBackpressureLimit = false; + unsigned int maxBackpressure = 64 * 1024; + bool closeOnBackpressureLimit = false; /* This one depends on kernel timeouts and is a bad default */ - static bool resetIdleTimeoutOnSend = false; + bool resetIdleTimeoutOnSend = false; /* A good default, esp. for newcomers */ - static bool sendPingsAutomatically = true; + bool sendPingsAutomatically = false; /* Maximum socket lifetime in seconds before forced closure (defaults to disabled) */ - static unsigned short maxLifetime = 0; + unsigned short maxLifetime = 0; opts->maxPayloadLength = maxPayloadLength; opts->maxBackpressure = maxBackpressure; diff --git a/src/javascript/jsc/bindings/bindings.zig b/src/javascript/jsc/bindings/bindings.zig index 841fa7a38..1b6524826 100644 --- a/src/javascript/jsc/bindings/bindings.zig +++ b/src/javascript/jsc/bindings/bindings.zig @@ -210,6 +210,10 @@ pub const ZigString = extern struct { return JSC.JSValue.fromRef(slice_).getZigString(ctx.ptr()); } + pub fn from16Slice(slice_: []const u16) ZigString { + return from16(slice_.ptr, slice_.len); + } + pub fn from16(slice_: [*]const u16, len: usize) ZigString { var str = init(@ptrCast([*]const u8, slice_)[0..len]); str.markUTF16(); diff --git a/src/javascript/jsc/bindings/exports.zig b/src/javascript/jsc/bindings/exports.zig index 0c9c2a677..c70388186 100644 --- a/src/javascript/jsc/bindings/exports.zig +++ b/src/javascript/jsc/bindings/exports.zig @@ -189,6 +189,8 @@ pub const JSArrayBufferSink = JSC.WebCore.ArrayBufferSink.JSSink; // WebSocket pub const WebSocketHTTPClient = @import("../../../http/websocket_http_client.zig").WebSocketHTTPClient; pub const WebSocketHTTSPClient = @import("../../../http/websocket_http_client.zig").WebSocketHTTPSClient; +pub const WebSocketClient = @import("../../../http/websocket_http_client.zig").WebSocketClient; +pub const WebSocketClientTLS = @import("../../../http/websocket_http_client.zig").WebSocketClientTLS; pub fn Errorable(comptime Type: type) type { return extern struct { @@ -2510,6 +2512,8 @@ pub const Formatter = ZigConsoleClient.Formatter; comptime { WebSocketHTTPClient.shim.ref(); WebSocketHTTSPClient.shim.ref(); + WebSocketClient.shim.ref(); + WebSocketClientTLS.shim.ref(); if (!is_bindgen) { _ = Process.getTitle; diff --git a/src/javascript/jsc/bindings/webcore/WebSocket.cpp b/src/javascript/jsc/bindings/webcore/WebSocket.cpp index 9b996565c..3c9f3a373 100644 --- a/src/javascript/jsc/bindings/webcore/WebSocket.cpp +++ b/src/javascript/jsc/bindings/webcore/WebSocket.cpp @@ -420,8 +420,6 @@ ExceptionOr<void> WebSocket::send(const String& message) if (utf8.length() > 0) this->sendWebSocketData<false>(utf8.data(), utf8.length()); - delete utf8; - return {}; } @@ -490,31 +488,34 @@ void WebSocket::sendWebSocketData(const char* baseAddress, size_t length) if constexpr (isBinary) opCode = uWS::OpCode::BINARY; - switch (m_connectedWebSocketKind) { - case ConnectedWebSocketKind::Client: { - this->m_connectedWebSocket.client->send({ baseAddress, length }, opCode); - this->m_bufferedAmount = this->m_connectedWebSocket.client->getBufferedAmount(); - break; - } - case ConnectedWebSocketKind::ClientSSL: { - this->m_connectedWebSocket.clientSSL->send({ baseAddress, length }, opCode); - this->m_bufferedAmount = this->m_connectedWebSocket.clientSSL->getBufferedAmount(); - break; - } - case ConnectedWebSocketKind::Server: { - this->m_connectedWebSocket.server->send({ baseAddress, length }, opCode); - this->m_bufferedAmount = this->m_connectedWebSocket.server->getBufferedAmount(); - break; - } - case ConnectedWebSocketKind::ServerSSL: { - this->m_connectedWebSocket.serverSSL->send({ baseAddress, length }, opCode); - this->m_bufferedAmount = this->m_connectedWebSocket.serverSSL->getBufferedAmount(); - break; - } - default: { - RELEASE_ASSERT_NOT_REACHED(); - } - } + this->m_connectedWebSocket.client->cork( + [&]() { + switch (m_connectedWebSocketKind) { + case ConnectedWebSocketKind::Client: { + this->m_connectedWebSocket.client->send({ baseAddress, length }, opCode); + this->m_bufferedAmount = this->m_connectedWebSocket.client->getBufferedAmount(); + break; + } + case ConnectedWebSocketKind::ClientSSL: { + this->m_connectedWebSocket.clientSSL->send({ baseAddress, length }, opCode); + this->m_bufferedAmount = this->m_connectedWebSocket.clientSSL->getBufferedAmount(); + break; + } + case ConnectedWebSocketKind::Server: { + this->m_connectedWebSocket.server->send({ baseAddress, length }, opCode); + this->m_bufferedAmount = this->m_connectedWebSocket.server->getBufferedAmount(); + break; + } + case ConnectedWebSocketKind::ServerSSL: { + this->m_connectedWebSocket.serverSSL->send({ baseAddress, length }, opCode); + this->m_bufferedAmount = this->m_connectedWebSocket.serverSSL->getBufferedAmount(); + break; + } + default: { + RELEASE_ASSERT_NOT_REACHED(); + } + } + }); } ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, const String& reason) @@ -856,16 +857,18 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe this->didConnect(); } -void WebSocket::didFailToConnect(int32_t code) +void WebSocket::didFailWithErrorCode(int32_t code) { m_state = CLOSED; // this means we already handled it - if (this->m_upgradeClient == nullptr) { + if (this->m_upgradeClient == nullptr && this->m_connectedWebSocketKind == ConnectedWebSocketKind::None) { return; } this->m_upgradeClient = nullptr; + this->m_connectedWebSocketKind = ConnectedWebSocketKind::None; + this->m_connectedWebSocket.client = nullptr; switch (code) { // cancel @@ -982,7 +985,7 @@ extern "C" void WebSocket__didConnect(WebCore::WebSocket* webSocket, us_socket_t { webSocket->didConnect(socket, bufferedData, len); } -extern "C" void WebSocket__didFailToConnect(WebCore::WebSocket* webSocket, int32_t errorCode) +extern "C" void WebSocket__didFailWithErrorCode(WebCore::WebSocket* webSocket, int32_t errorCode) { - webSocket->didFailToConnect(errorCode); + webSocket->didFailWithErrorCode(errorCode); }
\ No newline at end of file diff --git a/src/javascript/jsc/bindings/webcore/WebSocket.h b/src/javascript/jsc/bindings/webcore/WebSocket.h index 351e8c6b6..51b885c3b 100644 --- a/src/javascript/jsc/bindings/webcore/WebSocket.h +++ b/src/javascript/jsc/bindings/webcore/WebSocket.h @@ -97,7 +97,7 @@ public: void didConnect(); void didClose(unsigned unhandledBufferedAmount, unsigned short code, const String& reason); void didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize); - void didFailToConnect(int32_t code); + void didFailWithErrorCode(int32_t code); void didReceiveMessage(String&& message); void didReceiveData(const char* data, size_t length); diff --git a/src/javascript/jsc/rare_data.zig b/src/javascript/jsc/rare_data.zig index 3020dfc16..f82418da9 100644 --- a/src/javascript/jsc/rare_data.zig +++ b/src/javascript/jsc/rare_data.zig @@ -7,12 +7,14 @@ const Syscall = @import("./node/syscall.zig"); const JSC = @import("javascript_core"); const std = @import("std"); const BoringSSL = @import("boringssl"); -boring_ssl_engine: ?*BoringSSL.ENGINE = null, +const WebSocketClientMask = @import("../../http/websocket_http_client.zig").Mask; +boring_ssl_engine: ?*BoringSSL.ENGINE = null, editor_context: EditorContext = EditorContext{}, stderr_store: ?*Blob.Store = null, stdin_store: ?*Blob.Store = null, stdout_store: ?*Blob.Store = null, +websocket_mask: WebSocketClientMask = WebSocketClientMask{}, // TODO: make this per JSGlobalObject instead of global // This does not handle ShadowRealm correctly! |