diff options
Diffstat (limited to 'src/thread_pool.zig')
-rw-r--r-- | src/thread_pool.zig | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/src/thread_pool.zig b/src/thread_pool.zig index 47feae0df..da402407e 100644 --- a/src/thread_pool.zig +++ b/src/thread_pool.zig @@ -125,6 +125,269 @@ pub const Batch = struct { } }; +pub const WaitGroup = struct { + mutex: std.Thread.Mutex = .{}, + counter: u32 = 0, + event: std.Thread.ResetEvent, + + pub fn init(self: *WaitGroup) void { + self.* = .{ + .mutex = .{}, + .counter = 0, + .event = undefined, + }; + } + + pub fn deinit(self: *WaitGroup) void { + self.event.reset(); + self.* = undefined; + } + + pub fn start(self: *WaitGroup) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + self.counter += 1; + } + + pub fn finish(self: *WaitGroup) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + self.counter -= 1; + + if (self.counter == 0) { + self.event.set(); + } + } + + pub fn wait(self: *WaitGroup) void { + while (true) { + self.mutex.lock(); + + if (self.counter == 0) { + self.mutex.unlock(); + return; + } + + self.mutex.unlock(); + self.event.wait(); + } + } + + pub fn reset(self: *WaitGroup) void { + self.event.reset(); + } +}; + +pub fn ConcurrentFunction( + comptime Function: anytype, +) type { + return struct { + const Fn = Function; + const Args = std.meta.ArgsTuple(@TypeOf(Fn)); + const Runner = @This(); + thread_pool: *ThreadPool, + states: []Routine = undefined, + batch: Batch = .{}, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator, thread_pool: *ThreadPool, count: usize) !Runner { + return Runner{ + .allocator = allocator, + .thread_pool = thread_pool, + .states = try allocator.alloc(Routine, count), + .batch = .{}, + }; + } + + pub fn call(this: *@This(), args: Args) void { + this.states[this.batch.len] = .{ + .args = args, + }; + this.batch.push(Batch.from(&this.states[this.batch.len].task)); + } + + pub fn run(this: *@This()) void { + this.thread_pool.schedule(this.batch); + } + + pub const Routine = struct { + args: Args, + task: Task = .{ .callback = callback }, + + pub fn callback(task: *Task) void { + var routine = @fieldParentPtr(@This(), "task", task); + @call(.always_inline, Fn, routine.args); + } + }; + + pub fn deinit(this: *@This()) void { + this.allocator.free(this.states); + } + }; +} + +pub fn runner( + this: *ThreadPool, + allocator: std.mem.Allocator, + comptime Function: anytype, + count: usize, +) !ConcurrentFunction(Function) { + return try ConcurrentFunction(Function).init(allocator, this, count); +} + +/// Loop over an array of tasks and invoke `Run` on each one in a different thread +/// **Blocks the calling thread** until all tasks are completed. +pub fn do( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + ctx: anytype, + comptime Run: anytype, + values: anytype, +) !void { + return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, false); +} + +pub fn doPtr( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + ctx: anytype, + comptime Run: anytype, + values: anytype, +) !void { + return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, true); +} + +pub fn Do( + this: *ThreadPool, + allocator: std.mem.Allocator, + wg: ?*WaitGroup, + comptime Context: type, + ctx: Context, + comptime Function: anytype, + comptime ValuesType: type, + values: ValuesType, + comptime as_ptr: bool, +) !void { + if (values.len == 0) + return; + var allocated_wait_group: ?*WaitGroup = null; + defer { + if (allocated_wait_group) |group| { + group.deinit(); + allocator.destroy(group); + } + } + + var wait_group = wg orelse brk: { + allocated_wait_group = try allocator.create(WaitGroup); + allocated_wait_group.?.init(); + break :brk allocated_wait_group.?; + }; + const WaitContext = struct { + wait_group: *WaitGroup = undefined, + ctx: Context, + values: ValuesType, + }; + + const RunnerTask = struct { + task: Task, + ctx: *WaitContext, + i: usize = 0, + + pub fn call(task: *Task) void { + var runner_task = @fieldParentPtr(@This(), "task", task); + const i = runner_task.i; + if (comptime as_ptr) { + Function(runner_task.ctx.ctx, &runner_task.ctx.values[i], i); + } else { + Function(runner_task.ctx.ctx, runner_task.ctx.values[i], i); + } + + runner_task.ctx.wait_group.finish(); + } + }; + var wait_context = allocator.create(WaitContext) catch unreachable; + wait_context.* = .{ + .ctx = ctx, + .wait_group = wait_group, + .values = values, + }; + defer allocator.destroy(wait_context); + var tasks = allocator.alloc(RunnerTask, values.len) catch unreachable; + defer allocator.free(tasks); + var batch: Batch = undefined; + var offset = tasks.len - 1; + + { + tasks[0] = .{ + .i = offset, + .task = .{ .callback = RunnerTask.call }, + .ctx = wait_context, + }; + batch = Batch.from(&tasks[0].task); + } + if (tasks.len > 1) { + for (tasks[1..]) |*runner_task| { + offset -= 1; + runner_task.* = .{ + .i = offset, + .task = .{ .callback = RunnerTask.call }, + .ctx = wait_context, + }; + batch.push(Batch.from(&runner_task.task)); + } + } + + wait_group.counter += @intCast(u32, values.len); + this.schedule(batch); + wait_group.wait(); +} + +test "parallel for loop" { + Output.initTest(); + var thread_pool = ThreadPool.init(.{ .max_threads = 12 }); + var sleepy_time: u32 = 100; + var huge_array = &[_]u32{ + sleepy_time + std.rand.DefaultPrng.init(1).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(2).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(3).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(4).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(5).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(6).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(7).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(8).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(9).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(10).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(11).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(12).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(13).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(14).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(15).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(16).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(17).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(18).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(19).random().uintAtMost(u32, 20), + sleepy_time + std.rand.DefaultPrng.init(20).random().uintAtMost(u32, 20), + }; + const Runner = struct { + completed: usize = 0, + total: usize = 0, + pub fn run(ctx: *@This(), value: u32, _: usize) void { + std.time.sleep(value); + ctx.completed += 1; + std.debug.assert(ctx.completed <= ctx.total); + } + }; + var runny = try std.heap.page_allocator.create(Runner); + runny.* = .{ .total = huge_array.len }; + try thread_pool.doAndWait(std.heap.page_allocator, null, runny, Runner.run, std.mem.span(huge_array)); + try std.testing.expectEqual(huge_array.len, runny.completed); +} + /// Schedule a batch of tasks to be executed by some thread on the thread pool. pub fn schedule(self: *ThreadPool, batch: Batch) void { // Sanity check |