aboutsummaryrefslogtreecommitdiff
path: root/src/http/websocket_http_client.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r--src/http/websocket_http_client.zig576
1 files changed, 332 insertions, 244 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index 2be382dbe..8522c6038 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -92,6 +92,7 @@ const ErrorCode = enum(i32) {
compression_unsupported,
unexpected_mask_from_server,
expected_control_frame,
+ unsupported_control_frame,
unexpected_opcode,
invalid_utf8,
};
@@ -101,9 +102,9 @@ extern fn WebSocket__didConnect(
buffered_data: ?[*]u8,
buffered_len: usize,
) void;
-extern fn WebSocket__didFailWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void;
+extern fn WebSocket__didCloseWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void;
extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void;
-extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: []const u8) void;
+extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: [*]const u8, byte_len: usize) void;
const body_buf_len = 16384 - 16;
const BodyBufBytes = [body_buf_len]u8;
@@ -159,7 +160,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
var client_protocol_hash: u64 = 0;
var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null;
var client: HTTPClient = HTTPClient{
- .socket = undefined,
+ .tcp = undefined,
.outgoing_websocket = websocket,
.input_body_buf = body,
.websocket_protocol = client_protocol_hash,
@@ -200,14 +201,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn fail(this: *HTTPClient, code: ErrorCode) void {
JSC.markBinding();
- WebSocket__didFailWithErrorCode(this.outgoing_websocket, code);
+ WebSocket__didCloseWithErrorCode(this.outgoing_websocket, code);
this.cancel();
}
pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void {
JSC.markBinding();
this.clearData();
- WebSocket__didFailWithErrorCode(this.outgoing_websocket, ErrorCode.ended);
+ WebSocket__didCloseWithErrorCode(this.outgoing_websocket, ErrorCode.ended);
}
pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
@@ -407,6 +408,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.clearData();
JSC.markBinding();
+ this.tcp.timeout(0);
WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len);
}
@@ -500,17 +502,45 @@ pub const Mask = struct {
pub fn fill(this: *Mask, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void {
const mask = this.get();
mask_buf.* = mask;
+
+ const skip_mask = @bitCast(u32, mask) == 0;
+ if (!skip_mask) {
+ fillWithSkipMask(mask, output_, input_, false);
+ } else {
+ fillWithSkipMask(mask, output_, input_, true);
+ }
+ }
+
+ fn fillWithSkipMask(mask: [4]u8, output_: []u8, input_: []const u8, comptime skip_mask: bool) void {
var input = input_;
var output = output_;
+
if (comptime Environment.isAarch64 or Environment.isX64) {
if (input.len >= strings.ascii_vector_size) {
- const vec: strings.AsciiVector = @as(strings.AsciiVector, mask ** (strings.ascii_vector_size / 4));
+ const vec: strings.AsciiVector = brk: {
+ var in: [strings.ascii_vector_size]u8 = undefined;
+ comptime var i: usize = 0;
+ inline while (i < strings.ascii_vector_size) : (i += 4) {
+ in[i..][0..4].* = mask;
+ }
+ break :brk @as(strings.AsciiVector, in);
+ };
const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size);
- while (input.ptr != end_ptr_wrapped_to_last_16) {
- const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
- output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec;
- output = output[strings.ascii_vector_size..];
- input = input[strings.ascii_vector_size..];
+
+ if (comptime skip_mask) {
+ while (input.ptr != end_ptr_wrapped_to_last_16) {
+ const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
+ output.ptr[0..strings.ascii_vector_size].* = input_vec;
+ output = output[strings.ascii_vector_size..];
+ input = input[strings.ascii_vector_size..];
+ }
+ } else {
+ while (input.ptr != end_ptr_wrapped_to_last_16) {
+ const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
+ output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec;
+ output = output[strings.ascii_vector_size..];
+ input = input[strings.ascii_vector_size..];
+ }
}
}
@@ -518,20 +548,35 @@ pub const Mask = struct {
std.debug.assert(input.len < strings.ascii_vector_size);
}
- while (input.len >= 4) {
- const input_vec: [4]u8 = input[0..4].*;
- output.ptr[0..4].* = [4]u8{
- input_vec[0] ^ mask[0],
- input_vec[1] ^ mask[1],
- input_vec[2] ^ mask[2],
- input_vec[3] ^ mask[3],
- };
- output = output[4..];
- input = input[4..];
+ if (comptime !skip_mask) {
+ while (input.len >= 4) {
+ const input_vec: [4]u8 = input[0..4].*;
+ output.ptr[0..4].* = [4]u8{
+ input_vec[0] ^ mask[0],
+ input_vec[1] ^ mask[1],
+ input_vec[2] ^ mask[2],
+ input_vec[3] ^ mask[3],
+ };
+ output = output[4..];
+ input = input[4..];
+ }
+ } else {
+ while (input.len >= 4) {
+ const input_vec: [4]u8 = input[0..4].*;
+ output.ptr[0..4].* = input_vec;
+ output = output[4..];
+ input = input[4..];
+ }
}
- for (input) |c, i| {
- output[i] = c ^ mask[i % 4];
+ if (comptime !skip_mask) {
+ for (input) |c, i| {
+ output[i] = c ^ mask[i % 4];
+ }
+ } else {
+ for (input) |c, i| {
+ output[i] = c;
+ }
}
}
@@ -548,7 +593,8 @@ const ReceiveState = enum {
extended_payload_length_16,
extended_payload_length_64,
ping,
- closing,
+ pong,
+ close,
fail,
pub fn needControlFrame(this: ReceiveState) bool {
@@ -629,71 +675,82 @@ const Copy = union(enum) {
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.latin1 => {
- byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1);
+ byte_len.* = this.latin1.len;
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.bytes => {
byte_len.* = this.bytes.len;
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
- .raw => return this.raw.len,
+ .raw => {
+ byte_len.* = this.raw.len;
+ return this.raw.len;
+ },
}
}
pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize) void {
+ if (this == .raw) {
+ std.debug.assert(buf.len >= this.raw.len);
+ std.debug.assert(buf.ptr != this.raw.ptr);
+ @memcpy(buf.ptr, this.raw.ptr, this.raw.len);
+ return;
+ }
+
+ const how_big_is_the_length_integer = WebsocketHeader.lengthByteCount(content_byte_len);
+ const how_big_is_the_mask = 4;
+ const mask_offset = 2 + how_big_is_the_length_integer;
+ const content_offset = mask_offset + how_big_is_the_mask;
+
+ // 2 byte header
+ // 4 byte mask
+ // 0, 2, 8 byte length
+ var to_mask = buf[content_offset..];
+
+ var header = @bitCast(WebsocketHeader, @as(u16, 0));
+
+ // Write extended length if needed
+ switch (how_big_is_the_length_integer) {
+ 0 => {},
+ 2 => std.mem.writeIntBig(u16, buf[2..][0..2], @truncate(u16, content_byte_len)),
+ 8 => std.mem.writeIntBig(u64, buf[2..][0..8], @truncate(u64, content_byte_len)),
+ else => unreachable,
+ }
+
+ header.mask = true;
+ header.compressed = false;
+ header.final = true;
+
+ std.debug.assert(WebsocketHeader.frameSizeIncludingMask(content_byte_len) == buf.len);
+
switch (this) {
.utf16 => |utf16| {
- const length_offset = 2;
- const length_length = WebsocketHeader.lengthByteCount(content_byte_len);
- const mask_offset = length_offset + length_length;
- const content_offset = mask_offset + 4;
- var to_mask = buf[content_offset..];
- const encode_into_result = strings.copyUTF16IntoUTF8(utf16, to_mask);
+ header.len = WebsocketHeader.packLength(content_byte_len);
+ const encode_into_result = strings.copyUTF16IntoUTF8(to_mask, []const u16, utf16);
std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len);
std.debug.assert(@as(usize, encode_into_result.read) == utf16.len);
- var header = @bitCast(WebsocketHeader, @as(u16, 0));
- header.len = WebsocketHeader.packLength(content_byte_len);
- header.mask = true;
+ header.len = WebsocketHeader.packLength(encode_into_result.written);
header.opcode = Opcode.Text;
- header.compressed = false;
- header.final = true;
- header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable;
- Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ header.writeHeader(std.io.fixedBufferStream(buf).writer(), encode_into_result.written) catch unreachable;
+
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]);
},
.latin1 => |latin1| {
- const length_offset = 2;
- const length_length = WebsocketHeader.lengthByteCount(content_byte_len);
- const mask_offset = length_offset + length_length;
- const content_offset = mask_offset + 4;
- var to_mask = buf[content_offset..];
- const encode_into_result = strings.copyLatin1IntoUTF8(latin1, to_mask);
+ const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1);
std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len);
std.debug.assert(@as(usize, encode_into_result.read) == latin1.len);
- var header = @bitCast(WebsocketHeader, @as(u16, 0));
- header.len = WebsocketHeader.packLength(content_byte_len);
- header.mask = true;
+ header.len = WebsocketHeader.packLength(encode_into_result.written);
header.opcode = Opcode.Text;
- header.compressed = false;
- header.final = true;
- header.writeHeader(std.io.fixedBufferStream(buf), content_byte_len) catch unreachable;
- Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ header.writeHeader(std.io.fixedBufferStream(buf).writer(), encode_into_result.written) catch unreachable;
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]);
},
.bytes => |bytes| {
- const length_offset = 2;
- const length_length = WebsocketHeader.lengthByteCount(bytes.len);
- const mask_offset = length_offset + length_length;
- const content_offset = mask_offset + 4;
- var to_mask = buf[content_offset..];
- var header = @bitCast(WebsocketHeader, @as(u16, 0));
header.len = WebsocketHeader.packLength(bytes.len);
- header.mask = true;
- header.opcode = Opcode.Text;
- header.compressed = false;
- header.final = true;
- header.writeHeader(std.io.fixedBufferStream(buf), bytes.len) catch unreachable;
- Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask, to_mask);
+ header.opcode = Opcode.Binary;
+ header.writeHeader(std.io.fixedBufferStream(buf).writer(), bytes.len) catch unreachable;
+ Mask.from(globalThis).fill(buf[mask_offset..][0..4], to_mask[0..content_byte_len], bytes);
},
- .raw => @memcpy(buf.ptr, this.raw.ptr, this.raw.len),
+ .raw => unreachable,
}
}
};
@@ -709,7 +766,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
receive_remaining: usize = 0,
receiving_type: Opcode = Opcode.ResB,
- ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** 128 + 6,
+ ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6),
ping_len: u8 = 0,
receive_frame: usize = 0,
@@ -717,9 +774,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
receive_pending_chunk_len: usize = 0,
receive_body_buf: ?*BodyBuf = null,
receive_overflow_buffer: std.ArrayListUnmanaged(u8) = .{},
- send_overflow_buffer: std.ArrayListUnmanaged(u8) = .{},
- send_body_buf: ?*BodyBuf = null,
+ send_buffer: bun.LinearFifo(u8, .Dynamic),
send_len: usize = 0,
send_off: usize = 0,
@@ -728,10 +784,11 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";
pub const shim = JSC.Shimmer("Bun", name, @This());
+ const stack_frame_size = 1024;
- const HTTPClient = @This();
+ const WebSocket = @This();
- pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, parent: *anyopaque, ctx_: *anyopaque) callconv(.C) void {
+ pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void {
var vm = global.bunVM();
var loop = @ptrCast(*uws.Loop, loop_);
var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_);
@@ -742,10 +799,9 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
vm.uws_event_loop = loop;
- Socket.configureChild(
+ Socket.configure(
ctx,
- parent,
- HTTPClient,
+ WebSocket,
null,
handleClose,
handleData,
@@ -756,7 +812,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
);
}
- pub fn clearData(this: *HTTPClient) void {
+ pub fn clearData(this: *WebSocket) void {
this.clearReceiveBuffers(true);
this.clearSendBuffers(true);
this.ping_len = 0;
@@ -764,7 +820,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
this.receive_remaining = 0;
}
- pub fn cancel(this: *HTTPClient) callconv(.C) void {
+ pub fn cancel(this: *WebSocket) callconv(.C) void {
this.clearData();
if (this.tcp.isClosed() or this.tcp.isShutdown())
@@ -777,34 +833,34 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}
- pub fn fail(this: *HTTPClient, code: ErrorCode) void {
+ pub fn fail(this: *WebSocket, code: ErrorCode) void {
JSC.markBinding();
if (this.outgoing_websocket) |ws|
- WebSocket__didFailWithErrorCode(ws, code);
+ WebSocket__didCloseWithErrorCode(ws, code);
this.cancel();
}
- pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void {
+ pub fn handleClose(this: *WebSocket, _: Socket, _: c_int, _: ?*anyopaque) void {
JSC.markBinding();
this.clearData();
if (this.outgoing_websocket) |ws|
- WebSocket__didFailWithErrorCode(ws, ErrorCode.ended);
+ WebSocket__didCloseWithErrorCode(ws, ErrorCode.ended);
}
- pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
+ pub fn terminate(this: *WebSocket, code: ErrorCode) void {
this.fail(code);
}
- fn getBody(this: *HTTPClient) *BodyBufBytes {
- if (this.send_body_buf == null) {
- this.send_body_buf = BodyBufPool.get(bun.default_allocator);
- }
+ // fn getBody(this: *WebSocket) *BodyBufBytes {
+ // if (this.send_body_buf == null) {
+ // this.send_body_buf = BodyBufPool.get(bun.default_allocator);
+ // }
- return &this.send_body_buf.?.data;
- }
+ // return &this.send_body_buf.?.data;
+ // }
- fn getReceiveBody(this: *HTTPClient) *BodyBufBytes {
+ fn getReceiveBody(this: *WebSocket) *BodyBufBytes {
if (this.receive_body_buf == null) {
this.receive_body_buf = BodyBufPool.get(bun.default_allocator);
}
@@ -812,7 +868,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return &this.receive_body_buf.?.data;
}
- fn clearReceiveBuffers(this: *HTTPClient, free: bool) void {
+ fn clearReceiveBuffers(this: *WebSocket, free: bool) void {
if (this.receive_body_buf) |receive_buf| {
receive_buf.release();
this.receive_body_buf = null;
@@ -825,50 +881,58 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}
- fn clearSendBuffers(this: *HTTPClient, free: bool) void {
- if (this.send_body_buf) |buf| {
- buf.release();
- this.send_body_buf = null;
- }
+ fn clearSendBuffers(this: *WebSocket, free: bool) void {
+ // if (this.send_body_buf) |buf| {
+ // buf.release();
+ // this.send_body_buf = null;
+ // }
+ this.send_buffer.discard(this.send_buffer.count);
if (free) {
- this.send_overflow_buffer.clearAndFree(bun.default_allocator);
- } else {
- this.send_overflow_buffer.clearRetainingCapacity();
+ this.send_buffer.deinit();
+ this.send_buffer.buf.len = 0;
}
+
this.send_off = 0;
this.send_len = 0;
}
- fn dispatchData(this: *HTTPClient, data_: []const u8, kind: Opcode) void {
+ fn dispatchData(this: *WebSocket, data_: []const u8, kind: Opcode) void {
+ var out = this.outgoing_websocket orelse {
+ this.clearData();
+ return;
+ };
switch (kind) {
.Text => {
// this function encodes to UTF-16 if > 127
// so we don't need to worry about latin1 non-ascii code points
const utf16_bytes_ = strings.toUTF16Alloc(bun.default_allocator, data_, true) catch {
this.terminate(ErrorCode.invalid_utf8);
- return 0;
+ return;
};
defer this.clearReceiveBuffers(false);
var outstring = JSC.ZigString.Empty;
if (utf16_bytes_) |utf16| {
outstring = JSC.ZigString.from16Slice(utf16);
outstring.markUTF16();
- WebSocket__didReceiveText(this.outgoing_websocket, false, &outstring);
+ JSC.markBinding();
+ WebSocket__didReceiveText(out, false, &outstring);
} else {
outstring = JSC.ZigString.init(data_);
- WebSocket__didReceiveText(this.outgoing_websocket, true, &outstring);
+ JSC.markBinding();
+ WebSocket__didReceiveText(out, true, &outstring);
}
},
.Binary => {
- WebSocket__didReceiveBytes(this.outgoing_websocket, data_);
+ JSC.markBinding();
+ WebSocket__didReceiveBytes(out, data_.ptr, data_.len);
this.clearReceiveBuffers(false);
},
else => unreachable,
}
}
- pub fn consume(this: *HTTPClient, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize {
+ pub fn consume(this: *WebSocket, data_: []const u8, max: usize, kind: Opcode, is_final: bool) usize {
std.debug.assert(kind == .Text or kind == .Binary);
std.debug.assert(data_.len <= max);
@@ -898,7 +962,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
if (new_pending_chunk_len <= body_buf_len) {
// if our previously-allocated buffer is too small or we don't have one, use from the pool
var body = this.getReceiveBody();
- @memcpy(body + this.receive_pending_chunk_len, data_.ptr, data_.len);
+ @memcpy(body[this.receive_pending_chunk_len..].ptr, data_.ptr, data_.len);
if (can_dispatch_data) {
this.dispatchData(body[0..new_pending_chunk_len], kind);
this.receive_pending_chunk_len = 0;
@@ -924,7 +988,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}
- pub fn handleData(this: *HTTPClient, socket: Socket, data_: []const u8) void {
+ pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void {
var data = data_;
var receive_state = this.receive_state;
var terminated = false;
@@ -983,7 +1047,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
is_final = false;
receive_state = parseWebSocketHeader(
- header_bytes[0..2],
+ header_bytes[0..2].*,
&receiving_type,
&receive_body_remain,
&is_fragmented,
@@ -1005,7 +1069,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
switch (receiving_type) {
.Continue, .Text, .Binary, .Ping, .Pong, .Close => {},
else => {
- this.terminate(ErrorCode.unexpected_opcode);
+ this.terminate(ErrorCode.unsupported_control_frame);
terminated = true;
break;
},
@@ -1037,8 +1101,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
// Multibyte length quantities are expressed in network byte order
receive_body_remain = switch (byte_size) {
- 8 => @as(usize, std.mem.readIntBig(u64, data[0..8].*)),
- 2 => @as(usize, std.mem.readIntBig(u16, data[0..2].*)),
+ 8 => @as(usize, std.mem.readIntBig(u64, data[0..8])),
+ 2 => @as(usize, std.mem.readIntBig(u16, data[0..2])),
else => unreachable,
};
data = data[byte_size..];
@@ -1058,7 +1122,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
this.ping_len = @truncate(u8, ping_len);
if (ping_len > 0) {
- @memcpy(&this.ping_frame_bytes + 6, data.ptr, ping_len);
+ @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len);
data = data[ping_len..];
}
@@ -1110,7 +1174,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
break;
},
.fail => {
- this.terminate(ErrorCode.unexpected_control_frame);
+ this.terminate(ErrorCode.unsupported_control_frame);
terminated = true;
break;
},
@@ -1118,18 +1182,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}
- pub fn sendClose(this: *HTTPClient) void {
- this.sendCloseWithBody(this.tcp, 1001, null, null, 0);
+ pub fn sendClose(this: *WebSocket) void {
+ this.sendCloseWithBody(this.tcp, 1001, null, 0);
}
fn enqueueEncodedBytesMaybeFinal(
- this: *HTTPClient,
+ this: *WebSocket,
socket: Socket,
bytes: []const u8,
is_closing: bool,
) bool {
// fast path: no backpressure, no queue, just send the bytes.
- if (this.send_len == 0) {
+ if (!this.hasBackpressure()) {
const wrote = socket.write(bytes, !is_closing);
const expected = @intCast(c_int, bytes.len);
if (wrote == expected) {
@@ -1141,82 +1205,87 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return false;
}
- _ = this.copyToSendBuffer(bytes[bytes.len - @intCast(usize, wrote) ..], false, is_closing, false);
+ _ = this.copyToSendBuffer(bytes[@intCast(usize, wrote)..], false, is_closing);
return true;
}
- return this.copyToSendBuffer(bytes, true, is_closing, false);
+ return this.copyToSendBuffer(bytes, true, is_closing);
}
- fn copyToSendBuffer(this: *HTTPClient, bytes: []const u8, do_write: bool, is_closing: bool) bool {
+ fn copyToSendBuffer(this: *WebSocket, bytes: []const u8, do_write: bool, is_closing: bool) bool {
return this.sendData(.{ .raw = bytes }, do_write, is_closing);
}
- fn sendData(this: *HTTPClient, bytes: Copy, do_write: bool, is_closing: bool) bool {
+ fn sendData(this: *WebSocket, bytes: Copy, do_write: bool, is_closing: bool) bool {
var content_byte_len: usize = 0;
const write_len = bytes.len(&content_byte_len);
std.debug.assert(write_len > 0);
- this.send_len += write_len;
- var out_buf: []const u8 = "";
- const send_end = this.send_off + this.send_len;
- var ring_buffer = &this.send_overflow_buffer;
-
- if (send_end <= ring_buffer.capacity) {
- bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len);
- ring_buffer.items.len += write_len;
-
- out_buf = ring_buffer.items[this.send_off..send_end];
- } else if (send_end <= body_buf_len) {
- var buf = this.getBody();
- bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len);
- out_buf = buf[this.send_off..send_end];
- } else {
- if (this.send_body_buf) |send_body| {
- // transfer existing send buffer to overflow buffer
- ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch {
- this.terminate(ErrorCode.failed_to_allocate_memory);
- return false;
- };
- const off = this.send_len - write_len;
- @memcpy(ring_buffer.items.ptr, send_body + this.send_off, off);
- send_body.release();
- bytes.copy(this.globalThis, (ring_buffer.items.ptr + off)[0..write_len], content_byte_len);
- ring_buffer.items.len = this.send_len;
- this.send_body_buf = null;
- this.send_off = 0;
- } else if (send_end <= ring_buffer.capacity) {
- bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len);
- ring_buffer.items.len += write_len;
- // can we treat it as a ring buffer without re-allocating the array?
- } else if (send_end - this.send_off <= ring_buffer.capacity) {
- std.mem.copyBackwards(u8, ring_buffer.items[0..this.send_len], ring_buffer.items[this.send_off..send_end]);
- bytes.copy(this.globalThis, ring_buffer.items.ptr[this.send_len - write_len .. this.send_len], content_byte_len);
- ring_buffer.items.len = this.send_len;
- this.send_off = 0;
- } else {
- // we need to re-allocate the array
- ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch {
- this.terminate(ErrorCode.failed_to_allocate_memory);
- return false;
- };
-
- const data_to_copy = ring_buffer.items[this.send_off..][0..(this.send_len - write_len)];
- const position_in_slice = ring_buffer.items[0 .. this.send_len - write_len];
- if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr)
- std.mem.copyBackwards(
- u8,
- position_in_slice,
- data_to_copy,
- );
-
- ring_buffer.items.len = this.send_len;
- bytes.copy(this.globalThis, ring_buffer.items[this.send_len - write_len ..], content_byte_len);
- this.send_off = 0;
- }
-
- out_buf = ring_buffer.items[this.send_off..send_end];
- }
+ // this.send_len += write_len;
+ var writable = this.send_buffer.writableWithSize(write_len) catch unreachable;
+ bytes.copy(this.globalThis, writable[0..write_len], content_byte_len);
+ this.send_buffer.update(write_len);
+ // const send_end = this.send_off + this.send_len;
+ // var ring_buffer = &this.send_overflow_buffer;
+ // const prev_len = this.send_len - write_len;
+
+ // if (send_end <= ring_buffer.capacity) {
+ // const mid = ring_buffer.capacity / 2;
+ // if (send_end > mid and this.send_len < mid) {
+ // std.mem.copyBackwards(u8, ring_buffer.items.ptr[0..prev_len], ring_buffer.items.ptr[this.send_off..ring_buffer.capacity][0..prev_len]);
+ // this.send_off = 0;
+ // ring_buffer.items.len = prev_len;
+ // bytes.copy(this.globalThis, ring_buffer.items.ptr[prev_len..this.send_len], content_byte_len);
+ // } else {
+ // bytes.copy(this.globalThis, ring_buffer.items.ptr[send_end - write_len .. send_end], content_byte_len);
+ // }
+ // ring_buffer.items.len += write_len;
+
+ // out_buf = ring_buffer.items[this.send_off..][0..this.send_len];
+ // } else if (send_end <= body_buf_len) {
+ // var buf = this.getBody();
+ // bytes.copy(this.globalThis, buf[send_end - write_len ..][0..write_len], content_byte_len);
+ // out_buf = buf[this.send_off..][0..this.send_len];
+ // } else {
+ // if (this.send_body_buf) |send_body| {
+ // var buf = &send_body.data;
+ // // transfer existing send buffer to overflow buffer
+ // ring_buffer.ensureTotalCapacityPrecise(bun.default_allocator, this.send_len) catch {
+ // this.terminate(ErrorCode.failed_to_allocate_memory);
+ // return false;
+ // };
+ // @memcpy(ring_buffer.items.ptr, buf[this.send_off..].ptr, prev_len);
+ // send_body.release();
+ // bytes.copy(this.globalThis, (ring_buffer.items.ptr + prev_len)[0..write_len], content_byte_len);
+ // ring_buffer.items.len = this.send_len;
+ // this.send_body_buf = null;
+ // this.send_off = 0;
+ // } else if (send_end <= ring_buffer.capacity) {
+ // bytes.copy(this.globalThis, (ring_buffer.items.ptr + (send_end - write_len))[0..write_len], content_byte_len);
+ // ring_buffer.items.len += write_len;
+ // } else {
+ // // we need to re-allocate the array
+ // ring_buffer.ensureTotalCapacity(bun.default_allocator, this.send_len) catch {
+ // this.terminate(ErrorCode.failed_to_allocate_memory);
+ // return false;
+ // };
+
+ // const data_to_copy = ring_buffer.items[this.send_off..][0..prev_len];
+ // const position_in_slice = ring_buffer.items[0..prev_len];
+ // if (data_to_copy.len > 0 and data_to_copy.ptr != position_in_slice.ptr)
+ // std.mem.copyBackwards(
+ // u8,
+ // position_in_slice,
+ // data_to_copy,
+ // );
+
+ // ring_buffer.items.len = this.send_len;
+ // bytes.copy(this.globalThis, ring_buffer.items[prev_len..], content_byte_len);
+ // this.send_off = 0;
+ // }
+
+ // out_buf = ring_buffer.items[this.send_off..][0..this.send_len];
+ // }
if (do_write) {
if (comptime Environment.allow_assert) {
@@ -1224,61 +1293,66 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
std.debug.assert(!this.tcp.isClosed());
std.debug.assert(this.tcp.isEstablished());
}
- return this.sendBuffer(out_buf, is_closing);
+ return this.sendBuffer(this.send_buffer.readableSlice(0), is_closing, !is_closing);
}
return true;
}
- fn sendBuffer(this: *HTTPClient, out_buf: []const u8, is_closing: bool) bool {
+ fn sendBuffer(
+ this: *WebSocket,
+ out_buf: []const u8,
+ is_closing: bool,
+ _: bool,
+ ) bool {
std.debug.assert(out_buf.len > 0);
- const wrote = this.tcp.write(out_buf, !is_closing);
- const expected = @intCast(c_int, out_buf.len);
- if (wrote == expected) {
- this.clearSendBuffers(false);
- return true;
- }
-
+ _ = is_closing;
+ // set msg_more to false
+ // it seems to improve perf by ~20%
+ const wrote = this.tcp.write(out_buf, false);
if (wrote < 0) {
this.terminate(ErrorCode.failed_to_write);
return false;
}
+ const expected = @intCast(usize, wrote);
+ var readable = this.send_buffer.readableSlice(0);
+ if (readable.ptr == out_buf.ptr) {
+ this.send_buffer.discard(expected);
+ }
- this.send_len -= @intCast(usize, wrote);
- this.send_off += @intCast(usize, wrote);
return true;
}
- fn enqueueEncodedBytes(this: *HTTPClient, socket: Socket, bytes: []const u8) bool {
+ fn enqueueEncodedBytes(this: *WebSocket, socket: Socket, bytes: []const u8) bool {
return this.enqueueEncodedBytesMaybeFinal(socket, bytes, false);
}
- fn sendPong(this: *HTTPClient, socket: Socket) bool {
+ fn sendPong(this: *WebSocket, socket: Socket) bool {
if (socket.isClosed() or socket.isShutdown()) {
this.dispatchClose();
return false;
}
var header = @bitCast(WebsocketHeader, @as(u16, 0));
- header.fin = true;
+ header.final = true;
header.opcode = .Pong;
- var to_mask = &this.ping_frame_bytes[6..][0..this.ping_len];
+ var to_mask = this.ping_frame_bytes[6..][0..this.ping_len];
header.mask = to_mask.len > 0;
header.len = @truncate(u7, this.ping_len);
this.ping_frame_bytes[0..2].* = @bitCast(u16, header);
if (to_mask.len > 0) {
- Mask.from(this.globalThis).fill(&this.ping_frame_bytes[2..6], &to_mask, &to_mask);
+ Mask.from(this.globalThis).fill(this.ping_frame_bytes[2..6], to_mask, to_mask);
return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0 .. 6 + @as(usize, this.ping_len)]);
} else {
- return this.enqueueEncodedBytes(socket, &this.ping_frame_bytes[0..2]);
+ return this.enqueueEncodedBytes(socket, this.ping_frame_bytes[0..2]);
}
}
fn sendCloseWithBody(
- this: *HTTPClient,
+ this: *WebSocket,
socket: Socket,
code: u16,
body: ?*[125]u8,
@@ -1293,16 +1367,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
socket.shutdownRead();
var final_body_bytes: [128 + 8]u8 = undefined;
var header = @bitCast(WebsocketHeader, @as(u16, 0));
- header.fin = true;
+ header.final = true;
header.opcode = .Close;
header.mask = true;
- header.len = body_len + 2;
+ header.len = @truncate(u7, body_len + 2);
final_body_bytes[0..2].* = @bitCast([2]u8, @bitCast(u16, header));
- var mask_buf = &final_body_bytes[2..6];
- std.mem.writeIntSliceBig(u16, &final_body_bytes[6..8], code);
+ var mask_buf: *[4]u8 = final_body_bytes[2..6];
+ std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code);
if (body) |data| {
- if (body_len > 0) @memcpy(&final_body_bytes[8..], data, body_len);
+ if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len);
}
// we must mask the code
@@ -1315,41 +1389,37 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}
- pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
+ pub fn handleEnd(this: *WebSocket, socket: Socket) void {
std.debug.assert(socket.socket == this.tcp.socket);
this.terminate(ErrorCode.ended);
}
pub fn handleWritable(
- this: *HTTPClient,
+ this: *WebSocket,
socket: Socket,
) void {
std.debug.assert(socket.socket == this.tcp.socket);
- if (this.send_len == 0)
+ const send_buf = this.send_buffer.readableSlice(0);
+ if (send_buf.len == 0)
return;
-
- var send_buf: []const u8 = undefined;
- if (this.send_body_buf) |send_body_buf| {
- send_buf = send_body_buf[this.send_off..][0..this.send_len];
- std.debug.assert(this.send_overflow_buffer.items.len == 0);
- } else {
- send_buf = this.send_overflow_buffer.items[this.send_off..][0..this.send_len];
- }
- std.debug.assert(send_buf.len == this.send_len);
- _ = this.sendBuffer(send_buf, false);
+ _ = this.sendBuffer(send_buf, false, true);
}
pub fn handleTimeout(
- this: *HTTPClient,
+ this: *WebSocket,
_: Socket,
) void {
this.terminate(ErrorCode.timeout);
}
- pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void {
+ pub fn handleConnectError(this: *WebSocket, _: Socket, _: c_int) void {
this.terminate(ErrorCode.failed_to_connect);
}
+ pub fn hasBackpressure(this: *const WebSocket) bool {
+ return this.send_buffer.count > 0;
+ }
+
pub fn writeBinaryData(
- this: *HTTPClient,
+ this: *WebSocket,
ptr: [*]const u8,
len: usize,
) callconv(.C) void {
@@ -1365,17 +1435,17 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
const bytes = Copy{ .bytes = slice };
// fast path: small frame, no backpressure, attempt to send without allocating
const frame_size = WebsocketHeader.frameSizeIncludingMask(len);
- if (this.send_len == 0 and frame_size < 128 + 4) {
- var inline_buf: [128 + 4]u8 = undefined;
- bytes.copy(this.globalThis, &inline_buf[0..frame_size], slice.len);
- _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ if (!this.hasBackpressure() and frame_size < stack_frame_size) {
+ var inline_buf: [stack_frame_size]u8 = undefined;
+ bytes.copy(this.globalThis, inline_buf[0..frame_size], slice.len);
+ _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]);
return;
}
- _ = this.sendData(bytes, true, false);
+ _ = this.sendData(bytes, !this.hasBackpressure(), false);
}
pub fn writeString(
- this: *HTTPClient,
+ this: *WebSocket,
str_: *const JSC.ZigString,
) callconv(.C) void {
const str = str_.*;
@@ -1389,25 +1459,25 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
{
- var inline_buf: [128 + 4]u8 = undefined;
+ var inline_buf: [stack_frame_size]u8 = undefined;
// fast path: small frame, no backpressure, attempt to send without allocating
- if (!str.is16Bit() and str.len < 128 + 4) {
+ if (!str.is16Bit() and str.len < stack_frame_size) {
const bytes = Copy{ .latin1 = str.slice() };
const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len);
- if (this.send_len == 0 and frame_size < 128 + 4) {
- bytes.copy(this.globalThis, &inline_buf[0..frame_size], str.len);
- _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ if (!this.hasBackpressure() and frame_size < stack_frame_size) {
+ bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len);
+ _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]);
return;
}
// max length of a utf16 -> utf8 conversion is 4 times the length of the utf16 string
- } else if ((str.len * 4) < (128 + 4) and this.send_len == 0) {
- const bytes = Copy{ .utf16 = str.slice() };
+ } else if ((str.len * 4) < (stack_frame_size) and !this.hasBackpressure()) {
+ const bytes = Copy{ .utf16 = str.utf16SliceAligned() };
var byte_len: usize = 0;
const frame_size = bytes.len(&byte_len);
- std.debug.assert(frame_size <= 128 + 4);
- bytes.copy(this.globalThis, &inline_buf[0..frame_size], byte_len);
- _ = this.enqueueEncodedBytes(this.socket, inline_buf[0..frame_size]);
+ std.debug.assert(frame_size <= stack_frame_size);
+ bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len);
+ _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]);
return;
}
}
@@ -1417,12 +1487,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
Copy{ .utf16 = str.utf16SliceAligned() }
else
Copy{ .latin1 = str.slice() },
- true,
+ !this.hasBackpressure(),
false,
);
}
- pub fn close(this: *HTTPClient, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void {
+ fn dispatchClose(this: *WebSocket) void {
+ var out = this.outgoing_websocket orelse return;
+ JSC.markBinding();
+ WebSocket__didCloseWithErrorCode(out, ErrorCode.closed);
+ }
+
+ pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void {
if (this.tcp.isClosed() or this.tcp.isShutdown())
return;
@@ -1430,8 +1506,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
if (reason) |str| {
inner: {
var fixed_buffer = std.heap.FixedBufferAllocator.init(&close_reason_buf);
- const allocator = fixed_buffer.get();
- const wrote = std.fmt.allocPrint(allocator, "{}", str.*) catch break :inner;
+ const allocator = fixed_buffer.allocator();
+ const wrote = std.fmt.allocPrint(allocator, "{}", .{str.*}) catch break :inner;
this.sendCloseWithBody(this.tcp, code, wrote.ptr[0..125], wrote.len);
return;
}
@@ -1445,20 +1521,31 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
input_socket: *anyopaque,
socket_ctx: *anyopaque,
globalThis: *JSC.JSGlobalObject,
- ) callconv(.C) *anyopaque {
+ ) callconv(.C) ?*anyopaque {
var tcp = @ptrCast(*uws.Socket, input_socket);
var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx);
-
- return @ptrCast(
- *anyopaque,
- Socket.adopt(tcp, ctx, HTTPClient{
+ var adopted = Socket.adopt(
+ tcp,
+ ctx,
+ WebSocket,
+ "tcp",
+ WebSocket{
+ .tcp = undefined,
.outgoing_websocket = outgoing,
.globalThis = globalThis,
- }),
+ .receive_overflow_buffer = .{},
+ .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator),
+ },
+ ) orelse return null;
+ adopted.send_buffer.ensureTotalCapacity(2048) catch return null;
+ _ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic);
+ return @ptrCast(
+ *anyopaque,
+ adopted,
);
}
- pub fn finalize(this: *HTTPClient) callconv(.C) void {
+ pub fn finalize(this: *WebSocket) callconv(.C) void {
this.clearData();
if (this.tcp.isClosed())
@@ -1466,6 +1553,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
this.tcp.close(0, null);
this.outgoing_websocket = null;
+ _ = this.globalThis.bunVM().eventLoop().ready_tasks_count.fetchSub(1, .Monotonic);
}
pub const Export = shim.exportFunctions(.{