aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <jarred@jarredsumner.com> 2022-04-20 20:38:39 -0700
committerGravatar Jarred Sumner <jarred@jarredsumner.com> 2022-04-20 20:38:39 -0700
commit748cd82187d9f7246cc7d8a32fa1a5c5593bfc48 (patch)
treead42fb0ae4459a627d8e62d978269dc0af875ff9
parent64ae0c77143d1220fac6c8a5441023b070efcc06 (diff)
downloadbun-748cd82187d9f7246cc7d8a32fa1a5c5593bfc48.tar.gz
bun-748cd82187d9f7246cc7d8a32fa1a5c5593bfc48.tar.zst
bun-748cd82187d9f7246cc7d8a32fa1a5c5593bfc48.zip
[misc] Implement generic parallel for loop
-rw-r--r--src/thread_pool.zig155
1 files changed, 149 insertions, 6 deletions
diff --git a/src/thread_pool.zig b/src/thread_pool.zig
index 9a7111553..378adaecb 100644
--- a/src/thread_pool.zig
+++ b/src/thread_pool.zig
@@ -170,14 +170,14 @@ pub fn ConcurrentFunction(
thread_pool: *ThreadPool,
states: []Routine = undefined,
batch: Batch = .{},
+ allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, thread_pool: *ThreadPool, count: usize) !Runner {
- return .{
+ return Runner{
.allocator = allocator,
.thread_pool = thread_pool,
.states = try allocator.alloc(Routine, count),
.batch = .{},
- .i = 0,
};
}
@@ -188,7 +188,7 @@ pub fn ConcurrentFunction(
this.batch.push(Batch.from(&this.states[this.batch.len].task));
}
- pub fn go(this: *@This()) void {
+ pub fn run(this: *@This()) void {
this.thread_pool.schedule(this.batch);
}
@@ -208,13 +208,156 @@ pub fn ConcurrentFunction(
};
}
-/// kind of like a goroutine but with worse DX and without the small stack sizes
-pub fn go(
+pub fn runner(
this: *ThreadPool,
allocator: std.mem.Allocator,
comptime Function: anytype,
+ count: usize,
) !ConcurrentFunction(Function) {
- return try ConcurrentFunction(Function).init(allocator, this);
+ return try ConcurrentFunction(Function).init(allocator, this, count);
+}
+
+/// Loop over an array of tasks and invoke `Run` on each one in a different thread
+/// **Blocks the calling thread** until all tasks are completed.
+pub fn doAndWait(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ wg: ?*WaitGroup,
+ ctx: anytype,
+ comptime Run: anytype,
+ values: anytype,
+) !void {
+ return try Do(this, allocator, wg, true, @TypeOf(ctx), ctx, Run, @TypeOf(values), values);
+}
+
+/// Loop over an array of tasks and invoke `Run` on each one in a different thread
+pub fn do(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ ctx: anytype,
+ comptime Run: anytype,
+ values: anytype,
+) !void {
+ return try Do(this, allocator, null, false, @TypeOf(ctx), ctx, Run, @TypeOf(values), values);
+}
+
+pub fn Do(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ wg: ?*WaitGroup,
+ comptime block: bool,
+ comptime Context: type,
+ ctx: Context,
+ comptime Function: anytype,
+ comptime ValuesType: type,
+ values: ValuesType,
+) !void {
+ if (values.len == 0)
+ return;
+ var allocated_wait_group: ?*WaitGroup = null;
+ defer {
+ if (comptime block) {
+ if (allocated_wait_group) |group| {
+ group.deinit();
+ allocator.destroy(group);
+ }
+ }
+ }
+
+ const WaitGroupType = comptime if (block) *WaitGroup else void;
+ var wait_group: WaitGroupType = undefined;
+
+ if (comptime block) {
+ if (wg) |wg_| {
+ wait_group = wg_;
+ } else {
+ allocated_wait_group = try allocator.create(WaitGroup);
+ try allocated_wait_group.?.init();
+ wait_group = allocated_wait_group.?;
+ }
+ }
+
+ const WaitContext = struct {
+ wait_group: WaitGroupType = undefined,
+ ctx: Context,
+ };
+
+ const Runner = struct {
+ pub fn call(ctx_: WaitContext, values_: ValuesType, i: usize) void {
+ for (values_) |v, j| {
+ Function(ctx_.ctx, v, i + j);
+ }
+ if (comptime block) ctx_.wait_group.finish();
+ }
+ };
+
+ const tasks_per_worker = @maximum(try std.math.divCeil(u32, @intCast(u32, values.len), this.max_threads), 1);
+ const count = @truncate(u32, values.len / tasks_per_worker);
+ var runny = try runner(this, allocator, Runner.call, count);
+ defer runny.deinit();
+
+ var i: usize = 0;
+ const context_ = WaitContext{
+ .ctx = ctx,
+ .wait_group = if (comptime block) wait_group else void{},
+ };
+ var remain = values;
+ while (remain.len > 0) {
+ var slice = remain[0..@minimum(remain.len, tasks_per_worker)];
+
+ runny.call(.{
+ context_,
+ slice,
+ i,
+ });
+ i += slice.len;
+ remain = remain[slice.len..];
+ if (comptime block) wait_group.counter += 1;
+ }
+ runny.run();
+ if (comptime block)
+ wait_group.wait();
+}
+
+test "parallel for loop" {
+ Output.initTest();
+ var thread_pool = ThreadPool.init(.{ .max_threads = 12 });
+ var sleepy_time: u32 = 100;
+ var huge_array = &[_]u32{
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ sleepy_time,
+ };
+ const Runner = struct {
+ completed: usize = 0,
+ total: usize = 0,
+ pub fn run(ctx: *@This(), value: u32, _: usize) void {
+ std.time.sleep(value);
+ ctx.completed += 1;
+ std.debug.assert(ctx.completed <= ctx.total);
+ }
+ };
+ var runny = try std.heap.page_allocator.create(Runner);
+ runny.* = .{ .total = huge_array.len };
+ try thread_pool.doAndWait(std.heap.page_allocator, null, runny, Runner.run, std.mem.span(huge_array));
+ try std.testing.expectEqual(huge_array.len, runny.completed);
}
/// Schedule a batch of tasks to be executed by some thread on the thread pool.