diff options
| author | 2022-11-19 04:57:32 -0800 | |
|---|---|---|
| committer | 2022-11-19 04:57:32 -0800 | |
| commit | d6d04cab2415b662e1a1a9ce937fa42bfb33d823 (patch) | |
| tree | cceeaad9bd7c4065c6a09d77cebaab91ef183caf /src | |
| parent | bb95f90a62f3bbccd63288876870d1b107f510c3 (diff) | |
| download | bun-d6d04cab2415b662e1a1a9ce937fa42bfb33d823.tar.gz bun-d6d04cab2415b662e1a1a9ce937fa42bfb33d823.tar.zst bun-d6d04cab2415b662e1a1a9ce937fa42bfb33d823.zip | |
Fix GC crash with `WebSocket` uncovered thx to `BUN_GARBAGE_COLLECTOR_LEVEL`
Diffstat (limited to 'src')
| -rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.cpp | 68 | ||||
| -rw-r--r-- | src/bun.js/bindings/webcore/WebSocket.h | 20 |
2 files changed, 63 insertions, 25 deletions
diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp index 820f0e804..2b685ef95 100644 --- a/src/bun.js/bindings/webcore/WebSocket.cpp +++ b/src/bun.js/bindings/webcore/WebSocket.cpp @@ -157,11 +157,12 @@ WebSocket::WebSocket(ScriptExecutionContext& context) , m_subprotocol(emptyString()) , m_extensions(emptyString()) { + m_state = CONNECTING; + m_hasPendingActivity.store(true); } WebSocket::~WebSocket() { - if (m_upgradeClient != nullptr) { void* upgradeClient = m_upgradeClient; if (m_isSecure) { @@ -275,6 +276,7 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr if (!m_url.isValid()) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + updateHasPendingActivity(); return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) }; } @@ -283,11 +285,13 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr if (!m_url.protocolIs("ws"_s) && !is_secure) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + updateHasPendingActivity(); return Exception { SyntaxError, makeString("Wrong url scheme for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) }; } if (m_url.hasFragmentIdentifier()) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + updateHasPendingActivity(); return Exception { SyntaxError, makeString("URL has fragment component "_s, m_url.stringCenterEllipsizedToLength()) }; } @@ -326,6 +330,7 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr if (!isValidProtocolString(protocol)) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + updateHasPendingActivity(); return Exception { SyntaxError, makeString("Wrong protocol for WebSocket '"_s, encodeProtocolString(protocol), "'"_s) }; } } @@ -334,6 +339,7 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr if (!visited.add(protocol).isNewEntry) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + updateHasPendingActivity(); return Exception { SyntaxError, makeString("WebSocket protocols contain duplicates:"_s, encodeProtocolString(protocol), "'"_s) }; } } @@ -366,22 +372,22 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr } m_isSecure = is_secure; + this->incPendingActivityCount(); + if (is_secure) { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>(); RELEASE_ASSERT(ctx); - this->m_pendingActivityCount++; this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); } else { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>(); RELEASE_ASSERT(ctx); - this->m_pendingActivityCount++; this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); } if (this->m_upgradeClient == nullptr) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; - this->m_pendingActivityCount--; + this->decPendingActivityCount(); return Exception { SyntaxError, "WebSocket connection failed"_s }; } @@ -399,7 +405,7 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr // #endif // m_pendingActivity = makePendingActivity(*this); - + updateHasPendingActivity(); return {}; } @@ -482,7 +488,6 @@ ExceptionOr<void> WebSocket::send(ArrayBufferView& arrayBufferView) void WebSocket::sendWebSocketData(const char* baseAddress, size_t length) { - switch (m_connectedWebSocketKind) { case ConnectedWebSocketKind::Client: { Bun__WebSocketClient__writeBinaryData(this->m_connectedWebSocket.client, reinterpret_cast<const unsigned char*>(baseAddress), length); @@ -512,7 +517,6 @@ void WebSocket::sendWebSocketData(const char* baseAddress, size_t length) void WebSocket::sendWebSocketString(const String& message) { - switch (m_connectedWebSocketKind) { case ConnectedWebSocketKind::Client: { auto zigStr = Zig::toZigString(message); @@ -542,11 +546,11 @@ void WebSocket::sendWebSocketString(const String& message) RELEASE_ASSERT_NOT_REACHED(); } } + updateHasPendingActivity(); } ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, const String& reason) { - int code = optionalCode ? optionalCode.value() : static_cast<int>(0); if (code == 0) LOG(Network, "WebSocket %p close() without code and reason", this); @@ -573,6 +577,7 @@ ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, c Bun__WebSocketHTTPClient__cancel(upgradeClient); } } + updateHasPendingActivity(); return {}; } m_state = CLOSING; @@ -580,12 +585,14 @@ ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, c case ConnectedWebSocketKind::Client: { ZigString reasonZigStr = Zig::toZigString(reason); Bun__WebSocketClient__close(this->m_connectedWebSocket.client, code, &reasonZigStr); + updateHasPendingActivity(); // this->m_bufferedAmount = this->m_connectedWebSocket.client->getBufferedAmount(); break; } case ConnectedWebSocketKind::ClientSSL: { ZigString reasonZigStr = Zig::toZigString(reason); Bun__WebSocketClientTLS__close(this->m_connectedWebSocket.clientSSL, code, &reasonZigStr); + updateHasPendingActivity(); // this->m_bufferedAmount = this->m_connectedWebSocket.clientSSL->getBufferedAmount(); break; } @@ -604,7 +611,7 @@ ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, c } } this->m_connectedWebSocketKind = ConnectedWebSocketKind::None; - + updateHasPendingActivity(); return {}; } @@ -715,7 +722,6 @@ ScriptExecutionContext* WebSocket::scriptExecutionContext() const void WebSocket::didConnect() { // from new WebSocket() -> connect() - this->m_pendingActivityCount--; LOG(Network, "WebSocket %p didConnect()", this); // queueTaskKeepingObjectAlive(*this, TaskSource::WebSocket, [this] { @@ -730,10 +736,12 @@ void WebSocket::didConnect() if (auto* context = scriptExecutionContext()) { if (this->hasEventListeners("open"_s)) { + this->incPendingActivityCount(); // the main reason for dispatching on a separate tick is to handle when you haven't yet attached an event listener dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No)); + this->decPendingActivityCount(); } else { - this->m_pendingActivityCount++; + this->incPendingActivityCount(); context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); @@ -741,7 +749,7 @@ void WebSocket::didConnect() // m_extensions = m_channel->extensions(); protectedThis->dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No)); // }); - protectedThis->m_pendingActivityCount--; + protectedThis->decPendingActivityCount(); }); } } @@ -768,11 +776,11 @@ void WebSocket::didReceiveMessage(String&& message) } if (auto* context = scriptExecutionContext()) { - this->m_pendingActivityCount++; + this->incPendingActivityCount(); context->postTask([this, message_ = WTFMove(message), protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(MessageEvent::create(message_, protectedThis->m_url.string())); - protectedThis->m_pendingActivityCount--; + protectedThis->decPendingActivityCount(); }); } @@ -805,11 +813,11 @@ void WebSocket::didReceiveBinaryData(Vector<uint8_t>&& binaryData) if (auto* context = scriptExecutionContext()) { auto arrayBuffer = JSC::ArrayBuffer::create(binaryData.data(), binaryData.size()); - this->m_pendingActivityCount++; + this->incPendingActivityCount(); context->postTask([this, buffer = WTFMove(arrayBuffer), protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(MessageEvent::create(buffer, m_url.string())); - protectedThis->m_pendingActivityCount--; + protectedThis->decPendingActivityCount(); }); } @@ -827,10 +835,10 @@ void WebSocket::didReceiveMessageError(unsigned short code, WTF::StringImpl::Sta return; m_state = CLOSED; if (auto* context = scriptExecutionContext()) { - this->m_pendingActivityCount++; + this->incPendingActivityCount(); // https://html.spec.whatwg.org/multipage/web-sockets.html#feedback-from-the-protocol:concept-websocket-closed, we should synchronously fire a close event. dispatchEvent(CloseEvent::create(code < 1002, code, WTF::String(reason))); - this->m_pendingActivityCount--; + this->decPendingActivityCount(); } } @@ -849,6 +857,7 @@ void WebSocket::didStartClosingHandshake() if (m_state == CLOSED) return; m_state = CLOSING; + updateHasPendingActivity(); // }); } @@ -878,16 +887,19 @@ void WebSocket::didClose(unsigned unhandledBufferedAmount, unsigned short code, this->m_upgradeClient = nullptr; if (this->hasEventListeners("close"_s)) { + this->incPendingActivityCount(); this->dispatchEvent(CloseEvent::create(wasClean, code, reason)); + this->decPendingActivityCount(); + return; } if (auto* context = scriptExecutionContext()) { - this->m_pendingActivityCount++; + this->incPendingActivityCount(); context->postTask([this, code, wasClean, reason, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(CloseEvent::create(wasClean, code, reason)); - protectedThis->m_pendingActivityCount--; + protectedThis->decPendingActivityCount(); }); } @@ -909,11 +921,11 @@ void WebSocket::dispatchErrorEventIfNeeded() m_dispatchedErrorEvent = true; if (auto* context = scriptExecutionContext()) { - this->m_pendingActivityCount++; + this->incPendingActivityCount(); context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(Event::create(eventNames().errorEvent, Event::CanBubble::No, Event::IsCancelable::No)); - protectedThis->m_pendingActivityCount--; + protectedThis->decPendingActivityCount(); }); } } @@ -936,10 +948,10 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe void WebSocket::didFailWithErrorCode(int32_t code) { // from new WebSocket() -> connect() + if (m_state == CLOSED) return; - this->m_pendingActivityCount = this->m_pendingActivityCount > 0 ? this->m_pendingActivityCount - 1 : 0; this->m_upgradeClient = nullptr; this->m_connectedWebSocketKind = ConnectedWebSocketKind::None; this->m_connectedWebSocket.client = nullptr; @@ -1109,7 +1121,17 @@ void WebSocket::didFailWithErrorCode(int32_t code) } m_state = CLOSED; + scriptExecutionContext()->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) { + protectedThis->decPendingActivityCount(); + }); } +void WebSocket::updateHasPendingActivity() +{ + std::atomic_thread_fence(std::memory_order_acquire); + m_hasPendingActivity.store( + !(m_state == CLOSED && m_pendingActivityCount == 0)); +} + } // namespace WebCore extern "C" void WebSocket__didConnect(WebCore::WebSocket* webSocket, us_socket_t* socket, char* bufferedData, size_t len) diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h index 796268d7a..82de58333 100644 --- a/src/bun.js/bindings/webcore/WebSocket.h +++ b/src/bun.js/bindings/webcore/WebSocket.h @@ -66,7 +66,6 @@ public: OPEN = 1, CLOSING = 2, CLOSED = 3, - }; ExceptionOr<void> connect(const String& url); @@ -103,9 +102,10 @@ public: void didReceiveData(const char* data, size_t length); void didReceiveBinaryData(Vector<uint8_t>&&); + void updateHasPendingActivity(); bool hasPendingActivity() const { - return m_state == State::OPEN || m_state == State::CLOSING || m_pendingActivityCount > 0; + return m_hasPendingActivity.load(); } private: @@ -119,6 +119,8 @@ private: ClientSSL, }; + std::atomic<bool> m_hasPendingActivity { true }; + explicit WebSocket(ScriptExecutionContext&); explicit WebSocket(ScriptExecutionContext&, const String& url); @@ -142,6 +144,20 @@ private: void sendWebSocketString(const String& message); void sendWebSocketData(const char* data, size_t length); + void incPendingActivityCount() + { + m_pendingActivityCount++; + ref(); + updateHasPendingActivity(); + } + + void decPendingActivityCount() + { + m_pendingActivityCount--; + deref(); + updateHasPendingActivity(); + } + void failAsynchronously(); enum class BinaryType { Blob, |
