diff options
author | 2023-10-17 14:36:56 -0700 | |
---|---|---|
committer | 2023-10-17 14:36:56 -0700 | |
commit | afab26762043a40f38954868ebdccce85be7f95c (patch) | |
tree | 89e27598435d96e3b6777a9c9f7526e1992a89c8 /src | |
parent | ac36f5c278197026b2a442d8ac0f18da6d77f9a1 (diff) | |
download | bun-fix-websocket-upgrade.tar.gz bun-fix-websocket-upgrade.tar.zst bun-fix-websocket-upgrade.zip |
Fix `Host` header excluding port in WebSocket upgradefix-websocket-upgrade
Diffstat (limited to '')
-rw-r--r-- | src/http/websocket_http_client.zig | 16 | ||||
-rw-r--r-- | src/url.zig | 11 |
2 files changed, 17 insertions, 10 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index f275af0c6..9516934f1 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -21,6 +21,7 @@ const ObjectPool = @import("../pool.zig").ObjectPool; const WebsocketHeader = @import("./websocket.zig").WebsocketHeader; const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame; const Opcode = @import("./websocket.zig").Opcode; +const ZigURL = @import("../url.zig").URL; const log = Output.scoped(.WebSocketClient, false); @@ -54,7 +55,9 @@ const NonUTF8Headers = struct { fn buildRequestBody( vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, + is_https: bool, host: *const JSC.ZigString, + port: u16, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64, extra_headers: NonUTF8Headers, @@ -86,22 +89,25 @@ fn buildRequestBody( host_.deinit(); } + const host_fmt = ZigURL.HostFormatter{ + .is_https = is_https, + .host = host_.slice(), + .port = port, + }; const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))]; const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; return try std.fmt.allocPrint( allocator, "GET {s} HTTP/1.1\r\n" ++ - "Host: {s}\r\n" ++ - "Pragma: no-cache\r\n" ++ - "Cache-Control: no-cache\r\n" ++ + "Host: {any}\r\n" ++ "Connection: Upgrade\r\n" ++ "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{s}" ++ "{s}" ++ "\r\n", - .{ pathname_.slice(), host_.slice(), pico_headers, extra_headers }, + .{ pathname_.slice(), host_fmt, pico_headers, extra_headers }, ); } @@ -242,7 +248,9 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { var body = buildRequestBody( global.bunVM(), pathname, + ssl, host, + port, client_protocol, &client_protocol_hash, NonUTF8Headers.init(header_names, header_values, header_count), diff --git a/src/url.zig b/src/url.zig index 11959636e..90bc002ba 100644 --- a/src/url.zig +++ b/src/url.zig @@ -103,7 +103,7 @@ pub const URL = struct { pub const HostFormatter = struct { host: string, - port: string = "", + port: ?u16 = null, is_https: bool = false, pub fn format(formatter: HostFormatter, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { @@ -114,11 +114,10 @@ pub const URL = struct { try writer.writeAll(formatter.host); - const is_port_optional = (formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "443"))) or - (!formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "80"))); + const is_port_optional = formatter.port == null or (formatter.is_https and formatter.port == 443) or + (!formatter.is_https and formatter.port == 80); if (!is_port_optional) { - try writer.writeAll(":"); - try writer.writeAll(formatter.port); + try writer.print(":{d}", .{formatter.port.?}); return; } } @@ -126,7 +125,7 @@ pub const URL = struct { pub fn displayHost(this: *const URL) HostFormatter { return HostFormatter{ .host = if (this.host.len > 0) this.host else this.displayHostname(), - .port = this.port, + .port = if (this.port.len > 0) this.getPort() else null, .is_https = this.isHTTPS(), }; } |