diff options
author | 2023-10-17 14:36:56 -0700 | |
---|---|---|
committer | 2023-10-17 14:36:56 -0700 | |
commit | afab26762043a40f38954868ebdccce85be7f95c (patch) | |
tree | 89e27598435d96e3b6777a9c9f7526e1992a89c8 | |
parent | ac36f5c278197026b2a442d8ac0f18da6d77f9a1 (diff) | |
download | bun-afab26762043a40f38954868ebdccce85be7f95c.tar.gz bun-afab26762043a40f38954868ebdccce85be7f95c.tar.zst bun-afab26762043a40f38954868ebdccce85be7f95c.zip |
Fix `Host` header excluding port in WebSocket upgradefix-websocket-upgrade
-rw-r--r-- | src/http/websocket_http_client.zig | 16 | ||||
-rw-r--r-- | src/url.zig | 11 | ||||
-rw-r--r-- | test/js/web/websocket/websocket-upgrade.test.ts | 37 |
3 files changed, 54 insertions, 10 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index f275af0c6..9516934f1 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -21,6 +21,7 @@ const ObjectPool = @import("../pool.zig").ObjectPool; const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; const Opcode = @import("./websocket.zig").Opcode; +const ZigURL = @import("../url.zig").URL; const log = Output.scoped(.WebSocketClient, false); @@ -54,7 +55,9 @@ const NonUTF8Headers = struct { fn buildRequestBody( vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, + is_https: bool, host: *const JSC.ZigString, + port: u16, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64, extra_headers: NonUTF8Headers, @@ -86,22 +89,25 @@ fn buildRequestBody( host_.deinit(); } + const host_fmt = ZigURL.HostFormatter{ + .is_https = is_https, + .host = host_.slice(), + .port = port, + }; const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))]; const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; return try std.fmt.allocPrint( allocator, "GET {s} HTTP/1.1\r\n" ++ - "Host: {s}\r\n" ++ - "Pragma: no-cache\r\n" ++ - "Cache-Control: no-cache\r\n" ++ + "Host: {any}\r\n" ++ "Connection: Upgrade\r\n" ++ "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{s}" ++ "{s}" ++ "\r\n", - .{ pathname_.slice(), host_.slice(), pico_headers, extra_headers }, + .{ pathname_.slice(), host_fmt, pico_headers, extra_headers }, ); } @@ -242,7 +248,9 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { var body = buildRequestBody( global.bunVM(), pathname, + ssl, host, + port, client_protocol, &client_protocol_hash, NonUTF8Headers.init(header_names, header_values, header_count), diff --git a/src/url.zig b/src/url.zig index 11959636e..90bc002ba 100644 --- a/src/url.zig +++ b/src/url.zig @@ -103,7 +103,7 @@ pub const URL = struct { pub const HostFormatter = struct { host: string, - port: string = "", + port: ?u16 = null, is_https: bool = false, pub fn format(formatter: HostFormatter, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { @@ -114,11 +114,10 @@ pub const URL = struct { try writer.writeAll(formatter.host); - const is_port_optional = (formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "443"))) or - (!formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "80"))); + const is_port_optional = formatter.port == null or (formatter.is_https and formatter.port == 443) or + (!formatter.is_https and formatter.port == 80); if (!is_port_optional) { - try writer.writeAll(":"); - try writer.writeAll(formatter.port); + try writer.print(":{d}", .{formatter.port.?}); return; } } @@ -126,7 +125,7 @@ pub const URL = struct { pub fn displayHost(this: *const URL) HostFormatter { return HostFormatter{ .host = if (this.host.len > 0) this.host else this.displayHostname(), - .port = this.port, + .port = if (this.port.len > 0) this.getPort() else null, .is_https = this.isHTTPS(), }; } diff --git a/test/js/web/websocket/websocket-upgrade.test.ts b/test/js/web/websocket/websocket-upgrade.test.ts new file mode 100644 index 000000000..1b6e2f5d7 --- /dev/null +++ b/test/js/web/websocket/websocket-upgrade.test.ts @@ -0,0 +1,37 @@ +import { serve } from "bun"; +import { describe, test, expect } from "bun:test"; + +describe("WebSocket upgrade", () => { + test("should send correct upgrade headers", async () => { + const server = serve({ + hostname: "localhost", + port: 0, + fetch(request, server) { + expect(server.upgrade(request)).toBeTrue(); + const { headers } = request; + expect(headers.get("connection")).toBe("upgrade"); + expect(headers.get("upgrade")).toBe("websocket"); + expect(headers.get("sec-websocket-version")).toBe("13"); + expect(headers.get("sec-websocket-key")).toBeString(); + expect(headers.get("host")).toBe(`localhost:${server.port}`); + return; + // FIXME: types gets annoyed if this is not here + return new Response(); + }, + websocket: { + open(ws) { + // FIXME: double-free issue + // ws.close(); + server.stop(); + }, + message(ws, message) {}, + }, + }); + await new Promise((resolve, reject) => { + const ws = new WebSocket(`ws://localhost:${server.port}/`); + ws.addEventListener("open", resolve); + ws.addEventListener("error", reject); + ws.addEventListener("close", reject); + }); + }); +}); |