diff options
| author | 2022-12-28 18:39:19 -0800 | |
|---|---|---|
| committer | 2022-12-28 18:39:19 -0800 | |
| commit | ba0b5baee4ebe58f6c4e35932a9473b8eb2f2225 (patch) | |
| tree | 9c4eabc1471c6c00f21e8f9f29740f361e673565 /src | |
| parent | 384a9cda5e329c8fb44dcd9ff12d893696153a69 (diff) | |
| download | bun-ba0b5baee4ebe58f6c4e35932a9473b8eb2f2225.tar.gz bun-ba0b5baee4ebe58f6c4e35932a9473b8eb2f2225.tar.zst bun-ba0b5baee4ebe58f6c4e35932a9473b8eb2f2225.zip | |
[WebSocket] Implement `headers` support
Fixes https://github.com/oven-sh/bun/issues/1676
Diffstat (limited to 'src')
| -rw-r--r-- | src/bun.js/bindings/headers-cpp.h | 2 | ||||
| -rw-r--r-- | src/bun.js/bindings/headers.h | 6 | ||||
| -rw-r--r-- | src/bun.js/bindings/webcore/JSWebSocket.cpp | 64 | ||||
| -rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.cpp | 48 | ||||
| -rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.h | 3 | ||||
| -rw-r--r-- | src/http/websocket_http_client.zig | 62 | 
6 files changed, 162 insertions, 23 deletions
| diff --git a/src/bun.js/bindings/headers-cpp.h b/src/bun.js/bindings/headers-cpp.h index ef5d7718b..d49a705b7 100644 --- a/src/bun.js/bindings/headers-cpp.h +++ b/src/bun.js/bindings/headers-cpp.h @@ -1,4 +1,4 @@ -//-- AUTOGENERATED FILE -- 1672229965 +//-- AUTOGENERATED FILE -- 1672280340  // clang-format off  #pragma once diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h index 353b07c89..0f121ef06 100644 --- a/src/bun.js/bindings/headers.h +++ b/src/bun.js/bindings/headers.h @@ -1,5 +1,5 @@  // clang-format off -//-- AUTOGENERATED FILE -- 1672229965 +//-- AUTOGENERATED FILE -- 1672280340  #pragma once  #include <stddef.h> @@ -648,7 +648,7 @@ ZIG_DECL JSC__JSValue FileSink__write(JSC__JSGlobalObject* arg0, JSC__CallFrame*  #ifdef __cplusplus  ZIG_DECL void Bun__WebSocketHTTPClient__cancel(WebSocketHTTPClient* arg0); -ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6); +ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);  ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);  #endif @@ -656,7 +656,7 @@ ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void  #ifdef __cplusplus  ZIG_DECL void Bun__WebSocketHTTPSClient__cancel(WebSocketHTTPSClient* arg0); -ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6); +ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);  ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);  #endif diff --git a/src/bun.js/bindings/webcore/JSWebSocket.cpp b/src/bun.js/bindings/webcore/JSWebSocket.cpp index b234dbe1b..83d7283d9 100644 --- a/src/bun.js/bindings/webcore/JSWebSocket.cpp +++ b/src/bun.js/bindings/webcore/JSWebSocket.cpp @@ -37,6 +37,13 @@  #include "JSDOMConvertNumbers.h"  #include "JSDOMConvertSequences.h"  #include "JSDOMConvertStrings.h" +#include "JSDOMConvertBoolean.h" +#include "JSDOMConvertRecord.h" +#include "JSDOMConvertUnion.h" +#include "JSDOMExceptionHandling.h" +#include "JSDOMGlobalObjectInlines.h" +#include "JSDOMIterator.h" +#include "JSDOMOperation.h"  #include "JSDOMExceptionHandling.h"  #include "JSDOMGlobalObjectInlines.h"  #include "JSDOMOperation.h" @@ -54,6 +61,8 @@  #include <wtf/GetPtr.h>  #include <wtf/PointerPreparations.h>  #include <wtf/URL.h> +#include "IDLTypes.h" +#include "FetchHeaders.h"  namespace WebCore {  using namespace JSC; @@ -185,6 +194,54 @@ static inline EncodedJSValue constructJSWebSocket2(JSGlobalObject* lexicalGlobal      return JSValue::encode(jsValue);  } +static inline EncodedJSValue constructJSWebSocket3(JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, JSValue urlValue, JSValue optionsObjectValue) +{ +    VM& vm = lexicalGlobalObject->vm(); +    auto throwScope = DECLARE_THROW_SCOPE(vm); +    auto* globalObject = jsCast<Zig::GlobalObject*>(lexicalGlobalObject); +    auto* context = globalObject->scriptExecutionContext(); +    if (UNLIKELY(!context)) +        return throwConstructorScriptExecutionContextUnavailableError(*lexicalGlobalObject, throwScope, "WebSocket"); +    auto url = convert<IDLUSVString>(*lexicalGlobalObject, urlValue); +    RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); + +    Vector<String> protocols; + +    auto headersInit = std::optional<Converter<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>::ReturnType>(); +    if (JSC::JSObject* options = optionsObjectValue.getObject()) { +        if (JSValue headersValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "headers"_s)))) { +            if (!headersValue.isUndefinedOrNull()) { +                headersInit = convert<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>(*lexicalGlobalObject, headersValue); +                RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); +            } +        } + +        if (JSValue protocolsValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocols"_s)))) { +            if (!protocolsValue.isUndefinedOrNull()) { +                protocols = convert<IDLSequence<IDLDOMString>>(*lexicalGlobalObject, protocolsValue); +                RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); +            } +        } else if (JSValue protocolValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocol"_s)))) { +            if (!protocolValue.isUndefinedOrNull()) { +                protocols = Vector<String> { convert<IDLDOMString>(*lexicalGlobalObject, protocolValue) }; +                RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); +            } +        } +    } + +    RETURN_IF_EXCEPTION(throwScope, encodedJSValue()); +    auto object = WebSocket::create(*context, WTFMove(url), protocols, WTFMove(headersInit)); +    if constexpr (IsExceptionOr<decltype(object)>) +        RETURN_IF_EXCEPTION(throwScope, {}); +    static_assert(TypeOrExceptionOrUnderlyingType<decltype(object)>::isRef); +    auto jsValue = toJSNewlyCreated<IDLInterface<WebSocket>>(*lexicalGlobalObject, *globalObject, throwScope, WTFMove(object)); +    if constexpr (IsExceptionOr<decltype(object)>) +        RETURN_IF_EXCEPTION(throwScope, {}); +    setSubclassStructureIfNeeded<WebSocket>(lexicalGlobalObject, callFrame, asObject(jsValue)); +    RETURN_IF_EXCEPTION(throwScope, {}); +    return JSValue::encode(jsValue); +} +  template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::construct(JSGlobalObject* lexicalGlobalObject, CallFrame* callFrame)  {      VM& vm = lexicalGlobalObject->vm(); @@ -204,7 +261,12 @@ template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::co              if (success)                  RELEASE_AND_RETURN(throwScope, (constructJSWebSocket1(lexicalGlobalObject, callFrame)));          } -        RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame))); + +        if (distinguishingArg.isString()) { +            RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame))); +        } else if (distinguishingArg.isObject()) { +            RELEASE_AND_RETURN(throwScope, (constructJSWebSocket3(lexicalGlobalObject, callFrame, callFrame->uncheckedArgument(0), distinguishingArg))); +        }      }      return argsCount < 1 ? throwVMError(lexicalGlobalObject, throwScope, createNotEnoughArgumentsError(lexicalGlobalObject)) : throwVMTypeError(lexicalGlobalObject, throwScope);  } diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp index 2b685ef95..015b706d9 100644 --- a/src/bun.js/bindings/webcore/WebSocket.cpp +++ b/src/bun.js/bindings/webcore/WebSocket.cpp @@ -197,18 +197,23 @@ WebSocket::~WebSocket()  ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url)  { -    return create(context, url, Vector<String> {}); +    return create(context, url, Vector<String> {}, std::nullopt);  }  ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols)  { +    return create(context, url, protocols, std::nullopt); +} + +ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headers) +{      if (url.isNull())          return Exception { SyntaxError };      auto socket = adoptRef(*new WebSocket(context));      // socket->suspendIfNeeded(); -    auto result = socket->connect(url, protocols); +    auto result = socket->connect(url, protocols, WTFMove(headers));      // auto result = socket->connect(url, protocols);      if (result.hasException()) @@ -224,12 +229,12 @@ ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, c  ExceptionOr<void> WebSocket::connect(const String& url)  { -    return connect(url, Vector<String> {}); +    return connect(url, Vector<String> {}, std::nullopt);  }  ExceptionOr<void> WebSocket::connect(const String& url, const String& protocol)  { -    return connect(url, Vector<String> { 1, protocol }); +    return connect(url, Vector<String> { 1, protocol }, std::nullopt);  }  void WebSocket::failAsynchronously() @@ -267,6 +272,11 @@ static String hostName(const URL& url, bool secure)  ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols)  { +    return connect(url, protocols, std::nullopt); +} + +ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headersInit) +{      LOG(Network, "WebSocket %p connect() url='%s'", this, url.utf8().data());      m_url = URL { url }; @@ -280,9 +290,9 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr          return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) };      } -    bool is_secure = m_url.protocolIs("wss"_s); +    bool is_secure = m_url.protocolIs("wss"_s) || m_url.protocolIs("https"_s); -    if (!m_url.protocolIs("ws"_s) && !is_secure) { +    if (!m_url.protocolIs("http"_s) && !m_url.protocolIs("ws"_s) && !is_secure) {          // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );          m_state = CLOSED;          updateHasPendingActivity(); @@ -371,19 +381,41 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr          port = userPort.value();      } +    Vector<ZigString, 8> headerNames; +    Vector<ZigString, 8> headerValues; + +    auto headersOrException = FetchHeaders::create(WTFMove(headersInit)); +    if (UNLIKELY(headersOrException.hasException())) { +        m_state = CLOSED; +        updateHasPendingActivity(); +        return headersOrException.releaseException(); +    } + +    auto headers = headersOrException.releaseReturnValue(); +    headerNames.reserveInitialCapacity(headers.get().internalHeaders().size()); +    headerValues.reserveInitialCapacity(headers.get().internalHeaders().size()); +    auto iterator = headers.get().createIterator(); +    while (auto value = iterator.next()) { +        headerNames.uncheckedAppend(Zig::toZigString(value->key)); +        headerValues.uncheckedAppend(Zig::toZigString(value->value)); +    } +      m_isSecure = is_secure;      this->incPendingActivityCount();      if (is_secure) {          us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>();          RELEASE_ASSERT(ctx); -        this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); +        this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());      } else {          us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>();          RELEASE_ASSERT(ctx); -        this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); +        this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());      } +    headerValues.clear(); +    headerNames.clear(); +      if (this->m_upgradeClient == nullptr) {          // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );          m_state = CLOSED; diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h index 82de58333..59b551ea0 100644 --- a/src/bun.js/bindings/webcore/WebSocket.h +++ b/src/bun.js/bindings/webcore/WebSocket.h @@ -36,6 +36,7 @@  #include <wtf/URL.h>  #include <wtf/HashSet.h>  #include <wtf/Lock.h> +#include "FetchHeaders.h"  namespace uWS {  template<bool, bool, typename> @@ -59,6 +60,7 @@ public:      static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url);      static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const String& protocol);      static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols); +    static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);      ~WebSocket();      enum State { @@ -71,6 +73,7 @@ public:      ExceptionOr<void> connect(const String& url);      ExceptionOr<void> connect(const String& url, const String& protocol);      ExceptionOr<void> connect(const String& url, const Vector<String>& protocols); +    ExceptionOr<void> connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);      ExceptionOr<void> send(const String& message);      ExceptionOr<void> send(JSC::ArrayBuffer&); diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index c908b4bff..801e11007 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -24,14 +24,48 @@ const Opcode = @import("./websocket.zig").Opcode;  const log = Output.scoped(.WebSocketClient, false); -fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 { +const NonUTF8Headers = struct { +    names: []const JSC.ZigString, +    values: []const JSC.ZigString, + +    pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { +        const count = self.names.len; +        var i: usize = 0; +        while (i < count) : (i += 1) { +            try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] }); +        } +    } + +    pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers { +        if (len == 0) { +            return .{ +                .names = &[_]JSC.ZigString{}, +                .values = &[_]JSC.ZigString{}, +            }; +        } + +        return .{ +            .names = names.?[0..len], +            .values = values.?[0..len], +        }; +    } +}; + +fn buildRequestBody( +    vm: *JSC.VirtualMachine, +    pathname: *const JSC.ZigString, +    host: *const JSC.ZigString, +    client_protocol: *const JSC.ZigString, +    client_protocol_hash: *u64, +    extra_headers: NonUTF8Headers, +) std.mem.Allocator.Error![]u8 {      const allocator = vm.allocator;      const input_rand_buf = vm.rareData().nextUUID();      const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16);      var encoded_buf: [temp_buf_size]u8 = undefined;      const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); -    var headers = [_]PicoHTTP.Header{ +    var static_headers = [_]PicoHTTP.Header{          .{              .name = "Sec-WebSocket-Key",              .value = accept_key, @@ -43,9 +77,10 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos      };      if (client_protocol.len > 0) -        client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value); +        client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value); + +    const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; -    var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];      const pathname_ = pathname.slice();      const host_ = host.slice();      const pico_headers = PicoHTTP.Headers{ .headers = headers_ }; @@ -59,12 +94,9 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos              "Upgrade: websocket\r\n" ++              "Sec-WebSocket-Version: 13\r\n" ++              "{any}" ++ +            "{any}" ++              "\r\n", -        .{ -            pathname_, -            host_, -            pico_headers, -        }, +        .{ pathname_, host_, pico_headers, extra_headers },      );  } @@ -174,11 +206,21 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {              port: u16,              pathname: *const JSC.ZigString,              client_protocol: *const JSC.ZigString, +            header_names: ?[*]const JSC.ZigString, +            header_values: ?[*]const JSC.ZigString, +            header_count: usize,          ) callconv(.C) ?*HTTPClient {              std.debug.assert(global.bunVM().uws_event_loop != null);              var client_protocol_hash: u64 = 0; -            var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null; +            var body = buildRequestBody( +                global.bunVM(), +                pathname, +                host, +                client_protocol, +                &client_protocol_hash, +                NonUTF8Headers.init(header_names, header_values, header_count), +            ) catch return null;              var client: HTTPClient = HTTPClient{                  .tcp = undefined,                  .outgoing_websocket = websocket, | 
