diff options
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r-- | src/http/websocket_http_client.zig | 84 |
1 files changed, 44 insertions, 40 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 5e795bb4a..d577fadc6 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -25,8 +25,8 @@ const JSC = @import("javascript_core"); const PicoHTTP = @import("picohttp"); const ObjectPool = @import("../pool.zig").ObjectPool; -fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) ![]u8 { - const allocator = vm.allocator(); +fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 { + const allocator = vm.allocator; var input_rand_buf: [16]u8 = undefined; std.crypto.random.bytes(&input_rand_buf); const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); @@ -48,23 +48,26 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value); var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; - - return try std.fmt.allocPrint(allocator, - \\GET {} HTTP/1.1\r - \\Host: {}\r - \\Pragma: no-cache\r - \\Cache-Control: no-cache\r - \\Connection: Upgrade\r - \\Upgrade: websocket\r - \\Sec-WebSocket-Version: 13\r - \\{s} - \\\r - \\ - , .{ - pathname.*, - host.*, - PicoHTTP.Headers{ .headers = headers_ }, - }); + const pathname_ = pathname.slice(); + const host_ = host.slice(); + const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; + return try std.fmt.allocPrint( + allocator, + "GET {s} HTTP/1.1\r\n" ++ + "Host: {s}\r\n" ++ + "Pragma: no-cache\r\n" ++ + "Cache-Control: no-cache\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "{any}\n" ++ + "\r\n", + .{ + pathname_, + host_, + pico_headers, + }, + ); } const ErrorCode = enum(i32) { @@ -114,13 +117,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { body_written: usize = 0, websocket_protocol: u64 = 0, - const basename = (if (ssl) "SecureWebSocket" else "WebSocket") ++ "UpgradeClient"; + pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient"; - pub const shim = JSC.Shimmer("Bun", basename, @This()); + pub const shim = JSC.Shimmer("Bun", name, @This()); const HTTPClient = @This(); - pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) void { + pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { var vm = global.bunVM(); var loop = @ptrCast(*uws.Loop, loop_); var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); @@ -142,8 +145,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { port: u16, pathname: *const JSC.ZigString, client_protocol: *const JSC.ZigString, - outgoing_usocket: **anyopaque, - ) ?*HTTPClient { + ) callconv(.C) ?*HTTPClient { std.debug.assert(global.bunVM().uws_event_loop != null); var client_protocol_hash: u64 = 0; @@ -158,7 +160,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { defer host_.deinit(); if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "socket")) |out| { - outgoing_usocket.* = out.socket.socket; out.socket.timeout(120); return out; } @@ -179,22 +180,24 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { buf.release(); } } - pub fn cancel(this: *HTTPClient) void { + pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.socket.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.socket); + _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.socket.socket); } else { this.socket.close(0, null); } } pub fn fail(this: *HTTPClient, code: ErrorCode) void { + JSC.markBinding(); WebSocket__didFailToConnect(this.outgoing_websocket, code); this.cancel(); } pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { + JSC.markBinding(); this.clearData(); WebSocket__didFailToConnect(this.outgoing_websocket, ErrorCode.closed); } @@ -205,7 +208,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.socket.close(0, null); } - pub fn handleOpen(this: *HTTPClient, socket: Socket) void { + pub fn handleOpen(this: *HTTPClient, socket: Socket, _: []const u8, _: c_int) void { std.debug.assert(socket.socket == this.socket.socket); std.debug.assert(this.input_body_buf.len > 0); @@ -247,7 +250,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { const to_write = remain[0..@minimum(remain.len, data.len)]; if (data.len > 0 and to_write.len > 0) { - @memcpy(remain.ptr, data, to_write.len); + @memcpy(remain.ptr, data.ptr, to_write.len); this.body_written += to_write.len; } @@ -299,7 +302,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { "Connection".len => { if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) { connection_header = header; - if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0 and visited_version) { break; } } @@ -307,7 +310,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { "Upgrade".len => { if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) { upgrade_header = header; - if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0 and visited_version) { break; } } @@ -315,7 +318,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { "Sec-WebSocket-Version".len => { if (!visited_version and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) { visited_version = true; - if (!strings.eqlComptime(header.value, "13", false)) { + if (!strings.eqlComptimeIgnoreLen(header.value, "13")) { this.terminate(ErrorCode.invalid_websocket_version); return; } @@ -324,7 +327,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { "Sec-WebSocket-Accept".len => { if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) { websocket_accept_header = header; - if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0 and visited_version) { break; } } @@ -337,7 +340,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } visited_protocol = true; - if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0 and visited_version) { break; } } @@ -395,6 +398,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } this.clearData(); + JSC.markBinding(); WebSocket__didConnect(this.outgoing_websocket, this.socket.socket, overflow.ptr, overflow.len); } @@ -425,7 +429,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.failed_to_connect); } - const Exports = shim.exportFunctions(.{ + pub const Export = shim.exportFunctions(.{ .connect = connect, .cancel = cancel, .register = register, @@ -434,18 +438,18 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { comptime { if (!JSC.is_bindgen) { @export(connect, .{ - .name = Exports[0].symbol_name, + .name = Export[0].symbol_name, }); @export(cancel, .{ - .name = Exports[1].symbol_name, + .name = Export[1].symbol_name, }); @export(register, .{ - .name = Exports[2].symbol_name, + .name = Export[2].symbol_name, }); } } }; } -pub const WebSocketUpgradeClient = NewHTTPUpgradeClient(false); -pub const SecureWebSocketUpgradeClient = NewHTTPUpgradeClient(true); +pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false); +pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true); |