diff options
Diffstat (limited to 'src/http')
-rw-r--r-- | src/http/method.zig | 88 | ||||
-rw-r--r-- | src/http/websocket.zig | 8 | ||||
-rw-r--r-- | src/http/websocket_http_client.zig | 206 |
3 files changed, 216 insertions, 86 deletions
diff --git a/src/http/method.zig b/src/http/method.zig index 4a3d45133..d2668f1b7 100644 --- a/src/http/method.zig +++ b/src/http/method.zig @@ -11,15 +11,42 @@ const C = bun.C; const std = @import("std"); pub const Method = enum { + ACL, + BIND, + CHECKOUT, + CONNECT, + COPY, + DELETE, GET, HEAD, + LINK, + LOCK, + @"M-SEARCH", + MERGE, + MKACTIVITY, + MKCALENDAR, + MKCOL, + MOVE, + NOTIFY, + OPTIONS, PATCH, - PUT, POST, - OPTIONS, - CONNECT, + PROPFIND, + PROPPATCH, + PURGE, + PUT, + /// https://httpwg.org/http-extensions/draft-ietf-httpbis-safe-method-w-body.html + QUERY, + REBIND, + REPORT, + SEARCH, + SOURCE, + SUBSCRIBE, TRACE, - DELETE, + UNBIND, + UNLINK, + UNLOCK, + UNSUBSCRIBE, const with_body: std.enums.EnumSet(Method) = brk: { var values = std.enums.EnumSet(Method).initFull(); @@ -47,24 +74,77 @@ pub const Method = enum { } const Map = bun.ComptimeStringMap(Method, .{ + .{ "ACL", Method.ACL }, + .{ "BIND", Method.BIND }, + .{ "CHECKOUT", Method.CHECKOUT }, .{ "CONNECT", Method.CONNECT }, + .{ "COPY", Method.COPY }, .{ "DELETE", Method.DELETE }, .{ "GET", Method.GET }, .{ "HEAD", Method.HEAD }, + .{ "LINK", Method.LINK }, + .{ "LOCK", Method.LOCK }, + .{ "M-SEARCH", Method.@"M-SEARCH" }, + .{ "MERGE", Method.MERGE }, + .{ "MKACTIVITY", Method.MKACTIVITY }, + .{ "MKCALENDAR", Method.MKCALENDAR }, + .{ "MKCOL", Method.MKCOL }, + .{ "MOVE", Method.MOVE }, + .{ "NOTIFY", Method.NOTIFY }, .{ "OPTIONS", Method.OPTIONS }, .{ "PATCH", Method.PATCH }, .{ "POST", Method.POST }, + .{ "PROPFIND", Method.PROPFIND }, + .{ "PROPPATCH", Method.PROPPATCH }, + .{ "PURGE", Method.PURGE }, .{ "PUT", Method.PUT }, + .{ "QUERY", Method.QUERY }, + .{ "REBIND", Method.REBIND }, + .{ "REPORT", Method.REPORT }, + .{ "SEARCH", Method.SEARCH }, + .{ "SOURCE", Method.SOURCE }, + .{ "SUBSCRIBE", Method.SUBSCRIBE }, .{ "TRACE", Method.TRACE }, + .{ "UNBIND", Method.UNBIND }, + .{ "UNLINK", Method.UNLINK }, + .{ "UNLOCK", Method.UNLOCK }, + .{ "UNSUBSCRIBE", Method.UNSUBSCRIBE }, + + .{ "acl", Method.ACL }, + .{ "bind", Method.BIND }, + .{ "checkout", Method.CHECKOUT }, .{ "connect", Method.CONNECT }, + .{ "copy", Method.COPY }, .{ "delete", Method.DELETE }, .{ "get", Method.GET }, .{ "head", Method.HEAD }, + .{ "link", Method.LINK }, + .{ "lock", Method.LOCK }, + .{ "m-search", Method.@"M-SEARCH" }, + .{ "merge", Method.MERGE }, + .{ "mkactivity", Method.MKACTIVITY }, + .{ "mkcalendar", Method.MKCALENDAR }, + .{ "mkcol", Method.MKCOL }, + .{ "move", Method.MOVE }, + .{ "notify", Method.NOTIFY }, .{ "options", Method.OPTIONS }, .{ "patch", Method.PATCH }, .{ "post", Method.POST }, + .{ "propfind", Method.PROPFIND }, + .{ "proppatch", Method.PROPPATCH }, + .{ "purge", Method.PURGE }, .{ "put", Method.PUT }, + .{ "query", Method.QUERY }, + .{ "rebind", Method.REBIND }, + .{ "report", Method.REPORT }, + .{ "search", Method.SEARCH }, + .{ "source", Method.SOURCE }, + .{ "subscribe", Method.SUBSCRIBE }, .{ "trace", Method.TRACE }, + .{ "unbind", Method.UNBIND }, + .{ "unlink", Method.UNLINK }, + .{ "unlock", Method.UNLOCK }, + .{ "unsubscribe", Method.UNSUBSCRIBE }, }); pub fn which(str: []const u8) ?Method { diff --git a/src/http/websocket.zig b/src/http/websocket.zig index 98410d57c..48a4cebf5 100644 --- a/src/http/websocket.zig +++ b/src/http/websocket.zig @@ -34,7 +34,7 @@ pub const Opcode = enum(u4) { ResF = 0xF, pub fn isControl(opcode: Opcode) bool { - return @enumToInt(opcode) & 0x8 != 0; + return @intFromEnum(opcode) & 0x8 != 0; } }; @@ -261,7 +261,7 @@ pub const Websocket = struct { } pub fn read(self: *Websocket) !WebsocketDataFrame { - @memset(&self.buf, 0, self.buf.len); + @memset(&self.buf, 0); // Read and retry if we hit the end of the stream buffer var start = try self.stream.read(&self.buf); @@ -274,7 +274,7 @@ pub const Websocket = struct { } pub fn eatAt(self: *Websocket, offset: usize, _len: usize) []u8 { - const len = std.math.min(self.read_stream.buffer.len, _len); + const len = @min(self.read_stream.buffer.len, _len); self.read_stream.pos = len; return self.read_stream.buffer[offset..len]; } @@ -292,7 +292,7 @@ pub const Websocket = struct { // header.rsv1 = header_bytes[0] & 0x40 == 0x40; // header.rsv2 = header_bytes[0] & 0x20; // header.rsv3 = header_bytes[0] & 0x10; - header.opcode = @intToEnum(Opcode, @truncate(u4, header_bytes[0])); + header.opcode = @enumFromInt(Opcode, @truncate(u4, header_bytes[0])); header.mask = header_bytes[1] & 0x80 == 0x80; header.len = @truncate(u7, header_bytes[1]); diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index e1bd42984..ae8e40763 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -60,7 +60,7 @@ fn buildRequestBody( extra_headers: NonUTF8Headers, ) std.mem.Allocator.Error![]u8 { const allocator = vm.allocator; - const input_rand_buf = vm.rareData().nextUUID(); + const input_rand_buf = vm.rareData().nextUUID().bytes; const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16); var encoded_buf: [temp_buf_size]u8 = undefined; const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf); @@ -77,9 +77,9 @@ fn buildRequestBody( }; if (client_protocol.len > 0) - client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value); + client_protocol_hash.* = bun.hash(static_headers[1].value); - const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))]; + const headers_ = static_headers[0 .. 1 + @as(usize, @intFromBool(client_protocol.len > 0))]; const pathname_ = pathname.slice(); const host_ = host.slice(); @@ -145,6 +145,17 @@ const CppWebSocket = opaque { pub const didCloseWithErrorCode = WebSocket__didCloseWithErrorCode; pub const didReceiveText = WebSocket__didReceiveText; pub const didReceiveBytes = WebSocket__didReceiveBytes; + extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void; + extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void; + pub fn ref(this: *CppWebSocket) void { + JSC.markBinding(@src()); + WebSocket__incrementPendingActivity(this); + } + + pub fn unref(this: *CppWebSocket) void { + JSC.markBinding(@src()); + WebSocket__decrementPendingActivity(this); + } }; const body_buf_len = 16384 - 16; @@ -163,8 +174,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { to_send: []const u8 = "", read_length: usize = 0, headers_buf: [128]PicoHTTP.Header = undefined, - body_buf: ?*BodyBuf = null, - body_written: usize = 0, + body: std.ArrayListUnmanaged(u8) = .{}, websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: JSC.PollRef = .{}, @@ -280,16 +290,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.poll_ref.unrefOnNextTick(JSC.VirtualMachine.get()); this.clearInput(); - if (this.body_buf) |buf| { - this.body_buf = null; - buf.release(); - } + this.body.clearAndFree(bun.default_allocator); } pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); + _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } @@ -355,14 +362,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.to_send = this.input_body_buf[@intCast(usize, wrote)..]; } - fn getBody(this: *HTTPClient) *BodyBufBytes { - if (this.body_buf == null) { - this.body_buf = BodyBufPool.get(bun.default_allocator); - } - - return &this.body_buf.?.data; - } - pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); std.debug.assert(socket.socket == this.tcp.socket); @@ -374,43 +373,37 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); - var body = this.getBody(); - var remain = body[this.body_written..]; - const is_first = this.body_written == 0; + var body = data; + if (this.body.items.len > 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); + body = this.body.items; + } + + const is_first = this.body.items.len == 0; if (is_first) { // fail early if we receive a non-101 status code - if (!strings.hasPrefixComptime(data, "HTTP/1.1 101 ")) { + if (!strings.hasPrefixComptime(body, "HTTP/1.1 101 ")) { this.terminate(ErrorCode.expected_101_status_code); return; } } - const to_write = remain[0..@min(remain.len, data.len)]; - if (data.len > 0 and to_write.len > 0) { - @memcpy(remain.ptr, data.ptr, to_write.len); - this.body_written += to_write.len; - } - - const overflow = data[to_write.len..]; - - const available_to_read = body[0..this.body_written]; - const response = PicoHTTP.Response.parse(available_to_read, &this.headers_buf) catch |err| { + const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| { switch (err) { error.Malformed_HTTP_Response => { this.terminate(ErrorCode.invalid_response); return; }, error.ShortRead => { - if (overflow.len > 0) { - this.terminate(ErrorCode.headers_too_large); - return; + if (this.body.items.len == 0) { + this.body.appendSlice(bun.default_allocator, data) catch @panic("out of memory"); } return; }, } }; - this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); + this.processResponse(response, body[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { @@ -420,8 +413,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { - std.debug.assert(this.body_written > 0); - var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; var connection_header = PicoHTTP.Header{ .name = "", .value = "" }; var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" }; @@ -465,7 +456,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { }, "Sec-WebSocket-Protocol".len => { if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) { - if (this.websocket_protocol == 0 or std.hash.Wyhash.hash(0, header.value) != this.websocket_protocol) { + if (this.websocket_protocol == 0 or bun.hash(header.value) != this.websocket_protocol) { this.terminate(ErrorCode.mismatch_client_protocol); return; } @@ -524,7 +515,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.invalid_response); return; }; - if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); + @memcpy(overflow, remain_buf); } this.clearData(); @@ -757,7 +748,7 @@ const Copy = union(enum) { return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .latin1 => { - byte_len.* = this.latin1.len; + byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1); return WebsocketHeader.frameSizeIncludingMask(byte_len.*); }, .bytes => { @@ -775,7 +766,7 @@ const Copy = union(enum) { if (this == .raw) { std.debug.assert(buf.len >= this.raw.len); std.debug.assert(buf.ptr != this.raw.ptr); - @memcpy(buf.ptr, this.raw.ptr, this.raw.len); + @memcpy(buf[0..this.raw.len], this.raw); return; } @@ -821,7 +812,10 @@ const Copy = union(enum) { .latin1 => |latin1| { const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1); std.debug.assert(@as(usize, encode_into_result.written) == content_byte_len); + + // latin1 can contain non-ascii std.debug.assert(@as(usize, encode_into_result.read) == latin1.len); + header.len = WebsocketHeader.packLength(encode_into_result.written); header.opcode = Opcode.Text; var fib = std.io.fixedBufferStream(buf); @@ -863,6 +857,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { globalThis: *JSC.JSGlobalObject, poll_ref: JSC.PollRef = JSC.PollRef.init(), + initial_data_handler: ?*InitialDataHandler = null, + pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); @@ -914,7 +910,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @boolToInt(ssl)), this.tcp.socket); + _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); } else { this.tcp.close(0, null); } @@ -924,6 +920,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { JSC.markBinding(@src()); if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; + log("fail ({s})", .{@tagName(code)}); + ws.didCloseWithErrorCode(code); } @@ -934,7 +932,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { _ = socket; _ = ssl_error; JSC.markBinding(@src()); - log("WebSocket.onHandshake({d})", .{success}); + log("onHandshake({d})", .{success}); JSC.markBinding(@src()); if (success == 0) { if (this.outgoing_websocket) |ws| { @@ -1027,7 +1025,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { std.debug.assert(data_.len > 0); var writable = this.receive_buffer.writableWithSize(data_.len) catch unreachable; - @memcpy(writable.ptr, data_.ptr, data_.len); + @memcpy(writable[0..data_.len], data_); this.receive_buffer.update(data_.len); if (left_in_fragment >= data_.len and left_in_fragment - data_.len - this.receive_pending_chunk_len == 0) { @@ -1041,6 +1039,24 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { + // Due to scheduling, it is possible for the websocket onData + // handler to run with additional data before the microtask queue is + // drained. + if (this.initial_data_handler) |initial_handler| { + // This calls `handleData` + // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit. + // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there) + initial_handler.handleWithoutDeinit(); + + // handleWithoutDeinit is supposed to clear the handler from WebSocket* + // to prevent an infinite loop + std.debug.assert(this.initial_data_handler == null); + + // If we disconnected for any reason in the re-entrant case, we should just ignore the data + if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed()) + return; + } + var data = data_; var receive_state = this.receive_state; var terminated = false; @@ -1138,6 +1154,30 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { terminated = true; break; } + + // Handle when the payload length is 0, but it is a message + // + // This should become + // + // - ArrayBuffer(0) + // - "" + // - Buffer(0) (etc) + // + if (receive_body_remain == 0 and receive_state == .need_body and is_final) { + _ = this.consume( + "", + receive_body_remain, + last_receive_data_type, + is_final, + ); + + // Return to the header state to read the next frame + receive_state = .need_header; + is_fragmented = false; + + // Bail out if there's nothing left to read + if (data.len == 0) break; + } }, .need_mask => { this.terminate(.unexpected_mask_from_server); @@ -1177,10 +1217,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { .ping => { const ping_len = @min(data.len, @min(receive_body_remain, 125)); - this.ping_len = @truncate(u8, ping_len); + this.ping_len = ping_len; if (ping_len > 0) { - @memcpy(this.ping_frame_bytes[6..], data.ptr, ping_len); + @memcpy(this.ping_frame_bytes[6..][0..ping_len], data[0..ping_len]); data = data[ping_len..]; } @@ -1198,6 +1238,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (data.len == 0) break; }, .need_body => { + // Empty messages are valid, but we handle that earlier in the flow. if (receive_body_remain == 0 and data.len > 0) { this.terminate(ErrorCode.expected_control_frame); terminated = true; @@ -1379,7 +1420,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { std.mem.writeIntSliceBig(u16, final_body_bytes[6..8], code); if (body) |data| { - if (body_len > 0) @memcpy(final_body_bytes[8..], data, body_len); + if (body_len > 0) @memcpy(final_body_bytes[8..][0..body_len], data[0..body_len]); } // we must mask the code @@ -1431,9 +1472,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (len == 0) - return; - const slice = ptr[0..len]; const bytes = Copy{ .bytes = slice }; // fast path: small frame, no backpressure, attempt to send without allocating @@ -1457,9 +1495,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { return; } - if (str.len == 0) { - return; - } + // Note: 0 is valid { var inline_buf: [stack_frame_size]u8 = undefined; @@ -1467,9 +1503,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // fast path: small frame, no backpressure, attempt to send without allocating if (!str.is16Bit() and str.len < stack_frame_size) { const bytes = Copy{ .latin1 = str.slice() }; - const frame_size = WebsocketHeader.frameSizeIncludingMask(str.len); + var byte_len: usize = 0; + const frame_size = bytes.len(&byte_len); if (!this.hasBackpressure() and frame_size < stack_frame_size) { - bytes.copy(this.globalThis, inline_buf[0..frame_size], str.len); + bytes.copy(this.globalThis, inline_buf[0..frame_size], byte_len); _ = this.enqueueEncodedBytes(this.tcp, inline_buf[0..frame_size]); return; } @@ -1521,6 +1558,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { this.sendCloseWithBody(this.tcp, code, null, 0); } + const InitialDataHandler = struct { + adopted: ?*WebSocket, + ws: *CppWebSocket, + slice: []u8, + + pub const Handle = JSC.AnyTask.New(@This(), handle); + + pub fn handleWithoutDeinit(this: *@This()) void { + var this_socket = this.adopted orelse return; + this.adopted = null; + this_socket.initial_data_handler = null; + var ws = this.ws; + defer ws.unref(); + + if (this_socket.outgoing_websocket != null) + this_socket.handleData(this_socket.tcp, this.slice); + } + + pub fn handle(this: *@This()) void { + defer { + bun.default_allocator.free(this.slice); + bun.default_allocator.destroy(this); + } + this.handleWithoutDeinit(); + } + }; + pub fn init( outgoing: *CppWebSocket, input_socket: *anyopaque, @@ -1550,33 +1614,19 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { - const InitialDataHandler = struct { - adopted: *WebSocket, - slice: []u8, - task: JSC.AnyTask = undefined, - - pub const Handle = JSC.AnyTask.New(@This(), handle); - - pub fn handle(this: *@This()) void { - defer { - bun.default_allocator.free(this.slice); - bun.default_allocator.destroy(this); - } - - this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; - var writable = this.adopted.receive_buffer.writableSlice(0); - @memcpy(writable.ptr, this.slice.ptr, this.slice.len); - - this.adopted.handleData(this.adopted.tcp, writable); - } - }; var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ .adopted = adopted, .slice = buffered_slice, + .ws = outgoing, }; - initial_data.task = InitialDataHandler.Handle.init(initial_data); - globalThis.bunVM().eventLoop().enqueueTask(JSC.Task.init(&initial_data.task)); + + // Use a higher-priority callback for the initial onData handler + globalThis.queueMicrotaskCallback(initial_data, InitialDataHandler.handle); + + // We need to ref the outgoing websocket so that it doesn't get finalized + // before the initial data handler is called + outgoing.ref(); } return @ptrCast( *anyopaque, |