aboutsummaryrefslogtreecommitdiff
path: root/src/bun.js/api/bun/socket.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/bun.js/api/bun/socket.zig')
-rw-r--r--src/bun.js/api/bun/socket.zig48
1 files changed, 41 insertions, 7 deletions
diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig
index 329cc40e4..12a4cffc8 100644
--- a/src/bun.js/api/bun/socket.zig
+++ b/src/bun.js/api/bun/socket.zig
@@ -820,6 +820,7 @@ pub const Listener = struct {
var default_data = socket_config.default_data;
var protos: ?[]const u8 = null;
+ var server_name: ?[]const u8 = null;
const ssl_enabled = ssl != null;
defer if (ssl != null) ssl.?.deinit();
@@ -840,6 +841,9 @@ pub const Listener = struct {
if (ssl.?.protos) |p| {
protos = p[0..ssl.?.protos_len];
}
+ if (ssl.?.server_name) |s| {
+ server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable;
+ }
uws.NewSocketHandler(true).configure(
socket_context,
true,
@@ -892,6 +896,7 @@ pub const Listener = struct {
.socket = undefined,
.connection = connection,
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null,
+ .server_name = server_name,
};
TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data);
@@ -916,6 +921,7 @@ pub const Listener = struct {
.socket = undefined,
.connection = null,
.protos = null,
+ .server_name = null,
};
TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data);
@@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type {
connection: ?Listener.UnixOrHost = null,
protos: ?[]const u8,
owned_protos: bool = true,
+ server_name: ?[]const u8 = null,
// TODO: switch to something that uses `visitAggregate` and have the
// `Listener` keep a list of all the sockets JSValue in there
@@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type {
if (comptime ssl) {
var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle());
if (!ssl_ptr.isInitFinished()) {
- if (this.connection) |connection| {
+ if (this.server_name) |server_name| {
+ const host = normalizeHost(server_name);
+ if (host.len > 0) {
+ var host__ = default_allocator.dupeZ(u8, host) catch unreachable;
+ defer default_allocator.free(host__);
+ ssl_ptr.setHostname(host__);
+ }
+ } else if (this.connection) |connection| {
if (connection == .host) {
const host = normalizeHost(connection.host.host);
if (host.len > 0) {
@@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type {
pub fn onEnd(this: *This, socket: Socket) void {
JSC.markBinding(@src());
log("onEnd", .{});
+ if (this.detached) return;
+
this.detached = true;
defer this.markInactive();
@@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type {
}
}
+ if (this.server_name) |server_name| {
+ this.server_name = null;
+ default_allocator.free(server_name);
+ }
+
if (this.connection) |connection| {
this.connection = null;
connection.deinit();
@@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type {
if (comptime ssl == false) {
return JSValue.jsUndefined();
}
- if (this.detached) {
- return JSValue.jsUndefined();
- }
if (this.handlers.is_server) {
globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{});
@@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type {
return .zero;
}
- const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator);
- defer slice.deinit();
- const host = normalizeHost(slice.slice());
+ const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable;
+ if (this.server_name) |old| {
+ this.server_name = slice;
+ default_allocator.free(old);
+ } else {
+ this.server_name = slice;
+ }
+
+ if (this.detached) {
+ // will be attached onOpen
+ return JSValue.jsUndefined();
+ }
+
+ const host = normalizeHost(@as([]const u8, slice));
if (host.len > 0) {
var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle());
if (ssl_ptr.isInitFinished()) {
@@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type {
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null,
};
+ if (socket_config.server_name) |server_name| {
+ tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable;
+ }
+
var tls_js_value = tls.getThisValue(globalObject);
TLSSocket.dataSetCached(tls_js_value, globalObject, default_data);
@@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type {
bun.default_allocator.destroy(tls);
return JSValue.jsUndefined();
};
+
tls.socket = new_socket;
var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM");