diff options
author | 2023-06-25 02:58:49 -0700 | |
---|---|---|
committer | 2023-06-25 02:58:49 -0700 | |
commit | bc7719fc2800e339fa9a75a124023e92bae2ec56 (patch) | |
tree | 11b86421a69156a4585b2849ffaf5b75a13444c6 /src/string_immutable.zig | |
parent | ff635551436123022ba3980b39580d53973c80a2 (diff) | |
download | bun-bc7719fc2800e339fa9a75a124023e92bae2ec56.tar.gz bun-bc7719fc2800e339fa9a75a124023e92bae2ec56.tar.zst bun-bc7719fc2800e339fa9a75a124023e92bae2ec56.zip |
Reliability bugfix for `WebSocket` (#3394)
* Rewrite elementLengthLatin1IntoUTF8
* Update SIMDUTF
* Make `elementLengthLatin1IntoUTF8` faster
---------
Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
Diffstat (limited to 'src/string_immutable.zig')
-rw-r--r-- | src/string_immutable.zig | 71 |
1 files changed, 27 insertions, 44 deletions
diff --git a/src/string_immutable.zig b/src/string_immutable.zig index d075b38f5..3931648b8 100644 --- a/src/string_immutable.zig +++ b/src/string_immutable.zig @@ -1946,61 +1946,44 @@ pub fn replaceLatin1WithUTF8(buf_: []u8) void { } pub fn elementLengthLatin1IntoUTF8(comptime Type: type, latin1_: Type) usize { + // https://zig.godbolt.org/z/zzYexPPs9 + var latin1 = latin1_; + const input_len = latin1.len; var total_non_ascii_count: usize = 0; - const latin1_last = latin1.ptr + latin1.len; - if (latin1.ptr != latin1_last) { - - // reference the pointer directly because it improves codegen - var ptr = latin1.ptr; - - if (comptime Environment.enableSIMD) { - const wrapped_len = latin1.len - (latin1.len % ascii_vector_size); - const latin1_vec_end = ptr + wrapped_len; - while (ptr != latin1_vec_end) { - const vec: AsciiVector = ptr[0..ascii_vector_size].*; - const cmp = vec & @splat(ascii_vector_size, @as(u8, 0x80)); - total_non_ascii_count += @reduce(.Add, cmp); - ptr += ascii_vector_size; - } - } else { - while (@intFromPtr(ptr + 8) < @intFromPtr(latin1_last)) { - if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr)); - const bytes = @bitCast(u64, ptr[0..8].*) & 0x8080808080808080; - total_non_ascii_count += @popCount(bytes); - ptr += 8; - } + // This is about 30% faster on large input compared to auto-vectorization + if (comptime Environment.enableSIMD) { + const end = latin1.ptr + (latin1.len - (latin1.len % ascii_vector_size)); + while (latin1.ptr != end) { + const vec: AsciiVector = latin1[0..ascii_vector_size].*; + + // Shifting a unsigned 8 bit integer to the right by 7 bits always produces a value of 0 or 1. + const cmp = vec >> @splat( + ascii_vector_size, + @as(u8, 7), + ); - if (@intFromPtr(ptr + 4) < @intFromPtr(latin1_last)) { - if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr)); - const bytes = @bitCast(u32, ptr[0..4].*) & 0x80808080; - total_non_ascii_count += @popCount(bytes); - ptr += 4; - } + // Anding that value rather than converting it into a @Vector(16, u1) produces better code from LLVM. + const mask = cmp & @splat( + ascii_vector_size, + @as(u8, 1), + ); - if (@intFromPtr(ptr + 2) < @intFromPtr(latin1_last)) { - if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr)); - const bytes = @bitCast(u16, ptr[0..2].*) & 0x8080; - total_non_ascii_count += @popCount(bytes); - ptr += 2; - } + total_non_ascii_count += @as(usize, @reduce(.Add, mask)); + latin1 = latin1[ascii_vector_size..]; } - while (ptr != latin1_last) { - if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) < @intFromPtr(latin1_last)); - - total_non_ascii_count += @as(usize, @intFromBool(ptr[0] > 127)); - ptr += 1; - } + // an important hint to the compiler to not auto-vectorize the loop below + if (latin1.len >= ascii_vector_size) unreachable; + } - // assert we never go out of bounds - if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr)); + for (latin1) |c| { + total_non_ascii_count += @as(usize, @intFromBool(c > 127)); } // each non-ascii latin1 character becomes 2 UTF8 characters - // since latin1_.len is the original length, we only need to add up the number of non-ascii characters to get the final count - return latin1_.len + total_non_ascii_count; + return input_len + total_non_ascii_count; } const JSC = @import("root").bun.JSC; |