diff options
Diffstat (limited to 'src/sql/postgres.zig')
-rw-r--r-- | src/sql/postgres.zig | 387 |
1 files changed, 230 insertions, 157 deletions
diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index bcd8225de..203324445 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -18,7 +18,7 @@ const Data = union(enum) { } pub fn deinit(this: *@This()) void { - switch (this) { + switch (this.*) { .owned => this.owned.deinitWithAllocator(bun.default_allocator), .temporary => {}, .empty => {}, @@ -36,7 +36,7 @@ const Data = union(enum) { pub fn sliceZ(this: @This()) [:0]const u8 { return switch (this) { .owned => this.owned.slice()[0..this.owned.len :0], - .temporary => this.temporary, + .temporary => this.temporary[0..this.temporary.len :0], .empty => "", }; } @@ -70,6 +70,10 @@ pub const protocol = struct { this.message_start.* = this.offset.*; } + pub fn ensureLength(this: @This(), length: usize) bool { + return this.buffer.len >= (this.offset.* + length); + } + pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) protocol.NewReader(StackReader) { return .{ .wrapped = .{ @@ -145,7 +149,7 @@ pub const protocol = struct { } pub fn @"f64"(this: @This(), value: f64) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); + try this.write(std.mem.asBytes(&@byteSwap(@as(u64, @bitCast(value))))); } pub fn @"i16"(this: @This(), value: i16) !void { @@ -169,7 +173,7 @@ pub const protocol = struct { } pub fn @"null"(this: @This()) !void { - try this.i32(0xFFFFFFFF); + try this.i32(-1); } pub fn String(this: @This(), value: bun.String) !void { @@ -268,8 +272,8 @@ pub const protocol = struct { R: String, pub fn deinit(this: *FieldMessage) void { - switch (this) { - inline else => |message| { + switch (this.*) { + inline else => |*message| { message.deref(); }, } @@ -280,7 +284,7 @@ pub const protocol = struct { while (true) { const field_int = try reader.int(u8); if (field_int == 0) break; - const field: FieldType = @intFromEnum(field_int); + const field: FieldType = @enumFromInt(field_int); var message = try reader.readZ(); defer message.deinit(); @@ -294,24 +298,24 @@ pub const protocol = struct { pub fn init(tag: FieldType, message: []const u8) !FieldMessage { return switch (tag) { - .S => String.create(message), - .V => String.create(message), - .C => String.create(message), - .M => String.create(message), - .D => String.create(message), - .H => String.create(message), - .P => String.create(message), - .p => String.create(message), - .q => String.create(message), - .W => String.create(message), - .s => String.create(message), - .t => String.create(message), - .c => String.create(message), - .d => String.create(message), - .n => String.create(message), - .F => String.create(message), - .L => String.create(message), - .R => String.create(message), + .S => FieldMessage{ .S = String.create(message) }, + .V => FieldMessage{ .V = String.create(message) }, + .C => FieldMessage{ .C = String.create(message) }, + .M => FieldMessage{ .M = String.create(message) }, + .D => FieldMessage{ .D = String.create(message) }, + .H => FieldMessage{ .H = String.create(message) }, + .P => FieldMessage{ .P = String.create(message) }, + .p => FieldMessage{ .p = String.create(message) }, + .q => FieldMessage{ .q = String.create(message) }, + .W => FieldMessage{ .W = String.create(message) }, + .s => FieldMessage{ .s = String.create(message) }, + .t => FieldMessage{ .t = String.create(message) }, + .c => FieldMessage{ .c = String.create(message) }, + .d => FieldMessage{ .d = String.create(message) }, + .n => FieldMessage{ .n = String.create(message) }, + .F => FieldMessage{ .F = String.create(message) }, + .L => FieldMessage{ .L = String.create(message) }, + .R => FieldMessage{ .R = String.create(message) }, else => error.UnknownFieldType, }; } @@ -345,11 +349,11 @@ pub const protocol = struct { return try readFn(this.wrapped, count); } - pub inline fn eatMessage(this: @This(), comptime msg_: []const u8) anyerror!void { + pub inline fn eatMessage(this: @This(), comptime msg_: anytype) anyerror!void { const msg = msg_[1..]; var input = try readFn(this.wrapped, msg.len); defer input.deinit(); - if (bun.strings.eqlLong(input.slice(), msg)) return; + if (bun.strings.eqlComptime(input.slice(), msg)) return; } pub fn skip(this: @This(), count: usize) anyerror!void { @@ -373,7 +377,10 @@ pub const protocol = struct { pub fn int(this: @This(), comptime Int: type) !Int { var data = try this.read(@sizeOf((Int))); defer data.deinit(); - return @byteSwap(@as(Int, data.slice()[0..@sizeOf(Int)].*)); + if (comptime Int == u8) { + return @as(Int, data.slice()[0]); + } + return @byteSwap(@as(Int, @bitCast(data.slice()[0..@sizeOf(Int)].*))); } pub fn peekInt(this: @This(), comptime Int: type) ?Int { @@ -417,7 +424,7 @@ pub const protocol = struct { } pub fn NewReader(comptime Context: type) type { - return NewReaderWrap(Context, Context.markMessageStart, Context.skip, Context.peek, Context.ensureLength, Context.read, Context.readZ); + return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureLength, Context.read, Context.readZ); } pub fn NewWriter(comptime Context: type) type { @@ -425,7 +432,7 @@ pub const protocol = struct { } comptime { - if (@import("builtin").cpu.arch.endian() != .little) { + if (@import("builtin").cpu.arch.endian() != .Little) { @compileError("Postgres protocol implementation assumes little endian"); } } @@ -463,7 +470,7 @@ pub const protocol = struct { SSPI: struct {}, SASL: struct { mechanisms: Data, - data: Data, + data: Data = .{ .empty = {} }, }, SASLContinue: struct { data: Data, @@ -495,7 +502,9 @@ pub const protocol = struct { }, 5 => { if (message_length != 12) return error.InvalidMessageLength; - try reader.expectInt(u32, 5); + if (!try reader.expectInt(u32, 5)) { + return error.InvalidMessage; + } var salt_data = try reader.bytes(4); defer salt_data.deinit(); this.* = .{ @@ -513,8 +522,8 @@ pub const protocol = struct { 8 => { if (message_length < 9) return error.InvalidMessageLength; - const remaining = message_length -| (8 - 1); - const bytes = try reader.read(remaining); + const remaining: usize = @intCast(@max(message_length -| (8 - 1), 0)); + const bytes = try reader.read(@intCast(remaining)); this.* = .{ .GSSContinue = .{ .data = bytes, @@ -530,7 +539,7 @@ pub const protocol = struct { 10 => { if (message_length < 9) return error.InvalidMessageLength; - const remaining = message_length -| (8 - 1); + const remaining: usize = @intCast(@max(message_length -| (8 - 1), 0)); const bytes = try reader.read(remaining); this.* = .{ .SASL = .{ @@ -541,7 +550,8 @@ pub const protocol = struct { 11 => { if (message_length < 9) return error.InvalidMessageLength; - const remaining = message_length -| (8 - 1); + const remaining: usize = @intCast(@max(message_length -| (8 - 1), 0)); + const bytes = try reader.read(remaining); this.* = .{ .SASLContinue = .{ @@ -552,7 +562,8 @@ pub const protocol = struct { 12 => { if (message_length < 9) return error.InvalidMessageLength; - const remaining = message_length -| (8 - 1); + const remaining: usize = @intCast(@max(message_length -| (8 - 1), 0)); + const bytes = try reader.read(remaining); this.* = .{ .SASLFinal = .{ @@ -598,11 +609,13 @@ pub const protocol = struct { pub const decode = decoderWrap(BackendKeyData, decodeInternal).decode; pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - try reader.expectInt(u32, 12); + if (!try reader.expectInt(u32, 12)) { + return error.InvalidMessage; + } this.* = .{ - .process_id = try reader.i32(), - .secret_key = try reader.i32(), + .process_id = @bitCast(try reader.i32()), + .secret_key = @bitCast(try reader.i32()), }; } }; @@ -611,7 +624,7 @@ pub const protocol = struct { messages: std.ArrayListUnmanaged(FieldMessage) = .{}, pub fn deinit(this: *ErrorResponse) void { - for (this.messages.items) |message| { + for (this.messages.items) |*message| { message.deinit(); } this.messages.deinit(bun.default_allocator); @@ -688,7 +701,7 @@ pub const protocol = struct { pub const Terminate = [_]u8{'X'} ++ toBytes(Int32(4)); fn Int32(value: anytype) [4]u8 { - return @byteSwap(@as(i32, @intCast(value))); + return @bitCast(@byteSwap(@as(i32, @intCast(value)))); } const toBytes = std.mem.toBytes; @@ -739,11 +752,11 @@ pub const protocol = struct { }; pub const DataRow = struct { - pub fn decode(context: anyopaque, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: i16, bytes: ?*Data) anyerror!bool) anyerror!void { + pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) anyerror!bool) anyerror!void { var remaining_bytes = try reader.length(); remaining_bytes -|= 4; - var remaining_fields = try reader.i16(); + var remaining_fields: usize = @intCast(@max(try reader.i16(), 0)); for (0..remaining_fields) |index| { const byte_length = try reader.i32(); @@ -751,14 +764,14 @@ pub const protocol = struct { 0 => break, else => { var bytes = try reader.bytes(@intCast(byte_length)); - if (!try forEach(context, index, &bytes)) break; + if (!try forEach(context, @intCast(index), &bytes)) break; }, -1 => { - if (!try forEach(context, index, null)) break; + if (!try forEach(context, @intCast(index), null)) break; }, - std.math.minInt(i32)...-1 => { + std.math.minInt(i32)...-2 => { return error.InvalidMessageLength; }, } @@ -779,7 +792,7 @@ pub const protocol = struct { } pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const name = try reader.readZ(); + var name = try reader.readZ(); errdefer { name.deinit(); } @@ -787,7 +800,7 @@ pub const protocol = struct { .table_oid = try reader.i32(), .column_index = try reader.i16(), .type_oid = @truncate(try reader.i32()), - .name = try name.toOwned(), + .name = .{ .owned = try name.toOwned() }, }; try reader.skip(12); @@ -799,8 +812,8 @@ pub const protocol = struct { pub const RowDescription = struct { fields: []const FieldDescription = &[_]FieldDescription{}, pub fn deinit(this: *@This()) void { - for (this.fields) |field| { - field.deinit(); + for (this.fields) |*field| { + @constCast(field).deinit(); } bun.default_allocator.free(this.fields); @@ -810,8 +823,11 @@ pub const protocol = struct { var remaining_bytes = try reader.length(); remaining_bytes -|= 4; - const field_count = try reader.i16(); - var fields = try bun.default_allocator.alloc(FieldDescription, field_count); + const field_count: usize = @intCast(@max(try reader.i16(), 0)); + var fields = try bun.default_allocator.alloc( + FieldDescription, + field_count, + ); var remaining = fields; errdefer { for (fields[0 .. field_count - remaining.len]) |*field| { @@ -821,7 +837,7 @@ pub const protocol = struct { bun.default_allocator.free(fields); } while (remaining.len > 0) { - remaining[0] = try FieldDescription.decodeInternal(Container, reader); + try remaining[0].decodeInternal(Container, reader); remaining = remaining[1..]; } this.* = .{ @@ -840,11 +856,11 @@ pub const protocol = struct { remaining_bytes -|= 4; const count = try reader.i16(); - var parameters = try bun.default_allocator.alloc(i32, count); + var parameters = try bun.default_allocator.alloc(i32, @intCast(@max(count, 0))); - var data = try reader.read(count * @sizeOf((i32))); + var data = try reader.read(@as(usize, @intCast(@max(count, 0))) * @sizeOf((i32))); defer data.deinit(); - const input_params: []align(1) const i32 = @ptrCast(data.slice()); + const input_params: []align(1) const i32 = toInt32Slice(i32, data.slice()); for (input_params, parameters) |src, *dest| { dest.* = @byteSwap(src); } @@ -857,6 +873,11 @@ pub const protocol = struct { pub const decode = decoderWrap(ParameterDescription, decodeInternal).decode; }; + // workaround for zig compiler TODO + fn toInt32Slice(comptime Int: type, slice: []const u8) []align(1) const Int { + return @as([*]align(1) const Int, @ptrCast(slice.ptr))[0 .. slice.len / @sizeOf((Int))]; + } + pub const NotificationResponse = struct { pid: i32 = 0, channel: bun.ByteList = .{}, @@ -884,6 +905,10 @@ pub const protocol = struct { pub const CommandComplete = struct { command_tag: Data = .{ .empty = {} }, + pub fn deinit(this: *@This()) void { + this.command_tag.deinit(); + } + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { const length = try reader.length(); std.debug.assert(length >= 4); @@ -919,7 +944,7 @@ pub const protocol = struct { try writer.write(&header); try writer.string(this.name); try writer.string(this.query); - try writer.i16(@truncate(parameters.len)); + try writer.i16(@intCast(parameters.len)); for (parameters) |parameter| { try writer.i32(parameter); } @@ -960,7 +985,7 @@ pub const protocol = struct { pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { const length = try reader.length(); - const data = try reader.read(length -| 5); + const data = try reader.read(@intCast(length -| 5)); this.* = .{ .data = data, }; @@ -1059,7 +1084,7 @@ pub const protocol = struct { const count: usize = @sizeOf((i32)) + @sizeOf((i32)) + user.len + 1 + database.len + 1 + options.len + 1; - const header = toBytes(Int32(@truncate(count))); + const header = toBytes(Int32(@as(u32, @truncate(count)))); try writer.write(&header); try writer.i32(196608); if (user.len > 0) @@ -1088,7 +1113,7 @@ pub const protocol = struct { comptime Context: type, writer: NewWriter(Context), ) !void { - const message = this.message.slice(); + const message = this.p.slice(); const count: usize = @sizeOf((u32)) + @sizeOf((u32)) + message.len + 1; const header = [_]u8{ 'E', @@ -1109,8 +1134,8 @@ pub const protocol = struct { comptime Context: type, writer: NewWriter(Context), ) !void { - const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + @sizeOf((u32)) + message.len + 1; + const message = this.p.slice(); + const count: usize = @sizeOf((u32)) + @sizeOf((u32)) + message.len + 1; const header = [_]u8{ 'D', } ++ toBytes(Int32(count)); @@ -1187,7 +1212,7 @@ pub const protocol = struct { pub const NoticeResponse = struct { messages: std.ArrayListUnmanaged(FieldMessage) = .{}, pub fn deinit(this: *NoticeResponse) void { - for (this.messages.items) |message| { + for (this.messages.items) |*message| { message.deinit(); } this.messages.deinit(bun.default_allocator); @@ -1492,8 +1517,10 @@ pub const PostgresSQLContext = struct { pub fn init(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSC.JSValue { var ctx = &globalObject.bunVM().rareData().postgresql_context; - ctx.onQueryResolveFn.create(callframe.argument(0), globalObject); - ctx.onQueryRejectFn.create(callframe.argument(1), globalObject); + ctx.onQueryResolveFn.set(globalObject, callframe.argument(0)); + ctx.onQueryRejectFn.set(globalObject, callframe.argument(1)); + + return .undefined; } comptime { @@ -1543,7 +1570,7 @@ pub const PostgresSQLQuery = struct { if (ref_count == 1) { this.deinit(); - bun.default_allocator.free(this); + bun.default_allocator.destroy(this); } } @@ -1559,7 +1586,9 @@ pub const PostgresSQLQuery = struct { } this.deref(); - JSC.VirtualMachine.get().rareData().postgresql_context.onQueryResolveFn.get().?.callWithThis( + var vm = JSC.VirtualMachine.get(); + // TODO: error handling + _ = vm.rareData().postgresql_context.onQueryResolveFn.get().?.callWithThis( globalObject, this.thisValue, &[_]JSC.JSValue{ @@ -1581,13 +1610,18 @@ pub const PostgresSQLQuery = struct { b.allocate(bun.default_allocator) catch {}; for (err.messages.items) |msg| { - _ = b.append(msg); + var str = switch (msg) { + inline else => |m| m.toUTF8(bun.default_allocator), + }; + defer str.deinit(); + _ = b.append(str.slice()); _ = b.append("\n"); } const instance = globalObject.createSyntaxErrorInstance("Postgres error occurred\n{s}", .{b.allocatedSlice()}); b.deinit(bun.default_allocator); this.deref(); + // TODO: error handling _ = JSC.VirtualMachine.get().rareData().postgresql_context.onQueryRejectFn.get().?.callWithThis( globalObject, this.thisValue, @@ -1605,6 +1639,7 @@ pub const PostgresSQLQuery = struct { const pending_value = PostgresSQLQuery.pendingValueGetCached(this.thisValue) orelse JSC.JSValue.undefined; this.deref(); + // TODO: error handling _ = JSC.VirtualMachine.get().rareData().postgresql_context.onQueryResolveFn.get().?.callWithThis( globalObject, this.thisValue, @@ -1616,10 +1651,15 @@ pub const PostgresSQLQuery = struct { pub fn constructor(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) ?*PostgresSQLQuery { _ = callframe; - globalThis.throw("PostgresSQLQuery cannot be constructed directly"); + globalThis.throw("PostgresSQLQuery cannot be constructed directly", .{}); return null; } + pub fn estimatedSize(this: *PostgresSQLQuery) callconv(.C) usize { + _ = this; + return @sizeOf(PostgresSQLQuery); + } + pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSC.JSValue { const arguments = callframe.arguments(2).slice(); const query = arguments[0]; @@ -1631,7 +1671,7 @@ pub const PostgresSQLQuery = struct { } if (values.jsType() != .Array) { - globalThis.throwTypeError("values must be an array", .{}); + globalThis.throw("values must be an array", .{}); return .zero; } @@ -1640,7 +1680,7 @@ pub const PostgresSQLQuery = struct { return .zero; }; - const this_value = JSC.Codegen.JSPostgresSQLQuery.toJS(globalThis); + const this_value = ptr.toJS(globalThis); this_value.ensureStillAlive(); PostgresSQLQuery.bindingSetCached(this_value, globalThis, values); @@ -1655,7 +1695,7 @@ pub const PostgresSQLQuery = struct { } pub fn push(this: *PostgresSQLQuery, globalThis: *JSC.JSGlobalObject, value: JSC.JSValue) void { - var pending_value = PostgresSQLQuery.pendingValueGetCached(this.thisValue); + var pending_value = PostgresSQLQuery.pendingValueGetCached(this.thisValue) orelse JSC.JSValue.zero; if (pending_value.isEmptyOrUndefinedOrNull()) { pending_value = JSC.JSValue.createEmptyArray(globalThis, 0); PostgresSQLQuery.pendingValueSetCached(this.thisValue, globalThis, pending_value); @@ -1668,7 +1708,7 @@ pub const PostgresSQLQuery = struct { var arguments_ = callframe.arguments(2); const arguments = arguments_.slice(); var connection = arguments[0].as(PostgresSQLConnection) orelse { - globalObject.throwTypeError("connection must be a PostgresSQLConnection"); + globalObject.throw("connection must be a PostgresSQLConnection", .{}); return .zero; }; var query = arguments[1]; @@ -1744,7 +1784,7 @@ pub const PostgresRequest = struct { var iter = JSC.JSArrayIterator.init(values_array, globalObject); - try writer.i16(@truncate(iter.len)); + try writer.i16(@intCast(iter.len)); while (iter.next()) |value| { if (value.isUndefinedOrNull()) { @@ -1752,7 +1792,8 @@ pub const PostgresRequest = struct { continue; } - const tag = types.Tag.fromJS(globalObject, value); + const tag = try types.Tag.fromJS(globalObject, value); + switch (tag) { .bytea, .number => { try writer.i16(0); @@ -1763,11 +1804,11 @@ pub const PostgresRequest = struct { } } - try writer.i16(@truncate(iter.len)); + try writer.i16(@intCast(iter.len)); iter = JSC.JSArrayIterator.init(values_array, globalObject); - debug("Bind: {s} ({d})", .{ name, values_array, iter.len }); + debug("Bind: {s} ({d})", .{ name, iter.len }); while (iter.next()) |value| { if (value.isUndefinedOrNull()) { @@ -1777,7 +1818,7 @@ pub const PostgresRequest = struct { continue; } - const tag = types.Tag.fromJS(globalObject, value); + const tag = try types.Tag.fromJS(globalObject, value); switch (tag) { .number => { debug(" -> {s}", .{@tagName(tag)}); @@ -1811,7 +1852,7 @@ pub const PostgresRequest = struct { if (value.asArrayBuffer(globalObject)) |buf| { bytes = buf.byteSlice(); } - try writer.i32(@truncate(bytes.len)); + try writer.i32(@intCast(bytes.len)); debug(" -> {s}: {d}", .{ @tagName(tag), bytes.len }); try writer.bytes(bytes); @@ -1820,12 +1861,12 @@ pub const PostgresRequest = struct { debug(" -> string", .{}); // TODO: check if this leaks var str = value.toBunString(globalObject); - try writer.str(str); + try writer.String(str); }, } } - try writer.pwrite(&@byteSwap(std.mem.toBytes(@as(i32, @intCast(writer.offset())))), length_offset); + try writer.pwrite(&std.mem.toBytes(@byteSwap(@as(i32, @intCast(writer.offset())))), length_offset); } pub fn writeQuery( @@ -1836,15 +1877,13 @@ pub const PostgresRequest = struct { writer: protocol.NewWriter(Context), ) !void { { - var query_str = query.toUTF8(bun.default_allocator); - defer query_str.deinit(); var q = protocol.Parse{ .name = name, - .paramters = params, + .params = params, .query = query, }; try q.writeInternal(Context, writer); - debug("Parse: {s}", .{query_str}); + debug("Parse: {s}", .{query}); } { @@ -1867,12 +1906,12 @@ pub const PostgresRequest = struct { ) !Signature { var query_ = query.toUTF8(bun.default_allocator); defer query_.deinit(); - var signature = try Signature.generate(globalObject, query_, array_value); + var signature = try Signature.generate(globalObject, query_.slice(), array_value); errdefer { signature.deinit(); } - try writeQuery(query, signature.name, signature.fields, Context, writer); + try writeQuery(query_.slice(), signature.name, signature.fields, Context, writer); try writeBind(signature.name, bun.String.empty, globalObject, array_value, Context, writer); var exec = protocol.Execute{ .p = .{ @@ -1881,7 +1920,7 @@ pub const PostgresRequest = struct { }; try exec.writeInternal(Context, writer); - try writer.write(protocol.Flush); + try writer.write(&protocol.Flush); return signature; } @@ -1901,7 +1940,7 @@ pub const PostgresRequest = struct { }; try exec.writeInternal(Context, writer); - try writer.write(protocol.Flush); + try writer.write(&protocol.Flush); } pub fn onData( @@ -1915,7 +1954,7 @@ pub const PostgresRequest = struct { switch (try reader.int(u8)) { 'D' => try connection.on(.DataRow, Context, reader), 'd' => try connection.on(.CopyData, Context, reader), - 'S' => try connection.on(.ParameterStatus, Context), + 'S' => try connection.on(.ParameterStatus, Context, reader), 'Z' => try connection.on(.ReadyForQuery, Context, reader), 'C' => try connection.on(.CommandComplete, Context, reader), '2' => try connection.on(.BindComplete, Context, reader), @@ -1937,8 +1976,8 @@ pub const PostgresRequest = struct { else => |c| { debug("Unknown message: {d}", .{c}); - const to_skip = try reader.length(); - try reader.skip(to_skip); + const to_skip = try reader.length() -| 1; + try reader.skip(@intCast(@max(to_skip, 0))); }, } } @@ -2009,7 +2048,7 @@ pub const PostgresSQLConnection = struct { .connected => { const on_connect = this.on_connect.swap(); if (on_connect == .zero) return; - on_connect.callWithThis( + _ = on_connect.callWithThis( this.globalObject, this.js_value, &[_]JSC.JSValue{ @@ -2038,15 +2077,15 @@ pub const PostgresSQLConnection = struct { pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void { defer this.updateHasPendingActivity(); if (this.status == .failed) return; - debug("failed: {s}: {s}", .{ message, err }); + debug("failed: {s}: {s}", .{ message, @errorName(err) }); this.status = .failed; if (!this.socket.isClosed()) this.socket.close(); const on_close = this.on_close.swap(); if (on_close == .zero) return; const instance = this.globalObject.createErrorInstance("{s}", .{message}); - instance.put(this.globalObject, &JSC.ZigString.init("code"), bun.String.init(@errorName(err))); - on_close.callWithThis( + instance.put(this.globalObject, &JSC.ZigString.init("code"), bun.String.init(@errorName(err)).toJSConst(this.globalObject)); + _ = on_close.callWithThis( this.globalObject, this.js_value, &[_]JSC.JSValue{ @@ -2167,6 +2206,8 @@ pub const PostgresSQLConnection = struct { .options = options, .options_buf = options_buf, .socket = undefined, + .requests = PostgresRequest.Queue.init(bun.default_allocator), + .statements = PreparedStatementsMap{}, }; ptr.socket = socket: { @@ -2175,22 +2216,22 @@ pub const PostgresSQLConnection = struct { if (tls_object.isEmptyOrUndefinedOrNull()) { var ctx = vm.rareData().postgresql_context.tcp orelse brk: { var ctx_ = uws.us_create_bun_socket_context(0, vm.event_loop_handle, @sizeOf(*PostgresSQLConnection), uws.us_bun_socket_context_options_t{}).?; - uws.NewSocketHandler(false).configure(ctx_, false, *PostgresSQLConnection, SocketHandler(false)); + uws.NewSocketHandler(false).configure(ctx_, false, PostgresSQLConnection, SocketHandler(false)); vm.rareData().postgresql_context.tcp = ctx_; break :brk ctx_; }; break :socket Socket{ - .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, *PostgresSQLConnection, ptr) orelse { + .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr) orelse { globalObject.throwError(error.ConnectionFailed, "failed to connect to postgresql"); ptr.deinit(); - return .zero; + return null; }, }; } else { // TODO: globalObject.throwTODO("TLS is not supported yet"); ptr.deinit(); - return .zero; + return null; } }; @@ -2205,34 +2246,40 @@ pub const PostgresSQLConnection = struct { return Socket{ .SocketTLS = s }; } - return Socket{ .Socket = s }; + return Socket{ .SocketTCP = s }; } pub fn onOpen(this: *PostgresSQLConnection, socket: SocketType) void { this.onOpen(_socket(socket)); } pub fn onClose(this: *PostgresSQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { - this.onClose(_socket(socket)); + _ = socket; + this.onClose(); } pub fn onEnd(this: *PostgresSQLConnection, socket: SocketType) void { - this.onClose(_socket(socket)); + _ = socket; + this.onClose(); } pub fn onConnectError(this: *PostgresSQLConnection, socket: SocketType, _: i32) void { - this.onClose(_socket(socket)); + _ = socket; + this.onClose(); } pub fn onTimeout(this: *PostgresSQLConnection, socket: SocketType) void { - this.onTimeout(_socket(socket)); + _ = socket; + this.onTimeout(); } pub fn onData(this: *PostgresSQLConnection, socket: SocketType, data: []const u8) void { - this.onData(_socket(socket), data); + _ = socket; + this.onData(data); } pub fn onWritable(this: *PostgresSQLConnection, socket: SocketType) void { - this.onDrain(_socket(socket)); + _ = socket; + this.onDrain(); } }; } @@ -2242,14 +2289,16 @@ pub const PostgresSQLConnection = struct { this.ref_count += 1; } - pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) void { + pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSC.JSValue { this.poll_ref.ref(this.globalObject.bunVM()); this.updateHasPendingActivity(); + return .undefined; } - pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) void { + pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSC.JSValue { this.poll_ref.unref(this.globalObject.bunVM()); this.updateHasPendingActivity(); + return .undefined; } pub fn deref(this: *@This()) void { @@ -2262,16 +2311,18 @@ pub const PostgresSQLConnection = struct { } } - pub fn doClose(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) void { + pub fn doClose(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSC.JSValue { _ = globalObject; this.disconnect(); this.write_buffer.deinit(bun.default_allocator); + + return .undefined; } pub fn deinit(this: *@This()) void { var iter = this.statements.valueIterator(); - while (iter.next()) |stmt| { - stmt.connection = null; + while (iter.next()) |stmt_ptr| { + var stmt = stmt_ptr.*; stmt.deref(); } this.statements.deinit(bun.default_allocator); @@ -2279,6 +2330,7 @@ pub const PostgresSQLConnection = struct { this.read_buffer.deinit(bun.default_allocator); this.on_close.deinit(); this.on_connect.deinit(); + this.backend_parameters.deinit(); bun.default_allocator.free(this.options_buf); bun.default_allocator.destroy(this); } @@ -2286,8 +2338,8 @@ pub const PostgresSQLConnection = struct { pub fn disconnect(this: *@This()) void { if (this.status == .connected) { this.status = .disconnected; - this.poll_ref.deactivate(this.globalObject.bunVM().event_loop_handle); - this.socket.close(0, null); + this.poll_ref.disable(); + this.socket.close(); } } @@ -2328,9 +2380,11 @@ pub const PostgresSQLConnection = struct { connection: *PostgresSQLConnection, pub fn markMessageStart(this: Reader) void { - this.last_message_start = this.connection.read_buffer.head; + this.connection.last_message_start = this.connection.read_buffer.head; } + pub const ensureLength = ensureCapacity; + pub fn peek(this: Reader) []const u8 { return this.connection.read_buffer.remaining(); } @@ -2377,12 +2431,10 @@ pub const PostgresSQLConnection = struct { globalObject: *JSC.JSGlobalObject, fields: []const protocol.FieldDescription, - pub fn put(this: *const CellPutter, index_: i16, optional_bytes: ?*Data) anyerror!bool { - const index: u32 = @intCast(index_); - + pub fn put(this: *const CellPutter, index: u32, optional_bytes: ?*Data) anyerror!bool { const putDirectOffset = JSC.JSObject.putDirectOffset; var bytes_ = optional_bytes orelse { - putDirectOffset(this.vm, this.object, index, JSC.JSValue.jsNull()); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsNull()); return true; }; defer bytes_.deinit(); @@ -2392,40 +2444,40 @@ pub const PostgresSQLConnection = struct { .number => { switch (bytes.len) { 0 => { - putDirectOffset(this.vm, index, JSC.JSValue.jsNull()); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsNull()); }, 2 => { - putDirectOffset(this.vm, index, JSC.JSValue.jsNumber(@bitCast(@as(i16, @bitCast(bytes[0..2]))))); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsNumber(@as(i32, @as(i16, @bitCast(bytes[0..2].*))))); }, 4 => { - putDirectOffset(this.vm, index, JSC.JSValue.jsNumber(@bitCast(@as(i32, @bitCast(bytes[0..4]))))); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsNumber(@as(i32, @bitCast(bytes[0..4].*)))); }, else => { var eight: usize = 0; @memcpy(@as(*[8]u8, @ptrCast(&eight))[0..bytes.len], bytes[0..@min(8, bytes.len)]); eight = @byteSwap(eight); - putDirectOffset(this.vm, index, JSC.JSValue.jsNumber(@bitCast(@as(f64, eight)))); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsNumber(@as(f64, @bitCast(eight)))); }, } }, .json => { - var str = bun.String.fromUTF8(bytes.slice()); + var str = bun.String.fromUTF8(bytes); defer str.deref(); - putDirectOffset(this.vm, index, str.toJSForParseJSON(this.globalObject)); + putDirectOffset(this.object, this.vm, index, str.toJSForParseJSON(this.globalObject)); }, .boolean => { - putDirectOffset(this.vm, index, JSC.JSValue.jsBoolean(bytes.len > 0 and bytes[0] == 't')); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.jsBoolean(bytes.len > 0 and bytes[0] == 't')); }, .time, .datetime, .date => { - putDirectOffset(this.vm, index, JSC.JSValue.fromDateString(bytes_.sliceZ())); + putDirectOffset(this.object, this.vm, index, JSC.JSValue.fromDateString(this.globalObject, bytes_.sliceZ())); }, .bytea => { - putDirectOffset(this.vm, index, JSC.JSValue.createBuffer(this.globalObject, bytes, null)); + putDirectOffset(this.object, this.vm, index, JSC.ArrayBuffer.createBuffer(this.globalObject, bytes)); }, else => { - var str = bun.String.fromUTF8(bytes.slice()); + var str = bun.String.fromUTF8(bytes); defer str.deref(); - putDirectOffset(this.vm, index, str.toJS(this.globalObject)); + putDirectOffset(this.object, this.vm, index, str.toJS(this.globalObject)); }, } return true; @@ -2433,7 +2485,7 @@ pub const PostgresSQLConnection = struct { }; pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.EnumLiteral), comptime Context: type, reader: protocol.NewReader(Context)) !void { - debug("on({s})", .{@typeName(MessageType)}); + debug("on({s})", .{@tagName(MessageType)}); if (comptime MessageType != .ReadyForQuery) { this.is_ready_for_query = false; } @@ -2445,7 +2497,7 @@ pub const PostgresSQLConnection = struct { var structure = statement.structure(this.globalObject); std.debug.assert(!structure.isEmptyOrUndefinedOrNull()); - var row = JSC.JSObject.uninitialized(structure.asCell(), this.globalObject); + var row = JSC.JSObject.uninitialized(this.globalObject, structure); row.ensureStillAlive(); var putter = CellPutter{ .object = row, @@ -2459,7 +2511,7 @@ pub const PostgresSQLConnection = struct { reader, CellPutter.put, ); - request.push(row); + request.push(this.globalObject, row); }, .CopyData => { var copy_data: protocol.CopyData = undefined; @@ -2495,7 +2547,7 @@ pub const PostgresSQLConnection = struct { } debug("-> {s}", .{cmd.command_tag.slice()}); _ = this.requests.discard(1); - request.onSuccess(cmd.command_tag.slice(), this, this.globalObject); + request.onSuccess(cmd.command_tag.slice(), this.globalObject); }, .BindComplete => { try reader.eatMessage(protocol.BindComplete); @@ -2530,7 +2582,7 @@ pub const PostgresSQLConnection = struct { try reader.eatMessage(protocol.NoData); var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); - request.onNoData(this, this.globalObject); + request.onNoData(this.globalObject); }, .BackendKeyData => { try this.backend_key_data.decodeInternal(Context, reader); @@ -2543,20 +2595,20 @@ pub const PostgresSQLConnection = struct { } var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); - request.onError(err, this, this.globalObject); + request.onError(err, this.globalObject); }, .PortalSuspended => { - try reader.eatMessage(protocol.PortalSuspended); - var request = this.current() orelse return error.ExpectedRequest; - _ = request; - _ = this.requests.discard(1); + // try reader.eatMessage(&protocol.PortalSuspended); + // var request = this.current() orelse return error.ExpectedRequest; + // _ = request; + // _ = this.requests.discard(1); debug("TODO PortalSuspended", .{}); }, .CloseComplete => { try reader.eatMessage(protocol.CloseComplete); var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); - request.onSuccess("CLOSECOMPLETE", this, this.globalObject); + request.onSuccess("CLOSECOMPLETE", this.globalObject); }, .CopyInResponse => { debug("TODO CopyInResponse", .{}); @@ -2572,7 +2624,7 @@ pub const PostgresSQLConnection = struct { try reader.eatMessage(protocol.EmptyQueryResponse); var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); - request.onSuccess("", this, this.globalObject); + request.onSuccess("", this.globalObject); }, .CopyOutResponse => { debug("TODO CopyOutResponse", .{}); @@ -2583,8 +2635,29 @@ pub const PostgresSQLConnection = struct { .CopyBothResponse => { debug("TODO CopyBothResponse", .{}); }, + else => @compileError("Unknown message type: " ++ @tagName(MessageType)), } } + + pub fn doFlush(this: *PostgresSQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSC.JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .undefined; + } + + pub fn createQuery(this: *PostgresSQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSC.JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .undefined; + } + + pub fn getConnected(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject) callconv(.C) JSC.JSValue { + return JSC.JSValue.jsBoolean(this.status == Status.connected); + } }; pub const PostgresSQLStatement = struct { @@ -2611,7 +2684,7 @@ pub const PostgresSQLStatement = struct { std.debug.assert(this.ref_count == 0); for (this.fields) |*field| { - field.deinit(); + @constCast(field).deinit(); } bun.default_allocator.free(this.fields); bun.default_allocator.free(this.parameters); @@ -2630,7 +2703,7 @@ pub const PostgresSQLStatement = struct { bun.default_allocator.free(names); } for (this.fields, names) |*field, *name| { - name.* = String.createAtomIfPossible(field.name); + name.* = String.createAtomIfPossible(field.name.slice()); } var structure_ = JSC.JSObject.createStructure( globalObject, @@ -2675,23 +2748,23 @@ const Signature = struct { while (iter.next()) |value| { if (value.isUndefinedOrNull()) { - try fields.append(@byteSwap(-1)); - try name.append(".null"); + try fields.append(@byteSwap(@as(i32, -1))); + try name.appendSlice(".null"); continue; } - const tag = types.Tag.fromJS(globalObject, value); + const tag = try types.Tag.fromJS(globalObject, value); try fields.append(@byteSwap(@intFromEnum(tag))); switch (tag) { - .number => try name.append(".number"), - .json => try name.append(".json"), - .boolean => try name.append(".boolean"), - .date => try name.append(".date"), - .datetime => try name.append(".datetime"), - .time => try name.append(".time"), - .bytea => try name.append(".bytea"), - .bigint => try name.append(".bigint"), - else => try name.append(".string"), + .number => try name.appendSlice(".number"), + .json => try name.appendSlice(".json"), + .boolean => try name.appendSlice(".boolean"), + .date => try name.appendSlice(".date"), + .datetime => try name.appendSlice(".datetime"), + .time => try name.appendSlice(".time"), + .bytea => try name.appendSlice(".bytea"), + .bigint => try name.appendSlice(".bigint"), + else => try name.appendSlice(".string"), } } |