diff options
Diffstat (limited to 'src/deps/uws.zig')
-rw-r--r-- | src/deps/uws.zig | 319 |
1 files changed, 288 insertions, 31 deletions
diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 538756b71..83edbe410 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -23,7 +23,7 @@ fn NativeSocketHandleType(comptime ssl: bool) type { } pub fn NewSocketHandler(comptime ssl: bool) type { return struct { - const ssl_int: i32 = @boolToInt(ssl); + const ssl_int: i32 = @intFromBool(ssl); socket: *Socket, const ThisSocket = @This(); @@ -40,6 +40,120 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return us_socket_timeout(comptime ssl_int, this.socket, seconds); } + pub fn startTLS(this: ThisSocket, is_client: bool) void { + _ = us_socket_open(comptime ssl_int, this.socket, @intFromBool(is_client), null, 0); + } + + // Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context + // context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext + pub fn wrapTLS( + this: ThisSocket, + options: us_bun_socket_context_options_t, + socket_ext_size: i32, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) ?NewSocketHandler(true) { + const TLSSocket = NewSocketHandler(true); + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(1, socket).?; + } + + if (comptime deref_) { + return (TLSSocket{ .socket = socket }).ext(ContextType).?.*; + } + + return (TLSSocket{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + TLSSocket{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + TLSSocket{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), TLSSocket{ .socket = socket }, success, verify_error); + } + }; + + var events: us_socket_events_t = .{ + .on_open = SocketHandler.on_open, + .on_close = SocketHandler.on_close, + .on_data = SocketHandler.on_data, + .on_writable = SocketHandler.on_writable, + .on_timeout = SocketHandler.on_timeout, + .on_connect_error = SocketHandler.on_connect_error, + .on_end = SocketHandler.on_end, + .on_handshake = SocketHandler.on_handshake, + }; + + const socket = us_socket_wrap_with_tls(ssl_int, this.socket, options, events, socket_ext_size) orelse return null; + return NewSocketHandler(true).from(socket); + } + pub fn getNativeHandle(this: ThisSocket) *NativeSocketHandleType(ssl) { return @ptrCast(*NativeSocketHandleType(ssl), us_socket_get_native_handle(comptime ssl_int, this.socket).?); } @@ -49,7 +163,7 @@ pub fn NewSocketHandler(comptime ssl: bool) type { @compileError("SSL sockets do not have a file descriptor accessible this way"); } - return @intCast(i32, @ptrToInt(us_socket_get_native_handle(0, this.socket))); + return @intCast(i32, @intFromPtr(us_socket_get_native_handle(0, this.socket))); } pub fn markNeedsMoreForSendfile(this: ThisSocket) void { @@ -92,18 +206,29 @@ pub fn NewSocketHandler(comptime ssl: bool) type { data.ptr, // truncate to 31 bits since sign bit exists @intCast(i32, @truncate(u31, data.len)), - @as(i32, @boolToInt(msg_more)), + @as(i32, @intFromBool(msg_more)), + ); + } + + pub fn rawWrite(this: ThisSocket, data: []const u8, msg_more: bool) i32 { + return us_socket_raw_write( + comptime ssl_int, + this.socket, + data.ptr, + // truncate to 31 bits since sign bit exists + @intCast(i32, @truncate(u31, data.len)), + @as(i32, @intFromBool(msg_more)), ); } pub fn shutdown(this: ThisSocket) void { - debug("us_socket_shutdown({d})", .{@ptrToInt(this.socket)}); + debug("us_socket_shutdown({d})", .{@intFromPtr(this.socket)}); return us_socket_shutdown( comptime ssl_int, this.socket, ); } pub fn shutdownRead(this: ThisSocket) void { - debug("us_socket_shutdown_read({d})", .{@ptrToInt(this.socket)}); + debug("us_socket_shutdown_read({d})", .{@intFromPtr(this.socket)}); return us_socket_shutdown_read( comptime ssl_int, this.socket, @@ -122,7 +247,7 @@ pub fn NewSocketHandler(comptime ssl: bool) type { ) > 0; } pub fn close(this: ThisSocket, code: i32, reason: ?*anyopaque) void { - debug("us_socket_close({d})", .{@ptrToInt(this.socket)}); + debug("us_socket_close({d})", .{@intFromPtr(this.socket)}); _ = us_socket_close( comptime ssl_int, this.socket, @@ -241,13 +366,126 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return socket_; } + pub fn unsafeConfigure( + ctx: *SocketContext, + comptime ssl_type: bool, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) void { + const SocketHandlerType = NewSocketHandler(ssl_type); + const ssl_type_int: i32 = @intFromBool(ssl_type); + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(ssl_type_int, socket).?; + } + + if (comptime deref_) { + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?.*; + } + + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + SocketHandlerType{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + SocketHandlerType{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), SocketHandlerType{ .socket = socket }, success, verify_error); + } + }; + + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) + us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) + us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) + us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) + us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) + us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) + us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) + us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) + us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); + } + pub fn configure( ctx: *SocketContext, comptime deref: bool, comptime ContextType: type, comptime Fields: anytype, ) void { - const @"type" = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; const SocketHandler = struct { const alignment = if (ContextType == anyopaque) @@ -333,21 +571,21 @@ pub fn NewSocketHandler(comptime ssl: bool) type { } }; - if (comptime @hasDecl(@"type", "onOpen") and @typeInfo(@TypeOf(@"type".onOpen)) != .Null) + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); - if (comptime @hasDecl(@"type", "onClose") and @typeInfo(@TypeOf(@"type".onClose)) != .Null) + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); - if (comptime @hasDecl(@"type", "onData") and @typeInfo(@TypeOf(@"type".onData)) != .Null) + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); - if (comptime @hasDecl(@"type", "onWritable") and @typeInfo(@TypeOf(@"type".onWritable)) != .Null) + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); - if (comptime @hasDecl(@"type", "onTimeout") and @typeInfo(@TypeOf(@"type".onTimeout)) != .Null) + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); - if (comptime @hasDecl(@"type", "onConnectError") and @typeInfo(@TypeOf(@"type".onConnectError)) != .Null) + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); - if (comptime @hasDecl(@"type", "onEnd") and @typeInfo(@TypeOf(@"type".onEnd)) != .Null) + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); - if (comptime @hasDecl(@"type", "onHandshake") and @typeInfo(@TypeOf(@"type".onHandshake)) != .Null) + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); } @@ -421,7 +659,7 @@ pub const Timer = opaque { pub const SocketContext = opaque { pub fn getNativeHandle(this: *SocketContext, comptime ssl: bool) *anyopaque { - return us_socket_context_get_native_handle(comptime @as(i32, @boolToInt(ssl)), this).?; + return us_socket_context_get_native_handle(comptime @as(i32, @intFromBool(ssl)), this).?; } fn _deinit_ssl(this: *SocketContext) void { @@ -446,8 +684,8 @@ pub const SocketContext = opaque { } pub fn close(this: *SocketContext, ssl: bool) void { - debug("us_socket_context_close({d})", .{@ptrToInt(this)}); - us_socket_context_close(@as(i32, @boolToInt(ssl)), this); + debug("us_socket_context_close({d})", .{@intFromPtr(this)}); + us_socket_context_close(@as(i32, @intFromBool(ssl)), this); } pub fn ext(this: *SocketContext, ssl: bool, comptime ContextType: type) ?*ContextType { @@ -457,7 +695,7 @@ pub const SocketContext = opaque { std.meta.alignment(ContextType); var ptr = us_socket_context_ext( - @boolToInt(ssl), + @intFromBool(ssl), this, ) orelse return null; @@ -659,6 +897,20 @@ pub const us_bun_verify_error_t = extern struct { reason: [*c]const u8 = null, }; +pub const us_socket_events_t = extern struct { + on_open: ?*const fn (*Socket, i32, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_data: ?*const fn (*Socket, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_writable: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_close: ?*const fn (*Socket, i32, ?*anyopaque) callconv(.C) ?*Socket = null, + + on_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_long_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_end: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_connect_error: ?*const fn (*Socket, i32) callconv(.C) ?*Socket = null, + on_handshake: ?*const fn (*Socket, i32, us_bun_verify_error_t, ?*anyopaque) callconv(.C) void = null, +}; + +pub extern fn us_socket_wrap_with_tls(ssl: i32, s: *Socket, options: us_bun_socket_context_options_t, events: us_socket_events_t, socket_ext_size: i32) ?*Socket; extern fn us_socket_verify_error(ssl: i32, context: *Socket) us_bun_verify_error_t; extern fn SocketContextimestamp(ssl: i32, context: ?*SocketContext) c_ushort; pub extern fn us_socket_context_add_server_name(ssl: i32, context: ?*SocketContext, hostname_pattern: [*c]const u8, options: us_socket_context_options_t, ?*anyopaque) void; @@ -700,7 +952,7 @@ pub const Poll = opaque { fallthrough: bool, flags: Flags, ) ?*Poll { - var poll = us_create_poll(loop, @as(i32, @boolToInt(fallthrough)), @sizeOf(Data)); + var poll = us_create_poll(loop, @as(i32, @intFromBool(fallthrough)), @sizeOf(Data)); if (comptime Data != void) { poll.data(Data).* = val; } @@ -777,11 +1029,16 @@ extern fn us_socket_ext(ssl: i32, s: ?*Socket) ?*anyopaque; extern fn us_socket_context(ssl: i32, s: ?*Socket) ?*SocketContext; extern fn us_socket_flush(ssl: i32, s: ?*Socket) void; extern fn us_socket_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; +extern fn us_socket_raw_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; extern fn us_socket_shutdown(ssl: i32, s: ?*Socket) void; extern fn us_socket_shutdown_read(ssl: i32, s: ?*Socket) void; extern fn us_socket_is_shut_down(ssl: i32, s: ?*Socket) i32; extern fn us_socket_is_closed(ssl: i32, s: ?*Socket) i32; extern fn us_socket_close(ssl: i32, s: ?*Socket, code: i32, reason: ?*anyopaque) ?*Socket; +// if a TLS socket calls this, it will start SSL instance and call open event will also do TLS handshake if required +// will have no effect if the socket is closed or is not TLS +extern fn us_socket_open(ssl: i32, s: ?*Socket, is_client: i32, ip: [*c]const u8, ip_length: i32) ?*Socket; + extern fn us_socket_local_port(ssl: i32, s: ?*Socket) i32; extern fn us_socket_remote_address(ssl: i32, s: ?*Socket, buf: [*c]u8, length: [*c]i32) void; pub const uws_app_s = opaque {}; @@ -810,7 +1067,7 @@ pub const AnyWebSocket = union(enum) { } pub fn close(this: AnyWebSocket) void { - const ssl_flag = @boolToInt(this == .ssl); + const ssl_flag = @intFromBool(this == .ssl); return uws_ws_close(ssl_flag, this.raw()); } @@ -874,7 +1131,7 @@ pub const AnyWebSocket = union(enum) { } pub fn publishWithOptions(ssl: bool, app: *anyopaque, topic: []const u8, message: []const u8, opcode: Opcode, compress: bool) bool { return uws_publish( - @boolToInt(ssl), + @intFromBool(ssl), @ptrCast(*uws_app_t, app), topic.ptr, topic.len, @@ -1073,10 +1330,10 @@ pub const Request = opaque { pub const ListenSocket = opaque { pub fn close(this: *ListenSocket, ssl: bool) void { - us_listen_socket_close(@boolToInt(ssl), this); + us_listen_socket_close(@intFromBool(ssl), this); } pub fn getLocalPort(this: *ListenSocket, ssl: bool) i32 { - return us_socket_local_port(@boolToInt(ssl), @ptrCast(*uws.Socket, this)); + return us_socket_local_port(@intFromBool(ssl), @ptrCast(*uws.Socket, this)); } }; extern fn us_listen_socket_close(ssl: i32, ls: *ListenSocket) void; @@ -1085,7 +1342,7 @@ extern fn us_socket_context_close(ssl: i32, ctx: *anyopaque) void; pub fn NewApp(comptime ssl: bool) type { return opaque { - const ssl_flag = @as(i32, @boolToInt(ssl)); + const ssl_flag = @as(i32, @intFromBool(ssl)); const ThisApp = @This(); pub fn close(this: *ThisApp) void { @@ -1428,7 +1685,7 @@ pub fn NewApp(comptime ssl: bool) type { } pub fn getNativeHandle(res: *Response) i32 { - return @intCast(i32, @ptrToInt(uws_res_get_native_handle(ssl_flag, res.downcast()))); + return @intCast(i32, @intFromPtr(uws_res_get_native_handle(ssl_flag, res.downcast()))); } pub fn onWritable( res: *Response, @@ -1880,23 +2137,23 @@ pub const State = enum(i32) { _, pub inline fn isResponsePending(this: State) bool { - return @enumToInt(this) & @enumToInt(State.HTTP_RESPONSE_PENDING) != 0; + return @intFromEnum(this) & @intFromEnum(State.HTTP_RESPONSE_PENDING) != 0; } pub inline fn isHttpEndCalled(this: State) bool { - return @enumToInt(this) & @enumToInt(State.HTTP_END_CALLED) != 0; + return @intFromEnum(this) & @intFromEnum(State.HTTP_END_CALLED) != 0; } pub inline fn isHttpWriteCalled(this: State) bool { - return @enumToInt(this) & @enumToInt(State.HTTP_WRITE_CALLED) != 0; + return @intFromEnum(this) & @intFromEnum(State.HTTP_WRITE_CALLED) != 0; } pub inline fn isHttpStatusCalled(this: State) bool { - return @enumToInt(this) & @enumToInt(State.HTTP_STATUS_CALLED) != 0; + return @intFromEnum(this) & @intFromEnum(State.HTTP_STATUS_CALLED) != 0; } pub inline fn isHttpConnectionClose(this: State) bool { - return @enumToInt(this) & @enumToInt(State.HTTP_CONNECTION_CLOSE) != 0; + return @intFromEnum(this) & @intFromEnum(State.HTTP_CONNECTION_CLOSE) != 0; } }; |