aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Ciro Spaciari <ciro.spaciari@gmail.com> 2023-07-04 19:40:26 -0300
committerGravatar GitHub <noreply@github.com> 2023-07-04 15:40:26 -0700
commit979e99940374289e21332a0a7214889648e7397b (patch)
tree5b17ef26b04df20c7a590a3f3cd28c8f52e622fd
parentc2755f770cb0d6296c82a6f6d633d62307449028 (diff)
downloadbun-979e99940374289e21332a0a7214889648e7397b.tar.gz
bun-979e99940374289e21332a0a7214889648e7397b.tar.zst
bun-979e99940374289e21332a0a7214889648e7397b.zip
[tls] fix servername (#3513)
* fix servername * add postgres tls tests * update test packages * add basic CRUD test
-rw-r--r--src/bun.js/api/bun/socket.zig48
-rw-r--r--test/js/third_party/postgres/package.json8
-rw-r--r--test/js/third_party/postgres/postgres.test.ts56
-rw-r--r--test/package.json5
4 files changed, 109 insertions, 8 deletions
diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig
index 329cc40e4..12a4cffc8 100644
--- a/src/bun.js/api/bun/socket.zig
+++ b/src/bun.js/api/bun/socket.zig
@@ -820,6 +820,7 @@ pub const Listener = struct {
var default_data = socket_config.default_data;
var protos: ?[]const u8 = null;
+ var server_name: ?[]const u8 = null;
const ssl_enabled = ssl != null;
defer if (ssl != null) ssl.?.deinit();
@@ -840,6 +841,9 @@ pub const Listener = struct {
if (ssl.?.protos) |p| {
protos = p[0..ssl.?.protos_len];
}
+ if (ssl.?.server_name) |s| {
+ server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable;
+ }
uws.NewSocketHandler(true).configure(
socket_context,
true,
@@ -892,6 +896,7 @@ pub const Listener = struct {
.socket = undefined,
.connection = connection,
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null,
+ .server_name = server_name,
};
TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data);
@@ -916,6 +921,7 @@ pub const Listener = struct {
.socket = undefined,
.connection = null,
.protos = null,
+ .server_name = null,
};
TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data);
@@ -987,6 +993,7 @@ fn NewSocket(comptime ssl: bool) type {
connection: ?Listener.UnixOrHost = null,
protos: ?[]const u8,
owned_protos: bool = true,
+ server_name: ?[]const u8 = null,
// TODO: switch to something that uses `visitAggregate` and have the
// `Listener` keep a list of all the sockets JSValue in there
@@ -1170,7 +1177,14 @@ fn NewSocket(comptime ssl: bool) type {
if (comptime ssl) {
var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle());
if (!ssl_ptr.isInitFinished()) {
- if (this.connection) |connection| {
+ if (this.server_name) |server_name| {
+ const host = normalizeHost(server_name);
+ if (host.len > 0) {
+ var host__ = default_allocator.dupeZ(u8, host) catch unreachable;
+ defer default_allocator.free(host__);
+ ssl_ptr.setHostname(host__);
+ }
+ } else if (this.connection) |connection| {
if (connection == .host) {
const host = normalizeHost(connection.host.host);
if (host.len > 0) {
@@ -1248,6 +1262,8 @@ fn NewSocket(comptime ssl: bool) type {
pub fn onEnd(this: *This, socket: Socket) void {
JSC.markBinding(@src());
log("onEnd", .{});
+ if (this.detached) return;
+
this.detached = true;
defer this.markInactive();
@@ -1815,6 +1831,11 @@ fn NewSocket(comptime ssl: bool) type {
}
}
+ if (this.server_name) |server_name| {
+ this.server_name = null;
+ default_allocator.free(server_name);
+ }
+
if (this.connection) |connection| {
this.connection = null;
connection.deinit();
@@ -1898,9 +1919,6 @@ fn NewSocket(comptime ssl: bool) type {
if (comptime ssl == false) {
return JSValue.jsUndefined();
}
- if (this.detached) {
- return JSValue.jsUndefined();
- }
if (this.handlers.is_server) {
globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{});
@@ -1919,9 +1937,20 @@ fn NewSocket(comptime ssl: bool) type {
return .zero;
}
- const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator);
- defer slice.deinit();
- const host = normalizeHost(slice.slice());
+ const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable;
+ if (this.server_name) |old| {
+ this.server_name = slice;
+ default_allocator.free(old);
+ } else {
+ this.server_name = slice;
+ }
+
+ if (this.detached) {
+ // will be attached onOpen
+ return JSValue.jsUndefined();
+ }
+
+ const host = normalizeHost(@as([]const u8, slice));
if (host.len > 0) {
var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle());
if (ssl_ptr.isInitFinished()) {
@@ -2040,6 +2069,10 @@ fn NewSocket(comptime ssl: bool) type {
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null,
};
+ if (socket_config.server_name) |server_name| {
+ tls.server_name = bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable;
+ }
+
var tls_js_value = tls.getThisValue(globalObject);
TLSSocket.dataSetCached(tls_js_value, globalObject, default_data);
@@ -2061,6 +2094,7 @@ fn NewSocket(comptime ssl: bool) type {
bun.default_allocator.destroy(tls);
return JSValue.jsUndefined();
};
+
tls.socket = new_socket;
var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM");
diff --git a/test/js/third_party/postgres/package.json b/test/js/third_party/postgres/package.json
new file mode 100644
index 000000000..90809f48f
--- /dev/null
+++ b/test/js/third_party/postgres/package.json
@@ -0,0 +1,8 @@
+{
+ "name": "postgres",
+ "dependencies": {
+ "pg": "8.11.1",
+ "postgres": "3.3.5",
+ "pg-connection-string": "2.6.1"
+ }
+}
diff --git a/test/js/third_party/postgres/postgres.test.ts b/test/js/third_party/postgres/postgres.test.ts
new file mode 100644
index 000000000..490192ae7
--- /dev/null
+++ b/test/js/third_party/postgres/postgres.test.ts
@@ -0,0 +1,56 @@
+import { test, expect, describe } from "bun:test";
+import { Pool } from "pg";
+import { parse } from "pg-connection-string";
+import postgres from "postgres";
+
+const CONNECTION_STRING = process.env.TLS_POSTGRES_DATABASE_URL;
+
+const it = CONNECTION_STRING ? test : test.skip;
+
+describe("pg", () => {
+ it("should connect using TLS", async () => {
+ const pool = new Pool(parse(CONNECTION_STRING as string));
+ try {
+ const { rows } = await pool.query("SELECT version()", []);
+ const [{ version }] = rows;
+
+ expect(version).toMatch(/PostgreSQL/);
+ } finally {
+ pool.end();
+ }
+ });
+});
+
+describe("postgres", () => {
+ it("should connect using TLS", async () => {
+ const sql = postgres(CONNECTION_STRING as string);
+ try {
+ const [{ version }] = await sql`SELECT version()`;
+ expect(version).toMatch(/PostgreSQL/);
+ } finally {
+ sql.end();
+ }
+ });
+
+ it("should insert, select and delete", async () => {
+ const sql = postgres(CONNECTION_STRING as string);
+ try {
+ await sql`CREATE TABLE IF NOT EXISTS users (
+ user_id serial PRIMARY KEY,
+ username VARCHAR ( 50 ) NOT NULL
+ );`;
+
+ const [{ user_id, username }] = await sql`insert into users (username) values ('bun') returning *`;
+ expect(username).toBe("bun");
+
+ const [{ user_id: user_id2, username: username2 }] = await sql`select * from users where user_id = ${user_id}`;
+ expect(username2).toBe("bun");
+ expect(user_id2).toBe(user_id);
+
+ const [{ username: username3 }] = await sql`delete from users where user_id = ${user_id} returning *`;
+ expect(username3).toBe("bun");
+ } finally {
+ sql.end();
+ }
+ });
+});
diff --git a/test/package.json b/test/package.json
index d26875126..5529d3f20 100644
--- a/test/package.json
+++ b/test/package.json
@@ -26,7 +26,10 @@
"undici": "5.20.0",
"vitest": "0.32.2",
"webpack": "5.88.0",
- "webpack-cli": "4.7.2"
+ "webpack-cli": "4.7.2",
+ "pg": "8.11.1",
+ "postgres": "3.3.5",
+ "pg-connection-string": "2.6.1"
},
"private": true,
"scripts": {