diff options
author | 2023-04-11 12:46:30 -0300 | |
---|---|---|
committer | 2023-05-13 06:26:44 -0300 | |
commit | e50ce69a3b10a50cb3eb6e3876c8bfb849cb078e (patch) | |
tree | 1a5b76508b7f1cf4833feb3831281381d91e8747 | |
parent | 7f25aa9e0864e95aad72ee85d475a03aee68bfb4 (diff) | |
download | bun-ciro/ws-fetch-proper-handshake.tar.gz bun-ciro/ws-fetch-proper-handshake.tar.zst bun-ciro/ws-fetch-proper-handshake.zip |
attempt to adapt the new handshake properlyciro/ws-fetch-proper-handshake
-rw-r--r-- | src/bun.js/bindings/ScriptExecutionContext.cpp | 12 | ||||
-rw-r--r-- | src/http/websocket_http_client.zig | 19 | ||||
-rw-r--r-- | src/http_client_async.zig | 87 |
3 files changed, 81 insertions, 37 deletions
diff --git a/src/bun.js/bindings/ScriptExecutionContext.cpp b/src/bun.js/bindings/ScriptExecutionContext.cpp index 151c66495..cd22f9e1e 100644 --- a/src/bun.js/bindings/ScriptExecutionContext.cpp +++ b/src/bun.js/bindings/ScriptExecutionContext.cpp @@ -38,9 +38,9 @@ us_socket_context_t* ScriptExecutionContext::webSocketContextSSL() { if (!m_ssl_client_websockets_ctx) { us_loop_t* loop = (us_loop_t*)uws_get_loop(); - us_socket_context_options_t opts; - memset(&opts, 0, sizeof(us_socket_context_options_t)); - this->m_ssl_client_websockets_ctx = us_create_socket_context(1, loop, sizeof(size_t), opts); + us_bun_socket_context_options_t opts; + memset(&opts, 0, sizeof(us_bun_socket_context_options_t)); + this->m_ssl_client_websockets_ctx = us_create_bun_socket_context(1, loop, sizeof(size_t), opts); void** ptr = reinterpret_cast<void**>(us_socket_context_ext(1, m_ssl_client_websockets_ctx)); *ptr = this; registerHTTPContextForWebSocket<true, false>(this, m_ssl_client_websockets_ctx, loop); @@ -65,9 +65,9 @@ us_socket_context_t* ScriptExecutionContext::webSocketContextNoSSL() { if (!m_client_websockets_ctx) { us_loop_t* loop = (us_loop_t*)uws_get_loop(); - us_socket_context_options_t opts; - memset(&opts, 0, sizeof(us_socket_context_options_t)); - this->m_client_websockets_ctx = us_create_socket_context(0, loop, sizeof(size_t), opts); + us_bun_socket_context_options_t opts; + memset(&opts, 0, sizeof(us_bun_socket_context_options_t)); + this->m_client_websockets_ctx = us_create_bun_socket_context(0, loop, sizeof(size_t), opts); void** ptr = reinterpret_cast<void**>(us_socket_context_ext(0, m_client_websockets_ctx)); *ptr = this; registerHTTPContextForWebSocket<false, false>(this, m_client_websockets_ctx, loop); diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 89c3f70c8..626ba3d7a 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -128,6 +128,7 @@ const ErrorCode = enum(i32) { unsupported_control_frame, unexpected_opcode, invalid_utf8, + tls_handshake_error, }; extern fn WebSocket__didConnect( websocket_context: *anyopaque, @@ -185,6 +186,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { HTTPClient, struct { pub const onOpen = handleOpen; + pub const onHandshake = handleHandshake; pub const onClose = handleClose; pub const onData = handleData; pub const onWritable = handleWritable; @@ -304,6 +306,22 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.tcp.close(0, null); } + pub fn handleHandshake(this: *HTTPClient, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + _ = ssl_error; + log("onHandshake", .{}); + if (success == 1) { + const wrote = socket.write(this.input_body_buf, true); + if (wrote < 0) { + this.terminate(ErrorCode.failed_to_write); + return; + } + + this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; + } else { + this.terminate(ErrorCode.tls_handshake_error); + } + } + pub fn handleOpen(this: *HTTPClient, socket: Socket) void { log("onOpen", .{}); std.debug.assert(socket.socket == this.tcp.socket); @@ -317,6 +335,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { bun.default_allocator.free(this.hostname); this.hostname = ""; } + return; } const wrote = socket.write(this.input_body_buf, true); diff --git a/src/http_client_async.zig b/src/http_client_async.zig index 276186fe5..50f05afde 100644 --- a/src/http_client_async.zig +++ b/src/http_client_async.zig @@ -236,9 +236,9 @@ fn NewHTTPContext(comptime ssl: bool) type { } pub fn init(this: *@This()) !void { - var opts: uws.us_socket_context_options_t = undefined; - @memset(@ptrCast([*]u8, &opts), 0, @sizeOf(uws.us_socket_context_options_t)); - this.us_socket_context = uws.us_create_socket_context(ssl_int, http_thread.loop, @sizeOf(usize), opts).?; + var opts: uws.us_bun_socket_context_options_t = undefined; + @memset(@ptrCast([*]u8, &opts), 0, @sizeOf(uws.us_bun_socket_context_options_t)); + this.us_socket_context = uws.us_create_bun_socket_context(ssl_int, http_thread.loop, @sizeOf(usize), opts).?; if (comptime ssl) { this.sslCtx().setup(); } @@ -291,7 +291,36 @@ fn NewHTTPContext(comptime ssl: bool) type { ) void { const active = ActiveSocket.from(bun.cast(**anyopaque, ptr).*); if (active.get(HTTPClient)) |client| { - return client.onOpen(comptime ssl, socket); + if (comptime ssl == false) { + return client.onOpen(comptime ssl, socket); + } else { + var native_ssl: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); + if (!native_ssl.isInitFinished()) { + var _hostname = client.hostname orelse client.url.hostname; + if (client.http_proxy) |proxy| { + _hostname = proxy.hostname; + } + + var hostname: [:0]const u8 = ""; + var hostname_needs_free = false; + if (!strings.isIPAddress(_hostname)) { + if (_hostname.len < temp_hostname.len) { + @memcpy(&temp_hostname, _hostname.ptr, _hostname.len); + temp_hostname[_hostname.len] = 0; + hostname = temp_hostname[0.._hostname.len :0]; + } else { + hostname = bun.default_allocator.dupeZ(u8, _hostname) catch unreachable; + hostname_needs_free = true; + } + } + + defer if (hostname_needs_free) bun.default_allocator.free(hostname); + + native_ssl.configureHTTPClient(hostname); + } + // when ssl wait handshake before open + return; + } } if (active.get(PooledSocket)) |pooled| { @@ -304,6 +333,20 @@ fn NewHTTPContext(comptime ssl: bool) type { std.debug.assert(false); } } + + pub fn onHandshake(ptr: *anyopaque, socket: HTTPSocket, success: i32, _: uws.us_bun_verify_error_t) void { + log("onHandshake {d}", .{success}); + + const active = ActiveSocket.from(bun.cast(**anyopaque, ptr).*); + if (active.get(HTTPClient)) |client| { + if (success == 1) { + return client.onOpen(comptime ssl, socket); + } + // Fail with TLSHandshakeError + client.onTLSHandshakeError(comptime ssl, socket); + } + } + pub fn onClose( ptr: *anyopaque, socket: HTTPSocket, @@ -661,6 +704,7 @@ pub const HTTPThread = struct { const log = Output.scoped(.fetch, false); var temp_hostname: [8096]u8 = undefined; + pub fn onOpen( client: *HTTPClient, comptime is_ssl: bool, @@ -683,33 +727,6 @@ pub fn onOpen( return; } - if (comptime is_ssl) { - var ssl: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); - if (!ssl.isInitFinished()) { - var _hostname = client.hostname orelse client.url.hostname; - if (client.http_proxy) |proxy| { - _hostname = proxy.hostname; - } - - var hostname: [:0]const u8 = ""; - var hostname_needs_free = false; - if (!strings.isIPAddress(_hostname)) { - if (_hostname.len < temp_hostname.len) { - @memcpy(&temp_hostname, _hostname.ptr, _hostname.len); - temp_hostname[_hostname.len] = 0; - hostname = temp_hostname[0.._hostname.len :0]; - } else { - hostname = bun.default_allocator.dupeZ(u8, _hostname) catch unreachable; - hostname_needs_free = true; - } - } - - defer if (hostname_needs_free) bun.default_allocator.free(hostname); - - ssl.configureHTTPClient(hostname); - } - } - if (client.state.request_stage == .pending) { client.onWritable(true, comptime is_ssl, socket); } @@ -769,6 +786,14 @@ pub fn onConnectError( if (client.state.stage != .done and client.state.stage != .fail) client.fail(error.ConnectionRefused); } + +pub fn onTLSHandshakeError(client: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void { + log("onTLSHandshakeError {s}\n", .{client.url.href}); + + if (client.state.stage != .done and client.state.stage != .fail) + client.closeAndFail(error.TLSHandshakeError, is_ssl, socket); +} + pub fn onEnd( client: *HTTPClient, comptime is_ssl: bool, |