diff options
Diffstat (limited to 'src/bun.js/api/bun/socket.zig')
-rw-r--r-- | src/bun.js/api/bun/socket.zig | 48 |
1 files changed, 41 insertions, 7 deletions
diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 329cc40e4..12a4cffc8 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -820,6 +820,7 @@ pub const Listener = struct { var default_data = socket_config.default_data; var protos: ?[]const u8 = null; + var server_name: ?[]const u8 = null; const ssl_enabled = ssl != null; defer if (ssl != null) ssl.?.deinit(); @@ -840,6 +841,9 @@ pub const Listener = struct { if (ssl.?.protos) |p| { protos = p[0..ssl.?.protos_len]; } + if (ssl.?.server_name) |s| { + server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable; + } uws.NewSocketHandler(true).configure( socket_context, true, @@ -892,6 +896,7 @@ pub const Listener = struct { .socket = undefined, .connection = connection, .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, + .server_name = server_name, }; TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -916,6 +921,7 @@ pub const Listener = struct { .socket = undefined, .connection = null, .protos = null, + .server_name = null, }; TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type { connection: ?Listener.UnixOrHost = null, protos: ?[]const u8, owned_protos: bool = true, + server_name: ?[]const u8 = null, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there @@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl) { var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); if (!ssl_ptr.isInitFinished()) { - if (this.connection) |connection| { + if (this.server_name) |server_name| { + const host = normalizeHost(server_name); + if (host.len > 0) { + var host__ = default_allocator.dupeZ(u8, host) catch unreachable; + defer default_allocator.free(host__); + ssl_ptr.setHostname(host__); + } + } else if (this.connection) |connection| { if (connection == .host) { const host = normalizeHost(connection.host.host); if (host.len > 0) { @@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type { pub fn onEnd(this: *This, socket: Socket) void { JSC.markBinding(@src()); log("onEnd", .{}); + if (this.detached) return; + this.detached = true; defer this.markInactive(); @@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type { } } + if (this.server_name) |server_name| { + this.server_name = null; + default_allocator.free(server_name); + } + if (this.connection) |connection| { this.connection = null; connection.deinit(); @@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.detached) { - return JSValue.jsUndefined(); - } if (this.handlers.is_server) { globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); @@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); - defer slice.deinit(); - const host = normalizeHost(slice.slice()); + const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable; + if (this.server_name) |old| { + this.server_name = slice; + default_allocator.free(old); + } else { + this.server_name = slice; + } + + if (this.detached) { + // will be attached onOpen + return JSValue.jsUndefined(); + } + + const host = normalizeHost(@as([]const u8, slice)); if (host.len > 0) { var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); if (ssl_ptr.isInitFinished()) { @@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type { .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, }; + if (socket_config.server_name) |server_name| { + tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable; + } + var tls_js_value = tls.getThisValue(globalObject); TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); @@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type { bun.default_allocator.destroy(tls); return JSValue.jsUndefined(); }; + tls.socket = new_socket; var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); |