diff options
-rw-r--r-- | src/bun.js/api/bun/socket.zig | 48 | ||||
-rw-r--r-- | test/js/third_party/postgres/package.json | 8 | ||||
-rw-r--r-- | test/js/third_party/postgres/postgres.test.ts | 56 | ||||
-rw-r--r-- | test/package.json | 5 |
4 files changed, 109 insertions, 8 deletions
diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 329cc40e4..12a4cffc8 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -820,6 +820,7 @@ pub const Listener = struct { var default_data = socket_config.default_data; var protos: ?[]const u8 = null; + var server_name: ?[]const u8 = null; const ssl_enabled = ssl != null; defer if (ssl != null) ssl.?.deinit(); @@ -840,6 +841,9 @@ pub const Listener = struct { if (ssl.?.protos) |p| { protos = p[0..ssl.?.protos_len]; } + if (ssl.?.server_name) |s| { + server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable; + } uws.NewSocketHandler(true).configure( socket_context, true, @@ -892,6 +896,7 @@ pub const Listener = struct { .socket = undefined, .connection = connection, .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, + .server_name = server_name, }; TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -916,6 +921,7 @@ pub const Listener = struct { .socket = undefined, .connection = null, .protos = null, + .server_name = null, }; TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type { connection: ?Listener.UnixOrHost = null, protos: ?[]const u8, owned_protos: bool = true, + server_name: ?[]const u8 = null, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there @@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl) { var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); if (!ssl_ptr.isInitFinished()) { - if (this.connection) |connection| { + if (this.server_name) |server_name| { + const host = normalizeHost(server_name); + if (host.len > 0) { + var host__ = default_allocator.dupeZ(u8, host) catch unreachable; + defer default_allocator.free(host__); + ssl_ptr.setHostname(host__); + } + } else if (this.connection) |connection| { if (connection == .host) { const host = normalizeHost(connection.host.host); if (host.len > 0) { @@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type { pub fn onEnd(this: *This, socket: Socket) void { JSC.markBinding(@src()); log("onEnd", .{}); + if (this.detached) return; + this.detached = true; defer this.markInactive(); @@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type { } } + if (this.server_name) |server_name| { + this.server_name = null; + default_allocator.free(server_name); + } + if (this.connection) |connection| { this.connection = null; connection.deinit(); @@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.detached) { - return JSValue.jsUndefined(); - } if (this.handlers.is_server) { globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); @@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); - defer slice.deinit(); - const host = normalizeHost(slice.slice()); + const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable; + if (this.server_name) |old| { + this.server_name = slice; + default_allocator.free(old); + } else { + this.server_name = slice; + } + + if (this.detached) { + // will be attached onOpen + return JSValue.jsUndefined(); + } + + const host = normalizeHost(@as([]const u8, slice)); if (host.len > 0) { var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); if (ssl_ptr.isInitFinished()) { @@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type { .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, }; + if (socket_config.server_name) |server_name| { + tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable; + } + var tls_js_value = tls.getThisValue(globalObject); TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); @@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type { bun.default_allocator.destroy(tls); return JSValue.jsUndefined(); }; + tls.socket = new_socket; var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); diff --git a/test/js/third_party/postgres/package.json b/test/js/third_party/postgres/package.json new file mode 100644 index 000000000..90809f48f --- /dev/null +++ b/test/js/third_party/postgres/package.json @@ -0,0 +1,8 @@ +{ + "name": "postgres", + "dependencies": { + "pg": "8.11.1", + "postgres": "3.3.5", + "pg-connection-string": "2.6.1" + } +} diff --git a/test/js/third_party/postgres/postgres.test.ts b/test/js/third_party/postgres/postgres.test.ts new file mode 100644 index 000000000..490192ae7 --- /dev/null +++ b/test/js/third_party/postgres/postgres.test.ts @@ -0,0 +1,56 @@ +import { test, expect, describe } from "bun:test"; +import { Pool } from "pg"; +import { parse } from "pg-connection-string"; +import postgres from "postgres"; + +const CONNECTION_STRING = process.env.TLS_POSTGRES_DATABASE_URL; + +const it = CONNECTION_STRING ? test : test.skip; + +describe("pg", () => { + it("should connect using TLS", async () => { + const pool = new Pool(parse(CONNECTION_STRING as string)); + try { + const { rows } = await pool.query("SELECT version()", []); + const [{ version }] = rows; + + expect(version).toMatch(/PostgreSQL/); + } finally { + pool.end(); + } + }); +}); + +describe("postgres", () => { + it("should connect using TLS", async () => { + const sql = postgres(CONNECTION_STRING as string); + try { + const [{ version }] = await sql`SELECT version()`; + expect(version).toMatch(/PostgreSQL/); + } finally { + sql.end(); + } + }); + + it("should insert, select and delete", async () => { + const sql = postgres(CONNECTION_STRING as string); + try { + await sql`CREATE TABLE IF NOT EXISTS users ( + user_id serial PRIMARY KEY, + username VARCHAR ( 50 ) NOT NULL + );`; + + const [{ user_id, username }] = await sql`insert into users (username) values ('bun') returning *`; + expect(username).toBe("bun"); + + const [{ user_id: user_id2, username: username2 }] = await sql`select * from users where user_id = ${user_id}`; + expect(username2).toBe("bun"); + expect(user_id2).toBe(user_id); + + const [{ username: username3 }] = await sql`delete from users where user_id = ${user_id} returning *`; + expect(username3).toBe("bun"); + } finally { + sql.end(); + } + }); +}); diff --git a/test/package.json b/test/package.json index d26875126..5529d3f20 100644 --- a/test/package.json +++ b/test/package.json @@ -26,7 +26,10 @@ "undici": "5.20.0", "vitest": "0.32.2", "webpack": "5.88.0", - "webpack-cli": "4.7.2" + "webpack-cli": "4.7.2", + "pg": "8.11.1", + "postgres": "3.3.5", + "pg-connection-string": "2.6.1" }, "private": true, "scripts": { |