aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> 2023-05-21 15:23:02 -0700
committerGravatar Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> 2023-05-21 15:23:02 -0700
commit7d682c0fe7521289dc41f7c22f269c9cd278bf9e (patch)
treed8bb3bf73e2197d1fe3dc485ff78ffad918ee334
parentb6007a860e6e66aa794f664b699023c1f7e32cf3 (diff)
downloadbun-7d682c0fe7521289dc41f7c22f269c9cd278bf9e.tar.gz
bun-7d682c0fe7521289dc41f7c22f269c9cd278bf9e.tar.zst
bun-7d682c0fe7521289dc41f7c22f269c9cd278bf9e.zip
[ws client] Make it a little more type safe
-rw-r--r--src/http/websocket_http_client.zig84
1 files changed, 51 insertions, 33 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index 6a62ceb87..3ce8f9118 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -129,15 +129,23 @@ const ErrorCode = enum(i32) {
unexpected_opcode,
invalid_utf8,
};
-extern fn WebSocket__didConnect(
- websocket_context: *anyopaque,
- socket: *uws.Socket,
- buffered_data: ?[*]u8,
- buffered_len: usize,
-) void;
-extern fn WebSocket__didCloseWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void;
-extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void;
-extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: [*]const u8, byte_len: usize) void;
+
+pub const JSWebSocket = opaque {
+ extern fn WebSocket__didConnect(
+ websocket_context: *JSWebSocket,
+ socket: *uws.Socket,
+ buffered_data: ?[*]u8,
+ buffered_len: usize,
+ ) void;
+ extern fn WebSocket__didCloseWithErrorCode(websocket_context: *JSWebSocket, reason: ErrorCode) void;
+ extern fn WebSocket__didReceiveText(websocket_context: *JSWebSocket, clone: bool, text: *const JSC.ZigString) void;
+ extern fn WebSocket__didReceiveBytes(websocket_context: *JSWebSocket, bytes: [*]const u8, byte_len: usize) void;
+
+ pub const didConnect = WebSocket__didConnect;
+ pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode;
+ pub const didReceiveText = WebSocket__didReceiveText;
+ pub const didReceiveBytes = WebSocket__didReceiveBytes;
+};
const body_buf_len = 16384 - 16;
const BodyBufBytes = [body_buf_len]u8;
@@ -149,7 +157,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
return struct {
pub const Socket = uws.NewSocketHandler(ssl);
tcp: Socket,
- outgoing_websocket: *anyopaque,
+ outgoing_websocket: ?*JSWebSocket,
input_body_buf: []u8 = &[_]u8{},
client_protocol: []const u8 = "",
to_send: []const u8 = "",
@@ -202,7 +210,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn connect(
global: *JSC.JSGlobalObject,
socket_ctx: *anyopaque,
- websocket: *anyopaque,
+ websocket: *JSWebSocket,
host: *const JSC.ZigString,
port: u16,
pathname: *const JSC.ZigString,
@@ -290,7 +298,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn fail(this: *HTTPClient, code: ErrorCode) void {
log("onFail", .{});
JSC.markBinding(@src());
- WebSocket__didCloseWithErrorCode(this.outgoing_websocket, code);
+ if (this.outgoing_websocket) |ws| {
+ this.outgoing_websocket = null;
+ ws.didCloseWithErrorCode(code);
+ }
+
this.cancel();
}
@@ -298,7 +310,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
log("onClose", .{});
JSC.markBinding(@src());
this.clearData();
- WebSocket__didCloseWithErrorCode(this.outgoing_websocket, ErrorCode.ended);
+ if (this.outgoing_websocket) |ws| {
+ this.outgoing_websocket = null;
+ ws.didCloseWithErrorCode(ErrorCode.ended);
+ }
}
pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
@@ -351,6 +366,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
log("onData", .{});
std.debug.assert(socket.socket == this.tcp.socket);
+ if (this.outgoing_websocket == null) {
+ this.clearData();
+ return;
+ }
if (comptime Environment.allow_assert)
std.debug.assert(!socket.isShutdown());
@@ -512,7 +531,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
JSC.markBinding(@src());
this.tcp.timeout(0);
log("onDidConnect", .{});
- WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len);
+
+ this.outgoing_websocket.?.didConnect(this.tcp.socket, overflow.ptr, overflow.len);
}
pub fn handleWritable(
@@ -824,7 +844,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
return struct {
pub const Socket = uws.NewSocketHandler(ssl);
tcp: Socket,
- outgoing_websocket: ?*anyopaque,
+ outgoing_websocket: ?*JSWebSocket = null,
receive_state: ReceiveState = ReceiveState.need_header,
receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)),
@@ -902,8 +922,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
pub fn fail(this: *WebSocket, code: ErrorCode) void {
JSC.markBinding(@src());
- if (this.outgoing_websocket) |ws|
- WebSocket__didCloseWithErrorCode(ws, code);
+ if (this.outgoing_websocket) |ws| {
+ this.outgoing_websocket = null;
+ ws.didCloseWithErrorCode(code);
+ }
this.cancel();
}
@@ -914,7 +936,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
log("WebSocket.onHandshake({d})", .{success});
if (success == 0) {
if (this.outgoing_websocket) |ws| {
- WebSocket__didCloseWithErrorCode(ws, ErrorCode.failed_to_connect);
+ this.outgoing_websocket = null;
+ ws.didCloseWithErrorCode(ErrorCode.failed_to_connect);
}
}
}
@@ -922,22 +945,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
log("onClose", .{});
JSC.markBinding(@src());
this.clearData();
- if (this.outgoing_websocket) |ws|
- WebSocket__didCloseWithErrorCode(ws, ErrorCode.ended);
+ if (this.outgoing_websocket) |ws| {
+ this.outgoing_websocket = null;
+ ws.didCloseWithErrorCode(ErrorCode.ended);
+ }
}
pub fn terminate(this: *WebSocket, code: ErrorCode) void {
this.fail(code);
}
- fn getReceiveBody(this: *WebSocket) *BodyBufBytes {
- if (this.receive_body_buf == null) {
- this.receive_body_buf = BodyBufPool.get(bun.default_allocator);
- }
-
- return &this.receive_body_buf.?.data;
- }
-
fn clearReceiveBuffers(this: *WebSocket, free: bool) void {
this.receive_buffer.head = 0;
this.receive_buffer.count = 0;
@@ -978,16 +995,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
outstring = JSC.ZigString.from16Slice(utf16);
outstring.mark();
JSC.markBinding(@src());
- WebSocket__didReceiveText(out, false, &outstring);
+ out.didReceiveText(false, &outstring);
} else {
outstring = JSC.ZigString.init(data_);
JSC.markBinding(@src());
- WebSocket__didReceiveText(out, true, &outstring);
+ out.didReceiveText(true, &outstring);
}
},
.Binary => {
JSC.markBinding(@src());
- WebSocket__didReceiveBytes(out, data_.ptr, data_.len);
+ out.didReceiveBytes(data_.ptr, data_.len);
},
else => unreachable,
}
@@ -1480,7 +1497,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
var out = this.outgoing_websocket orelse return;
this.poll_ref.unrefOnNextTick(this.globalThis.bunVM());
JSC.markBinding(@src());
- WebSocket__didCloseWithErrorCode(out, ErrorCode.closed);
+ this.outgoing_websocket = null;
+ out.didCloseWithErrorCode(ErrorCode.closed);
}
pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void {
@@ -1502,7 +1520,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
pub fn init(
- outgoing: *anyopaque,
+ outgoing: *JSWebSocket,
input_socket: *anyopaque,
socket_ctx: *anyopaque,
globalThis: *JSC.JSGlobalObject,