aboutsummaryrefslogtreecommitdiff
path: root/src/string_immutable.zig
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <jarred@jarredsumner.com> 2023-06-25 02:58:49 -0700
committerGravatar GitHub <noreply@github.com> 2023-06-25 02:58:49 -0700
commitbc7719fc2800e339fa9a75a124023e92bae2ec56 (patch)
tree11b86421a69156a4585b2849ffaf5b75a13444c6 /src/string_immutable.zig
parentff635551436123022ba3980b39580d53973c80a2 (diff)
downloadbun-bc7719fc2800e339fa9a75a124023e92bae2ec56.tar.gz
bun-bc7719fc2800e339fa9a75a124023e92bae2ec56.tar.zst
bun-bc7719fc2800e339fa9a75a124023e92bae2ec56.zip
Reliability bugfix for `WebSocket` (#3394)
* Rewrite elementLengthLatin1IntoUTF8 * Update SIMDUTF * Make `elementLengthLatin1IntoUTF8` faster --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
Diffstat (limited to 'src/string_immutable.zig')
-rw-r--r--src/string_immutable.zig71
1 files changed, 27 insertions, 44 deletions
diff --git a/src/string_immutable.zig b/src/string_immutable.zig
index d075b38f5..3931648b8 100644
--- a/src/string_immutable.zig
+++ b/src/string_immutable.zig
@@ -1946,61 +1946,44 @@ pub fn replaceLatin1WithUTF8(buf_: []u8) void {
}
pub fn elementLengthLatin1IntoUTF8(comptime Type: type, latin1_: Type) usize {
+ // https://zig.godbolt.org/z/zzYexPPs9
+
var latin1 = latin1_;
+ const input_len = latin1.len;
var total_non_ascii_count: usize = 0;
- const latin1_last = latin1.ptr + latin1.len;
- if (latin1.ptr != latin1_last) {
-
- // reference the pointer directly because it improves codegen
- var ptr = latin1.ptr;
-
- if (comptime Environment.enableSIMD) {
- const wrapped_len = latin1.len - (latin1.len % ascii_vector_size);
- const latin1_vec_end = ptr + wrapped_len;
- while (ptr != latin1_vec_end) {
- const vec: AsciiVector = ptr[0..ascii_vector_size].*;
- const cmp = vec & @splat(ascii_vector_size, @as(u8, 0x80));
- total_non_ascii_count += @reduce(.Add, cmp);
- ptr += ascii_vector_size;
- }
- } else {
- while (@intFromPtr(ptr + 8) < @intFromPtr(latin1_last)) {
- if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
- const bytes = @bitCast(u64, ptr[0..8].*) & 0x8080808080808080;
- total_non_ascii_count += @popCount(bytes);
- ptr += 8;
- }
+ // This is about 30% faster on large input compared to auto-vectorization
+ if (comptime Environment.enableSIMD) {
+ const end = latin1.ptr + (latin1.len - (latin1.len % ascii_vector_size));
+ while (latin1.ptr != end) {
+ const vec: AsciiVector = latin1[0..ascii_vector_size].*;
+
+ // Shifting a unsigned 8 bit integer to the right by 7 bits always produces a value of 0 or 1.
+ const cmp = vec >> @splat(
+ ascii_vector_size,
+ @as(u8, 7),
+ );
- if (@intFromPtr(ptr + 4) < @intFromPtr(latin1_last)) {
- if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
- const bytes = @bitCast(u32, ptr[0..4].*) & 0x80808080;
- total_non_ascii_count += @popCount(bytes);
- ptr += 4;
- }
+ // Anding that value rather than converting it into a @Vector(16, u1) produces better code from LLVM.
+ const mask = cmp & @splat(
+ ascii_vector_size,
+ @as(u8, 1),
+ );
- if (@intFromPtr(ptr + 2) < @intFromPtr(latin1_last)) {
- if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
- const bytes = @bitCast(u16, ptr[0..2].*) & 0x8080;
- total_non_ascii_count += @popCount(bytes);
- ptr += 2;
- }
+ total_non_ascii_count += @as(usize, @reduce(.Add, mask));
+ latin1 = latin1[ascii_vector_size..];
}
- while (ptr != latin1_last) {
- if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) < @intFromPtr(latin1_last));
-
- total_non_ascii_count += @as(usize, @intFromBool(ptr[0] > 127));
- ptr += 1;
- }
+ // an important hint to the compiler to not auto-vectorize the loop below
+ if (latin1.len >= ascii_vector_size) unreachable;
+ }
- // assert we never go out of bounds
- if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
+ for (latin1) |c| {
+ total_non_ascii_count += @as(usize, @intFromBool(c > 127));
}
// each non-ascii latin1 character becomes 2 UTF8 characters
- // since latin1_.len is the original length, we only need to add up the number of non-ascii characters to get the final count
- return latin1_.len + total_non_ascii_count;
+ return input_len + total_non_ascii_count;
}
const JSC = @import("root").bun.JSC;