diff options
Diffstat (limited to 'src/http/websocket_http_client.zig')
-rw-r--r-- | src/http/websocket_http_client.zig | 42 |
1 files changed, 31 insertions, 11 deletions
diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 7b34dea45..7e5bb26ba 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -295,10 +295,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } }; - var buffered_body_data = body[@minimum(@intCast(usize, response.bytes_read), body.len)..]; - buffered_body_data = buffered_body_data[0..@minimum(buffered_body_data.len, this.body_written)]; - - this.processResponse(response, buffered_body_data, overflow); + this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { @@ -306,7 +303,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.ended); } - pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8, overflow_buf: []const u8) void { + pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { std.debug.assert(this.body_written > 0); var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; @@ -316,10 +313,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { // var visited_version = false; std.debug.assert(response.status_code == 101); - if (remain_buf.len > 0) { - std.debug.assert(overflow_buf.len == 0); - } - for (response.headers) |header| { switch (header.name.len) { "Connection".len => { @@ -408,7 +401,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { // TODO: check websocket_accept_header.value - const overflow_len = overflow_buf.len + remain_buf.len; + const overflow_len = remain_buf.len; var overflow: []u8 = &.{}; if (overflow_len > 0) { overflow = bun.default_allocator.alloc(u8, overflow_len) catch { @@ -416,7 +409,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return; }; if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); - if (overflow_buf.len > 0) @memcpy(overflow.ptr + remain_buf.len, overflow_buf.ptr, overflow_buf.len); } this.clearData(); @@ -1432,6 +1424,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, + buffered_data: [*]u8, + buffered_data_len: usize, ) callconv(.C) ?*anyopaque { var tcp = @ptrCast(*uws.Socket, input_socket); var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx); @@ -1453,6 +1447,32 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { adopted.event_loop_ref = true; adopted.globalThis.bunVM().us_loop_reference_count +|= 1; _ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic); + var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; + if (buffered_slice.len > 0) { + const InitialDataHandler = struct { + adopted: *WebSocket, + slice: []u8, + + pub fn handle(this: *@This()) void { + defer { + bun.default_allocator.free(this.slice); + bun.default_allocator.destroy(this); + } + + this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; + var writable = this.adopted.receive_buffer.writableSlice(0); + @memcpy(writable.ptr, this.slice.ptr, this.slice.len); + + this.adopted.handleData(this.adopted.tcp, writable); + } + }; + var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; + initial_data.* = .{ + .adopted = adopted, + .slice = buffered_slice, + }; + globalThis.bunVM().uws_event_loop.?.nextTick(*InitialDataHandler, initial_data, InitialDataHandler.handle); + } return @ptrCast( *anyopaque, adopted, |