aboutsummaryrefslogtreecommitdiff
path: root/src/s2n.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/s2n.zig')
-rw-r--r--src/s2n.zig246
1 files changed, 218 insertions, 28 deletions
diff --git a/src/s2n.zig b/src/s2n.zig
index fd8d23cd0..3c8d13014 100644
--- a/src/s2n.zig
+++ b/src/s2n.zig
@@ -1,19 +1,193 @@
pub usingnamespace @import("std").zig.c_builtins;
const std = @import("std");
+const Output = @import("./global.zig").Output;
+
+const alpn_protocols = "http/1.1";
+
+pub inline fn s2nassert(value: c_int) void {
+ std.debug.assert(value == 0);
+}
+
pub fn boot(allcoator: *std.mem.Allocator) void {
if (booted) return;
booted = true;
Allocator.allocator = allcoator;
+ CacheStore.instance = CacheStore{ .allocator = allcoator, .map = @TypeOf(CacheStore.instance.map).init(allcoator) };
+
+ // Important for any website using Cloudflare.
+ // Do to our modifications in the library, this must be called _first_
+ // Before we initialize s2n.
+ // It can never be changed after initialization or we risk undefined memory bugs.
+ if (s2n_get_highest_fully_supported_tls_version() == S2N_TLS13) {
+ // This conditional should always return true since we statically compile libCrypto.
+ _ = s2n_enable_tls13();
+
+ // Sadly, this TLS 1.3 implementation is slower than TLS 1.2.
+ // ❯ hyperfine "./fetch https://example.com" "./fetchtls13 https://example.com"
+ // Benchmark #1: ./fetch https://example.com
+ // Time (mean ± σ): 83.6 ms ± 5.4 ms [User: 15.1 ms, System: 4.7 ms]
+ // Range (min … max): 73.5 ms … 97.5 ms 35 runs
+
+ // Benchmark #2: ./fetchtls13 https://example.com
+ // Time (mean ± σ): 94.9 ms ± 3.2 ms [User: 15.8 ms, System: 4.8 ms]
+ // Range (min … max): 90.7 ms … 104.6 ms 29 runs
- _ = s2n_disable_atexit();
- _ = s2n_mem_set_callbacks(Allocator.initCallback, Allocator.deinitCallback, Allocator.mallocCallback, Allocator.freeCallback);
+ // Summary
+ // './fetch https://example.com' ran
+ // 1.14 ± 0.08 times faster than './fetchtls13 https://example.com'
- _ = s2n_init();
- global_s2n_config = s2n_config_new();
- _ = s2n_config_disable_x509_verification(global_s2n_config);
+ }
+
+ // We don't actually need the memory allocator, it will automatically use mimalloc...I don't know how!
+ // Also, the implementation
+ // s2nassert(s2n_mem_set_callbacks(Allocator.initCallback, Allocator.deinitCallback, Allocator.mallocCallback, Allocator.freeCallback));
+
+ s2nassert(s2n_disable_atexit());
+ s2nassert(s2n_init());
+ global_s2n_config = s2n_fetch_default_config();
+ // s2nassert(s2n_config_set_verify_host_callback(global_s2n_config, verify_host_callback, null));
+ // s2nassert(s2n_config_set_check_stapled_ocsp_response(global_s2n_config, 0));
+ // s2nassert(s2n_config_set_cipher_preferences(global_s2n_config, "default"));
+ // s2nassert(s2n_config_disable_x509_verification(global_s2n_config));
+ var protocol: [*c]const u8 = "http/1.1";
+ var protocols = &protocol;
+ s2nassert(s2n_config_set_protocol_preferences(global_s2n_config, protocols, 1));
+ s2nassert(s2n_config_send_max_fragment_length(global_s2n_config, S2N_TLS_MAX_FRAG_LEN_4096));
+
+ // s2n_config_set_ticket_decrypt_key_lifetime(global_s2n_config, 9999999);
+
+ s2nassert(
+ s2n_config_set_cache_store_callback(global_s2n_config, CacheStore.store, &CacheStore.instance),
+ );
+ s2nassert(
+ s2n_config_set_cache_retrieve_callback(global_s2n_config, CacheStore.retrieve, &CacheStore.instance),
+ );
+ s2nassert(
+ s2n_config_set_cache_delete_callback(global_s2n_config, CacheStore.delete, &CacheStore.instance),
+ );
+ s2nassert(
+ s2n_config_set_session_cache_onoff(global_s2n_config, 1),
+ );
+
+ // s2nassert(s2n_config_init_session_ticket_keys());
+ // s2nassert(s2n_config_set_client_auth_type(global_s2n_config, S2N_STATUS_REQUEST_NONE));
}
+pub const CacheStore = struct {
+ const CacheEntry = struct {
+ key: []u8,
+ value: []u8,
+ seconds: u64,
+
+ pub fn init(
+ allocator: *std.mem.Allocator,
+ key: *const c_void,
+ size: u64,
+ value: *const c_void,
+ value_size: u64,
+ seconds: u64,
+ ) CacheEntry {
+ const key_bytes = keyBytes(key, size);
+ const value_bytes = keyBytes(key, value_size);
+
+ var total_bytes = allocator.alloc(u8, key_bytes.len + value_bytes.len) catch unreachable;
+ @memcpy(total_bytes.ptr, key_bytes.ptr, key_bytes.len);
+ @memcpy(total_bytes[key_bytes.len..].ptr, value_bytes.ptr, value_bytes.len);
+
+ return CacheEntry{ .key = total_bytes[0..key_bytes.len], .value = total_bytes[key_bytes.len..], .seconds = seconds };
+ }
+ };
+
+ const Context = struct {
+ pub fn hash(this: @This(), key: u64) u64 {
+ return key;
+ }
+
+ pub fn eql(this: @This(), a: u64, b: u64) bool {
+ return a == b;
+ }
+ };
+
+ allocator: *std.mem.Allocator,
+ map: std.HashMap(u64, CacheEntry, Context, 80),
+
+ pub inline fn keyBytes(key: *const c_void, size: u64) []u8 {
+ const ptr = @intToPtr([*]u8, @ptrToInt(key));
+
+ return ptr[0..size];
+ }
+
+ inline fn hashKey(key: *const c_void, size: u64) u64 {
+ const bytes = keyBytes(key, size);
+ return std.hash.Wyhash.hash(0, bytes);
+ }
+
+ pub fn retrieve(
+ conn: *s2n_connection,
+ ctx: ?*c_void,
+ key: *const c_void,
+ key_size: u64,
+ value: *c_void,
+ value_size: *u64,
+ ) callconv(.C) c_int {
+ const hash = hashKey(key, key_size);
+
+ if (instance.map.getAdapted(hash, Context{})) |entry| {
+ const now = @intCast(usize, std.time.timestamp());
+ if (now > entry.seconds) {
+ _ = instance.map.removeAdapted(hash, Context{});
+ return 0;
+ }
+
+ var value_bytes = keyBytes(value, value_size.*);
+ if (value_bytes.len < entry.value.len) return -1;
+ std.mem.copy(u8, value_bytes, entry.value);
+ value_size.* = entry.value.len;
+ return 0;
+ }
+
+ return 0;
+ }
+
+ pub fn store(
+ conn: *s2n_connection,
+ ctx: ?*c_void,
+ seconds: u64,
+ key: *const c_void,
+ key_size: u64,
+ value: *const c_void,
+ value_size: u64,
+ ) callconv(.C) c_int {
+ var map_entry = instance.map.getOrPutAdapted(hashKey(key, key_size), Context{}) catch unreachable;
+
+ if (!map_entry.found_existing) {
+ map_entry.value_ptr.* = CacheEntry.init(instance.allocator, key, key_size, value, value_size, @intCast(usize, std.time.timestamp()) + seconds);
+ }
+
+ return S2N_SUCCESS;
+ }
+
+ pub fn delete(
+ conn: *s2n_connection,
+ ctx: ?*c_void,
+ key: *const c_void,
+ key_size: u64,
+ ) callconv(.C) c_int {
+ _ = instance.map.remove(hashKey(key, key_size));
+ return 0;
+ }
+
+ pub var instance: CacheStore = undefined;
+};
+
+pub fn verify_host_callback(ptr: [*c]const u8, len: usize, ctx: ?*c_void) callconv(.C) u8 {
+ return 1;
+}
+
+pub extern fn s2n_enable_tls13() c_int;
+pub extern fn s2n_fetch_default_config() *s2n_config;
+pub extern fn s2n_get_highest_fully_supported_tls_version() c_int;
pub extern fn s2n_errno_location() [*c]c_int;
pub const S2N_ERR_T_OK: c_int = 0;
pub const S2N_ERR_T_IO: c_int = 1;
@@ -37,9 +211,9 @@ pub extern fn s2n_config_free(config: *struct_s2n_config) c_int;
pub extern fn s2n_config_free_dhparams(config: *struct_s2n_config) c_int;
pub extern fn s2n_config_free_cert_chain_and_key(config: *struct_s2n_config) c_int;
pub const s2n_clock_time_nanoseconds = ?fn (?*c_void, [*c]u64) callconv(.C) c_int;
-pub const s2n_cache_retrieve_callback = ?fn (*struct_s2n_connection, ?*c_void, ?*const c_void, u64, ?*c_void, [*c]u64) callconv(.C) c_int;
-pub const s2n_cache_store_callback = ?fn (*struct_s2n_connection, ?*c_void, u64, ?*const c_void, u64, ?*const c_void, u64) callconv(.C) c_int;
-pub const s2n_cache_delete_callback = ?fn (*struct_s2n_connection, ?*c_void, ?*const c_void, u64) callconv(.C) c_int;
+pub const s2n_cache_retrieve_callback = ?fn (*struct_s2n_connection, ?*c_void, *const c_void, u64, *c_void, *u64) callconv(.C) c_int;
+pub const s2n_cache_store_callback = ?fn (*struct_s2n_connection, ?*c_void, u64, *const c_void, u64, *const c_void, u64) callconv(.C) c_int;
+pub const s2n_cache_delete_callback = ?fn (*struct_s2n_connection, ?*c_void, *const c_void, u64) callconv(.C) c_int;
pub extern fn s2n_config_set_wall_clock(config: *struct_s2n_config, clock_fn: s2n_clock_time_nanoseconds, ctx: ?*c_void) c_int;
pub extern fn s2n_config_set_monotonic_clock(config: *struct_s2n_config, clock_fn: s2n_clock_time_nanoseconds, ctx: ?*c_void) c_int;
pub extern fn s2n_strerror(@"error": c_int, lang: [*c]const u8) [*c]const u8;
@@ -164,8 +338,8 @@ pub extern fn s2n_connection_set_write_fd(conn: *struct_s2n_connection, writefd:
pub extern fn s2n_connection_get_read_fd(conn: *struct_s2n_connection, readfd: [*c]c_int) c_int;
pub extern fn s2n_connection_get_write_fd(conn: *struct_s2n_connection, writefd: [*c]c_int) c_int;
pub extern fn s2n_connection_use_corked_io(conn: *struct_s2n_connection) c_int;
-pub const s2n_recv_fn = fn (?*c_void, [*c]u8, u32) callconv(.C) c_int;
-pub const s2n_send_fn = fn (?*c_void, [*c]const u8, u32) callconv(.C) c_int;
+pub const s2n_recv_fn = fn (*s2n_connection, [*c]u8, u32) callconv(.C) c_int;
+pub const s2n_send_fn = fn (*s2n_connection, [*c]const u8, u32) callconv(.C) c_int;
pub extern fn s2n_connection_set_recv_ctx(conn: *struct_s2n_connection, ctx: ?*c_void) c_int;
pub extern fn s2n_connection_set_send_ctx(conn: *struct_s2n_connection, ctx: ?*c_void) c_int;
pub extern fn s2n_connection_set_recv_cb(conn: *struct_s2n_connection, recv: ?s2n_recv_fn) c_int;
@@ -395,6 +569,7 @@ pub const Connection = struct {
conn: *s2n_connection = undefined,
fd: std.os.socket_t,
node: *Pool.List.Node,
+ disable_shutdown: bool = false,
pub const Pool = struct {
pub const List = std.SinglyLinkedList(*s2n_connection);
@@ -429,27 +604,40 @@ pub const Connection = struct {
const errno = s2nErrorNo;
- pub fn s2n_recv_function(conn: *s2n_connection, buf: *c, len: u32) c_int {
- return std.os.recv(fd, buf, len, 0);
- }
- pub fn s2n_send_function(conn: *s2n_connection, buf: *c, len: u32) c_int {
- return std.os.send(fd, buf, SOCKET_FLAGS);
- }
+ // pub fn s2n_recv_function(conn: *s2n_connection, buf: [*c]u8, len: u32) callconv(.C) c_int {
+ // if (buf == null) return 0;
+ // var fd: c_int = 0;
+ // _ = s2n_connection_get_read_fd(conn, &fd);
+ // return @intCast(c_int, std.os.system.recvfrom(fd, buf, len, std.os.SOCK_CLOEXEC, null, null));
+ // }
+ // pub fn s2n_send_function(conn: *s2n_connection, buf: [*c]const u8, len: u32) callconv(.C) c_int {
+ // if (buf == null) return 0;
+ // var fd: c_int = 0;
+ // _ = s2n_connection_get_write_fd(conn, &fd);
+
+ // return @intCast(c_int, std.os.system.sendto(fd, buf.?, len, std.os.SOCK_CLOEXEC, null, 0));
+ // }
- pub fn start(this: *Connection) !void {
+ pub fn start(this: *Connection, server_name: [:0]const u8) !void {
this.node = Pool.get();
this.conn = this.node.data;
- _ = s2n_connection_set_config(this.conn, global_s2n_config);
- _ = s2n_connection_set_fd(this.conn, @intCast(c_int, this.fd));
- _ = s2n_connection_set_blinding(this.conn, S2N_SELF_SERVICE_BLINDING);
- _ = s2n_connection_prefer_low_latency(this.conn);
- _ = s2n_connection_set_ctx(this.conn, this);
-
- s2n_connection_set_recv_cb(this.conn, s2n_recv_function);
- s2n_connection_set_send_cb(this.conn, s2n_send_function);
+ s2nassert(s2n_connection_set_ctx(this.conn, this));
+ s2nassert(s2n_connection_set_config(this.conn, global_s2n_config));
+ s2nassert(s2n_connection_set_read_fd(this.conn, @intCast(c_int, this.fd)));
+ s2nassert(s2n_connection_set_write_fd(this.conn, @intCast(c_int, this.fd)));
+ s2nassert(s2n_connection_set_blinding(this.conn, S2N_SELF_SERVICE_BLINDING));
+ // s2nassert(s2n_connection_set_dynamic_record(this.conn));
+ s2nassert(s2n_set_server_name(this.conn, server_name.ptr));
+
+ // _ = s2n_connection_set_recv_cb(this.conn, s2n_recv_function);
+ // _ = s2n_connection_set_send_cb(this.conn, s2n_send_function);
const rc = s2n_negotiate(this.conn, &blocked_status);
+ if (rc < 0) {
+ Output.printErrorln("Alert: {d}", .{s2n_connection_get_alert(this.conn)});
+ Output.prettyErrorln("ERROR: {s}", .{s2n_strerror_debug(rc, "EN")});
+ }
- defer s2n_connection_free_handshake(this.conn);
+ defer s2nassert(s2n_connection_free_handshake(this.conn));
switch (try s2nErrorNo(rc)) {
.SUCCESS => return,
@@ -468,8 +656,10 @@ pub const Connection = struct {
}
pub fn close(this: *Connection) !void {
- _ = s2n_shutdown(this.conn, &blocked_status);
- Pool.put(this.node);
+ if (!this.disable_shutdown) {
+ _ = s2n_shutdown(this.conn, &blocked_status);
+ Pool.put(this.node);
+ }
std.os.closeSocket(this.fd);
}