diff options
Diffstat (limited to 'src/http/websocket_http_client.zig')
| -rw-r--r-- | src/http/websocket_http_client.zig | 576 |
1 files changed, 332 insertions, 244 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 2be382dbe..8522c6038 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -92,6 +92,7 @@ const ErrorCode = enum(i32) { compression_unsupported, unexpected_mask_from_server, expected_control_frame, + unsupported_control_frame, unexpected_opcode, invalid_utf8, }; @@ -101,9 +102,9 @@ extern fn WebSocket__didConnect( buffered_data: ?[*]u8, buffered_len: usize, ) void; -extern fn WebSocket__didFailWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; +extern fn WebSocket__didCloseWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void; -extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: []const u8) void; +extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: [*]const u8, byte_len: usize) void; const body_buf_len = 16384 - 16; const BodyBufBytes = [body_buf_len]u8; @@ -159,7 +160,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { var client_protocol_hash: u64 = 0; var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null; var client: HTTPClient = HTTPClient{ - .socket = undefined, + .tcp = undefined, .outgoing_websocket = websocket, .input_body_buf = body, .websocket_protocol = client_protocol_hash, @@ -200,14 +201,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn fail(this: *HTTPClient, code: ErrorCode) void { JSC.markBinding(); - WebSocket__didFailWithErrorCode(this.outgoing_websocket, code); + WebSocket__didCloseWithErrorCode(this.outgoing_websocket, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(); this.clearData(); - WebSocket__didFailWithErrorCode(this.outgoing_websocket, ErrorCode.ended); + WebSocket__didCloseWithErrorCode(this.outgoing_websocket, ErrorCode.ended); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { @@ -407,6 +408,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.clearData(); JSC.markBinding(); + this.tcp.timeout(0); WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len); } @@ -500,17 +502,45 @@ pub const Mask = struct { pub fn fill(this: *Mask, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void { const mask = this.get(); mask_buf.* = mask; + + const skip_mask = @bitCast(u32, mask) == 0; + if (!skip_mask) { + fillWithSkipMask(mask, output_, input_, false); + } else { + fillWithSkipMask(mask, output_, input_, true); + } + } + + fn fillWithSkipMask(mask: [4]u8, output_: []u8, input_: []const u8, comptime skip_mask: bool) void { var input = input_; var output = output_; + if (comptime Environment.isAarch64 or Environment.isX64) { if (input.len >= strings.ascii_vector_size) { - const vec: strings.AsciiVector = @as(strings.AsciiVector, mask ** (strings.ascii_vector_size / 4)); + const vec: strings.AsciiVector = brk: { + var in: [strings.ascii_vector_size]u8 = undefined; + comptime var i: usize = 0; + inline while (i < strings.ascii_vector_size) : (i += 4) { + in[i..][0..4].* = mask; + } + break :brk @as(strings.AsciiVector, in); + }; const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size); - while (input.ptr != end_ptr_wrapped_to_last_16) { - const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); - output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; - output = output[strings.ascii_vector_size..]; - input = input[strings.ascii_vector_size..]; + + if (comptime skip_mask) { + while (input.ptr != end_ptr_wrapped_to_last_16) { + const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); + output.ptr[0..strings.ascii_vector_size].* = input_vec; + output = output[strings.ascii_vector_size..]; + input = input[strings.ascii_vector_size..]; + } + } else { + while (input.ptr != end_ptr_wrapped_to_last_16) { + const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*); + output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec; + output = output[strings.ascii_vector_size..]; + input = input[strings.ascii_vector_size..]; + } } } @@ -518,20 +548,35 @@ pub const Mask = struct { std.debug.assert(input.len < strings.ascii_vector_size); } - while (input.len >= 4) { - const input_vec: [4]u8 = input[0..4].*; - output.ptr[0..4].* = [4]u8{ - input_vec[0] ^ mask[0], - input_vec[1] ^ mask[1], - input_vec[2] ^ mask[2], - input_vec[3] ^ mask[3], - }; - output = output[4..]; - input = input[4..]; + if (comptime !skip_mask) { + while (input.len >= 4) { + const input_vec: [4]u8 = input[0..4].*; + output.ptr[0..4].* = [4]u8{ + input_vec[0] ^ mask[0], + input_vec[1] ^ mask[1], + input_vec[2] ^ mask[2], + input_vec[3] ^ mask[3], + }; + output = output[4..]; + input = input[4..]; + } + } else { + while (input.len >= 4) { + const input_vec: [4]u8 = input[0..4].*; + output.ptr[0..4].* = input_vec; + output = output[4..]; + input = input[4..]; + } } - for (input) |c, i| { - output[i] = c ^ mask[i % 4]; + if (comptime !skip_mask) { + for (input) |c, i| { + output[i] = c ^ mask[i % 4]; + } + } else { + for (input) |c, i| { + output[i] = c; + } } } @@ -548,7 +593,8 @@ const ReceiveState = enum { extended_payload_length_16, extended_payload_length_64, ping, - closing, + pong, + close, fail, pub fn needControlFrame(this: ReceiveState) bool { @@ -629,71 +675,82 @@ const Copy = union(enum) { return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { - byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); + byte_len.* = this.latin1.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { byte_len.* = this.bytes.len; return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, - .raw => return this.raw.len, + .raw => { + byte_len.* = this.raw.len; + return this.raw.len; + }, } } pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void { + if (this == .raw) { + std.debug.assert(buf.len >= this.raw.len); + std.debug.assert(buf.ptr != this.raw.ptr); + @memcpy(buf.ptr, this.raw.ptr, this.raw.len); + return; + } + + const how_big_is_the_length_integer = WebsocketHeader.lengthByteCount(content_byte_len); + const how_big_is_the_mask = 4; + const mask_offset = 2 + how_big_is_the_length_integer; + const content_offset = mask_offset + how_big_is_the_mask; + + // 2 byte header + // 4 byte mask + // 0, 2, 8 byte length + var to_mask = buf[content_offset..]; + + var header = @bitCast(WebsocketHeader, @as(u16, 0)); + + // Write extended length if needed + switch (how_big_is_the_length_integer) { + 0 => {}, + 2 => std.mem.writeIntBig(u16, buf[2..][0..2], @truncate(u16, content_byte_len)), + 8 => std.mem.writeIntBig(u64, buf[2..][0..8], @truncate(u64, content_byte_len)), + else => unreachable, + } + + header.mask = true; + header.compressed = false; + header.final = true; + + std.debug.assert(WebsocketHeader.frameSizeIncludingMask(content_byte_len) == buf.len); + switch (this) { .utf16 => |utf16| { - const length_offset = 2; - const length_length = WebsocketHeader.lengthByteCount(content_byte_len); - const mask_offset = length_offset + length_length; - const content_offset = mask_offset + 4; - var to_mask = buf[content_offset..]; - const encode_into_result = strings.copyUTF16IntoUTF8(utf16, to_mask); + header.len = WebsocketHeader.packLength(content_byte_len); + const encode_into_result = strings.copyUTF16IntoUTF8(to_mask, []const u16, utf16); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == utf16.len); - var header = @bitCast(WebsocketHeader, @as(u16, 0)); - header.len = WebsocketHeader.packLength(content_byte_len); - header.mask = true; + header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; - header.compressed = false; - header.final = true; - header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; - Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + header.writeHeader(std.io.fixedBufferStream(buf).writer(), encode_into_result.written) catch unreachable; + + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .latin1 => |latin1| { - const length_offset = 2; - const length_length = WebsocketHeader.lengthByteCount(content_byte_len); - const mask_offset = length_offset + length_length; - const content_offset = mask_offset + 4; - var to_mask = buf[content_offset..]; - const encode_into_result = strings.copyLatin1IntoUTF8(latin1, to_mask); + const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); - var header = @bitCast(WebsocketHeader, @as(u16, 0)); - header.len = WebsocketHeader.packLength(content_byte_len); - header.mask = true; + header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; - header.compressed = false; - header.final = true; - header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable; - Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + header.writeHeader(std.io.fixedBufferStream(buf).writer(), encode_into_result.written) catch unreachable; + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]); }, .bytes => |bytes| { - const length_offset = 2; - const length_length = WebsocketHeader.lengthByteCount(bytes.len); - const mask_offset = length_offset + length_length; - const content_offset = mask_offset + 4; - var to_mask = buf[content_offset..]; - var header = @bitCast(WebsocketHeader, @as(u16, 0)); header.len = WebsocketHeader.packLength(bytes.len); - header.mask = true; - header.opcode = Opcode.Text; - header.compressed = false; - header.final = true; - header.writeHeader(std.io.fixedBufferStream(buf), bytes.len) catch unreachable; - Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask); + header.opcode = Opcode.Binary; + header.writeHeader(std.io.fixedBufferStream(buf).writer(), bytes.len) catch unreachable; + Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], bytes); }, - .raw => @memcpy(buf.ptr, this.raw.ptr, this.raw.len), + .raw => unreachable, } } }; @@ -709,7 +766,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { receive_remaining: usize = 0, receiving_type: Opcode = Opcode.ResB, - ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** 128 + 6, + ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6), ping_len: u8 = 0, receive_frame: usize = 0, @@ -717,9 +774,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { receive_pending_chunk_len: usize = 0, receive_body_buf: ?*BodyBuf = null, receive_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, - send_overflow_buffer: std.ArrayListUnmanaged(u8) = .{}, - send_body_buf: ?*BodyBuf = null, + send_buffer: bun.LinearFifo(u8, .Dynamic), send_len: usize = 0, send_off: usize = 0, @@ -728,10 +784,11 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); + const stack_frame_size = 1024; - const HTTPClient = @This(); + const WebSocket = @This(); - pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, parent: *anyopaque, ctx_: *anyopaque) callconv(.C) void { + pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, loop_); var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); @@ -742,10 +799,9 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { vm.uws_event_loop = loop; - Socket.configureChild( + Socket.configure( ctx, - parent, - HTTPClient, + WebSocket, null, handleClose, handleData, @@ -756,7 +812,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { ); } - pub fn clearData(this: *HTTPClient) void { + pub fn clearData(this: *WebSocket) void { this.clearReceiveBuffers(true); this.clearSendBuffers(true); this.ping_len = 0; @@ -764,7 +820,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.receive_remaining = 0; } - pub fn cancel(this: *HTTPClient) callconv(.C) void { + pub fn cancel(this: *WebSocket) callconv(.C) void { this.clearData(); if (this.tcp.isClosed() or this.tcp.isShutdown()) @@ -777,34 +833,34 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - pub fn fail(this: *HTTPClient, code: ErrorCode) void { + pub fn fail(this: *WebSocket, code: ErrorCode) void { JSC.markBinding(); if (this.outgoing_websocket) |ws| - WebSocket__didFailWithErrorCode(ws, code); + WebSocket__didCloseWithErrorCode(ws, code); this.cancel(); } - pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { + pub fn handleClose(this: *WebSocket, _: Socket, _: c_int, _: ?*anyopaque) void { JSC.markBinding(); this.clearData(); if (this.outgoing_websocket) |ws| - WebSocket__didFailWithErrorCode(ws, ErrorCode.ended); + WebSocket__didCloseWithErrorCode(ws, ErrorCode.ended); } - pub fn terminate(this: *HTTPClient, code: ErrorCode) void { + pub fn terminate(this: *WebSocket, code: ErrorCode) void { this.fail(code); } - fn getBody(this: *HTTPClient) *BodyBufBytes { - if (this.send_body_buf == null) { - this.send_body_buf = BodyBufPool.get(bun.default_allocator); - } + // fn getBody(this: *WebSocket) *BodyBufBytes { + // if (this.send_body_buf == null) { + // this.send_body_buf = BodyBufPool.get(bun.default_allocator); + // } - return &this.send_body_buf.?.data; - } + // return &this.send_body_buf.?.data; + // } - fn getReceiveBody(this: *HTTPClient) *BodyBufBytes { + fn getReceiveBody(this: *WebSocket) *BodyBufBytes { if (this.receive_body_buf == null) { this.receive_body_buf = BodyBufPool.get(bun.default_allocator); } @@ -812,7 +868,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return &this.receive_body_buf.?.data; } - fn clearReceiveBuffers(this: *HTTPClient, free: bool) void { + fn clearReceiveBuffers(this: *WebSocket, free: bool) void { if (this.receive_body_buf) |receive_buf| { receive_buf.release(); this.receive_body_buf = null; @@ -825,50 +881,58 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - fn clearSendBuffers(this: *HTTPClient, free: bool) void { - if (this.send_body_buf) |buf| { - buf.release(); - this.send_body_buf = null; - } + fn clearSendBuffers(this: *WebSocket, free: bool) void { + // if (this.send_body_buf) |buf| { + // buf.release(); + // this.send_body_buf = null; + // } + this.send_buffer.discard(this.send_buffer.count); if (free) { - this.send_overflow_buffer.clearAndFree(bun.default_allocator); - } else { - this.send_overflow_buffer.clearRetainingCapacity(); + this.send_buffer.deinit(); + this.send_buffer.buf.len = 0; } + this.send_off = 0; this.send_len = 0; } - fn dispatchData(this: *HTTPClient, data_: []const u8, kind: Opcode) void { + fn dispatchData(this: *WebSocket, data_: []const u8, kind: Opcode) void { + var out = this.outgoing_websocket orelse { + this.clearData(); + return; + }; switch (kind) { .Text => { // this function encodes to UTF-16 if > 127 // so we don't need to worry about latin1 non-ascii code points const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch { this.terminate(ErrorCode.invalid_utf8); - return 0; + return; }; defer this.clearReceiveBuffers(false); var outstring = JSC.ZigString.Empty; if (utf16_bytes_) |utf16| { outstring = JSC.ZigString.from16Slice(utf16); outstring.markUTF16(); - WebSocket__didReceiveText(this.outgoing_websocket, false, &outstring); + JSC.markBinding(); + WebSocket__didReceiveText(out, false, &outstring); } else { outstring = JSC.ZigString.init(data_); - WebSocket__didReceiveText(this.outgoing_websocket, true, &outstring); + JSC.markBinding(); + WebSocket__didReceiveText(out, true, &outstring); } }, .Binary => { - WebSocket__didReceiveBytes(this.outgoing_websocket, data_); + JSC.markBinding(); + WebSocket__didReceiveBytes(out, data_.ptr, data_.len); this.clearReceiveBuffers(false); }, else => unreachable, } } - pub fn consume(this: *HTTPClient, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize { + pub fn consume(this: *WebSocket, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize { std.debug.assert(kind == .Text or kind == .Binary); std.debug.assert(data_.len <= max); @@ -898,7 +962,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (new_pending_chunk_len <= body_buf_len) { // if our previously-allocated buffer is too small or we don't have one, use from the pool var body = this.getReceiveBody(); - @memcpy(body + this.receive_pending_chunk_len, data_.ptr, data_.len); + @memcpy(body[this.receive_pending_chunk_len..].ptr, data_.ptr, data_.len); if (can_dispatch_data) { this.dispatchData(body[0..new_pending_chunk_len], kind); this.receive_pending_chunk_len = 0; @@ -924,7 +988,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - pub fn handleData(this: *HTTPClient, socket: Socket, data_: []const u8) void { + pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { var data = data_; var receive_state = this.receive_state; var terminated = false; @@ -983,7 +1047,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { is_final = false; receive_state = parseWebSocketHeader( - header_bytes[0..2], + header_bytes[0..2].*, &receiving_type, &receive_body_remain, &is_fragmented, @@ -1005,7 +1069,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { switch (receiving_type) { .Continue, .Text, .Binary, .Ping, .Pong, .Close => {}, else => { - this.terminate(ErrorCode.unexpected_opcode); + this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, @@ -1037,8 +1101,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // Multibyte length quantities are expressed in network byte order receive_body_remain = switch (byte_size) { - 8 => @as(usize, std.mem.readIntBig(u64, data[0..8].*)), - 2 => @as(usize, std.mem.readIntBig(u16, data[0..2].*)), + 8 => @as(usize, std.mem.readIntBig(u64, data[0..8])), + 2 => @as(usize, std.mem.readIntBig(u16, data[0..2])), else => unreachable, }; data = data[byte_size..]; @@ -1058,7 +1122,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.ping_len = @truncate(u8, ping_len); if (ping_len > 0) { - @memcpy(&this.ping_frame_bytes + 6, data.ptr, ping_len); + @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len); data = data[ping_len..]; } @@ -1110,7 +1174,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { break; }, .fail => { - this.terminate(ErrorCode.unexpected_control_frame); + this.terminate(ErrorCode.unsupported_control_frame); terminated = true; break; }, @@ -1118,18 +1182,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - pub fn sendClose(this: *HTTPClient) void { - this.sendCloseWithBody(this.tcp, 1001, null, null, 0); + pub fn sendClose(this: *WebSocket) void { + this.sendCloseWithBody(this.tcp, 1001, null, 0); } fn enqueueEncodedBytesMaybeFinal( - this: *HTTPClient, + this: *WebSocket, socket: Socket, bytes: []const u8, is_closing: bool, ) bool { // fast path: no backpressure, no queue, just send the bytes. - if (this.send_len == 0) { + if (!this.hasBackpressure()) { const wrote = socket.write(bytes, !is_closing); const expected = @intCast(c_int, bytes.len); if (wrote == expected) { @@ -1141,82 +1205,87 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return false; } - _ = this.copyToSendBuffer(bytes[bytes.len - @intCast(usize, wrote) ..], false, is_closing, false); + _ = this.copyToSendBuffer(bytes[@intCast(usize, wrote)..], false, is_closing); return true; } - return this.copyToSendBuffer(bytes, true, is_closing, false); + return this.copyToSendBuffer(bytes, true, is_closing); } - fn copyToSendBuffer(this: *HTTPClient, bytes: []const u8, do_write: bool, is_closing: bool) bool { + fn copyToSendBuffer(this: *WebSocket, bytes: []const u8, do_write: bool, is_closing: bool) bool { return this.sendData(.{ .raw = bytes }, do_write, is_closing); } - fn sendData(this: *HTTPClient, bytes: Copy, do_write: bool, is_closing: bool) bool { + fn sendData(this: *WebSocket, bytes: Copy, do_write: bool, is_closing: bool) bool { var content_byte_len: usize = 0; const write_len = bytes.len(&content_byte_len); std.debug.assert(write_len > 0); - this.send_len += write_len; - var out_buf: []const u8 = ""; - const send_end = this.send_off + this.send_len; - var ring_buffer = &this.send_overflow_buffer; - - if (send_end <= ring_buffer.capacity) { - bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len); - ring_buffer.items.len += write_len; - - out_buf = ring_buffer.items[this.send_off..send_end]; - } else if (send_end <= body_buf_len) { - var buf = this.getBody(); - bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len); - out_buf = buf[this.send_off..send_end]; - } else { - if (this.send_body_buf) |send_body| { - // transfer existing send buffer to overflow buffer - ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch { - this.terminate(ErrorCode.failed_to_allocate_memory); - return false; - }; - const off = this.send_len - write_len; - @memcpy(ring_buffer.items.ptr, send_body + this.send_off, off); - send_body.release(); - bytes.copy(this.globalThis, (ring_buffer.items.ptr + off)[0..write_len], content_byte_len); - ring_buffer.items.len = this.send_len; - this.send_body_buf = null; - this.send_off = 0; - } else if (send_end <= ring_buffer.capacity) { - bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len); - ring_buffer.items.len += write_len; - // can we treat it as a ring buffer without re-allocating the array? - } else if (send_end - this.send_off <= ring_buffer.capacity) { - std.mem.copyBackwards(u8, ring_buffer.items[0..this.send_len], ring_buffer.items[this.send_off..send_end]); - bytes.copy(this.globalThis, ring_buffer.items.ptr[this.send_len - write_len .. this.send_len], content_byte_len); - ring_buffer.items.len = this.send_len; - this.send_off = 0; - } else { - // we need to re-allocate the array - ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch { - this.terminate(ErrorCode.failed_to_allocate_memory); - return false; - }; - - const data_to_copy = ring_buffer.items[this.send_off..][0..(this.send_len - write_len)]; - const position_in_slice = ring_buffer.items[0 .. this.send_len - write_len]; - if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr) - std.mem.copyBackwards( - u8, - position_in_slice, - data_to_copy, - ); - - ring_buffer.items.len = this.send_len; - bytes.copy(this.globalThis, ring_buffer.items[this.send_len - write_len ..], content_byte_len); - this.send_off = 0; - } - - out_buf = ring_buffer.items[this.send_off..send_end]; - } + // this.send_len += write_len; + var writable = this.send_buffer.writableWithSize(write_len) catch unreachable; + bytes.copy(this.globalThis, writable[0..write_len], content_byte_len); + this.send_buffer.update(write_len); + // const send_end = this.send_off + this.send_len; + // var ring_buffer = &this.send_overflow_buffer; + // const prev_len = this.send_len - write_len; + + // if (send_end <= ring_buffer.capacity) { + // const mid = ring_buffer.capacity / 2; + // if (send_end > mid and this.send_len < mid) { + // std.mem.copyBackwards(u8, ring_buffer.items.ptr[0..prev_len], ring_buffer.items.ptr[this.send_off..ring_buffer.capacity][0..prev_len]); + // this.send_off = 0; + // ring_buffer.items.len = prev_len; + // bytes.copy(this.globalThis, ring_buffer.items.ptr[prev_len..this.send_len], content_byte_len); + // } else { + // bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len); + // } + // ring_buffer.items.len += write_len; + + // out_buf = ring_buffer.items[this.send_off..][0..this.send_len]; + // } else if (send_end <= body_buf_len) { + // var buf = this.getBody(); + // bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len); + // out_buf = buf[this.send_off..][0..this.send_len]; + // } else { + // if (this.send_body_buf) |send_body| { + // var buf = &send_body.data; + // // transfer existing send buffer to overflow buffer + // ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch { + // this.terminate(ErrorCode.failed_to_allocate_memory); + // return false; + // }; + // @memcpy(ring_buffer.items.ptr, buf[this.send_off..].ptr, prev_len); + // send_body.release(); + // bytes.copy(this.globalThis, (ring_buffer.items.ptr + prev_len)[0..write_len], content_byte_len); + // ring_buffer.items.len = this.send_len; + // this.send_body_buf = null; + // this.send_off = 0; + // } else if (send_end <= ring_buffer.capacity) { + // bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len); + // ring_buffer.items.len += write_len; + // } else { + // // we need to re-allocate the array + // ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch { + // this.terminate(ErrorCode.failed_to_allocate_memory); + // return false; + // }; + + // const data_to_copy = ring_buffer.items[this.send_off..][0..prev_len]; + // const position_in_slice = ring_buffer.items[0..prev_len]; + // if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr) + // std.mem.copyBackwards( + // u8, + // position_in_slice, + // data_to_copy, + // ); + + // ring_buffer.items.len = this.send_len; + // bytes.copy(this.globalThis, ring_buffer.items[prev_len..], content_byte_len); + // this.send_off = 0; + // } + + // out_buf = ring_buffer.items[this.send_off..][0..this.send_len]; + // } if (do_write) { if (comptime Environment.allow_assert) { @@ -1224,61 +1293,66 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { std.debug.assert(!this.tcp.isClosed()); std.debug.assert(this.tcp.isEstablished()); } - return this.sendBuffer(out_buf, is_closing); + return this.sendBuffer(this.send_buffer.readableSlice(0), is_closing, !is_closing); } return true; } - fn sendBuffer(this: *HTTPClient, out_buf: []const u8, is_closing: bool) bool { + fn sendBuffer( + this: *WebSocket, + out_buf: []const u8, + is_closing: bool, + _: bool, + ) bool { std.debug.assert(out_buf.len > 0); - const wrote = this.tcp.write(out_buf, !is_closing); - const expected = @intCast(c_int, out_buf.len); - if (wrote == expected) { - this.clearSendBuffers(false); - return true; - } - + _ = is_closing; + // set msg_more to false + // it seems to improve perf by ~20% + const wrote = this.tcp.write(out_buf, false); if (wrote < 0) { this.terminate(ErrorCode.failed_to_write); return false; } + const expected = @intCast(usize, wrote); + var readable = this.send_buffer.readableSlice(0); + if (readable.ptr == out_buf.ptr) { + this.send_buffer.discard(expected); + } - this.send_len -= @intCast(usize, wrote); - this.send_off += @intCast(usize, wrote); return true; } - fn enqueueEncodedBytes(this: *HTTPClient, socket: Socket, bytes: []const u8) bool { + fn enqueueEncodedBytes(this: *WebSocket, socket: Socket, bytes: []const u8) bool { return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false); } - fn sendPong(this: *HTTPClient, socket: Socket) bool { + fn sendPong(this: *WebSocket, socket: Socket) bool { if (socket.isClosed() or socket.isShutdown()) { this.dispatchClose(); return false; } var header = @bitCast(WebsocketHeader, @as(u16, 0)); - header.fin = true; + header.final = true; header.opcode = .Pong; - var to_mask = &this.ping_frame_bytes[6..][0..this.ping_len]; + var to_mask = this.ping_frame_bytes[6..][0..this.ping_len]; header.mask = to_mask.len > 0; header.len = @truncate(u7, this.ping_len); this.ping_frame_bytes[0..2].* = @bitCast(u16, header); if (to_mask.len > 0) { - Mask.from(this.globalThis).fill(&this.ping_frame_bytes[2..6], &to_mask, &to_mask); + Mask.from(this.globalThis).fill(this.ping_frame_bytes[2..6], to_mask, to_mask); return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]); } else { - return this.enqueueEncodedBytes(socket, &this.ping_frame_bytes[0..2]); + return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0..2]); } } fn sendCloseWithBody( - this: *HTTPClient, + this: *WebSocket, socket: Socket, code: u16, body: ?*[125]u8, @@ -1293,16 +1367,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { socket.shutdownRead(); var final_body_bytes: [128 + 8]u8 = undefined; var header = @bitCast(WebsocketHeader, @as(u16, 0)); - header.fin = true; + header.final = true; header.opcode = .Close; header.mask = true; - header.len = body_len + 2; + header.len = @truncate(u7, body_len + 2); final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header)); - var mask_buf = &final_body_bytes[2..6]; - std.mem.writeIntSliceBig(u16, &final_body_bytes[6..8], code); + var mask_buf: *[4]u8 = final_body_bytes[2..6]; + std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code); if (body) |data| { - if (body_len > 0) @memcpy(&final_body_bytes[8..], data, body_len); + if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len); } // we must mask the code @@ -1315,41 +1389,37 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - pub fn handleEnd(this: *HTTPClient, socket: Socket) void { + pub fn handleEnd(this: *WebSocket, socket: Socket) void { std.debug.assert(socket.socket == this.tcp.socket); this.terminate(ErrorCode.ended); } pub fn handleWritable( - this: *HTTPClient, + this: *WebSocket, socket: Socket, ) void { std.debug.assert(socket.socket == this.tcp.socket); - if (this.send_len == 0) + const send_buf = this.send_buffer.readableSlice(0); + if (send_buf.len == 0) return; - - var send_buf: []const u8 = undefined; - if (this.send_body_buf) |send_body_buf| { - send_buf = send_body_buf[this.send_off..][0..this.send_len]; - std.debug.assert(this.send_overflow_buffer.items.len == 0); - } else { - send_buf = this.send_overflow_buffer.items[this.send_off..][0..this.send_len]; - } - std.debug.assert(send_buf.len == this.send_len); - _ = this.sendBuffer(send_buf, false); + _ = this.sendBuffer(send_buf, false, true); } pub fn handleTimeout( - this: *HTTPClient, + this: *WebSocket, _: Socket, ) void { this.terminate(ErrorCode.timeout); } - pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { + pub fn handleConnectError(this: *WebSocket, _: Socket, _: c_int) void { this.terminate(ErrorCode.failed_to_connect); } + pub fn hasBackpressure(this: *const WebSocket) bool { + return this.send_buffer.count > 0; + } + pub fn writeBinaryData( - this: *HTTPClient, + this: *WebSocket, ptr: [*]const u8, len: usize, ) callconv(.C) void { @@ -1365,17 +1435,17 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating const frame_size = WebsocketHeader.frameSizeIncludingMask(len); - if (this.send_len == 0 and frame_size < 128 + 4) { - var inline_buf: [128 + 4]u8 = undefined; - bytes.copy(this.globalThis, &inline_buf[0..frame_size], slice.len); - _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + if (!this.hasBackpressure() and frame_size < stack_frame_size) { + var inline_buf: [stack_frame_size]u8 = undefined; + bytes.copy(this.globalThis, inline_buf[0..frame_size], slice.len); + _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } - _ = this.sendData(bytes, true, false); + _ = this.sendData(bytes, !this.hasBackpressure(), false); } pub fn writeString( - this: *HTTPClient, + this: *WebSocket, str_: *const JSC.ZigString, ) callconv(.C) void { const str = str_.*; @@ -1389,25 +1459,25 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } { - var inline_buf: [128 + 4]u8 = undefined; + var inline_buf: [stack_frame_size]u8 = undefined; // fast path: small frame, no backpressure, attempt to send without allocating - if (!str.is16Bit() and str.len < 128 + 4) { + if (!str.is16Bit() and str.len < stack_frame_size) { const bytes = Copy{ .latin1 = str.slice() }; const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); - if (this.send_len == 0 and frame_size < 128 + 4) { - bytes.copy(this.globalThis, &inline_buf[0..frame_size], str.len); - _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + if (!this.hasBackpressure() and frame_size < stack_frame_size) { + bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len); + _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string - } else if ((str.len * 4) < (128 + 4) and this.send_len == 0) { - const bytes = Copy{ .utf16 = str.slice() }; + } else if ((str.len * 4) < (stack_frame_size) and !this.hasBackpressure()) { + const bytes = Copy{ .utf16 = str.utf16SliceAligned() }; var byte_len: usize = 0; const frame_size = bytes.len(&byte_len); - std.debug.assert(frame_size <= 128 + 4); - bytes.copy(this.globalThis, &inline_buf[0..frame_size], byte_len); - _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]); + std.debug.assert(frame_size <= stack_frame_size); + bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len); + _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } } @@ -1417,12 +1487,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { Copy{ .utf16 = str.utf16SliceAligned() } else Copy{ .latin1 = str.slice() }, - true, + !this.hasBackpressure(), false, ); } - pub fn close(this: *HTTPClient, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { + fn dispatchClose(this: *WebSocket) void { + var out = this.outgoing_websocket orelse return; + JSC.markBinding(); + WebSocket__didCloseWithErrorCode(out, ErrorCode.closed); + } + + pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { if (this.tcp.isClosed() or this.tcp.isShutdown()) return; @@ -1430,8 +1506,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (reason) |str| { inner: { var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf); - const allocator = fixed_buffer.get(); - const wrote = std.fmt.allocPrint(allocator, "{}", str.*) catch break :inner; + const allocator = fixed_buffer.allocator(); + const wrote = std.fmt.allocPrint(allocator, "{}", .{str.*}) catch break :inner; this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len); return; } @@ -1445,20 +1521,31 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, - ) callconv(.C) *anyopaque { + ) callconv(.C) ?*anyopaque { var tcp = @ptrCast(*uws.Socket, input_socket); var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx); - - return @ptrCast( - *anyopaque, - Socket.adopt(tcp, ctx, HTTPClient{ + var adopted = Socket.adopt( + tcp, + ctx, + WebSocket, + "tcp", + WebSocket{ + .tcp = undefined, .outgoing_websocket = outgoing, .globalThis = globalThis, - }), + .receive_overflow_buffer = .{}, + .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), + }, + ) orelse return null; + adopted.send_buffer.ensureTotalCapacity(2048) catch return null; + _ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic); + return @ptrCast( + *anyopaque, + adopted, ); } - pub fn finalize(this: *HTTPClient) callconv(.C) void { + pub fn finalize(this: *WebSocket) callconv(.C) void { this.clearData(); if (this.tcp.isClosed()) @@ -1466,6 +1553,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.tcp.close(0, null); this.outgoing_websocket = null; + _ = this.globalThis.bunVM().eventLoop().ready_tasks_count.fetchSub(1, .Monotonic); } pub const Export = shim.exportFunctions(.{ |
