From 5fa13625a1ca0ea1a3a1c5bb86d0880dcfac349f Mon Sep 17 00:00:00 2001 From: Dylan Conway <35280289+dylan-conway@users.noreply.github.com> Date: Wed, 21 Jun 2023 23:38:18 -0700 Subject: upgrade zig to `v0.11.0-dev.3737+9eb008717` (#3374) * progress * finish `@memset/@memcpy` update * Update build.zig * change `@enumToInt` to `@intFromEnum` and friends * update zig versions * it was 1 * add link to issue * add `compileError` reminder * fix merge * format * upgrade to llvm 16 * Revert "upgrade to llvm 16" This reverts commit cc930ceb1c5b4db9614a7638596948f704544ab8. --------- Co-authored-by: Jarred Sumner Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> --- src/bun.js/api/bun.zig | 92 +++++++++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 46 deletions(-) (limited to 'src/bun.js/api/bun.zig') diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 5580e8840..9df125a58 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -303,7 +303,7 @@ pub fn registerMacro( return js.JSValueMakeUndefined(ctx); } // TODO: make this faster - const id = @truncate(i32, @floatToInt(i64, js.JSValueToNumber(ctx, arguments[0], exception))); + const id = @truncate(i32, @intFromFloat(i64, js.JSValueToNumber(ctx, arguments[0], exception))); if (id == -1 or id == 0) { JSError(getAllocator(ctx), "Internal error registering macros: invalid id", .{}, ctx, exception); return js.JSValueMakeUndefined(ctx); @@ -523,7 +523,7 @@ pub fn getFilePath(ctx: js.JSContextRef, arguments: []const js.JSValueRef, buf: temp_strings_list[temp_strings_list_len] = out_slice; // The dots are kind of unnecessary. They'll be normalized. - if (out.len == 0 or @ptrToInt(out.ptr) == 0 or std.mem.eql(u8, out_slice, ".") or std.mem.eql(u8, out_slice, "..") or std.mem.eql(u8, out_slice, "../")) { + if (out.len == 0 or @intFromPtr(out.ptr) == 0 or std.mem.eql(u8, out_slice, ".") or std.mem.eql(u8, out_slice, "..") or std.mem.eql(u8, out_slice, "../")) { JSError(getAllocator(ctx), "Expected a file path as a string or an array of strings to be part of a file path.", .{}, ctx, exception); return null; } @@ -600,7 +600,7 @@ pub fn readFileAsStringCallback( return js.JSValueMakeUndefined(ctx); }; - if (stat.kind != .File) { + if (stat.kind != .file) { JSError(getAllocator(ctx), "Can't read a {s} as a string (\"{s}\")", .{ @tagName(stat.kind), path }, ctx, exception); return js.JSValueMakeUndefined(ctx); } @@ -641,7 +641,7 @@ pub fn readFileAsBytesCallback( return js.JSValueMakeUndefined(ctx); }; - if (stat.kind != .File) { + if (stat.kind != .file) { JSError(allocator, "Can't read a {s} as a string (\"{s}\")", .{ @tagName(stat.kind), path }, ctx, exception); return js.JSValueMakeUndefined(ctx); } @@ -1612,7 +1612,7 @@ pub const Crypto = struct { fn createCryptoError(globalThis: *JSC.JSGlobalObject, err_code: u32) JSValue { var outbuf: [128 + 1 + "BoringSSL error: ".len]u8 = undefined; - @memset(&outbuf, 0, outbuf.len); + @memset(&outbuf, 0); outbuf[0.."BoringSSL error: ".len].* = "BoringSSL error: ".*; var message_buf = outbuf["BoringSSL error: ".len..]; @@ -3171,9 +3171,9 @@ pub fn mmapFile( return JSC.C.JSObjectMakeTypedArrayWithBytesNoCopy(ctx, JSC.C.JSTypedArrayType.kJSTypedArrayTypeUint8Array, @ptrCast(?*anyopaque, map.ptr), map.len, struct { pub fn x(ptr: ?*anyopaque, size: ?*anyopaque) callconv(.C) void { - _ = JSC.Node.Syscall.munmap(@ptrCast([*]align(std.mem.page_size) u8, @alignCast(std.mem.page_size, ptr))[0..@ptrToInt(size)]); + _ = JSC.Node.Syscall.munmap(@ptrCast([*]align(std.mem.page_size) u8, @alignCast(std.mem.page_size, ptr))[0..@intFromPtr(size)]); } - }.x, @intToPtr(?*anyopaque, map.len), exception); + }.x, @ptrFromInt(?*anyopaque, map.len), exception); } pub fn getTranspilerConstructor( @@ -3401,7 +3401,7 @@ pub const Unsafe = struct { globalThis: *JSC.JSGlobalObject, value_: ?JSValue, ) JSValue { - const ret = JSValue.jsNumber(@as(i32, @enumToInt(globalThis.bunVM().aggressive_garbage_collection))); + const ret = JSValue.jsNumber(@as(i32, @intFromEnum(globalThis.bunVM().aggressive_garbage_collection))); if (value_) |value| { switch (value.coerce(i32, globalThis)) { @@ -3912,7 +3912,7 @@ pub const Timer = struct { id, Timeout.run, this.interval, - @as(i32, @boolToInt(this.kind == .setInterval)) * this.interval, + @as(i32, @intFromBool(this.kind == .setInterval)) * this.interval, ); return this_value; } @@ -4130,7 +4130,7 @@ pub const Timer = struct { }, Timeout.run, interval, - @as(i32, @boolToInt(kind == .setInterval)) * interval, + @as(i32, @intFromBool(kind == .setInterval)) * interval, ); } @@ -4318,7 +4318,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) u8, addr).*; + const value = @ptrFromInt(*align(1) u8, addr).*; return JSValue.jsNumber(value); } pub fn @"u16"( @@ -4327,7 +4327,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) u16, addr).*; + const value = @ptrFromInt(*align(1) u16, addr).*; return JSValue.jsNumber(value); } pub fn @"u32"( @@ -4336,7 +4336,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) u32, addr).*; + const value = @ptrFromInt(*align(1) u32, addr).*; return JSValue.jsNumber(value); } pub fn ptr( @@ -4345,7 +4345,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) u64, addr).*; + const value = @ptrFromInt(*align(1) u64, addr).*; return JSValue.jsNumber(value); } pub fn @"i8"( @@ -4354,7 +4354,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) i8, addr).*; + const value = @ptrFromInt(*align(1) i8, addr).*; return JSValue.jsNumber(value); } pub fn @"i16"( @@ -4363,7 +4363,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) i16, addr).*; + const value = @ptrFromInt(*align(1) i16, addr).*; return JSValue.jsNumber(value); } pub fn @"i32"( @@ -4372,7 +4372,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) i32, addr).*; + const value = @ptrFromInt(*align(1) i32, addr).*; return JSValue.jsNumber(value); } pub fn intptr( @@ -4381,7 +4381,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) i64, addr).*; + const value = @ptrFromInt(*align(1) i64, addr).*; return JSValue.jsNumber(value); } @@ -4391,7 +4391,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) f32, addr).*; + const value = @ptrFromInt(*align(1) f32, addr).*; return JSValue.jsNumber(value); } @@ -4401,7 +4401,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) f64, addr).*; + const value = @ptrFromInt(*align(1) f64, addr).*; return JSValue.jsNumber(value); } @@ -4411,7 +4411,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) i64, addr).*; + const value = @ptrFromInt(*align(1) i64, addr).*; return JSValue.fromInt64NoTruncate(global, value); } @@ -4421,7 +4421,7 @@ pub const FFI = struct { arguments: []const JSValue, ) JSValue { const addr = arguments[0].asPtrAddress() + if (arguments.len > 1) @intCast(usize, arguments[1].to(i32)) else @as(usize, 0); - const value = @intToPtr(*align(1) u64, addr).*; + const value = @ptrFromInt(*align(1) u64, addr).*; return JSValue.fromUInt64NoTruncate(global, value); } @@ -4432,7 +4432,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) u8, addr).*; + const value = @ptrFromInt(*align(1) u8, addr).*; return JSValue.jsNumber(value); } pub fn u16WithoutTypeChecks( @@ -4442,7 +4442,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) u16, addr).*; + const value = @ptrFromInt(*align(1) u16, addr).*; return JSValue.jsNumber(value); } pub fn u32WithoutTypeChecks( @@ -4452,7 +4452,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) u32, addr).*; + const value = @ptrFromInt(*align(1) u32, addr).*; return JSValue.jsNumber(value); } pub fn ptrWithoutTypeChecks( @@ -4462,7 +4462,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) u64, addr).*; + const value = @ptrFromInt(*align(1) u64, addr).*; return JSValue.jsNumber(value); } pub fn i8WithoutTypeChecks( @@ -4472,7 +4472,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) i8, addr).*; + const value = @ptrFromInt(*align(1) i8, addr).*; return JSValue.jsNumber(value); } pub fn i16WithoutTypeChecks( @@ -4482,7 +4482,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) i16, addr).*; + const value = @ptrFromInt(*align(1) i16, addr).*; return JSValue.jsNumber(value); } pub fn i32WithoutTypeChecks( @@ -4492,7 +4492,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) i32, addr).*; + const value = @ptrFromInt(*align(1) i32, addr).*; return JSValue.jsNumber(value); } pub fn intptrWithoutTypeChecks( @@ -4502,7 +4502,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) i64, addr).*; + const value = @ptrFromInt(*align(1) i64, addr).*; return JSValue.jsNumber(value); } @@ -4513,7 +4513,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) f32, addr).*; + const value = @ptrFromInt(*align(1) f32, addr).*; return JSValue.jsNumber(value); } @@ -4524,7 +4524,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) f64, addr).*; + const value = @ptrFromInt(*align(1) f64, addr).*; return JSValue.jsNumber(value); } @@ -4535,7 +4535,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) u64, addr).*; + const value = @ptrFromInt(*align(1) u64, addr).*; return JSValue.fromUInt64NoTruncate(global, value); } @@ -4546,7 +4546,7 @@ pub const FFI = struct { offset: i32, ) callconv(.C) JSValue { const addr = @intCast(usize, raw_addr) + @intCast(usize, offset); - const value = @intToPtr(*align(1) i64, addr).*; + const value = @ptrFromInt(*align(1) i64, addr).*; return JSValue.fromInt64NoTruncate(global, value); } @@ -4590,7 +4590,7 @@ pub const FFI = struct { _: *anyopaque, array: *JSC.JSUint8Array, ) callconv(.C) JSValue { - return JSValue.fromPtrAddress(@ptrToInt(array.ptr())); + return JSValue.fromPtrAddress(@intFromPtr(array.ptr())); } fn ptr_( @@ -4610,9 +4610,9 @@ pub const FFI = struct { return JSC.toInvalidArguments("ArrayBufferView must have a length > 0. A pointer to empty memory doesn't work", .{}, globalThis); } - var addr: usize = @ptrToInt(array_buffer.ptr); + var addr: usize = @intFromPtr(array_buffer.ptr); // const Sizes = @import("../bindings/sizes.zig"); - // std.debug.assert(addr == @ptrToInt(value.asEncoded().ptr) + Sizes.Bun_FFI_PointerOffsetToTypedArrayVector); + // std.debug.assert(addr == @intFromPtr(value.asEncoded().ptr) + Sizes.Bun_FFI_PointerOffsetToTypedArrayVector); if (byteOffset) |off| { if (!off.isEmptyOrUndefinedOrNull()) { @@ -4628,7 +4628,7 @@ pub const FFI = struct { addr += @intCast(usize, bytei64); } - if (addr > @ptrToInt(array_buffer.ptr) + @as(usize, array_buffer.byte_len)) { + if (addr > @intFromPtr(array_buffer.ptr) + @as(usize, array_buffer.byte_len)) { return JSC.toInvalidArguments("byteOffset out of bounds", .{}, globalThis); } } @@ -4720,11 +4720,11 @@ pub const FFI = struct { } const length = @intCast(usize, length_i); - return .{ .slice = @intToPtr([*]u8, addr)[0..length] }; + return .{ .slice = @ptrFromInt([*]u8, addr)[0..length] }; } } - return .{ .slice = bun.span(@intToPtr([*:0]u8, addr)) }; + return .{ .slice = bun.span(@ptrFromInt([*:0]u8, addr)) }; } fn getCPtr(value: JSValue) ?usize { @@ -4759,11 +4759,11 @@ pub const FFI = struct { var ctx: ?*anyopaque = null; if (finalizationCallback) |callback_value| { if (getCPtr(callback_value)) |callback_ptr| { - callback = @intToPtr(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); + callback = @ptrFromInt(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); if (finalizationCtxOrPtr) |ctx_value| { if (getCPtr(ctx_value)) |ctx_ptr| { - ctx = @intToPtr(*anyopaque, ctx_ptr); + ctx = @ptrFromInt(*anyopaque, ctx_ptr); } else if (!ctx_value.isUndefinedOrNull()) { return JSC.toInvalidArguments("Expected user data to be a C pointer (number or BigInt)", .{}, globalThis); } @@ -4773,7 +4773,7 @@ pub const FFI = struct { } } else if (finalizationCtxOrPtr) |callback_value| { if (getCPtr(callback_value)) |callback_ptr| { - callback = @intToPtr(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); + callback = @ptrFromInt(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); } else if (!callback_value.isEmptyOrUndefinedOrNull()) { return JSC.toInvalidArguments("Expected callback to be a C pointer (number or BigInt)", .{}, globalThis); } @@ -4801,11 +4801,11 @@ pub const FFI = struct { var ctx: ?*anyopaque = null; if (finalizationCallback) |callback_value| { if (getCPtr(callback_value)) |callback_ptr| { - callback = @intToPtr(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); + callback = @ptrFromInt(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); if (finalizationCtxOrPtr) |ctx_value| { if (getCPtr(ctx_value)) |ctx_ptr| { - ctx = @intToPtr(*anyopaque, ctx_ptr); + ctx = @ptrFromInt(*anyopaque, ctx_ptr); } else if (!ctx_value.isEmptyOrUndefinedOrNull()) { return JSC.toInvalidArguments("Expected user data to be a C pointer (number or BigInt)", .{}, globalThis); } @@ -4815,7 +4815,7 @@ pub const FFI = struct { } } else if (finalizationCtxOrPtr) |callback_value| { if (getCPtr(callback_value)) |callback_ptr| { - callback = @intToPtr(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); + callback = @ptrFromInt(JSC.C.JSTypedArrayBytesDeallocator, callback_ptr); } else if (!callback_value.isEmptyOrUndefinedOrNull()) { return JSC.toInvalidArguments("Expected callback to be a C pointer (number or BigInt)", .{}, globalThis); } -- cgit v1.2.3 From ca1fe3c602a19878e8cd3545494d6b5af7ed13c9 Mon Sep 17 00:00:00 2001 From: "Alex Lam S.L" Date: Fri, 23 Jun 2023 03:05:54 +0300 Subject: revamp dotEnv parser (#3347) - fixes `strings.indexOfAny()` - fixes OOB array access fixes #411 fixes #2823 fixes #3042 --- src/bun.js/api/bun.zig | 8 +- src/env_loader.zig | 667 ++++++++++++++++-------------------------- src/string_immutable.zig | 38 ++- test/cli/install/bunx.test.ts | 16 +- test/cli/run/env.test.ts | 82 +++++- 5 files changed, 371 insertions(+), 440 deletions(-) (limited to 'src/bun.js/api/bun.zig') diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 9df125a58..034aaa81f 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -4935,11 +4935,11 @@ pub const EnvironmentVariables = struct { pub fn getEnvNames(globalObject: *JSC.JSGlobalObject, names: []ZigString) usize { var vm = globalObject.bunVM(); const keys = vm.bundler.env.map.map.keys(); - const max = @min(names.len, keys.len); - for (keys[0..max], 0..) |key, i| { - names[i] = ZigString.initUTF8(key); + const len = @min(names.len, keys.len); + for (keys[0..len], names[0..len]) |key, *name| { + name.* = ZigString.initUTF8(key); } - return keys.len; + return len; } pub fn getEnvValue(globalObject: *JSC.JSGlobalObject, name: ZigString) ?ZigString { var vm = globalObject.bunVM(); diff --git a/src/env_loader.zig b/src/env_loader.zig index 14e1196b6..74577e3f2 100644 --- a/src/env_loader.zig +++ b/src/env_loader.zig @@ -17,384 +17,6 @@ const Fs = @import("./fs.zig"); const URL = @import("./url.zig").URL; const Api = @import("./api/schema.zig").Api; const which = @import("./which.zig").which; -const Variable = struct { - key: string, - value: string, - has_nested_value: bool = false, -}; - -// i don't expect anyone to actually use the escape line feed character -const escLineFeed = 0x0C; -// arbitrary character that is invalid in a real text file -const implicitQuoteCharacter = 8; - -// you get 4k. I hope you don't need more than that. -threadlocal var temporary_nested_value_buffer: [4096]u8 = undefined; - -pub const Lexer = struct { - source: *const logger.Source, - iter: CodepointIterator, - cursor: CodepointIterator.Cursor = CodepointIterator.Cursor{}, - _codepoint: CodePoint = 0, - current: usize = 0, - last_non_space: usize = 0, - prev_non_space: usize = 0, - start: usize = 0, - end: usize = 0, - has_nested_value: bool = false, - has_newline_before: bool = true, - was_quoted: bool = false, - - pub inline fn codepoint(this: *const Lexer) CodePoint { - return this.cursor.c; - } - - pub inline fn step(this: *Lexer) void { - const ended = !this.iter.next(&this.cursor); - if (ended) this.cursor.c = -1; - this.current = this.cursor.i + @as(usize, @intFromBool(ended)); - } - - pub fn eatNestedValue( - _: *Lexer, - comptime ContextType: type, - ctx: *ContextType, - comptime Writer: type, - writer: Writer, - variable: Variable, - comptime getter: fn (ctx: *const ContextType, key: string) ?string, - ) !void { - var i: usize = 0; - var last_flush: usize = 0; - - top: while (i < variable.value.len) { - switch (variable.value[i]) { - '$' => { - i += 1; - const start = i; - - const curly_braces_offset = @as(usize, @intFromBool(variable.value[i] == '{')); - i += curly_braces_offset; - - while (i < variable.value.len) { - switch (variable.value[i]) { - 'a'...'z', 'A'...'Z', '0'...'9', '-', '_' => { - i += 1; - }, - '}' => { - i += curly_braces_offset; - break; - }, - else => { - break; - }, - } - } - - try writer.writeAll(variable.value[last_flush .. start - 1]); - last_flush = i; - const name = variable.value[start + curly_braces_offset .. i - curly_braces_offset]; - - if (@call(.always_inline, getter, .{ ctx, name })) |new_value| { - if (new_value.len > 0) { - try writer.writeAll(new_value); - } - } - - continue :top; - }, - '\\' => { - i += 1; - switch (variable.value[i]) { - '$' => { - i += 1; - continue; - }, - else => {}, - } - }, - else => {}, - } - i += 1; - } - - try writer.writeAll(variable.value[last_flush..]); - } - - pub fn eatValue( - lexer: *Lexer, - comptime quote: CodePoint, - ) string { - var was_quoted = false; - switch (comptime quote) { - '"', '\'' => { - lexer.step(); - was_quoted = true; - }, - - else => {}, - } - - var start = lexer.current; - var last_non_space: usize = start; - var any_spaces = false; - - while (true) { - switch (lexer.codepoint()) { - '\\' => { - lexer.step(); - // Handle Windows CRLF - - switch (lexer.codepoint()) { - '\r' => { - lexer.step(); - if (lexer.codepoint() == '\n') { - lexer.step(); - } - continue; - }, - '$' => { - lexer.step(); - continue; - }, - else => { - continue; - }, - } - }, - -1 => { - lexer.end = lexer.current; - - return lexer.source.contents[start..if (any_spaces) @min(last_non_space, lexer.source.contents.len) else lexer.source.contents.len]; - }, - '$' => { - lexer.has_nested_value = true; - }, - - '#' => { - lexer.step(); - lexer.eatComment(); - - return lexer.source.contents[start .. last_non_space + 1]; - }, - - '\n', '\r', escLineFeed => { - switch (comptime quote) { - '\'' => { - lexer.end = lexer.current; - lexer.step(); - return lexer.source.contents[start..@min(lexer.end, lexer.source.contents.len)]; - }, - implicitQuoteCharacter => { - lexer.end = lexer.current; - lexer.step(); - - return lexer.source.contents[start..@min(if (any_spaces) last_non_space + 1 else lexer.end, lexer.end)]; - }, - '"' => { - // We keep going - }, - else => {}, - } - }, - quote => { - lexer.end = lexer.current; - lexer.step(); - - lexer.was_quoted = was_quoted; - return lexer.source.contents[start..@min( - lexer.end, - lexer.source.contents.len, - )]; - }, - ' ' => { - any_spaces = true; - while (lexer.codepoint() == ' ') lexer.step(); - continue; - }, - else => {}, - } - if (lexer.codepoint() != ' ') last_non_space = lexer.current; - lexer.step(); - } - unreachable; - } - - pub fn eatComment(this: *Lexer) void { - while (true) { - switch (this.codepoint()) { - '\r' => { - this.step(); - if (this.codepoint() == '\n') { - return; - } - }, - '\n' => { - this.step(); - return; - }, - -1 => { - return; - }, - else => { - this.step(); - }, - } - } - } - - // const NEWLINE = '\n' - // const RE_INI_KEY_VAL = /^\s*([\w.-]+)\s*=\s*(.*)?\s*$/ - // const RE_NEWLINES = /\\n/g - // const NEWLINES_MATCH = /\r\n|\n|\r/ - pub fn next(this: *Lexer, comptime is_process_env: bool) ?Variable { - if (this.end == 0) this.step(); - - const start = this.start; - - this.has_newline_before = this.end == 0; - - var last_non_space = start; - restart: while (true) { - last_non_space = switch (this.codepoint()) { - ' ', '\r', '\n' => last_non_space, - else => this.current, - }; - - switch (this.codepoint()) { - 0, -1 => { - return null; - }, - '#' => { - this.step(); - - this.eatComment(); - continue :restart; - }, - '\r', '\n', 0x2028, 0x2029 => { - this.step(); - this.has_newline_before = true; - continue; - }, - - // Valid keys: - 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => { - this.start = this.current; - this.step(); - var key_end: usize = 0; - while (true) { - switch (this.codepoint()) { - - // to match npm's "dotenv" behavior, we ignore lines that don't have an equals - '\r', '\n', escLineFeed => { - this.end = this.current; - this.step(); - continue :restart; - }, - 0, -1 => { - this.end = this.current; - return if (last_non_space > this.start) - Variable{ .key = this.source.contents[this.start..@min(last_non_space + 1, this.source.contents.len)], .value = "" } - else - null; - }, - 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => {}, - '=' => { - this.end = this.current; - if (key_end > 0) { - this.end = key_end; - } - const key = this.source.contents[this.start..this.end]; - if (key.len == 0) return null; - this.step(); - - // we don't need to do special parsing on process-level environment variable values - // if they're quoted, we should keep them quoted. - // https://github.com/oven-sh/bun/issues/40 - if (comptime is_process_env) { - const current = this.current; - // TODO: remove this loop - // it's not as simple as just setting to the end of the string - while (this.codepoint() != -1) : (this.step()) {} - return Variable{ - .key = key, - .value = this.source.contents[current..], - // nested values are unsupported in process environment variables - .has_nested_value = false, - }; - } - - this.has_nested_value = false; - inner: while (true) { - switch (this.codepoint()) { - '"' => { - const value = this.eatValue('"'); - return Variable{ - .key = key, - .value = value, - .has_nested_value = this.has_nested_value, - }; - }, - '\'' => { - const value = this.eatValue('\''); - return Variable{ - .key = key, - .value = value, - .has_nested_value = this.has_nested_value, - }; - }, - 0, -1 => { - return Variable{ .key = key, .value = "" }; - }, - '\r', '\n', escLineFeed => { - this.step(); - return Variable{ .key = key, .value = "" }; - }, - // consume unquoted leading spaces - ' ' => { - this.step(); - while (this.codepoint() == ' ') this.step(); - continue :inner; - }, - // we treat everything else the same as if it were wrapped in single quotes - // except we don't terminate on that character - else => { - const value = this.eatValue(implicitQuoteCharacter); - return Variable{ - .key = key, - .value = value, - .has_nested_value = this.has_nested_value, - }; - }, - } - } - }, - ' ' => { - // Set key end to the last non space character - key_end = this.current; - this.step(); - while (this.codepoint() == ' ') this.step(); - continue; - }, - else => {}, - } - this.step(); - } - }, - else => {}, - } - - this.step(); - } - } - - pub fn init(source: *const logger.Source) Lexer { - return Lexer{ - .source = source, - .iter = CodepointIterator.init(source.contents), - }; - } -}; pub const Loader = struct { map: *Map, @@ -912,58 +534,263 @@ pub const Loader = struct { } }; -pub const Parser = struct { - pub fn parse( - source: *const logger.Source, - allocator: std.mem.Allocator, - map: *Map, - comptime override: bool, - comptime is_process: bool, - ) void { - var lexer = Lexer.init(source); - var fbs = std.io.fixedBufferStream(&temporary_nested_value_buffer); - var writer = fbs.writer(); - const start_count = map.map.count(); - - while (lexer.next(is_process)) |variable| { - if (variable.has_nested_value) { - writer.context.reset(); - - lexer.eatNestedValue(Map, map, @TypeOf(writer), writer, variable, Map.get_) catch unreachable; - const new_value = fbs.buffer[0..fbs.pos]; - if (new_value.len > 0) { - if (comptime override) { - map.put(variable.key, allocator.dupe(u8, new_value) catch unreachable) catch unreachable; - } else { - var putter = map.map.getOrPut(variable.key) catch unreachable; - // Allow keys defined later in the same file to override keys defined earlier - // https://github.com/oven-sh/bun/issues/1262 - if (!putter.found_existing or putter.index >= start_count) { - if (putter.found_existing and putter.value_ptr.len > 0) { - allocator.free(putter.value_ptr.*); - } +const Parser = struct { + pos: usize = 0, + src: string, + + const whitespace_chars = "\t\x0B\x0C \xA0\n\r"; + // You get 4k. I hope you don't need more than that. + threadlocal var value_buffer: [4096]u8 = undefined; + + fn skipLine(this: *Parser) void { + if (strings.indexOfAny(this.src[this.pos..], "\n\r")) |i| { + this.pos += i + 1; + } else { + this.pos = this.src.len; + } + } - putter.value_ptr.* = allocator.dupe(u8, new_value) catch unreachable; + fn skipWhitespaces(this: *Parser) void { + var i = this.pos; + while (i < this.src.len) : (i += 1) { + if (strings.indexOfChar(whitespace_chars, this.src[i]) == null) break; + } + this.pos = i; + } + + fn parseKey(this: *Parser, comptime check_export: bool) ?string { + if (comptime check_export) this.skipWhitespaces(); + const start = this.pos; + var end = start; + while (end < this.src.len) : (end += 1) { + switch (this.src[end]) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => continue, + else => break, + } + } + if (end < this.src.len and start < end) { + this.pos = end; + this.skipWhitespaces(); + if (this.pos < this.src.len) { + if (comptime check_export) { + if (end < this.pos and strings.eqlComptime(this.src[start..end], "export")) { + if (this.parseKey(false)) |key| return key; + } + } + switch (this.src[this.pos]) { + '=' => { + this.pos += 1; + return this.src[start..end]; + }, + ':' => { + const next = this.pos + 1; + if (next < this.src.len and strings.indexOfChar(whitespace_chars, this.src[next]) != null) { + this.pos += 2; + return this.src[start..end]; } + }, + else => {}, + } + } + } + this.pos = start; + return null; + } + + fn parseQuoted(this: *Parser, comptime quote: u8) ?string { + if (comptime Environment.allow_assert) std.debug.assert(this.src[this.pos] == quote); + const start = this.pos; + var end = start + 1; + while (end < this.src.len) : (end += 1) { + switch (this.src[end]) { + '\\' => end += 1, + quote => { + end += 1; + this.pos = end; + this.skipWhitespaces(); + if (this.pos >= this.src.len or + this.src[this.pos] == '#' or + strings.indexOfChar(this.src[end..this.pos], '\n') != null or + strings.indexOfChar(this.src[end..this.pos], '\r') != null) + { + var ptr: usize = 0; + var i = start; + while (i < end) { + switch (this.src[i]) { + '\\' => if (comptime quote == '"') { + if (comptime Environment.allow_assert) std.debug.assert(i + 1 < end); + switch (this.src[i + 1]) { + 'n' => { + value_buffer[ptr] = '\n'; + ptr += 1; + i += 1; + }, + 'r' => { + value_buffer[ptr] = '\r'; + ptr += 1; + i += 1; + }, + else => { + value_buffer[ptr] = this.src[i]; + value_buffer[ptr + 1] = this.src[i + 1]; + ptr += 2; + i += 2; + }, + } + } else { + value_buffer[ptr] = '\\'; + ptr += 1; + i += 1; + }, + '\r' => { + i += 1; + if (i >= end or this.src[i] != '\n') { + value_buffer[ptr] = '\n'; + ptr += 1; + } + }, + else => |c| { + value_buffer[ptr] = c; + ptr += 1; + i += 1; + }, + } + } + return value_buffer[0..ptr]; } + this.pos = start; + }, + else => {}, + } + } + return null; + } + + fn parseValue(this: *Parser, comptime is_process: bool) string { + const start = this.pos; + this.skipWhitespaces(); + var end = this.pos; + if (end >= this.src.len) return this.src[this.src.len..]; + switch (this.src[end]) { + inline '`', '"', '\'' => |quote| { + if (this.parseQuoted(quote)) |value| { + return if (comptime is_process) value else value[1 .. value.len - 1]; } - } else { - if (comptime override) { - map.put(variable.key, variable.value) catch unreachable; + }, + else => {}, + } + end = start; + while (end < this.src.len) : (end += 1) { + switch (this.src[end]) { + '#', '\r', '\n' => break, + else => {}, + } + } + this.pos = end; + return strings.trim(this.src[start..end], whitespace_chars); + } + + inline fn writeBackwards(ptr: usize, bytes: []const u8) usize { + const end = ptr; + const start = end - bytes.len; + bun.copy(u8, value_buffer[start..end], bytes); + return start; + } + + fn expandValue(map: *Map, value: string) ?string { + if (value.len < 2) return null; + var ptr = value_buffer.len; + var pos = value.len - 2; + var last = value.len; + while (true) : (pos -= 1) { + if (value[pos] == '$') { + if (pos > 0 and value[pos - 1] == '\\') { + ptr = writeBackwards(ptr, value[pos..last]); + pos -= 1; } else { - // Allow keys defined later in the same file to override keys defined earlier - // https://github.com/oven-sh/bun/issues/1262 - var putter = map.map.getOrPut(variable.key) catch unreachable; - if (!putter.found_existing or putter.index >= start_count) { - if (putter.found_existing and putter.value_ptr.len > 0) { - allocator.free(putter.value_ptr.*); + var end = if (value[pos + 1] == '{') pos + 2 else pos + 1; + const key_start = end; + while (end < value.len) : (end += 1) { + switch (value[end]) { + 'a'...'z', 'A'...'Z', '0'...'9', '_' => continue, + else => break, } - - putter.value_ptr.* = allocator.dupe(u8, variable.value) catch unreachable; } + const lookup_value = map.get(value[key_start..end]); + const default_value = if (strings.hasPrefixComptime(value[end..], ":-")) brk: { + end += ":-".len; + const value_start = end; + while (end < value.len) : (end += 1) { + switch (value[end]) { + '}', '\\' => break, + else => continue, + } + } + break :brk value[value_start..end]; + } else ""; + if (end < value.len and value[end] == '}') end += 1; + ptr = writeBackwards(ptr, value[end..last]); + ptr = writeBackwards(ptr, lookup_value orelse default_value); + } + last = pos; + } + if (pos == 0) { + if (last == value.len) return null; + break; + } + } + if (last > 0) ptr = writeBackwards(ptr, value[0..last]); + return value_buffer[ptr..]; + } + + fn _parse( + this: *Parser, + allocator: std.mem.Allocator, + map: *Map, + comptime override: bool, + comptime is_process: bool, + ) void { + var count = map.map.count(); + while (this.pos < this.src.len) { + const key = this.parseKey(true) orelse { + this.skipLine(); + continue; + }; + const value = this.parseValue(is_process); + var entry = map.map.getOrPut(key) catch unreachable; + if (entry.found_existing) { + if (entry.index < count) { + // Allow keys defined later in the same file to override keys defined earlier + // https://github.com/oven-sh/bun/issues/1262 + if (comptime !override) continue; + } else { + allocator.free(entry.value_ptr.*); } } + entry.value_ptr.* = allocator.dupe(u8, value) catch unreachable; } + if (comptime !is_process) { + var it = map.iter(); + while (it.next()) |entry| { + if (count > 0) { + count -= 1; + } else if (expandValue(map, entry.value_ptr.*)) |value| { + allocator.free(entry.value_ptr.*); + entry.value_ptr.* = allocator.dupe(u8, value) catch unreachable; + } + } + } + } + + pub fn parse( + source: *const logger.Source, + allocator: std.mem.Allocator, + map: *Map, + comptime override: bool, + comptime is_process: bool, + ) void { + var parser = Parser{ .src = source.contents }; + parser._parse(allocator, map, override, is_process); } }; diff --git a/src/string_immutable.zig b/src/string_immutable.zig index 0c90b03ff..0e98f3f5f 100644 --- a/src/string_immutable.zig +++ b/src/string_immutable.zig @@ -39,10 +39,40 @@ pub fn toUTF16Literal(comptime str: []const u8) []const u16 { } pub const OptionalUsize = std.meta.Int(.unsigned, @bitSizeOf(usize) - 1); -pub fn indexOfAny(self: string, comptime str: anytype) ?OptionalUsize { - inline for (str) |a| { - if (indexOfChar(self, a)) |i| { - return @intCast(OptionalUsize, i); +pub fn indexOfAny(slice: string, comptime str: anytype) ?OptionalUsize { + switch (comptime str.len) { + 0 => @compileError("str cannot be empty"), + 1 => return indexOfChar(slice, str[0]), + else => {}, + } + + var remaining = slice; + if (remaining.len == 0) return null; + + if (comptime Environment.enableSIMD) { + while (remaining.len >= ascii_vector_size) { + const vec: AsciiVector = remaining[0..ascii_vector_size].*; + var cmp = @bitCast(AsciiVectorU1, vec == @splat(ascii_vector_size, @as(u8, str[0]))); + inline for (str[1..]) |c| { + cmp |= @bitCast(AsciiVectorU1, vec == @splat(ascii_vector_size, @as(u8, c))); + } + + if (@reduce(.Max, cmp) > 0) { + const bitmask = @bitCast(AsciiVectorInt, cmp); + const first = @ctz(bitmask); + + return @intCast(OptionalUsize, first + slice.len - remaining.len); + } + + remaining = remaining[ascii_vector_size..]; + } + + if (comptime Environment.allow_assert) std.debug.assert(remaining.len < ascii_vector_size); + } + + for (remaining, 0..) |c, i| { + if (strings.indexOfChar(str, c) != null) { + return @intCast(OptionalUsize, i + slice.len - remaining.len); } } diff --git a/test/cli/install/bunx.test.ts b/test/cli/install/bunx.test.ts index 87ad2f8b4..3605f5b6b 100644 --- a/test/cli/install/bunx.test.ts +++ b/test/cli/install/bunx.test.ts @@ -1,7 +1,6 @@ -import { spawn } from "bun"; +import { file, spawn } from "bun"; import { afterEach, beforeEach, expect, it } from "bun:test"; import { bunExe, bunEnv as env } from "harness"; -import { realpathSync } from "fs"; import { mkdtemp, realpath, rm, writeFile } from "fs/promises"; import { tmpdir } from "os"; import { join } from "path"; @@ -10,7 +9,7 @@ import { readdirSorted } from "./dummy.registry"; let x_dir: string; beforeEach(async () => { - x_dir = realpathSync(await mkdtemp(join(tmpdir(), "bun-x.test"))); + x_dir = await realpath(await mkdtemp(join(tmpdir(), "bun-x.test"))); }); afterEach(async () => { await rm(x_dir, { force: true, recursive: true }); @@ -167,6 +166,16 @@ for (const entry of await decompress(Buffer.from(buffer))) { expect(stderr).toBeDefined(); const err = await new Response(stderr).text(); expect(err).toBe(""); + expect(await readdirSorted(x_dir)).toEqual([".cache", "test.js"]); + expect(await readdirSorted(join(x_dir, ".cache"))).toContain("decompress"); + expect(await readdirSorted(join(x_dir, ".cache", "decompress"))).toEqual(["4.2.1"]); + expect(await readdirSorted(join(x_dir, ".cache", "decompress", "4.2.1"))).toEqual([ + "index.js", + "license", + "package.json", + "readme.md", + ]); + expect(await file(join(x_dir, ".cache", "decompress", "4.2.1", "index.js")).text()).toContain("\nmodule.exports = "); expect(stdout).toBeDefined(); const out = await new Response(stdout).text(); expect(out.split(/\r?\n/)).toEqual([ @@ -176,7 +185,6 @@ for (const entry of await decompress(Buffer.from(buffer))) { "", ]); expect(await exited).toBe(0); - expect(await readdirSorted(x_dir)).toEqual([".cache", "test.js"]); }); it("should execute from current working directory", async () => { diff --git a/test/cli/run/env.test.ts b/test/cli/run/env.test.ts index ddc316f05..0cab610a5 100644 --- a/test/cli/run/env.test.ts +++ b/test/cli/run/env.test.ts @@ -194,6 +194,51 @@ describe("dotenv priority", () => { }); }); +test(".env colon assign", () => { + const dir = tempDirWithFiles("dotenv-colon", { + ".env": "FOO: foo", + "index.ts": "console.log(process.env.FOO);", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("foo"); +}); + +test(".env export assign", () => { + const dir = tempDirWithFiles("dotenv-export", { + ".env": "export FOO = foo\nexport = bar", + "index.ts": "console.log(process.env.FOO, process.env.export);", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("foo bar"); +}); + +test(".env value expansion", () => { + const dir = tempDirWithFiles("dotenv-expand", { + ".env": "FOO=foo\nBAR=$FOO bar\nMOO=${FOO} ${BAR:-fail} ${MOZ:-moo}", + "index.ts": "console.log([process.env.FOO, process.env.BAR, process.env.MOO].join('|'));", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("foo|foo bar|foo foo bar moo"); +}); + +test(".env comments", () => { + const dir = tempDirWithFiles("dotenv-comments", { + ".env": "#FOZ\nFOO = foo#FAIL\nBAR='bar' #BAZ", + "index.ts": "console.log(process.env.FOO, process.env.BAR);", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("foo bar"); +}); + +test(".env escaped dollar sign", () => { + const dir = tempDirWithFiles("dotenv-dollar", { + ".env": "FOO=foo\nBAR=\\$FOO", + "index.ts": "console.log(process.env.FOO, process.env.BAR);", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("foo $FOO"); +}); + test(".env doesnt crash with 159 bytes", () => { const dir = tempDirWithFiles("dotenv-159", { ".env": @@ -217,28 +262,49 @@ test(".env doesnt crash with 159 bytes", () => { ); }); -test.todo(".env space edgecase (issue #411)", () => { +test(".env with >768 entries", () => { + const dir = tempDirWithFiles("dotenv-many-entries", { + ".env": new Array(2000) + .fill(null) + .map((_, i) => `TEST_VAR${i}=TEST_VAL${i}`) + .join("\n"), + "index.ts": "console.log(process.env.TEST_VAR47);", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("TEST_VAL47"); +}); + +test(".env space edgecase (issue #411)", () => { const dir = tempDirWithFiles("dotenv-issue-411", { ".env": "VARNAME=A B", - "index.ts": "console.log('[' + process.env.VARNAME + ']'); ", + "index.ts": "console.log('[' + process.env.VARNAME + ']');", }); const { stdout } = bunRun(`${dir}/index.ts`); expect(stdout).toBe("[A B]"); }); -test.todo(".env special characters 1 (issue #2823)", () => { - const dir = tempDirWithFiles("dotenv-issue-411", { - ".env": 'A="a$t"\n', - "index.ts": "console.log('[' + process.env.A + ']'); ", +test(".env special characters 1 (issue #2823)", () => { + const dir = tempDirWithFiles("dotenv-issue-2823", { + ".env": 'A="a$t"\nC=`c\\$v`', + "index.ts": "console.log('[' + process.env.A + ']', '[' + process.env.C + ']');", }); const { stdout } = bunRun(`${dir}/index.ts`); - expect(stdout).toBe("[a$t]"); + expect(stdout).toBe("[a] [c$v]"); }); test.todo("env escaped quote (issue #2484)", () => { - const dir = tempDirWithFiles("dotenv-issue-411", { + const dir = tempDirWithFiles("env-issue-2484", { "index.ts": "console.log(process.env.VALUE, process.env.VALUE2);", }); const { stdout } = bunRun(`${dir}/index.ts`, { VALUE: `\\"`, VALUE2: `\\\\"` }); expect(stdout).toBe('\\" \\\\"'); }); + +test(".env Windows-style newline (issue #3042)", () => { + const dir = tempDirWithFiles("dotenv-issue-3042", { + ".env": "FOO=\rBAR='bar\r\rbaz'\r\nMOO=moo\r", + "index.ts": "console.log([process.env.FOO, process.env.BAR, process.env.MOO].join('|'));", + }); + const { stdout } = bunRun(`${dir}/index.ts`); + expect(stdout).toBe("|bar\n\nbaz|moo"); +}); -- cgit v1.2.3 From fdfbb18531828fc5dec329d5d9e5c828a3c83921 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Sun, 25 Jun 2023 16:32:27 -0700 Subject: Support reading embedded files in compiled executables (#3405) * Support reading embedded files in compiled executables * :nail_care: --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> --- src/bun.js/api/bun.zig | 3 +++ src/bun.js/api/server.zig | 22 +++++++++------------- src/bun.js/javascript.zig | 5 ++++- src/bun.js/node/node_fs.zig | 30 ++++++++++++++++++++++++++++++ src/bun.js/webcore/blob.zig | 36 ++++++++++++++++++++++++++++++++++-- src/cli/build_command.zig | 1 + src/standalone_bun.zig | 44 +++++++++++++++++++++++++++++++++++++++++++- 7 files changed, 124 insertions(+), 17 deletions(-) (limited to 'src/bun.js/api/bun.zig') diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 034aaa81f..2e6381c74 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -896,6 +896,9 @@ pub fn createNodeFS( ) js.JSValueRef { var module = ctx.allocator().create(JSC.Node.NodeJSFS) catch unreachable; module.* = .{}; + var vm = ctx.bunVM(); + if (vm.standalone_module_graph != null) + module.node_fs.vm = vm; return module.toJS(ctx).asObjectRef(); } diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index a56ff971f..ebfacdcc9 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -2744,19 +2744,15 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // 1. Bun.file("foo") // 2. The content-disposition header is not present if (!has_content_disposition and content_type.category.autosetFilename()) { - if (this.blob.store()) |store| { - if (store.data == .file) { - if (store.data.file.pathlike == .path) { - const basename = std.fs.path.basename(store.data.file.pathlike.path.slice()); - if (basename.len > 0) { - var filename_buf: [1024]u8 = undefined; - - resp.writeHeader( - "content-disposition", - std.fmt.bufPrint(&filename_buf, "filename=\"{s}\"", .{basename[0..@min(basename.len, 1024 - 32)]}) catch "", - ); - } - } + if (this.blob.getFileName()) |filename| { + const basename = std.fs.path.basename(filename); + if (basename.len > 0) { + var filename_buf: [1024]u8 = undefined; + + resp.writeHeader( + "content-disposition", + std.fmt.bufPrint(&filename_buf, "filename=\"{s}\"", .{basename[0..@min(basename.len, 1024 - 32)]}) catch "", + ); } } } diff --git a/src/bun.js/javascript.zig b/src/bun.js/javascript.zig index 3baa25e22..cb1a50f1d 100644 --- a/src/bun.js/javascript.zig +++ b/src/bun.js/javascript.zig @@ -593,7 +593,10 @@ pub const VirtualMachine = struct { pub inline fn nodeFS(this: *VirtualMachine) *Node.NodeFS { return this.node_fs orelse brk: { this.node_fs = bun.default_allocator.create(Node.NodeFS) catch unreachable; - this.node_fs.?.* = Node.NodeFS{}; + this.node_fs.?.* = Node.NodeFS{ + // only used when standalone module graph is enabled + .vm = if (this.standalone_module_graph != null) this else null, + }; break :brk this.node_fs.?; }; } diff --git a/src/bun.js/node/node_fs.zig b/src/bun.js/node/node_fs.zig index 21a65251a..35c616a89 100644 --- a/src/bun.js/node/node_fs.zig +++ b/src/bun.js/node/node_fs.zig @@ -2492,6 +2492,7 @@ pub const NodeFS = struct { /// That means a stack-allocated buffer won't suffice. Instead, we re-use /// the heap allocated buffer on the NodefS struct sync_error_buf: [bun.MAX_PATH_BYTES]u8 = undefined, + vm: ?*JSC.VirtualMachine = null, pub const ReturnType = Return; @@ -3442,6 +3443,35 @@ pub const NodeFS = struct { const fd = switch (args.path) { .path => brk: { path = args.path.path.sliceZ(&this.sync_error_buf); + if (this.vm) |vm| { + if (vm.standalone_module_graph) |graph| { + if (graph.find(path)) |file| { + if (args.encoding == .buffer) { + return .{ + .result = .{ + .buffer = Buffer.fromBytes( + bun.default_allocator.dupe(u8, file.contents) catch @panic("out of memory"), + bun.default_allocator, + .Uint8Array, + ), + }, + }; + } else if (comptime string_type == .default) + .{ + .result = .{ + .string = bun.default_allocator.dupe(u8, file.contents) catch @panic("out of memory"), + }, + } + else + .{ + .result = .{ + .null_terminated = bun.default_allocator.dupeZ(u8, file.contents) catch @panic("out of memory"), + }, + }; + } + } + } + break :brk switch (Syscall.open( path, os.O.RDONLY | os.O.NOCTTY, diff --git a/src/bun.js/webcore/blob.zig b/src/bun.js/webcore/blob.zig index 1e63ea3a2..868acbb80 100644 --- a/src/bun.js/webcore/blob.zig +++ b/src/bun.js/webcore/blob.zig @@ -952,6 +952,13 @@ pub const Blob = struct { switch (path_) { .path => { const slice = path_.path.slice(); + + if (vm.standalone_module_graph) |graph| { + if (graph.find(slice)) |file| { + return file.blob(globalThis).dupe(); + } + } + var cloned = (allocator.dupeZ(u8, slice) catch unreachable)[0..slice.len]; break :brk .{ @@ -2195,6 +2202,9 @@ pub const Blob = struct { cap: SizeType = 0, allocator: std.mem.Allocator, + /// Used by standalone module graph + stored_name: bun.PathString = bun.PathString.empty, + pub fn init(bytes: []u8, allocator: std.mem.Allocator) ByteStore { return .{ .ptr = bytes.ptr, @@ -2528,17 +2538,31 @@ pub const Blob = struct { this: *Blob, globalThis: *JSC.JSGlobalObject, ) callconv(.C) JSValue { + if (this.getFileName()) |path| { + var str = bun.String.create(path); + return str.toJS(globalThis); + } + + return JSValue.undefined; + } + + pub fn getFileName( + this: *const Blob, + ) ?[]const u8 { if (this.store) |store| { if (store.data == .file) { if (store.data.file.pathlike == .path) { - return ZigString.fromUTF8(store.data.file.pathlike.path.slice()).toValueGC(globalThis); + return store.data.file.pathlike.path.slice(); } // we shouldn't return Number here. + } else if (store.data == .bytes) { + if (store.data.bytes.stored_name.slice().len > 0) + return store.data.bytes.stored_name.slice(); } } - return JSC.JSValue.jsUndefined(); + return null; } // TODO: Move this to a separate `File` object or BunFile @@ -3469,6 +3493,14 @@ pub const AnyBlob = union(enum) { InternalBlob: InternalBlob, WTFStringImpl: bun.WTF.StringImpl, + pub fn getFileName(this: *const AnyBlob) ?[]const u8 { + return switch (this.*) { + .Blob => this.Blob.getFileName(), + .WTFStringImpl => null, + .InternalBlob => null, + }; + } + pub inline fn fastSize(this: *const AnyBlob) Blob.SizeType { return switch (this.*) { .Blob => this.Blob.size, diff --git a/src/cli/build_command.zig b/src/cli/build_command.zig index 44e512996..ef99f7765 100644 --- a/src/cli/build_command.zig +++ b/src/cli/build_command.zig @@ -107,6 +107,7 @@ pub const BuildCommand = struct { // We never want to hit the filesystem for these files // This "compiled" protocol is specially handled by the module resolver. this_bundler.options.public_path = "compiled://root/"; + this_bundler.resolver.opts.public_path = "compiled://root/"; if (outfile.len == 0) { outfile = std.fs.path.basename(this_bundler.options.entry_points[0]); diff --git a/src/standalone_bun.zig b/src/standalone_bun.zig index e7363fb58..b18fe384e 100644 --- a/src/standalone_bun.zig +++ b/src/standalone_bun.zig @@ -18,6 +18,14 @@ pub const StandaloneModuleGraph = struct { return &this.files.values()[this.entry_point_id]; } + pub fn find(this: *const StandaloneModuleGraph, name: []const u8) ?*File { + if (!bun.strings.hasPrefixComptime(name, "compiled://root/")) { + return null; + } + + return this.files.getPtr(name); + } + pub const CompiledModuleGraphFile = struct { name: Schema.StringPointer = .{}, loader: bun.options.Loader = .file, @@ -30,6 +38,32 @@ pub const StandaloneModuleGraph = struct { loader: bun.options.Loader, contents: []const u8 = "", sourcemap: LazySourceMap, + blob_: ?*bun.JSC.WebCore.Blob = null, + + pub fn blob(this: *File, globalObject: *bun.JSC.JSGlobalObject) *bun.JSC.WebCore.Blob { + if (this.blob_ == null) { + var store = bun.JSC.WebCore.Blob.Store.init(@constCast(this.contents), bun.default_allocator) catch @panic("out of memory"); + // make it never free + store.ref(); + + var blob_ = bun.default_allocator.create(bun.JSC.WebCore.Blob) catch @panic("out of memory"); + blob_.* = bun.JSC.WebCore.Blob.initWithStore(store, globalObject); + blob_.allocator = bun.default_allocator; + + if (bun.HTTP.MimeType.byExtensionNoDefault(bun.strings.trimLeadingChar(std.fs.path.extension(this.name), '.'))) |mime| { + store.mime_type = mime; + blob_.content_type = mime.value; + blob_.content_type_was_set = true; + blob_.content_type_allocated = false; + } + + store.data.bytes.stored_name = bun.PathString.init(this.name); + + this.blob_ = blob_; + } + + return this.blob_.?; + } }; pub const LazySourceMap = union(enum) { @@ -152,8 +186,16 @@ pub const StandaloneModuleGraph = struct { continue; } + var dest_path = output_file.dest_path; + if (bun.strings.hasPrefixComptime(dest_path, "./")) { + dest_path = dest_path[2..]; + } + var module = CompiledModuleGraphFile{ - .name = string_builder.fmtAppendCount("{s}{s}", .{ prefix, output_file.dest_path }), + .name = string_builder.fmtAppendCount("{s}{s}", .{ + prefix, + dest_path, + }), .loader = output_file.loader, .contents = string_builder.appendCount(output_file.value.buffer.bytes), }; -- cgit v1.2.3 From a7a01bd52f20e7908f06d4de9a1814902b838a4b Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Mon, 3 Jul 2023 16:19:50 -0300 Subject: [tls] add socket parameter, setServername and ALPNprotocols support (#3457) * add socket parameter support * refactor #socket * add test and more fixs * some fixes * bump uws * handlers fix * more fixes * fix node net and node tls tests * fix duplicate port * fix deinit on CallbackJobs * cleanup * add setImmediate repro * add test to setImmediate * this is necessary? * fix prependOnce on native listener * try to findout the error on nodemailer CI * show error message * Update bun.lockb * prettier * Use exact versions of packages * add alpnProtocol support * update * emit error when connect fails on net.Socket * format * fix _write and cleanup * fixup * fix connect, add alpn test * fix socket.io * add socket parameter to TLSSocket * add TLSSocket socket first parameter * fixup and _start * remove flask tests * fmt --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> --- src/bun.js/api/bun.zig | 2 + src/bun.js/api/bun/socket.zig | 544 +++++++++++++++++++-- src/bun.js/api/server.zig | 48 +- src/bun.js/api/sockets.classes.ts | 16 + src/bun.js/bindings/JSSink.cpp | 2 +- src/bun.js/bindings/JSSink.h | 2 +- src/bun.js/bindings/JSSinkLookupTable.h | 2 +- src/bun.js/bindings/ZigGeneratedClasses.cpp | 218 +++++++++ src/bun.js/bindings/generated_classes.zig | 26 + src/bun.js/bindings/webcore/JSEventEmitter.cpp | 2 +- src/deps/uws | 2 +- src/deps/uws.zig | 284 ++++++++++- src/js/node/net.js | 151 ++++-- src/js/node/tls.js | 351 ++++++++++++- src/js/out/modules/node/net.js | 118 +++-- src/js/out/modules/node/tls.js | 190 ++++++- test/bun.lockb | Bin 139814 -> 140524 bytes test/js/node/net/node-net-server.test.ts | 55 --- test/js/node/tls/node-tls-connect.test.ts | 32 ++ test/js/node/tls/node-tls-server.test.ts | 55 --- test/js/third_party/nodemailer/nodemailer.test.ts | 15 + test/js/third_party/nodemailer/package.json | 6 + .../nodemailer/process-nodemailer-fixture.js | 23 + test/js/web/timers/process-setImmediate-fixture.js | 9 + test/js/web/timers/setImmediate.test.js | 27 + test/package.json | 5 +- 26 files changed, 1888 insertions(+), 297 deletions(-) create mode 100644 test/js/node/tls/node-tls-connect.test.ts create mode 100644 test/js/third_party/nodemailer/nodemailer.test.ts create mode 100644 test/js/third_party/nodemailer/package.json create mode 100644 test/js/third_party/nodemailer/process-nodemailer-fixture.js create mode 100644 test/js/web/timers/process-setImmediate-fixture.js (limited to 'src/bun.js/api/bun.zig') diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 2e6381c74..1e5a5e004 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -3794,6 +3794,8 @@ pub const Timer = struct { result.then(globalThis, this, CallbackJob__onResolve, CallbackJob__onReject); }, } + } else { + this.deinit(); } } }; diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 69d6611cb..329cc40e4 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -69,6 +69,11 @@ fn normalizeHost(input: anytype) @TypeOf(input) { const BinaryType = JSC.BinaryType; +const WrappedType = enum { + none, + tls, + tcp, +}; const Handlers = struct { onOpen: JSC.JSValue = .zero, onClose: JSC.JSValue = .zero, @@ -97,8 +102,8 @@ const Handlers = struct { handlers: *Handlers, socket_context: *uws.SocketContext, - pub fn exit(this: *Scope, ssl: bool) void { - this.handlers.markInactive(ssl, this.socket_context); + pub fn exit(this: *Scope, ssl: bool, wrapped: WrappedType) void { + this.handlers.markInactive(ssl, this.socket_context, wrapped); } }; @@ -123,19 +128,24 @@ const Handlers = struct { return true; } - pub fn markInactive(this: *Handlers, ssl: bool, ctx: *uws.SocketContext) void { + pub fn markInactive(this: *Handlers, ssl: bool, ctx: *uws.SocketContext, wrapped: WrappedType) void { Listener.log("markInactive", .{}); this.active_connections -= 1; - if (this.active_connections == 0 and this.is_server) { - var listen_socket: *Listener = @fieldParentPtr(Listener, "handlers", this); - // allow it to be GC'd once the last connection is closed and it's not listening anymore - if (listen_socket.listener == null) { - listen_socket.strong_self.clear(); + if (this.active_connections == 0) { + if (this.is_server) { + var listen_socket: *Listener = @fieldParentPtr(Listener, "handlers", this); + // allow it to be GC'd once the last connection is closed and it's not listening anymore + if (listen_socket.listener == null) { + listen_socket.strong_self.clear(); + } + } else { + this.unprotect(); + // will deinit when is not wrapped or when is the TCP wrapped connection + if (wrapped != .tls) { + ctx.deinit(ssl); + } + bun.default_allocator.destroy(this); } - } else if (this.active_connections == 0 and !this.is_server) { - this.unprotect(); - ctx.deinit(ssl); - bun.default_allocator.destroy(this); } } @@ -364,6 +374,7 @@ pub const Listener = struct { connection: UnixOrHost, socket_context: ?*uws.SocketContext = null, ssl: bool = false, + protos: ?[]const u8 = null, strong_data: JSC.Strong = .{}, strong_self: JSC.Strong = .{}, @@ -395,6 +406,19 @@ pub const Listener = struct { port: u16, }, + pub fn clone(this: UnixOrHost) UnixOrHost { + switch (this) { + .unix => |u| { + return .{ + .unix = (bun.default_allocator.dupe(u8, u) catch unreachable), + }; + }, + .host => |h| { + return .{ .host = .{ .host = (bun.default_allocator.dupe(u8, h.host) catch unreachable), .port = this.host.port } }; + }, + } + } + pub fn deinit(this: UnixOrHost) void { switch (this) { .unix => |u| { @@ -455,10 +479,12 @@ pub const Listener = struct { var socket_config = SocketConfig.fromJS(opts, globalObject, exception) orelse { return .zero; }; + var hostname_or_unix = socket_config.hostname_or_unix; var port = socket_config.port; var ssl = socket_config.ssl; var handlers = socket_config.handlers; + var protos: ?[]const u8 = null; const exclusive = socket_config.exclusive; handlers.is_server = true; @@ -496,6 +522,10 @@ pub const Listener = struct { }; if (ssl_enabled) { + if (ssl.?.protos) |p| { + protos = p[0..ssl.?.protos_len]; + } + uws.NewSocketHandler(true).configure( socket_context, true, @@ -593,6 +623,7 @@ pub const Listener = struct { .ssl = ssl_enabled, .socket_context = socket_context, .listener = listen_socket, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, }; socket.handlers.protect(); @@ -649,6 +680,8 @@ pub const Listener = struct { .handlers = &listener.handlers, .this_value = .zero, .socket = socket, + .protos = listener.protos, + .owned_protos = false, }; if (listener.strong_data.get()) |default_data| { const globalObject = listener.handlers.globalObject; @@ -715,6 +748,10 @@ pub const Listener = struct { this.handlers.unprotect(); this.connection.deinit(); + if (this.protos) |protos| { + this.protos = null; + bun.default_allocator.destroy(protos); + } bun.default_allocator.destroy(this); } @@ -775,13 +812,16 @@ pub const Listener = struct { const socket_config = SocketConfig.fromJS(opts, globalObject, exception) orelse { return .zero; }; + var hostname_or_unix = socket_config.hostname_or_unix; var port = socket_config.port; var ssl = socket_config.ssl; var handlers = socket_config.handlers; var default_data = socket_config.default_data; + var protos: ?[]const u8 = null; const ssl_enabled = ssl != null; + defer if (ssl != null) ssl.?.deinit(); handlers.protect(); @@ -797,6 +837,9 @@ pub const Listener = struct { }; if (ssl_enabled) { + if (ssl.?.protos) |p| { + protos = p[0..ssl.?.protos_len]; + } uws.NewSocketHandler(true).configure( socket_context, true, @@ -848,6 +891,7 @@ pub const Listener = struct { .this_value = .zero, .socket = undefined, .connection = connection, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, }; TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); @@ -871,6 +915,7 @@ pub const Listener = struct { .this_value = .zero, .socket = undefined, .connection = null, + .protos = null, }; TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); @@ -898,11 +943,41 @@ fn JSSocketType(comptime ssl: bool) type { } } +fn selectALPNCallback( + _: ?*BoringSSL.SSL, + out: [*c][*c]const u8, + outlen: [*c]u8, + in: [*c]const u8, + inlen: c_uint, + arg: ?*anyopaque, +) callconv(.C) c_int { + const this = bun.cast(*TLSSocket, arg); + if (this.protos) |protos| { + if (protos.len == 0) { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; + } + + const status = BoringSSL.SSL_select_next_proto(bun.cast([*c][*c]u8, out), outlen, protos.ptr, @intCast(c_uint, protos.len), in, inlen); + + // Previous versions of Node.js returned SSL_TLSEXT_ERR_NOACK if no protocol + // match was found. This would neither cause a fatal alert nor would it result + // in a useful ALPN response as part of the Server Hello message. + // We now return SSL_TLSEXT_ERR_ALERT_FATAL in that case as per Section 3.2 + // of RFC 7301, which causes a fatal no_application_protocol alert. + const expected = if (comptime BoringSSL.OPENSSL_NPN_NEGOTIATED == 1) BoringSSL.SSL_TLSEXT_ERR_OK else BoringSSL.SSL_TLSEXT_ERR_ALERT_FATAL; + + return if (status == expected) 1 else 0; + } else { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; + } +} + fn NewSocket(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); socket: Socket, detached: bool = false, + wrapped: WrappedType = .none, handlers: *Handlers, this_value: JSC.JSValue = .zero, poll_ref: JSC.PollRef = JSC.PollRef.init(), @@ -910,6 +985,8 @@ fn NewSocket(comptime ssl: bool) type { last_4: [4]u8 = .{ 0, 0, 0, 0 }, authorized: bool = false, connection: ?Listener.UnixOrHost = null, + protos: ?[]const u8, + owned_protos: bool = true, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there @@ -1079,7 +1156,7 @@ fn NewSocket(comptime ssl: bool) type { var vm = this.handlers.vm; this.reffer.unref(vm); - this.handlers.markInactive(ssl, this.socket.context()); + this.handlers.markInactive(ssl, this.socket.context(), this.wrapped); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .Release); } @@ -1091,25 +1168,35 @@ fn NewSocket(comptime ssl: bool) type { // Add SNI support for TLS (mongodb and others requires this) if (comptime ssl) { - if (this.connection) |connection| { - if (connection == .host) { - const host = normalizeHost(connection.host.host); - if (host.len > 0) { - var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); - if (!ssl_ptr.isInitFinished()) { + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, socket.getNativeHandle()); + if (!ssl_ptr.isInitFinished()) { + if (this.connection) |connection| { + if (connection == .host) { + const host = normalizeHost(connection.host.host); + if (host.len > 0) { var host__ = default_allocator.dupeZ(u8, host) catch unreachable; defer default_allocator.free(host__); ssl_ptr.setHostname(host__); } } } + if (this.protos) |protos| { + if (this.handlers.is_server) { + BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this)); + } else { + _ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @intCast(c_uint, protos.len)); + } + } } } this.poll_ref.ref(this.handlers.vm); this.detached = false; this.socket = socket; - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this); + + if (this.wrapped == .none) { + socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this); + } const handlers = this.handlers; const callback = handlers.onOpen; @@ -1174,7 +1261,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1211,7 +1298,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1255,7 +1342,6 @@ fn NewSocket(comptime ssl: bool) type { log("onClose", .{}); this.detached = true; defer this.markInactive(); - const handlers = this.handlers; this.poll_ref.unref(handlers.vm); @@ -1265,7 +1351,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); var globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1295,7 +1381,7 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); - defer scope.exit(ssl); + defer scope.exit(ssl, this.wrapped); // const encoding = handlers.encoding; const result = callback.callWithThis(globalObject, this_value, &[_]JSValue{ @@ -1476,10 +1562,20 @@ fn NewSocket(comptime ssl: bool) type { } fn writeMaybeCorked(this: *This, buffer: []const u8, is_end: bool) i32 { - if (this.socket.isShutdown() or this.socket.isClosed()) { + if (this.detached or this.socket.isShutdown() or this.socket.isClosed()) { return -1; } // we don't cork yet but we might later + + if (comptime ssl) { + // TLS wrapped but in TCP mode + if (this.wrapped == .tcp) { + const res = this.socket.rawWrite(buffer, is_end); + log("write({d}, {any}) = {d}", .{ buffer.len, is_end, res }); + return res; + } + } + const res = this.socket.write(buffer, is_end); log("write({d}, {any}) = {d}", .{ buffer.len, is_end, res }); return res; @@ -1487,7 +1583,6 @@ fn NewSocket(comptime ssl: bool) type { fn writeOrEnd(this: *This, globalObject: *JSC.JSGlobalObject, args: []const JSC.JSValue, is_end: bool) WriteResult { if (args.len == 0) return .{ .success = .{} }; - if (args.ptr[0].asArrayBuffer(globalObject)) |array_buffer| { var slice = array_buffer.slice(); @@ -1681,9 +1776,6 @@ fn NewSocket(comptime ssl: bool) type { if (result.wrote == result.total) { this.socket.flush(); this.detached = true; - if (!this.socket.isClosed()) { - this.socket.close(0, null); - } this.markInactive(); } break :brk JSValue.jsNumber(result.wrote); @@ -1706,17 +1798,27 @@ fn NewSocket(comptime ssl: bool) type { pub fn finalize(this: *This) callconv(.C) void { log("finalize()", .{}); - if (this.detached) return; - this.detached = true; - if (!this.socket.isClosed()) { - this.socket.close(0, null); + if (!this.detached) { + this.detached = true; + if (!this.socket.isClosed()) { + this.socket.close(0, null); + } + this.markInactive(); + } + + this.poll_ref.unref(JSC.VirtualMachine.get()); + // need to deinit event without being attached + if (this.owned_protos) { + if (this.protos) |protos| { + this.protos = null; + default_allocator.free(protos); + } } + if (this.connection) |connection| { - connection.deinit(); this.connection = null; + connection.deinit(); } - this.markInactive(); - this.poll_ref.unref(JSC.VirtualMachine.get()); } pub fn reload(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { @@ -1756,8 +1858,376 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } + + pub fn getALPNProtocol( + this: *This, + globalObject: *JSC.JSGlobalObject, + ) callconv(.C) JSValue { + if (comptime ssl == false) { + return JSValue.jsBoolean(false); + } + + if (this.detached) { + return JSValue.jsBoolean(false); + } + + var alpn_proto: [*c]const u8 = null; + var alpn_proto_len: u32 = 0; + + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); + BoringSSL.SSL_get0_alpn_selected(ssl_ptr, &alpn_proto, &alpn_proto_len); + if (alpn_proto == null or alpn_proto_len == 0) { + return JSValue.jsBoolean(false); + } + + const slice = alpn_proto[0..alpn_proto_len]; + if (strings.eql(slice, "h2")) { + return ZigString.static("h2").toValue(globalObject); + } + if (strings.eql(slice, "http/1.1")) { + return ZigString.static("http/1.1").toValue(globalObject); + } + return ZigString.fromUTF8(slice).toValueGC(globalObject); + } + + pub fn setServername( + this: *This, + globalObject: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + if (comptime ssl == false) { + return JSValue.jsUndefined(); + } + if (this.detached) { + return JSValue.jsUndefined(); + } + + if (this.handlers.is_server) { + globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); + return .zero; + } + + const args = callframe.arguments(1); + if (args.len < 1) { + globalObject.throw("Expected 1 argument", .{}); + return .zero; + } + + const server_name = args.ptr[0]; + if (!server_name.isString()) { + globalObject.throw("Expected \"serverName\" to be a string", .{}); + return .zero; + } + + const slice = server_name.getZigString(globalObject).toSlice(bun.default_allocator); + defer slice.deinit(); + const host = normalizeHost(slice.slice()); + if (host.len > 0) { + var ssl_ptr: *BoringSSL.SSL = @ptrCast(*BoringSSL.SSL, this.socket.getNativeHandle()); + if (ssl_ptr.isInitFinished()) { + // match node.js exceptions + globalObject.throw("Already started.", .{}); + return .zero; + } + var host__ = default_allocator.dupeZ(u8, host) catch unreachable; + defer default_allocator.free(host__); + ssl_ptr.setHostname(host__); + } + + return JSValue.jsUndefined(); + } + + pub fn open( + this: *This, + _: *JSC.JSGlobalObject, + _: *JSC.CallFrame, + ) callconv(.C) JSValue { + JSC.markBinding(@src()); + this.socket.open(!this.handlers.is_server); + return JSValue.jsUndefined(); + } + + // this invalidates the current socket returning 2 new sockets + // one for non-TLS and another for TLS + // handlers for non-TLS are preserved + pub fn wrapTLS( + this: *This, + globalObject: *JSC.JSGlobalObject, + callframe: *JSC.CallFrame, + ) callconv(.C) JSValue { + JSC.markBinding(@src()); + if (comptime ssl) { + return JSValue.jsUndefined(); + } + + if (this.detached) { + return JSValue.jsUndefined(); + } + + const args = callframe.arguments(1); + + if (args.len < 1) { + globalObject.throw("Expected 1 arguments", .{}); + return .zero; + } + + var exception: JSC.C.JSValueRef = null; + + const opts = args.ptr[0]; + if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) { + globalObject.throw("Expected options object", .{}); + return .zero; + } + + var socket_obj = opts.get(globalObject, "socket") orelse { + globalObject.throw("Expected \"socket\" option", .{}); + return .zero; + }; + + var handlers = Handlers.fromJS(globalObject, socket_obj, &exception) orelse { + globalObject.throwValue(exception.?.value()); + return .zero; + }; + + var ssl_opts: ?JSC.API.ServerConfig.SSLConfig = null; + + if (opts.getTruthy(globalObject, "tls")) |tls| { + if (tls.isBoolean()) { + if (tls.toBoolean()) { + ssl_opts = JSC.API.ServerConfig.SSLConfig.zero; + } + } else { + if (JSC.API.ServerConfig.SSLConfig.inJS(globalObject, tls, &exception)) |ssl_config| { + ssl_opts = ssl_config; + } else if (exception != null) { + return .zero; + } + } + } + + if (ssl_opts == null) { + globalObject.throw("Expected \"tls\" option", .{}); + return .zero; + } + + var default_data = JSValue.zero; + if (opts.getTruthy(globalObject, "data")) |default_data_value| { + default_data = default_data_value; + default_data.ensureStillAlive(); + } + + var socket_config = ssl_opts.?; + defer socket_config.deinit(); + const options = socket_config.asUSockets(); + + const protos = socket_config.protos; + const protos_len = socket_config.protos_len; + + const ext_size = @sizeOf(WrappedSocket); + + var tls = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); + var handlers_ptr = handlers.vm.allocator.create(Handlers) catch @panic("OOM"); + handlers_ptr.* = handlers; + handlers_ptr.is_server = this.handlers.is_server; + handlers_ptr.protect(); + + tls.* = .{ + .handlers = handlers_ptr, + .this_value = .zero, + .socket = undefined, + .connection = if (this.connection) |c| c.clone() else null, + .wrapped = .tls, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, + }; + + var tls_js_value = tls.getThisValue(globalObject); + TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); + + const TCPHandler = NewWrappedHandler(false); + + // reconfigure context to use the new wrapper handlers + Socket.unsafeConfigure(this.socket.context(), true, true, WrappedSocket, TCPHandler); + const old_context = this.socket.context(); + const TLSHandler = NewWrappedHandler(true); + const new_socket = this.socket.wrapTLS( + options, + ext_size, + true, + WrappedSocket, + TLSHandler, + ) orelse { + handlers_ptr.unprotect(); + handlers.vm.allocator.destroy(handlers_ptr); + bun.default_allocator.destroy(tls); + return JSValue.jsUndefined(); + }; + tls.socket = new_socket; + + var raw = handlers.vm.allocator.create(TLSSocket) catch @panic("OOM"); + var raw_handlers_ptr = handlers.vm.allocator.create(Handlers) catch @panic("OOM"); + this.handlers.unprotect(); + + var cloned_handlers: Handlers = .{ + .vm = globalObject.bunVM(), + .globalObject = globalObject, + .onOpen = this.handlers.onOpen, + .onClose = this.handlers.onClose, + .onData = this.handlers.onData, + .onWritable = this.handlers.onWritable, + .onTimeout = this.handlers.onTimeout, + .onConnectError = this.handlers.onConnectError, + .onEnd = this.handlers.onEnd, + .onError = this.handlers.onError, + .onHandshake = this.handlers.onHandshake, + .binary_type = this.handlers.binary_type, + }; + + raw_handlers_ptr.* = cloned_handlers; + raw_handlers_ptr.is_server = this.handlers.is_server; + raw_handlers_ptr.protect(); + raw.* = .{ + .handlers = raw_handlers_ptr, + .this_value = .zero, + .socket = new_socket, + .connection = if (this.connection) |c| c.clone() else null, + .wrapped = .tcp, + .protos = null, + }; + + var raw_js_value = raw.getThisValue(globalObject); + if (JSSocketType(ssl).dataGetCached(this.getThisValue(globalObject))) |raw_default_data| { + raw_default_data.ensureStillAlive(); + TLSSocket.dataSetCached(raw_js_value, globalObject, raw_default_data); + } + // marks both as active + raw.markActive(); + // this will keep tls alive until socket.open() is called to start TLS certificate and the handshake process + // open is not immediately called because we need to set bunSocketInternal + tls.markActive(); + + // mark both instances on socket data + new_socket.ext(WrappedSocket).?.* = .{ .tcp = raw, .tls = tls }; + + //detach and invalidate the old instance + this.detached = true; + if (this.reffer.has) { + var vm = this.handlers.vm; + this.reffer.unref(vm); + old_context.deinit(ssl); + bun.default_allocator.destroy(this.handlers); + this.poll_ref.unref(vm); + this.has_pending_activity.store(false, .Release); + } + + const array = JSC.JSValue.createEmptyArray(globalObject, 2); + array.putIndex(globalObject, 0, raw_js_value); + array.putIndex(globalObject, 1, tls_js_value); + return array; + } }; } pub const TCPSocket = NewSocket(false); pub const TLSSocket = NewSocket(true); + +pub const WrappedSocket = extern struct { + // both shares the same socket but one behaves as TLS and the other as TCP + tls: *TLSSocket, + tcp: *TLSSocket, +}; + +pub fn NewWrappedHandler(comptime tls: bool) type { + const Socket = uws.NewSocketHandler(true); + return struct { + pub fn onOpen( + this: WrappedSocket, + socket: Socket, + ) void { + // only TLS will call onOpen + if (comptime tls) { + TLSSocket.onOpen(this.tls, socket); + } + } + + pub fn onEnd( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onEnd(this.tls, socket); + } else { + TLSSocket.onEnd(this.tcp, socket); + } + } + + pub fn onHandshake( + this: WrappedSocket, + socket: Socket, + success: i32, + ssl_error: uws.us_bun_verify_error_t, + ) void { + // only TLS will call onHandshake + if (comptime tls) { + TLSSocket.onHandshake(this.tls, socket, success, ssl_error); + } + } + + pub fn onClose( + this: WrappedSocket, + socket: Socket, + err: c_int, + data: ?*anyopaque, + ) void { + if (comptime tls) { + TLSSocket.onClose(this.tls, socket, err, data); + } else { + TLSSocket.onClose(this.tcp, socket, err, data); + } + } + + pub fn onData( + this: WrappedSocket, + socket: Socket, + data: []const u8, + ) void { + if (comptime tls) { + TLSSocket.onData(this.tls, socket, data); + } else { + TLSSocket.onData(this.tcp, socket, data); + } + } + + pub fn onWritable( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onWritable(this.tls, socket); + } else { + TLSSocket.onWritable(this.tcp, socket); + } + } + pub fn onTimeout( + this: WrappedSocket, + socket: Socket, + ) void { + if (comptime tls) { + TLSSocket.onTimeout(this.tls, socket); + } else { + TLSSocket.onTimeout(this.tcp, socket); + } + } + + pub fn onConnectError( + this: WrappedSocket, + socket: Socket, + errno: c_int, + ) void { + if (comptime tls) { + TLSSocket.onConnectError(this.tls, socket, errno); + } else { + TLSSocket.onConnectError(this.tcp, socket, errno); + } + } + }; +} diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 140e62ce4..f52c08301 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -163,6 +163,8 @@ pub const ServerConfig = struct { request_cert: i32 = 0, reject_unauthorized: i32 = 0, ssl_ciphers: [*c]const u8 = null, + protos: [*c]const u8 = null, + protos_len: usize = 0, const log = Output.scoped(.SSLConfig, false); @@ -215,6 +217,7 @@ pub const ServerConfig = struct { "dh_params_file_name", "passphrase", "ssl_ciphers", + "protos", }; inline for (fields) |field| { @@ -270,6 +273,9 @@ pub const ServerConfig = struct { pub fn inJS(global: *JSC.JSGlobalObject, obj: JSC.JSValue, exception: JSC.C.ExceptionRef) ?SSLConfig { var result = zero; + var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); + defer arena.deinit(); + if (!obj.isObject()) { JSC.throwInvalidArguments("tls option expects an object", .{}, global, exception); return null; @@ -301,7 +307,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -317,7 +322,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -325,7 +329,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("key argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all keys result.key = native_array; result.deinit(); @@ -333,8 +336,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -356,7 +357,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -369,14 +369,11 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("key argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.key = native_array; result.deinit(); return null; } - - arena.deinit(); } } @@ -394,6 +391,22 @@ pub const ServerConfig = struct { } } + if (obj.getTruthy(global, "ALPNProtocols")) |protocols| { + if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), protocols, exception)) |sb| { + const sliced = sb.slice(); + if (sliced.len > 0) { + result.protos = bun.default_allocator.dupeZ(u8, sliced) catch unreachable; + result.protos_len = sliced.len; + } + + any = true; + } else { + global.throwInvalidArguments("ALPNProtocols argument must be an string, Buffer or TypedArray", .{}); + result.deinit(); + return null; + } + } + if (obj.getTruthy(global, "cert")) |js_obj| { if (js_obj.jsType().isArray()) { const count = js_obj.getLength(global); @@ -403,7 +416,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -419,7 +431,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -427,7 +438,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("cert argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.cert = native_array; result.deinit(); @@ -435,8 +445,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -458,7 +466,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -471,14 +478,11 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("cert argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all certs result.cert = native_array; result.deinit(); return null; } - - arena.deinit(); } } @@ -518,7 +522,6 @@ pub const ServerConfig = struct { var i: u32 = 0; var valid_count: u32 = 0; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); while (i < count) : (i += 1) { const item = js_obj.getIndex(global, i); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), item, exception)) |sb| { @@ -534,7 +537,6 @@ pub const ServerConfig = struct { valid_count += 1; any = true; } else { - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -542,7 +544,6 @@ pub const ServerConfig = struct { } } else { global.throwInvalidArguments("ca argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}); - arena.deinit(); // mark and free all CA's result.cert = native_array; result.deinit(); @@ -550,8 +551,6 @@ pub const ServerConfig = struct { } } - arena.deinit(); - if (valid_count == 0) { bun.default_allocator.free(native_array); } else { @@ -573,7 +572,6 @@ pub const ServerConfig = struct { } } else { const native_array = bun.default_allocator.alloc([*c]const u8, 1) catch unreachable; - var arena: @import("root").bun.ArenaAllocator = @import("root").bun.ArenaAllocator.init(bun.default_allocator); if (JSC.Node.StringOrBuffer.fromJS(global, arena.allocator(), js_obj, exception)) |sb| { const sliced = sb.slice(); if (sliced.len > 0) { @@ -586,13 +584,11 @@ pub const ServerConfig = struct { } } else { JSC.throwInvalidArguments("ca argument must be an string, Buffer, TypedArray, BunFile or an array containing string, Buffer, TypedArray or BunFile", .{}, global, exception); - arena.deinit(); // mark and free all certs result.ca = native_array; result.deinit(); return null; } - arena.deinit(); } } diff --git a/src/bun.js/api/sockets.classes.ts b/src/bun.js/api/sockets.classes.ts index da07741a3..0c7847e19 100644 --- a/src/bun.js/api/sockets.classes.ts +++ b/src/bun.js/api/sockets.classes.ts @@ -15,10 +15,21 @@ function generate(ssl) { authorized: { getter: "getAuthorized", }, + alpnProtocol: { + getter: "getALPNProtocol", + }, write: { fn: "write", length: 3, }, + wrapTLS: { + fn: "wrapTLS", + length: 1, + }, + open: { + fn: "open", + length: 0, + }, end: { fn: "end", length: 3, @@ -82,6 +93,11 @@ function generate(ssl) { fn: "reload", length: 1, }, + + setServername: { + fn: "setServername", + length: 1, + }, }, finalize: true, construct: true, diff --git a/src/bun.js/bindings/JSSink.cpp b/src/bun.js/bindings/JSSink.cpp index 19bf05599..5f99d3792 100644 --- a/src/bun.js/bindings/JSSink.cpp +++ b/src/bun.js/bindings/JSSink.cpp @@ -1,6 +1,6 @@ // AUTO-GENERATED FILE. DO NOT EDIT. -// Generated by 'make generate-sink' at 2023-06-25T17:34:54.187Z +// Generated by 'make generate-sink' at 2023-07-02T16:19:51.440Z // To regenerate this file, run: // // make generate-sink diff --git a/src/bun.js/bindings/JSSink.h b/src/bun.js/bindings/JSSink.h index 9bf5554c4..41d7065dc 100644 --- a/src/bun.js/bindings/JSSink.h +++ b/src/bun.js/bindings/JSSink.h @@ -1,6 +1,6 @@ // AUTO-GENERATED FILE. DO NOT EDIT. -// Generated by 'make generate-sink' at 2023-06-25T17:34:54.186Z +// Generated by 'make generate-sink' at 2023-07-02T16:19:51.438Z // #pragma once diff --git a/src/bun.js/bindings/JSSinkLookupTable.h b/src/bun.js/bindings/JSSinkLookupTable.h index f8518bc5e..e4ed81629 100644 --- a/src/bun.js/bindings/JSSinkLookupTable.h +++ b/src/bun.js/bindings/JSSinkLookupTable.h @@ -1,4 +1,4 @@ -// Automatically generated from src/bun.js/bindings/JSSink.cpp using /Users/silas/Workspace/opensource/bun/src/bun.js/WebKit/Source/JavaScriptCore/create_hash_table. DO NOT EDIT! +// Automatically generated from src/bun.js/bindings/JSSink.cpp using /home/cirospaciari/Repos/bun/src/bun.js/WebKit/Source/JavaScriptCore/create_hash_table. DO NOT EDIT! diff --git a/src/bun.js/bindings/ZigGeneratedClasses.cpp b/src/bun.js/bindings/ZigGeneratedClasses.cpp index b7461b5f0..866970e4d 100644 --- a/src/bun.js/bindings/ZigGeneratedClasses.cpp +++ b/src/bun.js/bindings/ZigGeneratedClasses.cpp @@ -16872,6 +16872,9 @@ extern "C" void* TCPSocketClass__construct(JSC::JSGlobalObject*, JSC::CallFrame* JSC_DECLARE_CUSTOM_GETTER(jsTCPSocketConstructor); extern "C" void TCPSocketClass__finalize(void*); +extern "C" JSC::EncodedJSValue TCPSocketPrototype__getALPNProtocol(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); +JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__alpnProtocolGetterWrap); + extern "C" JSC::EncodedJSValue TCPSocketPrototype__getAuthorized(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__authorizedGetterWrap); @@ -16896,6 +16899,9 @@ JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__listenerGetterWrap); extern "C" JSC::EncodedJSValue TCPSocketPrototype__getLocalPort(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__localPortGetterWrap); +extern "C" EncodedJSValue TCPSocketPrototype__open(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__openCallback); + extern "C" JSC::EncodedJSValue TCPSocketPrototype__getReadyState(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__readyStateGetterWrap); @@ -16908,6 +16914,9 @@ JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__reloadCallback); extern "C" JSC::EncodedJSValue TCPSocketPrototype__getRemoteAddress(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TCPSocketPrototype__remoteAddressGetterWrap); +extern "C" EncodedJSValue TCPSocketPrototype__setServername(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__setServernameCallback); + extern "C" EncodedJSValue TCPSocketPrototype__shutdown(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__shutdownCallback); @@ -16917,12 +16926,16 @@ JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__timeoutCallback); extern "C" EncodedJSValue TCPSocketPrototype__unref(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__unrefCallback); +extern "C" EncodedJSValue TCPSocketPrototype__wrapTLS(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__wrapTLSCallback); + extern "C" EncodedJSValue TCPSocketPrototype__write(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TCPSocketPrototype__writeCallback); STATIC_ASSERT_ISO_SUBSPACE_SHARABLE(JSTCPSocketPrototype, JSTCPSocketPrototype::Base); static const HashTableValue JSTCPSocketPrototypeTableValues[] = { + { "alpnProtocol"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__alpnProtocolGetterWrap, 0 } }, { "authorized"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__authorizedGetterWrap, 0 } }, { "data"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__dataGetterWrap, TCPSocketPrototype__dataSetterWrap } }, { "end"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__endCallback, 3 } }, @@ -16930,13 +16943,16 @@ static const HashTableValue JSTCPSocketPrototypeTableValues[] = { { "getAuthorizationError"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__getAuthorizationErrorCallback, 0 } }, { "listener"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__listenerGetterWrap, 0 } }, { "localPort"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__localPortGetterWrap, 0 } }, + { "open"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__openCallback, 0 } }, { "readyState"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__readyStateGetterWrap, 0 } }, { "ref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__refCallback, 0 } }, { "reload"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__reloadCallback, 1 } }, { "remoteAddress"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TCPSocketPrototype__remoteAddressGetterWrap, 0 } }, + { "setServername"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__setServernameCallback, 1 } }, { "shutdown"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__shutdownCallback, 1 } }, { "timeout"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__timeoutCallback, 1 } }, { "unref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__unrefCallback, 0 } }, + { "wrapTLS"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__wrapTLSCallback, 1 } }, { "write"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TCPSocketPrototype__writeCallback, 3 } } }; @@ -16954,6 +16970,18 @@ JSC_DEFINE_CUSTOM_GETTER(jsTCPSocketConstructor, (JSGlobalObject * lexicalGlobal return JSValue::encode(globalObject->JSTCPSocketConstructor()); } +JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__alpnProtocolGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) +{ + auto& vm = lexicalGlobalObject->vm(); + Zig::GlobalObject* globalObject = reinterpret_cast(lexicalGlobalObject); + auto throwScope = DECLARE_THROW_SCOPE(vm); + JSTCPSocket* thisObject = jsCast(JSValue::decode(thisValue)); + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + JSC::EncodedJSValue result = TCPSocketPrototype__getALPNProtocol(thisObject->wrapped(), globalObject); + RETURN_IF_EXCEPTION(throwScope, {}); + RELEASE_AND_RETURN(throwScope, result); +} + JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__authorizedGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17113,6 +17141,33 @@ JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__localPortGetterWrap, (JSGlobalObjec RELEASE_AND_RETURN(throwScope, result); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__openCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__open(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_CUSTOM_GETTER(TCPSocketPrototype__readyStateGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17210,6 +17265,33 @@ extern "C" EncodedJSValue TCPSocketPrototype__remoteAddressGetCachedValue(JSC::E return JSValue::encode(thisObject->m_remoteAddress.get()); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__setServernameCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__setServername(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__shutdownCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17291,6 +17373,33 @@ JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__unrefCallback, (JSGlobalObject * le return TCPSocketPrototype__unref(thisObject->wrapped(), lexicalGlobalObject, callFrame); } +JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__wrapTLSCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTCPSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TCPSocketPrototype__wrapTLS(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TCPSocketPrototype__writeCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17479,6 +17588,9 @@ extern "C" void* TLSSocketClass__construct(JSC::JSGlobalObject*, JSC::CallFrame* JSC_DECLARE_CUSTOM_GETTER(jsTLSSocketConstructor); extern "C" void TLSSocketClass__finalize(void*); +extern "C" JSC::EncodedJSValue TLSSocketPrototype__getALPNProtocol(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); +JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__alpnProtocolGetterWrap); + extern "C" JSC::EncodedJSValue TLSSocketPrototype__getAuthorized(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__authorizedGetterWrap); @@ -17503,6 +17615,9 @@ JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__listenerGetterWrap); extern "C" JSC::EncodedJSValue TLSSocketPrototype__getLocalPort(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__localPortGetterWrap); +extern "C" EncodedJSValue TLSSocketPrototype__open(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__openCallback); + extern "C" JSC::EncodedJSValue TLSSocketPrototype__getReadyState(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__readyStateGetterWrap); @@ -17515,6 +17630,9 @@ JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__reloadCallback); extern "C" JSC::EncodedJSValue TLSSocketPrototype__getRemoteAddress(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject); JSC_DECLARE_CUSTOM_GETTER(TLSSocketPrototype__remoteAddressGetterWrap); +extern "C" EncodedJSValue TLSSocketPrototype__setServername(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__setServernameCallback); + extern "C" EncodedJSValue TLSSocketPrototype__shutdown(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__shutdownCallback); @@ -17524,12 +17642,16 @@ JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__timeoutCallback); extern "C" EncodedJSValue TLSSocketPrototype__unref(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__unrefCallback); +extern "C" EncodedJSValue TLSSocketPrototype__wrapTLS(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); +JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__wrapTLSCallback); + extern "C" EncodedJSValue TLSSocketPrototype__write(void* ptr, JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame); JSC_DECLARE_HOST_FUNCTION(TLSSocketPrototype__writeCallback); STATIC_ASSERT_ISO_SUBSPACE_SHARABLE(JSTLSSocketPrototype, JSTLSSocketPrototype::Base); static const HashTableValue JSTLSSocketPrototypeTableValues[] = { + { "alpnProtocol"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__alpnProtocolGetterWrap, 0 } }, { "authorized"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__authorizedGetterWrap, 0 } }, { "data"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__dataGetterWrap, TLSSocketPrototype__dataSetterWrap } }, { "end"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__endCallback, 3 } }, @@ -17537,13 +17659,16 @@ static const HashTableValue JSTLSSocketPrototypeTableValues[] = { { "getAuthorizationError"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__getAuthorizationErrorCallback, 0 } }, { "listener"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__listenerGetterWrap, 0 } }, { "localPort"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__localPortGetterWrap, 0 } }, + { "open"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__openCallback, 0 } }, { "readyState"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__readyStateGetterWrap, 0 } }, { "ref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__refCallback, 0 } }, { "reload"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__reloadCallback, 1 } }, { "remoteAddress"_s, static_cast(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, TLSSocketPrototype__remoteAddressGetterWrap, 0 } }, + { "setServername"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__setServernameCallback, 1 } }, { "shutdown"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__shutdownCallback, 1 } }, { "timeout"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__timeoutCallback, 1 } }, { "unref"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__unrefCallback, 0 } }, + { "wrapTLS"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__wrapTLSCallback, 1 } }, { "write"_s, static_cast(JSC::PropertyAttribute::Function | PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::NativeFunctionType, TLSSocketPrototype__writeCallback, 3 } } }; @@ -17561,6 +17686,18 @@ JSC_DEFINE_CUSTOM_GETTER(jsTLSSocketConstructor, (JSGlobalObject * lexicalGlobal return JSValue::encode(globalObject->JSTLSSocketConstructor()); } +JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__alpnProtocolGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) +{ + auto& vm = lexicalGlobalObject->vm(); + Zig::GlobalObject* globalObject = reinterpret_cast(lexicalGlobalObject); + auto throwScope = DECLARE_THROW_SCOPE(vm); + JSTLSSocket* thisObject = jsCast(JSValue::decode(thisValue)); + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + JSC::EncodedJSValue result = TLSSocketPrototype__getALPNProtocol(thisObject->wrapped(), globalObject); + RETURN_IF_EXCEPTION(throwScope, {}); + RELEASE_AND_RETURN(throwScope, result); +} + JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__authorizedGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17720,6 +17857,33 @@ JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__localPortGetterWrap, (JSGlobalObjec RELEASE_AND_RETURN(throwScope, result); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__openCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__open(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_CUSTOM_GETTER(TLSSocketPrototype__readyStateGetterWrap, (JSGlobalObject * lexicalGlobalObject, EncodedJSValue thisValue, PropertyName attributeName)) { auto& vm = lexicalGlobalObject->vm(); @@ -17817,6 +17981,33 @@ extern "C" EncodedJSValue TLSSocketPrototype__remoteAddressGetCachedValue(JSC::E return JSValue::encode(thisObject->m_remoteAddress.get()); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__setServernameCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__setServername(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__shutdownCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); @@ -17898,6 +18089,33 @@ JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__unrefCallback, (JSGlobalObject * le return TLSSocketPrototype__unref(thisObject->wrapped(), lexicalGlobalObject, callFrame); } +JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__wrapTLSCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) +{ + auto& vm = lexicalGlobalObject->vm(); + + JSTLSSocket* thisObject = jsDynamicCast(callFrame->thisValue()); + + if (UNLIKELY(!thisObject)) { + auto throwScope = DECLARE_THROW_SCOPE(vm); + return throwVMTypeError(lexicalGlobalObject, throwScope); + } + + JSC::EnsureStillAliveScope thisArg = JSC::EnsureStillAliveScope(thisObject); + +#ifdef BUN_DEBUG + /** View the file name of the JS file that called this function + * from a debugger */ + SourceOrigin sourceOrigin = callFrame->callerSourceOrigin(vm); + const char* fileName = sourceOrigin.string().utf8().data(); + static const char* lastFileName = nullptr; + if (lastFileName != fileName) { + lastFileName = fileName; + } +#endif + + return TLSSocketPrototype__wrapTLS(thisObject->wrapped(), lexicalGlobalObject, callFrame); +} + JSC_DEFINE_HOST_FUNCTION(TLSSocketPrototype__writeCallback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame)) { auto& vm = lexicalGlobalObject->vm(); diff --git a/src/bun.js/bindings/generated_classes.zig b/src/bun.js/bindings/generated_classes.zig index 04a72d7ed..a220b6814 100644 --- a/src/bun.js/bindings/generated_classes.zig +++ b/src/bun.js/bindings/generated_classes.zig @@ -4426,6 +4426,9 @@ pub const JSTCPSocket = struct { @compileLog("TCPSocket.finalize is not a finalizer"); } + if (@TypeOf(TCPSocket.getALPNProtocol) != GetterType) + @compileLog("Expected TCPSocket.getALPNProtocol to be a getter"); + if (@TypeOf(TCPSocket.getAuthorized) != GetterType) @compileLog("Expected TCPSocket.getAuthorized to be a getter"); @@ -4446,6 +4449,8 @@ pub const JSTCPSocket = struct { if (@TypeOf(TCPSocket.getLocalPort) != GetterType) @compileLog("Expected TCPSocket.getLocalPort to be a getter"); + if (@TypeOf(TCPSocket.open) != CallbackType) + @compileLog("Expected TCPSocket.open to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.open))); if (@TypeOf(TCPSocket.getReadyState) != GetterType) @compileLog("Expected TCPSocket.getReadyState to be a getter"); @@ -4456,18 +4461,23 @@ pub const JSTCPSocket = struct { if (@TypeOf(TCPSocket.getRemoteAddress) != GetterType) @compileLog("Expected TCPSocket.getRemoteAddress to be a getter"); + if (@TypeOf(TCPSocket.setServername) != CallbackType) + @compileLog("Expected TCPSocket.setServername to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.setServername))); if (@TypeOf(TCPSocket.shutdown) != CallbackType) @compileLog("Expected TCPSocket.shutdown to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.shutdown))); if (@TypeOf(TCPSocket.timeout) != CallbackType) @compileLog("Expected TCPSocket.timeout to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.timeout))); if (@TypeOf(TCPSocket.unref) != CallbackType) @compileLog("Expected TCPSocket.unref to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.unref))); + if (@TypeOf(TCPSocket.wrapTLS) != CallbackType) + @compileLog("Expected TCPSocket.wrapTLS to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.wrapTLS))); if (@TypeOf(TCPSocket.write) != CallbackType) @compileLog("Expected TCPSocket.write to be a callback but received " ++ @typeName(@TypeOf(TCPSocket.write))); if (!JSC.is_bindgen) { @export(TCPSocket.end, .{ .name = "TCPSocketPrototype__end" }); @export(TCPSocket.finalize, .{ .name = "TCPSocketClass__finalize" }); @export(TCPSocket.flush, .{ .name = "TCPSocketPrototype__flush" }); + @export(TCPSocket.getALPNProtocol, .{ .name = "TCPSocketPrototype__getALPNProtocol" }); @export(TCPSocket.getAuthorizationError, .{ .name = "TCPSocketPrototype__getAuthorizationError" }); @export(TCPSocket.getAuthorized, .{ .name = "TCPSocketPrototype__getAuthorized" }); @export(TCPSocket.getData, .{ .name = "TCPSocketPrototype__getData" }); @@ -4476,12 +4486,15 @@ pub const JSTCPSocket = struct { @export(TCPSocket.getReadyState, .{ .name = "TCPSocketPrototype__getReadyState" }); @export(TCPSocket.getRemoteAddress, .{ .name = "TCPSocketPrototype__getRemoteAddress" }); @export(TCPSocket.hasPendingActivity, .{ .name = "TCPSocket__hasPendingActivity" }); + @export(TCPSocket.open, .{ .name = "TCPSocketPrototype__open" }); @export(TCPSocket.ref, .{ .name = "TCPSocketPrototype__ref" }); @export(TCPSocket.reload, .{ .name = "TCPSocketPrototype__reload" }); @export(TCPSocket.setData, .{ .name = "TCPSocketPrototype__setData" }); + @export(TCPSocket.setServername, .{ .name = "TCPSocketPrototype__setServername" }); @export(TCPSocket.shutdown, .{ .name = "TCPSocketPrototype__shutdown" }); @export(TCPSocket.timeout, .{ .name = "TCPSocketPrototype__timeout" }); @export(TCPSocket.unref, .{ .name = "TCPSocketPrototype__unref" }); + @export(TCPSocket.wrapTLS, .{ .name = "TCPSocketPrototype__wrapTLS" }); @export(TCPSocket.write, .{ .name = "TCPSocketPrototype__write" }); } } @@ -4581,6 +4594,9 @@ pub const JSTLSSocket = struct { @compileLog("TLSSocket.finalize is not a finalizer"); } + if (@TypeOf(TLSSocket.getALPNProtocol) != GetterType) + @compileLog("Expected TLSSocket.getALPNProtocol to be a getter"); + if (@TypeOf(TLSSocket.getAuthorized) != GetterType) @compileLog("Expected TLSSocket.getAuthorized to be a getter"); @@ -4601,6 +4617,8 @@ pub const JSTLSSocket = struct { if (@TypeOf(TLSSocket.getLocalPort) != GetterType) @compileLog("Expected TLSSocket.getLocalPort to be a getter"); + if (@TypeOf(TLSSocket.open) != CallbackType) + @compileLog("Expected TLSSocket.open to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.open))); if (@TypeOf(TLSSocket.getReadyState) != GetterType) @compileLog("Expected TLSSocket.getReadyState to be a getter"); @@ -4611,18 +4629,23 @@ pub const JSTLSSocket = struct { if (@TypeOf(TLSSocket.getRemoteAddress) != GetterType) @compileLog("Expected TLSSocket.getRemoteAddress to be a getter"); + if (@TypeOf(TLSSocket.setServername) != CallbackType) + @compileLog("Expected TLSSocket.setServername to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.setServername))); if (@TypeOf(TLSSocket.shutdown) != CallbackType) @compileLog("Expected TLSSocket.shutdown to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.shutdown))); if (@TypeOf(TLSSocket.timeout) != CallbackType) @compileLog("Expected TLSSocket.timeout to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.timeout))); if (@TypeOf(TLSSocket.unref) != CallbackType) @compileLog("Expected TLSSocket.unref to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.unref))); + if (@TypeOf(TLSSocket.wrapTLS) != CallbackType) + @compileLog("Expected TLSSocket.wrapTLS to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.wrapTLS))); if (@TypeOf(TLSSocket.write) != CallbackType) @compileLog("Expected TLSSocket.write to be a callback but received " ++ @typeName(@TypeOf(TLSSocket.write))); if (!JSC.is_bindgen) { @export(TLSSocket.end, .{ .name = "TLSSocketPrototype__end" }); @export(TLSSocket.finalize, .{ .name = "TLSSocketClass__finalize" }); @export(TLSSocket.flush, .{ .name = "TLSSocketPrototype__flush" }); + @export(TLSSocket.getALPNProtocol, .{ .name = "TLSSocketPrototype__getALPNProtocol" }); @export(TLSSocket.getAuthorizationError, .{ .name = "TLSSocketPrototype__getAuthorizationError" }); @export(TLSSocket.getAuthorized, .{ .name = "TLSSocketPrototype__getAuthorized" }); @export(TLSSocket.getData, .{ .name = "TLSSocketPrototype__getData" }); @@ -4631,12 +4654,15 @@ pub const JSTLSSocket = struct { @export(TLSSocket.getReadyState, .{ .name = "TLSSocketPrototype__getReadyState" }); @export(TLSSocket.getRemoteAddress, .{ .name = "TLSSocketPrototype__getRemoteAddress" }); @export(TLSSocket.hasPendingActivity, .{ .name = "TLSSocket__hasPendingActivity" }); + @export(TLSSocket.open, .{ .name = "TLSSocketPrototype__open" }); @export(TLSSocket.ref, .{ .name = "TLSSocketPrototype__ref" }); @export(TLSSocket.reload, .{ .name = "TLSSocketPrototype__reload" }); @export(TLSSocket.setData, .{ .name = "TLSSocketPrototype__setData" }); + @export(TLSSocket.setServername, .{ .name = "TLSSocketPrototype__setServername" }); @export(TLSSocket.shutdown, .{ .name = "TLSSocketPrototype__shutdown" }); @export(TLSSocket.timeout, .{ .name = "TLSSocketPrototype__timeout" }); @export(TLSSocket.unref, .{ .name = "TLSSocketPrototype__unref" }); + @export(TLSSocket.wrapTLS, .{ .name = "TLSSocketPrototype__wrapTLS" }); @export(TLSSocket.write, .{ .name = "TLSSocketPrototype__write" }); } } diff --git a/src/bun.js/bindings/webcore/JSEventEmitter.cpp b/src/bun.js/bindings/webcore/JSEventEmitter.cpp index 1957b404b..231ae0db4 100644 --- a/src/bun.js/bindings/webcore/JSEventEmitter.cpp +++ b/src/bun.js/bindings/webcore/JSEventEmitter.cpp @@ -149,7 +149,7 @@ static const HashTableValue JSEventEmitterPrototypeTableValues[] = { { "on"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_addListener, 2 } }, { "once"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_addOnceListener, 2 } }, { "prependListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependListener, 2 } }, - { "prependOnce"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependOnceListener, 2 } }, + { "prependOnceListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_prependOnceListener, 2 } }, { "removeListener"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeListener, 2 } }, { "off"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeListener, 2 } }, { "removeAllListeners"_s, static_cast(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsEventEmitterPrototypeFunction_removeAllListeners, 1 } }, diff --git a/src/deps/uws b/src/deps/uws index d82c4a95d..875948226 160000 --- a/src/deps/uws +++ b/src/deps/uws @@ -1 +1 @@ -Subproject commit d82c4a95de3af01614ecb12bfff821611b4cc6b7 +Subproject commit 875948226eede72861a5170212ff6b43c4b7d7f9 diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 8ebe04ac0..5dbe4f5d8 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -40,6 +40,129 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return us_socket_timeout(comptime ssl_int, this.socket, seconds); } + pub fn open(this: ThisSocket, is_client: bool) void { + _ = us_socket_open(comptime ssl_int, this.socket, @intFromBool(is_client), null, 0); + } + + // Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context + // context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext + pub fn wrapTLS( + this: ThisSocket, + options: us_bun_socket_context_options_t, + socket_ext_size: i32, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) ?NewSocketHandler(true) { + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + const TLSSocket = NewSocketHandler(true); + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(1, socket).?; + } + + if (comptime deref_) { + return (TLSSocket{ .socket = socket }).ext(ContextType).?.*; + } + + return (TLSSocket{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + TLSSocket{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + TLSSocket{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + TLSSocket{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + TLSSocket{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), TLSSocket{ .socket = socket }, success, verify_error); + } + }; + + var events: us_socket_events_t = .{}; + + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) + events.on_open = SocketHandler.on_open; + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) + events.on_close = SocketHandler.on_close; + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) + events.on_data = SocketHandler.on_data; + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) + events.on_writable = SocketHandler.on_writable; + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) + events.on_timeout = SocketHandler.on_timeout; + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) + events.on_connect_error = SocketHandler.on_connect_error; + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) + events.on_end = SocketHandler.on_end; + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) + events.on_handshake = SocketHandler.on_handshake; + + const socket = us_socket_wrap_with_tls(ssl_int, this.socket, options, events, socket_ext_size) orelse return null; + return NewSocketHandler(true).from(socket); + } + pub fn getNativeHandle(this: ThisSocket) *NativeSocketHandleType(ssl) { return @ptrCast(*NativeSocketHandleType(ssl), us_socket_get_native_handle(comptime ssl_int, this.socket).?); } @@ -95,6 +218,17 @@ pub fn NewSocketHandler(comptime ssl: bool) type { @as(i32, @intFromBool(msg_more)), ); } + + pub fn rawWrite(this: ThisSocket, data: []const u8, msg_more: bool) i32 { + return us_socket_raw_write( + comptime ssl_int, + this.socket, + data.ptr, + // truncate to 31 bits since sign bit exists + @intCast(i32, @truncate(u31, data.len)), + @as(i32, @intFromBool(msg_more)), + ); + } pub fn shutdown(this: ThisSocket) void { debug("us_socket_shutdown({d})", .{@intFromPtr(this.socket)}); return us_socket_shutdown( @@ -241,13 +375,126 @@ pub fn NewSocketHandler(comptime ssl: bool) type { return socket_; } + pub fn unsafeConfigure( + ctx: *SocketContext, + comptime ssl_type: bool, + comptime deref: bool, + comptime ContextType: type, + comptime Fields: anytype, + ) void { + const SocketHandlerType = NewSocketHandler(ssl_type); + const ssl_type_int: i32 = @intFromBool(ssl_type); + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + + const SocketHandler = struct { + const alignment = if (ContextType == anyopaque) + @sizeOf(usize) + else + std.meta.alignment(ContextType); + const deref_ = deref; + const ValueType = if (deref) ContextType else *ContextType; + fn getValue(socket: *Socket) ValueType { + if (comptime ContextType == anyopaque) { + return us_socket_ext(ssl_type_int, socket).?; + } + + if (comptime deref_) { + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?.*; + } + + return (SocketHandlerType{ .socket = socket }).ext(ContextType).?; + } + + pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + if (comptime @hasDecl(Fields, "onCreate")) { + if (is_client == 0) { + Fields.onCreate( + SocketHandlerType{ .socket = socket }, + ); + } + } + Fields.onOpen( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { + Fields.onClose( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + reason, + ); + return socket; + } + pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { + Fields.onData( + getValue(socket), + SocketHandlerType{ .socket = socket }, + buf.?[0..@intCast(usize, len)], + ); + return socket; + } + pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { + Fields.onWritable( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { + Fields.onTimeout( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + Fields.onConnectError( + getValue(socket), + SocketHandlerType{ .socket = socket }, + code, + ); + return socket; + } + pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { + Fields.onEnd( + getValue(socket), + SocketHandlerType{ .socket = socket }, + ); + return socket; + } + pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { + Fields.onHandshake(getValue(socket), SocketHandlerType{ .socket = socket }, success, verify_error); + } + }; + + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) + us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) + us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) + us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) + us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) + us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) + us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) + us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) + us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); + } + pub fn configure( ctx: *SocketContext, comptime deref: bool, comptime ContextType: type, comptime Fields: anytype, ) void { - const @"type" = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; + const Type = comptime if (@TypeOf(Fields) != type) @TypeOf(Fields) else Fields; const SocketHandler = struct { const alignment = if (ContextType == anyopaque) @@ -333,21 +580,21 @@ pub fn NewSocketHandler(comptime ssl: bool) type { } }; - if (comptime @hasDecl(@"type", "onOpen") and @typeInfo(@TypeOf(@"type".onOpen)) != .Null) + if (comptime @hasDecl(Type, "onOpen") and @typeInfo(@TypeOf(Type.onOpen)) != .Null) us_socket_context_on_open(ssl_int, ctx, SocketHandler.on_open); - if (comptime @hasDecl(@"type", "onClose") and @typeInfo(@TypeOf(@"type".onClose)) != .Null) + if (comptime @hasDecl(Type, "onClose") and @typeInfo(@TypeOf(Type.onClose)) != .Null) us_socket_context_on_close(ssl_int, ctx, SocketHandler.on_close); - if (comptime @hasDecl(@"type", "onData") and @typeInfo(@TypeOf(@"type".onData)) != .Null) + if (comptime @hasDecl(Type, "onData") and @typeInfo(@TypeOf(Type.onData)) != .Null) us_socket_context_on_data(ssl_int, ctx, SocketHandler.on_data); - if (comptime @hasDecl(@"type", "onWritable") and @typeInfo(@TypeOf(@"type".onWritable)) != .Null) + if (comptime @hasDecl(Type, "onWritable") and @typeInfo(@TypeOf(Type.onWritable)) != .Null) us_socket_context_on_writable(ssl_int, ctx, SocketHandler.on_writable); - if (comptime @hasDecl(@"type", "onTimeout") and @typeInfo(@TypeOf(@"type".onTimeout)) != .Null) + if (comptime @hasDecl(Type, "onTimeout") and @typeInfo(@TypeOf(Type.onTimeout)) != .Null) us_socket_context_on_timeout(ssl_int, ctx, SocketHandler.on_timeout); - if (comptime @hasDecl(@"type", "onConnectError") and @typeInfo(@TypeOf(@"type".onConnectError)) != .Null) + if (comptime @hasDecl(Type, "onConnectError") and @typeInfo(@TypeOf(Type.onConnectError)) != .Null) us_socket_context_on_connect_error(ssl_int, ctx, SocketHandler.on_connect_error); - if (comptime @hasDecl(@"type", "onEnd") and @typeInfo(@TypeOf(@"type".onEnd)) != .Null) + if (comptime @hasDecl(Type, "onEnd") and @typeInfo(@TypeOf(Type.onEnd)) != .Null) us_socket_context_on_end(ssl_int, ctx, SocketHandler.on_end); - if (comptime @hasDecl(@"type", "onHandshake") and @typeInfo(@TypeOf(@"type".onHandshake)) != .Null) + if (comptime @hasDecl(Type, "onHandshake") and @typeInfo(@TypeOf(Type.onHandshake)) != .Null) us_socket_context_on_handshake(ssl_int, ctx, SocketHandler.on_handshake, null); } @@ -659,6 +906,20 @@ pub const us_bun_verify_error_t = extern struct { reason: [*c]const u8 = null, }; +pub const us_socket_events_t = extern struct { + on_open: ?*const fn (*Socket, i32, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_data: ?*const fn (*Socket, [*c]u8, i32) callconv(.C) ?*Socket = null, + on_writable: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_close: ?*const fn (*Socket, i32, ?*anyopaque) callconv(.C) ?*Socket = null, + + on_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_long_timeout: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_end: ?*const fn (*Socket) callconv(.C) ?*Socket = null, + on_connect_error: ?*const fn (*Socket, i32) callconv(.C) ?*Socket = null, + on_handshake: ?*const fn (*Socket, i32, us_bun_verify_error_t, ?*anyopaque) callconv(.C) void = null, +}; + +pub extern fn us_socket_wrap_with_tls(ssl: i32, s: *Socket, options: us_bun_socket_context_options_t, events: us_socket_events_t, socket_ext_size: i32) ?*Socket; extern fn us_socket_verify_error(ssl: i32, context: *Socket) us_bun_verify_error_t; extern fn SocketContextimestamp(ssl: i32, context: ?*SocketContext) c_ushort; pub extern fn us_socket_context_add_server_name(ssl: i32, context: ?*SocketContext, hostname_pattern: [*c]const u8, options: us_socket_context_options_t, ?*anyopaque) void; @@ -777,11 +1038,16 @@ extern fn us_socket_ext(ssl: i32, s: ?*Socket) ?*anyopaque; extern fn us_socket_context(ssl: i32, s: ?*Socket) ?*SocketContext; extern fn us_socket_flush(ssl: i32, s: ?*Socket) void; extern fn us_socket_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; +extern fn us_socket_raw_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; extern fn us_socket_shutdown(ssl: i32, s: ?*Socket) void; extern fn us_socket_shutdown_read(ssl: i32, s: ?*Socket) void; extern fn us_socket_is_shut_down(ssl: i32, s: ?*Socket) i32; extern fn us_socket_is_closed(ssl: i32, s: ?*Socket) i32; extern fn us_socket_close(ssl: i32, s: ?*Socket, code: i32, reason: ?*anyopaque) ?*Socket; +// if a TLS socket calls this, it will start SSL instance and call open event will also do TLS handshake if required +// will have no effect if the socket is closed or is not TLS +extern fn us_socket_open(ssl: i32, s: ?*Socket, is_client: i32, ip: [*c]const u8, ip_length: i32) ?*Socket; + extern fn us_socket_local_port(ssl: i32, s: ?*Socket) i32; extern fn us_socket_remote_address(ssl: i32, s: ?*Socket, buf: [*c]u8, length: [*c]i32) void; pub const uws_app_s = opaque {}; diff --git a/src/js/node/net.js b/src/js/node/net.js index 430a0dfa2..1b7742dd1 100644 --- a/src/js/node/net.js +++ b/src/js/node/net.js @@ -64,6 +64,7 @@ const bunTlsSymbol = Symbol.for("::buntls::"); const bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"); const bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"); const bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"); +const bunSocketInternal = Symbol.for("::bunnetsocketinternal::"); var SocketClass; const Socket = (function (InternalSocket) { @@ -117,7 +118,7 @@ const Socket = (function (InternalSocket) { const self = socket.data; socket.timeout(self.timeout); socket.ref(); - self.#socket = socket; + self[bunSocketInternal] = socket; self.connecting = false; self.emit("connect", self); Socket.#Drain(socket); @@ -164,7 +165,7 @@ const Socket = (function (InternalSocket) { if (self.#closed) return; self.#closed = true; //socket cannot be used after close - self.#socket = null; + self[bunSocketInternal] = null; const queue = self.#readQueue; if (queue.isEmpty()) { if (self.push(null)) return; @@ -289,23 +290,33 @@ const Socket = (function (InternalSocket) { localAddress = "127.0.0.1"; #readQueue = createFIFO(); remotePort; - #socket; + [bunSocketInternal] = null; timeout = 0; #writeCallback; #writeChunk; #pendingRead; isServer = false; + _handle; + _parent; + _parentWrap; + #socket; constructor(options) { - const { signal, write, read, allowHalfOpen = false, ...opts } = options || {}; + const { socket, signal, write, read, allowHalfOpen = false, ...opts } = options || {}; super({ ...opts, allowHalfOpen, readable: true, writable: true, }); + this._handle = this; + this._parent = this; + this._parentWrap = this; this.#pendingRead = undefined; + if (socket instanceof Socket) { + this.#socket = socket; + } signal?.once("abort", () => this.destroy()); this.once("connect", () => this.emit("ready")); } @@ -327,7 +338,7 @@ const Socket = (function (InternalSocket) { socket.data = this; socket.timeout(this.timeout); socket.ref(); - this.#socket = socket; + this[bunSocketInternal] = socket; this.connecting = false; this.emit("connect", this); Socket.#Drain(socket); @@ -335,6 +346,7 @@ const Socket = (function (InternalSocket) { connect(port, host, connectListener) { var path; + var connection = this.#socket; if (typeof port === "string") { path = port; port = undefined; @@ -357,6 +369,7 @@ const Socket = (function (InternalSocket) { port, host, path, + socket, // TODOs localAddress, localPort, @@ -371,7 +384,11 @@ const Socket = (function (InternalSocket) { pauseOnConnect, servername, } = port; + this.servername = servername; + if (socket) { + connection = socket; + } } if (!pauseOnConnect) { @@ -399,41 +416,117 @@ const Socket = (function (InternalSocket) { } else { tls.rejectUnauthorized = rejectUnauthorized; tls.requestCert = true; + if (!connection && tls.socket) { + connection = tls.socket; + } + } + } + if (connection) { + if ( + typeof connection !== "object" || + !(connection instanceof Socket) || + typeof connection[bunTlsSymbol] === "function" + ) { + throw new TypeError("socket must be an instance of net.Socket"); } } - this.authorized = false; this.secureConnecting = true; this._secureEstablished = false; this._securePending = true; if (connectListener) this.on("secureConnect", connectListener); } else if (connectListener) this.on("connect", connectListener); - bunConnect( - path - ? { + // start using existing connection + + if (connection) { + const socket = connection[bunSocketInternal]; + + if (socket) { + const result = socket.wrapTLS({ + data: this, + tls, + socket: Socket.#Handlers, + }); + if (result) { + const [raw, tls] = result; + // replace socket + connection[bunSocketInternal] = raw; + raw.timeout(raw.timeout); + raw.connecting = false; + // set new socket + this[bunSocketInternal] = tls; + tls.timeout(tls.timeout); + tls.connecting = true; + this[bunSocketInternal] = socket; + // start tls + tls.open(); + } else { + this[bunSocketInternal] = null; + throw new Error("Invalid socket"); + } + } else { + // wait to be connected + connection.once("connect", () => { + const socket = connection[bunSocketInternal]; + if (!socket) return; + + const result = socket.wrapTLS({ data: this, - unix: path, - socket: Socket.#Handlers, tls, - } - : { - data: this, - hostname: host || "localhost", - port: port, socket: Socket.#Handlers, - tls, - }, - ); + }); + + if (result) { + const [raw, tls] = result; + // replace socket + connection[bunSocketInternal] = raw; + raw.timeout(raw.timeout); + raw.connecting = false; + // set new socket + this[bunSocketInternal] = tls; + tls.timeout(tls.timeout); + tls.connecting = true; + this[bunSocketInternal] = socket; + // start tls + tls.open(); + } else { + this[bunSocketInternal] = null; + throw new Error("Invalid socket"); + } + }); + } + } else if (path) { + // start using unix socket + bunConnect({ + data: this, + unix: path, + socket: Socket.#Handlers, + tls, + }).catch(error => { + this.emit("error", error); + }); + } else { + // default start + bunConnect({ + data: this, + hostname: host || "localhost", + port: port, + socket: Socket.#Handlers, + tls, + }).catch(error => { + this.emit("error", error); + }); + } return this; } _destroy(err, callback) { - this.#socket?.end(); + this[bunSocketInternal]?.end(); callback(err); } _final(callback) { - this.#socket?.end(); + this[bunSocketInternal]?.end(); callback(); } @@ -446,7 +539,7 @@ const Socket = (function (InternalSocket) { } get localPort() { - return this.#socket?.localPort; + return this[bunSocketInternal]?.localPort; } get pending() { @@ -472,11 +565,11 @@ const Socket = (function (InternalSocket) { } ref() { - this.#socket?.ref(); + this[bunSocketInternal]?.ref(); } get remoteAddress() { - return this.#socket?.remoteAddress; + return this[bunSocketInternal]?.remoteAddress; } get remoteFamily() { @@ -484,7 +577,7 @@ const Socket = (function (InternalSocket) { } resetAndDestroy() { - this.#socket?.end(); + this[bunSocketInternal]?.end(); } setKeepAlive(enable = false, initialDelay = 0) { @@ -498,19 +591,19 @@ const Socket = (function (InternalSocket) { } setTimeout(timeout, callback) { - this.#socket?.timeout(timeout); + this[bunSocketInternal]?.timeout(timeout); this.timeout = timeout; if (callback) this.once("timeout", callback); return this; } unref() { - this.#socket?.unref(); + this[bunSocketInternal]?.unref(); } _write(chunk, encoding, callback) { - if (typeof chunk == "string" && encoding !== "utf8") chunk = Buffer.from(chunk, encoding); - var written = this.#socket?.write(chunk); + if (typeof chunk == "string" && encoding !== "ascii") chunk = Buffer.from(chunk, encoding); + var written = this[bunSocketInternal]?.write(chunk); if (written == chunk.length) { callback(); } else if (this.#writeCallback) { diff --git a/src/js/node/tls.js b/src/js/node/tls.js index 356c25cbd..310a36620 100644 --- a/src/js/node/tls.js +++ b/src/js/node/tls.js @@ -1,9 +1,30 @@ // Hardcoded module "node:tls" -import { isTypedArray } from "util/types"; +import { isArrayBufferView, isTypedArray } from "util/types"; import net, { Server as NetServer } from "node:net"; const InternalTCPSocket = net[Symbol.for("::bunternal::")]; - +const bunSocketInternal = Symbol.for("::bunnetsocketinternal::"); + +const { RegExp, Array, String } = globalThis[Symbol.for("Bun.lazy")]("primordials"); +const SymbolReplace = Symbol.replace; +const RegExpPrototypeSymbolReplace = RegExp.prototype[SymbolReplace]; +const RegExpPrototypeExec = RegExp.prototype.exec; + +const StringPrototypeStartsWith = String.prototype.startsWith; +const StringPrototypeSlice = String.prototype.slice; +const StringPrototypeIncludes = String.prototype.includes; +const StringPrototypeSplit = String.prototype.split; +const StringPrototypeIndexOf = String.prototype.indexOf; +const StringPrototypeSubstring = String.prototype.substring; +const StringPrototypeEndsWith = String.prototype.endsWith; + +const ArrayPrototypeIncludes = Array.prototype.includes; +const ArrayPrototypeJoin = Array.prototype.join; +const ArrayPrototypeForEach = Array.prototype.forEach; +const ArrayPrototypePush = Array.prototype.push; +const ArrayPrototypeSome = Array.prototype.some; +const ArrayPrototypeReduce = Array.prototype.reduce; function parseCertString() { + // Removed since JAN 2022 Node v18.0.0+ https://github.com/nodejs/node/pull/41479 throwNotImplemented("Not implemented"); } @@ -18,6 +39,164 @@ function isValidTLSArray(obj) { } } +function unfqdn(host) { + return RegExpPrototypeSymbolReplace(/[.]$/, host, ""); +} + +function splitHost(host) { + return StringPrototypeSplit.call(RegExpPrototypeSymbolReplace(/[A-Z]/g, unfqdn(host), toLowerCase), "."); +} + +function check(hostParts, pattern, wildcards) { + // Empty strings, null, undefined, etc. never match. + if (!pattern) return false; + + const patternParts = splitHost(pattern); + + if (hostParts.length !== patternParts.length) return false; + + // Pattern has empty components, e.g. "bad..example.com". + if (ArrayPrototypeIncludes.call(patternParts, "")) return false; + + // RFC 6125 allows IDNA U-labels (Unicode) in names but we have no + // good way to detect their encoding or normalize them so we simply + // reject them. Control characters and blanks are rejected as well + // because nothing good can come from accepting them. + const isBad = s => RegExpPrototypeExec.call(/[^\u0021-\u007F]/u, s) !== null; + if (ArrayPrototypeSome.call(patternParts, isBad)) return false; + + // Check host parts from right to left first. + for (let i = hostParts.length - 1; i > 0; i -= 1) { + if (hostParts[i] !== patternParts[i]) return false; + } + + const hostSubdomain = hostParts[0]; + const patternSubdomain = patternParts[0]; + const patternSubdomainParts = StringPrototypeSplit.call(patternSubdomain, "*"); + + // Short-circuit when the subdomain does not contain a wildcard. + // RFC 6125 does not allow wildcard substitution for components + // containing IDNA A-labels (Punycode) so match those verbatim. + if (patternSubdomainParts.length === 1 || StringPrototypeIncludes.call(patternSubdomain, "xn--")) + return hostSubdomain === patternSubdomain; + + if (!wildcards) return false; + + // More than one wildcard is always wrong. + if (patternSubdomainParts.length > 2) return false; + + // *.tld wildcards are not allowed. + if (patternParts.length <= 2) return false; + + const { 0: prefix, 1: suffix } = patternSubdomainParts; + + if (prefix.length + suffix.length > hostSubdomain.length) return false; + + if (!StringPrototypeStartsWith.call(hostSubdomain, prefix)) return false; + + if (!StringPrototypeEndsWith.call(hostSubdomain, suffix)) return false; + + return true; +} + +// This pattern is used to determine the length of escaped sequences within +// the subject alt names string. It allows any valid JSON string literal. +// This MUST match the JSON specification (ECMA-404 / RFC8259) exactly. +const jsonStringPattern = + // eslint-disable-next-line no-control-regex + /^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/; + +function splitEscapedAltNames(altNames) { + const result = []; + let currentToken = ""; + let offset = 0; + while (offset !== altNames.length) { + const nextSep = StringPrototypeIndexOf.call(altNames, ", ", offset); + const nextQuote = StringPrototypeIndexOf.call(altNames, '"', offset); + if (nextQuote !== -1 && (nextSep === -1 || nextQuote < nextSep)) { + // There is a quote character and there is no separator before the quote. + currentToken += StringPrototypeSubstring.call(altNames, offset, nextQuote); + const match = RegExpPrototypeExec.call(jsonStringPattern, StringPrototypeSubstring.call(altNames, nextQuote)); + if (!match) { + let error = new SyntaxError("ERR_TLS_CERT_ALTNAME_FORMAT: Invalid subject alternative name string"); + error.name = ERR_TLS_CERT_ALTNAME_FORMAT; + throw error; + } + currentToken += JSON.parse(match[0]); + offset = nextQuote + match[0].length; + } else if (nextSep !== -1) { + // There is a separator and no quote before it. + currentToken += StringPrototypeSubstring.call(altNames, offset, nextSep); + ArrayPrototypePush.call(result, currentToken); + currentToken = ""; + offset = nextSep + 2; + } else { + currentToken += StringPrototypeSubstring.call(altNames, offset); + offset = altNames.length; + } + } + ArrayPrototypePush.call(result, currentToken); + return result; +} +function checkServerIdentity(hostname, cert) { + const subject = cert.subject; + const altNames = cert.subjectaltname; + const dnsNames = []; + const ips = []; + + hostname = "" + hostname; + + if (altNames) { + const splitAltNames = StringPrototypeIncludes.call(altNames, '"') + ? splitEscapedAltNames(altNames) + : StringPrototypeSplit.call(altNames, ", "); + ArrayPrototypeForEach.call(splitAltNames, name => { + if (StringPrototypeStartsWith.call(name, "DNS:")) { + ArrayPrototypePush.call(dnsNames, StringPrototypeSlice.call(name, 4)); + } else if (StringPrototypeStartsWith.call(name, "IP Address:")) { + ArrayPrototypePush.call(ips, canonicalizeIP(StringPrototypeSlice.call(name, 11))); + } + }); + } + + let valid = false; + let reason = "Unknown reason"; + + hostname = unfqdn(hostname); // Remove trailing dot for error messages. + + if (net.isIP(hostname)) { + valid = ArrayPrototypeIncludes.call(ips, canonicalizeIP(hostname)); + if (!valid) reason = `IP: ${hostname} is not in the cert's list: ` + ArrayPrototypeJoin.call(ips, ", "); + } else if (dnsNames.length > 0 || subject?.CN) { + const hostParts = splitHost(hostname); + const wildcard = pattern => check(hostParts, pattern, true); + + if (dnsNames.length > 0) { + valid = ArrayPrototypeSome.call(dnsNames, wildcard); + if (!valid) reason = `Host: ${hostname}. is not in the cert's altnames: ${altNames}`; + } else { + // Match against Common Name only if no supported identifiers exist. + const cn = subject.CN; + + if (ArrayIsArray(cn)) valid = ArrayPrototypeSome.call(cn, wildcard); + else if (cn) valid = wildcard(cn); + + if (!valid) reason = `Host: ${hostname}. is not cert's CN: ${cn}`; + } + } else { + reason = "Cert does not contain a DNS name"; + } + + if (!valid) { + let error = new Error(`ERR_TLS_CERT_ALTNAME_INVALID: Hostname/IP does not match certificate's altnames: ${reason}`); + error.name = "ERR_TLS_CERT_ALTNAME_INVALID"; + error.reason = reason; + error.host = host; + error.cert = cert; + return error; + } +} + var InternalSecureContext = class SecureContext { context; @@ -83,6 +262,36 @@ function createSecureContext(options) { return new SecureContext(options); } +// Translate some fields from the handle's C-friendly format into more idiomatic +// javascript object representations before passing them back to the user. Can +// be used on any cert object, but changing the name would be semver-major. +function translatePeerCertificate(c) { + if (!c) return null; + + if (c.issuerCertificate != null && c.issuerCertificate !== c) { + c.issuerCertificate = translatePeerCertificate(c.issuerCertificate); + } + if (c.infoAccess != null) { + const info = c.infoAccess; + c.infoAccess = { __proto__: null }; + + // XXX: More key validation? + RegExpPrototypeSymbolReplace(/([^\n:]*):([^\n]*)(?:\n|$)/g, info, (all, key, val) => { + if (val.charCodeAt(0) === 0x22) { + // The translatePeerCertificate function is only + // used on internally created legacy certificate + // objects, and any value that contains a quote + // will always be a valid JSON string literal, + // so this should never throw. + val = JSONParse(val); + } + if (key in c.infoAccess) ArrayPrototypePush.call(c.infoAccess[key], val); + else c.infoAccess[key] = [val]; + }); + } + return c; +} + const buntls = Symbol.for("::buntls::"); var SocketClass; @@ -107,8 +316,22 @@ const TLSSocket = (function (InternalTLSSocket) { })( class TLSSocket extends InternalTCPSocket { #secureContext; - constructor(options) { - super(options); + ALPNProtocols; + #socket; + + constructor(socket, options) { + super(socket instanceof InternalTCPSocket ? options : options || socket); + options = options || socket || {}; + if (typeof options === "object") { + const { ALPNProtocols } = options; + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, this); + } + if (socket instanceof InternalTCPSocket) { + this.#socket = socket; + } + } + this.#secureContext = options.secureContext || createSecureContext(options); this.authorized = false; this.secureConnecting = true; @@ -123,28 +346,52 @@ const TLSSocket = (function (InternalTLSSocket) { secureConnecting = false; _SNICallback; servername; - alpnProtocol; authorized = false; authorizationError; encrypted = true; - exportKeyingMaterial() { - throw Error("Not implented in Bun yet"); + _start() { + // some frameworks uses this _start internal implementation is suposed to start TLS handshake + // on Bun we auto start this after on_open callback and when wrapping we start it after the socket is attached to the net.Socket/tls.Socket } - setMaxSendFragment() { + + exportKeyingMaterial(length, label, context) { + //SSL_export_keying_material throw Error("Not implented in Bun yet"); } - setServername() { + setMaxSendFragment(size) { + // SSL_set_max_send_fragment throw Error("Not implented in Bun yet"); } + setServername(name) { + if (this.isServer) { + let error = new Error("ERR_TLS_SNI_FROM_SERVER: Cannot issue SNI from a TLS server-side socket"); + error.name = "ERR_TLS_SNI_FROM_SERVER"; + throw error; + } + // if the socket is detached we can't set the servername but we set this property so when open will auto set to it + this.servername = name; + this[bunSocketInternal]?.setServername(name); + } setSession() { throw Error("Not implented in Bun yet"); } getPeerCertificate() { + // need to implement peerCertificate on socket.zig + // const cert = this[bunSocketInternal]?.peerCertificate; + // if(cert) { + // return translatePeerCertificate(cert); + // } throw Error("Not implented in Bun yet"); } getCertificate() { + // need to implement certificate on socket.zig + // const cert = this[bunSocketInternal]?.certificate; + // if(cert) { + // It's not a peer cert, but the formatting is identical. + // return translatePeerCertificate(cert); + // } throw Error("Not implented in Bun yet"); } getPeerX509Certificate() { @@ -154,16 +401,17 @@ const TLSSocket = (function (InternalTLSSocket) { throw Error("Not implented in Bun yet"); } - [buntls](port, host) { - var { servername } = this; - if (servername) { - return { - serverName: typeof servername === "string" ? servername : host, - ...this.#secureContext, - }; - } + get alpnProtocol() { + return this[bunSocketInternal]?.alpnProtocol; + } - return true; + [buntls](port, host) { + return { + socket: this.#socket, + ALPNProtocols: this.ALPNProtocols, + serverName: this.servername || host || "localhost", + ...this.#secureContext, + }; } }, ); @@ -177,9 +425,12 @@ class Server extends NetServer { _rejectUnauthorized; _requestCert; servername; + ALPNProtocols; + #checkServerIdentity; constructor(options, secureConnectionListener) { super(options, secureConnectionListener); + this.#checkServerIdentity = options?.checkServerIdentity || checkServerIdentity; this.setSecureContext(options); } emit(event, args) { @@ -197,6 +448,12 @@ class Server extends NetServer { options = options.context; } if (options) { + const { ALPNProtocols } = options; + + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, this); + } + let key = options.key; if (key) { if (!isValidTLSArray(key)) { @@ -277,6 +534,8 @@ class Server extends NetServer { // Client always is NONE on set_verify rejectUnauthorized: isClient ? false : this._rejectUnauthorized, requestCert: isClient ? false : this._requestCert, + ALPNProtocols: this.ALPNProtocols, + checkServerIdentity: this.#checkServerIdentity, }, SocketClass, ]; @@ -296,6 +555,11 @@ const CLIENT_RENEG_LIMIT = 3, DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host, connectListener) => { if (typeof port === "object") { + port.checkServerIdentity || checkServerIdentity; + const { ALPNProtocols } = port; + if (ALPNProtocols) { + convertALPNProtocols(ALPNProtocols, port); + } // port is option pass Socket options and let connect handle connection options return new TLSSocket(port).connect(port, host, connectListener); } @@ -312,7 +576,55 @@ function getCurves() { return; } -function convertALPNProtocols(protocols, out) {} +// Convert protocols array into valid OpenSSL protocols list +// ("\x06spdy/2\x08http/1.1\x08http/1.0") +function convertProtocols(protocols) { + const lens = new Array(protocols.length); + const buff = Buffer.allocUnsafe( + ArrayPrototypeReduce.call( + protocols, + (p, c, i) => { + const len = Buffer.byteLength(c); + if (len > 255) { + throw new RangeError( + "The byte length of the protocol at index " + `${i} exceeds the maximum length.`, + "<= 255", + len, + true, + ); + } + lens[i] = len; + return p + 1 + len; + }, + 0, + ), + ); + + let offset = 0; + for (let i = 0, c = protocols.length; i < c; i++) { + buff[offset++] = lens[i]; + buff.write(protocols[i], offset); + offset += lens[i]; + } + + return buff; +} + +function convertALPNProtocols(protocols, out) { + // If protocols is Array - translate it into buffer + if (Array.isArray(protocols)) { + out.ALPNProtocols = convertProtocols(protocols); + } else if (isTypedArray(protocols)) { + // Copy new buffer not to be modified by user. + out.ALPNProtocols = Buffer.from(protocols); + } else if (isArrayBufferView(protocols)) { + out.ALPNProtocols = Buffer.from( + protocols.buffer.slice(protocols.byteOffset, protocols.byteOffset + protocols.byteLength), + ); + } else if (Buffer.isBuffer(protocols)) { + out.ALPNProtocols = protocols; + } +} var exports = { [Symbol.for("CommonJS")]: 0, @@ -351,6 +663,7 @@ export { getCurves, parseCertString, SecureContext, + checkServerIdentity, Server, TLSSocket, exports as default, diff --git a/src/js/out/modules/node/net.js b/src/js/out/modules/node/net.js index 164ec6677..c34f86b04 100644 --- a/src/js/out/modules/node/net.js +++ b/src/js/out/modules/node/net.js @@ -26,7 +26,7 @@ var isIPv4 = function(s) { self.emit("listening"); }, createServer = function(options, connectionListener) { return new Server(options, connectionListener); -}, v4Seg = "(?:[0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])", v4Str = `(${v4Seg}[.]){3}${v4Seg}`, IPv4Reg = new RegExp(`^${v4Str}$`), v6Seg = "(?:[0-9a-fA-F]{1,4})", IPv6Reg = new RegExp("^(" + `(?:${v6Seg}:){7}(?:${v6Seg}|:)|` + `(?:${v6Seg}:){6}(?:${v4Str}|:${v6Seg}|:)|` + `(?:${v6Seg}:){5}(?::${v4Str}|(:${v6Seg}){1,2}|:)|` + `(?:${v6Seg}:){4}(?:(:${v6Seg}){0,1}:${v4Str}|(:${v6Seg}){1,3}|:)|` + `(?:${v6Seg}:){3}(?:(:${v6Seg}){0,2}:${v4Str}|(:${v6Seg}){1,4}|:)|` + `(?:${v6Seg}:){2}(?:(:${v6Seg}){0,3}:${v4Str}|(:${v6Seg}){1,5}|:)|` + `(?:${v6Seg}:){1}(?:(:${v6Seg}){0,4}:${v4Str}|(:${v6Seg}){1,6}|:)|` + `(?::((?::${v6Seg}){0,5}:${v4Str}|(?::${v6Seg}){1,7}|:))` + ")(%[0-9a-zA-Z-.:]{1,})?$"), { Bun, createFIFO, Object } = globalThis[Symbol.for("Bun.lazy")]("primordials"), { connect: bunConnect } = Bun, { setTimeout } = globalThis, bunTlsSymbol = Symbol.for("::buntls::"), bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"), bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"), bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"), SocketClass, Socket = function(InternalSocket) { +}, v4Seg = "(?:[0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])", v4Str = `(${v4Seg}[.]){3}${v4Seg}`, IPv4Reg = new RegExp(`^${v4Str}$`), v6Seg = "(?:[0-9a-fA-F]{1,4})", IPv6Reg = new RegExp("^(" + `(?:${v6Seg}:){7}(?:${v6Seg}|:)|` + `(?:${v6Seg}:){6}(?:${v4Str}|:${v6Seg}|:)|` + `(?:${v6Seg}:){5}(?::${v4Str}|(:${v6Seg}){1,2}|:)|` + `(?:${v6Seg}:){4}(?:(:${v6Seg}){0,1}:${v4Str}|(:${v6Seg}){1,3}|:)|` + `(?:${v6Seg}:){3}(?:(:${v6Seg}){0,2}:${v4Str}|(:${v6Seg}){1,4}|:)|` + `(?:${v6Seg}:){2}(?:(:${v6Seg}){0,3}:${v4Str}|(:${v6Seg}){1,5}|:)|` + `(?:${v6Seg}:){1}(?:(:${v6Seg}){0,4}:${v4Str}|(:${v6Seg}){1,6}|:)|` + `(?::((?::${v6Seg}){0,5}:${v4Str}|(?::${v6Seg}){1,7}|:))` + ")(%[0-9a-zA-Z-.:]{1,})?$"), { Bun, createFIFO, Object } = globalThis[Symbol.for("Bun.lazy")]("primordials"), { connect: bunConnect } = Bun, { setTimeout } = globalThis, bunTlsSymbol = Symbol.for("::buntls::"), bunSocketServerHandlers = Symbol.for("::bunsocket_serverhandlers::"), bunSocketServerConnections = Symbol.for("::bunnetserverconnections::"), bunSocketServerOptions = Symbol.for("::bunnetserveroptions::"), bunSocketInternal = Symbol.for("::bunnetsocketinternal::"), SocketClass, Socket = function(InternalSocket) { return SocketClass = InternalSocket, Object.defineProperty(SocketClass.prototype, Symbol.toStringTag, { value: "Socket", enumerable: !1 @@ -62,7 +62,7 @@ var isIPv4 = function(s) { }, open(socket) { const self = socket.data; - socket.timeout(self.timeout), socket.ref(), self.#socket = socket, self.connecting = !1, self.emit("connect", self), Socket2.#Drain(socket); + socket.timeout(self.timeout), socket.ref(), self[bunSocketInternal] = socket, self.connecting = !1, self.emit("connect", self), Socket2.#Drain(socket); }, handshake(socket, success, verifyError) { const { data: self } = socket; @@ -87,7 +87,7 @@ var isIPv4 = function(s) { const self = socket.data; if (self.#closed) return; - self.#closed = !0, self.#socket = null; + self.#closed = !0, self[bunSocketInternal] = null; const queue = self.#readQueue; if (queue.isEmpty()) { if (self.push(null)) @@ -163,21 +163,27 @@ var isIPv4 = function(s) { localAddress = "127.0.0.1"; #readQueue = createFIFO(); remotePort; - #socket; + [bunSocketInternal] = null; timeout = 0; #writeCallback; #writeChunk; #pendingRead; isServer = !1; + _handle; + _parent; + _parentWrap; + #socket; constructor(options) { - const { signal, write, read, allowHalfOpen = !1, ...opts } = options || {}; + const { socket, signal, write, read, allowHalfOpen = !1, ...opts } = options || {}; super({ ...opts, allowHalfOpen, readable: !0, writable: !0 }); - this.#pendingRead = void 0, signal?.once("abort", () => this.destroy()), this.once("connect", () => this.emit("ready")); + if (this._handle = this, this._parent = this, this._parentWrap = this, this.#pendingRead = void 0, socket instanceof Socket2) + this.#socket = socket; + signal?.once("abort", () => this.destroy()), this.once("connect", () => this.emit("ready")); } address() { return { @@ -190,10 +196,10 @@ var isIPv4 = function(s) { return this.writableLength; } #attach(port, socket) { - this.remotePort = port, socket.data = this, socket.timeout(this.timeout), socket.ref(), this.#socket = socket, this.connecting = !1, this.emit("connect", this), Socket2.#Drain(socket); + this.remotePort = port, socket.data = this, socket.timeout(this.timeout), socket.ref(), this[bunSocketInternal] = socket, this.connecting = !1, this.emit("connect", this), Socket2.#Drain(socket); } connect(port, host, connectListener) { - var path; + var path, connection = this.#socket; if (typeof port === "string") { if (path = port, port = void 0, typeof host === "function") connectListener = host, host = void 0; @@ -207,6 +213,7 @@ var isIPv4 = function(s) { port, host, path, + socket, localAddress, localPort, family, @@ -220,7 +227,8 @@ var isIPv4 = function(s) { pauseOnConnect, servername } = port; - this.servername = servername; + if (this.servername = servername, socket) + connection = socket; } if (!pauseOnConnect) this.resume(); @@ -228,36 +236,78 @@ var isIPv4 = function(s) { const bunTLS = this[bunTlsSymbol]; var tls = void 0; if (typeof bunTLS === "function") { - if (tls = bunTLS.call(this, port, host, !0), this._requestCert = !0, this._rejectUnauthorized = rejectUnauthorized, tls) + if (tls = bunTLS.call(this, port, host, !0), this._requestCert = !0, this._rejectUnauthorized = rejectUnauthorized, tls) { if (typeof tls !== "object") tls = { rejectUnauthorized, requestCert: !0 }; - else - tls.rejectUnauthorized = rejectUnauthorized, tls.requestCert = !0; + else if (tls.rejectUnauthorized = rejectUnauthorized, tls.requestCert = !0, !connection && tls.socket) + connection = tls.socket; + } + if (connection) { + if (typeof connection !== "object" || !(connection instanceof Socket2) || typeof connection[bunTlsSymbol] === "function") + throw new TypeError("socket must be an instance of net.Socket"); + } if (this.authorized = !1, this.secureConnecting = !0, this._secureEstablished = !1, this._securePending = !0, connectListener) this.on("secureConnect", connectListener); } else if (connectListener) this.on("connect", connectListener); - return bunConnect(path ? { - data: this, - unix: path, - socket: Socket2.#Handlers, - tls - } : { - data: this, - hostname: host || "localhost", - port, - socket: Socket2.#Handlers, - tls - }), this; + if (connection) { + const socket2 = connection[bunSocketInternal]; + if (socket2) { + const result = socket2.wrapTLS({ + data: this, + tls, + socket: Socket2.#Handlers + }); + if (result) { + const [raw, tls2] = result; + connection[bunSocketInternal] = raw, raw.timeout(raw.timeout), raw.connecting = !1, this[bunSocketInternal] = tls2, tls2.timeout(tls2.timeout), tls2.connecting = !0, this[bunSocketInternal] = socket2, tls2.open(); + } else + throw this[bunSocketInternal] = null, new Error("Invalid socket"); + } else + connection.once("connect", () => { + const socket3 = connection[bunSocketInternal]; + if (!socket3) + return; + const result = socket3.wrapTLS({ + data: this, + tls, + socket: Socket2.#Handlers + }); + if (result) { + const [raw, tls2] = result; + connection[bunSocketInternal] = raw, raw.timeout(raw.timeout), raw.connecting = !1, this[bunSocketInternal] = tls2, tls2.timeout(tls2.timeout), tls2.connecting = !0, this[bunSocketInternal] = socket3, tls2.open(); + } else + throw this[bunSocketInternal] = null, new Error("Invalid socket"); + }); + } else if (path) + bunConnect({ + data: this, + unix: path, + socket: Socket2.#Handlers, + tls + }).catch((error) => { + this.emit("error", error); + }); + else + bunConnect({ + data: this, + hostname: host || "localhost", + port, + socket: Socket2.#Handlers, + tls + }).catch((error) => { + this.emit("error", error); + }); + return this; } _destroy(err, callback) { - this.#socket?.end(), callback(err); + this[bunSocketInternal]?.end(), callback(err); } _final(callback) { - this.#socket?.end(), callback(); + this[bunSocketInternal]?.end(), callback(); } get localAddress() { return "127.0.0.1"; @@ -266,7 +316,7 @@ var isIPv4 = function(s) { return "IPv4"; } get localPort() { - return this.#socket?.localPort; + return this[bunSocketInternal]?.localPort; } get pending() { return this.connecting; @@ -289,16 +339,16 @@ var isIPv4 = function(s) { return this.writable ? "writeOnly" : "closed"; } ref() { - this.#socket?.ref(); + this[bunSocketInternal]?.ref(); } get remoteAddress() { - return this.#socket?.remoteAddress; + return this[bunSocketInternal]?.remoteAddress; } get remoteFamily() { return "IPv4"; } resetAndDestroy() { - this.#socket?.end(); + this[bunSocketInternal]?.end(); } setKeepAlive(enable = !1, initialDelay = 0) { return this; @@ -307,17 +357,17 @@ var isIPv4 = function(s) { return this; } setTimeout(timeout, callback) { - if (this.#socket?.timeout(timeout), this.timeout = timeout, callback) + if (this[bunSocketInternal]?.timeout(timeout), this.timeout = timeout, callback) this.once("timeout", callback); return this; } unref() { - this.#socket?.unref(); + this[bunSocketInternal]?.unref(); } _write(chunk, encoding, callback) { - if (typeof chunk == "string" && encoding !== "utf8") + if (typeof chunk == "string" && encoding !== "ascii") chunk = Buffer.from(chunk, encoding); - var written = this.#socket?.write(chunk); + var written = this[bunSocketInternal]?.write(chunk); if (written == chunk.length) callback(); else if (this.#writeCallback) diff --git a/src/js/out/modules/node/tls.js b/src/js/out/modules/node/tls.js index 4cceadc7f..ca8a13270 100644 --- a/src/js/out/modules/node/tls.js +++ b/src/js/out/modules/node/tls.js @@ -1,4 +1,4 @@ -import {isTypedArray} from "node:util/types"; +import {isArrayBufferView, isTypedArray} from "node:util/types"; import net, {Server as NetServer} from "node:net"; var parseCertString = function() { throwNotImplemented("Not implemented"); @@ -11,18 +11,127 @@ var parseCertString = function() { return !1; return !0; } +}, unfqdn = function(host2) { + return RegExpPrototypeSymbolReplace(/[.]$/, host2, ""); +}, splitHost = function(host2) { + return StringPrototypeSplit.call(RegExpPrototypeSymbolReplace(/[A-Z]/g, unfqdn(host2), toLowerCase), "."); +}, check = function(hostParts, pattern, wildcards) { + if (!pattern) + return !1; + const patternParts = splitHost(pattern); + if (hostParts.length !== patternParts.length) + return !1; + if (ArrayPrototypeIncludes.call(patternParts, "")) + return !1; + const isBad = (s) => RegExpPrototypeExec.call(/[^\u0021-\u007F]/u, s) !== null; + if (ArrayPrototypeSome.call(patternParts, isBad)) + return !1; + for (let i = hostParts.length - 1;i > 0; i -= 1) + if (hostParts[i] !== patternParts[i]) + return !1; + const hostSubdomain = hostParts[0], patternSubdomain = patternParts[0], patternSubdomainParts = StringPrototypeSplit.call(patternSubdomain, "*"); + if (patternSubdomainParts.length === 1 || StringPrototypeIncludes.call(patternSubdomain, "xn--")) + return hostSubdomain === patternSubdomain; + if (!wildcards) + return !1; + if (patternSubdomainParts.length > 2) + return !1; + if (patternParts.length <= 2) + return !1; + const { 0: prefix, 1: suffix } = patternSubdomainParts; + if (prefix.length + suffix.length > hostSubdomain.length) + return !1; + if (!StringPrototypeStartsWith.call(hostSubdomain, prefix)) + return !1; + if (!StringPrototypeEndsWith.call(hostSubdomain, suffix)) + return !1; + return !0; +}, splitEscapedAltNames = function(altNames) { + const result = []; + let currentToken = "", offset = 0; + while (offset !== altNames.length) { + const nextSep = StringPrototypeIndexOf.call(altNames, ", ", offset), nextQuote = StringPrototypeIndexOf.call(altNames, '"', offset); + if (nextQuote !== -1 && (nextSep === -1 || nextQuote < nextSep)) { + currentToken += StringPrototypeSubstring.call(altNames, offset, nextQuote); + const match = RegExpPrototypeExec.call(jsonStringPattern, StringPrototypeSubstring.call(altNames, nextQuote)); + if (!match) { + let error = new SyntaxError("ERR_TLS_CERT_ALTNAME_FORMAT: Invalid subject alternative name string"); + throw error.name = ERR_TLS_CERT_ALTNAME_FORMAT, error; + } + currentToken += JSON.parse(match[0]), offset = nextQuote + match[0].length; + } else if (nextSep !== -1) + currentToken += StringPrototypeSubstring.call(altNames, offset, nextSep), ArrayPrototypePush.call(result, currentToken), currentToken = "", offset = nextSep + 2; + else + currentToken += StringPrototypeSubstring.call(altNames, offset), offset = altNames.length; + } + return ArrayPrototypePush.call(result, currentToken), result; +}, checkServerIdentity = function(hostname, cert) { + const { subject, subjectaltname: altNames } = cert, dnsNames = [], ips = []; + if (hostname = "" + hostname, altNames) { + const splitAltNames = StringPrototypeIncludes.call(altNames, '"') ? splitEscapedAltNames(altNames) : StringPrototypeSplit.call(altNames, ", "); + ArrayPrototypeForEach.call(splitAltNames, (name) => { + if (StringPrototypeStartsWith.call(name, "DNS:")) + ArrayPrototypePush.call(dnsNames, StringPrototypeSlice.call(name, 4)); + else if (StringPrototypeStartsWith.call(name, "IP Address:")) + ArrayPrototypePush.call(ips, canonicalizeIP(StringPrototypeSlice.call(name, 11))); + }); + } + let valid = !1, reason = "Unknown reason"; + if (hostname = unfqdn(hostname), net.isIP(hostname)) { + if (valid = ArrayPrototypeIncludes.call(ips, canonicalizeIP(hostname)), !valid) + reason = `IP: ${hostname} is not in the cert's list: ` + ArrayPrototypeJoin.call(ips, ", "); + } else if (dnsNames.length > 0 || subject?.CN) { + const hostParts = splitHost(hostname), wildcard = (pattern) => check(hostParts, pattern, !0); + if (dnsNames.length > 0) { + if (valid = ArrayPrototypeSome.call(dnsNames, wildcard), !valid) + reason = `Host: ${hostname}. is not in the cert's altnames: ${altNames}`; + } else { + const cn = subject.CN; + if (ArrayIsArray(cn)) + valid = ArrayPrototypeSome.call(cn, wildcard); + else if (cn) + valid = wildcard(cn); + if (!valid) + reason = `Host: ${hostname}. is not cert's CN: ${cn}`; + } + } else + reason = "Cert does not contain a DNS name"; + if (!valid) { + let error = new Error(`ERR_TLS_CERT_ALTNAME_INVALID: Hostname/IP does not match certificate's altnames: ${reason}`); + return error.name = "ERR_TLS_CERT_ALTNAME_INVALID", error.reason = reason, error.host = host, error.cert = cert, error; + } }, SecureContext = function(options) { return new InternalSecureContext(options); }, createSecureContext = function(options) { return new SecureContext(options); -}, createServer = function(options, connectionListener) { +}; +var createServer = function(options, connectionListener) { return new Server(options, connectionListener); }, getCiphers = function() { return DEFAULT_CIPHERS.split(":"); }, getCurves = function() { return; +}, convertProtocols = function(protocols) { + const lens = new Array(protocols.length), buff = Buffer.allocUnsafe(ArrayPrototypeReduce.call(protocols, (p, c, i) => { + const len = Buffer.byteLength(c); + if (len > 255) + throw new RangeError("The byte length of the protocol at index " + `${i} exceeds the maximum length.`, "<= 255", len, !0); + return lens[i] = len, p + 1 + len; + }, 0)); + let offset = 0; + for (let i = 0, c = protocols.length;i < c; i++) + buff[offset++] = lens[i], buff.write(protocols[i], offset), offset += lens[i]; + return buff; }, convertALPNProtocols = function(protocols, out) { -}, InternalTCPSocket = net[Symbol.for("::bunternal::")], InternalSecureContext = class SecureContext2 { + if (Array.isArray(protocols)) + out.ALPNProtocols = convertProtocols(protocols); + else if (isTypedArray(protocols)) + out.ALPNProtocols = Buffer.from(protocols); + else if (isArrayBufferView(protocols)) + out.ALPNProtocols = Buffer.from(protocols.buffer.slice(protocols.byteOffset, protocols.byteOffset + protocols.byteLength)); + else if (Buffer.isBuffer(protocols)) + out.ALPNProtocols = protocols; +}, InternalTCPSocket = net[Symbol.for("::bunternal::")], bunSocketInternal = Symbol.for("::bunnetsocketinternal::"), { RegExp, Array, String } = globalThis[Symbol.for("Bun.lazy")]("primordials"), SymbolReplace = Symbol.replace, RegExpPrototypeSymbolReplace = RegExp.prototype[SymbolReplace], RegExpPrototypeExec = RegExp.prototype.exec, StringPrototypeStartsWith = String.prototype.startsWith, StringPrototypeSlice = String.prototype.slice, StringPrototypeIncludes = String.prototype.includes, StringPrototypeSplit = String.prototype.split, StringPrototypeIndexOf = String.prototype.indexOf, StringPrototypeSubstring = String.prototype.substring, StringPrototypeEndsWith = String.prototype.endsWith, ArrayPrototypeIncludes = Array.prototype.includes, ArrayPrototypeJoin = Array.prototype.join, ArrayPrototypeForEach = Array.prototype.forEach, ArrayPrototypePush = Array.prototype.push, ArrayPrototypeSome = Array.prototype.some, ArrayPrototypeReduce = Array.prototype.reduce, jsonStringPattern = /^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/, InternalSecureContext = class SecureContext2 { context; constructor(options) { const context = {}; @@ -73,8 +182,17 @@ var parseCertString = function() { }); }(class TLSSocket2 extends InternalTCPSocket { #secureContext; - constructor(options) { - super(options); + ALPNProtocols; + #socket; + constructor(socket, options) { + super(socket instanceof InternalTCPSocket ? options : options || socket); + if (options = options || socket || {}, typeof options === "object") { + const { ALPNProtocols } = options; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, this); + if (socket instanceof InternalTCPSocket) + this.#socket = socket; + } this.#secureContext = options.secureContext || createSecureContext(options), this.authorized = !1, this.secureConnecting = !0, this._secureEstablished = !1, this._securePending = !0; } _secureEstablished = !1; @@ -84,19 +202,24 @@ var parseCertString = function() { secureConnecting = !1; _SNICallback; servername; - alpnProtocol; authorized = !1; authorizationError; encrypted = !0; - exportKeyingMaterial() { - throw Error("Not implented in Bun yet"); + _start() { } - setMaxSendFragment() { + exportKeyingMaterial(length, label, context) { throw Error("Not implented in Bun yet"); } - setServername() { + setMaxSendFragment(size) { throw Error("Not implented in Bun yet"); } + setServername(name) { + if (this.isServer) { + let error = new Error("ERR_TLS_SNI_FROM_SERVER: Cannot issue SNI from a TLS server-side socket"); + throw error.name = "ERR_TLS_SNI_FROM_SERVER", error; + } + this.servername = name, this[bunSocketInternal]?.setServername(name); + } setSession() { throw Error("Not implented in Bun yet"); } @@ -112,14 +235,16 @@ var parseCertString = function() { getX509Certificate() { throw Error("Not implented in Bun yet"); } - [buntls](port, host) { - var { servername } = this; - if (servername) - return { - serverName: typeof servername === "string" ? servername : host, - ...this.#secureContext - }; - return !0; + get alpnProtocol() { + return this[bunSocketInternal]?.alpnProtocol; + } + [buntls](port, host2) { + return { + socket: this.#socket, + ALPNProtocols: this.ALPNProtocols, + serverName: this.servername || host2 || "localhost", + ...this.#secureContext + }; } }); @@ -132,9 +257,11 @@ class Server extends NetServer { _rejectUnauthorized; _requestCert; servername; + ALPNProtocols; + #checkServerIdentity; constructor(options, secureConnectionListener) { super(options, secureConnectionListener); - this.setSecureContext(options); + this.#checkServerIdentity = options?.checkServerIdentity || checkServerIdentity, this.setSecureContext(options); } emit(event, args) { if (super.emit(event, args), event === "connection") @@ -146,6 +273,9 @@ class Server extends NetServer { if (options instanceof InternalSecureContext) options = options.context; if (options) { + const { ALPNProtocols } = options; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, this); let key = options.key; if (key) { if (!isValidTLSArray(key)) @@ -194,26 +324,33 @@ class Server extends NetServer { setTicketKeys() { throw Error("Not implented in Bun yet"); } - [buntls](port, host, isClient) { + [buntls](port, host2, isClient) { return [ { - serverName: this.servername || host || "localhost", + serverName: this.servername || host2 || "localhost", key: this.key, cert: this.cert, ca: this.ca, passphrase: this.passphrase, secureOptions: this.secureOptions, rejectUnauthorized: isClient ? !1 : this._rejectUnauthorized, - requestCert: isClient ? !1 : this._requestCert + requestCert: isClient ? !1 : this._requestCert, + ALPNProtocols: this.ALPNProtocols, + checkServerIdentity: this.#checkServerIdentity }, SocketClass ]; } } -var CLIENT_RENEG_LIMIT = 3, CLIENT_RENEG_WINDOW = 600, DEFAULT_ECDH_CURVE = "auto", DEFAULT_CIPHERS = "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256", DEFAULT_MIN_VERSION = "TLSv1.2", DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host, connectListener) => { - if (typeof port === "object") - return new TLSSocket(port).connect(port, host, connectListener); - return new TLSSocket().connect(port, host, connectListener); +var CLIENT_RENEG_LIMIT = 3, CLIENT_RENEG_WINDOW = 600, DEFAULT_ECDH_CURVE = "auto", DEFAULT_CIPHERS = "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256", DEFAULT_MIN_VERSION = "TLSv1.2", DEFAULT_MAX_VERSION = "TLSv1.3", createConnection = (port, host2, connectListener) => { + if (typeof port === "object") { + port.checkServerIdentity; + const { ALPNProtocols } = port; + if (ALPNProtocols) + convertALPNProtocols(ALPNProtocols, port); + return new TLSSocket(port).connect(port, host2, connectListener); + } + return new TLSSocket().connect(port, host2, connectListener); }, connect = createConnection, exports = { [Symbol.for("CommonJS")]: 0, CLIENT_RENEG_LIMIT, @@ -244,6 +381,7 @@ export { createConnection, convertALPNProtocols, connect, + checkServerIdentity, TLSSocket, Server, SecureContext, diff --git a/test/bun.lockb b/test/bun.lockb index f30ca197a..e3a2abdfa 100755 Binary files a/test/bun.lockb and b/test/bun.lockb differ diff --git a/test/js/node/net/node-net-server.test.ts b/test/js/node/net/node-net-server.test.ts index 398959bd6..3cdaa17e1 100644 --- a/test/js/node/net/node-net-server.test.ts +++ b/test/js/node/net/node-net-server.test.ts @@ -181,61 +181,6 @@ describe("net.createServer listen", () => { ); }); - it("should listen on the correct port", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49027, - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("::"); - expect(address.port).toStrictEqual(49027); - expect(address.family).toStrictEqual("IPv6"); - server.close(); - done(); - }), - ); - }); - - it("should listen on the correct port with IPV4", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49026, - "0.0.0.0", - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("0.0.0.0"); - expect(address.port).toStrictEqual(49026); - expect(address.family).toStrictEqual("IPv4"); - server.close(); - done(); - }), - ); - }); - it("should listen on unix domain socket", done => { const { mustCall, mustNotCall } = createCallCheckCtx(done); diff --git a/test/js/node/tls/node-tls-connect.test.ts b/test/js/node/tls/node-tls-connect.test.ts new file mode 100644 index 000000000..791dba88a --- /dev/null +++ b/test/js/node/tls/node-tls-connect.test.ts @@ -0,0 +1,32 @@ +import { TLSSocket, connect } from "tls"; + +it("should work with alpnProtocols", done => { + try { + let socket: TLSSocket | null = connect({ + ALPNProtocols: ["http/1.1"], + host: "bun.sh", + servername: "bun.sh", + port: 443, + rejectUnauthorized: false, + }); + + const timeout = setTimeout(() => { + socket?.end(); + done("timeout"); + }, 3000); + + socket.on("error", err => { + clearTimeout(timeout); + done(err); + }); + + socket.on("secureConnect", () => { + clearTimeout(timeout); + done(socket?.alpnProtocol === "http/1.1" ? undefined : "alpnProtocol is not http/1.1"); + socket?.end(); + socket = null; + }); + } catch (err) { + done(err); + } +}); diff --git a/test/js/node/tls/node-tls-server.test.ts b/test/js/node/tls/node-tls-server.test.ts index 6879d0927..2a6101b9f 100644 --- a/test/js/node/tls/node-tls-server.test.ts +++ b/test/js/node/tls/node-tls-server.test.ts @@ -195,61 +195,6 @@ describe("tls.createServer listen", () => { ); }); - it("should listen on the correct port", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(COMMON_CERT); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49027, - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("::"); - expect(address.port).toStrictEqual(49027); - expect(address.family).toStrictEqual("IPv6"); - server.close(); - done(); - }), - ); - }); - - it("should listen on the correct port with IPV4", done => { - const { mustCall, mustNotCall } = createCallCheckCtx(done); - - const server: Server = createServer(COMMON_CERT); - - let timeout: Timer; - const closeAndFail = () => { - clearTimeout(timeout); - server.close(); - mustNotCall()(); - }; - server.on("error", closeAndFail); - timeout = setTimeout(closeAndFail, 100); - - server.listen( - 49026, - "0.0.0.0", - mustCall(() => { - const address = server.address() as AddressInfo; - expect(address.address).toStrictEqual("0.0.0.0"); - expect(address.port).toStrictEqual(49026); - expect(address.family).toStrictEqual("IPv4"); - server.close(); - done(); - }), - ); - }); - it("should listen on unix domain socket", done => { const { mustCall, mustNotCall } = createCallCheckCtx(done); diff --git a/test/js/third_party/nodemailer/nodemailer.test.ts b/test/js/third_party/nodemailer/nodemailer.test.ts new file mode 100644 index 000000000..265112608 --- /dev/null +++ b/test/js/third_party/nodemailer/nodemailer.test.ts @@ -0,0 +1,15 @@ +import { test, expect, describe } from "bun:test"; +import { bunRun } from "harness"; +import path from "path"; + +describe("nodemailer", () => { + test("basic smtp", async () => { + try { + const info = bunRun(path.join(import.meta.dir, "process-nodemailer-fixture.js")); + expect(info.stdout).toBe("true"); + expect(info.stderr || "").toBe(""); + } catch (err: any) { + expect(err?.message || err).toBe(""); + } + }, 10000); +}); diff --git a/test/js/third_party/nodemailer/package.json b/test/js/third_party/nodemailer/package.json new file mode 100644 index 000000000..08e98074f --- /dev/null +++ b/test/js/third_party/nodemailer/package.json @@ -0,0 +1,6 @@ +{ + "name": "nodemailer", + "dependencies": { + "nodemailer": "6.9.3" + } +} diff --git a/test/js/third_party/nodemailer/process-nodemailer-fixture.js b/test/js/third_party/nodemailer/process-nodemailer-fixture.js new file mode 100644 index 000000000..a54735f26 --- /dev/null +++ b/test/js/third_party/nodemailer/process-nodemailer-fixture.js @@ -0,0 +1,23 @@ +import nodemailer from "nodemailer"; +const account = await nodemailer.createTestAccount(); +const transporter = nodemailer.createTransport({ + host: account.smtp.host, + port: account.smtp.port, + secure: account.smtp.secure, + auth: { + user: account.user, // generated ethereal user + pass: account.pass, // generated ethereal password + }, +}); + +// send mail with defined transport object +let info = await transporter.sendMail({ + from: '"Fred Foo 👻" ', // sender address + to: "example@gmail.com", // list of receivers + subject: "Hello ✔", // Subject line + text: "Hello world?", // plain text body + html: "Hello world?", // html body +}); +const url = nodemailer.getTestMessageUrl(info); +console.log(typeof url === "string" && url.length > 0); +transporter.close(); diff --git a/test/js/web/timers/process-setImmediate-fixture.js b/test/js/web/timers/process-setImmediate-fixture.js new file mode 100644 index 000000000..6ffd91c8d --- /dev/null +++ b/test/js/web/timers/process-setImmediate-fixture.js @@ -0,0 +1,9 @@ +setImmediate(() => { + console.log("setImmediate"); + return { + a: 1, + b: 2, + c: 3, + d: 4, + }; +}); diff --git a/test/js/web/timers/setImmediate.test.js b/test/js/web/timers/setImmediate.test.js index 9cd6fa1c9..d00224e0f 100644 --- a/test/js/web/timers/setImmediate.test.js +++ b/test/js/web/timers/setImmediate.test.js @@ -1,4 +1,6 @@ import { it, expect } from "bun:test"; +import { bunExe, bunEnv } from "harness"; +import path from "path"; it("setImmediate", async () => { var lastID = -1; @@ -45,3 +47,28 @@ it("clearImmediate", async () => { }); expect(called).toBe(false); }); + +it("setImmediate should not keep the process alive forever", async () => { + let process = null; + const success = async () => { + process = Bun.spawn({ + cmd: [bunExe(), "run", path.join(import.meta.dir, "process-setImmediate-fixture.js")], + stdout: "ignore", + env: { + ...bunEnv, + NODE_ENV: undefined, + }, + }); + await process.exited; + process = null; + return true; + }; + + const fail = async () => { + await Bun.sleep(500); + process?.kill(); + return false; + }; + + expect(await Promise.race([success(), fail()])).toBe(true); +}); diff --git a/test/package.json b/test/package.json index 116571879..db0053874 100644 --- a/test/package.json +++ b/test/package.json @@ -16,9 +16,10 @@ "iconv-lite": "0.6.3", "jest-extended": "4.0.0", "lodash": "4.17.21", + "nodemailer": "6.9.3", "prisma": "4.15.0", - "socket.io": "4.6.1", - "socket.io-client": "4.6.1", + "socket.io": "4.7.1", + "socket.io-client": "4.7.1", "supertest": "6.1.6", "svelte": "3.55.1", "typescript": "5.0.2", -- cgit v1.2.3 From 3aaec120e7ac26b3904895d2783a08352b63201a Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Wed, 5 Jul 2023 03:46:10 -0700 Subject: Fixes #3512 (#3526) * Fixes #3512 * Fix `clearTimeout` and `clearInterval` not cancelling jobs same-tick --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> --- packages/bun-types/timers.d.ts | 4 +- src/bun.js/api/bun.zig | 100 ++++++++++++++++------- test/js/node/timers/node-timers.test.ts | 13 +-- test/js/web/timers/setTimeout-unref-fixture-2.js | 9 ++ test/js/web/timers/setTimeout-unref-fixture-3.js | 7 ++ test/js/web/timers/setTimeout-unref-fixture-4.js | 5 ++ test/js/web/timers/setTimeout-unref-fixture-5.js | 5 ++ test/js/web/timers/setTimeout-unref-fixture.js | 12 +++ test/js/web/timers/setTimeout.test.js | 51 +++++++++++- 9 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 test/js/web/timers/setTimeout-unref-fixture-2.js create mode 100644 test/js/web/timers/setTimeout-unref-fixture-3.js create mode 100644 test/js/web/timers/setTimeout-unref-fixture-4.js create mode 100644 test/js/web/timers/setTimeout-unref-fixture-5.js create mode 100644 test/js/web/timers/setTimeout-unref-fixture.js (limited to 'src/bun.js/api/bun.zig') diff --git a/packages/bun-types/timers.d.ts b/packages/bun-types/timers.d.ts index ab1e29953..0d2f3e745 100644 --- a/packages/bun-types/timers.d.ts +++ b/packages/bun-types/timers.d.ts @@ -11,8 +11,8 @@ declare module "timers" { class Timer { - ref(): void; - unref(): void; + ref(): Timer; + unref(): Timer; hasRef(): boolean; } diff --git a/src/bun.js/api/bun.zig b/src/bun.js/api/bun.zig index 1e5a5e004..fbf567446 100644 --- a/src/bun.js/api/bun.zig +++ b/src/bun.js/api/bun.zig @@ -3715,21 +3715,32 @@ pub const Timer = struct { const kind = this.kind; var map: *TimeoutMap = vm.timer.maps.get(kind); - // This doesn't deinit the timer - // Timers are deinit'd separately - // We do need to handle when the timer is cancelled after the job has been enqueued - if (kind != .setInterval) { - if (map.fetchSwapRemove(this.id) == null) { - // if the timeout was cancelled, don't run the callback - this.deinit(); - return; - } - } else { - if (!map.contains(this.id)) { - // if the interval was cancelled, don't run the callback - this.deinit(); - return; + const should_cancel_job = brk: { + // This doesn't deinit the timer + // Timers are deinit'd separately + // We do need to handle when the timer is cancelled after the job has been enqueued + if (kind != .setInterval) { + if (map.get(this.id)) |tombstone_or_timer| { + break :brk tombstone_or_timer != null; + } else { + // clearTimeout has been called + break :brk true; + } + } else { + if (map.get(this.id)) |tombstone_or_timer| { + // .refresh() was called after CallbackJob enqueued + break :brk tombstone_or_timer == null; + } } + + break :brk false; + }; + + if (should_cancel_job) { + this.deinit(); + return; + } else if (kind != .setInterval) { + _ = map.swapRemove(this.id); } var args_buf: [8]JSC.JSValue = undefined; @@ -3825,10 +3836,29 @@ pub const Timer = struct { return timer_js; } - pub fn doRef(this: *TimerObject, _: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSValue { + pub fn doRef(this: *TimerObject, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { + const this_value = callframe.this(); + this_value.ensureStillAlive(); if (this.ref_count > 0) this.ref_count +|= 1; - return JSValue.jsUndefined(); + + var vm = globalObject.bunVM(); + switch (this.kind) { + .setTimeout, .setImmediate, .setInterval => { + if (vm.timer.maps.get(this.kind).getPtr(this.id)) |val_| { + if (val_.*) |*val| { + val.poll_ref.ref(vm); + + if (val.did_unref_timer) { + val.did_unref_timer = false; + vm.uws_event_loop.?.num_polls += 1; + } + } + } + }, + } + + return this_value; } pub fn doRefresh(this: *TimerObject, globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { @@ -3924,20 +3954,27 @@ pub const Timer = struct { return JSValue.jsUndefined(); } - pub fn doUnref(this: *TimerObject, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSValue { + pub fn doUnref(this: *TimerObject, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue { + const this_value = callframe.this(); + this_value.ensureStillAlive(); this.ref_count -|= 1; - if (this.ref_count == 0) { - switch (this.kind) { - .setTimeout, .setImmediate => { - _ = clearTimeout(globalObject, JSValue.jsNumber(this.id)); - }, - .setInterval => { - _ = clearInterval(globalObject, JSValue.jsNumber(this.id)); - }, - } + var vm = globalObject.bunVM(); + switch (this.kind) { + .setTimeout, .setImmediate, .setInterval => { + if (vm.timer.maps.get(this.kind).getPtr(this.id)) |val_| { + if (val_.*) |*val| { + val.poll_ref.unref(vm); + + if (!val.did_unref_timer) { + val.did_unref_timer = true; + vm.uws_event_loop.?.num_polls -= 1; + } + } + } + }, } - return JSValue.jsUndefined(); + return this_value; } pub fn hasRef(this: *TimerObject, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSValue { return JSValue.jsBoolean(this.ref_count > 0 and globalObject.bunVM().timer.maps.get(this.kind).contains(this.id)); @@ -3959,6 +3996,7 @@ pub const Timer = struct { callback: JSC.Strong = .{}, globalThis: *JSC.JSGlobalObject, timer: *uws.Timer, + did_unref_timer: bool = false, poll_ref: JSC.PollRef = JSC.PollRef.init(), arguments: JSC.Strong = .{}, @@ -4060,8 +4098,14 @@ pub const Timer = struct { var vm = this.globalThis.bunVM(); - this.poll_ref.unrefOnNextTick(vm); + this.poll_ref.unref(vm); + this.timer.deinit(); + if (this.did_unref_timer) { + // balance double-unrefing + vm.uws_event_loop.?.num_polls += 1; + } + this.callback.deinit(); this.arguments.deinit(); } diff --git a/test/js/node/timers/node-timers.test.ts b/test/js/node/timers/node-timers.test.ts index e6fa48010..412eabc22 100644 --- a/test/js/node/timers/node-timers.test.ts +++ b/test/js/node/timers/node-timers.test.ts @@ -1,17 +1,18 @@ import { describe, test } from "bun:test"; -import { setTimeout, clearTimeout, setInterval, setImmediate } from "node:timers"; +import { setTimeout, clearTimeout, setInterval, clearInterval, setImmediate } from "node:timers"; -for (const fn of [setTimeout, setInterval, setImmediate]) { +for (const fn of [setTimeout, setInterval]) { describe(fn.name, () => { test("unref is possible", done => { const timer = fn(() => { done(new Error("should not be called")); - }, 1); - fn(() => { + }, 1).unref(); + const other = fn(() => { + clearInterval(other); done(); }, 2); - timer.unref(); - if (fn !== setImmediate) clearTimeout(timer); + if (fn === setTimeout) clearTimeout(timer); + if (fn === setInterval) clearInterval(timer); }); }); } diff --git a/test/js/web/timers/setTimeout-unref-fixture-2.js b/test/js/web/timers/setTimeout-unref-fixture-2.js new file mode 100644 index 000000000..6a78f13cd --- /dev/null +++ b/test/js/web/timers/setTimeout-unref-fixture-2.js @@ -0,0 +1,9 @@ +setTimeout(() => { + console.log("TEST FAILED!"); +}, 100) + .ref() + .unref(); + +setTimeout(() => { + // this one should always run +}, 1); diff --git a/test/js/web/timers/setTimeout-unref-fixture-3.js b/test/js/web/timers/setTimeout-unref-fixture-3.js new file mode 100644 index 000000000..41808f5fc --- /dev/null +++ b/test/js/web/timers/setTimeout-unref-fixture-3.js @@ -0,0 +1,7 @@ +setTimeout(() => { + setTimeout(() => {}, 999_999); +}, 100).unref(); + +setTimeout(() => { + // this one should always run +}, 1); diff --git a/test/js/web/timers/setTimeout-unref-fixture-4.js b/test/js/web/timers/setTimeout-unref-fixture-4.js new file mode 100644 index 000000000..9968f3b36 --- /dev/null +++ b/test/js/web/timers/setTimeout-unref-fixture-4.js @@ -0,0 +1,5 @@ +setTimeout(() => { + console.log("TEST PASSED!"); +}, 1) + .unref() + .ref(); diff --git a/test/js/web/timers/setTimeout-unref-fixture-5.js b/test/js/web/timers/setTimeout-unref-fixture-5.js new file mode 100644 index 000000000..e5caa1be4 --- /dev/null +++ b/test/js/web/timers/setTimeout-unref-fixture-5.js @@ -0,0 +1,5 @@ +setTimeout(() => { + console.log("TEST FAILED!"); +}, 100) + .ref() + .unref(); diff --git a/test/js/web/timers/setTimeout-unref-fixture.js b/test/js/web/timers/setTimeout-unref-fixture.js new file mode 100644 index 000000000..97a0f78a2 --- /dev/null +++ b/test/js/web/timers/setTimeout-unref-fixture.js @@ -0,0 +1,12 @@ +const timer = setTimeout(() => {}, 999_999_999); +if (timer.unref() !== timer) throw new Error("Expected timer.unref() === timer"); + +var ranCount = 0; +const going2Refresh = setTimeout(() => { + if (ranCount < 1) going2Refresh.refresh(); + ranCount++; + + if (ranCount === 2) { + console.log("SUCCESS"); + } +}, 1); diff --git a/test/js/web/timers/setTimeout.test.js b/test/js/web/timers/setTimeout.test.js index dbe89dea8..eef6bbae0 100644 --- a/test/js/web/timers/setTimeout.test.js +++ b/test/js/web/timers/setTimeout.test.js @@ -1,5 +1,7 @@ +import { spawnSync } from "bun"; import { it, expect } from "bun:test"; - +import { bunEnv, bunExe } from "harness"; +import path from "node:path"; it("setTimeout", async () => { var lastID = -1; const result = await new Promise((resolve, reject) => { @@ -172,11 +174,56 @@ it.skip("order of setTimeouts", done => { Promise.resolve().then(maybeDone(() => nums.push(1))); }); +it("setTimeout -> refresh", () => { + const { exitCode, stdout } = spawnSync({ + cmd: [bunExe(), path.join(import.meta.dir, "setTimeout-unref-fixture.js")], + env: bunEnv, + }); + expect(exitCode).toBe(0); + expect(stdout.toString()).toBe("SUCCESS\n"); +}); + +it("setTimeout -> unref -> ref works", () => { + const { exitCode, stdout } = spawnSync({ + cmd: [bunExe(), path.join(import.meta.dir, "setTimeout-unref-fixture-4.js")], + env: bunEnv, + }); + expect(exitCode).toBe(0); + expect(stdout.toString()).toBe("TEST PASSED!\n"); +}); + +it("setTimeout -> ref -> unref works, even if there is another timer", () => { + const { exitCode, stdout } = spawnSync({ + cmd: [bunExe(), path.join(import.meta.dir, "setTimeout-unref-fixture-2.js")], + env: bunEnv, + }); + expect(exitCode).toBe(0); + expect(stdout.toString()).toBe(""); +}); + +it("setTimeout -> ref -> unref works", () => { + const { exitCode, stdout } = spawnSync({ + cmd: [bunExe(), path.join(import.meta.dir, "setTimeout-unref-fixture-5.js")], + env: bunEnv, + }); + expect(exitCode).toBe(0); + expect(stdout.toString()).toBe(""); +}); + +it("setTimeout -> unref doesn't keep event loop alive forever", () => { + const { exitCode, stdout } = spawnSync({ + cmd: [bunExe(), path.join(import.meta.dir, "setTimeout-unref-fixture-3.js")], + env: bunEnv, + }); + expect(exitCode).toBe(0); + expect(stdout.toString()).toBe(""); +}); + it("setTimeout should refresh N times", done => { let count = 0; let timer = setTimeout(() => { count++; - timer.refresh(); + expect(timer.refresh()).toBe(timer); }, 50); setTimeout(() => { -- cgit v1.2.3