diff --git a/src/mpsc.zig b/src/mpsc.zig index 1c0b4567..68eb8f87 100644 --- a/src/mpsc.zig +++ b/src/mpsc.zig @@ -24,6 +24,7 @@ pub fn Pool(comptime Node: type) type { head: ?*Node = null, // Tracks chunks of allocated nodes, used for freeing them at deinit() time. + cleanup_mu: std.Thread.Mutex = .{}, cleanup: std.ArrayListUnmanaged([*]Node) = .{}, // How many nodes to allocate at once for each chunk in the pool. @@ -72,10 +73,25 @@ pub fn Pool(comptime Node: type) type { break; // Pool is empty } - // Pool is empty, allocate new chunk of nodes, and track the pointer for later cleanup + // Pool is empty, we need to allocate new nodes + // This is the rare path where we need a lock to ensure thread safety only for the + // pool.cleanup tracking list. + pool.cleanup_mu.lock(); + + // Check the pool again after acquiring the lock + // Another thread might have already allocated nodes while we were waiting + const head2 = @atomicLoad(?*Node, &pool.head, .acquire); + if (head2) |_| { + // Pool is no longer empty, release the lock and try to acquire a node again + pool.cleanup_mu.unlock(); + return pool.acquire(allocator); + } + + // Pool still empty, allocate new chunk of nodes, and track the pointer for later cleanup const new_nodes = try allocator.alloc(Node, pool.chunk_size); errdefer allocator.free(new_nodes); try pool.cleanup.append(allocator, @ptrCast(new_nodes.ptr)); + pool.cleanup_mu.unlock(); // Link all our new nodes (except the first one acquired by the caller) into a chain // with eachother. @@ -311,3 +327,43 @@ test "basic" { try std.testing.expectEqual(queue.pop(), 3); try std.testing.expectEqual(queue.pop(), null); } + +test "concurrent producers" { + const allocator = std.testing.allocator; + + var queue: Queue(u32) = undefined; + try queue.init(allocator, 32); + defer queue.deinit(allocator); + + const n_jobs = 100; + const n_entries: u32 = 10000; + + var pool: std.Thread.Pool = undefined; + try std.Thread.Pool.init(&pool, .{ .allocator = allocator, .n_jobs = n_jobs }); + defer pool.deinit(); + + var wg: std.Thread.WaitGroup = .{}; + for (0..n_jobs) |_| { + pool.spawnWg( + &wg, + struct { + pub fn run(q: *Queue(u32)) void { + var i: u32 = 0; + while (i < n_entries) : (i += 1) { + q.push(allocator, i) catch unreachable; + } + } + }.run, + .{&queue}, + ); + } + + wg.wait(); + + // Verify we can read some values without crashing + var count: usize = 0; + while (queue.pop()) |_| { + count += 1; + if (count >= n_jobs * n_entries) break; + } +}