diff options
Diffstat (limited to 'src/s2n.zig')
-rw-r--r-- | src/s2n.zig | 246 |
1 files changed, 218 insertions, 28 deletions
diff --git a/src/s2n.zig b/src/s2n.zig index fd8d23cd0..3c8d13014 100644 --- a/src/s2n.zig +++ b/src/s2n.zig @@ -1,19 +1,193 @@ pub usingnamespace @import("std").zig.c_builtins; const std = @import("std"); +const Output = @import("./global.zig").Output; + +const alpn_protocols = "http/1.1"; + +pub inline fn s2nassert(value: c_int) void { + std.debug.assert(value == 0); +} + pub fn boot(allcoator: *std.mem.Allocator) void { if (booted) return; booted = true; Allocator.allocator = allcoator; + CacheStore.instance = CacheStore{ .allocator = allcoator, .map = @TypeOf(CacheStore.instance.map).init(allcoator) }; + + // Important for any website using Cloudflare. + // Do to our modifications in the library, this must be called _first_ + // Before we initialize s2n. + // It can never be changed after initialization or we risk undefined memory bugs. + if (s2n_get_highest_fully_supported_tls_version() == S2N_TLS13) { + // This conditional should always return true since we statically compile libCrypto. + _ = s2n_enable_tls13(); + + // Sadly, this TLS 1.3 implementation is slower than TLS 1.2. + // ❯ hyperfine "./fetch https://example.com" "./fetchtls13 https://example.com" + // Benchmark #1: ./fetch https://example.com + // Time (mean ± σ): 83.6 ms ± 5.4 ms [User: 15.1 ms, System: 4.7 ms] + // Range (min … max): 73.5 ms … 97.5 ms 35 runs + + // Benchmark #2: ./fetchtls13 https://example.com + // Time (mean ± σ): 94.9 ms ± 3.2 ms [User: 15.8 ms, System: 4.8 ms] + // Range (min … max): 90.7 ms … 104.6 ms 29 runs - _ = s2n_disable_atexit(); - _ = s2n_mem_set_callbacks(Allocator.initCallback, Allocator.deinitCallback, Allocator.mallocCallback, Allocator.freeCallback); + // Summary + // './fetch https://example.com' ran + // 1.14 ± 0.08 times faster than './fetchtls13 https://example.com' - _ = s2n_init(); - global_s2n_config = s2n_config_new(); - _ = s2n_config_disable_x509_verification(global_s2n_config); + } + + // We don't actually need the memory allocator, it will automatically use mimalloc...I don't know how! + // Also, the implementation + // s2nassert(s2n_mem_set_callbacks(Allocator.initCallback, Allocator.deinitCallback, Allocator.mallocCallback, Allocator.freeCallback)); + + s2nassert(s2n_disable_atexit()); + s2nassert(s2n_init()); + global_s2n_config = s2n_fetch_default_config(); + // s2nassert(s2n_config_set_verify_host_callback(global_s2n_config, verify_host_callback, null)); + // s2nassert(s2n_config_set_check_stapled_ocsp_response(global_s2n_config, 0)); + // s2nassert(s2n_config_set_cipher_preferences(global_s2n_config, "default")); + // s2nassert(s2n_config_disable_x509_verification(global_s2n_config)); + var protocol: [*c]const u8 = "http/1.1"; + var protocols = &protocol; + s2nassert(s2n_config_set_protocol_preferences(global_s2n_config, protocols, 1)); + s2nassert(s2n_config_send_max_fragment_length(global_s2n_config, S2N_TLS_MAX_FRAG_LEN_4096)); + + // s2n_config_set_ticket_decrypt_key_lifetime(global_s2n_config, 9999999); + + s2nassert( + s2n_config_set_cache_store_callback(global_s2n_config, CacheStore.store, &CacheStore.instance), + ); + s2nassert( + s2n_config_set_cache_retrieve_callback(global_s2n_config, CacheStore.retrieve, &CacheStore.instance), + ); + s2nassert( + s2n_config_set_cache_delete_callback(global_s2n_config, CacheStore.delete, &CacheStore.instance), + ); + s2nassert( + s2n_config_set_session_cache_onoff(global_s2n_config, 1), + ); + + // s2nassert(s2n_config_init_session_ticket_keys()); + // s2nassert(s2n_config_set_client_auth_type(global_s2n_config, S2N_STATUS_REQUEST_NONE)); } +pub const CacheStore = struct { + const CacheEntry = struct { + key: []u8, + value: []u8, + seconds: u64, + + pub fn init( + allocator: *std.mem.Allocator, + key: *const c_void, + size: u64, + value: *const c_void, + value_size: u64, + seconds: u64, + ) CacheEntry { + const key_bytes = keyBytes(key, size); + const value_bytes = keyBytes(key, value_size); + + var total_bytes = allocator.alloc(u8, key_bytes.len + value_bytes.len) catch unreachable; + @memcpy(total_bytes.ptr, key_bytes.ptr, key_bytes.len); + @memcpy(total_bytes[key_bytes.len..].ptr, value_bytes.ptr, value_bytes.len); + + return CacheEntry{ .key = total_bytes[0..key_bytes.len], .value = total_bytes[key_bytes.len..], .seconds = seconds }; + } + }; + + const Context = struct { + pub fn hash(this: @This(), key: u64) u64 { + return key; + } + + pub fn eql(this: @This(), a: u64, b: u64) bool { + return a == b; + } + }; + + allocator: *std.mem.Allocator, + map: std.HashMap(u64, CacheEntry, Context, 80), + + pub inline fn keyBytes(key: *const c_void, size: u64) []u8 { + const ptr = @intToPtr([*]u8, @ptrToInt(key)); + + return ptr[0..size]; + } + + inline fn hashKey(key: *const c_void, size: u64) u64 { + const bytes = keyBytes(key, size); + return std.hash.Wyhash.hash(0, bytes); + } + + pub fn retrieve( + conn: *s2n_connection, + ctx: ?*c_void, + key: *const c_void, + key_size: u64, + value: *c_void, + value_size: *u64, + ) callconv(.C) c_int { + const hash = hashKey(key, key_size); + + if (instance.map.getAdapted(hash, Context{})) |entry| { + const now = @intCast(usize, std.time.timestamp()); + if (now > entry.seconds) { + _ = instance.map.removeAdapted(hash, Context{}); + return 0; + } + + var value_bytes = keyBytes(value, value_size.*); + if (value_bytes.len < entry.value.len) return -1; + std.mem.copy(u8, value_bytes, entry.value); + value_size.* = entry.value.len; + return 0; + } + + return 0; + } + + pub fn store( + conn: *s2n_connection, + ctx: ?*c_void, + seconds: u64, + key: *const c_void, + key_size: u64, + value: *const c_void, + value_size: u64, + ) callconv(.C) c_int { + var map_entry = instance.map.getOrPutAdapted(hashKey(key, key_size), Context{}) catch unreachable; + + if (!map_entry.found_existing) { + map_entry.value_ptr.* = CacheEntry.init(instance.allocator, key, key_size, value, value_size, @intCast(usize, std.time.timestamp()) + seconds); + } + + return S2N_SUCCESS; + } + + pub fn delete( + conn: *s2n_connection, + ctx: ?*c_void, + key: *const c_void, + key_size: u64, + ) callconv(.C) c_int { + _ = instance.map.remove(hashKey(key, key_size)); + return 0; + } + + pub var instance: CacheStore = undefined; +}; + +pub fn verify_host_callback(ptr: [*c]const u8, len: usize, ctx: ?*c_void) callconv(.C) u8 { + return 1; +} + +pub extern fn s2n_enable_tls13() c_int; +pub extern fn s2n_fetch_default_config() *s2n_config; +pub extern fn s2n_get_highest_fully_supported_tls_version() c_int; pub extern fn s2n_errno_location() [*c]c_int; pub const S2N_ERR_T_OK: c_int = 0; pub const S2N_ERR_T_IO: c_int = 1; @@ -37,9 +211,9 @@ pub extern fn s2n_config_free(config: *struct_s2n_config) c_int; pub extern fn s2n_config_free_dhparams(config: *struct_s2n_config) c_int; pub extern fn s2n_config_free_cert_chain_and_key(config: *struct_s2n_config) c_int; pub const s2n_clock_time_nanoseconds = ?fn (?*c_void, [*c]u64) callconv(.C) c_int; -pub const s2n_cache_retrieve_callback = ?fn (*struct_s2n_connection, ?*c_void, ?*const c_void, u64, ?*c_void, [*c]u64) callconv(.C) c_int; -pub const s2n_cache_store_callback = ?fn (*struct_s2n_connection, ?*c_void, u64, ?*const c_void, u64, ?*const c_void, u64) callconv(.C) c_int; -pub const s2n_cache_delete_callback = ?fn (*struct_s2n_connection, ?*c_void, ?*const c_void, u64) callconv(.C) c_int; +pub const s2n_cache_retrieve_callback = ?fn (*struct_s2n_connection, ?*c_void, *const c_void, u64, *c_void, *u64) callconv(.C) c_int; +pub const s2n_cache_store_callback = ?fn (*struct_s2n_connection, ?*c_void, u64, *const c_void, u64, *const c_void, u64) callconv(.C) c_int; +pub const s2n_cache_delete_callback = ?fn (*struct_s2n_connection, ?*c_void, *const c_void, u64) callconv(.C) c_int; pub extern fn s2n_config_set_wall_clock(config: *struct_s2n_config, clock_fn: s2n_clock_time_nanoseconds, ctx: ?*c_void) c_int; pub extern fn s2n_config_set_monotonic_clock(config: *struct_s2n_config, clock_fn: s2n_clock_time_nanoseconds, ctx: ?*c_void) c_int; pub extern fn s2n_strerror(@"error": c_int, lang: [*c]const u8) [*c]const u8; @@ -164,8 +338,8 @@ pub extern fn s2n_connection_set_write_fd(conn: *struct_s2n_connection, writefd: pub extern fn s2n_connection_get_read_fd(conn: *struct_s2n_connection, readfd: [*c]c_int) c_int; pub extern fn s2n_connection_get_write_fd(conn: *struct_s2n_connection, writefd: [*c]c_int) c_int; pub extern fn s2n_connection_use_corked_io(conn: *struct_s2n_connection) c_int; -pub const s2n_recv_fn = fn (?*c_void, [*c]u8, u32) callconv(.C) c_int; -pub const s2n_send_fn = fn (?*c_void, [*c]const u8, u32) callconv(.C) c_int; +pub const s2n_recv_fn = fn (*s2n_connection, [*c]u8, u32) callconv(.C) c_int; +pub const s2n_send_fn = fn (*s2n_connection, [*c]const u8, u32) callconv(.C) c_int; pub extern fn s2n_connection_set_recv_ctx(conn: *struct_s2n_connection, ctx: ?*c_void) c_int; pub extern fn s2n_connection_set_send_ctx(conn: *struct_s2n_connection, ctx: ?*c_void) c_int; pub extern fn s2n_connection_set_recv_cb(conn: *struct_s2n_connection, recv: ?s2n_recv_fn) c_int; @@ -395,6 +569,7 @@ pub const Connection = struct { conn: *s2n_connection = undefined, fd: std.os.socket_t, node: *Pool.List.Node, + disable_shutdown: bool = false, pub const Pool = struct { pub const List = std.SinglyLinkedList(*s2n_connection); @@ -429,27 +604,40 @@ pub const Connection = struct { const errno = s2nErrorNo; - pub fn s2n_recv_function(conn: *s2n_connection, buf: *c, len: u32) c_int { - return std.os.recv(fd, buf, len, 0); - } - pub fn s2n_send_function(conn: *s2n_connection, buf: *c, len: u32) c_int { - return std.os.send(fd, buf, SOCKET_FLAGS); - } + // pub fn s2n_recv_function(conn: *s2n_connection, buf: [*c]u8, len: u32) callconv(.C) c_int { + // if (buf == null) return 0; + // var fd: c_int = 0; + // _ = s2n_connection_get_read_fd(conn, &fd); + // return @intCast(c_int, std.os.system.recvfrom(fd, buf, len, std.os.SOCK_CLOEXEC, null, null)); + // } + // pub fn s2n_send_function(conn: *s2n_connection, buf: [*c]const u8, len: u32) callconv(.C) c_int { + // if (buf == null) return 0; + // var fd: c_int = 0; + // _ = s2n_connection_get_write_fd(conn, &fd); + + // return @intCast(c_int, std.os.system.sendto(fd, buf.?, len, std.os.SOCK_CLOEXEC, null, 0)); + // } - pub fn start(this: *Connection) !void { + pub fn start(this: *Connection, server_name: [:0]const u8) !void { this.node = Pool.get(); this.conn = this.node.data; - _ = s2n_connection_set_config(this.conn, global_s2n_config); - _ = s2n_connection_set_fd(this.conn, @intCast(c_int, this.fd)); - _ = s2n_connection_set_blinding(this.conn, S2N_SELF_SERVICE_BLINDING); - _ = s2n_connection_prefer_low_latency(this.conn); - _ = s2n_connection_set_ctx(this.conn, this); - - s2n_connection_set_recv_cb(this.conn, s2n_recv_function); - s2n_connection_set_send_cb(this.conn, s2n_send_function); + s2nassert(s2n_connection_set_ctx(this.conn, this)); + s2nassert(s2n_connection_set_config(this.conn, global_s2n_config)); + s2nassert(s2n_connection_set_read_fd(this.conn, @intCast(c_int, this.fd))); + s2nassert(s2n_connection_set_write_fd(this.conn, @intCast(c_int, this.fd))); + s2nassert(s2n_connection_set_blinding(this.conn, S2N_SELF_SERVICE_BLINDING)); + // s2nassert(s2n_connection_set_dynamic_record(this.conn)); + s2nassert(s2n_set_server_name(this.conn, server_name.ptr)); + + // _ = s2n_connection_set_recv_cb(this.conn, s2n_recv_function); + // _ = s2n_connection_set_send_cb(this.conn, s2n_send_function); const rc = s2n_negotiate(this.conn, &blocked_status); + if (rc < 0) { + Output.printErrorln("Alert: {d}", .{s2n_connection_get_alert(this.conn)}); + Output.prettyErrorln("ERROR: {s}", .{s2n_strerror_debug(rc, "EN")}); + } - defer s2n_connection_free_handshake(this.conn); + defer s2nassert(s2n_connection_free_handshake(this.conn)); switch (try s2nErrorNo(rc)) { .SUCCESS => return, @@ -468,8 +656,10 @@ pub const Connection = struct { } pub fn close(this: *Connection) !void { - _ = s2n_shutdown(this.conn, &blocked_status); - Pool.put(this.node); + if (!this.disable_shutdown) { + _ = s2n_shutdown(this.conn, &blocked_status); + Pool.put(this.node); + } std.os.closeSocket(this.fd); } |