aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/http/websocket_http_client.zig16
-rw-r--r--src/url.zig11
-rw-r--r--test/js/web/websocket/websocket-upgrade.test.ts37
3 files changed, 54 insertions, 10 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index f275af0c6..9516934f1 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -21,6 +21,7 @@ const ObjectPool = @import("../pool.zig").ObjectPool;
const WebsocketHeader = @import("./websocket.zig").WebsocketHeader;
const WebsocketDataFrame = @import("./websocket.zig").WebsocketDataFrame;
const Opcode = @import("./websocket.zig").Opcode;
+const ZigURL = @import("../url.zig").URL;
const log = Output.scoped(.WebSocketClient, false);
@@ -54,7 +55,9 @@ const NonUTF8Headers = struct {
fn buildRequestBody(
vm: *JSC.VirtualMachine,
pathname: *const JSC.ZigString,
+ is_https: bool,
host: *const JSC.ZigString,
+ port: u16,
client_protocol: *const JSC.ZigString,
client_protocol_hash: *u64,
extra_headers: NonUTF8Headers,
@@ -86,22 +89,25 @@ fn buildRequestBody(
host_.deinit();
}
+ const host_fmt = ZigURL.HostFormatter{
+ .is_https = is_https,
+ .host = host_.slice(),
+ .port = port,
+ };
const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))];
const pico_headers = PicoHTTP.Headers{ .headers = headers_ };
return try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
- "Host: {s}\r\n" ++
- "Pragma: no-cache\r\n" ++
- "Cache-Control: no-cache\r\n" ++
+ "Host: {any}\r\n" ++
"Connection: Upgrade\r\n" ++
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"{s}" ++
"{s}" ++
"\r\n",
- .{ pathname_.slice(), host_.slice(), pico_headers, extra_headers },
+ .{ pathname_.slice(), host_fmt, pico_headers, extra_headers },
);
}
@@ -242,7 +248,9 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
var body = buildRequestBody(
global.bunVM(),
pathname,
+ ssl,
host,
+ port,
client_protocol,
&client_protocol_hash,
NonUTF8Headers.init(header_names, header_values, header_count),
diff --git a/src/url.zig b/src/url.zig
index 11959636e..90bc002ba 100644
--- a/src/url.zig
+++ b/src/url.zig
@@ -103,7 +103,7 @@ pub const URL = struct {
pub const HostFormatter = struct {
host: string,
- port: string = "",
+ port: ?u16 = null,
is_https: bool = false,
pub fn format(formatter: HostFormatter, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
@@ -114,11 +114,10 @@ pub const URL = struct {
try writer.writeAll(formatter.host);
- const is_port_optional = (formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "443"))) or
- (!formatter.is_https and (formatter.port.len == 0 or strings.eqlComptime(formatter.port, "80")));
+ const is_port_optional = formatter.port == null or (formatter.is_https and formatter.port == 443) or
+ (!formatter.is_https and formatter.port == 80);
if (!is_port_optional) {
- try writer.writeAll(":");
- try writer.writeAll(formatter.port);
+ try writer.print(":{d}", .{formatter.port.?});
return;
}
}
@@ -126,7 +125,7 @@ pub const URL = struct {
pub fn displayHost(this: *const URL) HostFormatter {
return HostFormatter{
.host = if (this.host.len > 0) this.host else this.displayHostname(),
- .port = this.port,
+ .port = if (this.port.len > 0) this.getPort() else null,
.is_https = this.isHTTPS(),
};
}
diff --git a/test/js/web/websocket/websocket-upgrade.test.ts b/test/js/web/websocket/websocket-upgrade.test.ts
new file mode 100644
index 000000000..1b6e2f5d7
--- /dev/null
+++ b/test/js/web/websocket/websocket-upgrade.test.ts
@@ -0,0 +1,37 @@
+import { serve } from "bun";
+import { describe, test, expect } from "bun:test";
+
+describe("WebSocket upgrade", () => {
+ test("should send correct upgrade headers", async () => {
+ const server = serve({
+ hostname: "localhost",
+ port: 0,
+ fetch(request, server) {
+ expect(server.upgrade(request)).toBeTrue();
+ const { headers } = request;
+ expect(headers.get("connection")).toBe("upgrade");
+ expect(headers.get("upgrade")).toBe("websocket");
+ expect(headers.get("sec-websocket-version")).toBe("13");
+ expect(headers.get("sec-websocket-key")).toBeString();
+ expect(headers.get("host")).toBe(`localhost:${server.port}`);
+ return;
+ // FIXME: types gets annoyed if this is not here
+ return new Response();
+ },
+ websocket: {
+ open(ws) {
+ // FIXME: double-free issue
+ // ws.close();
+ server.stop();
+ },
+ message(ws, message) {},
+ },
+ });
+ await new Promise((resolve, reject) => {
+ const ws = new WebSocket(`ws://localhost:${server.port}/`);
+ ws.addEventListener("open", resolve);
+ ws.addEventListener("error", reject);
+ ws.addEventListener("close", reject);
+ });
+ });
+});