diff options
author | 2021-09-09 05:40:07 -0700 | |
---|---|---|
committer | 2021-09-09 05:40:07 -0700 | |
commit | 8a02ad48a5eb1319c1bf3e9eb97e013924db875f (patch) | |
tree | 0d82d072026501e6e8086712764a1b9b45f87218 /src | |
parent | c30ec608b1484628cc28b4811b9d62e1c142b281 (diff) | |
download | bun-8a02ad48a5eb1319c1bf3e9eb97e013924db875f.tar.gz bun-8a02ad48a5eb1319c1bf3e9eb97e013924db875f.tar.zst bun-8a02ad48a5eb1319c1bf3e9eb97e013924db875f.zip |
fetc h!!!
Diffstat (limited to 'src')
29 files changed, 6726 insertions, 161 deletions
diff --git a/src/bundler.zig b/src/bundler.zig index 8b13df4a2..49ece989f 100644 --- a/src/bundler.zig +++ b/src/bundler.zig @@ -1114,7 +1114,15 @@ pub fn NewBundler(cache_files: bool) type { ); var module_name = file_path.text["/bun-vfs/node_modules/".len..]; - module_name = module_name[0..strings.indexOfChar(module_name, '/').?]; + + if (module_name[0] == '@') { + var end = strings.indexOfChar(module_name, '/').? + 1; + end += strings.indexOfChar(module_name[end..], '/').?; + + module_name = module_name[0..end]; + } else { + module_name = module_name[0..strings.indexOfChar(module_name, '/').?]; + } if (NodeFallbackModules.Map.get(module_name)) |mod| { break :brk CacheEntry{ .contents = mod.code.* }; diff --git a/src/deps/iguanaTLS/.gitattributes b/src/deps/iguanaTLS/.gitattributes new file mode 100644 index 000000000..0cb064aeb --- /dev/null +++ b/src/deps/iguanaTLS/.gitattributes @@ -0,0 +1 @@ +*.zig text=auto eol=lf diff --git a/src/deps/iguanaTLS/.gitignore b/src/deps/iguanaTLS/.gitignore new file mode 100644 index 000000000..b78ba5f90 --- /dev/null +++ b/src/deps/iguanaTLS/.gitignore @@ -0,0 +1,3 @@ +/zig-cache
+deps.zig
+gyro.lock
diff --git a/src/deps/iguanaTLS/.gitmodules b/src/deps/iguanaTLS/.gitmodules new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/deps/iguanaTLS/.gitmodules diff --git a/src/deps/iguanaTLS/LICENSE b/src/deps/iguanaTLS/LICENSE new file mode 100644 index 000000000..f830ca857 --- /dev/null +++ b/src/deps/iguanaTLS/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexandros Naskos + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/deps/iguanaTLS/build.zig b/src/deps/iguanaTLS/build.zig new file mode 100644 index 000000000..9aed4bf0d --- /dev/null +++ b/src/deps/iguanaTLS/build.zig @@ -0,0 +1,14 @@ +const Builder = @import("std").build.Builder; + +pub fn build(b: *Builder) void { + const mode = b.standardReleaseOptions(); + const lib = b.addStaticLibrary("iguanaTLS", "src/main.zig"); + lib.setBuildMode(mode); + lib.install(); + + var main_tests = b.addTest("src/main.zig"); + main_tests.setBuildMode(mode); + + const test_step = b.step("test", "Run library tests"); + test_step.dependOn(&main_tests.step); +} diff --git a/src/deps/iguanaTLS/src/asn1.zig b/src/deps/iguanaTLS/src/asn1.zig new file mode 100644 index 000000000..8d43b38e9 --- /dev/null +++ b/src/deps/iguanaTLS/src/asn1.zig @@ -0,0 +1,631 @@ +const std = @import("std"); +const BigInt = std.math.big.int.Const; +const mem = std.mem; +const Allocator = mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +// zig fmt: off +pub const Tag = enum(u8) { + bool = 0x01, + int = 0x02, + bit_string = 0x03, + octet_string = 0x04, + @"null" = 0x05, + object_identifier = 0x06, + utf8_string = 0x0c, + printable_string = 0x13, + ia5_string = 0x16, + utc_time = 0x17, + bmp_string = 0x1e, + sequence = 0x30, + set = 0x31, + // Bogus value + context_specific = 0xff, +}; +// zig fmt: on + +pub const ObjectIdentifier = struct { + data: [16]u32, + len: u8, +}; + +pub const BitString = struct { + data: []const u8, + bit_len: usize, +}; + +pub const Value = union(Tag) { + bool: bool, + int: BigInt, + bit_string: BitString, + octet_string: []const u8, + @"null", + // @TODO Make this []u32, owned? + object_identifier: ObjectIdentifier, + utf8_string: []const u8, + printable_string: []const u8, + ia5_string: []const u8, + utc_time: []const u8, + bmp_string: []const u16, + sequence: []const @This(), + set: []const @This(), + context_specific: struct { + child: *const Value, + number: u8, + }, + + pub fn deinit(self: @This(), alloc: *Allocator) void { + switch (self) { + .int => |i| alloc.free(i.limbs), + .bit_string => |bs| alloc.free(bs.data), + .octet_string, + .utf8_string, + .printable_string, + .ia5_string, + .utc_time, + => |s| alloc.free(s), + .bmp_string => |s| alloc.free(s), + .sequence, .set => |s| { + for (s) |c| { + c.deinit(alloc); + } + alloc.free(s); + }, + .context_specific => |cs| { + cs.child.deinit(alloc); + alloc.destroy(cs.child); + }, + else => {}, + } + } + + fn formatInternal( + self: Value, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + indents: usize, + writer: anytype, + ) @TypeOf(writer).Error!void { + try writer.writeByteNTimes(' ', indents); + switch (self) { + .bool => |b| try writer.print("BOOLEAN {}\n", .{b}), + .int => |i| { + try writer.writeAll("INTEGER "); + try i.format(fmt, options, writer); + try writer.writeByte('\n'); + }, + .bit_string => |bs| { + try writer.print("BIT STRING ({} bits) ", .{bs.bit_len}); + const bits_to_show = std.math.min(8 * 3, bs.bit_len); + const bytes = std.math.divCeil(usize, bits_to_show, 8) catch unreachable; + + var bit_idx: usize = 0; + var byte_idx: usize = 0; + while (byte_idx < bytes) : (byte_idx += 1) { + const byte = bs.data[byte_idx]; + var cur_bit_idx: u3 = 0; + while (bit_idx < bits_to_show) { + const mask = @as(u8, 0x80) >> cur_bit_idx; + try writer.print("{}", .{@boolToInt(byte & mask == mask)}); + cur_bit_idx += 1; + bit_idx += 1; + if (cur_bit_idx == 7) + break; + } + } + if (bits_to_show != bs.bit_len) + try writer.writeAll("..."); + try writer.writeByte('\n'); + }, + .octet_string => |s| try writer.print("OCTET STRING ({} bytes) {X}\n", .{ s.len, s }), + .@"null" => try writer.writeAll("NULL\n"), + .object_identifier => |oid| { + try writer.writeAll("OBJECT IDENTIFIER "); + var i: u8 = 0; + while (i < oid.len) : (i += 1) { + if (i != 0) try writer.writeByte('.'); + try writer.print("{}", .{oid.data[i]}); + } + try writer.writeByte('\n'); + }, + .utf8_string => |s| try writer.print("UTF8 STRING ({} bytes) {}\n", .{ s.len, s }), + .printable_string => |s| try writer.print("PRINTABLE STRING ({} bytes) {}\n", .{ s.len, s }), + .ia5_string => |s| try writer.print("IA5 STRING ({} bytes) {}\n", .{ s.len, s }), + .utc_time => |s| try writer.print("UTC TIME {}\n", .{s}), + .bmp_string => |s| try writer.print("BMP STRING ({} words) {}\n", .{ + s.len, + @ptrCast([*]const u16, s.ptr)[0 .. s.len * 2], + }), + .sequence => |children| { + try writer.print("SEQUENCE ({} elems)\n", .{children.len}); + for (children) |child| try child.formatInternal(fmt, options, indents + 2, writer); + }, + .set => |children| { + try writer.print("SET ({} elems)\n", .{children.len}); + for (children) |child| try child.formatInternal(fmt, options, indents + 2, writer); + }, + .context_specific => |cs| { + try writer.print("[{}]\n", .{cs.number}); + try cs.child.formatInternal(fmt, options, indents + 2, writer); + }, + } + } + + pub fn format(self: Value, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + try self.formatInternal(fmt, options, 0, writer); + } +}; + +/// Distinguished encoding rules +pub const der = struct { + pub fn DecodeError(comptime Reader: type) type { + return Reader.Error || error{ + OutOfMemory, + EndOfStream, + InvalidLength, + InvalidTag, + InvalidContainerLength, + DoesNotMatchSchema, + }; + } + + fn DERReaderState(comptime Reader: type) type { + return struct { + der_reader: Reader, + length: usize, + idx: usize = 0, + }; + } + + fn DERReader(comptime Reader: type) type { + const S = struct { + pub fn read(state: *DERReaderState(Reader), buffer: []u8) DecodeError(Reader)!usize { + const out_bytes = std.math.min(buffer.len, state.length - state.idx); + const res = try state.der_reader.readAll(buffer[0..out_bytes]); + state.idx += res; + return res; + } + }; + + return std.io.Reader(*DERReaderState(Reader), DecodeError(Reader), S.read); + } + + pub fn parse_schema( + schema: anytype, + captures: anytype, + der_reader: anytype, + ) !void { + const res = try parse_schema_tag_len_internal(null, null, schema, captures, der_reader); + if (res != null) return error.DoesNotMatchSchema; + } + + pub fn parse_schema_tag_len( + existing_tag_byte: ?u8, + existing_length: ?usize, + schema: anytype, + captures: anytype, + der_reader: anytype, + ) !void { + const res = try parse_schema_tag_len_internal( + existing_tag_byte, + existing_length, + schema, + captures, + der_reader, + ); + if (res != null) return error.DoesNotMatchSchema; + } + + const TagLength = struct { + tag: u8, + length: usize, + }; + + pub fn parse_schema_tag_len_internal( + existing_tag_byte: ?u8, + existing_length: ?usize, + schema: anytype, + captures: anytype, + der_reader: anytype, + ) !?TagLength { + const Reader = @TypeOf(der_reader); + + const isEnumLit = comptime std.meta.trait.is(.EnumLiteral); + comptime var tag_idx = 0; + + const has_capture = comptime isEnumLit(@TypeOf(schema[tag_idx])) and schema[tag_idx] == .capture; + if (has_capture) tag_idx += 2; + + const is_optional = comptime isEnumLit(@TypeOf(schema[tag_idx])) and schema[tag_idx] == .optional; + if (is_optional) tag_idx += 1; + + const tag_literal = schema[tag_idx]; + comptime std.debug.assert(isEnumLit(@TypeOf(tag_literal))); + + const tag_byte = existing_tag_byte orelse (der_reader.readByte() catch |err| switch (err) { + error.EndOfStream => return if (is_optional) null else error.EndOfStream, + else => |e| return e, + }); + + const length = existing_length orelse try parse_length(der_reader); + if (tag_literal == .sequence_of) { + if (tag_byte != @enumToInt(Tag.sequence)) { + if (is_optional) return TagLength{ .tag = tag_byte, .length = length }; + return error.InvalidTag; + } + + var curr_tag_length: ?TagLength = null; + const sub_schema = schema[tag_idx + 1]; + while (true) { + if (curr_tag_length == null) { + curr_tag_length = .{ + .tag = der_reader.readByte() catch |err| switch (err) { + error.EndOfStream => { + curr_tag_length = null; + break; + }, + else => |e| return e, + }, + .length = try parse_length(der_reader), + }; + } + + curr_tag_length = parse_schema_tag_len_internal( + curr_tag_length.?.tag, + curr_tag_length.?.length, + sub_schema, + captures, + der_reader, + ) catch |err| switch (err) { + error.DoesNotMatchSchema => break, + else => |e| return e, + }; + } + return curr_tag_length; + } else if (tag_literal == .any) { + if (!has_capture) { + try der_reader.skipBytes(length, .{}); + return null; + } + + var reader_state = DERReaderState(Reader){ + .der_reader = der_reader, + .idx = 0, + .length = length, + }; + var reader = DERReader(@TypeOf(der_reader)){ .context = &reader_state }; + const capture_context = captures[schema[1] * 2]; + const capture_action = captures[schema[1] * 2 + 1]; + try capture_action(capture_context, tag_byte, length, reader); + + // Skip remaining bytes + try der_reader.skipBytes(reader_state.length - reader_state.idx, .{}); + return null; + } else if (tag_literal == .context_specific) { + const cs_number = schema[tag_idx + 1]; + if (tag_byte & 0xC0 == 0x80 and tag_byte - 0xa0 == cs_number) { + if (!has_capture) { + if (schema.len > tag_idx + 2) { + return try parse_schema_tag_len_internal(null, null, schema[tag_idx + 2], captures, der_reader); + } + + try der_reader.skipBytes(length, .{}); + return null; + } + + var reader_state = DERReaderState(Reader){ + .der_reader = der_reader, + .idx = 0, + .length = length, + }; + var reader = DERReader(Reader){ .context = &reader_state }; + const capture_context = captures[schema[1] * 2]; + const capture_action = captures[schema[1] * 2 + 1]; + try capture_action(capture_context, tag_byte, length, reader); + + // Skip remaining bytes + try der_reader.skipBytes(reader_state.length - reader_state.idx, .{}); + return null; + } else if (is_optional) + return TagLength{ .tag = tag_byte, .length = length } + else + return error.DoesNotMatchSchema; + } + + const schema_tag: Tag = tag_literal; + const actual_tag = std.meta.intToEnum(Tag, tag_byte) catch return error.InvalidTag; + if (actual_tag != schema_tag) { + if (is_optional) return TagLength{ .tag = tag_byte, .length = length }; + return error.DoesNotMatchSchema; + } + + const single_seq = schema_tag == .sequence and schema.len == 1; + if ((!has_capture and schema_tag != .sequence) or (!has_capture and single_seq)) { + try der_reader.skipBytes(length, .{}); + return null; + } + + if (has_capture) { + var reader_state = DERReaderState(Reader){ + .der_reader = der_reader, + .idx = 0, + .length = length, + }; + var reader = DERReader(Reader){ .context = &reader_state }; + const capture_context = captures[schema[1] * 2]; + const capture_action = captures[schema[1] * 2 + 1]; + try capture_action(capture_context, tag_byte, length, reader); + + // Skip remaining bytes + try der_reader.skipBytes(reader_state.length - reader_state.idx, .{}); + return null; + } + + var cur_tag_length: ?TagLength = null; + const sub_schemas = schema[tag_idx + 1]; + comptime var i = 0; + inline while (i < sub_schemas.len) : (i += 1) { + const curr_tag = if (cur_tag_length) |tl| tl.tag else null; + const curr_length = if (cur_tag_length) |tl| tl.length else null; + cur_tag_length = try parse_schema_tag_len_internal(curr_tag, curr_length, sub_schemas[i], captures, der_reader); + } + return cur_tag_length; + } + + pub const EncodedLength = struct { + data: [@sizeOf(usize) + 1]u8, + len: usize, + + pub fn slice(self: @This()) []const u8 { + if (self.len == 1) return self.data[0..1]; + return self.data[0 .. 1 + self.len]; + } + }; + + pub fn encode_length(length: usize) EncodedLength { + var enc = EncodedLength{ .data = undefined, .len = 0 }; + if (length < 128) { + enc.data[0] = @truncate(u8, length); + enc.len = 1; + } else { + const bytes_needed = @intCast(u8, std.math.divCeil( + usize, + std.math.log2_int_ceil(usize, length), + 8, + ) catch unreachable); + enc.data[0] = bytes_needed | 0x80; + mem.copy( + u8, + enc.data[1 .. bytes_needed + 1], + mem.asBytes(&length)[0..bytes_needed], + ); + if (std.builtin.target.cpu.arch.endian() != .Big) { + mem.reverse(u8, enc.data[1 .. bytes_needed + 1]); + } + enc.len = bytes_needed; + } + return enc; + } + + fn parse_int_internal(alloc: *Allocator, bytes_read: *usize, der_reader: anytype) !BigInt { + const length = try parse_length_internal(bytes_read, der_reader); + return try parse_int_with_length_internal(alloc, bytes_read, length, der_reader); + } + + pub fn parse_int(alloc: *Allocator, der_reader: anytype) !BigInt { + var bytes: usize = undefined; + return try parse_int_internal(alloc, &bytes, der_reader); + } + + pub fn parse_int_with_length(alloc: *Allocator, length: usize, der_reader: anytype) !BigInt { + var read: usize = 0; + return try parse_int_with_length_internal(alloc, &read, length, der_reader); + } + + fn parse_int_with_length_internal(alloc: *Allocator, bytes_read: *usize, length: usize, der_reader: anytype) !BigInt { + const first_byte = try der_reader.readByte(); + if (first_byte == 0x0 and length > 1) { + // Positive number with highest bit set to 1 in the rest. + const limb_count = std.math.divCeil(usize, length - 1, @sizeOf(usize)) catch unreachable; + const limbs = try alloc.alloc(usize, limb_count); + std.mem.set(usize, limbs, 0); + errdefer alloc.free(limbs); + + var limb_ptr = @ptrCast([*]u8, limbs.ptr); + try der_reader.readNoEof(limb_ptr[0 .. length - 1]); + // We always reverse because the standard library big int expects little endian. + mem.reverse(u8, limb_ptr[0 .. length - 1]); + + bytes_read.* += length; + return BigInt{ .limbs = limbs, .positive = true }; + } + std.debug.assert(length != 0); + // Write first_byte + // Twos complement + const limb_count = std.math.divCeil(usize, length, @sizeOf(usize)) catch unreachable; + const limbs = try alloc.alloc(usize, limb_count); + std.mem.set(usize, limbs, 0); + errdefer alloc.free(limbs); + + var limb_ptr = @ptrCast([*]u8, limbs.ptr); + limb_ptr[0] = first_byte & ~@as(u8, 0x80); + try der_reader.readNoEof(limb_ptr[1..length]); + + // We always reverse because the standard library big int expects little endian. + mem.reverse(u8, limb_ptr[0..length]); + bytes_read.* += length; + return BigInt{ .limbs = limbs, .positive = (first_byte & 0x80) == 0x00 }; + } + + pub fn parse_length(der_reader: anytype) !usize { + var bytes: usize = 0; + return try parse_length_internal(&bytes, der_reader); + } + + fn parse_length_internal(bytes_read: *usize, der_reader: anytype) !usize { + const first_byte = try der_reader.readByte(); + bytes_read.* += 1; + if (first_byte & 0x80 == 0x00) { + // 1 byte value + return first_byte; + } + const length = @truncate(u7, first_byte); + if (length > @sizeOf(usize)) + @panic("DER length does not fit in usize"); + + var res_buf = std.mem.zeroes([@sizeOf(usize)]u8); + try der_reader.readNoEof(res_buf[0..length]); + bytes_read.* += length; + + if (std.builtin.target.cpu.arch.endian() != .Big) { + mem.reverse(u8, res_buf[0..length]); + } + return mem.bytesToValue(usize, &res_buf); + } + + fn parse_value_with_tag_byte( + tag_byte: u8, + alloc: *Allocator, + bytes_read: *usize, + der_reader: anytype, + ) DecodeError(@TypeOf(der_reader))!Value { + const tag = std.meta.intToEnum(Tag, tag_byte) catch { + // tag starts with '0b10...', this is the context specific class. + if (tag_byte & 0xC0 == 0x80) { + const length = try parse_length_internal(bytes_read, der_reader); + var cur_read_bytes: usize = 0; + var child = try alloc.create(Value); + errdefer alloc.destroy(child); + + child.* = try parse_value_internal(alloc, &cur_read_bytes, der_reader); + if (cur_read_bytes != length) + return error.InvalidContainerLength; + bytes_read.* += length; + return Value{ .context_specific = .{ .child = child, .number = tag_byte - 0xa0 } }; + } + + return error.InvalidTag; + }; + switch (tag) { + .bool => { + if ((try der_reader.readByte()) != 0x1) + return error.InvalidLength; + defer bytes_read.* += 2; + return Value{ .bool = (try der_reader.readByte()) != 0x0 }; + }, + .int => return Value{ .int = try parse_int_internal(alloc, bytes_read, der_reader) }, + .bit_string => { + const length = try parse_length_internal(bytes_read, der_reader); + const unused_bits = try der_reader.readByte(); + std.debug.assert(unused_bits < 8); + const bit_count = (length - 1) * 8 - unused_bits; + const bit_memory = try alloc.alloc(u8, std.math.divCeil(usize, bit_count, 8) catch unreachable); + errdefer alloc.free(bit_memory); + try der_reader.readNoEof(bit_memory[0 .. length - 1]); + + bytes_read.* += length; + return Value{ .bit_string = .{ .data = bit_memory, .bit_len = bit_count } }; + }, + .octet_string, .utf8_string, .printable_string, .utc_time, .ia5_string => { + const length = try parse_length_internal(bytes_read, der_reader); + const str_mem = try alloc.alloc(u8, length); + try der_reader.readNoEof(str_mem); + bytes_read.* += length; + return @as(Value, switch (tag) { + .octet_string => .{ .octet_string = str_mem }, + .utf8_string => .{ .utf8_string = str_mem }, + .printable_string => .{ .printable_string = str_mem }, + .utc_time => .{ .utc_time = str_mem }, + .ia5_string => .{ .ia5_string = str_mem }, + else => unreachable, + }); + }, + .@"null" => { + std.debug.assert((try parse_length_internal(bytes_read, der_reader)) == 0x00); + return .@"null"; + }, + .object_identifier => { + const length = try parse_length_internal(bytes_read, der_reader); + const first_byte = try der_reader.readByte(); + var ret = Value{ .object_identifier = .{ .data = undefined, .len = 0 } }; + ret.object_identifier.data[0] = first_byte / 40; + ret.object_identifier.data[1] = first_byte % 40; + + var out_idx: u8 = 2; + var i: usize = 0; + while (i < length - 1) { + var current_value: u32 = 0; + var current_byte = try der_reader.readByte(); + i += 1; + while (current_byte & 0x80 == 0x80) : (i += 1) { + // Increase the base of the previous bytes + current_value *= 128; + // Add the current byte in base 128 + current_value += @as(u32, current_byte & ~@as(u8, 0x80)) * 128; + current_byte = try der_reader.readByte(); + } else { + current_value += current_byte; + } + ret.object_identifier.data[out_idx] = current_value; + out_idx += 1; + } + ret.object_identifier.len = out_idx; + std.debug.assert(out_idx <= 16); + bytes_read.* += length; + return ret; + }, + .bmp_string => { + const length = try parse_length_internal(bytes_read, der_reader); + const str_mem = try alloc.alloc(u16, @divExact(length, 2)); + errdefer alloc.free(str_mem); + + for (str_mem) |*wide_char| { + wide_char.* = try der_reader.readIntBig(u16); + } + bytes_read.* += length; + return Value{ .bmp_string = str_mem }; + }, + .sequence, .set => { + const length = try parse_length_internal(bytes_read, der_reader); + var cur_read_bytes: usize = 0; + var arr = std.ArrayList(Value).init(alloc); + errdefer arr.deinit(); + + while (cur_read_bytes < length) { + (try arr.addOne()).* = try parse_value_internal(alloc, &cur_read_bytes, der_reader); + } + if (cur_read_bytes != length) + return error.InvalidContainerLength; + bytes_read.* += length; + + return @as(Value, switch (tag) { + .sequence => .{ .sequence = arr.toOwnedSlice() }, + .set => .{ .set = arr.toOwnedSlice() }, + else => unreachable, + }); + }, + .context_specific => unreachable, + } + } + + fn parse_value_internal(alloc: *Allocator, bytes_read: *usize, der_reader: anytype) DecodeError(@TypeOf(der_reader))!Value { + const tag_byte = try der_reader.readByte(); + bytes_read.* += 1; + return try parse_value_with_tag_byte(tag_byte, alloc, bytes_read, der_reader); + } + + pub fn parse_value(alloc: *Allocator, der_reader: anytype) DecodeError(@TypeOf(der_reader))!Value { + var read: usize = 0; + return try parse_value_internal(alloc, &read, der_reader); + } +}; + +test "der.parse_value" { + const github_der = @embedFile("../test/github.der"); + var fbs = std.io.fixedBufferStream(github_der); + + var arena = ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + + _ = try der.parse_value(&arena.allocator, fbs.reader()); +} diff --git a/src/deps/iguanaTLS/src/ciphersuites.zig b/src/deps/iguanaTLS/src/ciphersuites.zig new file mode 100644 index 000000000..99f58edec --- /dev/null +++ b/src/deps/iguanaTLS/src/ciphersuites.zig @@ -0,0 +1,446 @@ +const std = @import("std"); +const mem = std.mem; + +const crypto = @import("crypto.zig"); +const ChaCha20Stream = crypto.ChaCha20Stream; +const Chacha20Poly1305 = std.crypto.aead.chacha_poly.ChaCha20Poly1305; +const Poly1305 = std.crypto.onetimeauth.Poly1305; +const Aes128Gcm = std.crypto.aead.aes_gcm.Aes128Gcm; + +const main = @import("main.zig"); +const RecordHeader = main.RecordHeader; + +pub const suites = struct { + pub const ECDHE_RSA_Chacha20_Poly1305 = struct { + pub const name = "ECDHE-RSA-CHACHA20-POLY1305"; + pub const tag = 0xCCA8; + pub const key_exchange = .ecdhe; + pub const hash = .sha256; + pub const prefix_data_length = 0; + pub const mac_length = 16; + + pub const Keys = struct { + client_key: [32]u8, + server_key: [32]u8, + client_iv: [12]u8, + server_iv: [12]u8, + }; + + pub const State = struct { + mac: Poly1305, + context: ChaCha20Stream.BlockVec, + buf: [64]u8, + }; + + pub fn init_state(_: [0]u8, server_seq: u64, key_data: anytype, header: RecordHeader) State { + const len = header.len() - 16; + var nonce: [12]u8 = ([1]u8{0} ** 4) ++ ([1]u8{undefined} ** 8); + mem.writeIntBig(u64, nonce[4..12], server_seq); + for (nonce) |*n, i| { + n.* ^= key_data.server_iv(@This())[i]; + } + + var additional_data: [13]u8 = undefined; + mem.writeIntBig(u64, additional_data[0..8], server_seq); + additional_data[8..11].* = header.data[0..3].*; + mem.writeIntBig(u16, additional_data[11..13], len); + + var c: [4]u32 = undefined; + c[0] = 1; + c[1] = mem.readIntLittle(u32, nonce[0..4]); + c[2] = mem.readIntLittle(u32, nonce[4..8]); + c[3] = mem.readIntLittle(u32, nonce[8..12]); + const server_key = crypto.keyToWords(key_data.server_key(@This()).*); + + return .{ + .mac = ChaCha20Stream.initPoly1305(key_data.server_key(@This()).*, nonce, additional_data), + .context = ChaCha20Stream.initContext(server_key, c), + .buf = undefined, + }; + } + + pub fn decrypt_part( + key_data: anytype, + record_length: usize, + idx: *usize, + state: *State, + encrypted: []const u8, + out: []u8, + ) void { + _ = record_length; + + std.debug.assert(encrypted.len == out.len); + ChaCha20Stream.chacha20Xor( + out, + encrypted, + crypto.keyToWords(key_data.server_key(@This()).*), + &state.context, + idx, + &state.buf, + ); + + state.mac.update(encrypted); + } + + pub fn verify_mac(reader: anytype, record_length: usize, state: *State) !void { + var poly1305_tag: [16]u8 = undefined; + reader.readNoEof(&poly1305_tag) catch |err| switch (err) { + error.EndOfStream => return error.ServerMalformedResponse, + else => |e| return e, + }; + try ChaCha20Stream.checkPoly1305(&state.mac, record_length, poly1305_tag); + } + + pub fn raw_write( + comptime buffer_size: usize, + rand: *std.rand.Random, + key_data: anytype, + writer: anytype, + prefix: [3]u8, + seq: u64, + buffer: []const u8, + ) !void { + _ = rand; + + std.debug.assert(buffer.len <= buffer_size); + try writer.writeAll(&prefix); + try writer.writeIntBig(u16, @intCast(u16, buffer.len + 16)); + + var additional_data: [13]u8 = undefined; + mem.writeIntBig(u64, additional_data[0..8], seq); + additional_data[8..11].* = prefix; + mem.writeIntBig(u16, additional_data[11..13], @intCast(u16, buffer.len)); + + var encrypted_data: [buffer_size]u8 = undefined; + var tag_data: [16]u8 = undefined; + + var nonce: [12]u8 = ([1]u8{0} ** 4) ++ ([1]u8{undefined} ** 8); + mem.writeIntBig(u64, nonce[4..12], seq); + for (nonce) |*n, i| { + n.* ^= key_data.client_iv(@This())[i]; + } + + Chacha20Poly1305.encrypt( + encrypted_data[0..buffer.len], + &tag_data, + buffer, + &additional_data, + nonce, + key_data.client_key(@This()).*, + ); + try writer.writeAll(encrypted_data[0..buffer.len]); + try writer.writeAll(&tag_data); + } + + pub fn check_verify_message( + key_data: anytype, + length: usize, + reader: anytype, + verify_message: [16]u8, + ) !bool { + if (length != 32) + return false; + + var msg_in: [32]u8 = undefined; + try reader.readNoEof(&msg_in); + + const additional_data: [13]u8 = ([1]u8{0} ** 8) ++ [5]u8{ 0x16, 0x03, 0x03, 0x00, 0x10 }; + var decrypted: [16]u8 = undefined; + Chacha20Poly1305.decrypt( + &decrypted, + msg_in[0..16], + msg_in[16..].*, + &additional_data, + key_data.server_iv(@This()).*, + key_data.server_key(@This()).*, + ) catch return false; + + return mem.eql(u8, &decrypted, &verify_message); + } + }; + + pub const ECDHE_RSA_AES128_GCM_SHA256 = struct { + pub const name = "ECDHE-RSA-AES128-GCM-SHA256"; + pub const tag = 0xC02F; + pub const key_exchange = .ecdhe; + pub const hash = .sha256; + pub const prefix_data_length = 8; + pub const mac_length = 16; + + pub const Keys = struct { + client_key: [16]u8, + server_key: [16]u8, + client_iv: [4]u8, + server_iv: [4]u8, + }; + + const Aes = std.crypto.core.aes.Aes128; + pub const State = struct { + aes: @typeInfo(@TypeOf(Aes.initEnc)).Fn.return_type.?, + counterInt: u128, + }; + + pub fn init_state(prefix_data: [8]u8, server_seq: u64, key_data: anytype, header: RecordHeader) State { + _ = server_seq; + _ = header; + + var iv: [12]u8 = undefined; + iv[0..4].* = key_data.server_iv(@This()).*; + iv[4..].* = prefix_data; + + var j: [16]u8 = undefined; + mem.copy(u8, j[0..12], iv[0..]); + mem.writeIntBig(u32, j[12..][0..4], 2); + + return .{ + .aes = Aes.initEnc(key_data.server_key(@This()).*), + .counterInt = mem.readInt(u128, &j, .Big), + }; + } + + pub fn decrypt_part( + key_data: anytype, + record_length: usize, + idx: *usize, + state: *State, + encrypted: []const u8, + out: []u8, + ) void { + _ = key_data; + _ = record_length; + + std.debug.assert(encrypted.len == out.len); + + crypto.ctr( + @TypeOf(state.aes), + state.aes, + out, + encrypted, + &state.counterInt, + idx, + .Big, + ); + } + + pub fn verify_mac(reader: anytype, record_length: usize, state: *State) !void { + _ = state; + _ = record_length; + // @TODO Implement this + reader.skipBytes(16, .{}) catch |err| switch (err) { + error.EndOfStream => return error.ServerMalformedResponse, + else => |e| return e, + }; + } + + pub fn check_verify_message( + key_data: anytype, + length: usize, + reader: anytype, + verify_message: [16]u8, + ) !bool { + if (length != 40) + return false; + + var iv: [12]u8 = undefined; + iv[0..4].* = key_data.server_iv(@This()).*; + try reader.readNoEof(iv[4..12]); + + var msg_in: [32]u8 = undefined; + try reader.readNoEof(&msg_in); + + const additional_data: [13]u8 = ([1]u8{0} ** 8) ++ [5]u8{ 0x16, 0x03, 0x03, 0x00, 0x10 }; + var decrypted: [16]u8 = undefined; + Aes128Gcm.decrypt( + &decrypted, + msg_in[0..16], + msg_in[16..].*, + &additional_data, + iv, + key_data.server_key(@This()).*, + ) catch return false; + + return mem.eql(u8, &decrypted, &verify_message); + } + + pub fn raw_write( + comptime buffer_size: usize, + rand: *std.rand.Random, + key_data: anytype, + writer: anytype, + prefix: [3]u8, + seq: u64, + buffer: []const u8, + ) !void { + std.debug.assert(buffer.len <= buffer_size); + var iv: [12]u8 = undefined; + iv[0..4].* = key_data.client_iv(@This()).*; + rand.bytes(iv[4..12]); + + var additional_data: [13]u8 = undefined; + mem.writeIntBig(u64, additional_data[0..8], seq); + additional_data[8..11].* = prefix; + mem.writeIntBig(u16, additional_data[11..13], @intCast(u16, buffer.len)); + + try writer.writeAll(&prefix); + try writer.writeIntBig(u16, @intCast(u16, buffer.len + 24)); + try writer.writeAll(iv[4..12]); + + var encrypted_data: [buffer_size]u8 = undefined; + var tag_data: [16]u8 = undefined; + + Aes128Gcm.encrypt( + encrypted_data[0..buffer.len], + &tag_data, + buffer, + &additional_data, + iv, + key_data.client_key(@This()).*, + ); + try writer.writeAll(encrypted_data[0..buffer.len]); + try writer.writeAll(&tag_data); + } + }; + + pub const all = &[_]type{ ECDHE_RSA_Chacha20_Poly1305, ECDHE_RSA_AES128_GCM_SHA256 }; +}; + +fn key_field_width(comptime T: type, comptime field: anytype) ?usize { + if (!@hasField(T, @tagName(field))) + return null; + + const field_info = std.meta.fieldInfo(T, field); + if (!comptime std.meta.trait.is(.Array)(field_info.field_type) or std.meta.Elem(field_info.field_type) != u8) + @compileError("Field '" ++ field ++ "' of type '" ++ @typeName(T) ++ "' should be an array of u8."); + + return @typeInfo(field_info.field_type).Array.len; +} + +pub fn key_data_size(comptime ciphersuites: anytype) usize { + var max: usize = 0; + for (ciphersuites) |cs| { + const curr = (key_field_width(cs.Keys, .client_mac) orelse 0) + + (key_field_width(cs.Keys, .server_mac) orelse 0) + + key_field_width(cs.Keys, .client_key).? + + key_field_width(cs.Keys, .server_key).? + + key_field_width(cs.Keys, .client_iv).? + + key_field_width(cs.Keys, .server_iv).?; + if (curr > max) + max = curr; + } + return max; +} + +pub fn KeyData(comptime ciphersuites: anytype) type { + return struct { + data: [key_data_size(ciphersuites)]u8, + + pub fn client_mac(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .client_mac) orelse 0]u8 { + return self.data[0..comptime (key_field_width(cs.Keys, .client_mac) orelse 0)]; + } + + pub fn server_mac(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .server_mac) orelse 0]u8 { + const start = key_field_width(cs.Keys, .client_mac) orelse 0; + return self.data[start..][0..comptime (key_field_width(cs.Keys, .server_mac) orelse 0)]; + } + + pub fn client_key(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .client_key).?]u8 { + const start = (key_field_width(cs.Keys, .client_mac) orelse 0) + + (key_field_width(cs.Keys, .server_mac) orelse 0); + return self.data[start..][0..comptime key_field_width(cs.Keys, .client_key).?]; + } + + pub fn server_key(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .server_key).?]u8 { + const start = (key_field_width(cs.Keys, .client_mac) orelse 0) + + (key_field_width(cs.Keys, .server_mac) orelse 0) + + key_field_width(cs.Keys, .client_key).?; + return self.data[start..][0..comptime key_field_width(cs.Keys, .server_key).?]; + } + + pub fn client_iv(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .client_iv).?]u8 { + const start = (key_field_width(cs.Keys, .client_mac) orelse 0) + + (key_field_width(cs.Keys, .server_mac) orelse 0) + + key_field_width(cs.Keys, .client_key).? + + key_field_width(cs.Keys, .server_key).?; + return self.data[start..][0..comptime key_field_width(cs.Keys, .client_iv).?]; + } + + pub fn server_iv(self: *@This(), comptime cs: type) *[key_field_width(cs.Keys, .server_iv).?]u8 { + const start = (key_field_width(cs.Keys, .client_mac) orelse 0) + + (key_field_width(cs.Keys, .server_mac) orelse 0) + + key_field_width(cs.Keys, .client_key).? + + key_field_width(cs.Keys, .server_key).? + + key_field_width(cs.Keys, .client_iv).?; + return self.data[start..][0..comptime key_field_width(cs.Keys, .server_iv).?]; + } + }; +} + +pub fn key_expansion( + comptime ciphersuites: anytype, + tag: u16, + context: anytype, + comptime next_32_bytes: anytype, +) KeyData(ciphersuites) { + var res: KeyData(ciphersuites) = undefined; + inline for (ciphersuites) |cs| { + if (cs.tag == tag) { + var chunk: [32]u8 = undefined; + next_32_bytes(context, 0, &chunk); + comptime var chunk_idx = 1; + comptime var data_cursor = 0; + comptime var chunk_cursor = 0; + + const fields = .{ + .client_mac, .server_mac, + .client_key, .server_key, + .client_iv, .server_iv, + }; + inline for (fields) |field| { + if (chunk_cursor == 32) { + next_32_bytes(context, chunk_idx, &chunk); + chunk_idx += 1; + chunk_cursor = 0; + } + + const field_width = comptime (key_field_width(cs.Keys, field) orelse 0); + const first_read = comptime std.math.min(32 - chunk_cursor, field_width); + const second_read = field_width - first_read; + + res.data[data_cursor..][0..first_read].* = chunk[chunk_cursor..][0..first_read].*; + data_cursor += first_read; + chunk_cursor += first_read; + + if (second_read != 0) { + next_32_bytes(context, chunk_idx, &chunk); + chunk_idx += 1; + res.data[data_cursor..][0..second_read].* = chunk[chunk_cursor..][0..second_read].*; + data_cursor += second_read; + chunk_cursor = second_read; + comptime std.debug.assert(chunk_cursor != 32); + } + } + + return res; + } + } + unreachable; +} + +pub fn InRecordState(comptime ciphersuites: anytype) type { + var fields: [ciphersuites.len]std.builtin.TypeInfo.UnionField = undefined; + for (ciphersuites) |cs, i| { + fields[i] = .{ + .name = cs.name, + .field_type = cs.State, + .alignment = if (@sizeOf(cs.State) > 0) @alignOf(cs.State) else 0, + }; + } + return @Type(.{ + .Union = .{ + .layout = .Extern, + .tag_type = null, + .fields = &fields, + .decls = &[0]std.builtin.TypeInfo.Declaration{}, + }, + }); +} diff --git a/src/deps/iguanaTLS/src/crypto.zig b/src/deps/iguanaTLS/src/crypto.zig new file mode 100644 index 000000000..304c6d284 --- /dev/null +++ b/src/deps/iguanaTLS/src/crypto.zig @@ -0,0 +1,1016 @@ +const std = @import("std"); +const mem = std.mem; + +const Poly1305 = std.crypto.onetimeauth.Poly1305; +const Chacha20IETF = std.crypto.stream.chacha.ChaCha20IETF; + +// TODO See stdlib, this is a modified non vectorized implementation +pub const ChaCha20Stream = struct { + const math = std.math; + pub const BlockVec = [16]u32; + + pub fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le[0], constant_le[1], constant_le[2], constant_le[3], + key[0], key[1], key[2], key[3], + key[4], key[5], key[6], key[7], + d[0], d[1], d[2], d[3], + }; + } + + const QuarterRound = struct { + a: usize, + b: usize, + c: usize, + d: usize, + }; + + fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound { + return QuarterRound{ + .a = a, + .b = b, + .c = c, + .d = d, + }; + } + + inline fn chacha20Core(x: *BlockVec, input: BlockVec) void { + x.* = input; + + const rounds = comptime [_]QuarterRound{ + Rp(0, 4, 8, 12), + Rp(1, 5, 9, 13), + Rp(2, 6, 10, 14), + Rp(3, 7, 11, 15), + Rp(0, 5, 10, 15), + Rp(1, 6, 11, 12), + Rp(2, 7, 8, 13), + Rp(3, 4, 9, 14), + }; + + comptime var j: usize = 0; + inline while (j < 20) : (j += 2) { + inline for (rounds) |r| { + x[r.a] +%= x[r.b]; + x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16)); + x[r.c] +%= x[r.d]; + x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12)); + x[r.a] +%= x[r.b]; + x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8)); + x[r.c] +%= x[r.d]; + x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7)); + } + } + } + + inline fn hashToBytes(out: *[64]u8, x: BlockVec) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]); + } + } + + inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void { + var i: usize = 0; + while (i < 16) : (i += 1) { + x[i] +%= ctx[i]; + } + } + + pub fn initPoly1305(key: [32]u8, nonce: [12]u8, ad: [13]u8) Poly1305 { + var polyKey = [_]u8{0} ** 32; + Chacha20IETF.xor(&polyKey, &polyKey, 0, key, nonce); + var mac = Poly1305.init(&polyKey); + mac.update(&ad); + // Pad to 16 bytes from ad + mac.update(&.{ 0, 0, 0 }); + return mac; + } + + /// Call after `mac` has been updated with the whole message + pub fn checkPoly1305(mac: *Poly1305, len: usize, tag: [16]u8) !void { + if (len % 16 != 0) { + const zeros = [_]u8{0} ** 16; + const padding = 16 - (len % 16); + mac.update(zeros[0..padding]); + } + var lens: [16]u8 = undefined; + mem.writeIntLittle(u64, lens[0..8], 13); + mem.writeIntLittle(u64, lens[8..16], len); + mac.update(lens[0..]); + var computedTag: [16]u8 = undefined; + mac.final(computedTag[0..]); + + var acc: u8 = 0; + for (computedTag) |_, i| { + acc |= computedTag[i] ^ tag[i]; + } + if (acc != 0) { + return error.AuthenticationFailed; + } + } + + // TODO: Optimize this + pub fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, ctx: *BlockVec, idx: *usize, buf: *[64]u8) void { + _ = key; + + var x: BlockVec = undefined; + + var i: usize = 0; + while (i < in.len) { + if (idx.* % 64 == 0) { + if (idx.* != 0) { + ctx.*[12] += 1; + } + chacha20Core(x[0..], ctx.*); + contextFeedback(&x, ctx.*); + hashToBytes(buf, x); + } + + out[i] = in[i] ^ buf[idx.* % 64]; + + i += 1; + idx.* += 1; + } + } +}; + +pub fn keyToWords(key: [32]u8) [8]u32 { + var k: [8]u32 = undefined; + var i: usize = 0; + while (i < 8) : (i += 1) { + k[i] = mem.readIntLittle(u32, key[i * 4 ..][0..4]); + } + return k; +} + +// See std.crypto.core.modes.ctr +/// This mode creates a key stream by encrypting an incrementing counter using a block cipher, and adding it to the source material. +pub fn ctr( + comptime BlockCipher: anytype, + block_cipher: BlockCipher, + dst: []u8, + src: []const u8, + counterInt: *u128, + idx: *usize, + endian: std.builtin.Endian, +) void { + std.debug.assert(dst.len >= src.len); + const block_length = BlockCipher.block_length; + var cur_idx: usize = 0; + + const offset = idx.* % block_length; + if (offset != 0) { + const part_len = std.math.min(block_length - offset, src.len); + + var counter: [BlockCipher.block_length]u8 = undefined; + mem.writeInt(u128, &counter, counterInt.*, endian); + var pad = [_]u8{0} ** block_length; + mem.copy(u8, pad[offset..], src[0..part_len]); + block_cipher.xor(&pad, &pad, counter); + mem.copy(u8, dst[0..part_len], pad[offset..][0..part_len]); + cur_idx += part_len; + idx.* += part_len; + if (idx.* % block_length == 0) + counterInt.* += 1; + } + + const start_idx = cur_idx; + const remaining = src.len - cur_idx; + cur_idx = 0; + + const parallel_count = BlockCipher.block.parallel.optimal_parallel_blocks; + const wide_block_length = parallel_count * 16; + if (remaining >= wide_block_length) { + var counters: [parallel_count * 16]u8 = undefined; + while (cur_idx + wide_block_length <= remaining) : (cur_idx += wide_block_length) { + comptime var j = 0; + inline while (j < parallel_count) : (j += 1) { + mem.writeInt(u128, counters[j * 16 .. j * 16 + 16], counterInt.*, endian); + counterInt.* +%= 1; + } + block_cipher.xorWide(parallel_count, dst[start_idx..][cur_idx .. cur_idx + wide_block_length][0..wide_block_length], src[start_idx..][cur_idx .. cur_idx + wide_block_length][0..wide_block_length], counters); + idx.* += wide_block_length; + } + } + while (cur_idx + block_length <= remaining) : (cur_idx += block_length) { + var counter: [BlockCipher.block_length]u8 = undefined; + mem.writeInt(u128, &counter, counterInt.*, endian); + counterInt.* +%= 1; + block_cipher.xor(dst[start_idx..][cur_idx .. cur_idx + block_length][0..block_length], src[start_idx..][cur_idx .. cur_idx + block_length][0..block_length], counter); + idx.* += block_length; + } + if (cur_idx < remaining) { + std.debug.assert(idx.* % block_length == 0); + var counter: [BlockCipher.block_length]u8 = undefined; + mem.writeInt(u128, &counter, counterInt.*, endian); + + var pad = [_]u8{0} ** block_length; + mem.copy(u8, &pad, src[start_idx..][cur_idx..]); + block_cipher.xor(&pad, &pad, counter); + mem.copy(u8, dst[start_idx..][cur_idx..], pad[0 .. remaining - cur_idx]); + + idx.* += remaining - cur_idx; + if (idx.* % block_length == 0) + counterInt.* +%= 1; + } +} + +// Ported from BearSSL's ec_prime_i31 engine +pub const ecc = struct { + pub const SECP384R1 = struct { + pub const point_len = 96; + + const order = [point_len / 2]u8{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xC7, 0x63, 0x4D, 0x81, 0xF4, 0x37, 0x2D, 0xDF, + 0x58, 0x1A, 0x0D, 0xB2, 0x48, 0xB0, 0xA7, 0x7A, + 0xEC, 0xEC, 0x19, 0x6A, 0xCC, 0xC5, 0x29, 0x73, + }; + + const P = [_]u32{ + 0x0000018C, 0x7FFFFFFF, 0x00000001, 0x00000000, + 0x7FFFFFF8, 0x7FFFFFEF, 0x7FFFFFFF, 0x7FFFFFFF, + 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF, + 0x7FFFFFFF, 0x00000FFF, + }; + const R2 = [_]u32{ + 0x0000018C, 0x00000000, 0x00000080, 0x7FFFFE00, + 0x000001FF, 0x00000800, 0x00000000, 0x7FFFE000, + 0x00001FFF, 0x00008000, 0x00008000, 0x00000000, + 0x00000000, 0x00000000, + }; + const B = [_]u32{ + 0x0000018C, 0x6E666840, 0x070D0392, 0x5D810231, + 0x7651D50C, 0x17E218D6, 0x1B192002, 0x44EFE441, + 0x3A524E2B, 0x2719BA5F, 0x41F02209, 0x36C5643E, + 0x5813EFFE, 0x000008A5, + }; + + const base_point = [point_len]u8{ + 0xAA, 0x87, 0xCA, 0x22, 0xBE, 0x8B, 0x05, 0x37, + 0x8E, 0xB1, 0xC7, 0x1E, 0xF3, 0x20, 0xAD, 0x74, + 0x6E, 0x1D, 0x3B, 0x62, 0x8B, 0xA7, 0x9B, 0x98, + 0x59, 0xF7, 0x41, 0xE0, 0x82, 0x54, 0x2A, 0x38, + 0x55, 0x02, 0xF2, 0x5D, 0xBF, 0x55, 0x29, 0x6C, + 0x3A, 0x54, 0x5E, 0x38, 0x72, 0x76, 0x0A, 0xB7, + 0x36, 0x17, 0xDE, 0x4A, 0x96, 0x26, 0x2C, 0x6F, + 0x5D, 0x9E, 0x98, 0xBF, 0x92, 0x92, 0xDC, 0x29, + 0xF8, 0xF4, 0x1D, 0xBD, 0x28, 0x9A, 0x14, 0x7C, + 0xE9, 0xDA, 0x31, 0x13, 0xB5, 0xF0, 0xB8, 0xC0, + 0x0A, 0x60, 0xB1, 0xCE, 0x1D, 0x7E, 0x81, 0x9D, + 0x7A, 0x43, 0x1D, 0x7C, 0x90, 0xEA, 0x0E, 0x5F, + }; + + comptime { + std.debug.assert((P[0] - (P[0] >> 5) + 7) >> 2 == point_len + 1); + } + }; + + pub const SECP256R1 = struct { + pub const point_len = 64; + + const order = [point_len / 2]u8{ + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xBC, 0xE6, 0xFA, 0xAD, 0xA7, 0x17, 0x9E, 0x84, + 0xF3, 0xB9, 0xCA, 0xC2, 0xFC, 0x63, 0x25, 0x51, + }; + + const P = [_]u32{ + 0x00000108, 0x7FFFFFFF, + 0x7FFFFFFF, 0x7FFFFFFF, + 0x00000007, 0x00000000, + 0x00000000, 0x00000040, + 0x7FFFFF80, 0x000000FF, + }; + const R2 = [_]u32{ + 0x00000108, 0x00014000, + 0x00018000, 0x00000000, + 0x7FF40000, 0x7FEFFFFF, + 0x7FF7FFFF, 0x7FAFFFFF, + 0x005FFFFF, 0x00000000, + }; + const B = [_]u32{ + 0x00000108, 0x6FEE1803, + 0x6229C4BD, 0x21B139BE, + 0x327150AA, 0x3567802E, + 0x3F7212ED, 0x012E4355, + 0x782DD38D, 0x0000000E, + }; + + const base_point = [point_len]u8{ + 0x6B, 0x17, 0xD1, 0xF2, 0xE1, 0x2C, 0x42, 0x47, + 0xF8, 0xBC, 0xE6, 0xE5, 0x63, 0xA4, 0x40, 0xF2, + 0x77, 0x03, 0x7D, 0x81, 0x2D, 0xEB, 0x33, 0xA0, + 0xF4, 0xA1, 0x39, 0x45, 0xD8, 0x98, 0xC2, 0x96, + 0x4F, 0xE3, 0x42, 0xE2, 0xFE, 0x1A, 0x7F, 0x9B, + 0x8E, 0xE7, 0xEB, 0x4A, 0x7C, 0x0F, 0x9E, 0x16, + 0x2B, 0xCE, 0x33, 0x57, 0x6B, 0x31, 0x5E, 0xCE, + 0xCB, 0xB6, 0x40, 0x68, 0x37, 0xBF, 0x51, 0xF5, + }; + + comptime { + std.debug.assert((P[0] - (P[0] >> 5) + 7) >> 2 == point_len + 1); + } + }; + + fn jacobian_len(comptime Curve: type) usize { + return @divTrunc(Curve.order.len * 8 + 61, 31); + } + + fn Jacobian(comptime Curve: type) type { + return [3][jacobian_len(Curve)]u32; + } + + fn zero_jacobian(comptime Curve: type) Jacobian(Curve) { + var result = std.mem.zeroes(Jacobian(Curve)); + result[0][0] = Curve.P[0]; + result[1][0] = Curve.P[0]; + result[2][0] = Curve.P[0]; + return result; + } + + pub fn scalarmult( + comptime Curve: type, + point: [Curve.point_len]u8, + k: []const u8, + ) ![Curve.point_len]u8 { + var P: Jacobian(Curve) = undefined; + var res: u32 = decode_to_jacobian(Curve, &P, point); + point_mul(Curve, &P, k); + var out: [Curve.point_len]u8 = undefined; + encode_from_jacobian(Curve, &out, P); + if (res == 0) + return error.MultiplicationFailed; + return out; + } + + pub fn KeyPair(comptime Curve: type) type { + return struct { + public_key: [Curve.point_len]u8, + secret_key: [Curve.point_len / 2]u8, + }; + } + + pub fn make_key_pair(comptime Curve: type, rand_bytes: [Curve.point_len / 2]u8) KeyPair(Curve) { + var key_bytes = rand_bytes; + comptime var mask: u8 = 0xFF; + comptime { + while (mask >= Curve.order[0]) { + mask >>= 1; + } + } + key_bytes[0] &= mask; + key_bytes[Curve.point_len / 2 - 1] |= 0x01; + + return .{ + .secret_key = key_bytes, + .public_key = scalarmult(Curve, Curve.base_point, &key_bytes) catch unreachable, + }; + } + + fn jacobian_with_one_set(comptime Curve: type, comptime fields: [2][jacobian_len(Curve)]u32) Jacobian(Curve) { + const plen = comptime (Curve.P[0] + 63) >> 5; + return fields ++ [1][jacobian_len(Curve)]u32{ + [2]u32{ Curve.P[0], 1 } ++ ([1]u32{0} ** (plen - 2)), + }; + } + + fn encode_from_jacobian(comptime Curve: type, point: *[Curve.point_len]u8, P: Jacobian(Curve)) void { + var Q = P; + const T = comptime jacobian_with_one_set(Curve, [2][jacobian_len(Curve)]u32{ undefined, undefined }); + _ = run_code(Curve, &Q, T, &code.affine); + encode_jacobian_part(point[0 .. Curve.point_len / 2], &Q[0]); + encode_jacobian_part(point[Curve.point_len / 2 ..], &Q[1]); + } + + fn point_mul(comptime Curve: type, P: *Jacobian(Curve), x: []const u8) void { + var P2 = P.*; + point_double(Curve, &P2); + var P3 = P.*; + point_add(Curve, &P3, P2); + var Q = zero_jacobian(Curve); + var qz: u32 = 1; + var xlen = x.len; + var xidx: usize = 0; + while (xlen > 0) : ({ + xlen -= 1; + xidx += 1; + }) { + var k: u3 = 6; + while (true) : (k -= 2) { + point_double(Curve, &Q); + point_double(Curve, &Q); + var T = P.*; + var U = Q; + const bits = @as(u32, x[xidx] >> k) & 3; + const bnz = NEQ(bits, 0); + CCOPY(EQ(bits, 2), mem.asBytes(&T), mem.asBytes(&P2)); + CCOPY(EQ(bits, 3), mem.asBytes(&T), mem.asBytes(&P3)); + point_add(Curve, &U, T); + CCOPY(bnz & qz, mem.asBytes(&Q), mem.asBytes(&T)); + CCOPY(bnz & ~qz, mem.asBytes(&Q), mem.asBytes(&U)); + qz &= ~bnz; + + if (k == 0) + break; + } + } + P.* = Q; + } + + inline fn point_double(comptime Curve: type, P: *Jacobian(Curve)) void { + _ = run_code(Curve, P, P.*, &code.double); + } + inline fn point_add(comptime Curve: type, P1: *Jacobian(Curve), P2: Jacobian(Curve)) void { + _ = run_code(Curve, P1, P2, &code._add); + } + + fn decode_to_jacobian( + comptime Curve: type, + out: *Jacobian(Curve), + point: [Curve.point_len]u8, + ) u32 { + out.* = zero_jacobian(Curve); + var result = decode_mod(Curve, &out.*[0], point[0 .. Curve.point_len / 2].*); + result &= decode_mod(Curve, &out.*[1], point[Curve.point_len / 2 ..].*); + + const zlen = comptime ((Curve.P[0] + 63) >> 5); + comptime std.debug.assert(zlen == @typeInfo(@TypeOf(Curve.R2)).Array.len); + comptime std.debug.assert(zlen == @typeInfo(@TypeOf(Curve.B)).Array.len); + + const Q = comptime jacobian_with_one_set(Curve, [2][jacobian_len(Curve)]u32{ Curve.R2, Curve.B }); + result &= ~run_code(Curve, out, Q, &code.check); + return result; + } + + const code = struct { + const P1x = 0; + const P1y = 1; + const P1z = 2; + const P2x = 3; + const P2y = 4; + const P2z = 5; + const Px = 0; + const Py = 1; + const Pz = 2; + const t1 = 6; + const t2 = 7; + const t3 = 8; + const t4 = 9; + const t5 = 10; + const t6 = 11; + const t7 = 12; + const t8 = 3; + const t9 = 4; + const t10 = 5; + fn MSET(comptime d: u16, comptime a: u16) u16 { + return 0x0000 + (d << 8) + (a << 4); + } + fn MADD(comptime d: u16, comptime a: u16) u16 { + return 0x1000 + (d << 8) + (a << 4); + } + fn MSUB(comptime d: u16, comptime a: u16) u16 { + return 0x2000 + (d << 8) + (a << 4); + } + fn MMUL(comptime d: u16, comptime a: u16, comptime b: u16) u16 { + return 0x3000 + (d << 8) + (a << 4) + b; + } + fn MINV(comptime d: u16, comptime a: u16, comptime b: u16) u16 { + return 0x4000 + (d << 8) + (a << 4) + b; + } + fn MTZ(comptime d: u16) u16 { + return 0x5000 + (d << 8); + } + const ENDCODE = 0; + + const check = [_]u16{ + // Convert x and y to Montgomery representation. + MMUL(t1, P1x, P2x), + MMUL(t2, P1y, P2x), + MSET(P1x, t1), + MSET(P1y, t2), + // Compute x^3 in t1. + MMUL(t2, P1x, P1x), + MMUL(t1, P1x, t2), + // Subtract 3*x from t1. + MSUB(t1, P1x), + MSUB(t1, P1x), + MSUB(t1, P1x), + // Add b. + MADD(t1, P2y), + // Compute y^2 in t2. + MMUL(t2, P1y, P1y), + // Compare y^2 with x^3 - 3*x + b; they must match. + MSUB(t1, t2), + MTZ(t1), + // Set z to 1 (in Montgomery representation). + MMUL(P1z, P2x, P2z), + ENDCODE, + }; + const double = [_]u16{ + // Compute z^2 (in t1). + MMUL(t1, Pz, Pz), + // Compute x-z^2 (in t2) and then x+z^2 (in t1). + MSET(t2, Px), + MSUB(t2, t1), + MADD(t1, Px), + // Compute m = 3*(x+z^2)*(x-z^2) (in t1). + MMUL(t3, t1, t2), + MSET(t1, t3), + MADD(t1, t3), + MADD(t1, t3), + // Compute s = 4*x*y^2 (in t2) and 2*y^2 (in t3). + MMUL(t3, Py, Py), + MADD(t3, t3), + MMUL(t2, Px, t3), + MADD(t2, t2), + // Compute x' = m^2 - 2*s. + MMUL(Px, t1, t1), + MSUB(Px, t2), + MSUB(Px, t2), + // Compute z' = 2*y*z. + MMUL(t4, Py, Pz), + MSET(Pz, t4), + MADD(Pz, t4), + // Compute y' = m*(s - x') - 8*y^4. Note that we already have + // 2*y^2 in t3. + MSUB(t2, Px), + MMUL(Py, t1, t2), + MMUL(t4, t3, t3), + MSUB(Py, t4), + MSUB(Py, t4), + ENDCODE, + }; + const _add = [_]u16{ + // Compute u1 = x1*z2^2 (in t1) and s1 = y1*z2^3 (in t3). + MMUL(t3, P2z, P2z), + MMUL(t1, P1x, t3), + MMUL(t4, P2z, t3), + MMUL(t3, P1y, t4), + // Compute u2 = x2*z1^2 (in t2) and s2 = y2*z1^3 (in t4). + MMUL(t4, P1z, P1z), + MMUL(t2, P2x, t4), + MMUL(t5, P1z, t4), + MMUL(t4, P2y, t5), + //Compute h = u2 - u1 (in t2) and r = s2 - s1 (in t4). + MSUB(t2, t1), + MSUB(t4, t3), + // Report cases where r = 0 through the returned flag. + MTZ(t4), + // Compute u1*h^2 (in t6) and h^3 (in t5). + MMUL(t7, t2, t2), + MMUL(t6, t1, t7), + MMUL(t5, t7, t2), + // Compute x3 = r^2 - h^3 - 2*u1*h^2. + // t1 and t7 can be used as scratch registers. + MMUL(P1x, t4, t4), + MSUB(P1x, t5), + MSUB(P1x, t6), + MSUB(P1x, t6), + //Compute y3 = r*(u1*h^2 - x3) - s1*h^3. + MSUB(t6, P1x), + MMUL(P1y, t4, t6), + MMUL(t1, t5, t3), + MSUB(P1y, t1), + //Compute z3 = h*z1*z2. + MMUL(t1, P1z, P2z), + MMUL(P1z, t1, t2), + ENDCODE, + }; + const affine = [_]u16{ + // Save z*R in t1. + MSET(t1, P1z), + // Compute z^3 in t2. + MMUL(t2, P1z, P1z), + MMUL(t3, P1z, t2), + MMUL(t2, t3, P2z), + // Invert to (1/z^3) in t2. + MINV(t2, t3, t4), + // Compute y. + MSET(t3, P1y), + MMUL(P1y, t2, t3), + // Compute (1/z^2) in t3. + MMUL(t3, t2, t1), + // Compute x. + MSET(t2, P1x), + MMUL(P1x, t2, t3), + ENDCODE, + }; + }; + + fn decode_mod( + comptime Curve: type, + x: *[jacobian_len(Curve)]u32, + src: [Curve.point_len / 2]u8, + ) u32 { + const mlen = comptime ((Curve.P[0] + 31) >> 5); + const tlen = comptime std.math.max(mlen << 2, Curve.point_len / 2) + 4; + + var r: u32 = 0; + var pass: usize = 0; + while (pass < 2) : (pass += 1) { + var v: usize = 1; + var acc: u32 = 0; + var acc_len: u32 = 0; + + var u: usize = 0; + while (u < tlen) : (u += 1) { + const b = if (u < Curve.point_len / 2) + @as(u32, src[Curve.point_len / 2 - 1 - u]) + else + 0; + acc |= b << @truncate(u5, acc_len); + acc_len += 8; + if (acc_len >= 31) { + const xw = acc & 0x7FFFFFFF; + acc_len -= 31; + acc = b >> @truncate(u5, 8 - acc_len); + if (v <= mlen) { + if (pass != 0) { + x[v] = r & xw; + } else { + const cc = @bitCast(u32, CMP(xw, Curve.P[v])); + r = MUX(EQ(cc, 0), r, cc); + } + } else if (pass == 0) { + r = MUX(EQ(xw, 0), r, 1); + } + v += 1; + } + } + r >>= 1; + r |= (r << 1); + } + x[0] = Curve.P[0]; + return r & 1; + } + + fn run_code( + comptime Curve: type, + P1: *Jacobian(Curve), + P2: Jacobian(Curve), + comptime Code: []const u16, + ) u32 { + const jaclen = comptime jacobian_len(Curve); + + var t: [13][jaclen]u32 = undefined; + var result: u32 = 1; + + t[0..3].* = P1.*; + t[3..6].* = P2; + + comptime var u: usize = 0; + inline while (true) : (u += 1) { + comptime var op = Code[u]; + if (op == 0) + break; + const d = comptime (op >> 8) & 0x0F; + const a = comptime (op >> 4) & 0x0F; + const b = comptime op & 0x0F; + op >>= 12; + + switch (op) { + 0 => t[d] = t[a], + 1 => { + var ctl = add(&t[d], &t[a], 1); + ctl |= NOT(sub(&t[d], &Curve.P, 0)); + _ = sub(&t[d], &Curve.P, ctl); + }, + 2 => _ = add(&t[d], &Curve.P, sub(&t[d], &t[a], 1)), + 3 => montymul(&t[d], &t[a], &t[b], &Curve.P, 1), + 4 => { + var tp: [Curve.point_len / 2]u8 = undefined; + encode_jacobian_part(&tp, &Curve.P); + tp[Curve.point_len / 2 - 1] -= 2; + modpow(Curve, &t[d], tp, 1, &t[a], &t[b]); + }, + else => result &= ~iszero(&t[d]), + } + } + P1.* = t[0..3].*; + return result; + } + + inline fn MUL31(x: u32, y: u32) u64 { + return @as(u64, x) * @as(u64, y); + } + + inline fn MUL31_lo(x: u32, y: u32) u32 { + return (x *% y) & 0x7FFFFFFF; + } + + inline fn MUX(ctl: u32, x: u32, y: u32) u32 { + return y ^ (@bitCast(u32, -@bitCast(i32, ctl)) & (x ^ y)); + } + inline fn NOT(ctl: u32) u32 { + return ctl ^ 1; + } + inline fn NEQ(x: u32, y: u32) u32 { + const q = x ^ y; + return (q | @bitCast(u32, -@bitCast(i32, q))) >> 31; + } + inline fn EQ(x: u32, y: u32) u32 { + const q = x ^ y; + return NOT((q | @bitCast(u32, -@bitCast(i32, q))) >> 31); + } + inline fn CMP(x: u32, y: u32) i32 { + return @bitCast(i32, GT(x, y)) | -@bitCast(i32, GT(y, x)); + } + inline fn GT(x: u32, y: u32) u32 { + const z = y -% x; + return (z ^ ((x ^ y) & (x ^ z))) >> 31; + } + inline fn LT(x: u32, y: u32) u32 { + return GT(y, x); + } + inline fn GE(x: u32, y: u32) u32 { + return NOT(GT(y, x)); + } + + fn CCOPY(ctl: u32, dst: []u8, src: []const u8) void { + for (src) |s, i| { + dst[i] = @truncate(u8, MUX(ctl, s, dst[i])); + } + } + + inline fn set_zero(out: [*]u32, bit_len: u32) void { + out[0] = bit_len; + mem.set(u32, (out + 1)[0 .. (bit_len + 31) >> 5], 0); + } + + fn divrem(_hi: u32, _lo: u32, d: u32, r: *u32) u32 { + var hi = _hi; + var lo = _lo; + var q: u32 = 0; + const ch = EQ(hi, d); + hi = MUX(ch, 0, hi); + + var k: u5 = 31; + while (k > 0) : (k -= 1) { + const j = @truncate(u5, 32 - @as(u6, k)); + const w = (hi << j) | (lo >> k); + const ctl = GE(w, d) | (hi >> k); + const hi2 = (w -% d) >> j; + const lo2 = lo -% (d << k); + hi = MUX(ctl, hi2, hi); + lo = MUX(ctl, lo2, lo); + q |= ctl << k; + } + const cf = GE(lo, d) | hi; + q |= cf; + r.* = MUX(cf, lo -% d, lo); + return q; + } + + inline fn div(hi: u32, lo: u32, d: u32) u32 { + var r: u32 = undefined; + return divrem(hi, lo, d, &r); + } + + fn muladd_small(x: [*]u32, z: u32, m: [*]const u32) void { + var a0: u32 = undefined; + var a1: u32 = undefined; + var b0: u32 = undefined; + const mblr = @intCast(u5, m[0] & 31); + const mlen = (m[0] + 31) >> 5; + const hi = x[mlen]; + if (mblr == 0) { + a0 = x[mlen]; + mem.copyBackwards(u32, (x + 2)[0 .. mlen - 1], (x + 1)[0 .. mlen - 1]); + x[1] = z; + a1 = x[mlen]; + b0 = m[mlen]; + } else { + a0 = ((x[mlen] << (31 - mblr)) | (x[mlen - 1] >> mblr)) & 0x7FFFFFFF; + mem.copyBackwards(u32, (x + 2)[0 .. mlen - 1], (x + 1)[0 .. mlen - 1]); + x[1] = z; + a1 = ((x[mlen] << (31 - mblr)) | (x[mlen - 1] >> mblr)) & 0x7FFFFFFF; + b0 = ((m[mlen] << (31 - mblr)) | (m[mlen - 1] >> mblr)) & 0x7FFFFFFF; + } + + const g = div(a0 >> 1, a1 | (a0 << 31), b0); + const q = MUX(EQ(a0, b0), 0x7FFFFFFF, MUX(EQ(g, 0), 0, g -% 1)); + + var cc: u32 = 0; + var tb: u32 = 1; + var u: usize = 1; + while (u <= mlen) : (u += 1) { + const mw = m[u]; + const zl = MUL31(mw, q) + cc; + cc = @truncate(u32, zl >> 31); + const zw = @truncate(u32, zl) & 0x7FFFFFFF; + const xw = x[u]; + var nxw = xw -% zw; + cc += nxw >> 31; + nxw &= 0x7FFFFFFF; + x[u] = nxw; + tb = MUX(EQ(nxw, mw), tb, GT(nxw, mw)); + } + + const over = GT(cc, hi); + const under = ~over & (tb | LT(cc, hi)); + _ = add(x, m, over); + _ = sub(x, m, under); + } + + fn to_monty(x: [*]u32, m: [*]const u32) void { + const mlen = (m[0] + 31) >> 5; + var k = mlen; + while (k > 0) : (k -= 1) { + muladd_small(x, 0, m); + } + } + + fn modpow( + comptime Curve: type, + x: *[jacobian_len(Curve)]u32, + e: [Curve.point_len / 2]u8, + m0i: u32, + t1: *[jacobian_len(Curve)]u32, + t2: *[jacobian_len(Curve)]u32, + ) void { + t1.* = x.*; + to_monty(t1, &Curve.P); + set_zero(x, Curve.P[0]); + x[1] = 1; + const bitlen = comptime (Curve.point_len / 2) << 3; + var k: usize = 0; + while (k < bitlen) : (k += 1) { + const ctl = (e[Curve.point_len / 2 - 1 - (k >> 3)] >> (@truncate(u3, k & 7))) & 1; + montymul(t2, x, t1, &Curve.P, m0i); + CCOPY(ctl, mem.asBytes(x), mem.asBytes(t2)); + montymul(t2, t1, t1, &Curve.P, m0i); + t1.* = t2.*; + } + } + + fn encode_jacobian_part(dst: []u8, x: [*]const u32) void { + const xlen = (x[0] + 31) >> 5; + + var buf = @ptrToInt(dst.ptr) + dst.len; + var len: usize = dst.len; + var k: usize = 1; + var acc: u32 = 0; + var acc_len: u5 = 0; + while (len != 0) { + const w = if (k <= xlen) x[k] else 0; + k += 1; + if (acc_len == 0) { + acc = w; + acc_len = 31; + } else { + const z = acc | (w << acc_len); + acc_len -= 1; + acc = w >> (31 - acc_len); + if (len >= 4) { + buf -= 4; + len -= 4; + mem.writeIntBig(u32, @intToPtr([*]u8, buf)[0..4], z); + } else { + switch (len) { + 3 => { + @intToPtr(*u8, buf - 3).* = @truncate(u8, z >> 16); + @intToPtr(*u8, buf - 2).* = @truncate(u8, z >> 8); + }, + 2 => @intToPtr(*u8, buf - 2).* = @truncate(u8, z >> 8), + 1 => {}, + else => unreachable, + } + @intToPtr(*u8, buf - 1).* = @truncate(u8, z); + return; + } + } + } + } + + fn montymul( + out: [*]u32, + x: [*]const u32, + y: [*]const u32, + m: [*]const u32, + m0i: u32, + ) void { + const len = (m[0] + 31) >> 5; + const len4 = len & ~@as(usize, 3); + set_zero(out, m[0]); + var dh: u32 = 0; + var u: usize = 0; + while (u < len) : (u += 1) { + const xu = x[u + 1]; + const f = MUL31_lo(out[1] + MUL31_lo(x[u + 1], y[1]), m0i); + + var r: u64 = 0; + var v: usize = 0; + while (v < len4) : (v += 4) { + comptime var j = 1; + inline while (j <= 4) : (j += 1) { + const z = out[v + j] +% MUL31(xu, y[v + j]) +% MUL31(f, m[v + j]) +% r; + r = z >> 31; + out[v + j - 1] = @truncate(u32, z) & 0x7FFFFFFF; + } + } + while (v < len) : (v += 1) { + const z = out[v + 1] +% MUL31(xu, y[v + 1]) +% MUL31(f, m[v + 1]) +% r; + r = z >> 31; + out[v] = @truncate(u32, z) & 0x7FFFFFFF; + } + dh += @truncate(u32, r); + out[len] = dh & 0x7FFFFFFF; + dh >>= 31; + } + out[0] = m[0]; + const ctl = NEQ(dh, 0) | NOT(sub(out, m, 0)); + _ = sub(out, m, ctl); + } + + fn add(a: [*]u32, b: [*]const u32, ctl: u32) u32 { + var u: usize = 1; + var cc: u32 = 0; + const m = (a[0] + 63) >> 5; + while (u < m) : (u += 1) { + const aw = a[u]; + const bw = b[u]; + const naw = aw +% bw +% cc; + cc = naw >> 31; + a[u] = MUX(ctl, naw & 0x7FFFFFFF, aw); + } + return cc; + } + + fn sub(a: [*]u32, b: [*]const u32, ctl: u32) u32 { + var cc: u32 = 0; + const m = (a[0] + 63) >> 5; + var u: usize = 1; + while (u < m) : (u += 1) { + const aw = a[u]; + const bw = b[u]; + const naw = aw -% bw -% cc; + cc = naw >> 31; + a[u] = MUX(ctl, naw & 0x7FFFFFFF, aw); + } + return cc; + } + + fn iszero(arr: [*]const u32) u32 { + const mlen = (arr[0] + 63) >> 5; + var z: u32 = 0; + var u: usize = mlen - 1; + while (u > 0) : (u -= 1) { + z |= arr[u]; + } + return ~(z | @bitCast(u32, -@bitCast(i32, z))) >> 31; + } +}; + +test "elliptic curve functions with secp384r1 curve" { + { + // Decode to Jacobian then encode again with no operations + var P: ecc.Jacobian(ecc.SECP384R1) = undefined; + _ = ecc.decode_to_jacobian(ecc.SECP384R1, &P, ecc.SECP384R1.base_point); + var out: [96]u8 = undefined; + ecc.encode_from_jacobian(ecc.SECP384R1, &out, P); + try std.testing.expectEqual(ecc.SECP384R1.base_point, out); + + // Multiply by one, check that the result is still the base point + mem.set(u8, &out, 0); + ecc.point_mul(ecc.SECP384R1, &P, &[1]u8{1}); + ecc.encode_from_jacobian(ecc.SECP384R1, &out, P); + try std.testing.expectEqual(ecc.SECP384R1.base_point, out); + } + + { + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + // Derive a shared secret from a Diffie-Hellman key exchange + var seed: [48]u8 = undefined; + rand.bytes(&seed); + const kp1 = ecc.make_key_pair(ecc.SECP384R1, seed); + rand.bytes(&seed); + const kp2 = ecc.make_key_pair(ecc.SECP384R1, seed); + + const shared1 = try ecc.scalarmult(ecc.SECP384R1, kp1.public_key, &kp2.secret_key); + const shared2 = try ecc.scalarmult(ecc.SECP384R1, kp2.public_key, &kp1.secret_key); + try std.testing.expectEqual(shared1, shared2); + } + + // @TODO Add tests with known points. +} diff --git a/src/deps/iguanaTLS/src/main.zig b/src/deps/iguanaTLS/src/main.zig new file mode 100644 index 000000000..6937e19b2 --- /dev/null +++ b/src/deps/iguanaTLS/src/main.zig @@ -0,0 +1,2216 @@ +const std = @import("std"); +const mem = std.mem; +const Allocator = mem.Allocator; +const Sha224 = std.crypto.hash.sha2.Sha224; +const Sha384 = std.crypto.hash.sha2.Sha384; +const Sha512 = std.crypto.hash.sha2.Sha512; +const Sha256 = std.crypto.hash.sha2.Sha256; +const Hmac256 = std.crypto.auth.hmac.sha2.HmacSha256; + +pub const asn1 = @import("asn1.zig"); +pub const x509 = @import("x509.zig"); +pub const crypto = @import("crypto.zig"); + +const ciphers = @import("ciphersuites.zig"); +pub const ciphersuites = ciphers.suites; + +pub const @"pcks1v1.5" = @import("pcks1-1_5.zig"); + +comptime { + std.testing.refAllDecls(x509); + std.testing.refAllDecls(asn1); + std.testing.refAllDecls(crypto); +} + +fn handshake_record_length(reader: anytype) !usize { + return try record_length(0x16, reader); +} + +pub const RecordHeader = struct { + data: [5]u8, + + pub inline fn tag(self: @This()) u8 { + return self.data[0]; + } + + pub inline fn len(self: @This()) u16 { + return mem.readIntSliceBig(u16, self.data[3..]); + } +}; + +pub fn record_header(reader: anytype) !RecordHeader { + var header: [5]u8 = undefined; + try reader.readNoEof(&header); + + if (!mem.eql(u8, header[1..3], "\x03\x03") and !mem.eql(u8, header[1..3], "\x03\x01")) + return error.ServerInvalidVersion; + + return RecordHeader{ + .data = header, + }; +} + +pub fn record_length(t: u8, reader: anytype) !usize { + try check_record_type(t, reader); + var header: [4]u8 = undefined; + try reader.readNoEof(&header); + if (!mem.eql(u8, header[0..2], "\x03\x03") and !mem.eql(u8, header[0..2], "\x03\x01")) + return error.ServerInvalidVersion; + return mem.readIntSliceBig(u16, header[2..4]); +} + +pub const ServerAlert = error{ + AlertCloseNotify, + AlertUnexpectedMessage, + AlertBadRecordMAC, + AlertDecryptionFailed, + AlertRecordOverflow, + AlertDecompressionFailure, + AlertHandshakeFailure, + AlertNoCertificate, + AlertBadCertificate, + AlertUnsupportedCertificate, + AlertCertificateRevoked, + AlertCertificateExpired, + AlertCertificateUnknown, + AlertIllegalParameter, + AlertUnknownCA, + AlertAccessDenied, + AlertDecodeError, + AlertDecryptError, + AlertExportRestriction, + AlertProtocolVersion, + AlertInsufficientSecurity, + AlertInternalError, + AlertUserCanceled, + AlertNoRenegotiation, + AlertUnsupportedExtension, +}; + +fn check_record_type( + expected: u8, + reader: anytype, +) (@TypeOf(reader).Error || ServerAlert || error{ ServerMalformedResponse, EndOfStream })!void { + const record_type = try reader.readByte(); + // Alert + if (record_type == 0x15) { + // Skip SSL version, length of record + try reader.skipBytes(4, .{}); + + const severity = try reader.readByte(); + _ = severity; + const err_num = try reader.readByte(); + return alert_byte_to_error(err_num); + } + if (record_type != expected) + return error.ServerMalformedResponse; +} + +pub fn alert_byte_to_error(b: u8) (ServerAlert || error{ServerMalformedResponse}) { + return switch (b) { + 0 => error.AlertCloseNotify, + 10 => error.AlertUnexpectedMessage, + 20 => error.AlertBadRecordMAC, + 21 => error.AlertDecryptionFailed, + 22 => error.AlertRecordOverflow, + 30 => error.AlertDecompressionFailure, + 40 => error.AlertHandshakeFailure, + 41 => error.AlertNoCertificate, + 42 => error.AlertBadCertificate, + 43 => error.AlertUnsupportedCertificate, + 44 => error.AlertCertificateRevoked, + 45 => error.AlertCertificateExpired, + 46 => error.AlertCertificateUnknown, + 47 => error.AlertIllegalParameter, + 48 => error.AlertUnknownCA, + 49 => error.AlertAccessDenied, + 50 => error.AlertDecodeError, + 51 => error.AlertDecryptError, + 60 => error.AlertExportRestriction, + 70 => error.AlertProtocolVersion, + 71 => error.AlertInsufficientSecurity, + 80 => error.AlertInternalError, + 90 => error.AlertUserCanceled, + 100 => error.AlertNoRenegotiation, + 110 => error.AlertUnsupportedExtension, + else => error.ServerMalformedResponse, + }; +} + +// TODO: Now that we keep all the hashes, check the ciphersuite for the hash +// type used and use it where necessary instead of hardcoding sha256 +const HashSet = struct { + sha224: Sha224, + sha256: Sha256, + sha384: Sha384, + sha512: Sha512, + + fn update(self: *@This(), buf: []const u8) void { + self.sha224.update(buf); + self.sha256.update(buf); + self.sha384.update(buf); + self.sha512.update(buf); + } +}; + +fn HashingReader(comptime Reader: anytype) type { + const State = struct { + hash_set: *HashSet, + reader: Reader, + }; + const S = struct { + pub fn read(state: State, buffer: []u8) Reader.Error!usize { + const amt = try state.reader.read(buffer); + if (amt != 0) { + state.hash_set.update(buffer[0..amt]); + } + return amt; + } + }; + return std.io.Reader(State, Reader.Error, S.read); +} + +fn make_hashing_reader(hash_set: *HashSet, reader: anytype) HashingReader(@TypeOf(reader)) { + return .{ .context = .{ .hash_set = hash_set, .reader = reader } }; +} + +fn HashingWriter(comptime Writer: anytype) type { + const State = struct { + hash_set: *HashSet, + writer: Writer, + }; + const S = struct { + pub fn write(state: State, buffer: []const u8) Writer.Error!usize { + const amt = try state.writer.write(buffer); + if (amt != 0) { + state.hash_set.update(buffer[0..amt]); + } + return amt; + } + }; + return std.io.Writer(State, Writer.Error, S.write); +} + +fn make_hashing_writer(hash_set: *HashSet, writer: anytype) HashingWriter(@TypeOf(writer)) { + return .{ .context = .{ .hash_set = hash_set, .writer = writer } }; +} + +fn CertificateReaderState(comptime Reader: type) type { + return struct { + reader: Reader, + length: usize, + idx: usize = 0, + }; +} + +fn CertificateReader(comptime Reader: type) type { + const S = struct { + pub fn read(state: *CertificateReaderState(Reader), buffer: []u8) Reader.Error!usize { + const out_bytes = std.math.min(buffer.len, state.length - state.idx); + const res = try state.reader.readAll(buffer[0..out_bytes]); + state.idx += res; + return res; + } + }; + + return std.io.Reader(*CertificateReaderState(Reader), Reader.Error, S.read); +} + +pub const CertificateVerifier = union(enum) { + none, + function: anytype, + default, +}; + +pub fn CertificateVerifierReader(comptime Reader: type) type { + return CertificateReader(HashingReader(Reader)); +} + +pub fn ClientConnectError( + comptime verifier: CertificateVerifier, + comptime Reader: type, + comptime Writer: type, + comptime has_client_certs: bool, +) type { + const Additional = error{ + ServerInvalidVersion, + ServerMalformedResponse, + EndOfStream, + ServerInvalidCipherSuite, + ServerInvalidCompressionMethod, + ServerInvalidRenegotiationData, + ServerInvalidECPointCompression, + ServerInvalidProtocol, + ServerInvalidExtension, + ServerInvalidCurve, + ServerInvalidSignature, + ServerInvalidSignatureAlgorithm, + ServerAuthenticationFailed, + ServerInvalidVerifyData, + PreMasterGenerationFailed, + OutOfMemory, + }; + const err_msg = "Certificate verifier function cannot be generic, use CertificateVerifierReader to get the reader argument type"; + return Reader.Error || Writer.Error || ServerAlert || Additional || switch (verifier) { + .none => error{}, + .function => |f| @typeInfo(@typeInfo(@TypeOf(f)).Fn.return_type orelse + @compileError(err_msg)).ErrorUnion.error_set || error{CertificateVerificationFailed}, + .default => error{CertificateVerificationFailed}, + } || (if (has_client_certs) error{ClientCertificateVerifyFailed} else error{}); +} + +// See http://howardhinnant.github.io/date_algorithms.html +// Timestamp in seconds, only supports A.D. dates +fn unix_timestamp_from_civil_date(year: u16, month: u8, day: u8) i64 { + var y: i64 = year; + if (month <= 2) y -= 1; + const era = @divTrunc(y, 400); + const yoe = y - era * 400; // [0, 399] + const doy = @divTrunc((153 * (month + (if (month > 2) @as(i64, -3) else 9)) + 2), 5) + day - 1; // [0, 365] + const doe = yoe * 365 + @divTrunc(yoe, 4) - @divTrunc(yoe, 100) + doy; // [0, 146096] + return (era * 146097 + doe - 719468) * 86400; +} + +fn read_der_utc_timestamp(reader: anytype) !i64 { + var buf: [17]u8 = undefined; + + const tag = try reader.readByte(); + if (tag != 0x17) + return error.CertificateVerificationFailed; + const len = try asn1.der.parse_length(reader); + if (len > 17) + return error.CertificateVerificationFailed; + + try reader.readNoEof(buf[0..len]); + const year = std.fmt.parseUnsigned(u16, buf[0..2], 10) catch + return error.CertificateVerificationFailed; + const month = std.fmt.parseUnsigned(u8, buf[2..4], 10) catch + return error.CertificateVerificationFailed; + const day = std.fmt.parseUnsigned(u8, buf[4..6], 10) catch + return error.CertificateVerificationFailed; + + var time = unix_timestamp_from_civil_date(2000 + year, month, day); + time += (std.fmt.parseUnsigned(i64, buf[6..8], 10) catch + return error.CertificateVerificationFailed) * 3600; + time += (std.fmt.parseUnsigned(i64, buf[8..10], 10) catch + return error.CertificateVerificationFailed) * 60; + + if (buf[len - 1] == 'Z') { + if (len == 13) { + time += std.fmt.parseUnsigned(u8, buf[10..12], 10) catch + return error.CertificateVerificationFailed; + } else if (len != 11) { + return error.CertificateVerificationFailed; + } + } else { + if (len == 15) { + if (buf[10] != '+' and buf[10] != '-') + return error.CertificateVerificationFailed; + + var additional = (std.fmt.parseUnsigned(i64, buf[11..13], 10) catch + return error.CertificateVerificationFailed) * 3600; + additional += (std.fmt.parseUnsigned(i64, buf[13..15], 10) catch + return error.CertificateVerificationFailed) * 60; + + time += if (buf[10] == '+') -additional else additional; + } else if (len == 17) { + if (buf[12] != '+' and buf[12] != '-') + return error.CertificateVerificationFailed; + time += std.fmt.parseUnsigned(u8, buf[10..12], 10) catch + return error.CertificateVerificationFailed; + + var additional = (std.fmt.parseUnsigned(i64, buf[13..15], 10) catch + return error.CertificateVerificationFailed) * 3600; + additional += (std.fmt.parseUnsigned(i64, buf[15..17], 10) catch + return error.CertificateVerificationFailed) * 60; + + time += if (buf[12] == '+') -additional else additional; + } else return error.CertificateVerificationFailed; + } + return time; +} + +fn check_cert_timestamp(time: i64, tag_byte: u8, length: usize, reader: anytype) !void { + _ = tag_byte; + _ = length; + if (time < (try read_der_utc_timestamp(reader))) + return error.CertificateVerificationFailed; + if (time > (try read_der_utc_timestamp(reader))) + return error.CertificateVerificationFailed; +} + +fn add_dn_field(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = length; + _ = tag; + + const seq_tag = try reader.readByte(); + if (seq_tag != 0x30) + return error.CertificateVerificationFailed; + const seq_length = try asn1.der.parse_length(reader); + _ = seq_length; + + const oid_tag = try reader.readByte(); + if (oid_tag != 0x06) + return error.CertificateVerificationFailed; + + const oid_length = try asn1.der.parse_length(reader); + if (oid_length == 3 and (try reader.isBytes("\x55\x04\x03"))) { + // Common name + const common_name_tag = try reader.readByte(); + if (common_name_tag != 0x04 and common_name_tag != 0x0c and common_name_tag != 0x13 and common_name_tag != 0x16) + return error.CertificateVerificationFailed; + const common_name_len = try asn1.der.parse_length(reader); + state.list.items[state.list.items.len - 1].common_name = state.fbs.buffer[state.fbs.pos .. state.fbs.pos + common_name_len]; + } +} + +fn add_cert_subject_dn(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + state.list.items[state.list.items.len - 1].dn = state.fbs.buffer[state.fbs.pos .. state.fbs.pos + length]; + const schema = .{ + .sequence_of, + .{ + .capture, 0, .set, + }, + }; + const captures = .{ + state, add_dn_field, + }; + try asn1.der.parse_schema_tag_len(tag, length, schema, captures, reader); +} + +fn add_cert_public_key(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + state.list.items[state.list.items.len - 1].public_key = x509.parse_public_key( + state.allocator, + reader, + ) catch |err| switch (err) { + error.MalformedDER => return error.CertificateVerificationFailed, + else => |e| return e, + }; +} + +fn add_cert_extensions(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const schema = .{ + .sequence_of, + .{ .capture, 0, .sequence }, + }; + const captures = .{ + state, add_cert_extension, + }; + + try asn1.der.parse_schema(schema, captures, reader); +} + +fn add_cert_extension(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const start = state.fbs.pos; + + // The happy path is allocation free + // TODO: add a preflight check to mandate a specific tag + const object_id = try asn1.der.parse_value(state.allocator, reader); + defer object_id.deinit(state.allocator); + if (object_id != .object_identifier) return error.DoesNotMatchSchema; + if (object_id.object_identifier.len != 4) + return; + + const data = object_id.object_identifier.data; + // Prefix == id-ce + if (data[0] != 2 or data[1] != 5 or data[2] != 29) + return; + + switch (data[3]) { + 17 => { + const san_tag = try reader.readByte(); + if (san_tag != @enumToInt(asn1.Tag.octet_string)) return error.DoesNotMatchSchema; + + const san_length = try asn1.der.parse_length(reader); + _ = san_length; + + const body_tag = try reader.readByte(); + if (body_tag != @enumToInt(asn1.Tag.sequence)) return error.DoesNotMatchSchema; + + const body_length = try asn1.der.parse_length(reader); + const total_read = state.fbs.pos - start; + if (total_read + body_length > length) return error.DoesNotMatchSchema; + + state.list.items[state.list.items.len - 1].raw_subject_alternative_name = state.fbs.buffer[state.fbs.pos .. state.fbs.pos + body_length]; + + // Validate to make sure this is iterable later + const ref = state.fbs.pos; + while (state.fbs.pos - ref < body_length) { + const choice = try reader.readByte(); + if (choice < 0x80) return error.DoesNotMatchSchema; + + const chunk_length = try asn1.der.parse_length(reader); + _ = try reader.skipBytes(chunk_length, .{}); + } + }, + else => {}, + } +} + +fn add_server_cert(state: *VerifierCaptureState, tag_byte: u8, length: usize, reader: anytype) !void { + const is_ca = state.list.items.len != 0; + + // TODO: Some way to get tag + length buffer directly in the capture callback? + const encoded_length = asn1.der.encode_length(length).slice(); + // This is not errdefered since default_cert_verifier call takes care of cleaning up all the certificate data. + // Same for the signature.data + const cert_bytes = try state.allocator.alloc(u8, length + 1 + encoded_length.len); + cert_bytes[0] = tag_byte; + mem.copy(u8, cert_bytes[1 .. 1 + encoded_length.len], encoded_length); + + try reader.readNoEof(cert_bytes[1 + encoded_length.len ..]); + (try state.list.addOne(state.allocator)).* = .{ + .is_ca = is_ca, + .bytes = cert_bytes, + .dn = undefined, + .common_name = &[0]u8{}, + .raw_subject_alternative_name = &[0]u8{}, + .public_key = x509.PublicKey.empty, + .signature = asn1.BitString{ .data = &[0]u8{}, .bit_len = 0 }, + .signature_algorithm = undefined, + }; + + const schema = .{ + .sequence, + .{ + .{ .context_specific, 0 }, // version + .{.int}, // serialNumber + .{.sequence}, // signature + .{.sequence}, // issuer + .{ .capture, 0, .sequence }, // validity + .{ .capture, 1, .sequence }, // subject + .{ .capture, 2, .sequence }, // subjectPublicKeyInfo + .{ .optional, .context_specific, 1 }, // issuerUniqueID + .{ .optional, .context_specific, 2 }, // subjectUniqueID + .{ .capture, 3, .optional, .context_specific, 3 }, // extensions + }, + }; + + const captures = .{ + std.time.timestamp(), check_cert_timestamp, + state, add_cert_subject_dn, + state, add_cert_public_key, + state, add_cert_extensions, + }; + + var fbs = std.io.fixedBufferStream(@as([]const u8, cert_bytes[1 + encoded_length.len ..])); + state.fbs = &fbs; + + asn1.der.parse_schema_tag_len(tag_byte, length, schema, captures, fbs.reader()) catch |err| switch (err) { + error.InvalidLength, + error.InvalidTag, + error.InvalidContainerLength, + error.DoesNotMatchSchema, + => return error.CertificateVerificationFailed, + else => |e| return e, + }; +} + +fn set_signature_algorithm(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const cert = &state.list.items[state.list.items.len - 1]; + cert.signature_algorithm = (try x509.get_signature_algorithm(reader)) orelse return error.CertificateVerificationFailed; +} + +fn set_signature_value(state: *VerifierCaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const unused_bits = try reader.readByte(); + const bit_count = (length - 1) * 8 - unused_bits; + const signature_bytes = try state.allocator.alloc(u8, length - 1); + errdefer state.allocator.free(signature_bytes); + try reader.readNoEof(signature_bytes); + state.list.items[state.list.items.len - 1].signature = .{ + .data = signature_bytes, + .bit_len = bit_count, + }; +} + +const ServerCertificate = struct { + bytes: []const u8, + dn: []const u8, + common_name: []const u8, + raw_subject_alternative_name: []const u8, + public_key: x509.PublicKey, + signature: asn1.BitString, + signature_algorithm: x509.Certificate.SignatureAlgorithm, + is_ca: bool, + + const GeneralName = enum(u5) { + other_name = 0, + rfc822_name = 1, + dns_name = 2, + x400_address = 3, + directory_name = 4, + edi_party_name = 5, + uniform_resource_identifier = 6, + ip_address = 7, + registered_id = 8, + }; + + fn iterSAN(self: ServerCertificate, choice: GeneralName) NameIterator { + return .{ .cert = self, .choice = choice }; + } + + const NameIterator = struct { + cert: ServerCertificate, + choice: GeneralName, + pos: usize = 0, + + fn next(self: *NameIterator) ?[]const u8 { + while (self.pos < self.cert.raw_subject_alternative_name.len) { + const choice = self.cert.raw_subject_alternative_name[self.pos]; + std.debug.assert(choice >= 0x80); + const len = self.cert.raw_subject_alternative_name[self.pos + 1]; + const start = self.pos + 2; + const end = start + len; + self.pos = end; + if (@enumToInt(self.choice) == choice - 0x80) { + return self.cert.raw_subject_alternative_name[start..end]; + } + } + return null; + } + }; +}; + +const VerifierCaptureState = struct { + list: std.ArrayListUnmanaged(ServerCertificate), + allocator: *Allocator, + // Used in `add_server_cert` to avoid an extra allocation + fbs: *std.io.FixedBufferStream([]const u8), +}; + +// @TODO Move out of here +const ReverseSplitIterator = struct { + buffer: []const u8, + index: ?usize, + delimiter: []const u8, + + pub fn next(self: *ReverseSplitIterator) ?[]const u8 { + const end = self.index orelse return null; + const start = if (mem.lastIndexOfLinear(u8, self.buffer[0..end], self.delimiter)) |delim_start| blk: { + self.index = delim_start; + break :blk delim_start + self.delimiter.len; + } else blk: { + self.index = null; + break :blk 0; + }; + return self.buffer[start..end]; + } +}; + +fn reverse_split(buffer: []const u8, delimiter: []const u8) ReverseSplitIterator { + std.debug.assert(delimiter.len != 0); + return .{ + .index = buffer.len, + .buffer = buffer, + .delimiter = delimiter, + }; +} + +fn cert_name_matches(cert_name: []const u8, hostname: []const u8) bool { + var cert_name_split = reverse_split(cert_name, "."); + var hostname_split = reverse_split(hostname, "."); + while (true) { + const cn_part = cert_name_split.next(); + const hn_part = hostname_split.next(); + + if (cn_part) |cnp| { + if (hn_part == null and cert_name_split.index == null and mem.eql(u8, cnp, "www")) + return true + else if (hn_part) |hnp| { + if (mem.eql(u8, cnp, "*")) + continue; + if (!mem.eql(u8, cnp, hnp)) + return false; + } + } else return hn_part == null; + } +} + +pub fn default_cert_verifier( + allocator: *mem.Allocator, + reader: anytype, + certs_bytes: usize, + trusted_certificates: []const x509.Certificate, + hostname: []const u8, +) !x509.PublicKey { + var capture_state = VerifierCaptureState{ + .list = try std.ArrayListUnmanaged(ServerCertificate).initCapacity(allocator, 3), + .allocator = allocator, + .fbs = undefined, + }; + defer { + for (capture_state.list.items) |cert| { + cert.public_key.deinit(allocator); + allocator.free(cert.bytes); + allocator.free(cert.signature.data); + } + capture_state.list.deinit(allocator); + } + + const schema = .{ + .sequence, .{ + // tbsCertificate + .{ .capture, 0, .sequence }, + // signatureAlgorithm + .{ .capture, 1, .sequence }, + // signatureValue + .{ .capture, 2, .bit_string }, + }, + }; + const captures = .{ + &capture_state, add_server_cert, + &capture_state, set_signature_algorithm, + &capture_state, set_signature_value, + }; + + var bytes_read: u24 = 0; + while (bytes_read < certs_bytes) { + const cert_length = try reader.readIntBig(u24); + + asn1.der.parse_schema(schema, captures, reader) catch |err| switch (err) { + error.InvalidLength, + error.InvalidTag, + error.InvalidContainerLength, + error.DoesNotMatchSchema, + => return error.CertificateVerificationFailed, + else => |e| return e, + }; + + bytes_read += 3 + cert_length; + } + if (bytes_read != certs_bytes) + return error.CertificateVerificationFailed; + + const chain = capture_state.list.items; + if (chain.len == 0) return error.CertificateVerificationFailed; + // Check if the hostname matches one of the leaf certificate's names + name_matched: { + if (cert_name_matches(chain[0].common_name, hostname)) { + break :name_matched; + } + + var iter = chain[0].iterSAN(.dns_name); + while (iter.next()) |cert_name| { + if (cert_name_matches(cert_name, hostname)) { + break :name_matched; + } + } + + return error.CertificateVerificationFailed; + } + + var i: usize = 0; + while (i < chain.len - 1) : (i += 1) { + if (!try @"pcks1v1.5".certificate_verify_signature( + allocator, + chain[i].signature_algorithm, + chain[i].signature, + chain[i].bytes, + chain[i + 1].public_key, + )) { + return error.CertificateVerificationFailed; + } + } + + for (chain) |cert| { + for (trusted_certificates) |trusted| { + // Try to find an exact match to a trusted certificate + if (cert.is_ca == trusted.is_ca and mem.eql(u8, cert.dn, trusted.dn) and + cert.public_key.eql(trusted.public_key)) + { + const key = chain[0].public_key; + chain[0].public_key = x509.PublicKey.empty; + return key; + } + + if (!trusted.is_ca) + continue; + + if (try @"pcks1v1.5".certificate_verify_signature( + allocator, + cert.signature_algorithm, + cert.signature, + cert.bytes, + trusted.public_key, + )) { + const key = chain[0].public_key; + chain[0].public_key = x509.PublicKey.empty; + return key; + } + } + } + return error.CertificateVerificationFailed; +} + +pub fn extract_cert_public_key(allocator: *Allocator, reader: anytype, length: usize) !x509.PublicKey { + const CaptureState = struct { + pub_key: x509.PublicKey, + allocator: *Allocator, + }; + var capture_state = CaptureState{ + .pub_key = undefined, + .allocator = allocator, + }; + + const schema = .{ + .sequence, .{ + // tbsCertificate + .{ + .sequence, + .{ + .{ .context_specific, 0 }, // version + .{.int}, // serialNumber + .{.sequence}, // signature + .{.sequence}, // issuer + .{.sequence}, // validity + .{.sequence}, // subject + .{ .capture, 0, .sequence }, // subjectPublicKeyInfo + .{ .optional, .context_specific, 1 }, // issuerUniqueID + .{ .optional, .context_specific, 2 }, // subjectUniqueID + .{ .optional, .context_specific, 3 }, // extensions + }, + }, + // signatureAlgorithm + .{.sequence}, + // signatureValue + .{.bit_string}, + }, + }; + const captures = .{ + &capture_state, struct { + fn f(state: *CaptureState, tag: u8, _length: usize, subreader: anytype) !void { + _ = tag; + _ = _length; + + state.pub_key = x509.parse_public_key(state.allocator, subreader) catch |err| switch (err) { + error.MalformedDER => return error.ServerMalformedResponse, + else => |e| return e, + }; + } + }.f, + }; + + const cert_length = try reader.readIntBig(u24); + asn1.der.parse_schema(schema, captures, reader) catch |err| switch (err) { + error.InvalidLength, + error.InvalidTag, + error.InvalidContainerLength, + error.DoesNotMatchSchema, + => return error.ServerMalformedResponse, + else => |e| return e, + }; + errdefer capture_state.pub_key.deinit(allocator); + + try reader.skipBytes(length - cert_length - 3, .{}); + return capture_state.pub_key; +} + +pub const curves = struct { + pub const x25519 = struct { + pub const name = "x25519"; + const tag = 0x001D; + const pub_key_len = 32; + const Keys = std.crypto.dh.X25519.KeyPair; + + inline fn make_key_pair(rand: *std.rand.Random) Keys { + while (true) { + var seed: [32]u8 = undefined; + rand.bytes(&seed); + return std.crypto.dh.X25519.KeyPair.create(seed) catch continue; + } else unreachable; + } + + inline fn make_pre_master_secret( + key_pair: Keys, + pre_master_secret_buf: []u8, + server_public_key: *const [32]u8, + ) ![]const u8 { + pre_master_secret_buf[0..32].* = std.crypto.dh.X25519.scalarmult( + key_pair.secret_key, + server_public_key.*, + ) catch return error.PreMasterGenerationFailed; + return pre_master_secret_buf[0..32]; + } + }; + + pub const secp384r1 = struct { + pub const name = "secp384r1"; + const tag = 0x0018; + const pub_key_len = 97; + const Keys = crypto.ecc.KeyPair(crypto.ecc.SECP384R1); + + inline fn make_key_pair(rand: *std.rand.Random) Keys { + var seed: [48]u8 = undefined; + rand.bytes(&seed); + return crypto.ecc.make_key_pair(crypto.ecc.SECP384R1, seed); + } + + inline fn make_pre_master_secret( + key_pair: Keys, + pre_master_secret_buf: []u8, + server_public_key: *const [97]u8, + ) ![]const u8 { + pre_master_secret_buf[0..96].* = crypto.ecc.scalarmult( + crypto.ecc.SECP384R1, + server_public_key[1..].*, + &key_pair.secret_key, + ) catch return error.PreMasterGenerationFailed; + return pre_master_secret_buf[0..48]; + } + }; + + pub const secp256r1 = struct { + pub const name = "secp256r1"; + const tag = 0x0017; + const pub_key_len = 65; + const Keys = crypto.ecc.KeyPair(crypto.ecc.SECP256R1); + + inline fn make_key_pair(rand: *std.rand.Random) Keys { + var seed: [32]u8 = undefined; + rand.bytes(&seed); + return crypto.ecc.make_key_pair(crypto.ecc.SECP256R1, seed); + } + + inline fn make_pre_master_secret( + key_pair: Keys, + pre_master_secret_buf: []u8, + server_public_key: *const [65]u8, + ) ![]const u8 { + pre_master_secret_buf[0..64].* = crypto.ecc.scalarmult( + crypto.ecc.SECP256R1, + server_public_key[1..].*, + &key_pair.secret_key, + ) catch return error.PreMasterGenerationFailed; + return pre_master_secret_buf[0..32]; + } + }; + + pub const all = &[_]type{ x25519, secp384r1, secp256r1 }; + + fn max_pub_key_len(comptime list: anytype) usize { + var max: usize = 0; + for (list) |curve| { + if (curve.pub_key_len > max) + max = curve.pub_key_len; + } + return max; + } + + fn max_pre_master_secret_len(comptime list: anytype) usize { + var max: usize = 0; + for (list) |curve| { + const curr = @typeInfo(std.meta.fieldInfo(curve.Keys, .public_key).field_type).Array.len; + if (curr > max) + max = curr; + } + return max; + } + + fn KeyPair(comptime list: anytype) type { + var fields: [list.len]std.builtin.TypeInfo.UnionField = undefined; + for (list) |curve, i| { + fields[i] = .{ + .name = curve.name, + .field_type = curve.Keys, + .alignment = @alignOf(curve.Keys), + }; + } + return @Type(.{ + .Union = .{ + .layout = .Extern, + .tag_type = null, + .fields = &fields, + .decls = &[0]std.builtin.TypeInfo.Declaration{}, + }, + }); + } + + inline fn make_key_pair(comptime list: anytype, curve_id: u16, rand: *std.rand.Random) KeyPair(list) { + inline for (list) |curve| { + if (curve.tag == curve_id) { + return @unionInit(KeyPair(list), curve.name, curve.make_key_pair(rand)); + } + } + unreachable; + } + + inline fn make_pre_master_secret( + comptime list: anytype, + curve_id: u16, + key_pair: KeyPair(list), + pre_master_secret_buf: *[max_pre_master_secret_len(list)]u8, + server_public_key: [max_pub_key_len(list)]u8, + ) ![]const u8 { + inline for (list) |curve| { + if (curve.tag == curve_id) { + return try curve.make_pre_master_secret( + @field(key_pair, curve.name), + pre_master_secret_buf, + server_public_key[0..curve.pub_key_len], + ); + } + } + unreachable; + } +}; + +pub fn client_connect( + options: anytype, + hostname: []const u8, +) ClientConnectError( + options.cert_verifier, + @TypeOf(options.reader), + @TypeOf(options.writer), + @hasField(@TypeOf(options), "client_certificates"), +)!Client( + @TypeOf(options.reader), + @TypeOf(options.writer), + if (@hasField(@TypeOf(options), "ciphersuites")) + options.ciphersuites + else + ciphersuites.all, + @hasField(@TypeOf(options), "protocols"), +) { + const Options = @TypeOf(options); + if (@TypeOf(options.cert_verifier) != CertificateVerifier and + @TypeOf(options.cert_verifier) != @Type(.EnumLiteral)) + @compileError("cert_verifier should be of type CertificateVerifier"); + + if (!@hasField(Options, "temp_allocator")) + @compileError("Option tuple is missing field 'temp_allocator'"); + if (options.cert_verifier == .default) { + if (!@hasField(Options, "trusted_certificates")) + @compileError("Option tuple is missing field 'trusted_certificates' for .default cert_verifier"); + } + + const suites = if (!@hasField(Options, "ciphersuites")) + ciphersuites.all + else + options.ciphersuites; + if (suites.len == 0) + @compileError("Must provide at least one ciphersuite type."); + + const curvelist = if (!@hasField(Options, "curves")) + curves.all + else + options.curves; + if (curvelist.len == 0) + @compileError("Must provide at least one curve type."); + + const has_alpn = comptime @hasField(Options, "protocols"); + var handshake_record_hash_set = HashSet{ + .sha224 = Sha224.init(.{}), + .sha256 = Sha256.init(.{}), + .sha384 = Sha384.init(.{}), + .sha512 = Sha512.init(.{}), + }; + const reader = options.reader; + const writer = options.writer; + const hashing_reader = make_hashing_reader(&handshake_record_hash_set, reader); + const hashing_writer = make_hashing_writer(&handshake_record_hash_set, writer); + + var client_random: [32]u8 = undefined; + const rand = if (!@hasField(Options, "rand")) + std.crypto.random + else + options.rand; + + rand.bytes(&client_random); + + var server_random: [32]u8 = undefined; + const ciphersuite_bytes = 2 * suites.len + 2; + const alpn_bytes = if (has_alpn) blk: { + var sum: usize = 0; + for (options.protocols) |proto| { + sum += proto.len; + } + break :blk 6 + options.protocols.len + sum; + } else 0; + const curvelist_bytes = 2 * curvelist.len; + var protocol: if (has_alpn) []const u8 else void = undefined; + { + const client_hello_start = comptime blk: { + // TODO: We assume the compiler is running in a little endian system + var starting_part: [46]u8 = [_]u8{ + // Record header: Handshake record type, protocol version, handshake size + 0x16, 0x03, 0x01, undefined, undefined, + // Handshake message type, bytes of client hello + 0x01, undefined, undefined, undefined, + // Client version (hardcoded to TLS 1.2 even for TLS 1.3) + 0x03, + 0x03, + } ++ ([1]u8{undefined} ** 32) ++ [_]u8{ + // Session ID + 0x00, + } ++ mem.toBytes(@byteSwap(u16, ciphersuite_bytes)); + // using .* = mem.asBytes(...).* or mem.writeIntBig didn't work... + + // Same as above, couldnt achieve this with a single buffer. + // TLS_EMPTY_RENEGOTIATION_INFO_SCSV + var ciphersuite_buf: []const u8 = &[2]u8{ 0x00, 0x0f }; + for (suites) |cs| { + // Also check for properties of the ciphersuites here + if (cs.key_exchange != .ecdhe) + @compileError("Non ECDHE key exchange is not supported yet."); + if (cs.hash != .sha256) + @compileError("Non SHA256 hash algorithm is not supported yet."); + + ciphersuite_buf = ciphersuite_buf ++ mem.toBytes(@byteSwap(u16, cs.tag)); + } + + var ending_part: [13]u8 = [_]u8{ + // Compression methods (no compression) + 0x01, 0x00, + // Extensions length + undefined, undefined, + // Extension: server name + // id, length, length of entry + 0x00, 0x00, + undefined, undefined, + undefined, undefined, + // entry type, length of bytes + 0x00, undefined, + undefined, + }; + break :blk starting_part ++ ciphersuite_buf ++ ending_part; + }; + + var msg_buf = client_hello_start.ptr[0..client_hello_start.len].*; + mem.writeIntBig(u16, msg_buf[3..5], @intCast(u16, alpn_bytes + hostname.len + 0x55 + ciphersuite_bytes + curvelist_bytes)); + mem.writeIntBig(u24, msg_buf[6..9], @intCast(u24, alpn_bytes + hostname.len + 0x51 + ciphersuite_bytes + curvelist_bytes)); + mem.copy(u8, msg_buf[11..43], &client_random); + mem.writeIntBig(u16, msg_buf[48 + ciphersuite_bytes ..][0..2], @intCast(u16, alpn_bytes + hostname.len + 0x28 + curvelist_bytes)); + mem.writeIntBig(u16, msg_buf[52 + ciphersuite_bytes ..][0..2], @intCast(u16, hostname.len + 5)); + mem.writeIntBig(u16, msg_buf[54 + ciphersuite_bytes ..][0..2], @intCast(u16, hostname.len + 3)); + mem.writeIntBig(u16, msg_buf[57 + ciphersuite_bytes ..][0..2], @intCast(u16, hostname.len)); + try writer.writeAll(msg_buf[0..5]); + try hashing_writer.writeAll(msg_buf[5..]); + } + try hashing_writer.writeAll(hostname); + if (has_alpn) { + var msg_buf = [6]u8{ 0x00, 0x10, undefined, undefined, undefined, undefined }; + mem.writeIntBig(u16, msg_buf[2..4], @intCast(u16, alpn_bytes - 4)); + mem.writeIntBig(u16, msg_buf[4..6], @intCast(u16, alpn_bytes - 6)); + try hashing_writer.writeAll(&msg_buf); + for (options.protocols) |proto| { + try hashing_writer.writeByte(@intCast(u8, proto.len)); + try hashing_writer.writeAll(proto); + } + } + + // Extension: supported groups + { + var msg_buf = [6]u8{ + 0x00, 0x0A, + undefined, undefined, + undefined, undefined, + }; + + mem.writeIntBig(u16, msg_buf[2..4], @intCast(u16, curvelist_bytes + 2)); + mem.writeIntBig(u16, msg_buf[4..6], @intCast(u16, curvelist_bytes)); + try hashing_writer.writeAll(&msg_buf); + + inline for (curvelist) |curve| { + try hashing_writer.writeIntBig(u16, curve.tag); + } + } + + try hashing_writer.writeAll(&[25]u8{ + // Extension: EC point formats => uncompressed point format + 0x00, 0x0B, 0x00, 0x02, 0x01, 0x00, + // Extension: Signature algorithms + // RSA/PKCS1/SHA256, RSA/PKCS1/SHA512 + 0x00, 0x0D, 0x00, 0x06, 0x00, 0x04, + 0x04, 0x01, 0x06, 0x01, + // Extension: Renegotiation Info => new connection + 0xFF, 0x01, + 0x00, 0x01, 0x00, + // Extension: SCT (signed certificate timestamp) + 0x00, 0x12, 0x00, + 0x00, + }); + + // Read server hello + var ciphersuite: u16 = undefined; + { + const length = try handshake_record_length(reader); + if (length < 44) + return error.ServerMalformedResponse; + { + var hs_hdr_and_server_ver: [6]u8 = undefined; + try hashing_reader.readNoEof(&hs_hdr_and_server_ver); + if (hs_hdr_and_server_ver[0] != 0x02) + return error.ServerMalformedResponse; + if (!mem.eql(u8, hs_hdr_and_server_ver[4..6], "\x03\x03")) + return error.ServerInvalidVersion; + } + try hashing_reader.readNoEof(&server_random); + + // Just skip the session id for now + const sess_id_len = try hashing_reader.readByte(); + if (sess_id_len != 0) + try hashing_reader.skipBytes(sess_id_len, .{}); + + { + ciphersuite = try hashing_reader.readIntBig(u16); + var found = false; + inline for (suites) |cs| { + if (ciphersuite == cs.tag) { + found = true; + // TODO This segfaults stage1 + // break; + } + } + if (!found) + return error.ServerInvalidCipherSuite; + } + + // Compression method + if ((try hashing_reader.readByte()) != 0x00) + return error.ServerInvalidCompressionMethod; + + const exts_length = try hashing_reader.readIntBig(u16); + var ext_byte_idx: usize = 0; + while (ext_byte_idx < exts_length) { + var ext_tag: [2]u8 = undefined; + try hashing_reader.readNoEof(&ext_tag); + + const ext_len = try hashing_reader.readIntBig(u16); + ext_byte_idx += 4 + ext_len; + if (ext_tag[0] == 0xFF and ext_tag[1] == 0x01) { + // Renegotiation info + const renegotiation_info = try hashing_reader.readByte(); + if (ext_len != 0x01 or renegotiation_info != 0x00) + return error.ServerInvalidRenegotiationData; + } else if (ext_tag[0] == 0x00 and ext_tag[1] == 0x00) { + // Server name + if (ext_len != 0) + try hashing_reader.skipBytes(ext_len, .{}); + } else if (ext_tag[0] == 0x00 and ext_tag[1] == 0x0B) { + const format_count = try hashing_reader.readByte(); + var found_uncompressed = false; + var i: usize = 0; + while (i < format_count) : (i += 1) { + const byte = try hashing_reader.readByte(); + if (byte == 0x0) + found_uncompressed = true; + } + if (!found_uncompressed) + return error.ServerInvalidECPointCompression; + } else if (has_alpn and ext_tag[0] == 0x00 and ext_tag[1] == 0x10) { + const alpn_ext_len = try hashing_reader.readIntBig(u16); + if (alpn_ext_len != ext_len - 2) + return error.ServerMalformedResponse; + const str_len = try hashing_reader.readByte(); + var buf: [256]u8 = undefined; + try hashing_reader.readNoEof(buf[0..str_len]); + const found = for (options.protocols) |proto| { + if (mem.eql(u8, proto, buf[0..str_len])) { + protocol = proto; + break true; + } + } else false; + if (!found) + return error.ServerInvalidProtocol; + try hashing_reader.skipBytes(alpn_ext_len - str_len - 1, .{}); + } else return error.ServerInvalidExtension; + } + if (ext_byte_idx != exts_length) + return error.ServerMalformedResponse; + } + // Read server certificates + var certificate_public_key: x509.PublicKey = undefined; + { + const length = try handshake_record_length(reader); + _ = length; + { + var handshake_header: [4]u8 = undefined; + try hashing_reader.readNoEof(&handshake_header); + if (handshake_header[0] != 0x0b) + return error.ServerMalformedResponse; + } + const certs_length = try hashing_reader.readIntBig(u24); + const cert_verifier: CertificateVerifier = options.cert_verifier; + switch (cert_verifier) { + .none => certificate_public_key = try extract_cert_public_key( + options.temp_allocator, + hashing_reader, + certs_length, + ), + .function => |f| { + var reader_state = CertificateReaderState(@TypeOf(hashing_reader)){ + .reader = hashing_reader, + .length = certs_length, + }; + var cert_reader = CertificateReader(@TypeOf(hashing_reader)){ .context = &reader_state }; + certificate_public_key = try f(cert_reader); + try hashing_reader.skipBytes(reader_state.length - reader_state.idx, .{}); + }, + .default => certificate_public_key = try default_cert_verifier( + options.temp_allocator, + hashing_reader, + certs_length, + options.trusted_certificates, + hostname, + ), + } + } + errdefer certificate_public_key.deinit(options.temp_allocator); + // Read server ephemeral public key + var server_public_key_buf: [curves.max_pub_key_len(curvelist)]u8 = undefined; + var curve_id: u16 = undefined; + var curve_id_buf: [3]u8 = undefined; + var pub_key_len: u8 = undefined; + { + const length = try handshake_record_length(reader); + _ = length; + { + var handshake_header: [4]u8 = undefined; + try hashing_reader.readNoEof(&handshake_header); + if (handshake_header[0] != 0x0c) + return error.ServerMalformedResponse; + + try hashing_reader.readNoEof(&curve_id_buf); + if (curve_id_buf[0] != 0x03) + return error.ServerMalformedResponse; + + curve_id = mem.readIntBig(u16, curve_id_buf[1..]); + var found = false; + inline for (curvelist) |curve| { + if (curve.tag == curve_id) { + found = true; + // @TODO This break segfaults stage1 + // break; + } + } + if (!found) + return error.ServerInvalidCurve; + } + + pub_key_len = try hashing_reader.readByte(); + inline for (curvelist) |curve| { + if (curve.tag == curve_id) { + if (curve.pub_key_len != pub_key_len) + return error.ServerMalformedResponse; + // @TODO This break segfaults stage1 + // break; + } + } + + try hashing_reader.readNoEof(server_public_key_buf[0..pub_key_len]); + if (curve_id != curves.x25519.tag) { + if (server_public_key_buf[0] != 0x04) + return error.ServerMalformedResponse; + } + + // Signed public key + const signature_id = try hashing_reader.readIntBig(u16); + const signature_len = try hashing_reader.readIntBig(u16); + + var hash_buf: [64]u8 = undefined; + var hash: []const u8 = undefined; + const signature_algoritm: x509.Certificate.SignatureAlgorithm = switch (signature_id) { + // TODO: More + // RSA/PKCS1/SHA256 + 0x0401 => block: { + var sha256 = Sha256.init(.{}); + sha256.update(&client_random); + sha256.update(&server_random); + sha256.update(&curve_id_buf); + sha256.update(&[1]u8{pub_key_len}); + sha256.update(server_public_key_buf[0..pub_key_len]); + sha256.final(hash_buf[0..32]); + hash = hash_buf[0..32]; + break :block .{ .signature = .rsa, .hash = .sha256 }; + }, + // RSA/PKCS1/SHA512 + 0x0601 => block: { + var sha512 = Sha512.init(.{}); + sha512.update(&client_random); + sha512.update(&server_random); + sha512.update(&curve_id_buf); + sha512.update(&[1]u8{pub_key_len}); + sha512.update(server_public_key_buf[0..pub_key_len]); + sha512.final(hash_buf[0..64]); + hash = hash_buf[0..64]; + break :block .{ .signature = .rsa, .hash = .sha512 }; + }, + else => return error.ServerInvalidSignatureAlgorithm, + }; + const signature_bytes = try options.temp_allocator.alloc(u8, signature_len); + defer options.temp_allocator.free(signature_bytes); + try hashing_reader.readNoEof(signature_bytes); + + if (!try @"pcks1v1.5".verify_signature( + options.temp_allocator, + signature_algoritm, + .{ .data = signature_bytes, .bit_len = signature_len * 8 }, + hash, + certificate_public_key, + )) + return error.ServerInvalidSignature; + + certificate_public_key.deinit(options.temp_allocator); + certificate_public_key = x509.PublicKey.empty; + } + var client_certificate: ?*const x509.ClientCertificateChain = null; + { + const length = try handshake_record_length(reader); + const record_type = try hashing_reader.readByte(); + if (record_type == 14) { + // Server hello done + const is_bytes = try hashing_reader.isBytes("\x00\x00\x00"); + if (length != 4 or !is_bytes) + return error.ServerMalformedResponse; + } else if (record_type == 13) { + // Certificate request + const certificate_request_bytes = try hashing_reader.readIntBig(u24); + const hello_done_in_same_record = + if (length == certificate_request_bytes + 8) + true + else if (length != certificate_request_bytes) + false + else + return error.ServerMalformedResponse; + // TODO: For now, we are ignoring the certificate types, as they have been somewhat + // superceded by the supported_signature_algorithms field + const certificate_types_bytes = try hashing_reader.readByte(); + try hashing_reader.skipBytes(certificate_types_bytes, .{}); + + var chosen_client_certificates = std.ArrayListUnmanaged(*const x509.ClientCertificateChain){}; + defer chosen_client_certificates.deinit(options.temp_allocator); + + const signature_algorithms_bytes = try hashing_reader.readIntBig(u16); + if (@hasField(Options, "client_certificates")) { + var i: usize = 0; + while (i < signature_algorithms_bytes / 2) : (i += 1) { + var signature_algorithm: [2]u8 = undefined; + try hashing_reader.readNoEof(&signature_algorithm); + for (options.client_certificates) |*cert_chain| { + if (@enumToInt(cert_chain.signature_algorithm.hash) == signature_algorithm[0] and + @enumToInt(cert_chain.signature_algorithm.signature) == signature_algorithm[1]) + { + try chosen_client_certificates.append(options.temp_allocator, cert_chain); + } + } + } + } else { + try hashing_reader.skipBytes(signature_algorithms_bytes, .{}); + } + + const certificate_authorities_bytes = try hashing_reader.readIntBig(u16); + if (chosen_client_certificates.items.len == 0) { + try hashing_reader.skipBytes(certificate_authorities_bytes, .{}); + } else { + const dns_buf = try options.temp_allocator.alloc(u8, certificate_authorities_bytes); + defer options.temp_allocator.free(dns_buf); + + try hashing_reader.readNoEof(dns_buf); + var fbs = std.io.fixedBufferStream(dns_buf[2..]); + var fbs_reader = fbs.reader(); + + while (fbs.pos < fbs.buffer.len) { + const start_idx = fbs.pos; + const seq_tag = fbs_reader.readByte() catch return error.ServerMalformedResponse; + if (seq_tag != 0x30) + return error.ServerMalformedResponse; + + const seq_length = asn1.der.parse_length(fbs_reader) catch return error.ServerMalformedResponse; + fbs_reader.skipBytes(seq_length, .{}) catch return error.ServerMalformedResponse; + + var i: usize = 0; + while (i < chosen_client_certificates.items.len) { + const cert = chosen_client_certificates.items[i]; + var cert_idx: usize = 0; + while (cert_idx < cert.cert_len) : (cert_idx += 1) { + if (mem.eql(u8, cert.cert_issuer_dns[cert_idx], fbs.buffer[start_idx..fbs.pos])) + break; + } else { + _ = chosen_client_certificates.swapRemove(i); + continue; + } + i += 1; + } + } + if (fbs.pos != fbs.buffer.len) + return error.ServerMalformedResponse; + } + // Server hello done + if (!hello_done_in_same_record) { + const hello_done_record_len = try handshake_record_length(reader); + if (hello_done_record_len != 4) + return error.ServerMalformedResponse; + } + const hello_record_type = try hashing_reader.readByte(); + if (hello_record_type != 14) + return error.ServerMalformedResponse; + const is_bytes = try hashing_reader.isBytes("\x00\x00\x00"); + if (!is_bytes) + return error.ServerMalformedResponse; + + // Send the client certificate message + try writer.writeAll(&[3]u8{ 0x16, 0x03, 0x03 }); + if (chosen_client_certificates.items.len != 0) { + client_certificate = chosen_client_certificates.items[0]; + + const certificate_count = client_certificate.?.cert_len; + // 7 bytes for the record type tag (1), record length (3), certificate list length (3) + // 3 bytes for each certificate length + var total_len: u24 = 7 + 3 * @intCast(u24, certificate_count); + var i: usize = 0; + while (i < certificate_count) : (i += 1) { + total_len += @intCast(u24, client_certificate.?.raw_certs[i].len); + } + try writer.writeIntBig(u16, @intCast(u16, total_len)); + var msg_buf: [7]u8 = [1]u8{0x0b} ++ ([1]u8{undefined} ** 6); + mem.writeIntBig(u24, msg_buf[1..4], total_len - 4); + mem.writeIntBig(u24, msg_buf[4..7], total_len - 7); + try hashing_writer.writeAll(&msg_buf); + i = 0; + while (i < certificate_count) : (i += 1) { + try hashing_writer.writeIntBig(u24, @intCast(u24, client_certificate.?.raw_certs[i].len)); + try hashing_writer.writeAll(client_certificate.?.raw_certs[i]); + } + } else { + try writer.writeIntBig(u16, 7); + try hashing_writer.writeAll(&[7]u8{ 0x0b, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00 }); + } + } else return error.ServerMalformedResponse; + } + + // Generate keys for the session + const client_key_pair = curves.make_key_pair(curvelist, curve_id, rand); + + // Client key exchange + try writer.writeAll(&[3]u8{ 0x16, 0x03, 0x03 }); + try writer.writeIntBig(u16, pub_key_len + 5); + try hashing_writer.writeAll(&[5]u8{ 0x10, 0x00, 0x00, pub_key_len + 1, pub_key_len }); + + inline for (curvelist) |curve| { + if (curve.tag == curve_id) { + const actual_len = @typeInfo(std.meta.fieldInfo(curve.Keys, .public_key).field_type).Array.len; + if (pub_key_len == actual_len + 1) { + try hashing_writer.writeByte(0x04); + } else { + std.debug.assert(pub_key_len == actual_len); + } + try hashing_writer.writeAll(&@field(client_key_pair, curve.name).public_key); + break; + } + } + + // If we have a client certificate, send a certificate verify message + if (@hasField(Options, "client_certificates")) { + if (client_certificate) |client_cert| { + var current_hash_buf: [64]u8 = undefined; + var current_hash: []const u8 = undefined; + const hash_algo = client_cert.signature_algorithm.hash; + // TODO: Making this a switch statement kills stage1 + if (hash_algo == .none or hash_algo == .md5 or hash_algo == .sha1) + return error.ClientCertificateVerifyFailed + else if (hash_algo == .sha224) { + var hash_copy = handshake_record_hash_set.sha224; + hash_copy.final(current_hash_buf[0..28]); + current_hash = current_hash_buf[0..28]; + } else if (hash_algo == .sha256) { + var hash_copy = handshake_record_hash_set.sha256; + hash_copy.final(current_hash_buf[0..32]); + current_hash = current_hash_buf[0..32]; + } else if (hash_algo == .sha384) { + var hash_copy = handshake_record_hash_set.sha384; + hash_copy.final(current_hash_buf[0..48]); + current_hash = current_hash_buf[0..48]; + } else { + var hash_copy = handshake_record_hash_set.sha512; + hash_copy.final(¤t_hash_buf); + current_hash = ¤t_hash_buf; + } + + const signed = (try @"pcks1v1.5".sign( + options.temp_allocator, + client_cert.signature_algorithm, + current_hash, + client_cert.private_key, + )) orelse return error.ClientCertificateVerifyFailed; + defer options.temp_allocator.free(signed); + + try writer.writeAll(&[3]u8{ 0x16, 0x03, 0x03 }); + try writer.writeIntBig(u16, @intCast(u16, signed.len + 8)); + var msg_buf: [8]u8 = [1]u8{0x0F} ++ ([1]u8{undefined} ** 7); + mem.writeIntBig(u24, msg_buf[1..4], @intCast(u24, signed.len + 4)); + msg_buf[4] = @enumToInt(client_cert.signature_algorithm.hash); + msg_buf[5] = @enumToInt(client_cert.signature_algorithm.signature); + mem.writeIntBig(u16, msg_buf[6..8], @intCast(u16, signed.len)); + try hashing_writer.writeAll(&msg_buf); + try hashing_writer.writeAll(signed); + } + } + + // Client encryption keys calculation for ECDHE_RSA cipher suites with SHA256 hash + var master_secret: [48]u8 = undefined; + var key_data: ciphers.KeyData(suites) = undefined; + { + var pre_master_secret_buf: [curves.max_pre_master_secret_len(curvelist)]u8 = undefined; + const pre_master_secret = try curves.make_pre_master_secret( + curvelist, + curve_id, + client_key_pair, + &pre_master_secret_buf, + server_public_key_buf, + ); + + const seed_len = 77; // extra len variable to workaround a bug + var seed: [seed_len]u8 = undefined; + seed[0..13].* = "master secret".*; + seed[13..45].* = client_random; + seed[45..77].* = server_random; + + var a1: [32 + seed.len]u8 = undefined; + Hmac256.create(a1[0..32], &seed, pre_master_secret); + var a2: [32 + seed.len]u8 = undefined; + Hmac256.create(a2[0..32], a1[0..32], pre_master_secret); + + a1[32..].* = seed; + a2[32..].* = seed; + + var p1: [32]u8 = undefined; + Hmac256.create(&p1, &a1, pre_master_secret); + var p2: [32]u8 = undefined; + Hmac256.create(&p2, &a2, pre_master_secret); + + master_secret[0..32].* = p1; + master_secret[32..48].* = p2[0..16].*; + + // Key expansion + seed[0..13].* = "key expansion".*; + seed[13..45].* = server_random; + seed[45..77].* = client_random; + a1[32..].* = seed; + a2[32..].* = seed; + + const KeyExpansionState = struct { + seed: *const [77]u8, + a1: *[32 + seed_len]u8, + a2: *[32 + seed_len]u8, + master_secret: *const [48]u8, + }; + + const next_32_bytes = struct { + inline fn f( + state: *KeyExpansionState, + comptime chunk_idx: comptime_int, + chunk: *[32]u8, + ) void { + if (chunk_idx == 0) { + Hmac256.create(state.a1[0..32], state.seed, state.master_secret); + Hmac256.create(chunk, state.a1, state.master_secret); + } else if (chunk_idx % 2 == 1) { + Hmac256.create(state.a2[0..32], state.a1[0..32], state.master_secret); + Hmac256.create(chunk, state.a2, state.master_secret); + } else { + Hmac256.create(state.a1[0..32], state.a2[0..32], state.master_secret); + Hmac256.create(chunk, state.a1, state.master_secret); + } + } + }.f; + var state = KeyExpansionState{ + .seed = &seed, + .a1 = &a1, + .a2 = &a2, + .master_secret = &master_secret, + }; + + key_data = ciphers.key_expansion(suites, ciphersuite, &state, next_32_bytes); + } + + // Client change cipher spec and client handshake finished + { + try writer.writeAll(&[6]u8{ + // Client change cipher spec + 0x14, 0x03, 0x03, + 0x00, 0x01, 0x01, + }); + // The message we need to encrypt is the following: + // 0x14 0x00 0x00 0x0c + // <12 bytes of verify_data> + // seed = "client finished" + SHA256(all handshake messages) + // a1 = HMAC-SHA256(key=MasterSecret, data=seed) + // p1 = HMAC-SHA256(key=MasterSecret, data=a1 + seed) + // verify_data = p1[0..12] + var verify_message: [16]u8 = undefined; + verify_message[0..4].* = "\x14\x00\x00\x0C".*; + { + var seed: [47]u8 = undefined; + seed[0..15].* = "client finished".*; + // We still need to update the hash one time, so we copy + // to get the current digest here. + var hash_copy = handshake_record_hash_set.sha256; + hash_copy.final(seed[15..47]); + + var a1: [32 + seed.len]u8 = undefined; + Hmac256.create(a1[0..32], &seed, &master_secret); + a1[32..].* = seed; + var p1: [32]u8 = undefined; + Hmac256.create(&p1, &a1, &master_secret); + verify_message[4..16].* = p1[0..12].*; + } + handshake_record_hash_set.update(&verify_message); + + inline for (suites) |cs| { + if (cs.tag == ciphersuite) { + try cs.raw_write( + 256, + rand, + &key_data, + writer, + [3]u8{ 0x16, 0x03, 0x03 }, + 0, + &verify_message, + ); + } + } + } + + // Server change cipher spec + { + const length = try record_length(0x14, reader); + const next_byte = try reader.readByte(); + if (length != 1 or next_byte != 0x01) + return error.ServerMalformedResponse; + } + // Server handshake finished + { + const length = try handshake_record_length(reader); + + var verify_message: [16]u8 = undefined; + verify_message[0..4].* = "\x14\x00\x00\x0C".*; + { + var seed: [47]u8 = undefined; + seed[0..15].* = "server finished".*; + handshake_record_hash_set.sha256.final(seed[15..47]); + var a1: [32 + seed.len]u8 = undefined; + Hmac256.create(a1[0..32], &seed, &master_secret); + a1[32..].* = seed; + var p1: [32]u8 = undefined; + Hmac256.create(&p1, &a1, &master_secret); + verify_message[4..16].* = p1[0..12].*; + } + + inline for (suites) |cs| { + if (cs.tag == ciphersuite) { + if (!try cs.check_verify_message(&key_data, length, reader, verify_message)) + return error.ServerInvalidVerifyData; + } + } + } + + return Client(@TypeOf(reader), @TypeOf(writer), suites, has_alpn){ + .ciphersuite = ciphersuite, + .key_data = key_data, + .rand = rand, + .parent_reader = reader, + .parent_writer = writer, + .protocol = protocol, + }; +} + +pub fn Client( + comptime _Reader: type, + comptime _Writer: type, + comptime _ciphersuites: anytype, + comptime has_protocol: bool, +) type { + return struct { + const ReaderError = _Reader.Error || ServerAlert || error{ ServerMalformedResponse, ServerInvalidVersion, AuthenticationFailed }; + pub const Reader = std.io.Reader(*@This(), ReaderError, read); + pub const Writer = std.io.Writer(*@This(), _Writer.Error, write); + + const InRecordState = ciphers.InRecordState(_ciphersuites); + const ReadState = union(enum) { + none, + in_record: struct { + record_length: usize, + index: usize = 0, + state: InRecordState, + }, + }; + + ciphersuite: u16, + client_seq: u64 = 1, + server_seq: u64 = 1, + key_data: ciphers.KeyData(_ciphersuites), + read_state: ReadState = .none, + rand: *std.rand.Random, + + parent_reader: _Reader, + parent_writer: _Writer, + + protocol: if (has_protocol) []const u8 else void, + + pub fn reader(self: *@This()) Reader { + return .{ .context = self }; + } + + pub fn writer(self: *@This()) Writer { + return .{ .context = self }; + } + + pub fn read(self: *@This(), buffer: []u8) ReaderError!usize { + const buf_size = 1024; + + switch (self.read_state) { + .none => { + const header = record_header(self.parent_reader) catch |err| switch (err) { + error.EndOfStream => return 0, + else => |e| return e, + }; + + const len_overhead = inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + break cs.mac_length + cs.prefix_data_length; + } + } else unreachable; + + const rec_length = header.len(); + if (rec_length < len_overhead) + return error.ServerMalformedResponse; + const len = rec_length - len_overhead; + + if ((header.tag() != 0x17 and header.tag() != 0x15) or + (header.tag() == 0x15 and len != 2)) + { + return error.ServerMalformedResponse; + } + + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + var prefix_data: [cs.prefix_data_length]u8 = undefined; + if (cs.prefix_data_length > 0) { + self.parent_reader.readNoEof(&prefix_data) catch |err| switch (err) { + error.EndOfStream => return error.ServerMalformedResponse, + else => |e| return e, + }; + } + self.read_state = .{ .in_record = .{ + .record_length = len, + .state = @unionInit( + InRecordState, + cs.name, + cs.init_state(prefix_data, self.server_seq, &self.key_data, header), + ), + } }; + } + } + + if (header.tag() == 0x15) { + var encrypted: [2]u8 = undefined; + self.parent_reader.readNoEof(&encrypted) catch |err| switch (err) { + error.EndOfStream => return error.ServerMalformedResponse, + else => |e| return e, + }; + + var result: [2]u8 = undefined; + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + // This decrypt call should always consume the whole record + cs.decrypt_part( + &self.key_data, + self.read_state.in_record.record_length, + &self.read_state.in_record.index, + &@field(self.read_state.in_record.state, cs.name), + &encrypted, + &result, + ); + std.debug.assert(self.read_state.in_record.index == self.read_state.in_record.record_length); + try cs.verify_mac( + self.parent_reader, + self.read_state.in_record.record_length, + &@field(self.read_state.in_record.state, cs.name), + ); + } + } + self.read_state = .none; + self.server_seq += 1; + // CloseNotify + if (result[1] == 0) + return 0; + return alert_byte_to_error(result[1]); + } else if (header.tag() == 0x17) { + const curr_bytes = std.math.min(std.math.min(len, buf_size), buffer.len); + // Partially decrypt the data. + var encrypted: [buf_size]u8 = undefined; + const actually_read = try self.parent_reader.read(encrypted[0..curr_bytes]); + + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + cs.decrypt_part( + &self.key_data, + self.read_state.in_record.record_length, + &self.read_state.in_record.index, + &@field(self.read_state.in_record.state, cs.name), + encrypted[0..actually_read], + buffer[0..actually_read], + ); + + if (self.read_state.in_record.index == self.read_state.in_record.record_length) { + try cs.verify_mac( + self.parent_reader, + self.read_state.in_record.record_length, + &@field(self.read_state.in_record.state, cs.name), + ); + self.server_seq += 1; + self.read_state = .none; + } + } + } + return actually_read; + } else unreachable; + }, + .in_record => |*in_record| { + const curr_bytes = std.math.min(std.math.min(buf_size, buffer.len), in_record.record_length - in_record.index); + // Partially decrypt the data. + var encrypted: [buf_size]u8 = undefined; + const actually_read = try self.parent_reader.read(encrypted[0..curr_bytes]); + + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + cs.decrypt_part( + &self.key_data, + in_record.record_length, + &in_record.index, + &@field(in_record.state, cs.name), + encrypted[0..actually_read], + buffer[0..actually_read], + ); + + if (in_record.index == in_record.record_length) { + try cs.verify_mac( + self.parent_reader, + in_record.record_length, + &@field(in_record.state, cs.name), + ); + self.server_seq += 1; + self.read_state = .none; + } + } + } + return actually_read; + }, + } + } + + pub fn write(self: *@This(), buffer: []const u8) _Writer.Error!usize { + if (buffer.len == 0) return 0; + + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + // @TODO Make this buffer size configurable + const curr_bytes = @truncate(u16, std.math.min(buffer.len, 1024)); + try cs.raw_write( + 1024, + self.rand, + &self.key_data, + self.parent_writer, + [3]u8{ 0x17, 0x03, 0x03 }, + self.client_seq, + buffer[0..curr_bytes], + ); + self.client_seq += 1; + return curr_bytes; + } + } + unreachable; + } + + pub fn close_notify(self: *@This()) !void { + inline for (_ciphersuites) |cs| { + if (self.ciphersuite == cs.tag) { + try cs.raw_write( + 1024, + self.rand, + &self.key_data, + self.parent_writer, + [3]u8{ 0x15, 0x03, 0x03 }, + self.client_seq, + "\x01\x00", + ); + self.client_seq += 1; + return; + } + } + unreachable; + } + }; +} + +test "HTTPS request on wikipedia main page" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "en.wikipedia.org", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertHighAssuranceEVRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + var client = try client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + .ciphersuites = .{ciphersuites.ECDHE_RSA_Chacha20_Poly1305}, + .protocols = &[_][]const u8{"http/1.1"}, + .curves = .{curves.x25519}, + }, "en.wikipedia.org"); + defer client.close_notify() catch {}; + try std.testing.expectEqualStrings("http/1.1", client.protocol); + try client.writer().writeAll("GET /wiki/Main_Page HTTP/1.1\r\nHost: en.wikipedia.org\r\nAccept: */*\r\n\r\n"); + + { + const header = try client.reader().readUntilDelimiterAlloc(std.testing.allocator, '\n', std.math.maxInt(usize)); + try std.testing.expectEqualStrings("HTTP/1.1 200 OK", mem.trim(u8, header, &std.ascii.spaces)); + std.testing.allocator.free(header); + } + + // Skip the rest of the headers expect for Content-Length + var content_length: ?usize = null; + hdr_loop: while (true) { + const header = try client.reader().readUntilDelimiterAlloc(std.testing.allocator, '\n', std.math.maxInt(usize)); + defer std.testing.allocator.free(header); + + const hdr_contents = mem.trim(u8, header, &std.ascii.spaces); + if (hdr_contents.len == 0) { + break :hdr_loop; + } + + if (mem.startsWith(u8, hdr_contents, "Content-Length: ")) { + content_length = try std.fmt.parseUnsigned(usize, hdr_contents[16..], 10); + } + } + try std.testing.expect(content_length != null); + const html_contents = try std.testing.allocator.alloc(u8, content_length.?); + defer std.testing.allocator.free(html_contents); + + try client.reader().readNoEof(html_contents); +} + +test "HTTPS request on wikipedia alternate name" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "en.m.wikipedia.org", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertHighAssuranceEVRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + var client = try client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + .ciphersuites = .{ciphersuites.ECDHE_RSA_Chacha20_Poly1305}, + .protocols = &[_][]const u8{"http/1.1"}, + .curves = .{curves.x25519}, + }, "en.m.wikipedia.org"); + defer client.close_notify() catch {}; +} + +test "HTTPS request on twitch oath2 endpoint" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "id.twitch.tv", 443); + defer sock.close(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + var client = try client_connect(.{ + .rand = rand, + .temp_allocator = std.testing.allocator, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .none, + .protocols = &[_][]const u8{"http/1.1"}, + }, "id.twitch.tv"); + try std.testing.expectEqualStrings("http/1.1", client.protocol); + defer client.close_notify() catch {}; + + try client.writer().writeAll("GET /oauth2/validate HTTP/1.1\r\nHost: id.twitch.tv\r\nAccept: */*\r\n\r\n"); + var content_length: ?usize = null; + hdr_loop: while (true) { + const header = try client.reader().readUntilDelimiterAlloc(std.testing.allocator, '\n', std.math.maxInt(usize)); + defer std.testing.allocator.free(header); + + const hdr_contents = mem.trim(u8, header, &std.ascii.spaces); + if (hdr_contents.len == 0) { + break :hdr_loop; + } + + if (mem.startsWith(u8, hdr_contents, "Content-Length: ")) { + content_length = try std.fmt.parseUnsigned(usize, hdr_contents[16..], 10); + } + } + try std.testing.expect(content_length != null); + const html_contents = try std.testing.allocator.alloc(u8, content_length.?); + defer std.testing.allocator.free(html_contents); + + try client.reader().readNoEof(html_contents); +} + +test "Connecting to expired.badssl.com returns an error" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "expired.badssl.com", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertGlobalRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + if (client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + }, "expired.badssl.com")) |_| { + return error.ExpectedVerificationFailed; + } else |err| { + try std.testing.expect(err == error.CertificateVerificationFailed); + } +} + +test "Connecting to wrong.host.badssl.com returns an error" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "wrong.host.badssl.com", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertGlobalRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + if (client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + }, "wrong.host.badssl.com")) |_| { + return error.ExpectedVerificationFailed; + } else |err| { + try std.testing.expect(err == error.CertificateVerificationFailed); + } +} + +test "Connecting to self-signed.badssl.com returns an error" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "self-signed.badssl.com", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertGlobalRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + if (client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + }, "self-signed.badssl.com")) |_| { + return error.ExpectedVerificationFailed; + } else |err| { + try std.testing.expect(err == error.CertificateVerificationFailed); + } +} + +test "Connecting to client.badssl.com with a client certificate" { + const sock = try std.net.tcpConnectToHost(std.testing.allocator, "client.badssl.com", 443); + defer sock.close(); + + var fbs = std.io.fixedBufferStream(@embedFile("../test/DigiCertGlobalRootCA.crt.pem")); + var trusted_chain = try x509.CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer trusted_chain.deinit(); + + // @TODO Remove this once std.crypto.rand works in .evented mode + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + var client_cert = try x509.ClientCertificateChain.from_pem( + std.testing.allocator, + std.io.fixedBufferStream(@embedFile("../test/badssl.com-client.pem")).reader(), + ); + defer client_cert.deinit(std.testing.allocator); + + var client = try client_connect(.{ + .rand = rand, + .reader = sock.reader(), + .writer = sock.writer(), + .cert_verifier = .default, + .temp_allocator = std.testing.allocator, + .trusted_certificates = trusted_chain.data.items, + .client_certificates = &[1]x509.ClientCertificateChain{client_cert}, + }, "client.badssl.com"); + defer client.close_notify() catch {}; + + try client.writer().writeAll("GET / HTTP/1.1\r\nHost: client.badssl.com\r\nAccept: */*\r\n\r\n"); + + const line = try client.reader().readUntilDelimiterAlloc(std.testing.allocator, '\n', std.math.maxInt(usize)); + defer std.testing.allocator.free(line); + try std.testing.expectEqualStrings("HTTP/1.1 200 OK\r", line); +} diff --git a/src/deps/iguanaTLS/src/pcks1-1_5.zig b/src/deps/iguanaTLS/src/pcks1-1_5.zig new file mode 100644 index 000000000..32183a2d7 --- /dev/null +++ b/src/deps/iguanaTLS/src/pcks1-1_5.zig @@ -0,0 +1,209 @@ +const std = @import("std"); +const mem = std.mem; +const Allocator = mem.Allocator; +const Sha224 = std.crypto.hash.sha2.Sha224; +const Sha384 = std.crypto.hash.sha2.Sha384; +const Sha512 = std.crypto.hash.sha2.Sha512; +const Sha256 = std.crypto.hash.sha2.Sha256; + +const x509 = @import("x509.zig"); +const SignatureAlgorithm = x509.Certificate.SignatureAlgorithm; +const asn1 = @import("asn1.zig"); + +fn rsa_perform( + allocator: *Allocator, + modulus: std.math.big.int.Const, + exponent: std.math.big.int.Const, + base: []const u8, +) !?std.math.big.int.Managed { + // @TODO Better algorithm, make it faster. + const curr_base_limbs = try allocator.alloc( + usize, + std.math.divCeil(usize, base.len, @sizeOf(usize)) catch unreachable, + ); + const curr_base_limb_bytes = @ptrCast([*]u8, curr_base_limbs)[0..base.len]; + mem.copy(u8, curr_base_limb_bytes, base); + mem.reverse(u8, curr_base_limb_bytes); + var curr_base = (std.math.big.int.Mutable{ + .limbs = curr_base_limbs, + .positive = true, + .len = curr_base_limbs.len, + }).toManaged(allocator); + defer curr_base.deinit(); + + var curr_exponent = try exponent.toManaged(allocator); + defer curr_exponent.deinit(); + var result = try std.math.big.int.Managed.initSet(allocator, @as(usize, 1)); + + // encrypted = signature ^ key.exponent MOD key.modulus + while (curr_exponent.toConst().orderAgainstScalar(0) == .gt) { + if (curr_exponent.isOdd()) { + try result.ensureMulCapacity(result.toConst(), curr_base.toConst()); + try result.mul(result.toConst(), curr_base.toConst()); + try llmod(&result, modulus); + } + try curr_base.sqr(curr_base.toConst()); + try llmod(&curr_base, modulus); + try curr_exponent.shiftRight(curr_exponent, 1); + } + + if (result.limbs.len * @sizeOf(usize) < base.len) + return null; + return result; +} + +// res = res mod N +fn llmod(res: *std.math.big.int.Managed, n: std.math.big.int.Const) !void { + var temp = try std.math.big.int.Managed.init(res.allocator); + defer temp.deinit(); + try temp.divTrunc(res, res.toConst(), n); +} + +pub fn algorithm_prefix(signature_algorithm: SignatureAlgorithm) ?[]const u8 { + return switch (signature_algorithm.hash) { + .none, .md5, .sha1 => null, + .sha224 => &[_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, + 0x09, 0x60, 0x86, 0x48, 0x01, + 0x65, 0x03, 0x04, 0x02, 0x04, + 0x05, 0x00, 0x04, 0x1c, + }, + .sha256 => &[_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, + 0x09, 0x60, 0x86, 0x48, 0x01, + 0x65, 0x03, 0x04, 0x02, 0x01, + 0x05, 0x00, 0x04, 0x20, + }, + .sha384 => &[_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, + 0x09, 0x60, 0x86, 0x48, 0x01, + 0x65, 0x03, 0x04, 0x02, 0x02, + 0x05, 0x00, 0x04, 0x30, + }, + .sha512 => &[_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, + 0x09, 0x60, 0x86, 0x48, 0x01, + 0x65, 0x03, 0x04, 0x02, 0x03, + 0x05, 0x00, 0x04, 0x40, + }, + }; +} + +pub fn sign( + allocator: *Allocator, + signature_algorithm: SignatureAlgorithm, + hash: []const u8, + private_key: x509.PrivateKey, +) !?[]const u8 { + // @TODO ECDSA signatures + if (signature_algorithm.signature != .rsa or private_key != .rsa) + return null; + + const signature_length = private_key.rsa.modulus.len * @sizeOf(usize); + var sig_buf = try allocator.alloc(u8, signature_length); + defer allocator.free(sig_buf); + const prefix = algorithm_prefix(signature_algorithm) orelse return null; + const first_prefix_idx = sig_buf.len - hash.len - prefix.len; + const first_hash_idx = sig_buf.len - hash.len; + + // EM = 0x00 || 0x01 || PS || 0x00 || T + sig_buf[0] = 0; + sig_buf[1] = 1; + mem.set(u8, sig_buf[2 .. first_prefix_idx - 1], 0xff); + sig_buf[first_prefix_idx - 1] = 0; + mem.copy(u8, sig_buf[first_prefix_idx..first_hash_idx], prefix); + mem.copy(u8, sig_buf[first_hash_idx..], hash); + + const modulus = std.math.big.int.Const{ .limbs = private_key.rsa.modulus, .positive = true }; + const exponent = std.math.big.int.Const{ .limbs = private_key.rsa.exponent, .positive = true }; + + var rsa_result = (try rsa_perform(allocator, modulus, exponent, sig_buf)) orelse return null; + if (rsa_result.limbs.len * @sizeOf(usize) < signature_length) { + rsa_result.deinit(); + return null; + } + + const enc_buf = @ptrCast([*]u8, rsa_result.limbs.ptr)[0..signature_length]; + mem.reverse(u8, enc_buf); + return allocator.resize( + enc_buf.ptr[0 .. rsa_result.limbs.len * @sizeOf(usize)], + signature_length, + ) catch unreachable; +} + +pub fn verify_signature( + allocator: *Allocator, + signature_algorithm: SignatureAlgorithm, + signature: asn1.BitString, + hash: []const u8, + public_key: x509.PublicKey, +) !bool { + // @TODO ECDSA algorithms + if (public_key != .rsa or signature_algorithm.signature != .rsa) return false; + const prefix = algorithm_prefix(signature_algorithm) orelse return false; + + // RSA hash verification with PKCS 1 V1_5 padding + const modulus = std.math.big.int.Const{ .limbs = public_key.rsa.modulus, .positive = true }; + const exponent = std.math.big.int.Const{ .limbs = public_key.rsa.exponent, .positive = true }; + if (modulus.bitCountAbs() != signature.bit_len) + return false; + + var rsa_result = (try rsa_perform(allocator, modulus, exponent, signature.data)) orelse return false; + defer rsa_result.deinit(); + + if (rsa_result.limbs.len * @sizeOf(usize) < signature.data.len) + return false; + + const enc_buf = @ptrCast([*]u8, rsa_result.limbs.ptr)[0..signature.data.len]; + mem.reverse(u8, enc_buf); + + if (enc_buf[0] != 0x00 or enc_buf[1] != 0x01) + return false; + if (!mem.endsWith(u8, enc_buf, hash)) + return false; + if (!mem.endsWith(u8, enc_buf[0 .. enc_buf.len - hash.len], prefix)) + return false; + if (enc_buf[enc_buf.len - hash.len - prefix.len - 1] != 0x00) + return false; + for (enc_buf[2 .. enc_buf.len - hash.len - prefix.len - 1]) |c| { + if (c != 0xff) return false; + } + + return true; +} + +pub fn certificate_verify_signature( + allocator: *Allocator, + signature_algorithm: x509.Certificate.SignatureAlgorithm, + signature: asn1.BitString, + bytes: []const u8, + public_key: x509.PublicKey, +) !bool { + // @TODO ECDSA algorithms + if (public_key != .rsa or signature_algorithm.signature != .rsa) return false; + + var hash_buf: [64]u8 = undefined; + var hash: []u8 = undefined; + + switch (signature_algorithm.hash) { + // Deprecated hash algos + .none, .md5, .sha1 => return false, + .sha224 => { + Sha224.hash(bytes, hash_buf[0..28], .{}); + hash = hash_buf[0..28]; + }, + .sha256 => { + Sha256.hash(bytes, hash_buf[0..32], .{}); + hash = hash_buf[0..32]; + }, + .sha384 => { + Sha384.hash(bytes, hash_buf[0..48], .{}); + hash = hash_buf[0..48]; + }, + .sha512 => { + Sha512.hash(bytes, hash_buf[0..64], .{}); + hash = &hash_buf; + }, + } + return try verify_signature(allocator, signature_algorithm, signature, hash, public_key); +} diff --git a/src/deps/iguanaTLS/src/x509.zig b/src/deps/iguanaTLS/src/x509.zig new file mode 100644 index 000000000..06f8fc258 --- /dev/null +++ b/src/deps/iguanaTLS/src/x509.zig @@ -0,0 +1,1053 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const mem = std.mem; +const trait = std.meta.trait; + +const asn1 = @import("asn1.zig"); + +// zig fmt: off +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8 +pub const CurveId = enum { + sect163k1, sect163r1, sect163r2, sect193r1, + sect193r2, sect233k1, sect233r1, sect239k1, + sect283k1, sect283r1, sect409k1, sect409r1, + sect571k1, sect571r1, secp160k1, secp160r1, + secp160r2, secp192k1, secp192r1, secp224k1, + secp224r1, secp256k1, secp256r1, secp384r1, + secp521r1,brainpoolP256r1, brainpoolP384r1, + brainpoolP512r1, curve25519, curve448, +}; +// zig fmt: on + +pub const PublicKey = union(enum) { + pub const empty = PublicKey{ .ec = .{ .id = undefined, .curve_point = &[0]u8{} } }; + + /// RSA public key + rsa: struct { + //Positive std.math.big.int.Const numbers. + modulus: []const usize, + exponent: []const usize, + }, + /// Elliptic curve public key + ec: struct { + id: CurveId, + /// Public curve point (uncompressed format) + curve_point: []const u8, + }, + + pub fn deinit(self: @This(), alloc: *Allocator) void { + switch (self) { + .rsa => |rsa| { + alloc.free(rsa.modulus); + alloc.free(rsa.exponent); + }, + .ec => |ec| alloc.free(ec.curve_point), + } + } + + pub fn eql(self: @This(), other: @This()) bool { + if (@as(std.meta.Tag(@This()), self) != @as(std.meta.Tag(@This()), other)) + return false; + switch (self) { + .rsa => |mod_exp| return mem.eql(usize, mod_exp.exponent, other.rsa.exponent) and + mem.eql(usize, mod_exp.modulus, other.rsa.modulus), + .ec => |ec| return ec.id == other.ec.id and mem.eql(u8, ec.curve_point, other.ec.curve_point), + } + } +}; + +pub const PrivateKey = PublicKey; + +pub fn parse_public_key(allocator: *Allocator, reader: anytype) !PublicKey { + if ((try reader.readByte()) != 0x30) + return error.MalformedDER; + const seq_len = try asn1.der.parse_length(reader); + _ = seq_len; + + if ((try reader.readByte()) != 0x06) + return error.MalformedDER; + const oid_bytes = try asn1.der.parse_length(reader); + if (oid_bytes == 9) { + // @TODO This fails in async if merged with the if + if (!try reader.isBytes(&[9]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0xD, 0x1, 0x1, 0x1 })) + return error.MalformedDER; + // OID is 1.2.840.113549.1.1.1 + // RSA key + // Skip past the NULL + const null_byte = try reader.readByte(); + if (null_byte != 0x05) + return error.MalformedDER; + const null_len = try asn1.der.parse_length(reader); + if (null_len != 0x00) + return error.MalformedDER; + { + // BitString next! + if ((try reader.readByte()) != 0x03) + return error.MalformedDER; + _ = try asn1.der.parse_length(reader); + const bit_string_unused_bits = try reader.readByte(); + if (bit_string_unused_bits != 0) + return error.MalformedDER; + + if ((try reader.readByte()) != 0x30) + return error.MalformedDER; + _ = try asn1.der.parse_length(reader); + + // Modulus + if ((try reader.readByte()) != 0x02) + return error.MalformedDER; + const modulus = try asn1.der.parse_int(allocator, reader); + errdefer allocator.free(modulus.limbs); + if (!modulus.positive) return error.MalformedDER; + // Exponent + if ((try reader.readByte()) != 0x02) + return error.MalformedDER; + const exponent = try asn1.der.parse_int(allocator, reader); + errdefer allocator.free(exponent.limbs); + if (!exponent.positive) return error.MalformedDER; + return PublicKey{ + .rsa = .{ + .modulus = modulus.limbs, + .exponent = exponent.limbs, + }, + }; + } + } else if (oid_bytes == 7) { + // @TODO This fails in async if merged with the if + if (!try reader.isBytes(&[7]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 })) + return error.MalformedDER; + // OID is 1.2.840.10045.2.1 + // Elliptical curve + // We only support named curves, for which the parameter field is an OID. + const oid_tag = try reader.readByte(); + if (oid_tag != 0x06) + return error.MalformedDER; + const curve_oid_bytes = try asn1.der.parse_length(reader); + + var key: PublicKey = undefined; + if (curve_oid_bytes == 5) { + if (!try reader.isBytes(&[4]u8{ 0x2B, 0x81, 0x04, 0x00 })) + return error.MalformedDER; + // 1.3.132.0.{34, 35} + const last_byte = try reader.readByte(); + if (last_byte == 0x22) + key = .{ .ec = .{ .id = .secp384r1, .curve_point = undefined } } + else if (last_byte == 0x23) + key = .{ .ec = .{ .id = .secp521r1, .curve_point = undefined } } + else + return error.MalformedDER; + } else if (curve_oid_bytes == 8) { + if (!try reader.isBytes(&[8]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x3, 0x1, 0x7 })) + return error.MalformedDER; + key = .{ .ec = .{ .id = .secp256r1, .curve_point = undefined } }; + } else { + return error.MalformedDER; + } + + if ((try reader.readByte()) != 0x03) + return error.MalformedDER; + const byte_len = try asn1.der.parse_length(reader); + const unused_bits = try reader.readByte(); + const bit_count = (byte_len - 1) * 8 - unused_bits; + if (bit_count % 8 != 0) + return error.MalformedDER; + const bit_memory = try allocator.alloc(u8, std.math.divCeil(usize, bit_count, 8) catch unreachable); + errdefer allocator.free(bit_memory); + try reader.readNoEof(bit_memory[0 .. byte_len - 1]); + + key.ec.curve_point = bit_memory; + return key; + } + return error.MalformedDER; +} + +pub fn DecodeDERError(comptime Reader: type) type { + return Reader.Error || error{ + MalformedPEM, + MalformedDER, + EndOfStream, + OutOfMemory, + }; +} + +pub const Certificate = struct { + pub const SignatureAlgorithm = struct { + hash: enum(u8) { + none = 0, + md5 = 1, + sha1 = 2, + sha224 = 3, + sha256 = 4, + sha384 = 5, + sha512 = 6, + }, + signature: enum(u8) { + anonymous = 0, + rsa = 1, + dsa = 2, + ecdsa = 3, + }, + }; + + /// Subject distinguished name + dn: []const u8, + /// A "CA" anchor is deemed fit to verify signatures on certificates. + /// A "non-CA" anchor is accepted only for direct trust (server's certificate + /// name and key match the anchor). + is_ca: bool = false, + public_key: PublicKey, + + const CaptureState = struct { + self: *Certificate, + allocator: *Allocator, + dn_allocated: bool = false, + pk_allocated: bool = false, + }; + + fn initSubjectDn(state: *CaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + + const dn_mem = try state.allocator.alloc(u8, length); + errdefer state.allocator.free(dn_mem); + try reader.readNoEof(dn_mem); + state.self.dn = dn_mem; + state.dn_allocated = true; + } + + fn processExtension(state: *CaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const object_id = try asn1.der.parse_value(state.allocator, reader); + defer object_id.deinit(state.allocator); + if (object_id != .object_identifier) return error.DoesNotMatchSchema; + if (object_id.object_identifier.len != 4) + return; + + const data = object_id.object_identifier.data; + // Basic constraints extension + if (data[0] != 2 or data[1] != 5 or data[2] != 29 or data[3] != 19) + return; + + const basic_constraints = try asn1.der.parse_value(state.allocator, reader); + defer basic_constraints.deinit(state.allocator); + + switch (basic_constraints) { + .bool => state.self.is_ca = true, + .octet_string => |s| { + if (s.len != 5 or s[0] != 0x30 or s[1] != 0x03 or s[2] != 0x01 or s[3] != 0x01) + return error.DoesNotMatchSchema; + state.self.is_ca = s[4] != 0x00; + }, + else => return error.DoesNotMatchSchema, + } + } + + fn initExtensions(state: *CaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + const schema = .{ + .sequence_of, + .{ .capture, 0, .sequence }, + }; + const captures = .{ + state, processExtension, + }; + try asn1.der.parse_schema(schema, captures, reader); + } + + fn initPublicKeyInfo(state: *CaptureState, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + state.self.public_key = try parse_public_key(state.allocator, reader); + state.pk_allocated = true; + } + + /// Initialize a trusted anchor from distinguished encoding rules (DER) encoded data + pub fn create(allocator: *Allocator, der_reader: anytype) DecodeDERError(@TypeOf(der_reader))!@This() { + var self: @This() = undefined; + self.is_ca = false; + // https://tools.ietf.org/html/rfc5280#page-117 + const schema = .{ + .sequence, .{ + // tbsCertificate + .{ + .sequence, + .{ + .{ .context_specific, 0 }, // version + .{.int}, // serialNumber + .{.sequence}, // signature + .{.sequence}, // issuer + .{.sequence}, // validity, + .{ .capture, 0, .sequence }, // subject + .{ .capture, 1, .sequence }, // subjectPublicKeyInfo + .{ .optional, .context_specific, 1 }, // issuerUniqueID + .{ .optional, .context_specific, 2 }, // subjectUniqueID + .{ .capture, 2, .optional, .context_specific, 3 }, // extensions + }, + }, + // signatureAlgorithm + .{.sequence}, + // signatureValue + .{.bit_string}, + }, + }; + + var capture_state = CaptureState{ + .self = &self, + .allocator = allocator, + }; + const captures = .{ + &capture_state, initSubjectDn, + &capture_state, initPublicKeyInfo, + &capture_state, initExtensions, + }; + + errdefer { + if (capture_state.dn_allocated) + allocator.free(self.dn); + if (capture_state.pk_allocated) + self.public_key.deinit(allocator); + } + + asn1.der.parse_schema(schema, captures, der_reader) catch |err| switch (err) { + error.InvalidLength, + error.InvalidTag, + error.InvalidContainerLength, + error.DoesNotMatchSchema, + => return error.MalformedDER, + else => |e| return e, + }; + return self; + } + + pub fn deinit(self: @This(), alloc: *Allocator) void { + alloc.free(self.dn); + self.public_key.deinit(alloc); + } + + pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + _ = fmt; + _ = options; + + try writer.print( + \\CERTIFICATE + \\----------- + \\IS CA: {} + \\Subject distinguished name (encoded): + \\{X} + \\Public key: + \\ + , .{ self.is_ca, self.dn }); + + switch (self.public_key) { + .rsa => |mod_exp| { + const modulus = std.math.big.int.Const{ .positive = true, .limbs = mod_exp.modulus }; + const exponent = std.math.big.int.Const{ .positive = true, .limbs = mod_exp.exponent }; + try writer.print( + \\RSA + \\modulus: {} + \\exponent: {} + \\ + , .{ + modulus, + exponent, + }); + }, + .ec => |ec| { + try writer.print( + \\EC (Curve: {}) + \\point: {} + \\ + , .{ + ec.id, + ec.curve_point, + }); + }, + } + + try writer.writeAll( + \\----------- + \\ + ); + } +}; + +pub const CertificateChain = struct { + data: std.ArrayList(Certificate), + + pub fn from_pem(allocator: *Allocator, pem_reader: anytype) DecodeDERError(@TypeOf(pem_reader))!@This() { + var self = @This(){ .data = std.ArrayList(Certificate).init(allocator) }; + errdefer self.deinit(); + + var it = pemCertificateIterator(pem_reader); + while (try it.next()) |cert_reader| { + var buffered = std.io.bufferedReader(cert_reader); + const anchor = try Certificate.create(allocator, buffered.reader()); + errdefer anchor.deinit(allocator); + try self.data.append(anchor); + } + return self; + } + + pub fn deinit(self: @This()) void { + const alloc = self.data.allocator; + for (self.data.items) |ta| ta.deinit(alloc); + self.data.deinit(); + } +}; + +pub fn get_signature_algorithm( + reader: anytype, +) (@TypeOf(reader).Error || error{EndOfStream})!?Certificate.SignatureAlgorithm { + const oid_tag = try reader.readByte(); + if (oid_tag != 0x06) + return null; + + const oid_length = try asn1.der.parse_length(reader); + if (oid_length == 9) { + var oid_bytes: [9]u8 = undefined; + try reader.readNoEof(&oid_bytes); + + if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 })) { + // TODO: Is hash actually none here? + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .none }; + } else if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 })) { + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .md5 }; + } else if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 })) { + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .sha1 }; + } else if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B })) { + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .sha256 }; + } else if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C })) { + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .sha384 }; + } else if (mem.eql(u8, &oid_bytes, &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D })) { + return Certificate.SignatureAlgorithm{ .signature = .rsa, .hash = .sha512 }; + } else { + return null; + } + return; + } else if (oid_length == 10) { + // TODO + // ECDSA + <Hash> algorithms + } + return null; +} + +pub const ClientCertificateChain = struct { + /// Number of certificates in the chain + cert_len: usize, + /// Contains the raw data of each certificate in the certificate chain + raw_certs: [*]const []const u8, + /// Issuer distinguished name in DER format of each certificate in the certificate chain + /// issuer_dn[N] is a dubslice of raw[N] + cert_issuer_dns: [*]const []const u8, + signature_algorithm: Certificate.SignatureAlgorithm, + private_key: PrivateKey, + + // TODO: Encrypted private keys, non-RSA private keys + pub fn from_pem(allocator: *Allocator, pem_reader: anytype) !@This() { + var it = PEMSectionIterator(@TypeOf(pem_reader), .{ + .section_names = &.{ + "X.509 CERTIFICATE", + "CERTIFICATE", + "RSA PRIVATE KEY", + }, + .skip_irrelevant_lines = true, + }){ .reader = pem_reader }; + + var raw_certs = std.ArrayListUnmanaged([]const u8){}; + var cert_issuer_dns = std.ArrayList([]const u8).init(allocator); + errdefer { + for (raw_certs.items) |bytes| { + allocator.free(bytes); + } + raw_certs.deinit(allocator); + cert_issuer_dns.deinit(); + } + + var signature_algorithm: Certificate.SignatureAlgorithm = undefined; + var private_key: ?PrivateKey = null; + errdefer if (private_key) |pk| { + pk.deinit(allocator); + }; + + while (try it.next()) |state_and_reader| { + switch (state_and_reader.state) { + .@"X.509 CERTIFICATE", .@"CERTIFICATE" => { + const cert_bytes = try state_and_reader.reader.readAllAlloc(allocator, std.math.maxInt(usize)); + errdefer allocator.free(cert_bytes); + try raw_certs.append(allocator, cert_bytes); + + const schema = .{ + .sequence, .{ + // tbsCertificate + .{ + .sequence, + .{ + .{ .context_specific, 0 }, // version + .{.int}, // serialNumber + .{.sequence}, // signature + .{ .capture, 0, .sequence }, // issuer + .{.sequence}, // validity + .{.sequence}, // subject + .{.sequence}, // subjectPublicKeyInfo + .{ .optional, .context_specific, 1 }, // issuerUniqueID + .{ .optional, .context_specific, 2 }, // subjectUniqueID + .{ .optional, .context_specific, 3 }, // extensions + }, + }, + // signatureAlgorithm + .{ .capture, 1, .sequence }, + // signatureValue + .{.bit_string}, + }, + }; + + var fbs = std.io.fixedBufferStream(cert_bytes); + const state = .{ + .fbs = &fbs, + .dns = &cert_issuer_dns, + .signature_algorithm = &signature_algorithm, + }; + + const captures = .{ + state, + struct { + fn capture(_state: anytype, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = reader; + + // TODO: Some way to get tag + length buffer directly in the capture callback? + const encoded_length = asn1.der.encode_length(length).slice(); + const pos = _state.fbs.pos; + const dn = _state.fbs.buffer[pos - encoded_length.len - 1 .. pos + length]; + try _state.dns.append(dn); + } + }.capture, + state, + struct { + fn capture(_state: anytype, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + _ = length; + + if (_state.dns.items.len == 1) + _state.signature_algorithm.* = (try get_signature_algorithm(reader)) orelse + return error.InvalidSignatureAlgorithm; + } + }.capture, + }; + + asn1.der.parse_schema(schema, captures, fbs.reader()) catch |err| switch (err) { + error.DoesNotMatchSchema, + error.EndOfStream, + error.InvalidTag, + error.InvalidLength, + error.InvalidSignatureAlgorithm, + error.InvalidContainerLength, + => return error.InvalidCertificate, + error.OutOfMemory => return error.OutOfMemory, + }; + }, + .@"RSA PRIVATE KEY" => { + if (private_key != null) + return error.MultiplePrivateKeys; + + const schema = .{ + .sequence, .{ + .{.int}, // version + .{ .capture, 0, .int }, //modulus + .{.int}, //publicExponent + .{ .capture, 1, .int }, //privateExponent + .{.int}, // prime1 + .{.int}, //prime2 + .{.int}, //exponent1 + .{.int}, //exponent2 + .{.int}, //coefficient + .{ .optional, .any }, //otherPrimeInfos + }, + }; + + private_key = .{ .rsa = undefined }; + const state = .{ + .modulus = &private_key.?.rsa.modulus, + .exponent = &private_key.?.rsa.exponent, + .allocator = allocator, + }; + + const captures = .{ + state, + struct { + fn capture(_state: anytype, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + + _state.modulus.* = (try asn1.der.parse_int_with_length( + _state.allocator, + length, + reader, + )).limbs; + } + }.capture, + state, + struct { + fn capture(_state: anytype, tag: u8, length: usize, reader: anytype) !void { + _ = tag; + + _state.exponent.* = (try asn1.der.parse_int_with_length( + _state.allocator, + length, + reader, + )).limbs; + } + }.capture, + }; + + asn1.der.parse_schema(schema, captures, state_and_reader.reader) catch |err| switch (err) { + error.DoesNotMatchSchema, + error.EndOfStream, + error.InvalidTag, + error.InvalidLength, + error.InvalidContainerLength, + => return error.InvalidPrivateKey, + error.OutOfMemory => return error.OutOfMemory, + error.MalformedPEM => return error.MalformedPEM, + }; + }, + .none, .other => unreachable, + } + } + if (private_key == null) + return error.NoPrivateKey; + + std.debug.assert(cert_issuer_dns.items.len == raw_certs.items.len); + return @This(){ + .cert_len = raw_certs.items.len, + .raw_certs = raw_certs.toOwnedSlice(allocator).ptr, + .cert_issuer_dns = cert_issuer_dns.toOwnedSlice().ptr, + .signature_algorithm = signature_algorithm, + .private_key = private_key.?, + }; + } + + pub fn deinit(self: *@This(), allocator: *Allocator) void { + for (self.raw_certs[0..self.cert_len]) |cert_bytes| { + allocator.free(cert_bytes); + } + allocator.free(self.raw_certs[0..self.cert_len]); + allocator.free(self.cert_issuer_dns[0..self.cert_len]); + self.private_key.deinit(allocator); + } +}; + +fn PEMSectionReader(comptime Reader: type, comptime options: PEMSectionIteratorOptions) type { + const Error = Reader.Error || error{MalformedPEM}; + const read = struct { + fn f(it: *PEMSectionIterator(Reader, options), buf: []u8) Error!usize { + var out_idx: usize = 0; + if (it.waiting_chars_len > 0) { + const rest_written = std.math.min(it.waiting_chars_len, buf.len); + while (out_idx < rest_written) : (out_idx += 1) { + buf[out_idx] = it.waiting_chars[out_idx]; + } + + it.waiting_chars_len -= rest_written; + if (it.waiting_chars_len != 0) { + std.mem.copy(u8, it.waiting_chars[0..], it.waiting_chars[rest_written..]); + } + + if (out_idx == buf.len) { + return out_idx; + } + } + if (it.state == .none) + return out_idx; + + var base64_buf: [4]u8 = undefined; + var base64_idx: usize = 0; + while (true) { + const byte = it.reader.readByte() catch |err| switch (err) { + error.EndOfStream => return out_idx, + else => |e| return e, + }; + + if (byte == '-') { + if (it.reader.isBytes("----END ") catch |err| switch (err) { + error.EndOfStream => return error.MalformedPEM, + else => |e| return e, + }) { + try it.reader.skipUntilDelimiterOrEof('\n'); + it.state = .none; + return out_idx; + } else return error.MalformedPEM; + } else if (byte == '\r') { + if ((it.reader.readByte() catch |err| switch (err) { + error.EndOfStream => return error.MalformedPEM, + else => |e| return e, + }) != '\n') + return error.MalformedPEM; + continue; + } else if (byte == '\n') + continue; + + base64_buf[base64_idx] = byte; + base64_idx += 1; + if (base64_idx == base64_buf.len) { + base64_idx = 0; + + const out_len = std.base64.standard_decoder.calcSizeForSlice(&base64_buf) catch + return error.MalformedPEM; + + const rest_chars = if (out_len > buf.len - out_idx) + out_len - (buf.len - out_idx) + else + 0; + const buf_chars = out_len - rest_chars; + + var res_buffer: [3]u8 = undefined; + std.base64.standard_decoder.decode(res_buffer[0..out_len], &base64_buf) catch + return error.MalformedPEM; + + var i: u3 = 0; + while (i < buf_chars) : (i += 1) { + buf[out_idx] = res_buffer[i]; + out_idx += 1; + } + + if (rest_chars > 0) { + mem.copy(u8, &it.waiting_chars, res_buffer[i..]); + it.waiting_chars_len = @intCast(u2, rest_chars); + } + if (out_idx == buf.len) + return out_idx; + } + } + } + }.f; + + return std.io.Reader( + *PEMSectionIterator(Reader, options), + Error, + read, + ); +} + +const PEMSectionIteratorOptions = struct { + section_names: []const []const u8, + skip_irrelevant_lines: bool = false, +}; + +fn PEMSectionIterator(comptime Reader: type, comptime options: PEMSectionIteratorOptions) type { + var biggest_name_len = 0; + + var fields: [options.section_names.len + 2]std.builtin.TypeInfo.EnumField = undefined; + fields[0] = .{ .name = "none", .value = 0 }; + fields[1] = .{ .name = "other", .value = 1 }; + for (fields[2..]) |*field, idx| { + field.name = options.section_names[idx]; + field.value = @as(u8, idx + 2); + if (field.name.len > biggest_name_len) + biggest_name_len = field.name.len; + } + + const StateEnum = @Type(.{ + .Enum = .{ + .layout = .Auto, + .tag_type = u8, + .fields = &fields, + .decls = &.{}, + .is_exhaustive = true, + }, + }); + + const _biggest_name_len = biggest_name_len; + + return struct { + pub const SectionReader = PEMSectionReader(Reader, options); + pub const StateAndName = struct { + state: StateEnum, + reader: SectionReader, + }; + pub const NextError = SectionReader.Error || error{EndOfStream}; + + reader: Reader, + // Internal state for the iterator and the current reader. + state: StateEnum = .none, + waiting_chars: [4]u8 = undefined, + waiting_chars_len: u2 = 0, + + // TODO More verification, this will accept lots of invalid PEM + // TODO Simplify code + pub fn next(self: *@This()) NextError!?StateAndName { + self.waiting_chars_len = 0; + outer_loop: while (true) { + const byte = self.reader.readByte() catch |err| switch (err) { + error.EndOfStream => if (self.state == .none) + return null + else + return error.EndOfStream, + else => |e| return e, + }; + + switch (self.state) { + .none => switch (byte) { + '#' => { + try self.reader.skipUntilDelimiterOrEof('\n'); + continue; + }, + '\r', '\n', ' ', '\t' => continue, + '-' => { + if (try self.reader.isBytes("----BEGIN ")) { + var name_char_idx: usize = 0; + var name_buf: [_biggest_name_len]u8 = undefined; + + while (true) { + const next_byte = try self.reader.readByte(); + switch (next_byte) { + '-' => { + try self.reader.skipUntilDelimiterOrEof('\n'); + const name = name_buf[0..name_char_idx]; + for (options.section_names) |sec_name, idx| { + if (mem.eql(u8, sec_name, name)) { + self.state = @intToEnum(StateEnum, @intCast(u8, idx + 2)); + return StateAndName{ + .reader = .{ .context = self }, + .state = self.state, + }; + } + } + self.state = .other; + continue :outer_loop; + }, + '\n' => return error.MalformedPEM, + else => { + if (name_char_idx == _biggest_name_len) { + try self.reader.skipUntilDelimiterOrEof('\n'); + self.state = .other; + continue :outer_loop; + } + name_buf[name_char_idx] = next_byte; + name_char_idx += 1; + }, + } + } + } else return error.MalformedPEM; + }, + else => { + if (options.skip_irrelevant_lines) { + try self.reader.skipUntilDelimiterOrEof('\n'); + continue; + } else { + return error.MalformedPEM; + } + }, + }, + else => switch (byte) { + '#' => { + try self.reader.skipUntilDelimiterOrEof('\n'); + continue; + }, + '\r', '\n', ' ', '\t' => continue, + '-' => { + if (try self.reader.isBytes("----END ")) { + try self.reader.skipUntilDelimiterOrEof('\n'); + self.state = .none; + continue; + } else return error.MalformedPEM; + }, + // TODO: Make sure the character is base64 + else => continue, + }, + } + } + } + }; +} + +fn PEMCertificateIterator(comptime Reader: type) type { + const SectionIterator = PEMSectionIterator(Reader, .{ + .section_names = &.{ "X.509 CERTIFICATE", "CERTIFICATE" }, + }); + + return struct { + pub const SectionReader = SectionIterator.SectionReader; + pub const NextError = SectionReader.Error || error{EndOfStream}; + + section_it: SectionIterator, + + pub fn next(self: *@This()) NextError!?SectionReader { + return ((try self.section_it.next()) orelse return null).reader; + } + }; +} + +/// Iterator of io.Reader that each decode one certificate from the PEM reader. +/// Readers do not have to be fully consumed until end of stream, but they must be +/// read from in order. +/// Iterator.SectionReader is the type of the io.Reader, Iterator.NextError is the error +/// set of the next() function. +pub fn pemCertificateIterator(reader: anytype) PEMCertificateIterator(@TypeOf(reader)) { + return .{ .section_it = .{ .reader = reader } }; +} + +pub const NameElement = struct { + // Encoded OID without tag + oid: asn1.ObjectIdentifier, + // Destination buffer + buf: []u8, + status: enum { + not_found, + found, + errored, + }, +}; + +const github_pem = @embedFile("../test/github.pem"); +const github_der = @embedFile("../test/github.der"); + +fn expected_pem_certificate_chain(bytes: []const u8, certs: []const []const u8) !void { + var fbs = std.io.fixedBufferStream(bytes); + + var it = pemCertificateIterator(fbs.reader()); + var idx: usize = 0; + while (try it.next()) |cert_reader| : (idx += 1) { + const result_bytes = try cert_reader.readAllAlloc(std.testing.allocator, std.math.maxInt(usize)); + defer std.testing.allocator.free(result_bytes); + try std.testing.expectEqualSlices(u8, certs[idx], result_bytes); + } + if (idx != certs.len) { + std.debug.panic("Read {} certificates, wanted {}", .{ idx, certs.len }); + } + try std.testing.expect((try it.next()) == null); +} + +fn expected_pem_certificate(bytes: []const u8, cert_bytes: []const u8) !void { + try expected_pem_certificate_chain(bytes, &[1][]const u8{cert_bytes}); +} + +test "pemCertificateIterator" { + try expected_pem_certificate(github_pem, github_der); + try expected_pem_certificate( + \\-----BEGIN BOGUS----- + \\-----END BOGUS----- + \\ + ++ + github_pem, + github_der, + ); + + try expected_pem_certificate_chain( + github_pem ++ + \\ + \\-----BEGIN BOGUS----- + \\-----END BOGUS----- + \\ + ++ github_pem, + &[2][]const u8{ github_der, github_der }, + ); + + try expected_pem_certificate_chain( + \\-----BEGIN BOGUS----- + \\-----END BOGUS----- + \\ + , + &[0][]const u8{}, + ); + + // Try reading byte by byte from a cert reader + { + var fbs = std.io.fixedBufferStream(github_pem ++ "\n# Some comment\n" ++ github_pem); + var it = pemCertificateIterator(fbs.reader()); + // Read a couple of bytes from the first reader, then skip to the next + { + const first_reader = (try it.next()) orelse return error.NoCertificate; + var first_few: [8]u8 = undefined; + const bytes = try first_reader.readAll(&first_few); + try std.testing.expectEqual(first_few.len, bytes); + try std.testing.expectEqualSlices(u8, github_der[0..bytes], &first_few); + } + + const next_reader = (try it.next()) orelse return error.NoCertificate; + var idx: usize = 0; + while (true) : (idx += 1) { + const byte = next_reader.readByte() catch |err| switch (err) { + error.EndOfStream => break, + else => |e| return e, + }; + if (github_der[idx] != byte) { + std.debug.panic("index {}: expected 0x{X}, found 0x{X}", .{ idx, github_der[idx], byte }); + } + } + try std.testing.expectEqual(github_der.len, idx); + try std.testing.expect((try it.next()) == null); + } +} + +test "CertificateChain" { + var fbs = std.io.fixedBufferStream(github_pem ++ + \\ + \\# Hellenic Academic and Research Institutions RootCA 2011 + \\-----BEGIN CERTIFICATE----- + \\MIIEMTCCAxmgAwIBAgIBADANBgkqhkiG9w0BAQUFADCBlTELMAkGA1UEBhMCR1Ix + \\RDBCBgNVBAoTO0hlbGxlbmljIEFjYWRlbWljIGFuZCBSZXNlYXJjaCBJbnN0aXR1 + \\dGlvbnMgQ2VydC4gQXV0aG9yaXR5MUAwPgYDVQQDEzdIZWxsZW5pYyBBY2FkZW1p + \\YyBhbmQgUmVzZWFyY2ggSW5zdGl0dXRpb25zIFJvb3RDQSAyMDExMB4XDTExMTIw + \\NjEzNDk1MloXDTMxMTIwMTEzNDk1MlowgZUxCzAJBgNVBAYTAkdSMUQwQgYDVQQK + \\EztIZWxsZW5pYyBBY2FkZW1pYyBhbmQgUmVzZWFyY2ggSW5zdGl0dXRpb25zIENl + \\cnQuIEF1dGhvcml0eTFAMD4GA1UEAxM3SGVsbGVuaWMgQWNhZGVtaWMgYW5kIFJl + \\c2VhcmNoIEluc3RpdHV0aW9ucyBSb290Q0EgMjAxMTCCASIwDQYJKoZIhvcNAQEB + \\BQADggEPADCCAQoCggEBAKlTAOMupvaO+mDYLZU++CwqVE7NuYRhlFhPjz2L5EPz + \\dYmNUeTDN9KKiE15HrcS3UN4SoqS5tdI1Q+kOilENbgH9mgdVc04UfCMJDGFr4PJ + \\fel3r+0ae50X+bOdOFAPplp5kYCvN66m0zH7tSYJnTxa71HFK9+WXesyHgLacEns + \\bgzImjeN9/E2YEsmLIKe0HjzDQ9jpFEw4fkrJxIH2Oq9GGKYsFk3fb7u8yBRQlqD + \\75O6aRXxYp2fmTmCobd0LovUxQt7L/DICto9eQqakxylKHJzkUOap9FNhYS5qXSP + \\FEDH3N6sQWRstBmbAmNtJGSPRLIl6s5ddAxjMlyNh+UCAwEAAaOBiTCBhjAPBgNV + \\HRMBAf8EBTADAQH/MAsGA1UdDwQEAwIBBjAdBgNVHQ4EFgQUppFC/RNhSiOeCKQp + \\5dgTBCPuQSUwRwYDVR0eBEAwPqA8MAWCAy5ncjAFggMuZXUwBoIELmVkdTAGggQu + \\b3JnMAWBAy5ncjAFgQMuZXUwBoEELmVkdTAGgQQub3JnMA0GCSqGSIb3DQEBBQUA + \\A4IBAQAf73lB4XtuP7KMhjdCSk4cNx6NZrokgclPEg8hwAOXhiVtXdMiKahsog2p + \\6z0GW5k6x8zDmjR/qw7IThzh+uTczQ2+vyT+bOdrwg3IBp5OjWEopmr95fZi6hg8 + \\TqBTnbI6nOulnJEWtk2C4AwFSKls9cz4y51JtPACpf1wA+2KIaWuE4ZJwzNzvoc7 + \\dIsXRSZMFpGD/md9zU1jZ/rzAxKWeAaNsWftjj++n08C9bMJL/NMh98qy5V8Acys + \\Nnq/onN694/BtZqhFLKPM58N7yLcZnuEvUUXBj08yrl3NI/K6s8/MT7jiOOASSXI + \\l7WdmplNsDz4SgCbZN2fOUvRJ9e4 + \\-----END CERTIFICATE----- + \\ + \\# ePKI Root Certification Authority + \\-----BEGIN CERTIFICATE----- + \\MIIFsDCCA5igAwIBAgIQFci9ZUdcr7iXAF7kBtK8nTANBgkqhkiG9w0BAQUFADBe + \\MQswCQYDVQQGEwJUVzEjMCEGA1UECgwaQ2h1bmdod2EgVGVsZWNvbSBDby4sIEx0 + \\ZC4xKjAoBgNVBAsMIWVQS0kgUm9vdCBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAe + \\Fw0wNDEyMjAwMjMxMjdaFw0zNDEyMjAwMjMxMjdaMF4xCzAJBgNVBAYTAlRXMSMw + \\IQYDVQQKDBpDaHVuZ2h3YSBUZWxlY29tIENvLiwgTHRkLjEqMCgGA1UECwwhZVBL + \\SSBSb290IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIICIjANBgkqhkiG9w0BAQEF + \\AAOCAg8AMIICCgKCAgEA4SUP7o3biDN1Z82tH306Tm2d0y8U82N0ywEhajfqhFAH + \\SyZbCUNsIZ5qyNUD9WBpj8zwIuQf5/dqIjG3LBXy4P4AakP/h2XGtRrBp0xtInAh + \\ijHyl3SJCRImHJ7K2RKilTza6We/CKBk49ZCt0Xvl/T29de1ShUCWH2YWEtgvM3X + \\DZoTM1PRYfl61dd4s5oz9wCGzh1NlDivqOx4UXCKXBCDUSH3ET00hl7lSM2XgYI1 + \\TBnsZfZrxQWh7kcT1rMhJ5QQCtkkO7q+RBNGMD+XPNjX12ruOzjjK9SXDrkb5wdJ + \\fzcq+Xd4z1TtW0ado4AOkUPB1ltfFLqfpo0kR0BZv3I4sjZsN/+Z0V0OWQqraffA + \\sgRFelQArr5T9rXn4fg8ozHSqf4hUmTFpmfwdQcGlBSBVcYn5AGPF8Fqcde+S/uU + \\WH1+ETOxQvdibBjWzwloPn9s9h6PYq2lY9sJpx8iQkEeb5mKPtf5P0B6ebClAZLS + \\nT0IFaUQAS2zMnaolQ2zepr7BxB4EW/hj8e6DyUadCrlHJhBmd8hh+iVBmoKs2pH + \\dmX2Os+PYhcZewoozRrSgx4hxyy/vv9haLdnG7t4TY3OZ+XkwY63I2binZB1NJip + \\NiuKmpS5nezMirH4JYlcWrYvjB9teSSnUmjDhDXiZo1jDiVN1Rmy5nk3pyKdVDEC + \\AwEAAaNqMGgwHQYDVR0OBBYEFB4M97Zn8uGSJglFwFU5Lnc/QkqiMAwGA1UdEwQF + \\MAMBAf8wOQYEZyoHAAQxMC8wLQIBADAJBgUrDgMCGgUAMAcGBWcqAwAABBRFsMLH + \\ClZ87lt4DJX5GFPBphzYEDANBgkqhkiG9w0BAQUFAAOCAgEACbODU1kBPpVJufGB + \\uvl2ICO1J2B01GqZNF5sAFPZn/KmsSQHRGoqxqWOeBLoR9lYGxMqXnmbnwoqZ6Yl + \\PwZpVnPDimZI+ymBV3QGypzqKOg4ZyYr8dW1P2WT+DZdjo2NQCCHGervJ8A9tDkP + \\JXtoUHRVnAxZfVo9QZQlUgjgRywVMRnVvwdVxrsStZf0X4OFunHB2WyBEXYKCrC/ + \\gpf36j36+uwtqSiUO1bd0lEursC9CBWMd1I0ltabrNMdjmEPNXubrjlpC2JgQCA2 + \\j6/7Nu4tCEoduL+bXPjqpRugc6bY+G7gMwRfaKonh+3ZwZCc7b3jajWvY9+rGNm6 + \\5ulK6lCKD2GTHuItGeIwlDWSXQ62B68ZgI9HkFFLLk3dheLSClIKF5r8GrBQAuUB + \\o2M3IUxExJtRmREOc5wGj1QupyheRDmHVi03vYVElOEMSyycw5KFNGHLD7ibSkNS + \\/jQ6fbjpKdx2qcgw+BRxgMYeNkh0IkFch4LoGHGLQYlE535YW6i4jRPpp2zDR+2z + \\Gp1iro2C6pSe3VkQw63d4k3jMdXH7OjysP6SHhYKGvzZ8/gntsm+HbRsZJB/9OTE + \\W9c3rkIO3aQab3yIVMUWbuF6aC74Or8NpDyJO3inTmODBCEIZ43ygknQW/2xzQ+D + \\hNQ+IIX3Sj0rnP0qCglN6oH4EZw= + \\-----END CERTIFICATE----- + ); + const chain = try CertificateChain.from_pem(std.testing.allocator, fbs.reader()); + defer chain.deinit(); +} diff --git a/src/deps/picohttp.zig b/src/deps/picohttp.zig index 7b081b59f..3340793ca 100644 --- a/src/deps/picohttp.zig +++ b/src/deps/picohttp.zig @@ -106,8 +106,9 @@ pub const Response = struct { status_code: usize, status: []const u8, headers: []const Header, + bytes_read: c_int = 0, - pub fn parse(buf: []const u8, src: []Header) !Response { + pub fn parseParts(buf: []const u8, src: []Header, offset: ?*usize) !Response { var minor_version: c_int = undefined; var status_code: c_int = undefined; var status: []const u8 = undefined; @@ -122,20 +123,29 @@ pub const Response = struct { &status.len, @ptrCast([*c]c.phr_header, src.ptr), &num_headers, - 0, + offset.?.*, ); return switch (rc) { -1 => error.BadResponse, - -2 => error.ShortRead, + -2 => brk: { + offset.?.* += buf.len; + + break :brk error.ShortRead; + }, else => |bytes_read| Response{ .minor_version = @intCast(usize, minor_version), .status_code = @intCast(usize, status_code), .status = status, .headers = src[0..num_headers], + .bytes_read = bytes_read, }, }; } + + pub fn parse(buf: []const u8, src: []Header) !Response { + return try parseParts(buf, src, 0); + } }; test "pico_http: parse response" { diff --git a/src/http_client.zig b/src/http_client.zig new file mode 100644 index 000000000..e01d6aa92 --- /dev/null +++ b/src/http_client.zig @@ -0,0 +1,444 @@ +const picohttp = @import("picohttp"); +usingnamespace @import("./global.zig"); +const std = @import("std"); +const Headers = @import("./javascript/jsc/webcore/response.zig").Headers; +const URL = @import("./query_string_map.zig").URL; +const Method = @import("./http.zig").Method; +const iguanaTLS = @import("iguanaTLS"); +const Api = @import("./api/schema.zig").Api; + +const HTTPClient = @This(); +const SOCKET_FLAGS = os.SOCK_CLOEXEC; + +fn writeRequest( + comptime Writer: type, + writer: Writer, + request: picohttp.Request, + body: string, + // header_hashes: []u64, +) !void { + try writer.writeAll(request.method); + try writer.writeAll(" "); + try writer.writeAll(request.path); + try writer.writeAll(" HTTP/1.1\r\n"); + + for (request.headers) |header, i| { + try writer.writeAll(header.name); + try writer.writeAll(": "); + try writer.writeAll(header.value); + try writer.writeAll("\r\n"); + } +} + +method: Method, +header_entries: Headers.Entries, +header_buf: string, +url: URL, +allocator: *std.mem.Allocator, + +pub fn init(allocator: *std.mem.Allocator, method: Method, url: URL, header_entries: Headers.Entries, header_buf: string) HTTPClient { + return HTTPClient{ + .allocator = allocator, + .method = method, + .url = url, + .header_entries = header_entries, + .header_buf = header_buf, + }; +} + +threadlocal var response_headers_buf: [256]picohttp.Header = undefined; +threadlocal var request_headers_buf: [256]picohttp.Header = undefined; +threadlocal var header_name_hashes: [256]u64 = undefined; +// threadlocal var resolver_cache +const tcp = std.x.net.tcp; +const ip = std.x.net.ip; + +const IPv4 = std.x.os.IPv4; +const IPv6 = std.x.os.IPv6; +const Socket = std.x.os.Socket; +const os = std.os; + +// lowercase hash header names so that we can be sure +fn hashHeaderName(name: string) u64 { + var hasher = std.hash.Wyhash.init(0); + var remain: string = name; + var buf: [32]u8 = undefined; + var buf_slice: []u8 = std.mem.span(&buf); + + while (remain.len > 0) { + var end = std.math.min(hasher.buf.len, remain.len); + + hasher.update(strings.copyLowercase(std.mem.span(remain[0..end]), buf_slice)); + remain = remain[end..]; + } + return hasher.final(); +} + +const host_header_hash = hashHeaderName("Host"); +const connection_header_hash = hashHeaderName("Connection"); + +const content_encoding_hash = hashHeaderName("Content-Encoding"); +const host_header_name = "Host"; +const content_length_header_name = "Content-Length"; +const content_length_header_hash = hashHeaderName("Content-Length"); +const connection_header = picohttp.Header{ .name = "Connection", .value = "close" }; +const accept_header = picohttp.Header{ .name = "Accept", .value = "*/*" }; +const accept_header_hash = hashHeaderName("Accept"); + +pub fn headerStr(this: *const HTTPClient, ptr: Api.StringPointer) string { + return this.header_buf[ptr.offset..][0..ptr.length]; +} + +pub fn buildRequest(this: *const HTTPClient, body_len: usize) picohttp.Request { + var header_count: usize = 0; + var header_entries = this.header_entries.slice(); + var header_names = header_entries.items(.name); + var header_values = header_entries.items(.value); + + for (header_names) |head, i| { + const name = this.headerStr(head); + // Hash it as lowercase + const hash = hashHeaderName(request_headers_buf[header_count].name); + + // Skip host and connection header + // we manage those + switch (hash) { + host_header_hash, + connection_header_hash, + content_length_header_hash, + accept_header_hash, + => { + continue; + }, + else => {}, + } + + request_headers_buf[header_count] = picohttp.Header{ + .name = name, + .value = this.headerStr(header_values[i]), + }; + + // header_name_hashes[header_count] = hash; + + // // ensure duplicate headers come after each other + // if (header_count > 2) { + // var head_i: usize = header_count - 1; + // while (head_i > 0) : (head_i -= 1) { + // if (header_name_hashes[head_i] == header_name_hashes[header_count]) { + // std.mem.swap(picohttp.Header, &header_name_hashes[header_count], &header_name_hashes[head_i + 1]); + // std.mem.swap(u64, &request_headers_buf[header_count], &request_headers_buf[head_i + 1]); + // break; + // } + // } + // } + header_count += 1; + } + + request_headers_buf[header_count] = connection_header; + header_count += 1; + request_headers_buf[header_count] = accept_header; + header_count += 1; + + request_headers_buf[header_count] = picohttp.Header{ + .name = host_header_name, + .value = this.url.hostname, + }; + header_count += 1; + + if (body_len > 0) { + request_headers_buf[header_count] = picohttp.Header{ + .name = host_header_name, + .value = this.url.hostname, + }; + header_count += 1; + } + + return picohttp.Request{ + .method = @tagName(this.method), + .path = this.url.path, + .minor_version = 1, + .headers = request_headers_buf[0..header_count], + }; +} + +pub fn connect( + this: *HTTPClient, +) !tcp.Client { + var client: tcp.Client = try tcp.Client.init(tcp.Domain.ip, .{ .close_on_exec = true }); + const port = this.url.getPortAuto(); + + // if (this.url.isLocalhost()) { + // try client.connect( + // try std.x.os.Socket.Address.initIPv4(try std.net.Address.resolveIp("localhost", port), port), + // ); + // } else { + // } else if (this.url.isDomainName()) { + var stream = try std.net.tcpConnectToHost(default_allocator, this.url.hostname, port); + client.socket = std.x.os.Socket.from(stream.handle); + // } + // } else if (this.url.getIPv4Address()) |ip_addr| { + // try client.connect(std.x.os.Socket.Address(ip_addr, port)); + // } else if (this.url.getIPv6Address()) |ip_addr| { + // try client.connect(std.x.os.Socket.Address.initIPv6(ip_addr, port)); + // } else { + // return error.MissingHostname; + // } + + return client; +} + +threadlocal var http_req_buf: [65436]u8 = undefined; + +pub inline fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response { + if (this.url.isHTTPS()) { + return this.sendHTTPS(body, body_out_str); + } else { + return this.sendHTTP(body, body_out_str); + } +} + +pub fn sendHTTP(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response { + var client = try this.connect(); + defer { + std.os.closeSocket(client.socket.fd); + } + var request = buildRequest(this, body.len); + + var client_writer = client.writer(SOCKET_FLAGS); + { + var client_writer_buffered = std.io.bufferedWriter(client_writer); + var client_writer_buffered_writer = client_writer_buffered.writer(); + + try writeRequest(@TypeOf(&client_writer_buffered_writer), &client_writer_buffered_writer, request, body); + try client_writer_buffered_writer.writeAll("\r\n"); + try client_writer_buffered.flush(); + } + + if (body.len > 0) { + try client_writer.writeAll(body); + } + + var client_reader = client.reader(SOCKET_FLAGS); + var req_buf_len = try client_reader.readAll(&http_req_buf); + var request_buffer = http_req_buf[0..req_buf_len]; + var response: picohttp.Response = undefined; + + { + var response_length: usize = 0; + restart: while (true) { + response = picohttp.Response.parseParts(request_buffer, &response_headers_buf, &response_length) catch |err| { + switch (err) { + error.ShortRead => { + continue :restart; + }, + else => { + return err; + }, + } + }; + break :restart; + } + } + + body_out_str.reset(); + var content_length: u32 = 0; + for (response.headers) |header| { + switch (hashHeaderName(header.name)) { + content_length_header_hash => { + content_length = std.fmt.parseInt(u32, header.value, 10) catch 0; + try body_out_str.inflate(content_length); + body_out_str.list.expandToCapacity(); + }, + content_encoding_hash => { + return error.UnsupportedEncoding; + }, + else => {}, + } + } + + if (content_length > 0) { + var remaining_content_length = content_length; + var remainder = http_req_buf[@intCast(u32, response.bytes_read)..]; + remainder = remainder[0..std.math.min(remainder.len, content_length)]; + + var body_size: usize = 0; + if (remainder.len > 0) { + std.mem.copy(u8, body_out_str.list.items, remainder); + body_size = @intCast(u32, remainder.len); + remaining_content_length -= @intCast(u32, remainder.len); + } + + while (remaining_content_length > 0) { + const size = @intCast(u32, try client.read(body_out_str.list.items[body_size..], SOCKET_FLAGS)); + if (size == 0) break; + + body_size += size; + remaining_content_length -= size; + } + + body_out_str.list.items.len = body_size; + } + + return response; +} + +pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *MutableString) !picohttp.Response { + var connection = try this.connect(); + + var arena = std.heap.ArenaAllocator.init(this.allocator); + defer arena.deinit(); + + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + var client = try iguanaTLS.client_connect( + .{ + .rand = rand, + .temp_allocator = &arena.allocator, + .reader = connection.reader(SOCKET_FLAGS), + .writer = connection.writer(SOCKET_FLAGS), + .cert_verifier = .none, + .protocols = &[_][]const u8{"http/1.1"}, + }, + this.url.hostname, + ); + + defer { + client.close_notify() catch {}; + } + + var request = buildRequest(this, body_str.len); + const body = body_str; + + var client_writer = client.writer(); + { + var client_writer_buffered = std.io.bufferedWriter(client_writer); + var client_writer_buffered_writer = client_writer_buffered.writer(); + + try writeRequest(@TypeOf(&client_writer_buffered_writer), &client_writer_buffered_writer, request, body); + try client_writer_buffered_writer.writeAll("\r\n"); + try client_writer_buffered.flush(); + } + + if (body.len > 0) { + try client_writer.writeAll(body); + } + + var client_reader = client.reader(); + var req_buf_len = try client_reader.readAll(&http_req_buf); + var request_buffer = http_req_buf[0..req_buf_len]; + var response: picohttp.Response = undefined; + + { + var response_length: usize = 0; + restart: while (true) { + response = picohttp.Response.parseParts(request_buffer, &response_headers_buf, &response_length) catch |err| { + switch (err) { + error.ShortRead => { + continue :restart; + }, + else => { + return err; + }, + } + }; + break :restart; + } + } + + body_out_str.reset(); + var content_length: u32 = 0; + for (response.headers) |header| { + switch (hashHeaderName(header.name)) { + content_length_header_hash => { + content_length = std.fmt.parseInt(u32, header.value, 10) catch 0; + try body_out_str.inflate(content_length); + body_out_str.list.expandToCapacity(); + }, + content_encoding_hash => { + return error.UnsupportedEncoding; + }, + else => {}, + } + } + + if (content_length > 0) { + var remaining_content_length = content_length; + var remainder = http_req_buf[@intCast(u32, response.bytes_read)..]; + remainder = remainder[0..std.math.min(remainder.len, content_length)]; + + var body_size: usize = 0; + if (remainder.len > 0) { + std.mem.copy(u8, body_out_str.list.items, remainder); + body_size = @intCast(u32, remainder.len); + remaining_content_length -= @intCast(u32, remainder.len); + } + + while (remaining_content_length > 0) { + const size = @intCast(u32, try client.read( + body_out_str.list.items[body_size..], + )); + if (size == 0) break; + + body_size += size; + remaining_content_length -= size; + } + + body_out_str.list.items.len = body_size; + } + + return response; +} + +// zig test src/http_client.zig --test-filter "sendHTTP" -lc -lc++ /Users/jarred/Code/bun/src/deps/picohttpparser.o --cache-dir /Users/jarred/Code/bun/zig-cache --global-cache-dir /Users/jarred/.cache/zig --name bun --pkg-begin clap /Users/jarred/Code/bun/src/deps/zig-clap/clap.zig --pkg-end --pkg-begin picohttp /Users/jarred/Code/bun/src/deps/picohttp.zig --pkg-end --pkg-begin iguanaTLS /Users/jarred/Code/bun/src/deps/iguanaTLS/src/main.zig --pkg-end -I /Users/jarred/Code/bun/src/deps -I /Users/jarred/Code/bun/src/deps/mimalloc -I /usr/local/opt/icu4c/include -L src/deps/mimalloc -L /usr/local/opt/icu4c/lib --main-pkg-path /Users/jarred/Code/bun --enable-cache +test "sendHTTP" { + var headers = try std.heap.c_allocator.create(Headers); + headers.* = Headers{ + .entries = @TypeOf(headers.entries){}, + .buf = @TypeOf(headers.buf){}, + .used = 0, + .allocator = std.heap.c_allocator, + }; + + headers.appendHeader("X-What", "ok", true, true, false); + + var client = HTTPClient.init( + std.heap.c_allocator, + .GET, + URL.parse("http://example.com/"), + headers.entries, + headers.buf.items, + ); + var body_out_str = try MutableString.init(std.heap.c_allocator, 0); + var response = try client.sendHTTP("", &body_out_str); + try std.testing.expectEqual(response.status_code, 200); + try std.testing.expectEqual(body_out_str.list.items.len, 1256); +} + +// zig test src/http_client.zig --test-filter "sendHTTPS" -lc -lc++ /Users/jarred/Code/bun/src/deps/picohttpparser.o --cache-dir /Users/jarred/Code/bun/zig-cache --global-cache-dir /Users/jarred/.cache/zig --name bun --pkg-begin clap /Users/jarred/Code/bun/src/deps/zig-clap/clap.zig --pkg-end --pkg-begin picohttp /Users/jarred/Code/bun/src/deps/picohttp.zig --pkg-end --pkg-begin iguanaTLS /Users/jarred/Code/bun/src/deps/iguanaTLS/src/main.zig --pkg-end -I /Users/jarred/Code/bun/src/deps -I /Users/jarred/Code/bun/src/deps/mimalloc -I /usr/local/opt/icu4c/include -L src/deps/mimalloc -L /usr/local/opt/icu4c/lib --main-pkg-path /Users/jarred/Code/bun --enable-cache +test "sendHTTPS" { + var headers = try std.heap.c_allocator.create(Headers); + headers.* = Headers{ + .entries = @TypeOf(headers.entries){}, + .buf = @TypeOf(headers.buf){}, + .used = 0, + .allocator = std.heap.c_allocator, + }; + + headers.appendHeader("X-What", "ok", true, true, false); + + var client = HTTPClient.init( + std.heap.c_allocator, + .GET, + URL.parse("https://hookb.in/aBnOOWN677UXQ9kkQ2g3"), + headers.entries, + headers.buf.items, + ); + var body_out_str = try MutableString.init(std.heap.c_allocator, 0); + var response = try client.sendHTTPS("", &body_out_str); + try std.testing.expectEqual(response.status_code, 200); + try std.testing.expectEqual(body_out_str.list.items.len, 1256); +} diff --git a/src/javascript/jsc/base.zig b/src/javascript/jsc/base.zig index 58075dc38..3a0355503 100644 --- a/src/javascript/jsc/base.zig +++ b/src/javascript/jsc/base.zig @@ -726,11 +726,15 @@ pub fn NewClass( var static_functions = brk: { var funcs: [function_name_refs.len + 1]js.JSStaticFunction = undefined; - std.mem.set(js.JSStaticFunction, &funcs, js.JSStaticFunction{ - .name = @intToPtr([*c]const u8, 0), - .callAsFunction = null, - .attributes = js.JSPropertyAttributes.kJSPropertyAttributeNone, - },); + std.mem.set( + js.JSStaticFunction, + &funcs, + js.JSStaticFunction{ + .name = @intToPtr([*c]const u8, 0), + .callAsFunction = null, + .attributes = js.JSPropertyAttributes.kJSPropertyAttributeNone, + }, + ); break :brk funcs; }; var instance_functions = std.mem.zeroes([function_names.len]js.JSObjectRef); @@ -738,36 +742,40 @@ pub fn NewClass( var property_name_refs = std.mem.zeroes([property_names.len]js.JSStringRef); const property_name_literals = property_names; var static_properties = brk: { - var props: [property_names.len]js.JSStaticValue = undefined; - std.mem.set(js.JSStaticValue, &props, js.JSStaticValue{ - .name = @intToPtr([*c]const u8, 0), - .getProperty = null, - .setProperty = null, - .attributes = js.JSPropertyAttributes.kJSPropertyAttributeNone, - },); + var props: [property_names.len]js.JSStaticValue = undefined; + std.mem.set( + js.JSStaticValue, + &props, + js.JSStaticValue{ + .name = @intToPtr([*c]const u8, 0), + .getProperty = null, + .setProperty = null, + .attributes = js.JSPropertyAttributes.kJSPropertyAttributeNone, + }, + ); break :brk props; }; pub var ref: js.JSClassRef = null; pub var loaded = false; - pub var definition: js.JSClassDefinition =.{ - .version = 0, - .attributes = js.JSClassAttributes.kJSClassAttributeNone, - .className = name[0..:0].ptr, - .parentClass = null, - .staticValues = null, - .staticFunctions = null, - .initialize = null, - .finalize = null, - .hasProperty = null, - .getProperty = null, - .setProperty = null, - .deleteProperty = null, - .getPropertyNames = null, - .callAsFunction = null, - .callAsConstructor = null, - .hasInstance = null, - .convertToType = null, + pub var definition: js.JSClassDefinition = .{ + .version = 0, + .attributes = js.JSClassAttributes.kJSClassAttributeNone, + .className = name[0.. :0].ptr, + .parentClass = null, + .staticValues = null, + .staticFunctions = null, + .initialize = null, + .finalize = null, + .hasProperty = null, + .getProperty = null, + .setProperty = null, + .deleteProperty = null, + .getPropertyNames = null, + .callAsFunction = null, + .callAsConstructor = null, + .hasInstance = null, + .convertToType = null, }; const ConstructorWrapper = struct { pub fn rfn( @@ -1326,7 +1334,7 @@ pub fn NewClass( .callAsConstructor = null, .hasInstance = null, .convertToType = null, - }; + }; if (static_functions.len > 0) { std.mem.set(js.JSStaticFunction, &static_functions, std.mem.zeroes(js.JSStaticFunction)); @@ -1338,6 +1346,8 @@ pub fn NewClass( def.callAsConstructor = To.JS.Constructor(staticFunctions.constructor.rfn).rfn; } else if (comptime strings.eqlComptime(function_names[i], "finalize")) { def.finalize = To.JS.Finalize(ZigType, staticFunctions.finalize.rfn).rfn; + } else if (comptime strings.eqlComptime(function_names[i], "call")) { + def.callAsFunction = To.JS.Callback(ZigType, staticFunctions.call.rfn).rfn; } else if (comptime strings.eqlComptime(function_names[i], "callAsFunction")) { const ctxfn = @field(staticFunctions, function_names[i]).rfn; const Func: std.builtin.TypeInfo.Fn = @typeInfo(@TypeOf(ctxfn)).Fn; @@ -1379,6 +1389,8 @@ pub fn NewClass( def.callAsConstructor = To.JS.Constructor(staticFunctions.constructor).rfn; } else if (comptime strings.eqlComptime(function_names[i], "finalize")) { def.finalize = To.JS.Finalize(ZigType, staticFunctions.finalize).rfn; + } else if (comptime strings.eqlComptime(function_names[i], "call")) { + def.callAsFunction = To.JS.Callback(ZigType, staticFunctions.call).rfn; } else { var callback = To.JS.Callback( ZigType, diff --git a/src/javascript/jsc/bindings/ZigGlobalObject.cpp b/src/javascript/jsc/bindings/ZigGlobalObject.cpp index ea7e89141..8caf09662 100644 --- a/src/javascript/jsc/bindings/ZigGlobalObject.cpp +++ b/src/javascript/jsc/bindings/ZigGlobalObject.cpp @@ -165,7 +165,7 @@ void GlobalObject::setConsole(void *console) { // and any other objects available globally. void GlobalObject::installAPIGlobals(JSClassRef *globals, int count) { WTF::Vector<GlobalPropertyInfo> extraStaticGlobals; - extraStaticGlobals.reserveCapacity((size_t)count + 1); + extraStaticGlobals.reserveCapacity((size_t)count + 2); // This is not nearly a complete implementation. It's just enough to make some npm packages that // were compiled with Webpack to run without crashing in this environment. @@ -223,9 +223,7 @@ JSC::Identifier GlobalObject::moduleLoaderResolve(JSGlobalObject *globalObject, res.success = false; ZigString keyZ = toZigString(key, globalObject); ZigString referrerZ = referrer.isString() ? toZigString(referrer, globalObject) : ZigStringEmpty; - Zig__GlobalObject__resolve(&res, globalObject, &keyZ, - &referrerZ - ); + Zig__GlobalObject__resolve(&res, globalObject, &keyZ, &referrerZ); if (res.success) { return toIdentifier(res.result.value, globalObject); @@ -250,11 +248,9 @@ JSC::JSInternalPromise *GlobalObject::moduleLoaderImportModule(JSGlobalObject *g auto sourceURL = sourceOrigin.url(); ErrorableZigString resolved; auto moduleNameZ = toZigString(moduleNameValue, globalObject); - auto sourceOriginZ = sourceURL.isEmpty() ? ZigStringCwd - : toZigString(sourceURL.fileSystemPath()); + auto sourceOriginZ = sourceURL.isEmpty() ? ZigStringCwd : toZigString(sourceURL.fileSystemPath()); resolved.success = false; - Zig__GlobalObject__resolve(&resolved, globalObject, &moduleNameZ, &sourceOriginZ - ); + Zig__GlobalObject__resolve(&resolved, globalObject, &moduleNameZ, &sourceOriginZ); if (!resolved.success) { throwException(scope, resolved.result.err, globalObject); return promise->rejectWithCaughtException(globalObject, scope); @@ -382,8 +378,7 @@ JSC::JSInternalPromise *GlobalObject::moduleLoaderFetch(JSGlobalObject *globalOb res.result.err.code = 0; res.result.err.ptr = nullptr; - Zig__GlobalObject__fetch(&res, globalObject, &moduleKeyZig, - &source ); + Zig__GlobalObject__fetch(&res, globalObject, &moduleKeyZig, &source); if (!res.success) { throwException(scope, res.result.err, globalObject); diff --git a/src/javascript/jsc/bindings/bindings.cpp b/src/javascript/jsc/bindings/bindings.cpp index 766de863b..d5c5e057d 100644 --- a/src/javascript/jsc/bindings/bindings.cpp +++ b/src/javascript/jsc/bindings/bindings.cpp @@ -1765,4 +1765,15 @@ void WTF__URL__setQuery(WTF__URL *arg0, bWTF__StringView arg1) { void WTF__URL__setUser(WTF__URL *arg0, bWTF__StringView arg1) { arg0->setUser(*Wrap<WTF::StringView, bWTF__StringView>::unwrap(&arg1)); }; + +JSC__JSValue JSC__JSPromise__rejectedPromiseValue(JSC__JSGlobalObject *arg0, + JSC__JSValue JSValue1) { + return JSC::JSValue::encode( + JSC::JSPromise::rejectedPromise(arg0, JSC::JSValue::decode(JSValue1))); +} +JSC__JSValue JSC__JSPromise__resolvedPromiseValue(JSC__JSGlobalObject *arg0, + JSC__JSValue JSValue1) { + return JSC::JSValue::encode( + JSC::JSPromise::resolvedPromise(arg0, JSC::JSValue::decode(JSValue1))); } +}
\ No newline at end of file diff --git a/src/javascript/jsc/bindings/bindings.zig b/src/javascript/jsc/bindings/bindings.zig index 3ff10a285..4f9d5595e 100644 --- a/src/javascript/jsc/bindings/bindings.zig +++ b/src/javascript/jsc/bindings/bindings.zig @@ -438,10 +438,19 @@ pub const JSPromise = extern struct { pub fn resolvedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return cppFn("resolvedPromise", .{ globalThis, value }); } + + pub fn resolvedPromiseValue(globalThis: *JSGlobalObject, value: JSValue) JSValue { + return cppFn("resolvedPromiseValue", .{ globalThis, value }); + } + pub fn rejectedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return cppFn("rejectedPromise", .{ globalThis, value }); } + pub fn rejectedPromiseValue(globalThis: *JSGlobalObject, value: JSValue) JSValue { + return cppFn("rejectedPromiseValue", .{ globalThis, value }); + } + pub fn resolve(this: *JSPromise, globalThis: *JSGlobalObject, value: JSValue) void { cppFn("resolve", .{ this, globalThis, value }); } @@ -470,6 +479,8 @@ pub const JSPromise = extern struct { "rejectAsHandled", // "rejectException", "rejectAsHandledException", + "rejectedPromiseValue", + "resolvedPromiseValue", }; }; diff --git a/src/javascript/jsc/bindings/headers-cpp.h b/src/javascript/jsc/bindings/headers-cpp.h index 918138460..400697777 100644 --- a/src/javascript/jsc/bindings/headers-cpp.h +++ b/src/javascript/jsc/bindings/headers-cpp.h @@ -1,4 +1,4 @@ -//-- AUTOGENERATED FILE -- 1631085611 +//-- AUTOGENERATED FILE -- 1631179623 // clang-format off #pragma once diff --git a/src/javascript/jsc/bindings/headers.h b/src/javascript/jsc/bindings/headers.h index a26fecc2f..125937a59 100644 --- a/src/javascript/jsc/bindings/headers.h +++ b/src/javascript/jsc/bindings/headers.h @@ -1,4 +1,4 @@ -//-- AUTOGENERATED FILE -- 1631085611 +//-- AUTOGENERATED FILE -- 1631179623 // clang-format: off #pragma once @@ -285,9 +285,11 @@ CPP_DECL void JSC__JSPromise__reject(JSC__JSPromise* arg0, JSC__JSGlobalObject* CPP_DECL void JSC__JSPromise__rejectAsHandled(JSC__JSPromise* arg0, JSC__JSGlobalObject* arg1, JSC__JSValue JSValue2); CPP_DECL void JSC__JSPromise__rejectAsHandledException(JSC__JSPromise* arg0, JSC__JSGlobalObject* arg1, JSC__Exception* arg2); CPP_DECL JSC__JSPromise* JSC__JSPromise__rejectedPromise(JSC__JSGlobalObject* arg0, JSC__JSValue JSValue1); +CPP_DECL JSC__JSValue JSC__JSPromise__rejectedPromiseValue(JSC__JSGlobalObject* arg0, JSC__JSValue JSValue1); CPP_DECL void JSC__JSPromise__rejectWithCaughtException(JSC__JSPromise* arg0, JSC__JSGlobalObject* arg1, bJSC__ThrowScope arg2); CPP_DECL void JSC__JSPromise__resolve(JSC__JSPromise* arg0, JSC__JSGlobalObject* arg1, JSC__JSValue JSValue2); CPP_DECL JSC__JSPromise* JSC__JSPromise__resolvedPromise(JSC__JSGlobalObject* arg0, JSC__JSValue JSValue1); +CPP_DECL JSC__JSValue JSC__JSPromise__resolvedPromiseValue(JSC__JSGlobalObject* arg0, JSC__JSValue JSValue1); CPP_DECL JSC__JSValue JSC__JSPromise__result(const JSC__JSPromise* arg0, JSC__VM* arg1); CPP_DECL uint32_t JSC__JSPromise__status(const JSC__JSPromise* arg0, JSC__VM* arg1); diff --git a/src/javascript/jsc/bindings/headers.zig b/src/javascript/jsc/bindings/headers.zig index c1504a31f..f34816253 100644 --- a/src/javascript/jsc/bindings/headers.zig +++ b/src/javascript/jsc/bindings/headers.zig @@ -37,36 +37,7 @@ pub const __mbstate_t = extern union { pub const __darwin_mbstate_t = __mbstate_t; pub const __darwin_ptrdiff_t = c_long; pub const __darwin_size_t = c_ulong; -pub const __builtin_va_list = [*c]u8; -pub const __darwin_va_list = __builtin_va_list; -pub const __darwin_wchar_t = c_int; -pub const __darwin_rune_t = __darwin_wchar_t; -pub const __darwin_wint_t = c_int; -pub const __darwin_clock_t = c_ulong; -pub const __darwin_socklen_t = __uint32_t; -pub const __darwin_ssize_t = c_long; -pub const __darwin_time_t = c_long; -pub const __darwin_blkcnt_t = __int64_t; -pub const __darwin_blksize_t = __int32_t; -pub const __darwin_dev_t = __int32_t; -pub const __darwin_fsblkcnt_t = c_uint; -pub const __darwin_fsfilcnt_t = c_uint; -pub const __darwin_gid_t = __uint32_t; -pub const __darwin_id_t = __uint32_t; -pub const __darwin_ino64_t = __uint64_t; -pub const __darwin_ino_t = __darwin_ino64_t; -pub const __darwin_mach_port_name_t = __darwin_natural_t; -pub const __darwin_mach_port_t = __darwin_mach_port_name_t; -pub const __darwin_mode_t = __uint16_t; -pub const __darwin_off_t = __int64_t; -pub const __darwin_pid_t = __int32_t; -pub const __darwin_sigset_t = __uint32_t; -pub const __darwin_suseconds_t = __int32_t; -pub const __darwin_uid_t = __uint32_t; -pub const __darwin_useconds_t = __uint32_t; -pub const __darwin_uuid_t = [16]u8; -pub const __darwin_uuid_string_t = [37]u8; - + pub const JSC__RegExpPrototype = struct_JSC__RegExpPrototype; pub const JSC__GeneratorPrototype = struct_JSC__GeneratorPrototype; @@ -162,9 +133,11 @@ pub extern fn JSC__JSPromise__reject(arg0: [*c]JSC__JSPromise, arg1: [*c]JSC__JS pub extern fn JSC__JSPromise__rejectAsHandled(arg0: [*c]JSC__JSPromise, arg1: [*c]JSC__JSGlobalObject, JSValue2: JSC__JSValue) void; pub extern fn JSC__JSPromise__rejectAsHandledException(arg0: [*c]JSC__JSPromise, arg1: [*c]JSC__JSGlobalObject, arg2: [*c]JSC__Exception) void; pub extern fn JSC__JSPromise__rejectedPromise(arg0: [*c]JSC__JSGlobalObject, JSValue1: JSC__JSValue) [*c]JSC__JSPromise; +pub extern fn JSC__JSPromise__rejectedPromiseValue(arg0: [*c]JSC__JSGlobalObject, JSValue1: JSC__JSValue) JSC__JSValue; pub extern fn JSC__JSPromise__rejectWithCaughtException(arg0: [*c]JSC__JSPromise, arg1: [*c]JSC__JSGlobalObject, arg2: bJSC__ThrowScope) void; pub extern fn JSC__JSPromise__resolve(arg0: [*c]JSC__JSPromise, arg1: [*c]JSC__JSGlobalObject, JSValue2: JSC__JSValue) void; pub extern fn JSC__JSPromise__resolvedPromise(arg0: [*c]JSC__JSGlobalObject, JSValue1: JSC__JSValue) [*c]JSC__JSPromise; +pub extern fn JSC__JSPromise__resolvedPromiseValue(arg0: [*c]JSC__JSGlobalObject, JSValue1: JSC__JSValue) JSC__JSValue; pub extern fn JSC__JSPromise__result(arg0: [*c]const JSC__JSPromise, arg1: [*c]JSC__VM) JSC__JSValue; pub extern fn JSC__JSPromise__status(arg0: [*c]const JSC__JSPromise, arg1: [*c]JSC__VM) u32; pub extern fn JSC__JSInternalPromise__create(arg0: [*c]JSC__JSGlobalObject) [*c]JSC__JSInternalPromise; diff --git a/src/javascript/jsc/javascript.zig b/src/javascript/jsc/javascript.zig index 1d8616ff7..5349aaec1 100644 --- a/src/javascript/jsc/javascript.zig +++ b/src/javascript/jsc/javascript.zig @@ -33,6 +33,7 @@ pub const GlobalClasses = [_]type{ BuildError.Class, ResolveError.Class, Bun.Class, + Fetch.Class, }; const Blob = @import("../../blob.zig"); @@ -276,6 +277,10 @@ pub const Bun = struct { .rfn = Router.match, .ts = Router.match_type_definition, }, + .fetch = .{ + .rfn = Fetch.call, + .ts = d.ts{}, + }, .getImportedStyles = .{ .rfn = Bun.getImportedStyles, .ts = d.ts{ @@ -1348,7 +1353,6 @@ pub const EventListenerMixin = struct { // Rely on JS finalizer var fetch_event = try vm.allocator.create(FetchEvent); - fetch_event.* = FetchEvent{ .request_context = request_context, .request = Request{ .request_context = request_context }, diff --git a/src/javascript/jsc/webcore/response.zig b/src/javascript/jsc/webcore/response.zig index cd4dff8c8..680e8aa06 100644 --- a/src/javascript/jsc/webcore/response.zig +++ b/src/javascript/jsc/webcore/response.zig @@ -4,19 +4,34 @@ const Api = @import("../../../api/schema.zig").Api; const http = @import("../../../http.zig"); usingnamespace @import("../javascript.zig"); usingnamespace @import("../bindings/bindings.zig"); - +const ZigURL = @import("../../../query_string_map.zig").URL; +const HTTPClient = @import("../../../http_client.zig"); +const picohttp = @import("picohttp"); pub const Response = struct { pub const Class = NewClass( Response, .{ .name = "Response" }, .{ .@"constructor" = constructor, + .@"text" = .{ + .rfn = getText, + .ts = d.ts{}, + }, + .@"json" = .{ + .rfn = getJson, + .ts = d.ts{}, + }, + .@"arrayBuffer" = .{ + .rfn = getArrayBuffer, + .ts = d.ts{}, + }, }, .{ // .@"url" = .{ // .@"get" = getURL, // .ro = true, // }, + .@"ok" = .{ .@"get" = getOK, .ro = true, @@ -30,6 +45,7 @@ pub const Response = struct { allocator: *std.mem.Allocator, body: Body, + status_text: string = "", pub const Props = struct {}; @@ -41,7 +57,174 @@ pub const Response = struct { exception: js.ExceptionRef, ) js.JSValueRef { // https://developer.mozilla.org/en-US/docs/Web/API/Response/ok - return js.JSValueMakeBoolean(ctx, this.body.init.status_code >= 200 and this.body.init.status_code <= 299); + return js.JSValueMakeBoolean(ctx, this.body.init.status_code == 304 or (this.body.init.status_code >= 200 and this.body.init.status_code <= 299)); + } + + pub fn getText( + this: *Response, + ctx: js.JSContextRef, + function: js.JSObjectRef, + thisObject: js.JSObjectRef, + arguments: []const js.JSValueRef, + exception: js.ExceptionRef, + ) js.JSValueRef { + // https://developer.mozilla.org/en-US/docs/Web/API/Response/text + defer this.body.value = .Empty; + return JSPromise.resolvedPromiseValue( + VirtualMachine.vm.global, + (brk: { + switch (this.body.value) { + .Unconsumed => { + if (this.body.ptr) |_ptr| { + break :brk ZigString.init(_ptr[0..this.body.len]).toValue(VirtualMachine.vm.global); + } + + break :brk ZigString.init("").toValue(VirtualMachine.vm.global); + }, + .Empty => { + break :brk ZigString.init("").toValue(VirtualMachine.vm.global); + }, + .String => |str| { + break :brk ZigString.init(str).toValue(VirtualMachine.vm.global); + }, + .ArrayBuffer => |buffer| { + break :brk ZigString.init(buffer.ptr[buffer.offset..buffer.byte_len]).toValue(VirtualMachine.vm.global); + }, + } + }), + ).asRef(); + } + + var temp_error_buffer: [4096]u8 = undefined; + var error_arg_list: [1]js.JSObjectRef = undefined; + pub fn getJson( + this: *Response, + ctx: js.JSContextRef, + function: js.JSObjectRef, + thisObject: js.JSObjectRef, + arguments: []const js.JSValueRef, + exception: js.ExceptionRef, + ) js.JSValueRef { + defer this.body.value = .Empty; + var zig_string = ZigString.init(""); + + var js_string = (js.JSValueCreateJSONString( + ctx, + brk: { + switch (this.body.value) { + .Unconsumed => { + if (this.body.ptr) |_ptr| { + zig_string = ZigString.init(_ptr[0..this.body.len]); + break :brk zig_string.toJSStringRef(); + } + + break :brk zig_string.toJSStringRef(); + }, + .Empty => { + break :brk zig_string.toJSStringRef(); + }, + .String => |str| { + zig_string = ZigString.init(str); + break :brk zig_string.toJSStringRef(); + }, + .ArrayBuffer => |buffer| { + zig_string = ZigString.init(buffer.ptr[buffer.offset..buffer.byte_len]); + break :brk zig_string.toJSStringRef(); + }, + } + }, + 0, + exception, + ) orelse { + var out = std.fmt.bufPrint(&temp_error_buffer, "Invalid JSON\n\n \"{s}\"", .{zig_string.slice()[0..std.math.min(zig_string.len, 4000)]}) catch unreachable; + error_arg_list[0] = ZigString.init(out).toValueGC(VirtualMachine.vm.global).asRef(); + return JSPromise.rejectedPromiseValue( + VirtualMachine.vm.global, + JSValue.fromRef( + js.JSObjectMakeError( + ctx, + 1, + &error_arg_list, + exception, + ), + ), + ).asRef(); + }); + defer js.JSStringRelease(js_string); + + return JSPromise.resolvedPromiseValue( + VirtualMachine.vm.global, + JSValue.fromRef( + js.JSValueMakeString( + ctx, + js_string, + ), + ), + ).asRef(); + } + pub fn getArrayBuffer( + this: *Response, + ctx: js.JSContextRef, + function: js.JSObjectRef, + thisObject: js.JSObjectRef, + arguments: []const js.JSValueRef, + exception: js.ExceptionRef, + ) js.JSValueRef { + defer this.body.value = .Empty; + return JSPromise.resolvedPromiseValue( + VirtualMachine.vm.global, + JSValue.fromRef( + (brk: { + switch (this.body.value) { + .Unconsumed => { + if (this.body.ptr) |_ptr| { + break :brk js.JSObjectMakeTypedArrayWithBytesNoCopy( + ctx, + js.JSTypedArrayType.kJSTypedArrayTypeUint8Array, + _ptr, + this.body.len, + null, + null, + exception, + ); + } + + break :brk js.JSObjectMakeTypedArray( + ctx, + js.JSTypedArrayType.kJSTypedArrayTypeUint8Array, + 0, + exception, + ); + }, + .Empty => { + break :brk js.JSObjectMakeTypedArray(ctx, js.JSTypedArrayType.kJSTypedArrayTypeUint8Array, 0, exception); + }, + .String => |str| { + break :brk js.JSObjectMakeTypedArrayWithBytesNoCopy( + ctx, + js.JSTypedArrayType.kJSTypedArrayTypeUint8Array, + @intToPtr([*]u8, @ptrToInt(str.ptr)), + str.len, + null, + null, + exception, + ); + }, + .ArrayBuffer => |buffer| { + break :brk js.JSObjectMakeTypedArrayWithBytesNoCopy( + ctx, + buffer.typed_array_type, + buffer.ptr, + buffer.byte_len, + null, + null, + exception, + ); + }, + } + }), + ), + ).asRef(); } pub fn getStatus( @@ -87,7 +270,7 @@ pub const Response = struct { return http.MimeType.html.value; }, - .ArrayBuffer => { + .Unconsumed, .ArrayBuffer => { return "application/octet-stream"; }, } @@ -134,6 +317,151 @@ pub const Response = struct { } }; +pub const Fetch = struct { + const headers_string = "headers"; + const method_string = "method"; + + var fetch_body_string: MutableString = undefined; + var fetch_body_string_loaded = false; + + pub const Class = NewClass( + void, + .{ .name = "fetch" }, + .{ + .@"call" = .{ + .rfn = Fetch.call, + .ts = d.ts{}, + }, + }, + .{}, + ); + + pub fn call( + this: void, + ctx: js.JSContextRef, + function: js.JSObjectRef, + thisObject: js.JSObjectRef, + arguments: []const js.JSValueRef, + exception: js.ExceptionRef, + ) js.JSObjectRef { + if (arguments.len == 0 or arguments.len > 2) return js.JSValueMakeNull(ctx); + var http_client = HTTPClient.init(getAllocator(ctx), .GET, ZigURL{}, .{}, ""); + var headers: ?Headers = null; + var body: string = ""; + + if (!js.JSValueIsString(ctx, arguments[0])) { + return js.JSValueMakeNull(ctx); + } + + var url_zig_str = ZigString.init(""); + JSValue.fromRef(arguments[0]).toZigString( + &url_zig_str, + VirtualMachine.vm.global, + ); + var url_str = url_zig_str.slice(); + if (url_str.len == 0) return js.JSValueMakeNull(ctx); + http_client.url = ZigURL.parse(url_str); + + if (arguments.len == 2 and js.JSValueIsObject(ctx, arguments[1])) { + var array = js.JSObjectCopyPropertyNames(ctx, arguments[1]); + defer js.JSPropertyNameArrayRelease(array); + const count = js.JSPropertyNameArrayGetCount(array); + var i: usize = 0; + while (i < count) : (i += 1) { + var property_name_ref = js.JSPropertyNameArrayGetNameAtIndex(array, i); + switch (js.JSStringGetLength(property_name_ref)) { + "headers".len => { + if (js.JSStringIsEqualToUTF8CString(property_name_ref, "headers")) { + if (js.JSObjectGetProperty(ctx, arguments[1], property_name_ref, null)) |value| { + if (GetJSPrivateData(Headers, value)) |headers_ptr| { + headers = headers_ptr.*; + } else if (Headers.JS.headersInit(ctx, value) catch null) |headers_| { + headers = headers_; + } + } + } + }, + "body".len => { + if (js.JSStringIsEqualToUTF8CString(property_name_ref, "body")) { + if (js.JSObjectGetProperty(ctx, arguments[1], property_name_ref, null)) |value| { + var body_ = Body.extractBody(ctx, value, false, null, exception); + if (exception != null) return js.JSValueMakeNull(ctx); + switch (body_.value) { + .ArrayBuffer => |arraybuffer| { + body = arraybuffer.ptr[0..arraybuffer.byte_len]; + }, + .String => |str| { + body = str; + }, + else => {}, + } + } + } + }, + "method".len => { + if (js.JSStringIsEqualToUTF8CString(property_name_ref, "method")) { + if (js.JSObjectGetProperty(ctx, arguments[1], property_name_ref, null)) |value| { + var string_ref = js.JSValueToStringCopy(ctx, value, exception); + + if (exception != null) return js.JSValueMakeNull(ctx); + defer js.JSStringRelease(string_ref); + var method_name_buf: [16]u8 = undefined; + var method_name = method_name_buf[0..js.JSStringGetUTF8CString(string_ref, &method_name_buf, method_name_buf.len)]; + http_client.method = http.Method.which(method_name) orelse http_client.method; + } + } + }, + else => {}, + } + } + } + + if (headers) |head| { + http_client.header_entries = head.entries; + http_client.header_buf = head.buf.items; + } + + if (fetch_body_string_loaded) { + fetch_body_string.reset(); + } else { + fetch_body_string = MutableString.init(VirtualMachine.vm.allocator, 0) catch unreachable; + fetch_body_string_loaded = true; + } + + var http_response = http_client.send(body, &fetch_body_string) catch |err| { + const fetch_error = std.fmt.allocPrint(getAllocator(ctx), "Fetch error: {s}", .{@errorName(err)}) catch unreachable; + return JSPromise.rejectedPromiseValue(VirtualMachine.vm.global, ZigString.init(fetch_error).toErrorInstance(VirtualMachine.vm.global)).asRef(); + }; + + var response_headers = Headers.fromPicoHeaders(getAllocator(ctx), http_response.headers) catch unreachable; + response_headers.guard = .immutable; + var response = getAllocator(ctx).create(Response) catch unreachable; + var allocator = getAllocator(ctx); + var duped = allocator.dupeZ(u8, fetch_body_string.list.items) catch unreachable; + response.* = Response{ + .allocator = allocator, + .status_text = allocator.dupe(u8, http_response.status) catch unreachable, + .body = .{ + .init = .{ + .headers = response_headers, + .status_code = @truncate(u16, http_response.status_code), + }, + .value = .{ + .Unconsumed = 0, + }, + .ptr = duped.ptr, + .len = duped.len, + .ptr_allocator = allocator, + }, + }; + + return JSPromise.resolvedPromiseValue( + VirtualMachine.vm.global, + JSValue.fromRef(Response.Class.make(ctx, response)), + ).asRef(); + } +}; + // https://developer.mozilla.org/en-US/docs/Web/API/Headers pub const Headers = struct { pub const Kv = struct { @@ -272,6 +600,77 @@ pub const Headers = struct { return js.JSValueMakeNull(ctx); } + pub fn headersInit(ctx: js.JSContextRef, header_prop: js.JSObjectRef) !?Headers { + const header_keys = js.JSObjectCopyPropertyNames(ctx, header_prop); + defer js.JSPropertyNameArrayRelease(header_keys); + const total_header_count = js.JSPropertyNameArrayGetCount(header_keys); + if (total_header_count == 0) return null; + + // 2 passes through the headers + + // Pass #1: find the "real" count. + // The number of things which are strings or numbers. + // Anything else should be ignored. + // We could throw a TypeError, but ignoring silently is more JavaScript-like imo + var real_header_count: usize = 0; + var estimated_buffer_len: usize = 0; + var j: usize = 0; + while (j < total_header_count) : (j += 1) { + var key_ref = js.JSPropertyNameArrayGetNameAtIndex(header_keys, j); + var value_ref = js.JSObjectGetProperty(ctx, header_prop, key_ref, null); + + switch (js.JSValueGetType(ctx, value_ref)) { + js.JSType.kJSTypeNumber => { + const key_len = js.JSStringGetLength(key_ref); + if (key_len > 0) { + real_header_count += 1; + estimated_buffer_len += key_len; + estimated_buffer_len += std.fmt.count("{d}", .{js.JSValueToNumber(ctx, value_ref, null)}); + } + }, + js.JSType.kJSTypeString => { + const key_len = js.JSStringGetLength(key_ref); + const value_len = js.JSStringGetLength(value_ref); + if (key_len > 0 and value_len > 0) { + real_header_count += 1; + estimated_buffer_len += key_len + value_len; + } + }, + else => {}, + } + } + + if (real_header_count == 0 or estimated_buffer_len == 0) return null; + + j = 0; + var allocator = getAllocator(ctx); + var headers = Headers{ + .allocator = allocator, + .buf = try std.ArrayListUnmanaged(u8).initCapacity(allocator, estimated_buffer_len), + .entries = Headers.Entries{}, + }; + errdefer headers.deinit(); + try headers.entries.ensureTotalCapacity(allocator, real_header_count); + headers.buf.expandToCapacity(); + while (j < total_header_count) : (j += 1) { + var key_ref = js.JSPropertyNameArrayGetNameAtIndex(header_keys, j); + var value_ref = js.JSObjectGetProperty(ctx, header_prop, key_ref, null); + + switch (js.JSValueGetType(ctx, value_ref)) { + js.JSType.kJSTypeNumber => { + if (js.JSStringGetLength(key_ref) == 0) continue; + try headers.appendInit(ctx, key_ref, .kJSTypeNumber, value_ref); + }, + js.JSType.kJSTypeString => { + if (js.JSStringGetLength(value_ref) == 0 or js.JSStringGetLength(key_ref) == 0) continue; + try headers.appendInit(ctx, key_ref, .kJSTypeString, value_ref); + }, + else => {}, + } + } + return headers; + } + // https://developer.mozilla.org/en-US/docs/Web/API/Headers/Headers pub fn constructor( ctx: js.JSContextRef, @@ -283,6 +682,14 @@ pub const Headers = struct { if (arguments.len > 0 and js.JSValueIsObjectOfClass(ctx, arguments[0], Headers.Class.get().*)) { var other = castObj(arguments[0], Headers); other.clone(headers) catch unreachable; + } else if (arguments.len == 1 and js.JSValueIsObject(ctx, arguments[0])) { + headers.* = (JS.headersInit(ctx, arguments[0]) catch unreachable) orelse Headers{ + .entries = @TypeOf(headers.entries){}, + .buf = @TypeOf(headers.buf){}, + .used = 0, + .allocator = getAllocator(ctx), + .guard = Guard.none, + }; } else { headers.* = Headers{ .entries = @TypeOf(headers.entries){}, @@ -356,26 +763,25 @@ pub const Headers = struct { none, }; - // TODO: is it worth making this lazy? instead of copying all the request headers, should we just do it on get/put/iterator? - pub fn fromRequestCtx(allocator: *std.mem.Allocator, request: *http.RequestContext) !Headers { + pub fn fromPicoHeaders(allocator: *std.mem.Allocator, picohttp_headers: []const picohttp.Header) !Headers { var total_len: usize = 0; - for (request.request.headers) |header| { + for (picohttp_headers) |header| { total_len += header.name.len; total_len += header.value.len; } // for the null bytes - total_len += request.request.headers.len * 2; + total_len += picohttp_headers.len * 2; var headers = Headers{ .allocator = allocator, .entries = Entries{}, .buf = std.ArrayListUnmanaged(u8){}, }; - try headers.entries.ensureTotalCapacity(allocator, request.request.headers.len); + try headers.entries.ensureTotalCapacity(allocator, picohttp_headers.len); try headers.buf.ensureTotalCapacity(allocator, total_len); headers.buf.expandToCapacity(); headers.guard = Guard.request; - for (request.request.headers) |header| { + for (picohttp_headers) |header| { headers.entries.appendAssumeCapacity(Kv{ .name = headers.appendString( string, @@ -394,11 +800,14 @@ pub const Headers = struct { }); } - headers.guard = Guard.immutable; - return headers; } + // TODO: is it worth making this lazy? instead of copying all the request headers, should we just do it on get/put/iterator? + pub fn fromRequestCtx(allocator: *std.mem.Allocator, request: *http.RequestContext) !Headers { + return fromPicoHeaders(allocator, request.request.headers); + } + pub fn asStr(headers: *const Headers, ptr: Api.StringPointer) []u8 { return headers.buf.items[ptr.offset..][0..ptr.length]; } @@ -479,7 +888,7 @@ pub const Headers = struct { ), .value = headers.appendString( string, - key, + value, needs_lowercase, needs_normalize, append_null, @@ -577,6 +986,9 @@ pub const Headers = struct { pub const Body = struct { init: Init, value: Value, + ptr: ?[*]u8 = null, + len: usize = 0, + ptr_allocator: ?*std.mem.Allocator = null, pub fn deinit(this: *Body, allocator: *std.mem.Allocator) void { if (this.init.headers) |headers| { @@ -602,7 +1014,7 @@ pub const Body = struct { defer js.JSPropertyNameArrayRelease(array); const count = js.JSPropertyNameArrayGetCount(array); var i: usize = 0; - upper: while (i < count) : (i += 1) { + while (i < count) : (i += 1) { var property_name_ref = js.JSPropertyNameArrayGetNameAtIndex(array, i); switch (js.JSStringGetLength(property_name_ref)) { "headers".len => { @@ -611,73 +1023,7 @@ pub const Body = struct { if (js.JSObjectGetProperty(ctx, init_ref, property_name_ref, null)) |header_prop| { switch (js.JSValueGetType(ctx, header_prop)) { js.JSType.kJSTypeObject => { - const header_keys = js.JSObjectCopyPropertyNames(ctx, header_prop); - defer js.JSPropertyNameArrayRelease(header_keys); - const total_header_count = js.JSPropertyNameArrayGetCount(array); - if (total_header_count == 0) continue :upper; - - // 2 passes through the headers - - // Pass #1: find the "real" count. - // The number of things which are strings or numbers. - // Anything else should be ignored. - // We could throw a TypeError, but ignoring silently is more JavaScript-like imo - var real_header_count: usize = 0; - var estimated_buffer_len: usize = 0; - var j: usize = 0; - while (j < total_header_count) : (j += 1) { - var key_ref = js.JSPropertyNameArrayGetNameAtIndex(header_keys, j); - var value_ref = js.JSObjectGetProperty(ctx, header_prop, key_ref, null); - - switch (js.JSValueGetType(ctx, value_ref)) { - js.JSType.kJSTypeNumber => { - const key_len = js.JSStringGetLength(key_ref); - if (key_len > 0) { - real_header_count += 1; - estimated_buffer_len += key_len; - estimated_buffer_len += std.fmt.count("{d}", .{js.JSValueToNumber(ctx, value_ref, null)}); - } - }, - js.JSType.kJSTypeString => { - const key_len = js.JSStringGetLength(key_ref); - const value_len = js.JSStringGetLength(value_ref); - if (key_len > 0 and value_len > 0) { - real_header_count += 1; - estimated_buffer_len += key_len + value_len; - } - }, - else => {}, - } - } - - if (real_header_count == 0 or estimated_buffer_len == 0) continue :upper; - - j = 0; - var headers = Headers{ - .allocator = allocator, - .buf = try std.ArrayListUnmanaged(u8).initCapacity(allocator, estimated_buffer_len), - .entries = Headers.Entries{}, - }; - errdefer headers.deinit(); - try headers.entries.ensureTotalCapacity(allocator, real_header_count); - - while (j < total_header_count) : (j += 1) { - var key_ref = js.JSPropertyNameArrayGetNameAtIndex(header_keys, j); - var value_ref = js.JSObjectGetProperty(ctx, header_prop, key_ref, null); - - switch (js.JSValueGetType(ctx, value_ref)) { - js.JSType.kJSTypeNumber => { - if (js.JSStringGetLength(key_ref) == 0) continue; - try headers.appendInit(ctx, key_ref, .kJSTypeNumber, value_ref); - }, - js.JSType.kJSTypeString => { - if (js.JSStringGetLength(value_ref) == 0 or js.JSStringGetLength(key_ref) == 0) continue; - try headers.appendInit(ctx, key_ref, .kJSTypeString, value_ref); - }, - else => {}, - } - } - result.headers = headers; + result.headers = try Headers.JS.headersInit(ctx, header_prop); }, else => {}, } @@ -705,10 +1051,12 @@ pub const Body = struct { ArrayBuffer: ArrayBuffer, String: string, Empty: u0, + Unconsumed: u0, pub const Tag = enum { ArrayBuffer, String, Empty, + Unconsumed, }; pub fn length(value: *const Value) usize { @@ -719,7 +1067,7 @@ pub const Body = struct { .String => |str| { return str.len; }, - .Empty => { + else => { return 0; }, } @@ -783,6 +1131,8 @@ pub const Body = struct { } body.value = Value{ .String = str.characters8()[0..len] }; + body.ptr = @intToPtr([*]u8, @ptrToInt(body.value.String.ptr)); + body.len = body.value.String.len; return body; }, .kJSTypeObject => { @@ -807,6 +1157,8 @@ pub const Body = struct { } else |err| {} } body.value = Value{ .ArrayBuffer = buffer }; + body.ptr = buffer.ptr[buffer.offset..buffer.byte_len].ptr; + body.len = buffer.ptr[buffer.offset..buffer.byte_len].len; return body; }, } diff --git a/src/node-fallbacks/@vercel_fetch.js b/src/node-fallbacks/@vercel_fetch.js new file mode 100644 index 000000000..5ab626670 --- /dev/null +++ b/src/node-fallbacks/@vercel_fetch.js @@ -0,0 +1,31 @@ +// This is just a no-op. Intent is to prevent importing a bunch of stuff that isn't relevant. +module.exports = (wrapper = Bun.fetch) => { + return async function vercelFetch(url, opts = {}) { + // Convert Object bodies to JSON if they are JS objects + if ( + opts.body && + typeof opts.body === "object" && + (!("buffer" in opts.body) || + typeof opts.body.buffer !== "object" || + !(opts.body.buffer instanceof ArrayBuffer)) + ) { + opts.body = JSON.stringify(opts.body); + // Content length will automatically be set + if (!opts.headers) opts.headers = new Headers(); + + opts.headers.set("Content-Type", "application/json"); + } + + try { + return await wrapper(url, opts); + } catch (err) { + if (typeof err === "string") { + err = new Error(err); + } + + err.url = url; + err.opts = opts; + throw err; + } + }; +}; diff --git a/src/node-fallbacks/isomorphic-fetch.js b/src/node-fallbacks/isomorphic-fetch.js new file mode 100644 index 000000000..0bbe50ebf --- /dev/null +++ b/src/node-fallbacks/isomorphic-fetch.js @@ -0,0 +1 @@ +export default Bun.fetch; diff --git a/src/node-fallbacks/node-fetch.js b/src/node-fallbacks/node-fetch.js new file mode 100644 index 000000000..0bbe50ebf --- /dev/null +++ b/src/node-fallbacks/node-fetch.js @@ -0,0 +1 @@ +export default Bun.fetch; diff --git a/src/node_fallbacks.zig b/src/node_fallbacks.zig index e7882635f..def1d6126 100644 --- a/src/node_fallbacks.zig +++ b/src/node_fallbacks.zig @@ -27,6 +27,13 @@ const _url_code: string = @embedFile("./node-fallbacks/out/url.js"); const _util_code: string = @embedFile("./node-fallbacks/out/util.js"); const _zlib_code: string = @embedFile("./node-fallbacks/out/zlib.js"); +const _node_fetch_code: string = @embedFile("./node-fallbacks/out/node-fetch.js"); +const _isomorphic_fetch_code: string = @embedFile("./node-fallbacks/out/isomorphic-fetch.js"); +const _vercel_fetch_code: string = @embedFile("./node-fallbacks/out/@vercel_fetch.js"); +const node_fetch_code: *const string = &_node_fetch_code; +const isomorphic_fetch_code: *const string = &_isomorphic_fetch_code; +const vercel_fetch_code: *const string = &_vercel_fetch_code; + const assert_code: *const string = &_assert_code; const buffer_code: *const string = &_buffer_code; const console_code: *const string = &_console_code; @@ -73,6 +80,10 @@ const url_import_path = "/bun-vfs/node_modules/url/index.js"; const util_import_path = "/bun-vfs/node_modules/util/index.js"; const zlib_import_path = "/bun-vfs/node_modules/zlib/index.js"; +const node_fetch_import_path = "/bun-vfs/node_modules/node-fetch/index.js"; +const isomorphic_fetch_import_path = "/bun-vfs/node_modules/isomorphic-fetch/index.js"; +const vercel_fetch_import_path = "/bun-vfs/node_modules/@vercel/fetch/index.js"; + const assert_package_json = PackageJSON{ .name = "assert", .version = "0.0.0-polyfill", @@ -277,6 +288,34 @@ const zlib_package_json = PackageJSON{ .source = logger.Source.initPathString("/bun-vfs/node_modules/zlib/package.json", ""), }; +const node_fetch_package_json = PackageJSON{ + .name = "node-fetch", + .version = "0.0.0-polyfill", + .module_type = .cjs, + .hash = @truncate(u32, std.hash.Wyhash.hash(0, "node-fetch@0.0.0-polyfill")), + .main_fields = undefined, + .browser_map = undefined, + .source = logger.Source.initPathString("/bun-vfs/node_modules/node-fetch/package.json", ""), +}; +const isomorphic_fetch_package_json = PackageJSON{ + .name = "isomorphic-fetch", + .version = "0.0.0-polyfill", + .module_type = .cjs, + .hash = @truncate(u32, std.hash.Wyhash.hash(0, "isomorphic-fetch@0.0.0-polyfill")), + .main_fields = undefined, + .browser_map = undefined, + .source = logger.Source.initPathString("/bun-vfs/node_modules/isomorphic-fetch/package.json", ""), +}; +const vercel_fetch_package_json = PackageJSON{ + .name = "@vercel/fetch", + .version = "0.0.0-polyfill", + .module_type = .cjs, + .hash = @truncate(u32, std.hash.Wyhash.hash(0, "@vercel/fetch@0.0.0-polyfill")), + .main_fields = undefined, + .browser_map = undefined, + .source = logger.Source.initPathString("/bun-vfs/node_modules/@vercel/fetch/package.json", ""), +}; + pub const FallbackModule = struct { path: Fs.Path, code: *const string, @@ -392,6 +431,24 @@ pub const FallbackModule = struct { .code = zlib_code, .package_json = &zlib_package_json, }; + + pub const @"node-fetch" = FallbackModule{ + .path = Fs.Path.initWithNamespaceVirtual(node_fetch_import_path, "node", "node-fetch"), + .code = node_fetch_code, + .package_json = &node_fetch_package_json, + }; + + pub const @"isomorphic-fetch" = FallbackModule{ + .path = Fs.Path.initWithNamespaceVirtual(isomorphic_fetch_import_path, "node", "isomorphic-fetch"), + .code = isomorphic_fetch_code, + .package_json = &isomorphic_fetch_package_json, + }; + + pub const @"@vercel/fetch" = FallbackModule{ + .path = Fs.Path.initWithNamespaceVirtual(vercel_fetch_import_path, "node", "@vercel/fetch"), + .code = vercel_fetch_code, + .package_json = &vercel_fetch_package_json, + }; }; pub const Map = std.ComptimeStringMap(FallbackModule, .{ @@ -417,4 +474,8 @@ pub const Map = std.ComptimeStringMap(FallbackModule, .{ &.{ "url", FallbackModule.url }, &.{ "util", FallbackModule.util }, &.{ "zlib", FallbackModule.zlib }, + + &.{ "node-fetch", FallbackModule.@"node-fetch" }, + &.{ "isomorphic-fetch", FallbackModule.@"isomorphic-fetch" }, + &.{ "@vercel/fetch", FallbackModule.@"@vercel/fetch" }, }); diff --git a/src/query_string_map.zig b/src/query_string_map.zig index 1bac8f5ce..ddf90750f 100644 --- a/src/query_string_map.zig +++ b/src/query_string_map.zig @@ -20,6 +20,37 @@ pub const URL = struct { username: string = "", port_was_automatically_set: bool = false, + pub fn isDomainName(this: *const URL) bool { + for (this.hostname) |c, i| { + switch (c) { + '0'...'9', '.', ':' => {}, + else => { + return true; + }, + } + } + + return false; + } + + pub fn isLocalhost(this: *const URL) bool { + return this.hostname.len == 0 or strings.eqlComptime(this.hostname, "localhost") or strings.eqlComptime(this.hostname, "0.0.0.0"); + } + + pub fn getIPv4Address(this: *const URL) ?std.x.net.ip.Address.IPv4 { + return (if (this.hostname.length > 0) + std.x.os.IPv4.parse(this.hostname) + else + std.x.os.IPv4.parse(this.href)) catch return null; + } + + pub fn getIPv6Address(this: *const URL) ?std.x.net.ip.Address.IPv6 { + return (if (this.hostname.length > 0) + std.x.os.IPv6.parse(this.hostname) + else + std.x.os.IPv6.parse(this.href)) catch return null; + } + pub fn displayProtocol(this: *const URL) string { if (this.protocol.len > 0) { return this.protocol; @@ -34,6 +65,10 @@ pub const URL = struct { return "http"; } + pub inline fn isHTTPS(this: *const URL) bool { + return strings.eqlComptime(this.protocol, "https"); + } + pub fn displayHostname(this: *const URL) string { if (this.hostname.len > 0) { return this.hostname; @@ -50,6 +85,10 @@ pub const URL = struct { return std.fmt.parseInt(u16, this.port, 10) catch null; } + pub fn getPortAuto(this: *const URL) u16 { + return this.getPort() orelse (if (this.isHTTPS()) @as(u16, 443) else @as(u16, 80)); + } + pub fn hasValidPort(this: *const URL) bool { return (this.getPort() orelse 0) > 1; } diff --git a/src/string_immutable.zig b/src/string_immutable.zig index 238706f93..a68059aef 100644 --- a/src/string_immutable.zig +++ b/src/string_immutable.zig @@ -124,8 +124,8 @@ pub const StringOrTinyString = struct { pub fn copyLowercase(in: string, out: []u8) string { @setRuntimeSafety(false); - var in_slice = in; - var out_slice = out[0..in.len]; + var in_slice: string = in; + var out_slice: []u8 = out[0..in.len]; begin: while (out_slice.len > 0) { @setRuntimeSafety(false); |