aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> 2022-06-19 03:59:08 -0700
committerGravatar Jarred Sumner <jarred@jarredsumner.com> 2022-06-22 06:56:47 -0700
commitdda85d92c9bafd0fe86540efd0f30be3e6c08c03 (patch)
treeaf7280cb8de11bf521b195e6cfe46604050b627c
parentab888d2ebebea0d128f4151a4240180211d95f03 (diff)
downloadbun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.tar.gz
bun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.tar.zst
bun-dda85d92c9bafd0fe86540efd0f30be3e6c08c03.zip
implement a custom websocket client
-rw-r--r--src/deps/uws.zig43
-rw-r--r--src/http/websocket.zig23
-rw-r--r--src/http/websocket_http_client.zig1087
-rw-r--r--src/javascript/jsc/bindings/ScriptExecutionContext.cpp14
-rw-r--r--src/javascript/jsc/bindings/bindings.zig4
-rw-r--r--src/javascript/jsc/bindings/exports.zig4
-rw-r--r--src/javascript/jsc/bindings/webcore/WebSocket.cpp65
-rw-r--r--src/javascript/jsc/bindings/webcore/WebSocket.h2
-rw-r--r--src/javascript/jsc/rare_data.zig4
9 files changed, 1171 insertions, 75 deletions
diff --git a/src/deps/uws.zig b/src/deps/uws.zig
index 08037a890..48c93ab53 100644
--- a/src/deps/uws.zig
+++ b/src/deps/uws.zig
@@ -70,11 +70,11 @@ pub fn NewSocketHandler(comptime ssl: bool) type {
this.socket,
) > 0;
}
- pub fn isClosed(this: ThisSocket) c_int {
+ pub fn isClosed(this: ThisSocket) bool {
return us_socket_is_closed(
comptime ssl_int,
this.socket,
- );
+ ) > 0;
}
pub fn close(this: ThisSocket, code: c_int, reason: ?*anyopaque) void {
_ = us_socket_close(
@@ -191,13 +191,38 @@ pub fn NewSocketHandler(comptime ssl: bool) type {
}
};
- us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open);
- us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close);
- us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data);
- us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable);
- us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout);
- us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error);
- us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end);
+ if (comptime @typeInfo(@TypeOf(onOpen)) != .Null)
+ us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open);
+ if (comptime @typeInfo(@TypeOf(onClose)) != .Null)
+ us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close);
+ if (comptime @typeInfo(@TypeOf(onData)) != .Null)
+ us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data);
+ if (comptime @typeInfo(@TypeOf(onWritable)) != .Null)
+ us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable);
+ if (comptime @typeInfo(@TypeOf(onTimeout)) != .Null)
+ us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout);
+ if (comptime @typeInfo(@TypeOf(onConnectError)) != .Null)
+ us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error);
+ if (comptime @typeInfo(@TypeOf(onEnd)) != .Null)
+ us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end);
+ }
+
+ pub fn adopt(
+ socket: *Socket,
+ socket_ctx: *us_socket_context_t,
+ comptime Context: type,
+ comptime socket_field_name: []const u8,
+ ctx: Context,
+ ) ?*Context {
+ var adopted = Socket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(Context)) orelse return null };
+ var holder = adopted.ext(Context) orelse {
+ if (comptime bun.Environment.allow_assert) unreachable;
+ _ = us_socket_close(comptime ssl_int, socket);
+ return null;
+ };
+ holder.* = ctx;
+ @field(holder, socket_field_name) = adopted;
+ return holder;
}
};
}
diff --git a/src/http/websocket.zig b/src/http/websocket.zig
index 293c67757..6bb433ea3 100644
--- a/src/http/websocket.zig
+++ b/src/http/websocket.zig
@@ -65,13 +65,30 @@ pub const WebsocketHeader = packed struct {
}
pub fn packLength(length: usize) u7 {
- var eight: u7 = 0;
- eight = switch (length) {
+ return switch (length) {
0...126 => @truncate(u7, length),
127...0xFFFF => 126,
else => 127,
};
- return eight;
+ }
+
+ const mask_length = 4;
+ const header_length = 2;
+
+ pub fn lengthByteCount(byte_length: usize) usize {
+ return switch (byte_length) {
+ 0...126 => 0,
+ 127...0xFFFF => @sizeOf(u16),
+ else => @sizeOf(u64),
+ };
+ }
+
+ pub fn frameSize(byte_length: usize) usize {
+ return header_length + byte_length + lengthByteCount(byte_length);
+ }
+
+ pub fn frameSizeIncludingMask(byte_length: usize) usize {
+ return frameSize(byte_length) + mask_length;
}
};
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index 0a141c4c2..2be382dbe 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -3,12 +3,6 @@
const std = @import("std");
const native_endian = @import("builtin").target.cpu.arch.endian();
-const tcp = std.x.net.tcp;
-const ip = std.x.net.ip;
-
-const IPv4 = std.x.os.IPv4;
-const IPv6 = std.x.os.IPv6;
-const os = std.os;
const bun = @import("../global.zig");
const string = bun.string;
const Output = bun.Output;
@@ -24,6 +18,9 @@ const uws = @import("uws");
const JSC = @import("javascript_core");
const PicoHTTP = @import("picohttp");
const ObjectPool = @import("../pool.zig").ObjectPool;
+const WebsocketHeader = @import("./websocket.zig").WebsocketHeader;
+const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame;
+const Opcode = @import("./websocket.zig").Opcode;
fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 {
const allocator = vm.allocator;
@@ -89,6 +86,14 @@ const ErrorCode = enum(i32) {
failed_to_connect,
headers_too_large,
ended,
+ failed_to_allocate_memory,
+ control_frame_is_fragmented,
+ invalid_control_frame,
+ compression_unsupported,
+ unexpected_mask_from_server,
+ expected_control_frame,
+ unexpected_opcode,
+ invalid_utf8,
};
extern fn WebSocket__didConnect(
websocket_context: *anyopaque,
@@ -96,9 +101,12 @@ extern fn WebSocket__didConnect(
buffered_data: ?[*]u8,
buffered_len: usize,
) void;
-extern fn WebSocket__didFailToConnect(websocket_context: *anyopaque, reason: ErrorCode) void;
+extern fn WebSocket__didFailWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void;
+extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void;
+extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: []const u8) void;
-const BodyBufBytes = [16384 - 16]u8;
+const body_buf_len = 16384 - 16;
+const BodyBufBytes = [body_buf_len]u8;
const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4);
const BodyBuf = BodyBufPool.Node;
@@ -106,7 +114,7 @@ const BodyBuf = BodyBufPool.Node;
pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
return struct {
pub const Socket = uws.NewSocketHandler(ssl);
- socket: Socket,
+ tcp: Socket,
outgoing_websocket: *anyopaque,
input_body_buf: []u8 = &[_]u8{},
client_protocol: []const u8 = "",
@@ -159,8 +167,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
var host_ = host.toSlice(bun.default_allocator);
defer host_.deinit();
- if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "socket")) |out| {
- out.socket.timeout(120);
+ if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "tcp")) |out| {
+ out.tcp.timeout(120);
return out;
}
@@ -183,33 +191,33 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn cancel(this: *HTTPClient) callconv(.C) void {
this.clearData();
- if (!this.socket.isEstablished()) {
- _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.socket.socket);
+ if (!this.tcp.isEstablished()) {
+ _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket);
} else {
- this.socket.close(0, null);
+ this.tcp.close(0, null);
}
}
pub fn fail(this: *HTTPClient, code: ErrorCode) void {
JSC.markBinding();
- WebSocket__didFailToConnect(this.outgoing_websocket, code);
+ WebSocket__didFailWithErrorCode(this.outgoing_websocket, code);
this.cancel();
}
pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void {
JSC.markBinding();
this.clearData();
- WebSocket__didFailToConnect(this.outgoing_websocket, ErrorCode.ended);
+ WebSocket__didFailWithErrorCode(this.outgoing_websocket, ErrorCode.ended);
}
pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
this.fail(code);
- if (this.socket.isClosed() == 0)
- this.socket.close(0, null);
+ if (!this.tcp.isClosed())
+ this.tcp.close(0, null);
}
pub fn handleOpen(this: *HTTPClient, socket: Socket) void {
- std.debug.assert(socket.socket == this.socket.socket);
+ std.debug.assert(socket.socket == this.tcp.socket);
std.debug.assert(this.input_body_buf.len > 0);
std.debug.assert(this.to_send.len == 0);
@@ -232,7 +240,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
- std.debug.assert(socket.socket == this.socket.socket);
+ std.debug.assert(socket.socket == this.tcp.socket);
if (comptime Environment.allow_assert)
std.debug.assert(!socket.isShutdown());
@@ -280,7 +288,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
- std.debug.assert(socket.socket == this.socket.socket);
+ std.debug.assert(socket.socket == this.tcp.socket);
this.terminate(ErrorCode.ended);
}
@@ -399,14 +407,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.clearData();
JSC.markBinding();
- WebSocket__didConnect(this.outgoing_websocket, this.socket.socket, overflow.ptr, overflow.len);
+ WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len);
}
pub fn handleWritable(
this: *HTTPClient,
socket: Socket,
) void {
- std.debug.assert(socket.socket == this.socket.socket);
+ std.debug.assert(socket.socket == this.tcp.socket);
if (this.to_send.len == 0)
return;
@@ -451,5 +459,1038 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
};
}
+pub const Mask = struct {
+ const rand_buffer_size = 64;
+ bytes: [rand_buffer_size]u8 = [_]u8{0} ** rand_buffer_size,
+ offset: usize = 0,
+ needs_reset: bool = true,
+ queued_reset: bool = false,
+
+ pub fn from(globalObject: *JSC.JSGlobalObject) *Mask {
+ return &globalObject.bunVM().rareData().websocket_mask;
+ }
+
+ pub fn get(this: *Mask) [4]u8 {
+ if (this.needs_reset) {
+ this.needs_reset = false;
+ this.offset = 0;
+ std.crypto.random.bytes(&this.bytes);
+ }
+ const offset = this.offset;
+ const wrapped = offset % rand_buffer_size;
+ var mask = this.bytes[wrapped..][0..4].*;
+ if (offset > rand_buffer_size) {
+ const wrapped2 = @truncate(u8, wrapped);
+ mask[0] +%= wrapped2;
+ mask[1] +%= wrapped2;
+ mask[2] +%= wrapped2;
+ mask[3] +%= wrapped2;
+ }
+
+ this.offset += 4;
+
+ if (!this.queued_reset and this.offset % rand_buffer_size == 0) {
+ this.queued_reset = true;
+ uws.Loop.get().?.nextTick(*Mask, this, reset);
+ }
+
+ return mask;
+ }
+
+ pub fn fill(this: *Mask, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void {
+ const mask = this.get();
+ mask_buf.* = mask;
+ var input = input_;
+ var output = output_;
+ if (comptime Environment.isAarch64 or Environment.isX64) {
+ if (input.len >= strings.ascii_vector_size) {
+ const vec: strings.AsciiVector = @as(strings.AsciiVector, mask ** (strings.ascii_vector_size / 4));
+ const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size);
+ while (input.ptr != end_ptr_wrapped_to_last_16) {
+ const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
+ output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec;
+ output = output[strings.ascii_vector_size..];
+ input = input[strings.ascii_vector_size..];
+ }
+ }
+
+ // hint to the compiler not to vectorize the next loop
+ std.debug.assert(input.len < strings.ascii_vector_size);
+ }
+
+ while (input.len >= 4) {
+ const input_vec: [4]u8 = input[0..4].*;
+ output.ptr[0..4].* = [4]u8{
+ input_vec[0] ^ mask[0],
+ input_vec[1] ^ mask[1],
+ input_vec[2] ^ mask[2],
+ input_vec[3] ^ mask[3],
+ };
+ output = output[4..];
+ input = input[4..];
+ }
+
+ for (input) |c, i| {
+ output[i] = c ^ mask[i % 4];
+ }
+ }
+
+ pub fn reset(this: *Mask) void {
+ this.queued_reset = false;
+ this.needs_reset = true;
+ }
+};
+
+const ReceiveState = enum {
+ need_header,
+ need_mask,
+ need_body,
+ extended_payload_length_16,
+ extended_payload_length_64,
+ ping,
+ closing,
+ fail,
+
+ pub fn needControlFrame(this: ReceiveState) bool {
+ return this != .need_body;
+ }
+};
+const DataType = enum {
+ none,
+ text,
+ binary,
+};
+
+fn parseWebSocketHeader(
+ bytes: [2]u8,
+ receiving_type: *Opcode,
+ payload_length: *usize,
+ is_fragmented: *bool,
+ is_final: *bool,
+ need_compression: *bool,
+) ReceiveState {
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-------+-+-------------+-------------------------------+
+ // |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ // |I|S|S|S| (4) |A| (7) | (16/64) |
+ // |N|V|V|V| |S| | (if payload len==126/127) |
+ // | |1|2|3| |K| | |
+ // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ // | Extended payload length continued, if payload len == 127 |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // | |Masking-key, if MASK set to 1 |
+ // +-------------------------------+-------------------------------+
+ // | Masking-key (continued) | Payload Data |
+ // +-------------------------------- - - - - - - - - - - - - - - - +
+ // : Payload Data continued ... :
+ // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ // | Payload Data continued ... |
+ // +---------------------------------------------------------------+
+ const header = @bitCast(WebsocketHeader, @bitCast(u16, bytes));
+ const payload = @as(usize, header.len);
+ payload_length.* = payload;
+ receiving_type.* = header.opcode;
+ is_fragmented.* = switch (header.opcode) {
+ .Continue => true,
+ else => false,
+ };
+ is_final.* = header.final;
+ need_compression.* = header.compressed;
+
+ if (header.mask) {
+ return .need_mask;
+ }
+
+ return switch (header.opcode) {
+ .Text, .Continue, .Binary => switch (payload) {
+ 0...125 => ReceiveState.need_body,
+ 126 => ReceiveState.extended_payload_length_16,
+ 127 => ReceiveState.extended_payload_length_64,
+ else => unreachable,
+ },
+ .Close => ReceiveState.close,
+ .Ping => ReceiveState.ping,
+ .Pong => ReceiveState.pong,
+ else => ReceiveState.fail,
+ };
+}
+
+const Copy = union(enum) {
+ utf16: []const u16,
+ latin1: []const u8,
+ bytes: []const u8,
+ raw: []const u8,
+
+ pub fn len(this: @This(), byte_len: *usize) usize {
+ switch (this) {
+ .utf16 => {
+ byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16);
+ return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
+ },
+ .latin1 => {
+ byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1);
+ return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
+ },
+ .bytes => {
+ byte_len.* = this.bytes.len;
+ return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
+ },
+ .raw => return this.raw.len,
+ }
+ }
+
+ pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void {
+ switch (this) {
+ .utf16 => |utf16| {
+ const length_offset = 2;
+ const length_length = WebsocketHeader.lengthByteCount(content_byte_len);
+ const mask_offset = length_offset + length_length;
+ const content_offset = mask_offset + 4;
+ var to_mask = buf[content_offset..];
+ const encode_into_result = strings.copyUTF16IntoUTF8(utf16, to_mask);
+ std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len);
+ std.debug.assert(@as(usize, encode_into_result.read) == utf16.len);
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+ header.len = WebsocketHeader.packLength(content_byte_len);
+ header.mask = true;
+ header.opcode = Opcode.Text;
+ header.compressed = false;
+ header.final = true;
+ header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable;
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ },
+ .latin1 => |latin1| {
+ const length_offset = 2;
+ const length_length = WebsocketHeader.lengthByteCount(content_byte_len);
+ const mask_offset = length_offset + length_length;
+ const content_offset = mask_offset + 4;
+ var to_mask = buf[content_offset..];
+ const encode_into_result = strings.copyLatin1IntoUTF8(latin1, to_mask);
+ std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len);
+ std.debug.assert(@as(usize, encode_into_result.read) == latin1.len);
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+ header.len = WebsocketHeader.packLength(content_byte_len);
+ header.mask = true;
+ header.opcode = Opcode.Text;
+ header.compressed = false;
+ header.final = true;
+ header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable;
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ },
+ .bytes => |bytes| {
+ const length_offset = 2;
+ const length_length = WebsocketHeader.lengthByteCount(bytes.len);
+ const mask_offset = length_offset + length_length;
+ const content_offset = mask_offset + 4;
+ var to_mask = buf[content_offset..];
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+ header.len = WebsocketHeader.packLength(bytes.len);
+ header.mask = true;
+ header.opcode = Opcode.Text;
+ header.compressed = false;
+ header.final = true;
+ header.writeHeader(std.io.fixedBufferStream(buf), bytes.len) catch unreachable;
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ },
+ .raw => @memcpy(buf.ptr, this.raw.ptr, this.raw.len),
+ }
+ }
+};
+
+pub fn NewWebSocketClient(comptime ssl: bool) type {
+ return struct {
+ pub const Socket = uws.NewSocketHandler(ssl);
+ tcp: Socket,
+ outgoing_websocket: ?*anyopaque,
+
+ receive_state: ReceiveState = ReceiveState.need_header,
+ receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)),
+ receive_remaining: usize = 0,
+ receiving_type: Opcode = Opcode.ResB,
+
+ ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** 128 + 6,
+ ping_len: u8 = 0,
+
+ receive_frame: usize = 0,
+ receive_body_remain: usize = 0,
+ receive_pending_chunk_len: usize = 0,
+ receive_body_buf: ?*BodyBuf = null,
+ receive_overflow_buffer: std.ArrayListUnmanaged(u8) = .{},
+ send_overflow_buffer: std.ArrayListUnmanaged(u8) = .{},
+
+ send_body_buf: ?*BodyBuf = null,
+ send_len: usize = 0,
+ send_off: usize = 0,
+
+ globalThis: *JSC.JSGlobalObject,
+
+ pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";
+
+ pub const shim = JSC.Shimmer("Bun", name, @This());
+
+ const HTTPClient = @This();
+
+ pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, parent: *anyopaque, ctx_: *anyopaque) callconv(.C) void {
+ var vm = global.bunVM();
+ var loop = @ptrCast(*uws.Loop, loop_);
+ var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_);
+
+ if (vm.uws_event_loop) |other| {
+ std.debug.assert(other == loop);
+ }
+
+ vm.uws_event_loop = loop;
+
+ Socket.configureChild(
+ ctx,
+ parent,
+ HTTPClient,
+ null,
+ handleClose,
+ handleData,
+ handleWritable,
+ handleTimeout,
+ handleConnectError,
+ handleEnd,
+ );
+ }
+
+ pub fn clearData(this: *HTTPClient) void {
+ this.clearReceiveBuffers(true);
+ this.clearSendBuffers(true);
+ this.ping_len = 0;
+ this.receive_pending_chunk_len = 0;
+ this.receive_remaining = 0;
+ }
+
+ pub fn cancel(this: *HTTPClient) callconv(.C) void {
+ this.clearData();
+
+ if (this.tcp.isClosed() or this.tcp.isShutdown())
+ return;
+
+ if (!this.tcp.isEstablished()) {
+ _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket);
+ } else {
+ this.tcp.close(0, null);
+ }
+ }
+
+ pub fn fail(this: *HTTPClient, code: ErrorCode) void {
+ JSC.markBinding();
+ if (this.outgoing_websocket) |ws|
+ WebSocket__didFailWithErrorCode(ws, code);
+
+ this.cancel();
+ }
+
+ pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void {
+ JSC.markBinding();
+ this.clearData();
+ if (this.outgoing_websocket) |ws|
+ WebSocket__didFailWithErrorCode(ws, ErrorCode.ended);
+ }
+
+ pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
+ this.fail(code);
+ }
+
+ fn getBody(this: *HTTPClient) *BodyBufBytes {
+ if (this.send_body_buf == null) {
+ this.send_body_buf = BodyBufPool.get(bun.default_allocator);
+ }
+
+ return &this.send_body_buf.?.data;
+ }
+
+ fn getReceiveBody(this: *HTTPClient) *BodyBufBytes {
+ if (this.receive_body_buf == null) {
+ this.receive_body_buf = BodyBufPool.get(bun.default_allocator);
+ }
+
+ return &this.receive_body_buf.?.data;
+ }
+
+ fn clearReceiveBuffers(this: *HTTPClient, free: bool) void {
+ if (this.receive_body_buf) |receive_buf| {
+ receive_buf.release();
+ this.receive_body_buf = null;
+ }
+
+ if (free) {
+ this.receive_overflow_buffer.clearAndFree(bun.default_allocator);
+ } else {
+ this.receive_overflow_buffer.clearRetainingCapacity();
+ }
+ }
+
+ fn clearSendBuffers(this: *HTTPClient, free: bool) void {
+ if (this.send_body_buf) |buf| {
+ buf.release();
+ this.send_body_buf = null;
+ }
+
+ if (free) {
+ this.send_overflow_buffer.clearAndFree(bun.default_allocator);
+ } else {
+ this.send_overflow_buffer.clearRetainingCapacity();
+ }
+ this.send_off = 0;
+ this.send_len = 0;
+ }
+
+ fn dispatchData(this: *HTTPClient, data_: []const u8, kind: Opcode) void {
+ switch (kind) {
+ .Text => {
+ // this function encodes to UTF-16 if > 127
+ // so we don't need to worry about latin1 non-ascii code points
+ const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch {
+ this.terminate(ErrorCode.invalid_utf8);
+ return 0;
+ };
+ defer this.clearReceiveBuffers(false);
+ var outstring = JSC.ZigString.Empty;
+ if (utf16_bytes_) |utf16| {
+ outstring = JSC.ZigString.from16Slice(utf16);
+ outstring.markUTF16();
+ WebSocket__didReceiveText(this.outgoing_websocket, false, &outstring);
+ } else {
+ outstring = JSC.ZigString.init(data_);
+ WebSocket__didReceiveText(this.outgoing_websocket, true, &outstring);
+ }
+ },
+ .Binary => {
+ WebSocket__didReceiveBytes(this.outgoing_websocket, data_);
+ this.clearReceiveBuffers(false);
+ },
+ else => unreachable,
+ }
+ }
+
+ pub fn consume(this: *HTTPClient, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize {
+ std.debug.assert(kind == .Text or kind == .Binary);
+ std.debug.assert(data_.len <= max);
+
+ const can_dispatch_data = is_final and data_.len == max;
+
+ // did all the data fit in the buffer?
+ // we can avoid copying & allocating a temporary buffer
+ if (can_dispatch_data and this.receive_pending_chunk_len == 0) {
+ this.dispatchData(data_, kind);
+ return data_.len;
+ }
+
+ // if we previously allocated a buffer and there's room, attempt to use that one
+ const new_pending_chunk_len = max + this.receive_pending_chunk_len;
+ if (new_pending_chunk_len <= this.receive_overflow_buffer.capacity) {
+ @memcpy(this.receive_overflow_buffer.items.ptr + this.receive_overflow_buffer.items.len, data_.ptr, data_.len);
+ this.receive_overflow_buffer.items.len += data_.len;
+ if (can_dispatch_data) {
+ this.dispatchData(this.receive_overflow_buffer.items, kind);
+ this.receive_pending_chunk_len = 0;
+ } else {
+ this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len;
+ }
+ return data_.len;
+ }
+
+ if (new_pending_chunk_len <= body_buf_len) {
+ // if our previously-allocated buffer is too small or we don't have one, use from the pool
+ var body = this.getReceiveBody();
+ @memcpy(body + this.receive_pending_chunk_len, data_.ptr, data_.len);
+ if (can_dispatch_data) {
+ this.dispatchData(body[0..new_pending_chunk_len], kind);
+ this.receive_pending_chunk_len = 0;
+ } else {
+ this.receive_pending_chunk_len += data_.len;
+ }
+ return data_.len;
+ }
+
+ {
+ // we need to copy the data into a potentially large temporary buffer
+ this.receive_overflow_buffer.appendSlice(bun.default_allocator, data_) catch {
+ this.terminate(ErrorCode.failed_to_allocate_memory);
+ return 0;
+ };
+ if (can_dispatch_data) {
+ this.dispatchData(this.receive_overflow_buffer.items, kind);
+ this.receive_pending_chunk_len = 0;
+ } else {
+ this.receive_pending_chunk_len = this.receive_overflow_buffer.items.len;
+ }
+ return data_.len;
+ }
+ }
+
+ pub fn handleData(this: *HTTPClient, socket: Socket, data_: []const u8) void {
+ var data = data_;
+ var receive_state = this.receive_state;
+ var terminated = false;
+ var is_fragmented = false;
+ var receiving_type = this.receiving_type;
+ var receive_body_remain = this.receive_body_remain;
+ var is_final = false;
+ var last_receive_data_type = receiving_type;
+
+ defer {
+ if (!terminated) {
+ this.receive_state = receive_state;
+ this.receiving_type = last_receive_data_type;
+ this.receive_body_remain = receive_body_remain;
+
+ // if we receive multiple pings in a row
+ // we just send back the last one
+ if (this.ping_len > 0) {
+ _ = this.sendPong(socket);
+ this.ping_len = 0;
+ }
+ }
+ }
+
+ var header_bytes: [@sizeOf(usize)]u8 = [_]u8{0} ** @sizeOf(usize);
+ while (true) {
+ switch (receive_state) {
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-------+-+-------------+-------------------------------+
+ // |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ // |I|S|S|S| (4) |A| (7) | (16/64) |
+ // |N|V|V|V| |S| | (if payload len==126/127) |
+ // | |1|2|3| |K| | |
+ // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ // | Extended payload length continued, if payload len == 127 |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // | |Masking-key, if MASK set to 1 |
+ // +-------------------------------+-------------------------------+
+ // | Masking-key (continued) | Payload Data |
+ // +-------------------------------- - - - - - - - - - - - - - - - +
+ // : Payload Data continued ... :
+ // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ // | Payload Data continued ... |
+ // +---------------------------------------------------------------+
+ .need_header => {
+ if (data.len < 2) {
+ this.terminate(ErrorCode.control_frame_is_fragmented);
+ terminated = true;
+ break;
+ }
+
+ header_bytes[0..2].* = data[0..2].*;
+ receive_body_remain = 0;
+ var need_compression = false;
+ is_final = false;
+
+ receive_state = parseWebSocketHeader(
+ header_bytes[0..2],
+ &receiving_type,
+ &receive_body_remain,
+ &is_fragmented,
+ &is_final,
+ &need_compression,
+ );
+ if (receiving_type == .Text or receiving_type == .Binary) {
+ last_receive_data_type = receiving_type;
+ }
+ data = data[2..];
+
+ if (receiving_type.isControl() and is_fragmented) {
+ // Control frames must not be fragmented.
+ this.terminate(ErrorCode.control_frame_is_fragmented);
+ terminated = true;
+ break;
+ }
+
+ switch (receiving_type) {
+ .Continue, .Text, .Binary, .Ping, .Pong, .Close => {},
+ else => {
+ this.terminate(ErrorCode.unexpected_opcode);
+ terminated = true;
+ break;
+ },
+ }
+
+ if (need_compression) {
+ this.terminate(ErrorCode.compression_unsupported);
+ terminated = true;
+ break;
+ }
+ },
+ .need_mask => {
+ this.terminate(.unexpected_mask_from_server);
+ terminated = true;
+ break;
+ },
+ .extended_payload_length_64, .extended_payload_length_16 => |rc| {
+ const byte_size = switch (rc) {
+ .extended_payload_length_64 => @as(usize, 8),
+ .extended_payload_length_16 => @as(usize, 2),
+ else => unreachable,
+ };
+
+ if (data.len < byte_size) {
+ this.terminate(ErrorCode.control_frame_is_fragmented);
+ terminated = true;
+ break;
+ }
+
+ // Multibyte length quantities are expressed in network byte order
+ receive_body_remain = switch (byte_size) {
+ 8 => @as(usize, std.mem.readIntBig(u64, data[0..8].*)),
+ 2 => @as(usize, std.mem.readIntBig(u16, data[0..2].*)),
+ else => unreachable,
+ };
+ data = data[byte_size..];
+ receive_state = .need_body;
+
+ if (receive_body_remain == 0) {
+ // this is an error
+ // the server should've set length to zero
+ this.terminate(ErrorCode.invalid_control_frame);
+ terminated = true;
+ break;
+ }
+ },
+
+ .ping => {
+ const ping_len = @minimum(data.len, @minimum(receive_body_remain, 125));
+ this.ping_len = @truncate(u8, ping_len);
+
+ if (ping_len > 0) {
+ @memcpy(&this.ping_frame_bytes + 6, data.ptr, ping_len);
+ data = data[ping_len..];
+ }
+
+ receive_state = .need_header;
+ receive_body_remain = 0;
+ receiving_type = last_receive_data_type;
+
+ if (data.len == 0) break;
+ },
+ .pong => {
+ const pong_len = @minimum(data.len, @minimum(receive_body_remain, this.ping_frame_bytes.len));
+ data = data[pong_len..];
+ receive_state = .need_header;
+ receiving_type = last_receive_data_type;
+ if (data.len == 0) break;
+ },
+ .need_body => {
+ if (receive_body_remain == 0 and data.len > 0) {
+ this.terminate(ErrorCode.expected_control_frame);
+ terminated = true;
+ break;
+ }
+ if (data.len == 0) return;
+
+ const to_consume = @minimum(receive_body_remain, data.len);
+
+ const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final);
+ if (consumed == 0 and last_receive_data_type == .Text) {
+ this.terminate(ErrorCode.invalid_utf8);
+ terminated = true;
+ break;
+ }
+
+ receive_body_remain -= consumed;
+ data = data[to_consume..];
+ if (receive_body_remain == 0) {
+ receive_state = .need_header;
+ is_fragmented = false;
+ }
+
+ if (data.len == 0) break;
+ },
+
+ .close => {
+ // closing frame data is text only.
+ _ = this.consume(data[0..receive_body_remain], receive_body_remain, .Text, true);
+ this.sendClose();
+ terminated = true;
+ break;
+ },
+ .fail => {
+ this.terminate(ErrorCode.unexpected_control_frame);
+ terminated = true;
+ break;
+ },
+ }
+ }
+ }
+
+ pub fn sendClose(this: *HTTPClient) void {
+ this.sendCloseWithBody(this.tcp, 1001, null, null, 0);
+ }
+
+ fn enqueueEncodedBytesMaybeFinal(
+ this: *HTTPClient,
+ socket: Socket,
+ bytes: []const u8,
+ is_closing: bool,
+ ) bool {
+ // fast path: no backpressure, no queue, just send the bytes.
+ if (this.send_len == 0) {
+ const wrote = socket.write(bytes, !is_closing);
+ const expected = @intCast(c_int, bytes.len);
+ if (wrote == expected) {
+ return true;
+ }
+
+ if (wrote < 0) {
+ this.terminate(ErrorCode.failed_to_write);
+ return false;
+ }
+
+ _ = this.copyToSendBuffer(bytes[bytes.len - @intCast(usize, wrote) ..], false, is_closing, false);
+ return true;
+ }
+
+ return this.copyToSendBuffer(bytes, true, is_closing, false);
+ }
+
+ fn copyToSendBuffer(this: *HTTPClient, bytes: []const u8, do_write: bool, is_closing: bool) bool {
+ return this.sendData(.{ .raw = bytes }, do_write, is_closing);
+ }
+
+ fn sendData(this: *HTTPClient, bytes: Copy, do_write: bool, is_closing: bool) bool {
+ var content_byte_len: usize = 0;
+ const write_len = bytes.len(&content_byte_len);
+ std.debug.assert(write_len > 0);
+
+ this.send_len += write_len;
+ var out_buf: []const u8 = "";
+ const send_end = this.send_off + this.send_len;
+ var ring_buffer = &this.send_overflow_buffer;
+
+ if (send_end <= ring_buffer.capacity) {
+ bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len);
+ ring_buffer.items.len += write_len;
+
+ out_buf = ring_buffer.items[this.send_off..send_end];
+ } else if (send_end <= body_buf_len) {
+ var buf = this.getBody();
+ bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len);
+ out_buf = buf[this.send_off..send_end];
+ } else {
+ if (this.send_body_buf) |send_body| {
+ // transfer existing send buffer to overflow buffer
+ ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch {
+ this.terminate(ErrorCode.failed_to_allocate_memory);
+ return false;
+ };
+ const off = this.send_len - write_len;
+ @memcpy(ring_buffer.items.ptr, send_body + this.send_off, off);
+ send_body.release();
+ bytes.copy(this.globalThis, (ring_buffer.items.ptr + off)[0..write_len], content_byte_len);
+ ring_buffer.items.len = this.send_len;
+ this.send_body_buf = null;
+ this.send_off = 0;
+ } else if (send_end <= ring_buffer.capacity) {
+ bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len);
+ ring_buffer.items.len += write_len;
+ // can we treat it as a ring buffer without re-allocating the array?
+ } else if (send_end - this.send_off <= ring_buffer.capacity) {
+ std.mem.copyBackwards(u8, ring_buffer.items[0..this.send_len], ring_buffer.items[this.send_off..send_end]);
+ bytes.copy(this.globalThis, ring_buffer.items.ptr[this.send_len - write_len .. this.send_len], content_byte_len);
+ ring_buffer.items.len = this.send_len;
+ this.send_off = 0;
+ } else {
+ // we need to re-allocate the array
+ ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch {
+ this.terminate(ErrorCode.failed_to_allocate_memory);
+ return false;
+ };
+
+ const data_to_copy = ring_buffer.items[this.send_off..][0..(this.send_len - write_len)];
+ const position_in_slice = ring_buffer.items[0 .. this.send_len - write_len];
+ if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr)
+ std.mem.copyBackwards(
+ u8,
+ position_in_slice,
+ data_to_copy,
+ );
+
+ ring_buffer.items.len = this.send_len;
+ bytes.copy(this.globalThis, ring_buffer.items[this.send_len - write_len ..], content_byte_len);
+ this.send_off = 0;
+ }
+
+ out_buf = ring_buffer.items[this.send_off..send_end];
+ }
+
+ if (do_write) {
+ if (comptime Environment.allow_assert) {
+ std.debug.assert(!this.tcp.isShutdown());
+ std.debug.assert(!this.tcp.isClosed());
+ std.debug.assert(this.tcp.isEstablished());
+ }
+ return this.sendBuffer(out_buf, is_closing);
+ }
+
+ return true;
+ }
+
+ fn sendBuffer(this: *HTTPClient, out_buf: []const u8, is_closing: bool) bool {
+ std.debug.assert(out_buf.len > 0);
+ const wrote = this.tcp.write(out_buf, !is_closing);
+ const expected = @intCast(c_int, out_buf.len);
+ if (wrote == expected) {
+ this.clearSendBuffers(false);
+ return true;
+ }
+
+ if (wrote < 0) {
+ this.terminate(ErrorCode.failed_to_write);
+ return false;
+ }
+
+ this.send_len -= @intCast(usize, wrote);
+ this.send_off += @intCast(usize, wrote);
+ return true;
+ }
+
+ fn enqueueEncodedBytes(this: *HTTPClient, socket: Socket, bytes: []const u8) bool {
+ return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false);
+ }
+
+ fn sendPong(this: *HTTPClient, socket: Socket) bool {
+ if (socket.isClosed() or socket.isShutdown()) {
+ this.dispatchClose();
+ return false;
+ }
+
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+ header.fin = true;
+ header.opcode = .Pong;
+
+ var to_mask = &this.ping_frame_bytes[6..][0..this.ping_len];
+
+ header.mask = to_mask.len > 0;
+ header.len = @truncate(u7, this.ping_len);
+ this.ping_frame_bytes[0..2].* = @bitCast(u16, header);
+
+ if (to_mask.len > 0) {
+ Mask.from(this.globalThis).fill(&this.ping_frame_bytes[2..6], &to_mask, &to_mask);
+ return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]);
+ } else {
+ return this.enqueueEncodedBytes(socket, &this.ping_frame_bytes[0..2]);
+ }
+ }
+
+ fn sendCloseWithBody(
+ this: *HTTPClient,
+ socket: Socket,
+ code: u16,
+ body: ?*[125]u8,
+ body_len: usize,
+ ) void {
+ if (socket.isClosed() or socket.isShutdown()) {
+ this.dispatchClose();
+ this.clearData();
+ return;
+ }
+
+ socket.shutdownRead();
+ var final_body_bytes: [128 + 8]u8 = undefined;
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+ header.fin = true;
+ header.opcode = .Close;
+ header.mask = true;
+ header.len = body_len + 2;
+ final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header));
+ var mask_buf = &final_body_bytes[2..6];
+ std.mem.writeIntSliceBig(u16, &final_body_bytes[6..8], code);
+
+ if (body) |data| {
+ if (body_len > 0) @memcpy(&final_body_bytes[8..], data, body_len);
+ }
+
+ // we must mask the code
+ var slice = final_body_bytes[0..(8 + body_len)];
+ Mask.from(this.globalThis).fill(mask_buf, slice[6..], slice[6..]);
+
+ if (this.enqueueEncodedBytesMaybeFinal(socket, slice, true)) {
+ this.dispatchClose();
+ this.clearData();
+ }
+ }
+
+ pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
+ std.debug.assert(socket.socket == this.tcp.socket);
+ this.terminate(ErrorCode.ended);
+ }
+
+ pub fn handleWritable(
+ this: *HTTPClient,
+ socket: Socket,
+ ) void {
+ std.debug.assert(socket.socket == this.tcp.socket);
+ if (this.send_len == 0)
+ return;
+
+ var send_buf: []const u8 = undefined;
+ if (this.send_body_buf) |send_body_buf| {
+ send_buf = send_body_buf[this.send_off..][0..this.send_len];
+ std.debug.assert(this.send_overflow_buffer.items.len == 0);
+ } else {
+ send_buf = this.send_overflow_buffer.items[this.send_off..][0..this.send_len];
+ }
+ std.debug.assert(send_buf.len == this.send_len);
+ _ = this.sendBuffer(send_buf, false);
+ }
+ pub fn handleTimeout(
+ this: *HTTPClient,
+ _: Socket,
+ ) void {
+ this.terminate(ErrorCode.timeout);
+ }
+ pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void {
+ this.terminate(ErrorCode.failed_to_connect);
+ }
+
+ pub fn writeBinaryData(
+ this: *HTTPClient,
+ ptr: [*]const u8,
+ len: usize,
+ ) callconv(.C) void {
+ if (this.tcp.isClosed() or this.tcp.isShutdown()) {
+ this.dispatchClose();
+ return;
+ }
+
+ if (len == 0)
+ return;
+
+ const slice = ptr[0..len];
+ const bytes = Copy{ .bytes = slice };
+ // fast path: small frame, no backpressure, attempt to send without allocating
+ const frame_size = WebsocketHeader.frameSizeIncludingMask(len);
+ if (this.send_len == 0 and frame_size < 128 + 4) {
+ var inline_buf: [128 + 4]u8 = undefined;
+ bytes.copy(this.globalThis, &inline_buf[0..frame_size], slice.len);
+ _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ return;
+ }
+
+ _ = this.sendData(bytes, true, false);
+ }
+ pub fn writeString(
+ this: *HTTPClient,
+ str_: *const JSC.ZigString,
+ ) callconv(.C) void {
+ const str = str_.*;
+ if (this.tcp.isClosed() or this.tcp.isShutdown()) {
+ this.dispatchClose();
+ return;
+ }
+
+ if (str.len == 0) {
+ return;
+ }
+
+ {
+ var inline_buf: [128 + 4]u8 = undefined;
+
+ // fast path: small frame, no backpressure, attempt to send without allocating
+ if (!str.is16Bit() and str.len < 128 + 4) {
+ const bytes = Copy{ .latin1 = str.slice() };
+ const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len);
+ if (this.send_len == 0 and frame_size < 128 + 4) {
+ bytes.copy(this.globalThis, &inline_buf[0..frame_size], str.len);
+ _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ return;
+ }
+ // max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string
+ } else if ((str.len * 4) < (128 + 4) and this.send_len == 0) {
+ const bytes = Copy{ .utf16 = str.slice() };
+ var byte_len: usize = 0;
+ const frame_size = bytes.len(&byte_len);
+ std.debug.assert(frame_size <= 128 + 4);
+ bytes.copy(this.globalThis, &inline_buf[0..frame_size], byte_len);
+ _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ return;
+ }
+ }
+
+ _ = this.sendData(
+ if (str.is16Bit())
+ Copy{ .utf16 = str.utf16SliceAligned() }
+ else
+ Copy{ .latin1 = str.slice() },
+ true,
+ false,
+ );
+ }
+
+ pub fn close(this: *HTTPClient, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void {
+ if (this.tcp.isClosed() or this.tcp.isShutdown())
+ return;
+
+ var close_reason_buf: [128]u8 = undefined;
+ if (reason) |str| {
+ inner: {
+ var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf);
+ const allocator = fixed_buffer.get();
+ const wrote = std.fmt.allocPrint(allocator, "{}", str.*) catch break :inner;
+ this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len);
+ return;
+ }
+ }
+
+ this.sendCloseWithBody(this.tcp, code, null, 0);
+ }
+
+ pub fn init(
+ outgoing: *anyopaque,
+ input_socket: *anyopaque,
+ socket_ctx: *anyopaque,
+ globalThis: *JSC.JSGlobalObject,
+ ) callconv(.C) *anyopaque {
+ var tcp = @ptrCast(*uws.Socket, input_socket);
+ var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx);
+
+ return @ptrCast(
+ *anyopaque,
+ Socket.adopt(tcp, ctx, HTTPClient{
+ .outgoing_websocket = outgoing,
+ .globalThis = globalThis,
+ }),
+ );
+ }
+
+ pub fn finalize(this: *HTTPClient) callconv(.C) void {
+ this.clearData();
+
+ if (this.tcp.isClosed())
+ return;
+
+ this.tcp.close(0, null);
+ this.outgoing_websocket = null;
+ }
+
+ pub const Export = shim.exportFunctions(.{
+ .writeBinaryData = writeBinaryData,
+ .writeString = writeString,
+ .close = close,
+ .register = register,
+ .init = init,
+ .finalize = finalize,
+ });
+
+ comptime {
+ if (!JSC.is_bindgen) {
+ @export(writeBinaryData, .{ .name = Export[0].symbol_name });
+ @export(writeString, .{ .name = Export[1].symbol_name });
+ @export(close, .{ .name = Export[2].symbol_name });
+ @export(register, .{ .name = Export[3].symbol_name });
+ @export(init, .{ .name = Export[4].symbol_name });
+ @export(finalize, .{ .name = Export[5].symbol_name });
+ }
+ }
+ };
+}
+
pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false);
pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true);
+pub const WebSocketClient = NewWebSocketClient(false);
+pub const WebSocketClientTLS = NewWebSocketClient(true);
diff --git a/src/javascript/jsc/bindings/ScriptExecutionContext.cpp b/src/javascript/jsc/bindings/ScriptExecutionContext.cpp
index b89e0645f..9c6735993 100644
--- a/src/javascript/jsc/bindings/ScriptExecutionContext.cpp
+++ b/src/javascript/jsc/bindings/ScriptExecutionContext.cpp
@@ -63,18 +63,18 @@ static uWS::WebSocketContext<SSL, false, WebCore::WebSocket*>* registerWebSocket
auto* opts = ctx->getExt();
/* Maximum message size we can receive */
- static unsigned int maxPayloadLength = 128 * 1024 * 1024;
+ unsigned int maxPayloadLength = 16 * 1024;
/* 2 minutes timeout is good */
- static unsigned short idleTimeout = 120;
+ unsigned short idleTimeout = 120;
/* 64kb backpressure is probably good */
- static unsigned int maxBackpressure = 128 * 1024 * 1024;
- static bool closeOnBackpressureLimit = false;
+ unsigned int maxBackpressure = 64 * 1024;
+ bool closeOnBackpressureLimit = false;
/* This one depends on kernel timeouts and is a bad default */
- static bool resetIdleTimeoutOnSend = false;
+ bool resetIdleTimeoutOnSend = false;
/* A good default, esp. for newcomers */
- static bool sendPingsAutomatically = true;
+ bool sendPingsAutomatically = false;
/* Maximum socket lifetime in seconds before forced closure (defaults to disabled) */
- static unsigned short maxLifetime = 0;
+ unsigned short maxLifetime = 0;
opts->maxPayloadLength = maxPayloadLength;
opts->maxBackpressure = maxBackpressure;
diff --git a/src/javascript/jsc/bindings/bindings.zig b/src/javascript/jsc/bindings/bindings.zig
index 841fa7a38..1b6524826 100644
--- a/src/javascript/jsc/bindings/bindings.zig
+++ b/src/javascript/jsc/bindings/bindings.zig
@@ -210,6 +210,10 @@ pub const ZigString = extern struct {
return JSC.JSValue.fromRef(slice_).getZigString(ctx.ptr());
}
+ pub fn from16Slice(slice_: []const u16) ZigString {
+ return from16(slice_.ptr, slice_.len);
+ }
+
pub fn from16(slice_: [*]const u16, len: usize) ZigString {
var str = init(@ptrCast([*]const u8, slice_)[0..len]);
str.markUTF16();
diff --git a/src/javascript/jsc/bindings/exports.zig b/src/javascript/jsc/bindings/exports.zig
index 0c9c2a677..c70388186 100644
--- a/src/javascript/jsc/bindings/exports.zig
+++ b/src/javascript/jsc/bindings/exports.zig
@@ -189,6 +189,8 @@ pub const JSArrayBufferSink = JSC.WebCore.ArrayBufferSink.JSSink;
// WebSocket
pub const WebSocketHTTPClient = @import("../../../http/websocket_http_client.zig").WebSocketHTTPClient;
pub const WebSocketHTTSPClient = @import("../../../http/websocket_http_client.zig").WebSocketHTTPSClient;
+pub const WebSocketClient = @import("../../../http/websocket_http_client.zig").WebSocketClient;
+pub const WebSocketClientTLS = @import("../../../http/websocket_http_client.zig").WebSocketClientTLS;
pub fn Errorable(comptime Type: type) type {
return extern struct {
@@ -2510,6 +2512,8 @@ pub const Formatter = ZigConsoleClient.Formatter;
comptime {
WebSocketHTTPClient.shim.ref();
WebSocketHTTSPClient.shim.ref();
+ WebSocketClient.shim.ref();
+ WebSocketClientTLS.shim.ref();
if (!is_bindgen) {
_ = Process.getTitle;
diff --git a/src/javascript/jsc/bindings/webcore/WebSocket.cpp b/src/javascript/jsc/bindings/webcore/WebSocket.cpp
index 9b996565c..3c9f3a373 100644
--- a/src/javascript/jsc/bindings/webcore/WebSocket.cpp
+++ b/src/javascript/jsc/bindings/webcore/WebSocket.cpp
@@ -420,8 +420,6 @@ ExceptionOr<void> WebSocket::send(const String& message)
if (utf8.length() > 0)
this->sendWebSocketData<false>(utf8.data(), utf8.length());
- delete utf8;
-
return {};
}
@@ -490,31 +488,34 @@ void WebSocket::sendWebSocketData(const char* baseAddress, size_t length)
if constexpr (isBinary)
opCode = uWS::OpCode::BINARY;
- switch (m_connectedWebSocketKind) {
- case ConnectedWebSocketKind::Client: {
- this->m_connectedWebSocket.client->send({ baseAddress, length }, opCode);
- this->m_bufferedAmount = this->m_connectedWebSocket.client->getBufferedAmount();
- break;
- }
- case ConnectedWebSocketKind::ClientSSL: {
- this->m_connectedWebSocket.clientSSL->send({ baseAddress, length }, opCode);
- this->m_bufferedAmount = this->m_connectedWebSocket.clientSSL->getBufferedAmount();
- break;
- }
- case ConnectedWebSocketKind::Server: {
- this->m_connectedWebSocket.server->send({ baseAddress, length }, opCode);
- this->m_bufferedAmount = this->m_connectedWebSocket.server->getBufferedAmount();
- break;
- }
- case ConnectedWebSocketKind::ServerSSL: {
- this->m_connectedWebSocket.serverSSL->send({ baseAddress, length }, opCode);
- this->m_bufferedAmount = this->m_connectedWebSocket.serverSSL->getBufferedAmount();
- break;
- }
- default: {
- RELEASE_ASSERT_NOT_REACHED();
- }
- }
+ this->m_connectedWebSocket.client->cork(
+ [&]() {
+ switch (m_connectedWebSocketKind) {
+ case ConnectedWebSocketKind::Client: {
+ this->m_connectedWebSocket.client->send({ baseAddress, length }, opCode);
+ this->m_bufferedAmount = this->m_connectedWebSocket.client->getBufferedAmount();
+ break;
+ }
+ case ConnectedWebSocketKind::ClientSSL: {
+ this->m_connectedWebSocket.clientSSL->send({ baseAddress, length }, opCode);
+ this->m_bufferedAmount = this->m_connectedWebSocket.clientSSL->getBufferedAmount();
+ break;
+ }
+ case ConnectedWebSocketKind::Server: {
+ this->m_connectedWebSocket.server->send({ baseAddress, length }, opCode);
+ this->m_bufferedAmount = this->m_connectedWebSocket.server->getBufferedAmount();
+ break;
+ }
+ case ConnectedWebSocketKind::ServerSSL: {
+ this->m_connectedWebSocket.serverSSL->send({ baseAddress, length }, opCode);
+ this->m_bufferedAmount = this->m_connectedWebSocket.serverSSL->getBufferedAmount();
+ break;
+ }
+ default: {
+ RELEASE_ASSERT_NOT_REACHED();
+ }
+ }
+ });
}
ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, const String& reason)
@@ -856,16 +857,18 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe
this->didConnect();
}
-void WebSocket::didFailToConnect(int32_t code)
+void WebSocket::didFailWithErrorCode(int32_t code)
{
m_state = CLOSED;
// this means we already handled it
- if (this->m_upgradeClient == nullptr) {
+ if (this->m_upgradeClient == nullptr && this->m_connectedWebSocketKind == ConnectedWebSocketKind::None) {
return;
}
this->m_upgradeClient = nullptr;
+ this->m_connectedWebSocketKind = ConnectedWebSocketKind::None;
+ this->m_connectedWebSocket.client = nullptr;
switch (code) {
// cancel
@@ -982,7 +985,7 @@ extern "C" void WebSocket__didConnect(WebCore::WebSocket* webSocket, us_socket_t
{
webSocket->didConnect(socket, bufferedData, len);
}
-extern "C" void WebSocket__didFailToConnect(WebCore::WebSocket* webSocket, int32_t errorCode)
+extern "C" void WebSocket__didFailWithErrorCode(WebCore::WebSocket* webSocket, int32_t errorCode)
{
- webSocket->didFailToConnect(errorCode);
+ webSocket->didFailWithErrorCode(errorCode);
} \ No newline at end of file
diff --git a/src/javascript/jsc/bindings/webcore/WebSocket.h b/src/javascript/jsc/bindings/webcore/WebSocket.h
index 351e8c6b6..51b885c3b 100644
--- a/src/javascript/jsc/bindings/webcore/WebSocket.h
+++ b/src/javascript/jsc/bindings/webcore/WebSocket.h
@@ -97,7 +97,7 @@ public:
void didConnect();
void didClose(unsigned unhandledBufferedAmount, unsigned short code, const String& reason);
void didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize);
- void didFailToConnect(int32_t code);
+ void didFailWithErrorCode(int32_t code);
void didReceiveMessage(String&& message);
void didReceiveData(const char* data, size_t length);
diff --git a/src/javascript/jsc/rare_data.zig b/src/javascript/jsc/rare_data.zig
index 3020dfc16..f82418da9 100644
--- a/src/javascript/jsc/rare_data.zig
+++ b/src/javascript/jsc/rare_data.zig
@@ -7,12 +7,14 @@ const Syscall = @import("./node/syscall.zig");
const JSC = @import("javascript_core");
const std = @import("std");
const BoringSSL = @import("boringssl");
-boring_ssl_engine: ?*BoringSSL.ENGINE = null,
+const WebSocketClientMask = @import("../../http/websocket_http_client.zig").Mask;
+boring_ssl_engine: ?*BoringSSL.ENGINE = null,
editor_context: EditorContext = EditorContext{},
stderr_store: ?*Blob.Store = null,
stdin_store: ?*Blob.Store = null,
stdout_store: ?*Blob.Store = null,
+websocket_mask: WebSocketClientMask = WebSocketClientMask{},
// TODO: make this per JSGlobalObject instead of global
// This does not handle ShadowRealm correctly!