diff options
-rw-r--r-- | src/http/websocket_http_client.zig | 84 |
1 files changed, 51 insertions, 33 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 6a62ceb87..3ce8f9118 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -129,15 +129,23 @@ const ErrorCode = enum(i32) { unexpected_opcode, invalid_utf8, }; -extern fn WebSocket__didConnect( - websocket_context: *anyopaque, - socket: *uws.Socket, - buffered_data: ?[*]u8, - buffered_len: usize, -) void; -extern fn WebSocket__didCloseWithErrorCode(websocket_context: *anyopaque, reason: ErrorCode) void; -extern fn WebSocket__didReceiveText(websocket_context: *anyopaque, clone: bool, text: *const JSC.ZigString) void; -extern fn WebSocket__didReceiveBytes(websocket_context: *anyopaque, bytes: [*]const u8, byte_len: usize) void; + +pub const JSWebSocket = opaque { + extern fn WebSocket__didConnect( + websocket_context: *JSWebSocket, + socket: *uws.Socket, + buffered_data: ?[*]u8, + buffered_len: usize, + ) void; + extern fn WebSocket__didCloseWithErrorCode(websocket_context: *JSWebSocket, reason: ErrorCode) void; + extern fn WebSocket__didReceiveText(websocket_context: *JSWebSocket, clone: bool, text: *const JSC.ZigString) void; + extern fn WebSocket__didReceiveBytes(websocket_context: *JSWebSocket, bytes: [*]const u8, byte_len: usize) void; + + pub const didConnect = WebSocket__didConnect; + pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode; + pub const didReceiveText = WebSocket__didReceiveText; + pub const didReceiveBytes = WebSocket__didReceiveBytes; +}; const body_buf_len = 16384 - 16; const BodyBufBytes = [body_buf_len]u8; @@ -149,7 +157,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, - outgoing_websocket: *anyopaque, + outgoing_websocket: ?*JSWebSocket, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", to_send: []const u8 = "", @@ -202,7 +210,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn connect( global: *JSC.JSGlobalObject, socket_ctx: *anyopaque, - websocket: *anyopaque, + websocket: *JSWebSocket, host: *const JSC.ZigString, port: u16, pathname: *const JSC.ZigString, @@ -290,7 +298,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn fail(this: *HTTPClient, code: ErrorCode) void { log("onFail", .{}); JSC.markBinding(@src()); - WebSocket__didCloseWithErrorCode(this.outgoing_websocket, code); + if (this.outgoing_websocket) |ws| { + this.outgoing_websocket = null; + ws.didCloseWithErrorCode(code); + } + this.cancel(); } @@ -298,7 +310,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { log("onClose", .{}); JSC.markBinding(@src()); this.clearData(); - WebSocket__didCloseWithErrorCode(this.outgoing_websocket, ErrorCode.ended); + if (this.outgoing_websocket) |ws| { + this.outgoing_websocket = null; + ws.didCloseWithErrorCode(ErrorCode.ended); + } } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { @@ -351,6 +366,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); std.debug.assert(socket.socket == this.tcp.socket); + if (this.outgoing_websocket == null) { + this.clearData(); + return; + } if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); @@ -512,7 +531,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { JSC.markBinding(@src()); this.tcp.timeout(0); log("onDidConnect", .{}); - WebSocket__didConnect(this.outgoing_websocket, this.tcp.socket, overflow.ptr, overflow.len); + + this.outgoing_websocket.?.didConnect(this.tcp.socket, overflow.ptr, overflow.len); } pub fn handleWritable( @@ -824,7 +844,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); tcp: Socket, - outgoing_websocket: ?*anyopaque, + outgoing_websocket: ?*JSWebSocket = null, receive_state: ReceiveState = ReceiveState.need_header, receive_header: WebsocketHeader = @bitCast(WebsocketHeader, @as(u16, 0)), @@ -902,8 +922,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { pub fn fail(this: *WebSocket, code: ErrorCode) void { JSC.markBinding(@src()); - if (this.outgoing_websocket) |ws| - WebSocket__didCloseWithErrorCode(ws, code); + if (this.outgoing_websocket) |ws| { + this.outgoing_websocket = null; + ws.didCloseWithErrorCode(code); + } this.cancel(); } @@ -914,7 +936,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { log("WebSocket.onHandshake({d})", .{success}); if (success == 0) { if (this.outgoing_websocket) |ws| { - WebSocket__didCloseWithErrorCode(ws, ErrorCode.failed_to_connect); + this.outgoing_websocket = null; + ws.didCloseWithErrorCode(ErrorCode.failed_to_connect); } } } @@ -922,22 +945,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { log("onClose", .{}); JSC.markBinding(@src()); this.clearData(); - if (this.outgoing_websocket) |ws| - WebSocket__didCloseWithErrorCode(ws, ErrorCode.ended); + if (this.outgoing_websocket) |ws| { + this.outgoing_websocket = null; + ws.didCloseWithErrorCode(ErrorCode.ended); + } } pub fn terminate(this: *WebSocket, code: ErrorCode) void { this.fail(code); } - fn getReceiveBody(this: *WebSocket) *BodyBufBytes { - if (this.receive_body_buf == null) { - this.receive_body_buf = BodyBufPool.get(bun.default_allocator); - } - - return &this.receive_body_buf.?.data; - } - fn clearReceiveBuffers(this: *WebSocket, free: bool) void { this.receive_buffer.head = 0; this.receive_buffer.count = 0; @@ -978,16 +995,16 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { outstring = JSC.ZigString.from16Slice(utf16); outstring.mark(); JSC.markBinding(@src()); - WebSocket__didReceiveText(out, false, &outstring); + out.didReceiveText(false, &outstring); } else { outstring = JSC.ZigString.init(data_); JSC.markBinding(@src()); - WebSocket__didReceiveText(out, true, &outstring); + out.didReceiveText(true, &outstring); } }, .Binary => { JSC.markBinding(@src()); - WebSocket__didReceiveBytes(out, data_.ptr, data_.len); + out.didReceiveBytes(data_.ptr, data_.len); }, else => unreachable, } @@ -1480,7 +1497,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { var out = this.outgoing_websocket orelse return; this.poll_ref.unrefOnNextTick(this.globalThis.bunVM()); JSC.markBinding(@src()); - WebSocket__didCloseWithErrorCode(out, ErrorCode.closed); + this.outgoing_websocket = null; + out.didCloseWithErrorCode(ErrorCode.closed); } pub fn close(this: *WebSocket, code: u16, reason: ?*const JSC.ZigString) callconv(.C) void { @@ -1502,7 +1520,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn init( - outgoing: *anyopaque, + outgoing: *JSWebSocket, input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, |