diff options
| author | 2023-07-04 19:40:26 -0300 | |
|---|---|---|
| committer | 2023-07-04 15:40:26 -0700 | |
| commit | 979e99940374289e21332a0a7214889648e7397b (patch) | |
| tree | 5b17ef26b04df20c7a590a3f3cd28c8f52e622fd /src | |
| parent | c2755f770cb0d6296c82a6f6d633d62307449028 (diff) | |
| download | bun-979e99940374289e21332a0a7214889648e7397b.tar.gz bun-979e99940374289e21332a0a7214889648e7397b.tar.zst bun-979e99940374289e21332a0a7214889648e7397b.zip | |
[tls] fix servername (#3513)
* fix servername
* add postgres tls tests
* update test packages
* add basic CRUD test
Diffstat (limited to 'src')
| -rw-r--r-- | src/bun.js/api/bun/socket.zig | 48 | 
1 files changed, 41 insertions, 7 deletions
| diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 329cc40e4..12a4cffc8 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -820,6 +820,7 @@ pub const Listener = struct {          var default_data = socket_config.default_data;          var protos: ?[]const u8 = null; +        var server_name: ?[]const u8 = null;          const ssl_enabled = ssl != null;          defer if (ssl != null) ssl.?.deinit(); @@ -840,6 +841,9 @@ pub const Listener = struct {              if (ssl.?.protos) |p| {                  protos = p[0..ssl.?.protos_len];              } +            if (ssl.?.server_name) |s| { +                server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable; +            }              uws.NewSocketHandler(true).configure(                  socket_context,                  true, @@ -892,6 +896,7 @@ pub const Listener = struct {                  .socket = undefined,                  .connection = connection,                  .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, +                .server_name = server_name,              };              TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -916,6 +921,7 @@ pub const Listener = struct {                  .socket = undefined,                  .connection = null,                  .protos = null, +                .server_name = null,              };              TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type {          connection: ?Listener.UnixOrHost = null,          protos: ?[]const u8,          owned_protos: bool = true, +        server_name: ?[]const u8 = null,          // TODO: switch to something that uses `visitAggregate` and have the          // `Listener` keep a list of all the sockets JSValue in there @@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type {              if (comptime ssl) {                  var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle());                  if (!ssl_ptr.isInitFinished()) { -                    if (this.connection) |connection| { +                    if (this.server_name) |server_name| { +                        const host = normalizeHost(server_name); +                        if (host.len > 0) { +                            var host__ = default_allocator.dupeZ(u8, host) catch unreachable; +                            defer default_allocator.free(host__); +                            ssl_ptr.setHostname(host__); +                        } +                    } else if (this.connection) |connection| {                          if (connection == .host) {                              const host = normalizeHost(connection.host.host);                              if (host.len > 0) { @@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type {          pub fn onEnd(this: *This, socket: Socket) void {              JSC.markBinding(@src());              log("onEnd", .{}); +            if (this.detached) return; +              this.detached = true;              defer this.markInactive(); @@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type {                  }              } +            if (this.server_name) |server_name| { +                this.server_name = null; +                default_allocator.free(server_name); +            } +              if (this.connection) |connection| {                  this.connection = null;                  connection.deinit(); @@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type {              if (comptime ssl == false) {                  return JSValue.jsUndefined();              } -            if (this.detached) { -                return JSValue.jsUndefined(); -            }              if (this.handlers.is_server) {                  globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); @@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type {                  return .zero;              } -            const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); -            defer slice.deinit(); -            const host = normalizeHost(slice.slice()); +            const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable; +            if (this.server_name) |old| { +                this.server_name = slice; +                default_allocator.free(old); +            } else { +                this.server_name = slice; +            } + +            if (this.detached) { +                // will be attached onOpen +                return JSValue.jsUndefined(); +            } + +            const host = normalizeHost(@as([]const u8, slice));              if (host.len > 0) {                  var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle());                  if (ssl_ptr.isInitFinished()) { @@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type {                  .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null,              }; +            if (socket_config.server_name) |server_name| { +                tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable; +            } +              var tls_js_value = tls.getThisValue(globalObject);              TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); @@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type {                  bun.default_allocator.destroy(tls);                  return JSValue.jsUndefined();              }; +              tls.socket = new_socket;              var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); | 
