diff options
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r-- | src/http/websocket_http_client.zig | 451 |
1 files changed, 451 insertions, 0 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig new file mode 100644 index 000000000..5e795bb4a --- /dev/null +++ b/src/http/websocket_http_client.zig @@ -0,0 +1,451 @@ +// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig +// Thank you @frmdstryr. +const std = @import("std"); +const native_endian = @import("builtin").target.cpu.arch.endian(); + +const tcp = std.x.net.tcp; +const ip = std.x.net.ip; + +const IPv4 = std.x.os.IPv4; +const IPv6 = std.x.os.IPv6; +const os = std.os; +const bun = @import("../global.zig"); +const string = bun.string; +const Output = bun.Output; +const Global = bun.Global; +const Environment = bun.Environment; +const strings = bun.strings; +const MutableString = bun.MutableString; +const stringZ = bun.stringZ; +const default_allocator = bun.default_allocator; +const C = bun.C; + +const uws = @import("uws"); +const JSC = @import("javascript_core"); +const PicoHTTP = @import("picohttp"); +const ObjectPool = @import("../pool.zig").ObjectPool; + +fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) ![]u8 { + const allocator = vm.allocator(); + var input_rand_buf: [16]u8 = undefined; + std.crypto.random.bytes(&input_rand_buf); + const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); + var encoded_buf: [temp_buf_size]u8 = undefined; + const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); + + var headers = [_]PicoHTTP.Header{ + .{ + .name = "Sec-WebSocket-Key", + .value = accept_key, + }, + .{ + .name = "Sec-WebSocket-Protocol", + .value = client_protocol.slice(), + }, + }; + + if (client_protocol.len > 0) + client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value); + + var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; + + return try std.fmt.allocPrint(allocator, + \\GET {} HTTP/1.1\r + \\Host: {}\r + \\Pragma: no-cache\r + \\Cache-Control: no-cache\r + \\Connection: Upgrade\r + \\Upgrade: websocket\r + \\Sec-WebSocket-Version: 13\r + \\{s} + \\\r + \\ + , .{ + pathname.*, + host.*, + PicoHTTP.Headers{ .headers = headers_ }, + }); +} + +const ErrorCode = enum(i32) { + cancel, + invalid_response, + expected_101_status_code, + missing_upgrade_header, + missing_connection_header, + missing_websocket_accept_header, + invalid_upgrade_header, + invalid_connection_header, + invalid_websocket_version, + mismatch_websocket_accept_header, + missing_client_protocol, + mismatch_client_protocol, + timeout, + closed, + failed_to_write, + failed_to_connect, + headers_too_large, + ended, +}; +extern fn WebSocket__didConnect( + websocket_context: *anyopaque, + socket: *uws.Socket, + buffered_data: ?[*]u8, + buffered_len: usize, +) void; +extern fn WebSocket__didFailToConnect(websocket_context: *anyopaque, reason: ErrorCode) void; + +const BodyBufBytes = [16384 - 16]u8; + +const BodyBufPool = ObjectPool(BodyBufBytes, null, true, 4); +const BodyBuf = BodyBufPool.Node; + +pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { + return struct { + pub const Socket = uws.NewSocketHandler(ssl); + socket: Socket, + outgoing_websocket: *anyopaque, + input_body_buf: []u8 = &[_]u8{}, + client_protocol: []const u8 = "", + to_send: []const u8 = "", + read_length: usize = 0, + headers_buf: [128]PicoHTTP.Header = undefined, + body_buf: ?*BodyBuf = null, + body_written: usize = 0, + websocket_protocol: u64 = 0, + + const basename = (if (ssl) "SecureWebSocket" else "WebSocket") ++ "UpgradeClient"; + + pub const shim = JSC.Shimmer("Bun", basename, @This()); + + const HTTPClient = @This(); + + pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) void { + var vm = global.bunVM(); + var loop = @ptrCast(*uws.Loop, loop_); + var ctx: *uws.us_socket_context_t = @ptrCast(*uws.us_socket_context_t, ctx_); + + if (vm.uws_event_loop) |other| { + std.debug.assert(other == loop); + } + + vm.uws_event_loop = loop; + + Socket.configure(ctx, HTTPClient, handleOpen, handleClose, handleData, handleWritable, handleTimeout, handleConnectError, handleEnd); + } + + pub fn connect( + global: *JSC.JSGlobalObject, + socket_ctx: *anyopaque, + websocket: *anyopaque, + host: *const JSC.ZigString, + port: u16, + pathname: *const JSC.ZigString, + client_protocol: *const JSC.ZigString, + outgoing_usocket: **anyopaque, + ) ?*HTTPClient { + std.debug.assert(global.bunVM().uws_event_loop != null); + + var client_protocol_hash: u64 = 0; + var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null; + var client: HTTPClient = HTTPClient{ + .socket = undefined, + .outgoing_websocket = websocket, + .input_body_buf = body, + .websocket_protocol = client_protocol_hash, + }; + var host_ = host.toSlice(bun.default_allocator); + defer host_.deinit(); + + if (Socket.connect(host_.slice(), port, @ptrCast(*uws.us_socket_context_t, socket_ctx), HTTPClient, client, "socket")) |out| { + outgoing_usocket.* = out.socket.socket; + out.socket.timeout(120); + return out; + } + + client.clearData(); + + return null; + } + + pub fn clearInput(this: *HTTPClient) void { + if (this.input_body_buf.len > 0) bun.default_allocator.free(this.input_body_buf); + this.input_body_buf.len = 0; + } + pub fn clearData(this: *HTTPClient) void { + this.clearInput(); + if (this.body_buf) |buf| { + this.body_buf = null; + buf.release(); + } + } + pub fn cancel(this: *HTTPClient) void { + this.clearData(); + + if (!this.socket.isEstablished()) { + _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.socket); + } else { + this.socket.close(0, null); + } + } + + pub fn fail(this: *HTTPClient, code: ErrorCode) void { + WebSocket__didFailToConnect(this.outgoing_websocket, code); + this.cancel(); + } + + pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void { + this.clearData(); + WebSocket__didFailToConnect(this.outgoing_websocket, ErrorCode.closed); + } + + pub fn terminate(this: *HTTPClient, code: ErrorCode) void { + this.fail(code); + if (this.socket.isClosed() == 0) + this.socket.close(0, null); + } + + pub fn handleOpen(this: *HTTPClient, socket: Socket) void { + std.debug.assert(socket.socket == this.socket.socket); + + std.debug.assert(this.input_body_buf.len > 0); + std.debug.assert(this.to_send.len == 0); + + const wrote = socket.write(this.input_body_buf, false); + if (wrote < 0) { + this.terminate(ErrorCode.failed_to_write); + return; + } + + this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; + } + + fn getBody(this: *HTTPClient) *BodyBufBytes { + if (this.body_buf == null) { + this.body_buf = BodyBufPool.get(bun.default_allocator); + } + + return &this.body_buf.?.data; + } + + pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { + std.debug.assert(socket.socket == this.socket.socket); + + if (comptime Environment.allow_assert) + std.debug.assert(socket.isShutdown() == 0); + + var body = this.getBody(); + var remain = body[this.body_written..]; + const is_first = this.body_written == 0; + if (is_first and data.len >= "101 ".len) { + // fail early if we receive a non-101 status code + if (!strings.eqlComptimeIgnoreLen(data[0.."101 ".len], "101 ")) { + this.terminate(ErrorCode.expected_101_status_code); + return; + } + } + + const to_write = remain[0..@minimum(remain.len, data.len)]; + if (data.len > 0 and to_write.len > 0) { + @memcpy(remain.ptr, data, to_write.len); + this.body_written += to_write.len; + } + + const overflow = data[to_write.len..]; + + const available_to_read = body[0..this.body_written]; + const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { + switch (err) { + error.Malformed_HTTP_Response => { + this.terminate(ErrorCode.invalid_response); + return; + }, + error.ShortRead => { + if (overflow.len > 0) { + this.terminate(ErrorCode.headers_too_large); + return; + } + return; + }, + } + }; + + var buffered_body_data = body[@minimum(@intCast(usize, response.bytes_read), body.len)..]; + buffered_body_data = buffered_body_data[0..@minimum(buffered_body_data.len, this.body_written)]; + + this.processResponse(response, buffered_body_data, overflow); + } + + pub fn handleEnd(this: *HTTPClient, socket: Socket) void { + std.debug.assert(socket.socket == this.socket.socket); + this.terminate(ErrorCode.ended); + } + + pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8, overflow_buf: []const u8) void { + std.debug.assert(this.body_written > 0); + + var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; + var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; + var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; + var visited_protocol = this.websocket_protocol == 0; + var visited_version = false; + + if (remain_buf.len > 0) { + std.debug.assert(overflow_buf.len == 0); + } + + for (response.headers) |header| { + switch (header.name.len) { + "Connection".len => { + if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) { + connection_header = header; + if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + break; + } + } + }, + "Upgrade".len => { + if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) { + upgrade_header = header; + if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + break; + } + } + }, + "Sec-WebSocket-Version".len => { + if (!visited_version and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) { + visited_version = true; + if (!strings.eqlComptime(header.value, "13", false)) { + this.terminate(ErrorCode.invalid_websocket_version); + return; + } + } + }, + "Sec-WebSocket-Accept".len => { + if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) { + websocket_accept_header = header; + if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + break; + } + } + }, + "Sec-WebSocket-Protocol".len => { + if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { + if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) { + this.terminate(ErrorCode.mismatch_client_protocol); + return; + } + visited_protocol = true; + + if (visited_protocol and upgrade_header.len > 0 and connection_header.len > 0 and websocket_accept_header.len > 0 and visited_version) { + break; + } + } + }, + else => {}, + } + } + + if (!visited_version) { + this.terminate(ErrorCode.invalid_websocket_version); + return; + } + + if (@minimum(upgrade_header.name.len, upgrade_header.value.len) == 0) { + this.terminate(ErrorCode.missing_upgrade_header); + return; + } + + if (@minimum(connection_header.name.len, connection_header.value.len) == 0) { + this.terminate(ErrorCode.missing_connection_header); + return; + } + + if (@minimum(websocket_accept_header.name.len, websocket_accept_header.value.len) == 0) { + this.terminate(ErrorCode.missing_websocket_accept_header); + return; + } + + if (!visited_protocol) { + this.terminate(ErrorCode.mismatch_client_protocol); + return; + } + + if (strings.eqlComptime(connection_header.value, "Upgrade")) { + this.terminate(ErrorCode.invalid_connection_header); + return; + } + + if (!strings.eqlComptime(upgrade_header.value, "websocket")) { + this.terminate(ErrorCode.invalid_upgrade_header); + return; + } + + // TODO: check websocket_accept_header.value + + const overflow_len = overflow_buf.len + remain_buf.len; + var overflow: []u8 = &.{}; + if (overflow_len > 0) { + overflow = bun.default_allocator.alloc(u8, overflow_len) catch { + this.terminate(ErrorCode.invalid_response); + return; + }; + if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); + if (overflow_buf.len > 0) @memcpy(overflow.ptr + remain_buf.len, overflow_buf.ptr, remain_buf.len); + } + + this.clearData(); + WebSocket__didConnect(this.outgoing_websocket, this.socket.socket, overflow.ptr, overflow.len); + } + + pub fn handleWritable( + this: *HTTPClient, + socket: Socket, + ) void { + std.debug.assert(socket.socket == this.socket.socket); + + if (this.to_send.len == 0) + return; + + const wrote = socket.write(this.to_send, false); + if (wrote < 0) { + this.terminate(ErrorCode.failed_to_write); + return; + } + std.debug.assert(@intCast(usize, wrote) >= this.to_send.len); + this.to_send = this.to_send[@minimum(@intCast(usize, wrote), this.to_send.len)..]; + } + pub fn handleTimeout( + this: *HTTPClient, + _: Socket, + ) void { + this.terminate(ErrorCode.timeout); + } + pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { + this.terminate(ErrorCode.failed_to_connect); + } + + const Exports = shim.exportFunctions(.{ + .connect = connect, + .cancel = cancel, + .register = register, + }); + + comptime { + if (!JSC.is_bindgen) { + @export(connect, .{ + .name = Exports[0].symbol_name, + }); + @export(cancel, .{ + .name = Exports[1].symbol_name, + }); + @export(register, .{ + .name = Exports[2].symbol_name, + }); + } + } + }; +} + +pub const WebSocketUpgradeClient = NewHTTPUpgradeClient(false); +pub const SecureWebSocketUpgradeClient = NewHTTPUpgradeClient(true); |