diff options
-rw-r--r-- | packages/bun-types/globals.d.ts | 19 | ||||
-rw-r--r-- | src/bun.js/bindings/headers-cpp.h | 2 | ||||
-rw-r--r-- | src/bun.js/bindings/headers.h | 6 | ||||
-rw-r--r-- | src/bun.js/bindings/webcore/JSWebSocket.cpp | 64 | ||||
-rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.cpp | 48 | ||||
-rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.h | 3 | ||||
-rw-r--r-- | src/http/websocket_http_client.zig | 62 | ||||
-rw-r--r-- | test/bun.js/websocket.test.js | 23 |
8 files changed, 204 insertions, 23 deletions
diff --git a/packages/bun-types/globals.d.ts b/packages/bun-types/globals.d.ts index b0cd1783b..0679a7eb4 100644 --- a/packages/bun-types/globals.d.ts +++ b/packages/bun-types/globals.d.ts @@ -1656,6 +1656,25 @@ interface WebSocket extends EventTarget { declare var WebSocket: { prototype: WebSocket; new (url: string | URL, protocols?: string | string[]): WebSocket; + new ( + url: string | URL, + options: { + /** + * An object specifying connection headers + * + * This is a Bun-specific extension. + */ + headers?: HeadersInit; + /** + * A string specifying the subprotocols the server is willing to accept. + */ + protocol?: string; + /** + * A string array specifying the subprotocols the server is willing to accept. + */ + protocols?: string[]; + }, + ): WebSocket; readonly CLOSED: number; readonly CLOSING: number; readonly CONNECTING: number; diff --git a/src/bun.js/bindings/headers-cpp.h b/src/bun.js/bindings/headers-cpp.h index ef5d7718b..d49a705b7 100644 --- a/src/bun.js/bindings/headers-cpp.h +++ b/src/bun.js/bindings/headers-cpp.h @@ -1,4 +1,4 @@ -//-- AUTOGENERATED FILE -- 1672229965 +//-- AUTOGENERATED FILE -- 1672280340 // clang-format off #pragma once diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h index 353b07c89..0f121ef06 100644 --- a/src/bun.js/bindings/headers.h +++ b/src/bun.js/bindings/headers.h @@ -1,5 +1,5 @@ // clang-format off -//-- AUTOGENERATED FILE -- 1672229965 +//-- AUTOGENERATED FILE -- 1672280340 #pragma once #include <stddef.h> @@ -648,7 +648,7 @@ ZIG_DECL JSC__JSValue FileSink__write(JSC__JSGlobalObject* arg0, JSC__CallFrame* #ifdef __cplusplus ZIG_DECL void Bun__WebSocketHTTPClient__cancel(WebSocketHTTPClient* arg0); -ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6); +ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9); ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2); #endif @@ -656,7 +656,7 @@ ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void #ifdef __cplusplus ZIG_DECL void Bun__WebSocketHTTPSClient__cancel(WebSocketHTTPSClient* arg0); -ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6); +ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9); ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2); #endif diff --git a/src/bun.js/bindings/webcore/JSWebSocket.cpp b/src/bun.js/bindings/webcore/JSWebSocket.cpp index b234dbe1b..83d7283d9 100644 --- a/src/bun.js/bindings/webcore/JSWebSocket.cpp +++ b/src/bun.js/bindings/webcore/JSWebSocket.cpp @@ -37,6 +37,13 @@ #include "JSDOMConvertNumbers.h" #include "JSDOMConvertSequences.h" #include "JSDOMConvertStrings.h" +#include "JSDOMConvertBoolean.h" +#include "JSDOMConvertRecord.h" +#include "JSDOMConvertUnion.h" +#include "JSDOMExceptionHandling.h" +#include "JSDOMGlobalObjectInlines.h" +#include "JSDOMIterator.h" +#include "JSDOMOperation.h" #include "JSDOMExceptionHandling.h" #include "JSDOMGlobalObjectInlines.h" #include "JSDOMOperation.h" @@ -54,6 +61,8 @@ #include <wtf/GetPtr.h> #include <wtf/PointerPreparations.h> #include <wtf/URL.h> +#include "IDLTypes.h" +#include "FetchHeaders.h" namespace WebCore { using namespace JSC; @@ -185,6 +194,54 @@ static inline EncodedJSValue constructJSWebSocket2(JSGlobalObject* lexicalGlobal return JSValue::encode(jsValue); } +static inline EncodedJSValue constructJSWebSocket3(JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, JSValue urlValue, JSValue optionsObjectValue) +{ + VM& vm = lexicalGlobalObject->vm(); + auto throwScope = DECLARE_THROW_SCOPE(vm); + auto* globalObject = jsCast<Zig::GlobalObject*>(lexicalGlobalObject); + auto* context = globalObject->scriptExecutionContext(); + if (UNLIKELY(!context)) + return throwConstructorScriptExecutionContextUnavailableError(*lexicalGlobalObject, throwScope, "WebSocket"); + auto url = convert<IDLUSVString>(*lexicalGlobalObject, urlValue); + RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + + Vector<String> protocols; + + auto headersInit = std::optional<Converter<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>::ReturnType>(); + if (JSC::JSObject* options = optionsObjectValue.getObject()) { + if (JSValue headersValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "headers"_s)))) { + if (!headersValue.isUndefinedOrNull()) { + headersInit = convert<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>(*lexicalGlobalObject, headersValue); + RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + } + } + + if (JSValue protocolsValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocols"_s)))) { + if (!protocolsValue.isUndefinedOrNull()) { + protocols = convert<IDLSequence<IDLDOMString>>(*lexicalGlobalObject, protocolsValue); + RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + } + } else if (JSValue protocolValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocol"_s)))) { + if (!protocolValue.isUndefinedOrNull()) { + protocols = Vector<String> { convert<IDLDOMString>(*lexicalGlobalObject, protocolValue) }; + RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + } + } + } + + RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + auto object = WebSocket::create(*context, WTFMove(url), protocols, WTFMove(headersInit)); + if constexpr (IsExceptionOr<decltype(object)>) + RETURN_IF_EXCEPTION(throwScope, {}); + static_assert(TypeOrExceptionOrUnderlyingType<decltype(object)>::isRef); + auto jsValue = toJSNewlyCreated<IDLInterface<WebSocket>>(*lexicalGlobalObject, *globalObject, throwScope, WTFMove(object)); + if constexpr (IsExceptionOr<decltype(object)>) + RETURN_IF_EXCEPTION(throwScope, {}); + setSubclassStructureIfNeeded<WebSocket>(lexicalGlobalObject, callFrame, asObject(jsValue)); + RETURN_IF_EXCEPTION(throwScope, {}); + return JSValue::encode(jsValue); +} + template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::construct(JSGlobalObject* lexicalGlobalObject, CallFrame* callFrame) { VM& vm = lexicalGlobalObject->vm(); @@ -204,7 +261,12 @@ template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::co if (success) RELEASE_AND_RETURN(throwScope, (constructJSWebSocket1(lexicalGlobalObject, callFrame))); } - RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame))); + + if (distinguishingArg.isString()) { + RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame))); + } else if (distinguishingArg.isObject()) { + RELEASE_AND_RETURN(throwScope, (constructJSWebSocket3(lexicalGlobalObject, callFrame, callFrame->uncheckedArgument(0), distinguishingArg))); + } } return argsCount < 1 ? throwVMError(lexicalGlobalObject, throwScope, createNotEnoughArgumentsError(lexicalGlobalObject)) : throwVMTypeError(lexicalGlobalObject, throwScope); } diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp index 2b685ef95..015b706d9 100644 --- a/src/bun.js/bindings/webcore/WebSocket.cpp +++ b/src/bun.js/bindings/webcore/WebSocket.cpp @@ -197,18 +197,23 @@ WebSocket::~WebSocket() ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url) { - return create(context, url, Vector<String> {}); + return create(context, url, Vector<String> {}, std::nullopt); } ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols) { + return create(context, url, protocols, std::nullopt); +} + +ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headers) +{ if (url.isNull()) return Exception { SyntaxError }; auto socket = adoptRef(*new WebSocket(context)); // socket->suspendIfNeeded(); - auto result = socket->connect(url, protocols); + auto result = socket->connect(url, protocols, WTFMove(headers)); // auto result = socket->connect(url, protocols); if (result.hasException()) @@ -224,12 +229,12 @@ ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, c ExceptionOr<void> WebSocket::connect(const String& url) { - return connect(url, Vector<String> {}); + return connect(url, Vector<String> {}, std::nullopt); } ExceptionOr<void> WebSocket::connect(const String& url, const String& protocol) { - return connect(url, Vector<String> { 1, protocol }); + return connect(url, Vector<String> { 1, protocol }, std::nullopt); } void WebSocket::failAsynchronously() @@ -267,6 +272,11 @@ static String hostName(const URL& url, bool secure) ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols) { + return connect(url, protocols, std::nullopt); +} + +ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headersInit) +{ LOG(Network, "WebSocket %p connect() url='%s'", this, url.utf8().data()); m_url = URL { url }; @@ -280,9 +290,9 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) }; } - bool is_secure = m_url.protocolIs("wss"_s); + bool is_secure = m_url.protocolIs("wss"_s) || m_url.protocolIs("https"_s); - if (!m_url.protocolIs("ws"_s) && !is_secure) { + if (!m_url.protocolIs("http"_s) && !m_url.protocolIs("ws"_s) && !is_secure) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; updateHasPendingActivity(); @@ -371,19 +381,41 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr port = userPort.value(); } + Vector<ZigString, 8> headerNames; + Vector<ZigString, 8> headerValues; + + auto headersOrException = FetchHeaders::create(WTFMove(headersInit)); + if (UNLIKELY(headersOrException.hasException())) { + m_state = CLOSED; + updateHasPendingActivity(); + return headersOrException.releaseException(); + } + + auto headers = headersOrException.releaseReturnValue(); + headerNames.reserveInitialCapacity(headers.get().internalHeaders().size()); + headerValues.reserveInitialCapacity(headers.get().internalHeaders().size()); + auto iterator = headers.get().createIterator(); + while (auto value = iterator.next()) { + headerNames.uncheckedAppend(Zig::toZigString(value->key)); + headerValues.uncheckedAppend(Zig::toZigString(value->value)); + } + m_isSecure = is_secure; this->incPendingActivityCount(); if (is_secure) { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>(); RELEASE_ASSERT(ctx); - this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); + this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size()); } else { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>(); RELEASE_ASSERT(ctx); - this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); + this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size()); } + headerValues.clear(); + headerNames.clear(); + if (this->m_upgradeClient == nullptr) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h index 82de58333..59b551ea0 100644 --- a/src/bun.js/bindings/webcore/WebSocket.h +++ b/src/bun.js/bindings/webcore/WebSocket.h @@ -36,6 +36,7 @@ #include <wtf/URL.h> #include <wtf/HashSet.h> #include <wtf/Lock.h> +#include "FetchHeaders.h" namespace uWS { template<bool, bool, typename> @@ -59,6 +60,7 @@ public: static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url); static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const String& protocol); static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols); + static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&); ~WebSocket(); enum State { @@ -71,6 +73,7 @@ public: ExceptionOr<void> connect(const String& url); ExceptionOr<void> connect(const String& url, const String& protocol); ExceptionOr<void> connect(const String& url, const Vector<String>& protocols); + ExceptionOr<void> connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&); ExceptionOr<void> send(const String& message); ExceptionOr<void> send(JSC::ArrayBuffer&); diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index c908b4bff..801e11007 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -24,14 +24,48 @@ const Opcode = @import("./websocket.zig").Opcode; const log = Output.scoped(.WebSocketClient, false); -fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 { +const NonUTF8Headers = struct { + names: []const JSC.ZigString, + values: []const JSC.ZigString, + + pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + const count = self.names.len; + var i: usize = 0; + while (i < count) : (i += 1) { + try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] }); + } + } + + pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers { + if (len == 0) { + return .{ + .names = &[_]JSC.ZigString{}, + .values = &[_]JSC.ZigString{}, + }; + } + + return .{ + .names = names.?[0..len], + .values = values.?[0..len], + }; + } +}; + +fn buildRequestBody( + vm: *JSC.VirtualMachine, + pathname: *const JSC.ZigString, + host: *const JSC.ZigString, + client_protocol: *const JSC.ZigString, + client_protocol_hash: *u64, + extra_headers: NonUTF8Headers, +) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; const input_rand_buf = vm.rareData().nextUUID(); const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); - var headers = [_]PicoHTTP.Header{ + var static_headers = [_]PicoHTTP.Header{ .{ .name = "Sec-WebSocket-Key", .value = accept_key, @@ -43,9 +77,10 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos }; if (client_protocol.len > 0) - client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value); + client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value); + + const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; - var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; const pathname_ = pathname.slice(); const host_ = host.slice(); const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; @@ -59,12 +94,9 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos "Upgrade: websocket\r\n" ++ "Sec-WebSocket-Version: 13\r\n" ++ "{any}" ++ + "{any}" ++ "\r\n", - .{ - pathname_, - host_, - pico_headers, - }, + .{ pathname_, host_, pico_headers, extra_headers }, ); } @@ -174,11 +206,21 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { port: u16, pathname: *const JSC.ZigString, client_protocol: *const JSC.ZigString, + header_names: ?[*]const JSC.ZigString, + header_values: ?[*]const JSC.ZigString, + header_count: usize, ) callconv(.C) ?*HTTPClient { std.debug.assert(global.bunVM().uws_event_loop != null); var client_protocol_hash: u64 = 0; - var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null; + var body = buildRequestBody( + global.bunVM(), + pathname, + host, + client_protocol, + &client_protocol_hash, + NonUTF8Headers.init(header_names, header_values, header_count), + ) catch return null; var client: HTTPClient = HTTPClient{ .tcp = undefined, .outgoing_websocket = websocket, diff --git a/test/bun.js/websocket.test.js b/test/bun.js/websocket.test.js index ab825fa63..3680e2749 100644 --- a/test/bun.js/websocket.test.js +++ b/test/bun.js/websocket.test.js @@ -19,6 +19,29 @@ describe("WebSocket", () => { await closed; }); + it("supports headers", (done) => { + const server = Bun.serve({ + port: 8024, + fetch(req, server) { + expect(req.headers.get("X-Hello")).toBe("World"); + expect(req.headers.get("content-type")).toBe("lolwut"); + server.stop(); + done(); + return new Response(); + }, + websocket: { + open(ws) { + ws.close(); + }, + }, + }); + const ws = new WebSocket(`ws://${server.hostname}:${server.port}`, { + headers: { + "X-Hello": "World", + "content-type": "lolwut", + }, + }); + }); it("should send and receive messages", async () => { const ws = new WebSocket(TEST_WEBSOCKET_HOST); await new Promise((resolve, reject) => { |