diff options
author | 2023-10-15 04:56:05 -0700 | |
---|---|---|
committer | 2023-10-15 04:56:05 -0700 | |
commit | ab035d6e82f2c4d25831e709707099ef768dba79 (patch) | |
tree | f3ed14c51ea7c9d8e968104083ea34727bf9f0cc | |
parent | ad7e90ae1bb840def28308c9c8e87ee292305be6 (diff) | |
download | bun-ab035d6e82f2c4d25831e709707099ef768dba79.tar.gz bun-ab035d6e82f2c4d25831e709707099ef768dba79.tar.zst bun-ab035d6e82f2c4d25831e709707099ef768dba79.zip |
Further
-rw-r--r-- | src/sql/postgres.zig | 50 |
1 files changed, 30 insertions, 20 deletions
diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index df6a976b6..4352f7db3 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -100,13 +100,17 @@ pub const protocol = struct { this.offset.* += count; } pub fn ensureCapacity(this: StackReader, count: usize) bool { - return this.buffer.len > (this.offset.* + count); + return this.buffer.len >= (this.offset.* + count); } pub fn read(this: StackReader, count: usize) anyerror!Data { const offset = this.offset.*; + if (!this.ensureCapacity(count)) { + return error.ShortRead; + } + this.skip(count); return Data{ - .temporary = this.buffer[offset .. offset + count], + .temporary = this.buffer[offset..this.offset.*], }; } pub fn readZ(this: StackReader) anyerror!Data { @@ -355,9 +359,12 @@ pub const protocol = struct { pub inline fn eatMessage(this: @This(), comptime msg_: anytype) anyerror!void { const msg = msg_[1..]; + try this.ensureCapacity(msg.len); + var input = try readFn(this.wrapped, msg.len); defer input.deinit(); if (bun.strings.eqlComptime(input.slice(), msg)) return; + return error.InvalidMessage; } pub fn skip(this: @This(), count: usize) anyerror!void { @@ -729,10 +736,6 @@ pub const protocol = struct { const length = try reader.length(); std.debug.assert(length >= 4); - if (length != 5) { - return error.InvalidMessageLength; - } - const status = try reader.int(u8); this.* = .{ .status = @enumFromInt(status), @@ -933,7 +936,7 @@ pub const protocol = struct { writer: NewWriter(Context), ) !void { const parameters = this.params; - const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)); + const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)) + zCount(this.name) + zCount(this.query); const header = [_]u8{ 'P', } ++ toBytes(Int32(count)); @@ -1078,7 +1081,7 @@ pub const protocol = struct { const database = this.database.slice(); const options = this.options.slice(); - const count: usize = @sizeOf((int32)) + @sizeOf((int32)) + zCount("user", user) + zCount("database", database) + zCount("client_encoding", "UTF8") + zCount("", options) + 1; + const count: usize = @sizeOf((int32)) + @sizeOf((int32)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + zFieldCount("", options) + 1; const header = toBytes(Int32(@as(u32, @truncate(count)))); try writer.write(&header); @@ -1109,20 +1112,20 @@ pub const protocol = struct { pub const write = writeWrap(@This(), writeInternal).write; }; - fn zCount(prefix: []const u8, slice: []const u8) usize { - if (slice.len > 0) { - return slice.len + 1 + prefix.len + 1; - } + fn zCount(slice: []const u8) usize { + return if (slice.len > 0) slice.len + 1 else 0; + } - if (prefix.len > 0) { - return prefix.len + 1; + fn zFieldCount(prefix: []const u8, slice: []const u8) usize { + if (slice.len > 0) { + return zCount(prefix) + zCount(slice); } - return 0; + return zCount(prefix); } pub const Execute = struct { - max_rows: int32 = 0, + max_rows: int32 = std.math.maxInt(int32), p: PortalOrPreparedStatement, pub fn writeInternal( @@ -1152,13 +1155,14 @@ pub const protocol = struct { writer: NewWriter(Context), ) !void { const message = this.p.slice(); - const count: usize = @sizeOf((u32)) + @sizeOf((u32)) + message.len + 1; + const count: usize = @sizeOf((u32)) + @sizeOf((u32)) + message.len + 2; const header = [_]u8{ 'D', } ++ toBytes(Int32(count)); try writer.write(&header); try writer.write(&[_]u8{ this.p.tag(), + 0, }); try writer.string(message); } @@ -1825,7 +1829,7 @@ pub const PostgresRequest = struct { iter = JSC.JSArrayIterator.init(values_array, globalObject); - debug("Bind: {s} ({d})", .{ name, iter.len }); + debug("Bind: {} ({d})", .{ bun.strings.QuotedFormatter{ .text = name }, iter.len }); while (iter.next()) |value| { if (value.isUndefinedOrNull()) { @@ -1900,7 +1904,7 @@ pub const PostgresRequest = struct { .query = query, }; try q.writeInternal(Context, writer); - debug("Parse: {s}", .{query}); + debug("Parse: {}", .{bun.strings.QuotedFormatter{ .text = query }}); } { @@ -1910,7 +1914,7 @@ pub const PostgresRequest = struct { }, }; try d.writeInternal(Context, writer); - debug("Describe", .{}); + debug("Describe: {}", .{bun.strings.QuotedFormatter{ .text = name }}); } } @@ -2146,6 +2150,7 @@ pub const PostgresSQLConnection = struct { PostgresRequest.onData(this, protocol.StackReader, reader) catch |err| { if (err == error.ShortRead) { this.read_buffer.head = 0; + this.last_message_start = 0; this.read_buffer.byte_list.len = 0; this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); } else { @@ -2558,6 +2563,9 @@ pub const PostgresSQLConnection = struct { try this.backend_parameters.insert(parameter_status.name.slice(), parameter_status.value.slice()); }, .ReadyForQuery => { + var ready_for_query: protocol.ReadyForQuery = undefined; + try ready_for_query.decodeInternal(Context, reader); + if (this.pending_disconnect) { this.disconnect(); return; @@ -2770,6 +2778,8 @@ const Signature = struct { var fields = std.ArrayList(int32).init(bun.default_allocator); var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); + name.appendSliceAssumeCapacity(query); + errdefer { fields.deinit(); name.deinit(); |