diff options
Diffstat (limited to 'src/env_loader.zig')
-rw-r--r-- | src/env_loader.zig | 393 |
1 files changed, 376 insertions, 17 deletions
diff --git a/src/env_loader.zig b/src/env_loader.zig index 4e31edbfd..1a0556455 100644 --- a/src/env_loader.zig +++ b/src/env_loader.zig @@ -2,7 +2,8 @@ const std = @import("std"); const logger = @import("./logger.zig"); usingnamespace @import("./global.zig"); const CodepointIterator = @import("./string_immutable.zig").CodepointIterator; - +const Fs = @import("./fs.zig"); +const Api = @import("./api/schema.zig").Api; const Variable = struct { key: string, value: string, @@ -44,7 +45,7 @@ pub const Lexer = struct { comptime Writer: type, writer: Writer, variable: Variable, - getter: fn (ctx: *const ContextType, key: string) ?string, + comptime getter: fn (ctx: *const ContextType, key: string) ?string, ) !void { var i: usize = 0; var last_flush: usize = 0; @@ -70,7 +71,7 @@ pub const Lexer = struct { last_flush = i; const name = variable.value[start..i]; - if (getter(ctx, name)) |new_value| { + if (@call(.{ .modifier = .always_inline }, getter, .{ ctx, name })) |new_value| { if (new_value.len > 0) { try writer.writeAll(new_value); } @@ -328,27 +329,307 @@ pub const Lexer = struct { } }; -pub const Parser = struct { - pub fn parse(source: *const logger.Source, allocator: *std.mem.Allocator) Map { - var map = Map.init(allocator); +pub const Loader = struct { + map: *Map, + allocator: *std.mem.Allocator, + + @".env.local": ?logger.Source = null, + @".env.development": ?logger.Source = null, + @".env.production": ?logger.Source = null, + @".env": ?logger.Source = null, + + did_load_process: bool = false, + + const empty_string_value: string = "\"\""; + + pub fn copyForDefine( + this: *Loader, + comptime Type: type, + to: *Type, + framework_defaults: Api.StringMap, + behavior: Api.DotEnvBehavior, + prefix: string, + allocator: *std.mem.Allocator, + ) ![]u8 { + var iter = this.map.iter(); + var key_count: usize = 0; + var string_map_hashes = try allocator.alloc(u64, framework_defaults.keys.len); + defer allocator.free(string_map_hashes); + const invalid_hash = std.math.maxInt(u64) - 1; + std.mem.set(u64, string_map_hashes, invalid_hash); + + var key_buf: []u8 = ""; + // Frameworks determine an allowlist of values + + for (framework_defaults.keys) |key, i| { + if (key.len > "process.env.".len and strings.eqlComptime(key[0.."process.env.".len], "process.env.")) { + const hashable_segment = key["process.env.".len..]; + string_map_hashes[i] = std.hash.Wyhash.hash(0, hashable_segment); + } + } + + // We have to copy all the keys to prepend "process.env" :/ + var key_buf_len: usize = 0; + + if (behavior != .disable) { + if (behavior == .prefix) { + std.debug.assert(prefix.len > 0); + + while (iter.next()) |entry| { + if (strings.startsWith(entry.key_ptr.*, prefix)) { + key_buf_len += entry.key_ptr.len; + key_count += 1; + std.debug.assert(entry.key_ptr.len > 0); + } + } + } else { + while (iter.next()) |entry| { + key_buf_len += entry.key_ptr.len; + key_count += 1; + std.debug.assert(entry.key_ptr.len > 0); + } + } + + if (key_buf_len > 0) { + iter.reset(); + key_buf = try allocator.alloc(u8, key_buf_len + key_count * "process.env.".len); + errdefer allocator.free(key_buf); + var key_fixed_allocator = std.heap.FixedBufferAllocator.init(key_buf); + var key_allocator = &key_fixed_allocator.allocator; + + if (behavior == .prefix) { + while (iter.next()) |entry| { + const value: string = if (entry.value_ptr.*.len == 0) empty_string_value else entry.value_ptr.*; + + if (strings.startsWith(entry.key_ptr.*, prefix)) { + _ = try to.getOrPutValue( + std.fmt.allocPrint(key_allocator, "process.env.{s}", .{entry.key_ptr.*}) catch unreachable, + value, + ); + } else { + const hash = std.hash.Wyhash.hash(0, entry.key_ptr.*); + + std.debug.assert(hash != invalid_hash); + + if (std.mem.indexOfScalar(u64, string_map_hashes, hash)) |key_i| { + _ = try to.getOrPutValue( + framework_defaults.keys[key_i], + value, + ); + } + } + } + } else { + while (iter.next()) |entry| { + const value: string = if (entry.value_ptr.*.len == 0) empty_string_value else entry.value_ptr.*; + _ = try to.getOrPutValue( + std.fmt.allocPrint(key_allocator, "process.env.{s}", .{entry.key_ptr.*}) catch unreachable, + value, + ); + } + } + } + } + + for (framework_defaults.keys) |key, i| { + const value = framework_defaults.values[i]; + + if (value.len == 0) { + _ = try to.getOrPutValue(key, empty_string_value); + } else { + _ = try to.getOrPutValue(key, value); + } + } + + return key_buf; + } + + pub fn init(map: *Map, allocator: *std.mem.Allocator) Loader { + return Loader{ + .map = map, + .allocator = allocator, + }; + } + + pub fn loadProcess(this: *Loader) void { + if (this.did_load_process) return; + + // This is a little weird because it's evidently stored line-by-line + var source = logger.Source.initPathString("process.env", ""); + for (std.os.environ) |env| { + source.contents = std.mem.span(env); + Parser.parse(&source, this.allocator, this.map, true); + } + this.did_load_process = true; + } + + // mostly for tests + pub fn loadFromString(this: *Loader, str: string, comptime overwrite: bool) void { + var source = logger.Source.initPathString("test", str); + Parser.parse(&source, this.allocator, this.map, overwrite); + std.mem.doNotOptimizeAway(&source); + } + + // .env.local goes first + // Load .env.development if development + // Load .env.production if !development + // .env goes last + pub fn load( + this: *Loader, + fs: *Fs.FileSystem.RealFS, + dir: *Fs.FileSystem.DirEntry, + comptime development: bool, + ) !void { + const start = std.time.nanoTimestamp(); + var dir_handle: std.fs.Dir = std.fs.cwd(); + var can_auto_close = false; + + if (dir.hasComptimeQuery(".env.local")) { + try this.loadEnvFile(fs, dir_handle, ".env.local", false); + } + + if (comptime development) { + if (dir.hasComptimeQuery(".env.development")) { + try this.loadEnvFile(fs, dir_handle, ".env.development", false); + } + } else { + if (dir.hasComptimeQuery(".env.production")) { + try this.loadEnvFile(fs, dir_handle, ".env.production", false); + } + } + + if (dir.hasComptimeQuery(".env")) { + try this.loadEnvFile(fs, dir_handle, ".env", false); + } + + this.printLoaded(start); + } + + pub fn printLoaded(this: *Loader, start: i128) void { + const count = + @intCast(u8, @boolToInt(this.@".env.local" != null)) + + @intCast(u8, @boolToInt(this.@".env.development" != null)) + + @intCast(u8, @boolToInt(this.@".env.production" != null)) + + @intCast(u8, @boolToInt(this.@".env" != null)); + + if (count == 0) return; + const elapsed = @intToFloat(f64, (std.time.nanoTimestamp() - start)) / std.time.ns_per_ms; + + const all = [_]string{ + ".env.local", + ".env.development", + ".env.production", + ".env", + }; + const loaded = [_]bool{ + this.@".env.local" != null, + this.@".env.development" != null, + this.@".env.production" != null, + this.@".env" != null, + }; + + var loaded_i: u8 = 0; + Output.printElapsed(elapsed); + Output.prettyError(" <d>", .{}); + + for (loaded) |yes, i| { + if (yes) { + loaded_i += 1; + if (count == 1 or (loaded_i >= count and count > 1)) { + Output.prettyError("\"{s}\"", .{all[i]}); + } else { + Output.prettyError("\"{s}\", ", .{all[i]}); + } + } + } + Output.prettyErrorln("<r>\n", .{}); + Output.flush(); + } + + pub fn loadEnvFile(this: *Loader, fs: *Fs.FileSystem.RealFS, dir: std.fs.Dir, comptime base: string, comptime override: bool) !void { + if (@field(this, base) != null) { + return; + } + + var file = dir.openFile(base, .{ .read = true }) catch |err| { + switch (err) { + error.FileNotFound => { + // prevent retrying + @field(this, base) = logger.Source.initPathString(base, ""); + return; + }, + else => { + return err; + }, + } + }; + Fs.FileSystem.setMaxFd(file.handle); + + defer { + if (fs.needToCloseFiles()) { + file.close(); + } + } + const stat = try file.stat(); + if (stat.size == 0) { + @field(this, base) = logger.Source.initPathString(base, ""); + return; + } + + var buf = try this.allocator.allocSentinel(u8, stat.size, 0); + errdefer this.allocator.free(buf); + var contents = try file.readAll(buf); + // always sentinel + buf.ptr[contents + 1] = 0; + const source = logger.Source.initPathString(base, buf.ptr[0..contents]); + + Parser.parse( + &source, + this.allocator, + this.map, + override, + ); + + @field(this, base) = source; + } +}; + +pub const Parser = struct { + pub fn parse( + source: *const logger.Source, + allocator: *std.mem.Allocator, + map: *Map, + comptime override: bool, + ) void { var lexer = Lexer.init(source); var fbs = std.io.fixedBufferStream(&temporary_nested_value_buffer); var writer = fbs.writer(); + while (lexer.next()) |variable| { if (variable.has_nested_value) { writer.context.reset(); - lexer.eatNestedValue(Map, &map, @TypeOf(writer), writer, variable, Map.get) catch unreachable; + + lexer.eatNestedValue(Map, map, @TypeOf(writer), writer, variable, Map.get) catch unreachable; const new_value = fbs.buffer[0..fbs.pos]; if (new_value.len > 0) { - map.put(variable.key, allocator.dupe(u8, new_value) catch unreachable) catch unreachable; + if (comptime override) { + map.put(variable.key, allocator.dupe(u8, new_value) catch unreachable) catch unreachable; + } else { + var putter = map.map.getOrPut(variable.key) catch unreachable; + if (!putter.found_existing) { + putter.value_ptr.* = allocator.dupe(u8, new_value) catch unreachable; + } + } } } else { - map.put(variable.key, variable.value) catch unreachable; + if (comptime override) { + map.put(variable.key, variable.value) catch unreachable; + } else { + map.putDefault(variable.key, variable.value) catch unreachable; + } } } - - return map; } }; @@ -361,7 +642,7 @@ pub const Map = struct { return Map{ .map = HashTable.init(allocator) }; } - pub inline fn iter(this: *Map) !HashTable.Iterator { + pub inline fn iter(this: *Map) HashTable.Iterator { return this.map.iterator(); } @@ -380,9 +661,9 @@ pub const Map = struct { _ = try this.map.getOrPutValue(key, value); } - pub fn merge(this: *Map, other: *Map) !void {} - - pub fn copyPrefixed(this: *Map, other: *Map) !void {} + pub inline fn getOrPut(this: *Map, key: string, value: string) !void { + _ = try this.map.getOrPutValue(key, value); + } }; const expectString = std.testing.expectEqualStrings; @@ -422,10 +703,18 @@ test "DotEnv Loader" { \\ ; const source = logger.Source.initPathString(".env", VALID_ENV); - const map = Parser.parse(&source, std.heap.c_allocator); + var map = Map.init(std.heap.c_allocator); + Parser.parse( + &source, + std.heap.c_allocator, + &map, + true, + ); + try expectString(map.get("NESTED_VALUES_RESPECT_ESCAPING").?, "'\\$API_KEY'"); + try expectString(map.get("NESTED_VALUE").?, "'verysecure'"); try expectString(map.get("RECURSIVE_NESTED_VALUE").?, "'verysecure':verysecure"); - try expectString(map.get("NESTED_VALUES_RESPECT_ESCAPING").?, "'\\$API_KEY'"); + try expectString(map.get("API_KEY").?, "verysecure"); try expectString(map.get("process.env.WAT").?, "ABCDEFGHIJKLMNOPQRSTUVWXYZZ10239457123"); try expectString(map.get("DOUBLE-QUOTED_SHOULD_PRESERVE_NEWLINES").?, "\"\nya\n\""); @@ -438,3 +727,73 @@ test "DotEnv Loader" { try expectString(map.get("IGNORING_DOESNT_BREAK_OTHER_LINES").?, "'yes'"); try expectString(map.get("LEADING_SPACE_IN_UNQUOTED_VALUE_IS_TRIMMED").?, "yes"); } + +test "DotEnv Process" { + var map = Map.init(std.heap.c_allocator); + var process = try std.process.getEnvMap(std.heap.c_allocator); + var loader = Loader.init(&map, std.heap.c_allocator); + loader.loadProcess(); + + try expectString(loader.map.get("TMPDIR").?, process.get("TMPDIR").?); + try expect(loader.map.get("TMPDIR").?.len > 0); + + try expectString(loader.map.get("USER").?, process.get("USER").?); + try expect(loader.map.get("USER").?.len > 0); +} + +test "DotEnv Loader.copyForDefine" { + const UserDefine = std.StringArrayHashMap(string); + + var map = Map.init(std.heap.c_allocator); + var loader = Loader.init(&map, std.heap.c_allocator); + const framework_keys = [_]string{ "process.env.BACON", "process.env.HOSTNAME" }; + const framework_values = [_]string{ "true", "\"localhost\"" }; + const framework = Api.StringMap{ + .keys = std.mem.span(&framework_keys), + .values = std.mem.span(&framework_values), + }; + + const user_overrides: string = + \\BACON=false + \\HOSTNAME=example.com + \\THIS_SHOULDNT_BE_IN_DEFINES_MAP=true + \\ + ; + + const skip_user_overrides: string = + \\THIS_SHOULDNT_BE_IN_DEFINES_MAP=true + \\ + ; + + loader.loadFromString(skip_user_overrides, false); + + var user_defines = UserDefine.init(std.heap.c_allocator); + var buf = try loader.copyForDefine(UserDefine, &user_defines, framework, .disable, "", std.heap.c_allocator); + + try expect(user_defines.get("process.env.THIS_SHOULDNT_BE_IN_DEFINES_MAP") == null); + + user_defines = UserDefine.init(std.heap.c_allocator); + + loader.loadFromString(user_overrides, true); + + buf = try loader.copyForDefine( + UserDefine, + &user_defines, + framework, + Api.DotEnvBehavior.load_all, + "", + std.heap.c_allocator, + ); + + try expect(user_defines.get("process.env.BACON") != null); + try expectString(user_defines.get("process.env.BACON").?, "false"); + try expectString(user_defines.get("process.env.HOSTNAME").?, "example.com"); + try expect(user_defines.get("process.env.THIS_SHOULDNT_BE_IN_DEFINES_MAP") != null); + + user_defines = UserDefine.init(std.heap.c_allocator); + + buf = try loader.copyForDefine(UserDefine, &user_defines, framework, .prefix, "HO", std.heap.c_allocator); + + try expectString(user_defines.get("process.env.HOSTNAME").?, "example.com"); + try expect(user_defines.get("process.env.THIS_SHOULDNT_BE_IN_DEFINES_MAP") == null); +} |