aboutsummaryrefslogtreecommitdiff
path: root/src/thread_pool.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/thread_pool.zig')
-rw-r--r--src/thread_pool.zig263
1 files changed, 263 insertions, 0 deletions
diff --git a/src/thread_pool.zig b/src/thread_pool.zig
index 47feae0df..da402407e 100644
--- a/src/thread_pool.zig
+++ b/src/thread_pool.zig
@@ -125,6 +125,269 @@ pub const Batch = struct {
}
};
+pub const WaitGroup = struct {
+ mutex: std.Thread.Mutex = .{},
+ counter: u32 = 0,
+ event: std.Thread.ResetEvent,
+
+ pub fn init(self: *WaitGroup) void {
+ self.* = .{
+ .mutex = .{},
+ .counter = 0,
+ .event = undefined,
+ };
+ }
+
+ pub fn deinit(self: *WaitGroup) void {
+ self.event.reset();
+ self.* = undefined;
+ }
+
+ pub fn start(self: *WaitGroup) void {
+ self.mutex.lock();
+ defer self.mutex.unlock();
+
+ self.counter += 1;
+ }
+
+ pub fn finish(self: *WaitGroup) void {
+ self.mutex.lock();
+ defer self.mutex.unlock();
+
+ self.counter -= 1;
+
+ if (self.counter == 0) {
+ self.event.set();
+ }
+ }
+
+ pub fn wait(self: *WaitGroup) void {
+ while (true) {
+ self.mutex.lock();
+
+ if (self.counter == 0) {
+ self.mutex.unlock();
+ return;
+ }
+
+ self.mutex.unlock();
+ self.event.wait();
+ }
+ }
+
+ pub fn reset(self: *WaitGroup) void {
+ self.event.reset();
+ }
+};
+
+pub fn ConcurrentFunction(
+ comptime Function: anytype,
+) type {
+ return struct {
+ const Fn = Function;
+ const Args = std.meta.ArgsTuple(@TypeOf(Fn));
+ const Runner = @This();
+ thread_pool: *ThreadPool,
+ states: []Routine = undefined,
+ batch: Batch = .{},
+ allocator: std.mem.Allocator,
+
+ pub fn init(allocator: std.mem.Allocator, thread_pool: *ThreadPool, count: usize) !Runner {
+ return Runner{
+ .allocator = allocator,
+ .thread_pool = thread_pool,
+ .states = try allocator.alloc(Routine, count),
+ .batch = .{},
+ };
+ }
+
+ pub fn call(this: *@This(), args: Args) void {
+ this.states[this.batch.len] = .{
+ .args = args,
+ };
+ this.batch.push(Batch.from(&this.states[this.batch.len].task));
+ }
+
+ pub fn run(this: *@This()) void {
+ this.thread_pool.schedule(this.batch);
+ }
+
+ pub const Routine = struct {
+ args: Args,
+ task: Task = .{ .callback = callback },
+
+ pub fn callback(task: *Task) void {
+ var routine = @fieldParentPtr(@This(), "task", task);
+ @call(.always_inline, Fn, routine.args);
+ }
+ };
+
+ pub fn deinit(this: *@This()) void {
+ this.allocator.free(this.states);
+ }
+ };
+}
+
+pub fn runner(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ comptime Function: anytype,
+ count: usize,
+) !ConcurrentFunction(Function) {
+ return try ConcurrentFunction(Function).init(allocator, this, count);
+}
+
+/// Loop over an array of tasks and invoke `Run` on each one in a different thread
+/// **Blocks the calling thread** until all tasks are completed.
+pub fn do(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ wg: ?*WaitGroup,
+ ctx: anytype,
+ comptime Run: anytype,
+ values: anytype,
+) !void {
+ return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, false);
+}
+
+pub fn doPtr(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ wg: ?*WaitGroup,
+ ctx: anytype,
+ comptime Run: anytype,
+ values: anytype,
+) !void {
+ return try Do(this, allocator, wg, @TypeOf(ctx), ctx, Run, @TypeOf(values), values, true);
+}
+
+pub fn Do(
+ this: *ThreadPool,
+ allocator: std.mem.Allocator,
+ wg: ?*WaitGroup,
+ comptime Context: type,
+ ctx: Context,
+ comptime Function: anytype,
+ comptime ValuesType: type,
+ values: ValuesType,
+ comptime as_ptr: bool,
+) !void {
+ if (values.len == 0)
+ return;
+ var allocated_wait_group: ?*WaitGroup = null;
+ defer {
+ if (allocated_wait_group) |group| {
+ group.deinit();
+ allocator.destroy(group);
+ }
+ }
+
+ var wait_group = wg orelse brk: {
+ allocated_wait_group = try allocator.create(WaitGroup);
+ allocated_wait_group.?.init();
+ break :brk allocated_wait_group.?;
+ };
+ const WaitContext = struct {
+ wait_group: *WaitGroup = undefined,
+ ctx: Context,
+ values: ValuesType,
+ };
+
+ const RunnerTask = struct {
+ task: Task,
+ ctx: *WaitContext,
+ i: usize = 0,
+
+ pub fn call(task: *Task) void {
+ var runner_task = @fieldParentPtr(@This(), "task", task);
+ const i = runner_task.i;
+ if (comptime as_ptr) {
+ Function(runner_task.ctx.ctx, &runner_task.ctx.values[i], i);
+ } else {
+ Function(runner_task.ctx.ctx, runner_task.ctx.values[i], i);
+ }
+
+ runner_task.ctx.wait_group.finish();
+ }
+ };
+ var wait_context = allocator.create(WaitContext) catch unreachable;
+ wait_context.* = .{
+ .ctx = ctx,
+ .wait_group = wait_group,
+ .values = values,
+ };
+ defer allocator.destroy(wait_context);
+ var tasks = allocator.alloc(RunnerTask, values.len) catch unreachable;
+ defer allocator.free(tasks);
+ var batch: Batch = undefined;
+ var offset = tasks.len - 1;
+
+ {
+ tasks[0] = .{
+ .i = offset,
+ .task = .{ .callback = RunnerTask.call },
+ .ctx = wait_context,
+ };
+ batch = Batch.from(&tasks[0].task);
+ }
+ if (tasks.len > 1) {
+ for (tasks[1..]) |*runner_task| {
+ offset -= 1;
+ runner_task.* = .{
+ .i = offset,
+ .task = .{ .callback = RunnerTask.call },
+ .ctx = wait_context,
+ };
+ batch.push(Batch.from(&runner_task.task));
+ }
+ }
+
+ wait_group.counter += @intCast(u32, values.len);
+ this.schedule(batch);
+ wait_group.wait();
+}
+
+test "parallel for loop" {
+ Output.initTest();
+ var thread_pool = ThreadPool.init(.{ .max_threads = 12 });
+ var sleepy_time: u32 = 100;
+ var huge_array = &[_]u32{
+ sleepy_time + std.rand.DefaultPrng.init(1).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(2).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(3).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(4).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(5).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(6).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(7).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(8).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(9).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(10).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(11).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(12).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(13).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(14).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(15).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(16).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(17).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(18).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(19).random().uintAtMost(u32, 20),
+ sleepy_time + std.rand.DefaultPrng.init(20).random().uintAtMost(u32, 20),
+ };
+ const Runner = struct {
+ completed: usize = 0,
+ total: usize = 0,
+ pub fn run(ctx: *@This(), value: u32, _: usize) void {
+ std.time.sleep(value);
+ ctx.completed += 1;
+ std.debug.assert(ctx.completed <= ctx.total);
+ }
+ };
+ var runny = try std.heap.page_allocator.create(Runner);
+ runny.* = .{ .total = huge_array.len };
+ try thread_pool.doAndWait(std.heap.page_allocator, null, runny, Runner.run, std.mem.span(huge_array));
+ try std.testing.expectEqual(huge_array.len, runny.completed);
+}
+
/// Schedule a batch of tasks to be executed by some thread on the thread pool.
pub fn schedule(self: *ThreadPool, batch: Batch) void {
// Sanity check