aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/bun-types/globals.d.ts19
-rw-r--r--src/bun.js/bindings/headers-cpp.h2
-rw-r--r--src/bun.js/bindings/headers.h6
-rw-r--r--src/bun.js/bindings/webcore/JSWebSocket.cpp64
-rw-r--r--src/bun.js/bindings/webcore/WebSocket.cpp48
-rw-r--r--src/bun.js/bindings/webcore/WebSocket.h3
-rw-r--r--src/http/websocket_http_client.zig62
-rw-r--r--test/bun.js/websocket.test.js23
8 files changed, 204 insertions, 23 deletions
diff --git a/packages/bun-types/globals.d.ts b/packages/bun-types/globals.d.ts
index b0cd1783b..0679a7eb4 100644
--- a/packages/bun-types/globals.d.ts
+++ b/packages/bun-types/globals.d.ts
@@ -1656,6 +1656,25 @@ interface WebSocket extends EventTarget {
declare var WebSocket: {
prototype: WebSocket;
new (url: string | URL, protocols?: string | string[]): WebSocket;
+ new (
+ url: string | URL,
+ options: {
+ /**
+ * An object specifying connection headers
+ *
+ * This is a Bun-specific extension.
+ */
+ headers?: HeadersInit;
+ /**
+ * A string specifying the subprotocols the server is willing to accept.
+ */
+ protocol?: string;
+ /**
+ * A string array specifying the subprotocols the server is willing to accept.
+ */
+ protocols?: string[];
+ },
+ ): WebSocket;
readonly CLOSED: number;
readonly CLOSING: number;
readonly CONNECTING: number;
diff --git a/src/bun.js/bindings/headers-cpp.h b/src/bun.js/bindings/headers-cpp.h
index ef5d7718b..d49a705b7 100644
--- a/src/bun.js/bindings/headers-cpp.h
+++ b/src/bun.js/bindings/headers-cpp.h
@@ -1,4 +1,4 @@
-//-- AUTOGENERATED FILE -- 1672229965
+//-- AUTOGENERATED FILE -- 1672280340
// clang-format off
#pragma once
diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h
index 353b07c89..0f121ef06 100644
--- a/src/bun.js/bindings/headers.h
+++ b/src/bun.js/bindings/headers.h
@@ -1,5 +1,5 @@
// clang-format off
-//-- AUTOGENERATED FILE -- 1672229965
+//-- AUTOGENERATED FILE -- 1672280340
#pragma once
#include <stddef.h>
@@ -648,7 +648,7 @@ ZIG_DECL JSC__JSValue FileSink__write(JSC__JSGlobalObject* arg0, JSC__CallFrame*
#ifdef __cplusplus
ZIG_DECL void Bun__WebSocketHTTPClient__cancel(WebSocketHTTPClient* arg0);
-ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6);
+ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);
ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
#endif
@@ -656,7 +656,7 @@ ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void
#ifdef __cplusplus
ZIG_DECL void Bun__WebSocketHTTPSClient__cancel(WebSocketHTTPSClient* arg0);
-ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6);
+ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);
ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
#endif
diff --git a/src/bun.js/bindings/webcore/JSWebSocket.cpp b/src/bun.js/bindings/webcore/JSWebSocket.cpp
index b234dbe1b..83d7283d9 100644
--- a/src/bun.js/bindings/webcore/JSWebSocket.cpp
+++ b/src/bun.js/bindings/webcore/JSWebSocket.cpp
@@ -37,6 +37,13 @@
#include "JSDOMConvertNumbers.h"
#include "JSDOMConvertSequences.h"
#include "JSDOMConvertStrings.h"
+#include "JSDOMConvertBoolean.h"
+#include "JSDOMConvertRecord.h"
+#include "JSDOMConvertUnion.h"
+#include "JSDOMExceptionHandling.h"
+#include "JSDOMGlobalObjectInlines.h"
+#include "JSDOMIterator.h"
+#include "JSDOMOperation.h"
#include "JSDOMExceptionHandling.h"
#include "JSDOMGlobalObjectInlines.h"
#include "JSDOMOperation.h"
@@ -54,6 +61,8 @@
#include <wtf/GetPtr.h>
#include <wtf/PointerPreparations.h>
#include <wtf/URL.h>
+#include "IDLTypes.h"
+#include "FetchHeaders.h"
namespace WebCore {
using namespace JSC;
@@ -185,6 +194,54 @@ static inline EncodedJSValue constructJSWebSocket2(JSGlobalObject* lexicalGlobal
return JSValue::encode(jsValue);
}
+static inline EncodedJSValue constructJSWebSocket3(JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, JSValue urlValue, JSValue optionsObjectValue)
+{
+ VM& vm = lexicalGlobalObject->vm();
+ auto throwScope = DECLARE_THROW_SCOPE(vm);
+ auto* globalObject = jsCast<Zig::GlobalObject*>(lexicalGlobalObject);
+ auto* context = globalObject->scriptExecutionContext();
+ if (UNLIKELY(!context))
+ return throwConstructorScriptExecutionContextUnavailableError(*lexicalGlobalObject, throwScope, "WebSocket");
+ auto url = convert<IDLUSVString>(*lexicalGlobalObject, urlValue);
+ RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
+
+ Vector<String> protocols;
+
+ auto headersInit = std::optional<Converter<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>::ReturnType>();
+ if (JSC::JSObject* options = optionsObjectValue.getObject()) {
+ if (JSValue headersValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "headers"_s)))) {
+ if (!headersValue.isUndefinedOrNull()) {
+ headersInit = convert<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>(*lexicalGlobalObject, headersValue);
+ RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
+ }
+ }
+
+ if (JSValue protocolsValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocols"_s)))) {
+ if (!protocolsValue.isUndefinedOrNull()) {
+ protocols = convert<IDLSequence<IDLDOMString>>(*lexicalGlobalObject, protocolsValue);
+ RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
+ }
+ } else if (JSValue protocolValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocol"_s)))) {
+ if (!protocolValue.isUndefinedOrNull()) {
+ protocols = Vector<String> { convert<IDLDOMString>(*lexicalGlobalObject, protocolValue) };
+ RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
+ }
+ }
+ }
+
+ RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
+ auto object = WebSocket::create(*context, WTFMove(url), protocols, WTFMove(headersInit));
+ if constexpr (IsExceptionOr<decltype(object)>)
+ RETURN_IF_EXCEPTION(throwScope, {});
+ static_assert(TypeOrExceptionOrUnderlyingType<decltype(object)>::isRef);
+ auto jsValue = toJSNewlyCreated<IDLInterface<WebSocket>>(*lexicalGlobalObject, *globalObject, throwScope, WTFMove(object));
+ if constexpr (IsExceptionOr<decltype(object)>)
+ RETURN_IF_EXCEPTION(throwScope, {});
+ setSubclassStructureIfNeeded<WebSocket>(lexicalGlobalObject, callFrame, asObject(jsValue));
+ RETURN_IF_EXCEPTION(throwScope, {});
+ return JSValue::encode(jsValue);
+}
+
template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::construct(JSGlobalObject* lexicalGlobalObject, CallFrame* callFrame)
{
VM& vm = lexicalGlobalObject->vm();
@@ -204,7 +261,12 @@ template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::co
if (success)
RELEASE_AND_RETURN(throwScope, (constructJSWebSocket1(lexicalGlobalObject, callFrame)));
}
- RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame)));
+
+ if (distinguishingArg.isString()) {
+ RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame)));
+ } else if (distinguishingArg.isObject()) {
+ RELEASE_AND_RETURN(throwScope, (constructJSWebSocket3(lexicalGlobalObject, callFrame, callFrame->uncheckedArgument(0), distinguishingArg)));
+ }
}
return argsCount < 1 ? throwVMError(lexicalGlobalObject, throwScope, createNotEnoughArgumentsError(lexicalGlobalObject)) : throwVMTypeError(lexicalGlobalObject, throwScope);
}
diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp
index 2b685ef95..015b706d9 100644
--- a/src/bun.js/bindings/webcore/WebSocket.cpp
+++ b/src/bun.js/bindings/webcore/WebSocket.cpp
@@ -197,18 +197,23 @@ WebSocket::~WebSocket()
ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url)
{
- return create(context, url, Vector<String> {});
+ return create(context, url, Vector<String> {}, std::nullopt);
}
ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols)
{
+ return create(context, url, protocols, std::nullopt);
+}
+
+ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headers)
+{
if (url.isNull())
return Exception { SyntaxError };
auto socket = adoptRef(*new WebSocket(context));
// socket->suspendIfNeeded();
- auto result = socket->connect(url, protocols);
+ auto result = socket->connect(url, protocols, WTFMove(headers));
// auto result = socket->connect(url, protocols);
if (result.hasException())
@@ -224,12 +229,12 @@ ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, c
ExceptionOr<void> WebSocket::connect(const String& url)
{
- return connect(url, Vector<String> {});
+ return connect(url, Vector<String> {}, std::nullopt);
}
ExceptionOr<void> WebSocket::connect(const String& url, const String& protocol)
{
- return connect(url, Vector<String> { 1, protocol });
+ return connect(url, Vector<String> { 1, protocol }, std::nullopt);
}
void WebSocket::failAsynchronously()
@@ -267,6 +272,11 @@ static String hostName(const URL& url, bool secure)
ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols)
{
+ return connect(url, protocols, std::nullopt);
+}
+
+ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headersInit)
+{
LOG(Network, "WebSocket %p connect() url='%s'", this, url.utf8().data());
m_url = URL { url };
@@ -280,9 +290,9 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) };
}
- bool is_secure = m_url.protocolIs("wss"_s);
+ bool is_secure = m_url.protocolIs("wss"_s) || m_url.protocolIs("https"_s);
- if (!m_url.protocolIs("ws"_s) && !is_secure) {
+ if (!m_url.protocolIs("http"_s) && !m_url.protocolIs("ws"_s) && !is_secure) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;
updateHasPendingActivity();
@@ -371,19 +381,41 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
port = userPort.value();
}
+ Vector<ZigString, 8> headerNames;
+ Vector<ZigString, 8> headerValues;
+
+ auto headersOrException = FetchHeaders::create(WTFMove(headersInit));
+ if (UNLIKELY(headersOrException.hasException())) {
+ m_state = CLOSED;
+ updateHasPendingActivity();
+ return headersOrException.releaseException();
+ }
+
+ auto headers = headersOrException.releaseReturnValue();
+ headerNames.reserveInitialCapacity(headers.get().internalHeaders().size());
+ headerValues.reserveInitialCapacity(headers.get().internalHeaders().size());
+ auto iterator = headers.get().createIterator();
+ while (auto value = iterator.next()) {
+ headerNames.uncheckedAppend(Zig::toZigString(value->key));
+ headerValues.uncheckedAppend(Zig::toZigString(value->value));
+ }
+
m_isSecure = is_secure;
this->incPendingActivityCount();
if (is_secure) {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>();
RELEASE_ASSERT(ctx);
- this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
+ this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());
} else {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>();
RELEASE_ASSERT(ctx);
- this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
+ this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());
}
+ headerValues.clear();
+ headerNames.clear();
+
if (this->m_upgradeClient == nullptr) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;
diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h
index 82de58333..59b551ea0 100644
--- a/src/bun.js/bindings/webcore/WebSocket.h
+++ b/src/bun.js/bindings/webcore/WebSocket.h
@@ -36,6 +36,7 @@
#include <wtf/URL.h>
#include <wtf/HashSet.h>
#include <wtf/Lock.h>
+#include "FetchHeaders.h"
namespace uWS {
template<bool, bool, typename>
@@ -59,6 +60,7 @@ public:
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url);
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const String& protocol);
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols);
+ static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);
~WebSocket();
enum State {
@@ -71,6 +73,7 @@ public:
ExceptionOr<void> connect(const String& url);
ExceptionOr<void> connect(const String& url, const String& protocol);
ExceptionOr<void> connect(const String& url, const Vector<String>& protocols);
+ ExceptionOr<void> connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);
ExceptionOr<void> send(const String& message);
ExceptionOr<void> send(JSC::ArrayBuffer&);
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig
index c908b4bff..801e11007 100644
--- a/src/http/websocket_http_client.zig
+++ b/src/http/websocket_http_client.zig
@@ -24,14 +24,48 @@ const Opcode = @import("./websocket.zig").Opcode;
const log = Output.scoped(.WebSocketClient, false);
-fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 {
+const NonUTF8Headers = struct {
+ names: []const JSC.ZigString,
+ values: []const JSC.ZigString,
+
+ pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
+ const count = self.names.len;
+ var i: usize = 0;
+ while (i < count) : (i += 1) {
+ try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] });
+ }
+ }
+
+ pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers {
+ if (len == 0) {
+ return .{
+ .names = &[_]JSC.ZigString{},
+ .values = &[_]JSC.ZigString{},
+ };
+ }
+
+ return .{
+ .names = names.?[0..len],
+ .values = values.?[0..len],
+ };
+ }
+};
+
+fn buildRequestBody(
+ vm: *JSC.VirtualMachine,
+ pathname: *const JSC.ZigString,
+ host: *const JSC.ZigString,
+ client_protocol: *const JSC.ZigString,
+ client_protocol_hash: *u64,
+ extra_headers: NonUTF8Headers,
+) std.mem.Allocator.Error![]u8 {
const allocator = vm.allocator;
const input_rand_buf = vm.rareData().nextUUID();
const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16);
var encoded_buf: [temp_buf_size]u8 = undefined;
const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf);
- var headers = [_]PicoHTTP.Header{
+ var static_headers = [_]PicoHTTP.Header{
.{
.name = "Sec-WebSocket-Key",
.value = accept_key,
@@ -43,9 +77,10 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos
};
if (client_protocol.len > 0)
- client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value);
+ client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value);
+
+ const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];
- var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];
const pathname_ = pathname.slice();
const host_ = host.slice();
const pico_headers = PicoHTTP.Headers{ .headers = headers_ };
@@ -59,12 +94,9 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"{any}" ++
+ "{any}" ++
"\r\n",
- .{
- pathname_,
- host_,
- pico_headers,
- },
+ .{ pathname_, host_, pico_headers, extra_headers },
);
}
@@ -174,11 +206,21 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
port: u16,
pathname: *const JSC.ZigString,
client_protocol: *const JSC.ZigString,
+ header_names: ?[*]const JSC.ZigString,
+ header_values: ?[*]const JSC.ZigString,
+ header_count: usize,
) callconv(.C) ?*HTTPClient {
std.debug.assert(global.bunVM().uws_event_loop != null);
var client_protocol_hash: u64 = 0;
- var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null;
+ var body = buildRequestBody(
+ global.bunVM(),
+ pathname,
+ host,
+ client_protocol,
+ &client_protocol_hash,
+ NonUTF8Headers.init(header_names, header_values, header_count),
+ ) catch return null;
var client: HTTPClient = HTTPClient{
.tcp = undefined,
.outgoing_websocket = websocket,
diff --git a/test/bun.js/websocket.test.js b/test/bun.js/websocket.test.js
index ab825fa63..3680e2749 100644
--- a/test/bun.js/websocket.test.js
+++ b/test/bun.js/websocket.test.js
@@ -19,6 +19,29 @@ describe("WebSocket", () => {
await closed;
});
+ it("supports headers", (done) => {
+ const server = Bun.serve({
+ port: 8024,
+ fetch(req, server) {
+ expect(req.headers.get("X-Hello")).toBe("World");
+ expect(req.headers.get("content-type")).toBe("lolwut");
+ server.stop();
+ done();
+ return new Response();
+ },
+ websocket: {
+ open(ws) {
+ ws.close();
+ },
+ },
+ });
+ const ws = new WebSocket(`ws://${server.hostname}:${server.port}`, {
+ headers: {
+ "X-Hello": "World",
+ "content-type": "lolwut",
+ },
+ });
+ });
it("should send and receive messages", async () => {
const ws = new WebSocket(TEST_WEBSOCKET_HOST);
await new Promise((resolve, reject) => {