diff options
| -rw-r--r-- | src/bun.js/api/bun/socket.zig | 48 | ||||
| -rw-r--r-- | test/js/third_party/postgres/package.json | 8 | ||||
| -rw-r--r-- | test/js/third_party/postgres/postgres.test.ts | 56 | ||||
| -rw-r--r-- | test/package.json | 5 | 
4 files changed, 109 insertions, 8 deletions
| diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 329cc40e4..12a4cffc8 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -820,6 +820,7 @@ pub const Listener = struct {          var default_data = socket_config.default_data;          var protos: ?[]const u8 = null; +        var server_name: ?[]const u8 = null;          const ssl_enabled = ssl != null;          defer if (ssl != null) ssl.?.deinit(); @@ -840,6 +841,9 @@ pub const Listener = struct {              if (ssl.?.protos) |p| {                  protos = p[0..ssl.?.protos_len];              } +            if (ssl.?.server_name) |s| { +                server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable; +            }              uws.NewSocketHandler(true).configure(                  socket_context,                  true, @@ -892,6 +896,7 @@ pub const Listener = struct {                  .socket = undefined,                  .connection = connection,                  .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, +                .server_name = server_name,              };              TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -916,6 +921,7 @@ pub const Listener = struct {                  .socket = undefined,                  .connection = null,                  .protos = null, +                .server_name = null,              };              TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type {          connection: ?Listener.UnixOrHost = null,          protos: ?[]const u8,          owned_protos: bool = true, +        server_name: ?[]const u8 = null,          // TODO: switch to something that uses `visitAggregate` and have the          // `Listener` keep a list of all the sockets JSValue in there @@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type {              if (comptime ssl) {                  var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle());                  if (!ssl_ptr.isInitFinished()) { -                    if (this.connection) |connection| { +                    if (this.server_name) |server_name| { +                        const host = normalizeHost(server_name); +                        if (host.len > 0) { +                            var host__ = default_allocator.dupeZ(u8, host) catch unreachable; +                            defer default_allocator.free(host__); +                            ssl_ptr.setHostname(host__); +                        } +                    } else if (this.connection) |connection| {                          if (connection == .host) {                              const host = normalizeHost(connection.host.host);                              if (host.len > 0) { @@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type {          pub fn onEnd(this: *This, socket: Socket) void {              JSC.markBinding(@src());              log("onEnd", .{}); +            if (this.detached) return; +              this.detached = true;              defer this.markInactive(); @@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type {                  }              } +            if (this.server_name) |server_name| { +                this.server_name = null; +                default_allocator.free(server_name); +            } +              if (this.connection) |connection| {                  this.connection = null;                  connection.deinit(); @@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type {              if (comptime ssl == false) {                  return JSValue.jsUndefined();              } -            if (this.detached) { -                return JSValue.jsUndefined(); -            }              if (this.handlers.is_server) {                  globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); @@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type {                  return .zero;              } -            const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); -            defer slice.deinit(); -            const host = normalizeHost(slice.slice()); +            const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable; +            if (this.server_name) |old| { +                this.server_name = slice; +                default_allocator.free(old); +            } else { +                this.server_name = slice; +            } + +            if (this.detached) { +                // will be attached onOpen +                return JSValue.jsUndefined(); +            } + +            const host = normalizeHost(@as([]const u8, slice));              if (host.len > 0) {                  var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle());                  if (ssl_ptr.isInitFinished()) { @@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type {                  .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null,              }; +            if (socket_config.server_name) |server_name| { +                tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable; +            } +              var tls_js_value = tls.getThisValue(globalObject);              TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); @@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type {                  bun.default_allocator.destroy(tls);                  return JSValue.jsUndefined();              }; +              tls.socket = new_socket;              var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); diff --git a/test/js/third_party/postgres/package.json b/test/js/third_party/postgres/package.json new file mode 100644 index 000000000..90809f48f --- /dev/null +++ b/test/js/third_party/postgres/package.json @@ -0,0 +1,8 @@ +{ +  "name": "postgres", +  "dependencies": { +    "pg": "8.11.1",  +    "postgres": "3.3.5", +    "pg-connection-string": "2.6.1" +  } +} diff --git a/test/js/third_party/postgres/postgres.test.ts b/test/js/third_party/postgres/postgres.test.ts new file mode 100644 index 000000000..490192ae7 --- /dev/null +++ b/test/js/third_party/postgres/postgres.test.ts @@ -0,0 +1,56 @@ +import { test, expect, describe } from "bun:test"; +import { Pool } from "pg"; +import { parse } from "pg-connection-string"; +import postgres from "postgres"; + +const CONNECTION_STRING = process.env.TLS_POSTGRES_DATABASE_URL; + +const it = CONNECTION_STRING ? test : test.skip; + +describe("pg", () => { +  it("should connect using TLS", async () => { +    const pool = new Pool(parse(CONNECTION_STRING as string)); +    try { +      const { rows } = await pool.query("SELECT version()", []); +      const [{ version }] = rows; + +      expect(version).toMatch(/PostgreSQL/); +    } finally { +      pool.end(); +    } +  }); +}); + +describe("postgres", () => { +  it("should connect using TLS", async () => { +    const sql = postgres(CONNECTION_STRING as string); +    try { +      const [{ version }] = await sql`SELECT version()`; +      expect(version).toMatch(/PostgreSQL/); +    } finally { +      sql.end(); +    } +  }); + +  it("should insert, select and delete", async () => { +    const sql = postgres(CONNECTION_STRING as string); +    try { +      await sql`CREATE TABLE IF NOT EXISTS users ( +            user_id serial PRIMARY KEY, +            username VARCHAR ( 50 ) NOT NULL +        );`; + +      const [{ user_id, username }] = await sql`insert into users (username) values ('bun') returning *`; +      expect(username).toBe("bun"); + +      const [{ user_id: user_id2, username: username2 }] = await sql`select * from users where user_id = ${user_id}`; +      expect(username2).toBe("bun"); +      expect(user_id2).toBe(user_id); + +      const [{ username: username3 }] = await sql`delete from users where user_id = ${user_id} returning *`; +      expect(username3).toBe("bun"); +    } finally { +      sql.end(); +    } +  }); +}); diff --git a/test/package.json b/test/package.json index d26875126..5529d3f20 100644 --- a/test/package.json +++ b/test/package.json @@ -26,7 +26,10 @@      "undici": "5.20.0",      "vitest": "0.32.2",      "webpack": "5.88.0", -    "webpack-cli": "4.7.2" +    "webpack-cli": "4.7.2", +    "pg": "8.11.1",  +    "postgres": "3.3.5", +    "pg-connection-string": "2.6.1"    },    "private": true,    "scripts": { | 
