├── .gitattributes ├── .gitignore ├── benchmarks ├── go │ ├── README.md │ └── qsort.go ├── rust │ ├── README.md │ ├── Cargo.toml │ ├── rayon_qsort.rs │ └── qsort.rs └── zig │ ├── README.md │ ├── build.zig │ ├── qsort.zig │ └── async.zig ├── LICENSE-MIT ├── README.md ├── src ├── thread_pool_go_based.zig └── thread_pool.zig └── blog.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.zig text=auto eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE stuff 2 | .vscode/ 3 | .history/ 4 | 5 | # build directories 6 | zig-out/ 7 | zig-cache/ 8 | target/ 9 | gobuild/ 10 | 11 | # build artifacts 12 | Cargo.lock 13 | NUL 14 | 15 | # testing 16 | main.zig 17 | perf.* -------------------------------------------------------------------------------- /benchmarks/go/README.md: -------------------------------------------------------------------------------- 1 | # Go implementations 2 | 3 | ``` 4 | go run qsort.go 5 | ``` 6 | Runs the quick sort benchmark using [Golang](https://golang.org/). With respect to Zig, this is a baseline that the thread pool should meet or exceed. If not, then this experiment didn't meet its goal in my opinion. -------------------------------------------------------------------------------- /benchmarks/rust/README.md: -------------------------------------------------------------------------------- 1 | # Rust implementations 2 | 3 | ``` 4 | cargo run --release --bin qsort 5 | ``` 6 | Runs the quick sort benchmark using [`tokio`](https://tokio.rs/). The primiary rust representative in this case given it also supports async. 7 | 8 | ``` 9 | cargo run --release --bin rsort 10 | ``` 11 | Runs the quick sort benchmark using [`rayon`](https://docs.rs/rayon). This is just out of curiousity. Rayon doesn't support async AFAIK but it's still a good comparison. -------------------------------------------------------------------------------- /benchmarks/zig/README.md: -------------------------------------------------------------------------------- 1 | # Zig implementations 2 | 3 | ``` 4 | zig build run -Drelease-fast // add -Dc flag if on posix systems 5 | ``` 6 | Runs the quick sort benchmark using the thread pool in this repo which is written in [Ziglang](https://ziglang.org/). `async.zig` wraps the thread pool api which similar `async/await` syntax as the other languages in the benchmark. Also, I wrote two thread pools for curiousity. You can switch which one is used by changing the path to the zig file in `build.zig`. -------------------------------------------------------------------------------- /benchmarks/rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bench" 3 | version = "0.0.0" 4 | edition = "2021" 5 | authors = ["kprotty"] 6 | 7 | [[bin]] 8 | name = "qsort" 9 | path = "qsort.rs" 10 | 11 | [[bin]] 12 | name = "rsort" 13 | path = "rayon_qsort.rs" 14 | 15 | [dependencies.rayon] 16 | version = "1.5" 17 | 18 | [dependencies.tokio] 19 | version = "1" 20 | features = ["rt-multi-thread", "sync", "macros"] 21 | 22 | [profile.release] 23 | codegen-units = 1 24 | lto = true 25 | panic = "abort" 26 | -------------------------------------------------------------------------------- /benchmarks/zig/build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.build.Builder) void { 4 | const mode = b.standardReleaseOptions(); 5 | const target = b.standardTargetOptions(.{}); 6 | const link_c = b.option(bool, "c", "link libc") orelse false; 7 | 8 | const exe = b.addExecutable("qsort", "qsort.zig"); 9 | if (link_c) { 10 | exe.linkLibC(); 11 | } 12 | 13 | exe.addPackage(.{ 14 | .name = "thread_pool", 15 | .path = .{ .path = "../../src/thread_pool.zig" }, 16 | }); 17 | exe.setTarget(target); 18 | exe.setBuildMode(mode); 19 | exe.install(); 20 | 21 | const run = b.step("run", "Run the benchmark"); 22 | run.dependOn(&exe.run().step); 23 | } -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 kprotty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | zap [![License](https://img.shields.io/badge/license-MIT-8FBD08.svg)](https://shields.io/) 2 | ==== 3 | Designing efficient task scheduling for Ziglang. 4 | 5 | ## Goals 6 | So I originally started this project around 2019 in order to develop memory, threads, io, and synchronization primitives for Zig given they were lacking at the time. Over the months, it shifted more on developing a runtime (or thread pool rather) that was both resource efficient (one of Zig's, and my personal, implicit Zen's) and competitive in performance with existing implementations. 7 | 8 | Here lies the result of that effort for now. There's still more experimenting to do like how to dispatch I/O efficiently and the like, but I'm happy with what has come and wanted to share. You can find a copy of the blogpost [in this repo](blog.md), the reference implementation in [src](src/thread_pool.zig), and some of my [previous attempts](zap/tree/old_branches) in their own branch. 9 | 10 | ## Benchmarks 11 | To benchmark the implementation, I wrote some quicksort implementations for similar APIs in other languages. The reasoning behind quicksort is that it's fairly practical and can also be heavy with concurrency. Try running them locally! 12 | -------------------------------------------------------------------------------- /benchmarks/go/qsort.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | "sync" 7 | ) 8 | 9 | func main() { 10 | arr := make([]int, 10 * 1000 * 1000) 11 | 12 | fmt.Println("filling") 13 | for i := 0; i < len(arr); i++ { 14 | arr[i] = i 15 | } 16 | 17 | fmt.Println("shuffling") 18 | shuffle(arr) 19 | 20 | fmt.Println("running") 21 | start := time.Now() 22 | quickSort(arr) 23 | 24 | fmt.Println("took", time.Since(start)) 25 | if !verify(arr) { 26 | panic("array not sorted") 27 | } 28 | } 29 | 30 | func verify(arr []int) bool { 31 | for i := 0;; i++ { 32 | if i == len(arr) - 1 { 33 | return true 34 | } else if arr[i] > arr[i + 1] { 35 | return false 36 | } 37 | } 38 | } 39 | 40 | func shuffle(arr []int) { 41 | var xs uint = 0xdeadbeef 42 | for i := uint(len(arr)) - 1; i > 0; i-- { 43 | xs ^= xs << 13 44 | xs ^= xs >> 17 45 | xs ^= xs << 5 46 | j := xs % (i + 1) 47 | arr[i], arr[j] = arr[j], arr[i] 48 | } 49 | } 50 | 51 | func quickSort(arr []int) { 52 | if len(arr) <= 32 { 53 | insertionSort(arr) 54 | } else { 55 | var wg sync.WaitGroup 56 | mid := partition(arr) 57 | 58 | wg.Add(2) 59 | go func() { 60 | quickSort(arr[:mid]) 61 | wg.Done() 62 | }() 63 | go func() { 64 | quickSort(arr[mid:]) 65 | wg.Done() 66 | }() 67 | 68 | wg.Wait() 69 | } 70 | } 71 | 72 | func partition(arr []int) int { 73 | pivot := len(arr) - 1 74 | i := 0; 75 | for j := 0; j < pivot; j++ { 76 | if arr[j] <= arr[pivot] { 77 | arr[i], arr[j] = arr[j], arr[i] 78 | i++ 79 | } 80 | } 81 | arr[i], arr[pivot] = arr[pivot], arr[i] 82 | return i 83 | } 84 | 85 | func insertionSort(arr [] int) { 86 | for i := 1; i < len(arr); i++ { 87 | for n := i; n > 0 && arr[n] < arr[n - 1]; n-- { 88 | arr[n], arr[n - 1] = arr[n - 1], arr[n] 89 | } 90 | } 91 | } -------------------------------------------------------------------------------- /benchmarks/rust/rayon_qsort.rs: -------------------------------------------------------------------------------- 1 | const SIZE: usize = 10_000_000; 2 | 3 | fn main() { 4 | println!("filling"); 5 | let mut arr = (0..SIZE) 6 | .map(|i| i.try_into().unwrap()) 7 | .collect::>() 8 | .into_boxed_slice(); 9 | 10 | println!("shuffling"); 11 | shuffle(&mut arr); 12 | 13 | println!("running"); 14 | let start = std::time::Instant::now(); 15 | quick_sort(&mut arr); 16 | 17 | println!("took {:?}", start.elapsed()); 18 | assert!(verify(&arr)); 19 | } 20 | 21 | fn verify(arr: &[i32]) -> bool { 22 | arr.windows(2).all(|i| i[0] <= i[1]) 23 | } 24 | 25 | fn shuffle(arr: &mut [i32]) { 26 | let mut xs: u32 = 0xdeadbeef; 27 | for i in 0..arr.len() { 28 | xs ^= xs << 13; 29 | xs ^= xs >> 17; 30 | xs ^= xs << 5; 31 | let j = (xs as usize) % (i + 1); 32 | arr.swap(i, j); 33 | } 34 | } 35 | 36 | fn quick_sort(arr: &mut [i32]) { 37 | if arr.len() <= 32 { 38 | insertion_sort(arr); 39 | } else { 40 | let mid = partition(arr); 41 | let (low, high) = arr.split_at_mut(mid); 42 | 43 | rayon::scope(|s| { 44 | s.spawn(|_| quick_sort(low)); 45 | s.spawn(|_| quick_sort(high)); 46 | }); 47 | 48 | // Optimized version (hooks directly into the scheduler) 49 | // rayon::join( 50 | // || quick_sort(low), 51 | // || quick_sort(high) 52 | // ); 53 | } 54 | } 55 | 56 | fn partition(arr: &mut [i32]) -> usize { 57 | let pivot = arr.len() - 1; 58 | let mut i = 0; 59 | for j in 0..pivot { 60 | if arr[j] <= arr[pivot] { 61 | arr.swap(i, j); 62 | i += 1; 63 | } 64 | } 65 | arr.swap(i, pivot); 66 | i 67 | } 68 | 69 | fn insertion_sort(arr: &mut [i32]) { 70 | for i in 1..arr.len() { 71 | let mut n = i; 72 | while n > 0 && arr[n] < arr[n - 1] { 73 | arr.swap(n, n - 1); 74 | n -= 1; 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /benchmarks/rust/qsort.rs: -------------------------------------------------------------------------------- 1 | const SIZE: usize = 10_000_000; 2 | 3 | #[tokio::main] 4 | pub async fn main() { 5 | use std::convert::TryInto; 6 | 7 | println!("filling"); 8 | let arr = (0..SIZE) 9 | .map(|i| i.try_into().unwrap()) 10 | .collect::>() 11 | .into_boxed_slice(); 12 | 13 | let mut arr = Box::leak(arr); 14 | let arr_ptr = arr.as_ptr(); 15 | 16 | println!("shuffling"); 17 | shuffle(&mut arr); 18 | 19 | println!("running"); 20 | let start = std::time::Instant::now(); 21 | quick_sort(arr).await; 22 | 23 | println!("took {:?}", start.elapsed()); 24 | assert!(verify(unsafe { std::slice::from_raw_parts(arr_ptr, SIZE) })); 25 | } 26 | 27 | fn verify(arr: &[i32]) -> bool { 28 | arr.windows(2).all(|i| i[0] <= i[1]) 29 | } 30 | 31 | fn shuffle(arr: &mut [i32]) { 32 | let mut xs: u32 = 0xdeadbeef; 33 | for i in 0..arr.len() { 34 | xs ^= xs << 13; 35 | xs ^= xs >> 17; 36 | xs ^= xs << 5; 37 | let j = (xs as usize) % (i + 1); 38 | arr.swap(i, j); 39 | } 40 | } 41 | 42 | async fn quick_sort(arr: &'static mut [i32]) { 43 | if arr.len() <= 32 { 44 | insertion_sort(arr); 45 | } else { 46 | let mid = partition(arr); 47 | let (low, high) = arr.split_at_mut(mid); 48 | 49 | fn spawn_quick_sort(array: &'static mut [i32]) -> tokio::task::JoinHandle<()> { 50 | tokio::spawn(async move { 51 | quick_sort(array).await 52 | }) 53 | } 54 | 55 | let left = spawn_quick_sort(low); 56 | let right = spawn_quick_sort(high); 57 | 58 | left.await.unwrap(); 59 | right.await.unwrap(); 60 | } 61 | } 62 | 63 | fn partition(arr: &mut [i32]) -> usize { 64 | let pivot = arr.len() - 1; 65 | let mut i = 0; 66 | for j in 0..pivot { 67 | if arr[j] <= arr[pivot] { 68 | arr.swap(i, j); 69 | i += 1; 70 | } 71 | } 72 | arr.swap(i, pivot); 73 | i 74 | } 75 | 76 | fn insertion_sort(arr: &mut [i32]) { 77 | for i in 1..arr.len() { 78 | let mut n = i; 79 | while n > 0 && arr[n] < arr[n - 1] { 80 | arr.swap(n, n - 1); 81 | n -= 1; 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /benchmarks/zig/qsort.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Async = @import("async.zig"); 3 | 4 | const SIZE = 10_000_000; 5 | 6 | pub fn main() void { 7 | return Async.run(asyncMain, .{}); 8 | } 9 | 10 | fn asyncMain() void { 11 | const arr = Async.allocator.alloc(i32, SIZE) catch @panic("failed to allocate array"); 12 | defer Async.allocator.free(arr); 13 | 14 | std.debug.print("filling\n", .{}); 15 | for (arr) |*item, i| { 16 | item.* = @intCast(i32, i); 17 | } 18 | 19 | std.debug.print("shuffling\n", .{}); 20 | shuffle(arr); 21 | 22 | std.debug.print("running\n", .{}); 23 | var timer = std.time.Timer.start() catch @panic("failed to create os timer"); 24 | quickSort(arr); 25 | 26 | var elapsed = @intToFloat(f64, timer.lap()); 27 | var units: []const u8 = "ns"; 28 | if (elapsed >= std.time.ns_per_s) { 29 | elapsed /= std.time.ns_per_s; 30 | units = "s"; 31 | } else if (elapsed >= std.time.ns_per_ms) { 32 | elapsed /= std.time.ns_per_ms; 33 | units = "ms"; 34 | } else if (elapsed >= std.time.ns_per_us) { 35 | elapsed /= std.time.ns_per_us; 36 | units = "us"; 37 | } 38 | 39 | std.debug.print("took {d:.2}{s}\n", .{ elapsed, units }); 40 | if (!verify(arr)) { 41 | std.debug.panic("array not sorted", .{}); 42 | } 43 | } 44 | 45 | fn verify(arr: []const i32) bool { 46 | var i: usize = 0; 47 | while (true) : (i += 1) { 48 | if (i == arr.len - 1) return true; 49 | if (arr[i] > arr[i + 1]) return false; 50 | } 51 | } 52 | 53 | fn shuffle(arr: []i32) void { 54 | var xs: u32 = 0xdeadbeef; 55 | for (arr) |_, i| { 56 | xs ^= xs << 13; 57 | xs ^= xs >> 17; 58 | xs ^= xs << 5; 59 | const j = xs % (i + 1); 60 | std.mem.swap(i32, &arr[i], &arr[j]); 61 | } 62 | } 63 | 64 | fn quickSort(arr: []i32) void { 65 | if (arr.len <= 32) { 66 | insertionSort(arr); 67 | } else { 68 | const mid = partition(arr); 69 | 70 | var left = Async.spawn(quickSort, .{arr[0..mid]}); 71 | var right = Async.spawn(quickSort, .{arr[mid..]}); 72 | 73 | left.join(); 74 | right.join(); 75 | } 76 | } 77 | 78 | fn partition(arr: []i32) usize { 79 | const pivot = arr.len - 1; 80 | var i: usize = 0; 81 | for (arr[0..pivot]) |_, j| { 82 | if (arr[j] <= arr[pivot]) { 83 | std.mem.swap(i32, &arr[j], &arr[i]); 84 | i += 1; 85 | } 86 | } 87 | std.mem.swap(i32, &arr[i], &arr[pivot]); 88 | return i; 89 | } 90 | 91 | fn insertionSort(arr: []i32) void { 92 | for (arr[1..]) |_, i| { 93 | var n = i + 1; 94 | while (n > 0 and arr[n] < arr[n - 1]) { 95 | std.mem.swap(i32, &arr[n], &arr[n - 1]); 96 | n -= 1; 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /benchmarks/zig/async.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const builtin = @import("builtin"); 3 | const ThreadPool = @import("thread_pool"); 4 | 5 | /// Global thread pool which mimics other async runtimes 6 | var thread_pool: ThreadPool = undefined; 7 | 8 | /// Global allocator which mimics other async runtimes 9 | pub var allocator: std.mem.Allocator = undefined; 10 | 11 | /// Zig async wrapper around ThreadPool.Task 12 | const Task = struct { 13 | tp_task: ThreadPool.Task = .{ .callback = onSchedule }, 14 | frame: anyframe, 15 | 16 | fn onSchedule(tp_task: *ThreadPool.Task) void { 17 | const task = @fieldParentPtr(Task, "tp_task", tp_task); 18 | resume task.frame; 19 | } 20 | 21 | fn schedule(self: *Task) void { 22 | const batch = ThreadPool.Batch.from(&self.tp_task); 23 | thread_pool.schedule(batch); 24 | } 25 | }; 26 | 27 | fn ReturnTypeOf(comptime asyncFn: anytype) type { 28 | return @typeInfo(@TypeOf(asyncFn)).Fn.return_type.?; 29 | } 30 | 31 | /// Entry point for the synchronous main() to an async function. 32 | /// Initializes the global thread pool and allocator 33 | /// then calls asyncFn(...args) in the thread pool and returns the results. 34 | pub fn run(comptime asyncFn: anytype, args: anytype) ReturnTypeOf(asyncFn) { 35 | const Args = @TypeOf(args); 36 | const Wrapper = struct { 37 | fn entry(task: *Task, fn_args: Args) ReturnTypeOf(asyncFn) { 38 | // Prepare the task to resume this frame once it's scheduled. 39 | // Returns execution to after `async Wrapper.entry(&task, args)`. 40 | suspend { 41 | task.* = .{ .frame = @frame() }; 42 | } 43 | 44 | // Begin teardown of the thread pool after the entry point async fn completes. 45 | defer thread_pool.shutdown(); 46 | 47 | // Run the entry point async fn 48 | return @call(.{}, asyncFn, fn_args); 49 | } 50 | }; 51 | 52 | var task: Task = undefined; 53 | var frame = async Wrapper.entry(&task, args); 54 | 55 | // On windows, use the process heap allocator. 56 | // On posix systems, use the libc allocator. 57 | const is_windows = builtin.target.os.tag == .windows; 58 | var win_heap: if (is_windows) std.heap.HeapAllocator else void = undefined; 59 | if (is_windows) { 60 | win_heap = @TypeOf(win_heap).init(); 61 | win_heap.heap_handle = std.os.windows.kernel32.GetProcessHeap() orelse unreachable; 62 | allocator = win_heap.allocator; 63 | } else if (builtin.link_libc) { 64 | allocator = std.heap.c_allocator; 65 | } else { 66 | @compileError("link to libc with '-Dc' as zig stdlib doesn't provide a fast, libc-less, general purpose allocator (yet)"); 67 | } 68 | 69 | const num_cpus = std.Thread.getCpuCount() catch @panic("failed to get cpu core count"); 70 | const num_threads = std.math.cast(u16, num_cpus) catch std.math.maxInt(u16); 71 | thread_pool = ThreadPool.init(.{ .max_threads = num_threads }); 72 | 73 | // Schedule the task onto the thread pool and wait for the thread pool to be shutdown() by the task. 74 | task.schedule(); 75 | thread_pool.deinit(); 76 | 77 | // At this point, all threads in the pool should not be running async tasks 78 | // so the main task/frame has been completed. 79 | return nosuspend await frame; 80 | } 81 | 82 | /// State synchronization which handles waiting for the result of a spawned async function. 83 | fn SpawnHandle(comptime T: type) type { 84 | return struct { 85 | state: std.atomic.Atomic(usize) = std.atomic.Atomic(usize).init(0), 86 | 87 | const Self = @This(); 88 | const DETACHED: usize = 0b1; 89 | const Waiter = struct { 90 | task: Task, 91 | value: T, 92 | }; 93 | 94 | /// Called by the async function to resolve the join() coroutine with the function result. 95 | /// Returns without doing anything if it was detach()ed. 96 | pub fn complete(self: *Self, value: T) void { 97 | // Prepare our waiter node with the value 98 | var waiter = Waiter{ 99 | .value = value, 100 | .task = .{ .frame = @frame() }, 101 | }; 102 | 103 | // Suspend get ready to wait asynchonously. 104 | suspend { 105 | // Acquire barrier to ensuer we see the join()'s *Waiter writes if present. 106 | // Release barrier to ensure join() and detach() see our *Waiter writes. 107 | const state = self.state.swap(@ptrToInt(&waiter), .AcqRel); 108 | 109 | // If join() or detach() were called before us. 110 | if (state != 0) { 111 | // Then fill the result value for join() & wake it up. 112 | if (state != DETACHED) { 113 | const joiner = @intToPtr(*Waiter, state); 114 | joiner.value = waiter.value; 115 | joiner.task.schedule(); 116 | } 117 | // Also wake ourselves up since there's nothing to wait for. 118 | waiter.task.schedule(); 119 | } 120 | } 121 | } 122 | 123 | /// Waits for the async fn to call complete(T) and returns the T given to complete(). 124 | pub fn join(self: *Self) T { 125 | var waiter = Waiter{ 126 | .value = undefined, // the complete() task will fill this for us 127 | .task = .{ .frame = @frame() }, 128 | }; 129 | 130 | suspend { 131 | // Acquire barrier to ensuer we see the complete()'s *Waiter writes if present. 132 | // Release barrier to ensure complete() sees our *Waiter writes. 133 | if (@intToPtr(?*Waiter, self.state.swap(@ptrToInt(&waiter), .AcqRel))) |completer| { 134 | // complete() was waiting for us to consume its value. 135 | // Do so and reschedule both of us. 136 | waiter.value = completer.value; 137 | completer.task.schedule(); 138 | waiter.task.schedule(); 139 | } 140 | } 141 | 142 | // Return the waiter value which is either: 143 | // - consumed by the waiting complete() above or 144 | // - filled in by complete() when it goes to suspend 145 | return waiter.value; 146 | } 147 | 148 | pub fn detach(self: *Self) void { 149 | // Mark the state as detached, making a subsequent complete() no-op 150 | // Wake up the waiting complete() if it was there before us. 151 | // Acquire barrier in order to see the complete()'s *Waiter writes. 152 | if (@intToPtr(?*Waiter, self.state.swap(DETACHED, .Acquire))) |completer| { 153 | completer.task.schedule(); 154 | } 155 | } 156 | }; 157 | } 158 | 159 | /// A type-safe wrapper around SpawnHandle() for the spawn() caller. 160 | pub fn JoinHandle(comptime T: type) type { 161 | return struct { 162 | spawn_handle: *SpawnHandle(T), 163 | 164 | pub fn join(self: @This()) T { 165 | return self.spawn_handle.join(); 166 | } 167 | 168 | pub fn detach(self: @This()) void { 169 | return self.spawn_handle.detach(); 170 | } 171 | }; 172 | } 173 | 174 | /// Dynamically allocates and runs an async function concurrently to the caller. 175 | /// Returns a handle to the async function which can be used to wait for its result or detach it as a dependency. 176 | pub fn spawn(comptime asyncFn: anytype, args: anytype) JoinHandle(ReturnTypeOf(asyncFn)) { 177 | const Args = @TypeOf(args); 178 | const Result = ReturnTypeOf(asyncFn); 179 | const Wrapper = struct { 180 | fn entry(spawn_handle_ref: **SpawnHandle(Result), fn_args: Args) void { 181 | // Create the spawn handle in the @Frame() and return a reference of it to the caller. 182 | var spawn_handle = SpawnHandle(Result){}; 183 | spawn_handle_ref.* = &spawn_handle; 184 | 185 | // Reschedule the @Frame() so that it can run concurrently from the caller. 186 | // This returns execution to after `async Wrapper.entry()` down below. 187 | var task = Task{ .frame = @frame() }; 188 | suspend { 189 | task.schedule(); 190 | } 191 | 192 | // Run the async function and synchronize the reuslt with the spawn/join handle. 193 | const result = @call(.{}, asyncFn, fn_args); 194 | spawn_handle.complete(result); 195 | 196 | // Finally, we deallocate this @Frame() since we're done with it. 197 | // The `suspend` is there as a trick to avoid a use-after-free: 198 | // 199 | // Zig async functions are appended with some code to resume an `await`er if any. 200 | // That code involves interacting with the Frame's memory which is a no-no once deallocated. 201 | // To avoid that, we first suspend. This ensures any frame interactions happen befor the suspend-block. 202 | // This also means that any `await`er would block indefinitely, 203 | // but that's fine since we're using a custom method with SpawnHandle instead of await to get the value. 204 | suspend { 205 | allocator.destroy(@frame()); 206 | } 207 | } 208 | }; 209 | 210 | const frame = allocator.create(@Frame(Wrapper.entry)) catch @panic("failed to allocate coroutine"); 211 | var spawn_handle: *SpawnHandle(Result) = undefined; 212 | frame.* = async Wrapper.entry(&spawn_handle, args); 213 | return JoinHandle(Result){ .spawn_handle = spawn_handle }; 214 | } 215 | -------------------------------------------------------------------------------- /src/thread_pool_go_based.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | const std = @import("std"); 3 | const assert = std.debug.assert; 4 | const Atomic = std.atomic.Atomic; 5 | const ThreadPool = @This(); 6 | 7 | stack_size: u32, 8 | max_threads: u16, 9 | queue: Node.Queue = .{}, 10 | join_event: Event = .{}, 11 | idle_event: Event = .{}, 12 | sync: Atomic(u32) = Atomic(u32).init(0), 13 | threads: Atomic(?*Thread) = Atomic(?*Thread).init(null), 14 | 15 | const Sync = packed struct { 16 | idle: u10 = 0, 17 | spawned: u10 = 0, 18 | stealing: u10 = 0, 19 | padding: u1 = 0, 20 | shutdown: bool = false, 21 | }; 22 | 23 | pub const Config = struct { 24 | max_threads: u16, 25 | stack_size: u32 = (std.Thread.SpawnConfig{}).stack_size, 26 | }; 27 | 28 | pub fn init(config: Config) ThreadPool { 29 | return .{ 30 | .max_threads = std.math.max(1, config.max_threads), 31 | .stack_size = std.math.max(std.mem.page_size, config.stack_size), 32 | }; 33 | } 34 | 35 | pub fn deinit(self: *ThreadPool) void { 36 | self.join(); 37 | self.* = undefined; 38 | } 39 | 40 | /// A Task represents the unit of Work / Job / Execution that the ThreadPool schedules. 41 | /// The user provides a `callback` which is invoked when the *Task can run on a thread. 42 | pub const Task = struct { 43 | node: Node = .{}, 44 | callback: fn (*Task) void, 45 | }; 46 | 47 | /// An unordered collection of Tasks which can be submitted for scheduling as a group. 48 | pub const Batch = struct { 49 | len: usize = 0, 50 | head: ?*Task = null, 51 | tail: ?*Task = null, 52 | 53 | /// Create a batch from a single task. 54 | pub fn from(task: *Task) Batch { 55 | return Batch{ 56 | .len = 1, 57 | .head = task, 58 | .tail = task, 59 | }; 60 | } 61 | 62 | /// Another batch into this one, taking ownership of its tasks. 63 | pub fn push(self: *Batch, batch: Batch) void { 64 | if (batch.len == 0) return; 65 | if (self.len == 0) { 66 | self.* = batch; 67 | } else { 68 | self.tail.?.node.next = if (batch.head) |h| &h.node else null; 69 | self.tail = batch.tail; 70 | self.len += batch.len; 71 | } 72 | } 73 | }; 74 | 75 | /// Schedule a batch of tasks to be executed by some thread on the thread pool. 76 | pub noinline fn schedule(self: *ThreadPool, batch: Batch) void { 77 | // Sanity check 78 | if (batch.len == 0) { 79 | return; 80 | } 81 | 82 | // Extract out the Node's from the Tasks 83 | var list = Node.List{ 84 | .head = &batch.head.?.node, 85 | .tail = &batch.tail.?.node, 86 | }; 87 | 88 | // Push the task Nodes to the most approriate queue 89 | if (Thread.current) |thread| { 90 | thread.buffer.push(&list) catch thread.queue.push(list); 91 | } else { 92 | self.queue.push(list); 93 | } 94 | 95 | const sync = @bitCast(Sync, self.sync.load(.Monotonic)); 96 | if (sync.shutdown) return; 97 | if (sync.stealing > 0) return; 98 | if (sync.idle == 0 and sync.spawned == self.max_threads) return; 99 | return self.notify(); 100 | } 101 | 102 | noinline fn notify(self: *ThreadPool) void { 103 | var sync = @bitCast(Sync, self.sync.load(.Monotonic)); 104 | while (true) { 105 | if (sync.shutdown) return; 106 | if (sync.stealing != 0) return; 107 | 108 | var new_sync = sync; 109 | new_sync.stealing = 1; 110 | if (sync.idle > 0) { 111 | // the thread will decrement idle on its own 112 | } else if (sync.spawned < self.max_threads) { 113 | new_sync.spawned += 1; 114 | } else { 115 | return; 116 | } 117 | 118 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 119 | @bitCast(u32, sync), 120 | @bitCast(u32, new_sync), 121 | .SeqCst, 122 | .Monotonic, 123 | ) orelse { 124 | if (sync.idle > 0) 125 | return self.idle_event.notify(); 126 | 127 | assert(sync.spawned < self.max_threads); 128 | const spawn_config = std.Thread.SpawnConfig{ .stack_size = self.stack_size }; 129 | const thread = std.Thread.spawn(spawn_config, Thread.run, .{self}) catch @panic("failed to spawn a thread"); 130 | thread.detach(); 131 | return; 132 | }); 133 | } 134 | } 135 | 136 | /// Marks the thread pool as shutdown 137 | pub noinline fn shutdown(self: *ThreadPool) void { 138 | var sync = @bitCast(Sync, self.sync.load(.Monotonic)); 139 | while (!sync.shutdown) { 140 | var new_sync = sync; 141 | new_sync.shutdown = true; 142 | 143 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 144 | @bitCast(u32, sync), 145 | @bitCast(u32, new_sync), 146 | .SeqCst, 147 | .Monotonic, 148 | ) orelse { 149 | self.idle_event.shutdown(); 150 | return; 151 | }); 152 | } 153 | } 154 | 155 | noinline fn register(self: *ThreadPool, thread: *Thread) void { 156 | var threads = self.threads.load(.Monotonic); 157 | while (true) { 158 | thread.next = threads; 159 | threads = self.threads.tryCompareAndSwap( 160 | threads, 161 | thread, 162 | .Release, 163 | .Monotonic, 164 | ) orelse break; 165 | } 166 | } 167 | 168 | noinline fn unregister(self: *ThreadPool, thread: *Thread) void { 169 | const one_spawned = @bitCast(u32, Sync{ .spawned = 1 }); 170 | const sync = @bitCast(Sync, self.sync.fetchSub(one_spawned, .SeqCst)); 171 | 172 | assert(sync.spawned > 0); 173 | if (sync.spawned == 1) { 174 | self.join_event.notify(); 175 | } 176 | 177 | thread.join_event.wait(); 178 | if (thread.next) |next| { 179 | next.join_event.notify(); 180 | } 181 | } 182 | 183 | noinline fn join(self: *ThreadPool) void { 184 | self.join_event.wait(); 185 | if (self.threads.load(.Acquire)) |thread| { 186 | thread.join_event.notify(); 187 | } 188 | } 189 | 190 | const Thread = struct { 191 | pool: *ThreadPool, 192 | next: ?*Thread = null, 193 | stealing: bool = true, 194 | target: ?*Thread = null, 195 | join_event: Event = .{}, 196 | buffer: Node.Buffer = .{}, 197 | queue: Node.Queue = .{}, 198 | 199 | threadlocal var current: ?*Thread = null; 200 | 201 | fn run(thread_pool: *ThreadPool) void { 202 | var self = Thread{ .pool = thread_pool }; 203 | current = &self; 204 | 205 | self.pool.register(&self); 206 | defer self.pool.unregister(&self); 207 | 208 | while (true) { 209 | const node = self.poll() catch break; 210 | const task = @fieldParentPtr(Task, "node", node); 211 | (task.callback)(task); 212 | } 213 | } 214 | 215 | fn poll(self: *Thread) error{Shutdown}!*Node { 216 | defer if (self.stealing) { 217 | const one_stealing = @bitCast(u32, Sync{ .stealing = 1 }); 218 | const sync = @bitCast(Sync, self.pool.sync.fetchSub(one_stealing, .SeqCst)); 219 | 220 | // assert(sync.stealing > 0); 221 | if (sync.stealing == 0) { 222 | std.debug.print("{} resetspinning(): {}\n", .{std.Thread.getCurrentId(), sync}); 223 | unreachable; 224 | } 225 | 226 | self.stealing = false; 227 | self.pool.notify(); 228 | }; 229 | 230 | if (self.buffer.pop()) |node| 231 | return node; 232 | 233 | while (true) { 234 | if (self.buffer.consume(&self.queue)) |result| 235 | return result.node; 236 | 237 | if (self.buffer.consume(&self.pool.queue)) |result| 238 | return result.node; 239 | 240 | if (!self.stealing) blk: { 241 | var sync = @bitCast(Sync, self.pool.sync.load(.Monotonic)); 242 | if ((@as(u32, sync.stealing) * 2) >= (sync.spawned - sync.idle)) 243 | break :blk; 244 | 245 | const one_stealing = @bitCast(u32, Sync{ .stealing = 1 }); 246 | sync = @bitCast(Sync, self.pool.sync.fetchAdd(one_stealing, .SeqCst)); 247 | assert(sync.stealing < sync.spawned); 248 | self.stealing = true; 249 | } 250 | 251 | if (self.stealing) { 252 | var attempts: u8 = 4; 253 | while (attempts > 0) : (attempts -= 1) { 254 | var num_threads: u16 = @bitCast(Sync, self.pool.sync.load(.Monotonic)).spawned; 255 | while (num_threads > 0) : (num_threads -= 1) { 256 | const thread = self.target orelse self.pool.threads.load(.Acquire) orelse unreachable; 257 | self.target = thread.next; 258 | 259 | if (self.buffer.consume(&thread.queue)) |result| 260 | return result.node; 261 | 262 | if (self.buffer.steal(&thread.buffer)) |result| 263 | return result.node; 264 | } 265 | } 266 | } 267 | 268 | if (self.buffer.consume(&self.pool.queue)) |result| 269 | return result.node; 270 | 271 | var update = @bitCast(u32, Sync{ .idle = 1 }); 272 | if (self.stealing) { 273 | update -%= @bitCast(u32, Sync{ .stealing = 1 }); 274 | } 275 | 276 | var sync = @bitCast(Sync, self.pool.sync.fetchAdd(update, .SeqCst)); 277 | //std.debug.print("\nwait {}({}):{}\n\t\t{}\n", .{std.Thread.getCurrentId(), self.stealing, sync, @bitCast(Sync, @bitCast(u32, sync) +% update)}); 278 | assert(sync.idle < sync.spawned); 279 | if (self.stealing) assert(sync.stealing <= sync.spawned); 280 | self.stealing = false; 281 | 282 | update = @bitCast(u32, Sync{ .idle = 1 }); 283 | if (self.canSteal()) { 284 | update -%= @bitCast(u32, Sync{ .stealing = 1 }); 285 | self.stealing = true; 286 | } else { 287 | self.pool.idle_event.wait(); 288 | } 289 | 290 | sync = @bitCast(Sync, self.pool.sync.fetchSub(update, .SeqCst)); 291 | //std.debug.print("\nwake {}({}):{}\n\t\t{}\n", .{std.Thread.getCurrentId(), self.stealing, sync, @bitCast(Sync, @bitCast(u32, sync) -% update)}); 292 | assert(sync.idle <= sync.spawned); 293 | if (self.stealing) assert(sync.stealing < sync.spawned); 294 | 295 | self.stealing = !sync.shutdown; 296 | if (!self.stealing) return error.Shutdown; 297 | continue; 298 | } 299 | } 300 | 301 | fn canSteal(self: *const Thread) bool { 302 | if (self.queue.canSteal()) 303 | return true; 304 | 305 | if (self.pool.queue.canSteal()) 306 | return true; 307 | 308 | var num_threads: u16 = @bitCast(Sync, self.pool.sync.load(.Monotonic)).spawned; 309 | var threads: ?*Thread = null; 310 | while (num_threads > 0) : (num_threads -= 1) { 311 | const thread = threads orelse self.pool.threads.load(.Acquire) orelse unreachable; 312 | threads = thread.next; 313 | 314 | if (thread.queue.canSteal()) 315 | return true; 316 | 317 | if (thread.buffer.canSteal()) 318 | return true; 319 | } 320 | 321 | return false; 322 | } 323 | }; 324 | 325 | /// Linked list intrusive memory node and lock-free data structures to operate with it 326 | const Node = struct { 327 | next: ?*Node = null, 328 | 329 | /// A linked list of Nodes 330 | const List = struct { 331 | head: *Node, 332 | tail: *Node, 333 | }; 334 | 335 | /// An unbounded multi-producer-(non blocking)-multi-consumer queue of Node pointers. 336 | const Queue = struct { 337 | stack: Atomic(usize) = Atomic(usize).init(0), 338 | cache: ?*Node = null, 339 | 340 | const HAS_CACHE: usize = 0b01; 341 | const IS_CONSUMING: usize = 0b10; 342 | const PTR_MASK: usize = ~(HAS_CACHE | IS_CONSUMING); 343 | 344 | comptime { 345 | assert(@alignOf(Node) >= ((IS_CONSUMING | HAS_CACHE) + 1)); 346 | } 347 | 348 | noinline fn push(noalias self: *Queue, list: List) void { 349 | var stack = self.stack.load(.Monotonic); 350 | while (true) { 351 | // Attach the list to the stack (pt. 1) 352 | list.tail.next = @intToPtr(?*Node, stack & PTR_MASK); 353 | 354 | // Update the stack with the list (pt. 2). 355 | // Don't change the HAS_CACHE and IS_CONSUMING bits of the consumer. 356 | var new_stack = @ptrToInt(list.head); 357 | assert(new_stack & ~PTR_MASK == 0); 358 | new_stack |= (stack & ~PTR_MASK); 359 | 360 | // Push to the stack with a release barrier for the consumer to see the proper list links. 361 | stack = self.stack.tryCompareAndSwap( 362 | stack, 363 | new_stack, 364 | .Release, 365 | .Monotonic, 366 | ) orelse break; 367 | } 368 | } 369 | 370 | fn canSteal(self: *const Queue) bool { 371 | const stack = self.stack.load(.Monotonic); 372 | if (stack & IS_CONSUMING != 0) return false; 373 | if (stack & (HAS_CACHE | PTR_MASK) == 0) return false; 374 | return true; 375 | } 376 | 377 | fn tryAcquireConsumer(self: *Queue) error{Empty, Contended}!?*Node { 378 | var stack = self.stack.load(.Monotonic); 379 | while (true) { 380 | if (stack & IS_CONSUMING != 0) 381 | return error.Contended; // The queue already has a consumer. 382 | if (stack & (HAS_CACHE | PTR_MASK) == 0) 383 | return error.Empty; // The queue is empty when there's nothing cached and nothing in the stack. 384 | 385 | // When we acquire the consumer, also consume the pushed stack if the cache is empty. 386 | var new_stack = stack | HAS_CACHE | IS_CONSUMING; 387 | if (stack & HAS_CACHE == 0) { 388 | assert(stack & PTR_MASK != 0); 389 | new_stack &= ~PTR_MASK; 390 | } 391 | 392 | // Acquire barrier on getting the consumer to see cache/Node updates done by previous consumers 393 | // and to ensure our cache/Node updates in pop() happen after that of previous consumers. 394 | stack = self.stack.tryCompareAndSwap( 395 | stack, 396 | new_stack, 397 | .Acquire, 398 | .Monotonic, 399 | ) orelse return self.cache orelse @intToPtr(*Node, stack & PTR_MASK); 400 | } 401 | } 402 | 403 | fn releaseConsumer(noalias self: *Queue, noalias consumer: ?*Node) void { 404 | // Stop consuming and remove the HAS_CACHE bit as well if the consumer's cache is empty. 405 | // When HAS_CACHE bit is zeroed, the next consumer will acquire the pushed stack nodes. 406 | var remove = IS_CONSUMING; 407 | if (consumer == null) 408 | remove |= HAS_CACHE; 409 | 410 | // Release the consumer with a release barrier to ensure cache/node accesses 411 | // happen before the consumer was released and before the next consumer starts using the cache. 412 | self.cache = consumer; 413 | const stack = self.stack.fetchSub(remove, .Release); 414 | assert(stack & remove != 0); 415 | } 416 | 417 | fn pop(noalias self: *Queue, noalias consumer_ref: *?*Node) ?*Node { 418 | // Check the consumer cache (fast path) 419 | if (consumer_ref.*) |node| { 420 | consumer_ref.* = node.next; 421 | return node; 422 | } 423 | 424 | // Load the stack to see if there was anything pushed that we could grab. 425 | var stack = self.stack.load(.Monotonic); 426 | assert(stack & IS_CONSUMING != 0); 427 | if (stack & PTR_MASK == 0) { 428 | return null; 429 | } 430 | 431 | // Nodes have been pushed to the stack, grab then with an Acquire barrier to see the Node links. 432 | stack = self.stack.swap(HAS_CACHE | IS_CONSUMING, .Acquire); 433 | assert(stack & IS_CONSUMING != 0); 434 | assert(stack & PTR_MASK != 0); 435 | 436 | const node = @intToPtr(*Node, stack & PTR_MASK); 437 | consumer_ref.* = node.next; 438 | return node; 439 | } 440 | }; 441 | 442 | /// A bounded single-producer, multi-consumer ring buffer for node pointers. 443 | const Buffer = struct { 444 | head: Atomic(Index) = Atomic(Index).init(0), 445 | tail: Atomic(Index) = Atomic(Index).init(0), 446 | array: [capacity]Atomic(*Node) = undefined, 447 | 448 | const Index = u32; 449 | const capacity = 256; // Appears to be a pretty good trade-off in space vs contended throughput 450 | comptime { 451 | assert(std.math.maxInt(Index) >= capacity); 452 | assert(std.math.isPowerOfTwo(capacity)); 453 | } 454 | 455 | noinline fn push(noalias self: *Buffer, noalias list: *List) error{Overflow}!void { 456 | var head = self.head.load(.Monotonic); 457 | var tail = self.tail.loadUnchecked(); // we're the only thread that can change this 458 | 459 | while (true) { 460 | var size = tail -% head; 461 | assert(size <= capacity); 462 | 463 | // Push nodes from the list to the buffer if it's not empty.. 464 | if (size < capacity) { 465 | var nodes: ?*Node = list.head; 466 | while (size < capacity) : (size += 1) { 467 | const node = nodes orelse break; 468 | nodes = node.next; 469 | 470 | // Array written atomically with weakest ordering since it could be getting atomically read by steal(). 471 | self.array[tail % capacity].store(node, .Unordered); 472 | tail +%= 1; 473 | } 474 | 475 | // Release barrier synchronizes with Acquire loads for steal()ers to see the array writes. 476 | self.tail.store(tail, .Release); 477 | 478 | // Update the list with the nodes we pushed to the buffer and try again if there's more. 479 | list.head = nodes orelse return; 480 | std.atomic.spinLoopHint(); 481 | head = self.head.load(.Monotonic); 482 | continue; 483 | } 484 | 485 | // Try to steal/overflow half of the tasks in the buffer to make room for future push()es. 486 | // Migrating half amortizes the cost of stealing while requiring future pops to still use the buffer. 487 | // Acquire barrier to ensure the linked list creation after the steal only happens after we succesfully steal. 488 | var migrate = size / 2; 489 | head = self.head.tryCompareAndSwap( 490 | head, 491 | head +% migrate, 492 | .Acquire, 493 | .Monotonic, 494 | ) orelse { 495 | // Link the migrated Nodes together 496 | const first = self.array[head % capacity].loadUnchecked(); 497 | while (migrate > 0) : (migrate -= 1) { 498 | const prev = self.array[head % capacity].loadUnchecked(); 499 | head +%= 1; 500 | prev.next = self.array[head % capacity].loadUnchecked(); 501 | } 502 | 503 | // Append the list that was supposed to be pushed to the end of the migrated Nodes 504 | const last = self.array[(head -% 1) % capacity].loadUnchecked(); 505 | last.next = list.head; 506 | list.tail.next = null; 507 | 508 | // Return the migrated nodes + the original list as overflowed 509 | list.head = first; 510 | return error.Overflow; 511 | }; 512 | } 513 | } 514 | 515 | fn pop(self: *Buffer) ?*Node { 516 | var head = self.head.load(.Monotonic); 517 | var tail = self.tail.loadUnchecked(); // we're the only thread that can change this 518 | 519 | while (true) { 520 | // Quick sanity check and return null when not empty 521 | var size = tail -% head; 522 | assert(size <= capacity); 523 | if (size == 0) { 524 | return null; 525 | } 526 | 527 | // On x86, a fetchAdd ("lock xadd") can be faster than a tryCompareAndSwap ("lock cmpxchg"). 528 | // If the increment makes the head go past the tail, it means the queue was emptied before we incremented so revert. 529 | // Acquire barrier to ensure that any writes we do to the popped Node only happen after the head increment. 530 | if (comptime builtin.target.cpu.arch.isX86()) { 531 | head = self.head.fetchAdd(1, .Acquire); 532 | if (head == tail) { 533 | self.head.store(head, .Monotonic); 534 | return null; 535 | } 536 | 537 | size = tail -% head; 538 | assert(size <= capacity); 539 | return self.array[head % capacity].loadUnchecked(); 540 | } 541 | 542 | // Dequeue with an acquire barrier to ensure any writes done to the Node 543 | // only happen after we succesfully claim it from the array. 544 | head = self.head.tryCompareAndSwap( 545 | head, 546 | head +% 1, 547 | .Acquire, 548 | .Monotonic, 549 | ) orelse return self.array[head % capacity].loadUnchecked(); 550 | } 551 | } 552 | 553 | const Stole = struct { 554 | node: *Node, 555 | pushed: bool, 556 | }; 557 | 558 | fn canSteal(self: *const Buffer) bool { 559 | while (true) : (std.atomic.spinLoopHint()) { 560 | const head = self.head.load(.Acquire); 561 | const tail = self.tail.load(.Acquire); 562 | 563 | // On x86, the target buffer thread uses fetchAdd to increment the head which can go over if it's zero. 564 | // Account for that here by understanding that it's empty here. 565 | if (comptime builtin.target.cpu.arch.isX86()) { 566 | if (head == tail +% 1) { 567 | return false; 568 | } 569 | } 570 | 571 | const size = tail -% head; 572 | if (size > capacity) { 573 | continue; 574 | } 575 | 576 | assert(size <= capacity); 577 | return size != 0; 578 | } 579 | } 580 | 581 | fn consume(noalias self: *Buffer, noalias queue: *Queue) ?Stole { 582 | var consumer = queue.tryAcquireConsumer() catch return null; 583 | defer queue.releaseConsumer(consumer); 584 | 585 | const head = self.head.load(.Monotonic); 586 | const tail = self.tail.loadUnchecked(); // we're the only thread that can change this 587 | 588 | const size = tail -% head; 589 | assert(size <= capacity); 590 | assert(size == 0); // we should only be consuming if our array is empty 591 | 592 | // Pop nodes from the queue and push them to our array. 593 | // Atomic stores to the array as steal() threads may be atomically reading from it. 594 | var pushed: Index = 0; 595 | while (pushed < capacity) : (pushed += 1) { 596 | const node = queue.pop(&consumer) orelse break; 597 | self.array[(tail +% pushed) % capacity].store(node, .Unordered); 598 | } 599 | 600 | // We will be returning one node that we stole from the queue. 601 | // Get an extra, and if that's not possible, take one from our array. 602 | const node = queue.pop(&consumer) orelse blk: { 603 | if (pushed == 0) return null; 604 | pushed -= 1; 605 | break :blk self.array[(tail +% pushed) % capacity].loadUnchecked(); 606 | }; 607 | 608 | // Update the array tail with the nodes we pushed to it. 609 | // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. 610 | if (pushed > 0) self.tail.store(tail +% pushed, .Release); 611 | return Stole{ 612 | .node = node, 613 | .pushed = pushed > 0, 614 | }; 615 | } 616 | 617 | fn steal(noalias self: *Buffer, noalias buffer: *Buffer) ?Stole { 618 | const head = self.head.load(.Monotonic); 619 | const tail = self.tail.loadUnchecked(); // we're the only thread that can change this 620 | 621 | const size = tail -% head; 622 | assert(size <= capacity); 623 | assert(size == 0); // we should only be stealing if our array is empty 624 | 625 | while (true) : (std.atomic.spinLoopHint()) { 626 | const buffer_head = buffer.head.load(.Acquire); 627 | const buffer_tail = buffer.tail.load(.Acquire); 628 | 629 | // On x86, the target buffer thread uses fetchAdd to increment the head which can go over if it's zero. 630 | // Account for that here by understanding that it's empty here. 631 | if (comptime builtin.target.cpu.arch.isX86()) { 632 | if (buffer_head == buffer_tail +% 1) { 633 | return null; 634 | } 635 | } 636 | 637 | // Overly large size indicates the the tail was updated a lot after the head was loaded. 638 | // Reload both and try again. 639 | const buffer_size = buffer_tail -% buffer_head; 640 | if (buffer_size > capacity) { 641 | continue; 642 | } 643 | 644 | // Try to steal half (divCeil) to amortize the cost of stealing from other threads. 645 | const steal_size = buffer_size - (buffer_size / 2); 646 | if (steal_size == 0) { 647 | return null; 648 | } 649 | 650 | // Copy the nodes we will steal from the target's array to our own. 651 | // Atomically load from the target buffer array as it may be pushing and atomically storing to it. 652 | // Atomic store to our array as other steal() threads may be atomically loading from it as above. 653 | var i: Index = 0; 654 | while (i < steal_size) : (i += 1) { 655 | const node = buffer.array[(buffer_head +% i) % capacity].load(.Unordered); 656 | self.array[(tail +% i) % capacity].store(node, .Unordered); 657 | } 658 | 659 | // Try to commit the steal from the target buffer using: 660 | // - an Acquire barrier to ensure that we only interact with the stolen Nodes after the steal was committed. 661 | // - a Release barrier to ensure that the Nodes are copied above prior to the committing of the steal 662 | // because if they're copied after the steal, the could be getting rewritten by the target's push(). 663 | _ = buffer.head.compareAndSwap( 664 | buffer_head, 665 | buffer_head +% steal_size, 666 | .AcqRel, 667 | .Monotonic, 668 | ) orelse { 669 | // Pop one from the nodes we stole as we'll be returning it 670 | const pushed = steal_size - 1; 671 | const node = self.array[(tail +% pushed) % capacity].loadUnchecked(); 672 | 673 | // Update the array tail with the nodes we pushed to it. 674 | // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. 675 | if (pushed > 0) self.tail.store(tail +% pushed, .Release); 676 | return Stole{ 677 | .node = node, 678 | .pushed = pushed > 0, 679 | }; 680 | }; 681 | } 682 | } 683 | }; 684 | }; 685 | 686 | /// An event which stores 1 semaphore token and is multi-threaded safe. 687 | /// The event can be shutdown(), waking up all wait()ing threads and 688 | /// making subsequent wait()'s return immediately. 689 | const Event = struct { 690 | state: Atomic(u32) = Atomic(u32).init(EMPTY), 691 | 692 | const EMPTY = 0; 693 | const WAITING = 1; 694 | const NOTIFIED = 2; 695 | const SHUTDOWN = 3; 696 | 697 | /// Wait for and consume a notification 698 | /// or wait for the event to be shutdown entirely 699 | noinline fn wait(self: *Event) void { 700 | var acquire_with: u32 = EMPTY; 701 | var state = self.state.load(.Monotonic); 702 | 703 | while (true) { 704 | // If we're shutdown then exit early. 705 | // Acquire barrier to ensure operations before the shutdown() are seen after the wait(). 706 | // Shutdown is rare so it's better to have an Acquire barrier here instead of on CAS failure + load which are common. 707 | if (state == SHUTDOWN) { 708 | std.atomic.fence(.Acquire); 709 | return; 710 | } 711 | 712 | // Consume a notification when it pops up. 713 | // Acquire barrier to ensure operations before the notify() appear after the wait(). 714 | if (state == NOTIFIED) { 715 | state = self.state.tryCompareAndSwap( 716 | state, 717 | acquire_with, 718 | .Acquire, 719 | .Monotonic, 720 | ) orelse return; 721 | continue; 722 | } 723 | 724 | // There is no notification to consume, we should wait on the event by ensuring its WAITING. 725 | if (state != WAITING) blk: { 726 | state = self.state.tryCompareAndSwap( 727 | state, 728 | WAITING, 729 | .Monotonic, 730 | .Monotonic, 731 | ) orelse break :blk; 732 | continue; 733 | } 734 | 735 | // Wait on the event until a notify() or shutdown(). 736 | // If we wake up to a notification, we must acquire it with WAITING instead of EMPTY 737 | // since there may be other threads sleeping on the Futex who haven't been woken up yet. 738 | // 739 | // Acquiring to WAITING will make the next notify() or shutdown() wake a sleeping futex thread 740 | // who will either exit on SHUTDOWN or acquire with WAITING again, ensuring all threads are awoken. 741 | // This unfortunately results in the last notify() or shutdown() doing an extra futex wake but that's fine. 742 | std.Thread.Futex.wait(&self.state, WAITING, null) catch unreachable; 743 | state = self.state.load(.Monotonic); 744 | acquire_with = WAITING; 745 | } 746 | } 747 | 748 | /// Post a notification to the event if it doesn't have one already 749 | /// then wake up a waiting thread if there is one as well. 750 | fn notify(self: *Event) void { 751 | return self.wake(NOTIFIED, 1); 752 | } 753 | 754 | /// Marks the event as shutdown, making all future wait()'s return immediately. 755 | /// Then wakes up any threads currently waiting on the Event. 756 | fn shutdown(self: *Event) void { 757 | return self.wake(SHUTDOWN, std.math.maxInt(u32)); 758 | } 759 | 760 | noinline fn wake(self: *Event, release_with: u32, wake_threads: u32) void { 761 | // Update the Event to notifty it with the new `release_with` state (either NOTIFIED or SHUTDOWN). 762 | // Release barrier to ensure any operations before this are this to happen before the wait() in the other threads. 763 | const state = self.state.swap(release_with, .Release); 764 | 765 | // Only wake threads sleeping in futex if the state is WAITING. 766 | // Avoids unnecessary wake ups. 767 | if (state == WAITING) { 768 | std.Thread.Futex.wake(&self.state, wake_threads); 769 | } 770 | } 771 | }; 772 | -------------------------------------------------------------------------------- /src/thread_pool.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const ThreadPool = @This(); 3 | 4 | const assert = std.debug.assert; 5 | const Atomic = std.atomic.Atomic; 6 | 7 | stack_size: u32, 8 | max_threads: u32, 9 | sync: Atomic(u32) = Atomic(u32).init(@bitCast(u32, Sync{})), 10 | idle_event: Event = .{}, 11 | join_event: Event = .{}, 12 | run_queue: Node.Queue = .{}, 13 | threads: Atomic(?*Thread) = Atomic(?*Thread).init(null), 14 | 15 | const Sync = packed struct { 16 | /// Tracks the number of threads not searching for Tasks 17 | idle: u14 = 0, 18 | /// Tracks the number of threads spawned 19 | spawned: u14 = 0, 20 | /// What you see is what you get 21 | unused: bool = false, 22 | /// Used to not miss notifications while state = waking 23 | notified: bool = false, 24 | /// The current state of the thread pool 25 | state: enum(u2) { 26 | /// A notification can be issued to wake up a sleeping as the "waking thread". 27 | pending = 0, 28 | /// The state was notifiied with a signal. A thread is woken up. 29 | /// The first thread to transition to `waking` becomes the "waking thread". 30 | signaled, 31 | /// There is a "waking thread" among us. 32 | /// No other thread should be woken up until the waking thread transitions the state. 33 | waking, 34 | /// The thread pool was terminated. Start decremented `spawned` so that it can be joined. 35 | shutdown, 36 | } = .pending, 37 | }; 38 | 39 | /// Configuration options for the thread pool. 40 | /// TODO: add CPU core affinity? 41 | pub const Config = struct { 42 | stack_size: u32 = (std.Thread.SpawnConfig{}).stack_size, 43 | max_threads: u32, 44 | }; 45 | 46 | /// Statically initialize the thread pool using the configuration. 47 | pub fn init(config: Config) ThreadPool { 48 | return .{ 49 | .stack_size = std.math.max(1, config.stack_size), 50 | .max_threads = std.math.max(1, config.max_threads), 51 | }; 52 | } 53 | 54 | /// Wait for a thread to call shutdown() on the thread pool and kill the worker threads. 55 | pub fn deinit(self: *ThreadPool) void { 56 | self.join(); 57 | self.* = undefined; 58 | } 59 | 60 | /// A Task represents the unit of Work / Job / Execution that the ThreadPool schedules. 61 | /// The user provides a `callback` which is invoked when the *Task can run on a thread. 62 | pub const Task = struct { 63 | node: Node = .{}, 64 | callback: fn (*Task) void, 65 | }; 66 | 67 | /// An unordered collection of Tasks which can be submitted for scheduling as a group. 68 | pub const Batch = struct { 69 | len: usize = 0, 70 | head: ?*Task = null, 71 | tail: ?*Task = null, 72 | 73 | /// Create a batch from a single task. 74 | pub fn from(task: *Task) Batch { 75 | return Batch{ 76 | .len = 1, 77 | .head = task, 78 | .tail = task, 79 | }; 80 | } 81 | 82 | /// Another batch into this one, taking ownership of its tasks. 83 | pub fn push(self: *Batch, batch: Batch) void { 84 | if (batch.len == 0) return; 85 | if (self.len == 0) { 86 | self.* = batch; 87 | } else { 88 | self.tail.?.node.next = if (batch.head) |h| &h.node else null; 89 | self.tail = batch.tail; 90 | self.len += batch.len; 91 | } 92 | } 93 | }; 94 | 95 | /// Schedule a batch of tasks to be executed by some thread on the thread pool. 96 | pub fn schedule(self: *ThreadPool, batch: Batch) void { 97 | // Sanity check 98 | if (batch.len == 0) { 99 | return; 100 | } 101 | 102 | // Extract out the Node's from the Tasks 103 | var list = Node.List{ 104 | .head = &batch.head.?.node, 105 | .tail = &batch.tail.?.node, 106 | }; 107 | 108 | // Push the task Nodes to the most approriate queue 109 | if (Thread.current) |thread| { 110 | thread.run_buffer.push(&list) catch thread.run_queue.push(list); 111 | } else { 112 | self.run_queue.push(list); 113 | } 114 | 115 | // Try to notify a thread 116 | const is_waking = false; 117 | return self.notify(is_waking); 118 | } 119 | 120 | inline fn notify(self: *ThreadPool, is_waking: bool) void { 121 | // Fast path to check the Sync state to avoid calling into notifySlow(). 122 | // If we're waking, then we need to update the state regardless 123 | if (!is_waking) { 124 | const sync = @bitCast(Sync, self.sync.load(.Monotonic)); 125 | if (sync.notified) { 126 | return; 127 | } 128 | } 129 | 130 | return self.notifySlow(is_waking); 131 | } 132 | 133 | noinline fn notifySlow(self: *ThreadPool, is_waking: bool) void { 134 | var sync = @bitCast(Sync, self.sync.load(.Monotonic)); 135 | while (sync.state != .shutdown) { 136 | 137 | const can_wake = is_waking or (sync.state == .pending); 138 | if (is_waking) { 139 | assert(sync.state == .waking); 140 | } 141 | 142 | var new_sync = sync; 143 | new_sync.notified = true; 144 | if (can_wake and sync.idle > 0) { // wake up an idle thread 145 | new_sync.state = .signaled; 146 | } else if (can_wake and sync.spawned < self.max_threads) { // spawn a new thread 147 | new_sync.state = .signaled; 148 | new_sync.spawned += 1; 149 | } else if (is_waking) { // no other thread to pass on "waking" status 150 | new_sync.state = .pending; 151 | } else if (sync.notified) { // nothing to update 152 | return; 153 | } 154 | 155 | // Release barrier synchronizes with Acquire in wait() 156 | // to ensure pushes to run queues happen before observing a posted notification. 157 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 158 | @bitCast(u32, sync), 159 | @bitCast(u32, new_sync), 160 | .Release, 161 | .Monotonic, 162 | ) orelse { 163 | // We signaled to notify an idle thread 164 | if (can_wake and sync.idle > 0) { 165 | return self.idle_event.notify(); 166 | } 167 | 168 | // We signaled to spawn a new thread 169 | if (can_wake and sync.spawned < self.max_threads) { 170 | const spawn_config = std.Thread.SpawnConfig{ .stack_size = self.stack_size }; 171 | const thread = std.Thread.spawn(spawn_config, Thread.run, .{self}) catch return self.unregister(null); 172 | return thread.detach(); 173 | } 174 | 175 | return; 176 | }); 177 | } 178 | } 179 | 180 | noinline fn wait(self: *ThreadPool, _is_waking: bool) error{Shutdown}!bool { 181 | var is_idle = false; 182 | var is_waking = _is_waking; 183 | var sync = @bitCast(Sync, self.sync.load(.Monotonic)); 184 | 185 | while (true) { 186 | if (sync.state == .shutdown) return error.Shutdown; 187 | if (is_waking) assert(sync.state == .waking); 188 | 189 | // Consume a notification made by notify(). 190 | if (sync.notified) { 191 | var new_sync = sync; 192 | new_sync.notified = false; 193 | if (is_idle) 194 | new_sync.idle -= 1; 195 | if (sync.state == .signaled) 196 | new_sync.state = .waking; 197 | 198 | // Acquire barrier synchronizes with notify() 199 | // to ensure that pushes to run queue are observed after wait() returns. 200 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 201 | @bitCast(u32, sync), 202 | @bitCast(u32, new_sync), 203 | .Acquire, 204 | .Monotonic, 205 | ) orelse { 206 | return is_waking or (sync.state == .signaled); 207 | }); 208 | 209 | // No notification to consume. 210 | // Mark this thread as idle before sleeping on the idle_event. 211 | } else if (!is_idle) { 212 | var new_sync = sync; 213 | new_sync.idle += 1; 214 | if (is_waking) 215 | new_sync.state = .pending; 216 | 217 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 218 | @bitCast(u32, sync), 219 | @bitCast(u32, new_sync), 220 | .Monotonic, 221 | .Monotonic, 222 | ) orelse { 223 | is_waking = false; 224 | is_idle = true; 225 | continue; 226 | }); 227 | 228 | // Wait for a signal by either notify() or shutdown() without wasting cpu cycles. 229 | // TODO: Add I/O polling here. 230 | } else { 231 | self.idle_event.wait(); 232 | sync = @bitCast(Sync, self.sync.load(.Monotonic)); 233 | } 234 | } 235 | } 236 | 237 | /// Marks the thread pool as shutdown 238 | pub noinline fn shutdown(self: *ThreadPool) void { 239 | var sync = @bitCast(Sync, self.sync.load(.Monotonic)); 240 | while (sync.state != .shutdown) { 241 | var new_sync = sync; 242 | new_sync.notified = true; 243 | new_sync.state = .shutdown; 244 | new_sync.idle = 0; 245 | 246 | // Full barrier to synchronize with both wait() and notify() 247 | sync = @bitCast(Sync, self.sync.tryCompareAndSwap( 248 | @bitCast(u32, sync), 249 | @bitCast(u32, new_sync), 250 | .AcqRel, 251 | .Monotonic, 252 | ) orelse { 253 | // Wake up any threads sleeping on the idle_event. 254 | // TODO: I/O polling notification here. 255 | if (sync.idle > 0) self.idle_event.shutdown(); 256 | return; 257 | }); 258 | } 259 | } 260 | 261 | fn register(noalias self: *ThreadPool, noalias thread: *Thread) void { 262 | // Push the thread onto the threads stack in a lock-free manner. 263 | var threads = self.threads.load(.Monotonic); 264 | while (true) { 265 | thread.next = threads; 266 | threads = self.threads.tryCompareAndSwap( 267 | threads, 268 | thread, 269 | .Release, 270 | .Monotonic, 271 | ) orelse break; 272 | } 273 | } 274 | 275 | fn unregister(noalias self: *ThreadPool, noalias maybe_thread: ?*Thread) void { 276 | // Un-spawn one thread, either due to a failed OS thread spawning or the thread is exitting. 277 | const one_spawned = @bitCast(u32, Sync{ .spawned = 1 }); 278 | const sync = @bitCast(Sync, self.sync.fetchSub(one_spawned, .Release)); 279 | assert(sync.spawned > 0); 280 | 281 | // The last thread to exit must wake up the thread pool join()er 282 | // who will start the chain to shutdown all the threads. 283 | if (sync.state == .shutdown and sync.spawned == 1) { 284 | self.join_event.notify(); 285 | } 286 | 287 | // If this is a thread pool thread, wait for a shutdown signal by the thread pool join()er. 288 | const thread = maybe_thread orelse return; 289 | thread.join_event.wait(); 290 | 291 | // After receiving the shutdown signal, shutdown the next thread in the pool. 292 | // We have to do that without touching the thread pool itself since it's memory is invalidated by now. 293 | // So just follow our .next link. 294 | const next_thread = thread.next orelse return; 295 | next_thread.join_event.notify(); 296 | } 297 | 298 | fn join(self: *ThreadPool) void { 299 | // Wait for the thread pool to be shutdown() then for all threads to enter a joinable state 300 | self.join_event.wait(); 301 | const sync = @bitCast(Sync, self.sync.load(.Monotonic)); 302 | assert(sync.state == .shutdown); 303 | assert(sync.spawned == 0); 304 | 305 | // If there are threads, start off the chain sending it the shutdown signal. 306 | // The thread receives the shutdown signal and sends it to the next thread, and the next.. 307 | const thread = self.threads.load(.Acquire) orelse return; 308 | thread.join_event.notify(); 309 | } 310 | 311 | const Thread = struct { 312 | next: ?*Thread = null, 313 | target: ?*Thread = null, 314 | join_event: Event = .{}, 315 | run_queue: Node.Queue = .{}, 316 | run_buffer: Node.Buffer = .{}, 317 | 318 | threadlocal var current: ?*Thread = null; 319 | 320 | /// Thread entry point which runs a worker for the ThreadPool 321 | fn run(thread_pool: *ThreadPool) void { 322 | var self = Thread{}; 323 | current = &self; 324 | 325 | thread_pool.register(&self); 326 | defer thread_pool.unregister(&self); 327 | 328 | var is_waking = false; 329 | while (true) { 330 | is_waking = thread_pool.wait(is_waking) catch return; 331 | 332 | while (self.pop(thread_pool)) |result| { 333 | if (result.pushed or is_waking) 334 | thread_pool.notify(is_waking); 335 | is_waking = false; 336 | 337 | const task = @fieldParentPtr(Task, "node", result.node); 338 | (task.callback)(task); 339 | } 340 | } 341 | } 342 | 343 | /// Try to dequeue a Node/Task from the ThreadPool. 344 | /// Spurious reports of dequeue() returning empty are allowed. 345 | fn pop(noalias self: *Thread, noalias thread_pool: *ThreadPool) ?Node.Buffer.Stole { 346 | // Check our local buffer first 347 | if (self.run_buffer.pop()) |node| { 348 | return Node.Buffer.Stole{ 349 | .node = node, 350 | .pushed = false, 351 | }; 352 | } 353 | 354 | // Then check our local queue 355 | if (self.run_buffer.consume(&self.run_queue)) |stole| { 356 | return stole; 357 | } 358 | 359 | // Then the global queue 360 | if (self.run_buffer.consume(&thread_pool.run_queue)) |stole| { 361 | return stole; 362 | } 363 | 364 | // TODO: add optimistic I/O polling here 365 | 366 | // Then try work stealing from other threads 367 | var num_threads: u32 = @bitCast(Sync, thread_pool.sync.load(.Monotonic)).spawned; 368 | while (num_threads > 0) : (num_threads -= 1) { 369 | // Traverse the stack of registered threads on the thread pool 370 | const target = self.target orelse thread_pool.threads.load(.Acquire) orelse unreachable; 371 | self.target = target.next; 372 | 373 | // Try to steal from their queue first to avoid contention (the target steal's from queue last). 374 | if (self.run_buffer.consume(&target.run_queue)) |stole| { 375 | return stole; 376 | } 377 | 378 | // Skip stealing from the buffer if we're the target. 379 | // We still steal from our own queue above given it may have just been locked the first time we tried. 380 | if (target == self) { 381 | continue; 382 | } 383 | 384 | // Steal from the buffer of a remote thread as a last resort 385 | if (self.run_buffer.steal(&target.run_buffer)) |stole| { 386 | return stole; 387 | } 388 | } 389 | 390 | return null; 391 | } 392 | }; 393 | 394 | /// An event which stores 1 semaphore token and is multi-threaded safe. 395 | /// The event can be shutdown(), waking up all wait()ing threads and 396 | /// making subsequent wait()'s return immediately. 397 | const Event = struct { 398 | state: Atomic(u32) = Atomic(u32).init(EMPTY), 399 | 400 | const EMPTY = 0; 401 | const WAITING = 1; 402 | const NOTIFIED = 2; 403 | const SHUTDOWN = 3; 404 | 405 | /// Wait for and consume a notification 406 | /// or wait for the event to be shutdown entirely 407 | noinline fn wait(self: *Event) void { 408 | var acquire_with: u32 = EMPTY; 409 | var state = self.state.load(.Monotonic); 410 | 411 | while (true) { 412 | // If we're shutdown then exit early. 413 | // Acquire barrier to ensure operations before the shutdown() are seen after the wait(). 414 | // Shutdown is rare so it's better to have an Acquire barrier here instead of on CAS failure + load which are common. 415 | if (state == SHUTDOWN) { 416 | std.atomic.fence(.Acquire); 417 | return; 418 | } 419 | 420 | // Consume a notification when it pops up. 421 | // Acquire barrier to ensure operations before the notify() appear after the wait(). 422 | if (state == NOTIFIED) { 423 | state = self.state.tryCompareAndSwap( 424 | state, 425 | acquire_with, 426 | .Acquire, 427 | .Monotonic, 428 | ) orelse return; 429 | continue; 430 | } 431 | 432 | // There is no notification to consume, we should wait on the event by ensuring its WAITING. 433 | if (state != WAITING) blk: { 434 | state = self.state.tryCompareAndSwap( 435 | state, 436 | WAITING, 437 | .Monotonic, 438 | .Monotonic, 439 | ) orelse break :blk; 440 | continue; 441 | } 442 | 443 | // Wait on the event until a notify() or shutdown(). 444 | // If we wake up to a notification, we must acquire it with WAITING instead of EMPTY 445 | // since there may be other threads sleeping on the Futex who haven't been woken up yet. 446 | // 447 | // Acquiring to WAITING will make the next notify() or shutdown() wake a sleeping futex thread 448 | // who will either exit on SHUTDOWN or acquire with WAITING again, ensuring all threads are awoken. 449 | // This unfortunately results in the last notify() or shutdown() doing an extra futex wake but that's fine. 450 | std.Thread.Futex.wait(&self.state, WAITING, null) catch unreachable; 451 | state = self.state.load(.Monotonic); 452 | acquire_with = WAITING; 453 | } 454 | } 455 | 456 | /// Post a notification to the event if it doesn't have one already 457 | /// then wake up a waiting thread if there is one as well. 458 | fn notify(self: *Event) void { 459 | return self.wake(NOTIFIED, 1); 460 | } 461 | 462 | /// Marks the event as shutdown, making all future wait()'s return immediately. 463 | /// Then wakes up any threads currently waiting on the Event. 464 | fn shutdown(self: *Event) void { 465 | return self.wake(SHUTDOWN, std.math.maxInt(u32)); 466 | } 467 | 468 | fn wake(self: *Event, release_with: u32, wake_threads: u32) void { 469 | // Update the Event to notifty it with the new `release_with` state (either NOTIFIED or SHUTDOWN). 470 | // Release barrier to ensure any operations before this are this to happen before the wait() in the other threads. 471 | const state = self.state.swap(release_with, .Release); 472 | 473 | // Only wake threads sleeping in futex if the state is WAITING. 474 | // Avoids unnecessary wake ups. 475 | if (state == WAITING) { 476 | std.Thread.Futex.wake(&self.state, wake_threads); 477 | } 478 | } 479 | }; 480 | 481 | /// Linked list intrusive memory node and lock-free data structures to operate with it 482 | const Node = struct { 483 | next: ?*Node = null, 484 | 485 | /// A linked list of Nodes 486 | const List = struct { 487 | head: *Node, 488 | tail: *Node, 489 | }; 490 | 491 | /// An unbounded multi-producer-(non blocking)-multi-consumer queue of Node pointers. 492 | const Queue = struct { 493 | stack: Atomic(usize) = Atomic(usize).init(0), 494 | cache: ?*Node = null, 495 | 496 | const HAS_CACHE: usize = 0b01; 497 | const IS_CONSUMING: usize = 0b10; 498 | const PTR_MASK: usize = ~(HAS_CACHE | IS_CONSUMING); 499 | 500 | comptime { 501 | assert(@alignOf(Node) >= ((IS_CONSUMING | HAS_CACHE) + 1)); 502 | } 503 | 504 | fn push(noalias self: *Queue, list: List) void { 505 | var stack = self.stack.load(.Monotonic); 506 | while (true) { 507 | // Attach the list to the stack (pt. 1) 508 | list.tail.next = @intToPtr(?*Node, stack & PTR_MASK); 509 | 510 | // Update the stack with the list (pt. 2). 511 | // Don't change the HAS_CACHE and IS_CONSUMING bits of the consumer. 512 | var new_stack = @ptrToInt(list.head); 513 | assert(new_stack & ~PTR_MASK == 0); 514 | new_stack |= (stack & ~PTR_MASK); 515 | 516 | // Push to the stack with a release barrier for the consumer to see the proper list links. 517 | stack = self.stack.tryCompareAndSwap( 518 | stack, 519 | new_stack, 520 | .Release, 521 | .Monotonic, 522 | ) orelse break; 523 | } 524 | } 525 | 526 | fn tryAcquireConsumer(self: *Queue) error{Empty, Contended}!?*Node { 527 | var stack = self.stack.load(.Monotonic); 528 | while (true) { 529 | if (stack & IS_CONSUMING != 0) 530 | return error.Contended; // The queue already has a consumer. 531 | if (stack & (HAS_CACHE | PTR_MASK) == 0) 532 | return error.Empty; // The queue is empty when there's nothing cached and nothing in the stack. 533 | 534 | // When we acquire the consumer, also consume the pushed stack if the cache is empty. 535 | var new_stack = stack | HAS_CACHE | IS_CONSUMING; 536 | if (stack & HAS_CACHE == 0) { 537 | assert(stack & PTR_MASK != 0); 538 | new_stack &= ~PTR_MASK; 539 | } 540 | 541 | // Acquire barrier on getting the consumer to see cache/Node updates done by previous consumers 542 | // and to ensure our cache/Node updates in pop() happen after that of previous consumers. 543 | stack = self.stack.tryCompareAndSwap( 544 | stack, 545 | new_stack, 546 | .Acquire, 547 | .Monotonic, 548 | ) orelse return self.cache orelse @intToPtr(*Node, stack & PTR_MASK); 549 | } 550 | } 551 | 552 | fn releaseConsumer(noalias self: *Queue, noalias consumer: ?*Node) void { 553 | // Stop consuming and remove the HAS_CACHE bit as well if the consumer's cache is empty. 554 | // When HAS_CACHE bit is zeroed, the next consumer will acquire the pushed stack nodes. 555 | var remove = IS_CONSUMING; 556 | if (consumer == null) 557 | remove |= HAS_CACHE; 558 | 559 | // Release the consumer with a release barrier to ensure cache/node accesses 560 | // happen before the consumer was released and before the next consumer starts using the cache. 561 | self.cache = consumer; 562 | const stack = self.stack.fetchSub(remove, .Release); 563 | assert(stack & remove != 0); 564 | } 565 | 566 | fn pop(noalias self: *Queue, noalias consumer_ref: *?*Node) ?*Node { 567 | // Check the consumer cache (fast path) 568 | if (consumer_ref.*) |node| { 569 | consumer_ref.* = node.next; 570 | return node; 571 | } 572 | 573 | // Load the stack to see if there was anything pushed that we could grab. 574 | var stack = self.stack.load(.Monotonic); 575 | assert(stack & IS_CONSUMING != 0); 576 | if (stack & PTR_MASK == 0) { 577 | return null; 578 | } 579 | 580 | // Nodes have been pushed to the stack, grab then with an Acquire barrier to see the Node links. 581 | stack = self.stack.swap(HAS_CACHE | IS_CONSUMING, .Acquire); 582 | assert(stack & IS_CONSUMING != 0); 583 | assert(stack & PTR_MASK != 0); 584 | 585 | const node = @intToPtr(*Node, stack & PTR_MASK); 586 | consumer_ref.* = node.next; 587 | return node; 588 | } 589 | }; 590 | 591 | /// A bounded single-producer, multi-consumer ring buffer for node pointers. 592 | const Buffer = struct { 593 | head: Atomic(Index) = Atomic(Index).init(0), 594 | tail: Atomic(Index) = Atomic(Index).init(0), 595 | array: [capacity]Atomic(*Node) = undefined, 596 | 597 | const Index = u32; 598 | const capacity = 256; // Appears to be a pretty good trade-off in space vs contended throughput 599 | comptime { 600 | assert(std.math.maxInt(Index) >= capacity); 601 | assert(std.math.isPowerOfTwo(capacity)); 602 | } 603 | 604 | fn push(noalias self: *Buffer, noalias list: *List) error{Overflow}!void { 605 | var head = self.head.load(.Monotonic); 606 | var tail = self.tail.loadUnchecked(); // we're the only thread that can change this 607 | 608 | while (true) { 609 | var size = tail -% head; 610 | assert(size <= capacity); 611 | 612 | // Push nodes from the list to the buffer if it's not empty.. 613 | if (size < capacity) { 614 | var nodes: ?*Node = list.head; 615 | while (size < capacity) : (size += 1) { 616 | const node = nodes orelse break; 617 | nodes = node.next; 618 | 619 | // Array written atomically with weakest ordering since it could be getting atomically read by steal(). 620 | self.array[tail % capacity].store(node, .Unordered); 621 | tail +%= 1; 622 | } 623 | 624 | // Release barrier synchronizes with Acquire loads for steal()ers to see the array writes. 625 | self.tail.store(tail, .Release); 626 | 627 | // Update the list with the nodes we pushed to the buffer and try again if there's more. 628 | list.head = nodes orelse return; 629 | std.atomic.spinLoopHint(); 630 | head = self.head.load(.Monotonic); 631 | continue; 632 | } 633 | 634 | // Try to steal/overflow half of the tasks in the buffer to make room for future push()es. 635 | // Migrating half amortizes the cost of stealing while requiring future pops to still use the buffer. 636 | // Acquire barrier to ensure the linked list creation after the steal only happens after we succesfully steal. 637 | var migrate = size / 2; 638 | head = self.head.tryCompareAndSwap( 639 | head, 640 | head +% migrate, 641 | .Acquire, 642 | .Monotonic, 643 | ) orelse { 644 | // Link the migrated Nodes together 645 | const first = self.array[head % capacity].loadUnchecked(); 646 | while (migrate > 0) : (migrate -= 1) { 647 | const prev = self.array[head % capacity].loadUnchecked(); 648 | head +%= 1; 649 | prev.next = self.array[head % capacity].loadUnchecked(); 650 | } 651 | 652 | // Append the list that was supposed to be pushed to the end of the migrated Nodes 653 | const last = self.array[(head -% 1) % capacity].loadUnchecked(); 654 | last.next = list.head; 655 | list.tail.next = null; 656 | 657 | // Return the migrated nodes + the original list as overflowed 658 | list.head = first; 659 | return error.Overflow; 660 | }; 661 | } 662 | } 663 | 664 | fn pop(self: *Buffer) ?*Node { 665 | var head = self.head.load(.Monotonic); 666 | var tail = self.tail.loadUnchecked(); // we're the only thread that can change this 667 | 668 | while (true) { 669 | // Quick sanity check and return null when not empty 670 | var size = tail -% head; 671 | assert(size <= capacity); 672 | if (size == 0) { 673 | return null; 674 | } 675 | 676 | // Dequeue with an acquire barrier to ensure any writes done to the Node 677 | // only happen after we succesfully claim it from the array. 678 | head = self.head.tryCompareAndSwap( 679 | head, 680 | head +% 1, 681 | .Acquire, 682 | .Monotonic, 683 | ) orelse return self.array[head % capacity].loadUnchecked(); 684 | } 685 | } 686 | 687 | const Stole = struct { 688 | node: *Node, 689 | pushed: bool, 690 | }; 691 | 692 | fn consume(noalias self: *Buffer, noalias queue: *Queue) ?Stole { 693 | var consumer = queue.tryAcquireConsumer() catch return null; 694 | defer queue.releaseConsumer(consumer); 695 | 696 | const head = self.head.load(.Monotonic); 697 | const tail = self.tail.loadUnchecked(); // we're the only thread that can change this 698 | 699 | const size = tail -% head; 700 | assert(size <= capacity); 701 | assert(size == 0); // we should only be consuming if our array is empty 702 | 703 | // Pop nodes from the queue and push them to our array. 704 | // Atomic stores to the array as steal() threads may be atomically reading from it. 705 | var pushed: Index = 0; 706 | while (pushed < capacity) : (pushed += 1) { 707 | const node = queue.pop(&consumer) orelse break; 708 | self.array[(tail +% pushed) % capacity].store(node, .Unordered); 709 | } 710 | 711 | // We will be returning one node that we stole from the queue. 712 | // Get an extra, and if that's not possible, take one from our array. 713 | const node = queue.pop(&consumer) orelse blk: { 714 | if (pushed == 0) return null; 715 | pushed -= 1; 716 | break :blk self.array[(tail +% pushed) % capacity].loadUnchecked(); 717 | }; 718 | 719 | // Update the array tail with the nodes we pushed to it. 720 | // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. 721 | if (pushed > 0) self.tail.store(tail +% pushed, .Release); 722 | return Stole{ 723 | .node = node, 724 | .pushed = pushed > 0, 725 | }; 726 | } 727 | 728 | fn steal(noalias self: *Buffer, noalias buffer: *Buffer) ?Stole { 729 | const head = self.head.load(.Monotonic); 730 | const tail = self.tail.loadUnchecked(); // we're the only thread that can change this 731 | 732 | const size = tail -% head; 733 | assert(size <= capacity); 734 | assert(size == 0); // we should only be stealing if our array is empty 735 | 736 | while (true) : (std.atomic.spinLoopHint()) { 737 | const buffer_head = buffer.head.load(.Acquire); 738 | const buffer_tail = buffer.tail.load(.Acquire); 739 | 740 | // Overly large size indicates the the tail was updated a lot after the head was loaded. 741 | // Reload both and try again. 742 | const buffer_size = buffer_tail -% buffer_head; 743 | if (buffer_size > capacity) { 744 | continue; 745 | } 746 | 747 | // Try to steal half (divCeil) to amortize the cost of stealing from other threads. 748 | const steal_size = buffer_size - (buffer_size / 2); 749 | if (steal_size == 0) { 750 | return null; 751 | } 752 | 753 | // Copy the nodes we will steal from the target's array to our own. 754 | // Atomically load from the target buffer array as it may be pushing and atomically storing to it. 755 | // Atomic store to our array as other steal() threads may be atomically loading from it as above. 756 | var i: Index = 0; 757 | while (i < steal_size) : (i += 1) { 758 | const node = buffer.array[(buffer_head +% i) % capacity].load(.Unordered); 759 | self.array[(tail +% i) % capacity].store(node, .Unordered); 760 | } 761 | 762 | // Try to commit the steal from the target buffer using: 763 | // - an Acquire barrier to ensure that we only interact with the stolen Nodes after the steal was committed. 764 | // - a Release barrier to ensure that the Nodes are copied above prior to the committing of the steal 765 | // because if they're copied after the steal, the could be getting rewritten by the target's push(). 766 | _ = buffer.head.compareAndSwap( 767 | buffer_head, 768 | buffer_head +% steal_size, 769 | .AcqRel, 770 | .Monotonic, 771 | ) orelse { 772 | // Pop one from the nodes we stole as we'll be returning it 773 | const pushed = steal_size - 1; 774 | const node = self.array[(tail +% pushed) % capacity].loadUnchecked(); 775 | 776 | // Update the array tail with the nodes we pushed to it. 777 | // Release barrier to synchronize with Acquire barrier in steal()'s to see the written array Nodes. 778 | if (pushed > 0) self.tail.store(tail +% pushed, .Release); 779 | return Stole{ 780 | .node = node, 781 | .pushed = pushed > 0, 782 | }; 783 | }; 784 | } 785 | } 786 | }; 787 | }; 788 | -------------------------------------------------------------------------------- /blog.md: -------------------------------------------------------------------------------- 1 | I'd like to share what I've been working on for the past 2 years give or take. It's a thread pool that checks a bunch of boxes: lock-free, allocation-free\* (excluding spawning threads), supports batch scheduling, and dynamically spawns threads while handling thread spawn failure. 2 | 3 | To preface, this assumes you're familiar with thread synchronization patterns and manual memory management. It's also more of a letter to other people implementing schedulers than it is to benefit most programmers. So if you don't understand what's going on sometimes, that's perfectly fine. I try to explain what led to each thought and if you're just interested in how the claims above materialized, go read [the source](https://github.com/kprotty/zap/blob/blog/src/thread_pool.zig) directly. 4 | 5 | ## Thread Pools? 6 | 7 | For those unaware, a thread pool is just a group of threads that work can be dispatched to. Having a group amortizes the costs of creating and shutting down threads which can be expensive in comparison to the work being performed. It also prevents a task from blocking another by having other thread ready to process it. 8 | 9 | Thread pools are used everywhere from your favorite I/O event loop (Golang, Tokio, Akka, Node.js), to game logic or simulation processing (OpenMP, Intel TBB, Rayon, Bevy), and even broader applications (Linkers, Machine Learning, and more). It's a pretty well explored abstraction, but *there's still more room for improvement.* 10 | 11 | ## Why Build Your Own? 12 | 13 | A good question. Given the abundance of solutions (I've listed some above), why not just use an existing thread pool? Aren't thread pools a solved problem? Aren't they just all the same: a group of threads? It's reasonable to have this line of though if the processing isn't your main concern. However I like tinkering, optimizing and have quite a bit of free time. These are shared formulas with which helped build the existing solutions. 14 | 15 | First, I'd like to set the stage. I'm very into Zig. The time is somewhere after Zig `0.5`. Andrew just recently introduced Zig's [new `async/await` semantics](https://ziglang.org/download/0.5.0/release-notes.html#Async-Functions) (I hope to do a post about this in the future) and the standard library event loop (async I/O driver) is only at its baby stages. This is a chance to get Zig into the big player domain like Go and Rust for async I/O stuff. A good thread pool appears necessary. 16 | 17 | Second, **thread pools aren't a solved problem**. While the existing reference implementations are quite fast for their needs, they personally have some inefficient design choices that I believed could be improved on. Even between Intel TBB and Go's runtime, their implementations aren't that similar to each other and arguably pretty ew code wise [TBH](https://www.howtogeek.com/447760/what-does-tbh-mean-and-how-do-you-use-it/#:~:text=%E2%80%9CTo%20Be%20Honest%E2%80%9D%20or%20%E2%80%9C,%2C%20and%20text%2Dmessage%20culture.) The former is a jungle of classes spread over different files to get to the meat of scheduling. The latter has a lot of short context-lacking variable names mixed with GC/tracing stuff which distracted me when I was first understanding it. (**Update**: Go cleaned up [the scheduler](https://golang.org/src/runtime/proc.go) and It's much nicer now). 18 | 19 | Third, good thread pools aren't always straight forward. The [META](https://www.arc.unsw.edu.au/blitz/read/explainer-what-is-a-metaquestion#:~:text=In%20essence%2C%20a%20%22meta%22,%E2%80%9Cmost%20effective%20tactics%20available%E2%80%9D.) nowadays for I/O event loops is a work-stealing, wake-throttling, I/O sharing, mostly-cooperative, task scheduler. Yea it's a mouth full and, and yea each of the components carries its own implementation trade-offs, but this matrix of scheduling options will help understand why I started with such a design. 20 | 21 | ## Resource Efficiency 22 | 23 | Zig has a certain ethos or [Zen](https://ziglang.org/documentation/master/#Zen) which attracted me to the language in the first place. That is: the focus on edge cases and utilizing the hardware's resources in a good and less wasteful way. The best example of this is program memory. Having developed on a machine with relatively low memory as a restraint when starting out, this is a problem I wished to address early on the thread pool's design. 24 | 25 | When you simplify a thread pool's API, it all comes down a function which takes a Task or Unit-of-Work and queues it up for execution on some thread: `schedule(Task)`. Some implementations will often store the tasks in the thread pools itself and basically have an unbounded queue of them which heap allocates to grow. This can be wasteful memory-wise (and add synchronization overhead) so I decided to have Tasks in my thread pool be intrusively provided. 26 | 27 | ## Intrusive Memory 28 | 29 | Intrusive data structures are, as I understand it, when you store a reference to the callers data with the caller having more context on what that reference is. This contrasts to non-intrusive data structures which copy or move the callers data into container for ownership. 30 | 31 | Poor explanation, I know, but as an example you can think of a hash map as non-intrusive since it owns whatever key/value you insert into it and can changes its memory internally when growing. While a linked list in which the caller provides the node pointers, and can only deallocate the node once it's removed from the list, is labeled as intrusive. A [possibly better explanation here](https://www.boost.org/doc/libs/1_55_0/doc/html/intrusive/intrusive_vs_nontrusive.html), but our thread pool tasks now look like this: 32 | 33 | ```zig 34 | pub const Task = struct { 35 | next: ?*Task = null, 36 | callback: fn (*Task) void, 37 | }; 38 | 39 | pub fn schedule(task: *Task) void { 40 | // ... 41 | } 42 | ``` 43 | 44 | To schedule a callback with some context, you would generally store the Task itself *with* the context and use the [`@fieldParentPtr()`](https://ziglang.org/documentation/master/#fieldParentPtr) to convert the Task pointer back into the context pointer. If you're familiar with C, this is basically `containerof` but a bit more type safe. It takes a pointer to a field and gives you a pointer to the parent/container struct/class. 45 | 46 | ```zig 47 | const Context = struct { 48 | value: usize, 49 | task: Task, 50 | 51 | pub fn scheduleToIncrement(this: *Context) void { 52 | this.task = Task{ .callback = onScheduled }; 53 | schedule(&this.task); 54 | } 55 | 56 | fn onScheduled(task_ptr: *Task) void { 57 | const this = @fieldParentPtr(Context, "task", task_ptr); 58 | this.value += 1; 59 | } 60 | }; 61 | ``` 62 | 63 | This is a very powerful and memory efficient way to model callbacks. It leaves the scheduler to only interact with opaque Task pointers which are effectively just linked-list nodes. Zig makes this pattern easy and common too; The standard library uses intrusive memory and `containerof` to model runtime polymorphism for Allocators by having them hold a function pointer which takes in an opaque Allocator pointer where the function's implementation uses `@fieldParentPtr` on the Allocator pointer to get its allocator-specific context. It's like a cool alternative to [vtables](https://en.wikipedia.org/wiki/Virtual_method_table). 64 | 65 | ## Scheduling and the Run Loop 66 | 67 | Now that we have the basic API down, we can actually make a single threaded implementation to understand the concept of task schedulers. 68 | 69 | ```zig 70 | stack: ?*Task = null, 71 | 72 | pub fn schedule(task: *Task) void { 73 | task.next = stack; 74 | stack = task; 75 | } 76 | 77 | pub fn run() void { 78 | while (stack) |task| { 79 | stack = task.next; 80 | (task.callback)(task); 81 | } 82 | } 83 | ``` 84 | 85 | This is effectively what most schedulers, and hence thread pools, boil down to. The main difference from this and a threaded version is that the queue of Tasks to run called `stack` here is conceptually shared between threads and multiple threads are popping from it in order to call Task callbacks. Let's make our simple example thread-safe by adding a Mutex. 86 | 87 | ```zig 88 | lock: std.Mutex = .{}, 89 | stack: ?*Task = null, 90 | 91 | pub fn schedule(task: *Task) void { 92 | const held = self.lock.acquire(); 93 | defer held.release(); 94 | 95 | task.next = stack; 96 | stack = task.next; 97 | } 98 | 99 | fn runOnEachThread() void { 100 | while (dequeue()) |task| 101 | (task.callback)(task); 102 | } 103 | 104 | fn dequeue() ?*Task { 105 | const held = self.lock.acquire(); 106 | defer held.release(); 107 | 108 | const task = stack orelse return null; 109 | stack = task.next; 110 | return task; 111 | } 112 | ``` 113 | 114 | Did you spot the inefficiency here? We musn't forget that now there's multiple threads dequeueing from `stack`. If there's only one task running and the `stack` is empty, then all the other threads are just spinning on dequeue(). To save execution resources those threads should be put to sleep until `stack` is populated. I'll spare you the details this time but we've boiled down the API for a multi-threaded thread-pool here to this pseudo code: 115 | 116 | ## The Algorithm 117 | 118 | ```rs 119 | schedule(task): 120 | run_queue.push(task) 121 | threads.notify() 122 | 123 | join(): 124 | shutdown = true 125 | for all threads |t|: 126 | t.join() 127 | 128 | run_on_each_thread(): 129 | while not shutdown: 130 | if run_queue.pop() |task|: 131 | task.run() 132 | else: 133 | threads.wait() 134 | ``` 135 | 136 | Here is the algorithm that we will implement for our thread pool. I will refer back to this here and there and also reiterate over it later. For now, keep this as a reminder for where we're working on. 137 | 138 | ## Run Queues 139 | 140 | Let's focus on the run queue first. Having a shared run queue for all threads increases how much they fight over it when going to dequeue and is quite the bottleneck. This fighting is known as **contention** in synchronization terms and is the primary slowdown of any sync mechanism from Locks down to atomic instructions. The less threads are stomping over each other, the better the throughput in most cases. 141 | 142 | To help decrease contention on the shared run queue, we just give each thread its own run queue! When threads schedule(), they push to their own run queue. When they pop(), they first dequeue from their own, then try to dequeue() from others as a last resort. **This is nothing new, but is what people call work-stealing**. 143 | 144 | If the total work on the system is being pushed in by different threads, then this scales great since they're not touching each other most of the time. But when they start stealing, the contention slowdown reels it's head in again. This can happen a lot if there's only a few threads pushing work to their queues and the rest are just stealing. Time to investigate what we can do about that. 145 | 146 | ### Going Lock-Free: Bounded 147 | 148 | **WARNING**: here be atomics. Skip to [Notification Throttling](#Notification-Throttling) to get back into algorithm territory 149 | 150 | The first thing we can do is to get rid of the locks on the run queues. When there's a lot of contention a lock, the thread has to be put to sleep. This is a relatively expensive operation compared to the actual dequeue; It's a syscall for the losing thread to sleep and often a syscall for the winning thread to wake up a losing thread. We can avoid this with a few realizations. 151 | 152 | One realization is that there's only one producer to our thread local queues while there's multiple consumers in the form of "the work stealing threads". This means we don't need to synchronize the producer side and can use lock-free SPMC (single-producer-multi-consumer) algorithms. Golang uses a good one (which I believed is borrowed from Cilk?) that has a really efficient push() and can steal in batches, all without locks: 153 | 154 | ```rs 155 | head = 0 156 | tail = 0 157 | buffer: [N]*Task = uninit 158 | 159 | // -% is wrapping subtraction 160 | // +% is wrapping addition 161 | // `ATOMIC_CMPXCHG(): ?int` where `null` is success and `int` is failure with new value. 162 | 163 | push(task): 164 | h = ATOMIC_LOAD(&head, Relaxed) 165 | if tail -% h >= N: 166 | return Full 167 | // store to buffer must be atomic since slow steal() threads may still load(). 168 | ATOMIC_STORE(&buffer[tail % N], task, Unordered) 169 | ATOMIC_STORE(&tail, t +% 1, Release) 170 | 171 | pop(): 172 | h = ATOMIC_LOAD(&head, Relaxed) 173 | while h != tail: 174 | h = ATOMIC_CMPXCHG(&head, h, h +% 1, Acquire) orelse: 175 | return buffer[head % N] 176 | return null 177 | 178 | steal(into): 179 | while True: 180 | h = ATOMIC_LOAD(&head, Acquire) 181 | t = ATOMIC_LOAD(&tail, Acquire) 182 | if t -% h > N: continue // preempted too long between loads 183 | if t == h: return Empty 184 | 185 | // steal half to amortize the cost of stealing. 186 | // loads from buffer must be atomic since may be getting updated by push(). 187 | // stores to `into` buffer must be atomic since it's pushing. see push(). 188 | half = (t -% h) - ((t -% h) / 2) 189 | for i in 0..half: 190 | task = ATOMIC_LOAD(&buffer[(h +% i) % N], Unordered) 191 | ATOMIC_STORE(&into.buffer[(into.tail +% i) % N], task, Unordered) 192 | 193 | _ = ATOMIC_CMPXCHG(&head, h, h +% half, AcqRel) orelse: 194 | new_tail = into.tail +% half 195 | ATOMIC_STORE(&into.tail, new_tail -% 1, Release) 196 | return into.buffer[new_tail % N] 197 | ``` 198 | 199 | You can ignore the details but just know that this algorithm is nice because it allows stealing to happen concurrently to producing. Stealing can also happen concurrently to other steal()s and pop()s without ever having to pause the thread from issuing a blocking syscall. Basically, we've made the serialization points (places where mutual exclusion is needed) to be the atomic operations which happen in hardware instead of the locks which serialize entire OS threads using syscalls in software. 200 | 201 | Unfortunately, this algorithm is only for a bounded array. `N` could be pretty small relative to the overall Tasks that may be queued on a given thread so we need a way to hold tasks which overflow, but without re-introducing locks. This is where other implementations stop but we can keep going lock-free with more realizations. 202 | 203 | ### Going Lock-Free: Unbounded 204 | 205 | If we refer back to the pseudo code, `run_queue.push` is always followed by `threads.notify`. And a failed `run_queue.pop` is always followed by `threads.wait`. This means that the `run_queue.pop` is allowed to spuriously see empty run queues and wait without worry as long as there's a matching notification to wake it up. 206 | 207 | This is actually a powerful realization. It means that any OS thread blocking/unblocking from syscalls done in the run queue operations can actually be omitted since `threads.wait` and `threads.notify` are already doing the blocking/unblocking! **If a run queue operation would normally block, _it just shouldn't_** since it will already block once it fails anyways. The thread can use that free time to check other thread run queues (instead of blocking) before reporting an empty dequeue. We've effectively merged thread sleep/wakeup mechanisms with run queue synchronization. 208 | 209 | We can translate this realization to each thread having a non-blocking-lock (i.e. `try_lock()`, no `lock()`) protected queue *along with the SPMC buffer*. When our thread's buffer overflows, we take/steal half of it, build a linked list from that, then lock our queue and push that linked list. Migrating half instead of 1 amortizes the cost of stealing from ourselves on push() and makes future pushes go directly to the buffer which is fast. 210 | 211 | When we dequeue and our buffer is empty, we *try to lock* our queue and pop/refill our buffer with tasks from the queue. If both our buffer and queue are empty (or if another thread is holding our queue lock) then we steal by *try_locking* and refilling from *other* thread queues, stealing from their buffer if that doesn't work. This is in reverse order to how they dequeue to, again, avoid contention. 212 | 213 | ```rs 214 | run_queue.push(): 215 | if thread_local.buffer.push(task) == Full: 216 | migrated = thread_local.buffer.steal() + task 217 | thread_local.queue.lock_and_push(migrated) 218 | 219 | run_queue.pop(): ?*Task 220 | if thread_local.buffer.pop() |task| 221 | return task 222 | if thread_local.queue.try_lock_and_pop() |task| 223 | return task 224 | for all other threads |t|: 225 | if t.queue.try_lock_and_pop() |task| 226 | return task 227 | if t.buffer.steal(into: &thread_local.buffer) |task| 228 | return task 229 | return null 230 | ``` 231 | 232 | ---- 233 | 234 | This might have been a lot to process, but hopefully the code shows what's going on. If you're still reading ... take a minute break or something; There's still more to come. If you're attentive, you may remember that I said we should do this without locks but there's still `lock_and_push()` in the producer! Well here we go again. 235 | 236 | ---- 237 | 238 | ### Going Lock-Free: Unbounded; Season 1 pt. 2 239 | 240 | You also may have noticed that the "try_lock_" in `try_lock_and_pop` for our thread queues is just there to enforce serialization on the consumer side. There's also still only one producer. Using these assumptions, we can reduce the queues down to non-blocking-lock protected lock-free SPSC queues. This would allow the producer to operate lock-free to the consumer and remove the final blocking serialization point that is `lock_and_push()`. 241 | 242 | Unfortunately, there don't seem to be any unbounded lock-free SPSC queues out there which are fully intrusive *and* don't use atomic read-modify-write instructions (that's generally avoided for SPSC). But that's fine! We can just use an intrusive unbounded MPSC instead. Dmitry Vyukov developed/discovered a [fast algorithm](https://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue) for such use-case a while back which has been well known and used everywhere from [Rust stdlib](https://doc.rust-lang.org/src/std/sync/mpsc/mpsc_queue.rs.html) to [Apple GCD](https://github.com/apple/swift-corelibs-libdispatch/blob/34f383d34450d47dd5bdfdf675fcdaa0d0ec8031/src/inline_internal.h#L1510) to [Ponylang](https://github.com/ponylang/ponyc/blob/7d38ffa91cf5f89f94daf6f195dfae3bd3395355/src/libponyrt/actor/messageq.c#L31). 243 | 244 | We can also merge the non-blocking-lock acquisition and release into the MPSC algorithm itself by having the pop-end be `ATOMIC_CMPXCHG` acquired with a sentinel value and released by storing the actual pointer after popping from the queue with the acquired pop-end. Again, here's just the nitty gritty for those interested. 245 | 246 | ```rs 247 | stub: Task = .{ .next = null }, 248 | head: usize = 0 249 | tail: ?*Task = null 250 | 251 | CONSUMING = 1 252 | 253 | push(migrated): 254 | migrated.tail.next = null 255 | t = ATOMIC_SWAP(&tail, migrated.tail, AcqRel) 256 | prev = t orelse &stub 257 | ATOMIC_STORE(&prev.next, migrated.head, Release) 258 | 259 | try_lock(): ?*Task 260 | h = ATOMIC_LOAD(&head, Relaxed) 261 | while True: 262 | if h == 0 and ATOMIC_LOAD(&tail, Relaxed) == null: 263 | return null // Empty queue 264 | if h == CONSUMING: 265 | return null // Queue already locked 266 | 267 | h = ATOMIC_CMPXCHG(&head, h, CONSUMING, Acquire) orelse: 268 | return (h as ?*Task) orelse &stub 269 | 270 | pop(ref locked_head: *Task): ?*Task 271 | if locked_head == &stub: 272 | locked_head = ATOMIC_LOAD(&stub.next, Acquire) orelse return null 273 | 274 | if ATOMIC_LOAD(&locked_head.next, Acquire) |next| 275 | defer locked_head = next; 276 | return locked_head; 277 | 278 | // push was preempted between SWAP and STORE 279 | // its ok since we can safely return spurious empty 280 | if ATOMIC_LOAD(&tail, Acquire) != locked_head: 281 | return null 282 | 283 | // Try to Ensure theres a next node 284 | push(LinkedList.from(&stub)) 285 | 286 | // Same thing as above 287 | const next = ATOMIC_LOAD(&locked_head.next, Acquire) orelse return null 288 | defer locked_head = next; 289 | return locked_head 290 | 291 | unlock(locked_head: *Task): 292 | assert(ATOMIC_LOAD(&head, Unordered) == CONSUMING) 293 | ATOMIC_STORE(&head, locked_head as usize, Release) 294 | ``` 295 | 296 | **SIDENOTE**: A different algorithm ended up in the final thread pool since I discovered a "mostly-LIFO" version of this which performs about the same in practice. It's a [Treiber Stack](https://en.wikipedia.org/wiki/Treiber_stack) MPSC which swaps the entire stack with null for consumer. The idea used fairly often in the wild (See [mimalloc: 2.4 The Thread Free List](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/mimalloc-tr-v1.pdf)), but I just figured out a way to add try-lock usage to the consumer end. 297 | 298 | ```rs 299 | stack: usize = 0 300 | cache: ?*Task = null 301 | 302 | MASK = ~0b11 303 | CACHED = 0b01 304 | CONSUMING = 0b10 305 | 306 | // Classic treiber stack push 307 | push(migrated): 308 | s = ATOMIC_LOAD(&stack, Relaxed) 309 | while True: 310 | migrated.tail.next = (s & MASK) as ?*Task 311 | new = (migrated.head as usize) | (s & ~MASK) 312 | s = ATOMIC_CMPXCHG(&stack, s, new, Release) orelse break 313 | 314 | try_lock(): ?*Task 315 | s = ATOMIC_LOAD(&stack, Relaxed) 316 | while True: 317 | if s & CONSUMING != 0: 318 | return null // Queue already locked 319 | if s & (MASK | CACHED) == 0: 320 | return null // Queue is empty 321 | 322 | // Grab consuming, but also grab the pushed stack if nothings cached 323 | new = s = CONSUMING | CACHED 324 | if s & CACHED == 0: 325 | new &= ~MASK 326 | 327 | s = ATOMIC_CMPXCHG(&stack, s, new, Acquire) orelse: 328 | return cache orelse ((s & MASK) as *Task) 329 | 330 | pop(ref locked_stack: ?*Task): ?*Task 331 | // fast path 332 | if locked_stack |task|: 333 | locked_stack = task.next 334 | return task 335 | 336 | // quick load before the swap to avoid taking ownership of cache line 337 | if ATOMIC_LOAD(&stack, Relaxed) & MASK == 0: 338 | return null 339 | 340 | // grab the stack in one foul swoop 341 | s = ATOMIC_SWAP(&stack, CONSUMING | CACHED, Acquire) 342 | task = ((s & MASK) as ?*Task) orelse return null 343 | locked_stack = task.next 344 | return task 345 | 346 | 347 | unlock(locked_stack: ?*Task): 348 | // remove the CACHED bit if the cache is empty 349 | // which will cause next try_lock() to consume the stack 350 | sub = CONSUMING 351 | if locked_stack == null: 352 | sub |= CACHED 353 | 354 | cache = locked_stack 355 | ATOMIC_SUB(&stack, sub, Release) 356 | ``` 357 | 358 | ### Going Lock-Free: Unbounded; Season 1 pt. 3 359 | 360 | Now that the entire run queue is lock-free, we've actually introduced a situation where thread A can grab the queue lock of thread B and the thread B would see empty (its queue is currently locked) and sleep on `threads.wait`. This is expected, but the sad part is that the queue lock holder may leave some remaining Tasks after refilling it's buffer even while there's sleeping threads that could process those Tasks! As a general rule, **anytime we push to the buffer in any way, even when work-stealing, follow it up with a notification**. This prevents under-utilization of threads in the pool and we must change the algorithm to reflect this: 361 | 362 | ```rs 363 | run_on_each_thread(): 364 | while not shutdown: 365 | if run_queue.pop() |(task, pushed)|: 366 | if pushed: threads.notify() 367 | task.run() 368 | else: 369 | threads.wait() 370 | 371 | run_queue.pop(): ?(task: *Task, pushed: bool) 372 | if thread_local.buffer.pop() |task| 373 | return (task, false) 374 | if steal_queue(&thread_local.queue) |task| 375 | return (task, true) 376 | for all other threads |t|: 377 | if steal_queue(&t.queue) |task| 378 | return (task, true) 379 | if t.buffer.steal(into: &thread_local.buffer) |task| 380 | return (task, true) 381 | return null 382 | ``` 383 | 384 | ## Notification Throttling 385 | 386 | The run queue is now optimized and by this point it has improved throughput the most so far. The next thing to do is to optimize how threads are put to sleep and woken up through `threads.wait` and `threads.notify`. The run queue relies on `wait()` to handle spurious reports of being empty, and `notify()` is now called on every steal, so both functions have to be efficient. 387 | 388 | I mentioned before that putting a thread to sleep and waking it up are both "expensive" syscalls. We should also not try to wake up all threads for each `notify()` as that would increase contention on the run queues (even if we're already trying hard to avoid it). The best solution that myself and others have found in practice is to throttle thread wake ups. 389 | 390 | Throttling in this case means that **when we *do* wake up a thread, we don't wake up another until the woken up thread has actually been scheduled by the OS**. We can take this even further by requiring that the woken up thread to find Tasks before waking another. This is what Golang and Rust async executors do to great results and is what we will do as well, but in a *different* way. 391 | 392 | For context, Golang and Rust use a counter of all the threads who are stealing. They only wake up a thread if there's no threads currently stealing. So `notify()` tries to `ATOMIC_CMPXCHG()` the stealing count from 0 to 1 and wakes only if that's successful. When entering the work stealing portion, the count is incremented if some heuristics deem OK. When leaving, the stealing count is decremented and if the last thread to exit stealing finds a Task, it will try to `notify()` again. This works for other thread pools, but is a bit awkward for us for a few reasons. 393 | 394 | We want to have a similar throttling but have different requirements. Unlike Rust, we spawn threads lazily to support static initialization for our thread pool. Unlike Go, we don't use locks for mutual exclusion to know whether to wake up or spawn a new thread on `notify()`. We also want to allow thread spawning to fail without bringing the entire program down from a `panic()` like both Go and Rust. Threads are a resource which, like memory, can be constrained at runtime and we should be explicit about handle it as per Zig Zen. 395 | 396 | (**Update**: I found a way to make the Go-style system work for our thread pool after the blog was written. Go [check out the source](https://github.com/kprotty/zap/blob/blog/src/thread_pool_go_based.zig)) 397 | 398 | I came up with a different solution which I believe is a bit friendlier to [LL/SC](https://en.wikipedia.org/wiki/Load-link/store-conditional) systems like ARM, but also solves the problems listed above. I originally called it `Counter` but have started calling it `Sync` out of simplicity. All thread coordination state is stored in a single machine word which packs the bits full of meaning (yay memory efficiency!) and is atomically transitioned through `ATOMIC_CMPXCHG`. 399 | 400 | ### Counter/Sync Algorithm 401 | 402 | ```zig 403 | enum State(u2): 404 | pending = 0b00 405 | waking = 0b01 406 | signaled = 0b10 407 | shutdown = 0b11 408 | 409 | packed struct Sync(u32): 410 | state: State 411 | notified: bool(u1) 412 | unused: bool(u1) 413 | idle_threads: u14 414 | spawned_threads: u14 415 | ``` 416 | 417 | The thicc-but-not-really `Sync` struct tracks the "pool state" which is used to control thread signaling, shutdown, and throttling. That's followed by a boolean called `notified` which helps in thread notification, `unused` which you can ignore (it's just there to pad it to `u32`), and counters for the amount of threads sleeping and the amount of threads created. You could extend `Sync`'s size from `u32` to `u64` on 64bit platforms and grow the counters, but if you need more than 16K (`1 << 14`) threads in your thread pool, you have bigger issues... 418 | 419 | In order to implement thread wakeup throttling, we introduce something called "the waking thread". To wake up a thread, the `state` is transitioned from `pending` to `signaled`. Once a thread wakes up, it consumes this signal by transitioning the state from `signaled` to `waking`. The winning thread to consume the signal now becomes the "waking thread". 420 | 421 | While there is a "waking thread", no other thread can be woken up. The waking thread will either dequeue a Task or go back to sleep. If it finds a Task, it must transfer its "waking" status to someone else by transitioning from `waking` to `signaled` and wake up another thread. If it doesn't find Tasks, it must transition from `waking` to `pending` before going back to sleep. 422 | 423 | This results in the same throttling mechanisms found in Go and Rust by avoiding a [thundering herd](https://en.wikipedia.org/wiki/Thundering_herd_problem) of threads on `notify()`, decreases contention on the amount of stealing threads, and amortizes the syscall cost of actually waking up a thread: 424 | 425 | * T1 pushes Tasks to its run queue and calls `notify()` 426 | * T2 is woken up and designated as the "waking thread" 427 | * T1 pushses Tasks again but can't wake up other threads since T2 is still "waking" 428 | * T2 steals Tasks from T1 and wakes up T3 as the new "waking" thread 429 | * T3 steals from from either T2 or T1 and wakes T4 as the new "waking" thread. 430 | * By the time T4 wakes up, all Tasks have been processed 431 | * T4 fails to steal Tasks, gives up the "waking thread" status, and goes back to sleep on `wait()` 432 | 433 | ### Thread Counters and Races 434 | 435 | So far, we've only talked about the `state`, but there's still `notified`, `idle_threads` and `spawned_threads`. These are here to optimize the algorithm and provide lazy/faillable thread spawning as I mentioned a while back. Let's go through all of them: 436 | 437 | First, let's check out `spawned_threads`. Since it's handled atomically with `idle_threads`, this gives us a choice on how we want to "wake" up a thread. If there's existing idle/sleeping threads, we should of course prefer waking up those instead of spawning new ones. But if there aren't any, we can accurately spawn more until we reach a user-set "max threads" capacity. **If spawning a thread fails, we just decrement this count**. `spawned_threads` is also used to synchronize shutdown which is explained later. 438 | 439 | Then there's `notified`. Even when there's a "waking" thread, we still don't want `notify()`s to be lost as then that's missed wake ups which lead to CPU under-utilization. So every time we `notify()`, we also set the `notified` bit if it's not already. Threads going to sleep can observe the `notified` bit and try to consume it. Consuming it acts like a pseudo wake up so the thread should recheck run queues again instead of sleeping, which applies to the "waking" thread as well. This keeps the Threads on their toes by having at most one other non-waking thread searching for Tasks. For `Sync(u64)`, we could probably extend this to a counter to have more active searching threads. 440 | 441 | Finally there's `idle_threads`. When a thread goes to sleep, it increments `idle_threads` by one then sleeps on a semaphore or something. A non-zero idle count allows `notify()` to know to transition to `signaled` and post to the theoretical semaphore. It's the notification-consuming thread's responsibility to decrement the `idle_threads` count when it transitions the state from `signaled` to `waking` or munches up the `notified` bit. Those familiar with [semaphore internals](https://code.woboq.org/userspace/glibc/nptl/sem_waitcommon.c.html#__new_sem_wait_slow) or [event counts](https://github.com/r10a/Event-Counts) will recognize this `idle_threads` algorithm. 442 | 443 | ### Shutdown Synchronization 444 | 445 | When the book of revelations comes to pass, and the thread pool is ready to reap and ascend to reclaimed memory, it must first make peace with its children to join gracefully. For thou pool must not be eager to return, else they risk the memory corruption of others. The scripture recites a particular mantra to perform the process: 446 | 447 | Transition the `state` from whatever it is to `shutdown`, then post to the semaphore if there were any `idle_threads`. This notifies the threads that the end is *among us*. `notify()` bails if it observes the state to be `shutdown`. `wait()` decrements `spawned_threads` and bails when it observes `shutdown`. The last thread to decrement the spawned count to zero must notify the pool that *it is time*. 448 | 449 | The thread pool can iterate its children threads and sacrifice them to the kernel... but wait, we never explained how the thread pool keeps track of threads? Well, to keep with the idea of intrusive memory, a thread pushes itself to a lock-free stack in the thread pool on spawn. Threads find each other by following that stack and restarting from the top when the first-born is reached. We just follow this stack as well when `spawned_threads` reaches 0 to join them. 450 | 451 | The final algorithm is as follows. Thank you for coming to my TED talk. 452 | 453 | ```rs 454 | 455 | notify(is_waking: bool): 456 | s = ATOMIC_LOAD(&sync, Relaxed) 457 | while s.state != .shutdown: 458 | new = { s | notified: true } 459 | can_wake = is_waking or s.state == .pending 460 | if can_wake and s.idle > 0: 461 | new.state = .signaled 462 | else if can_wake and s.spawned < max_spawn: 463 | new.state = .signaled 464 | new.spawned += 1 465 | else if is_waking: // nothing to wake, transition out of waking 466 | new.state = .pending 467 | else if !s.notified: 468 | return // nothing to wake or notify 469 | 470 | s = ATOMIC_CMPXCHG(&sync, s, new, Release) orelse: 471 | if can_wake and s.idle > 0: 472 | return thread_sema.post(1) 473 | if can_wake and s.spawned < max_spawn: 474 | return spawn_thread(run_on_each_thread) catch kill_thread(null) 475 | return 476 | 477 | wait(is_waking: bool): error{Shutdown}!bool 478 | is_idle = false 479 | s = ATOMIC_LOAD(&sync, Relaxed) 480 | while True: 481 | if s.state == .shutdown: 482 | return error.Shutdown 483 | 484 | if s.notified or !is_idle: 485 | new = { s | notified: false } 486 | if s.notified: 487 | if s.state == .signaled: 488 | new.state = .waking 489 | if is_idle: 490 | new.idle -= 1 491 | else: 492 | new.idle += 1 493 | if is_waking: 494 | new.state = .pending 495 | 496 | s = ATOMIC_CMPXCHG(&sync, s, new, Acquire) orelse: 497 | if s.notified: 498 | return is_waking or s.state == .signaled 499 | is_waking = false 500 | is_idle = true 501 | s = new 502 | continue 503 | 504 | thread_sema.wait() 505 | s = ATOMIC_LOAD(&sync, Relaxed) 506 | 507 | kill_thread(thread: ?*Thread): 508 | s = ATOMIC_SUB(&sync, Sync{ .spawned = 1 }, Release) 509 | if s.state == .shutdown and s.spawned - 1 == 0: 510 | shutdown_sema.notify() 511 | 512 | if thread |t|: 513 | t.join_sema.wait() 514 | 515 | shutdown_and_join(): 516 | s = ATOMIC_SWAP(&sync, Sync{ .state = .shutdown }, AcqRel); 517 | if s.idle > 0: 518 | thread_sema.post(s.idle) 519 | 520 | shutdowm_sema.wait() 521 | for all_threads following stack til null |t|: 522 | t.join_sema.post(1) 523 | 524 | run_on_each_thread(): 525 | atomic_stack_push(&all_threads, &thread_local) 526 | defer kill_thread(&thread_local) 527 | 528 | is_waking = false 529 | while True: 530 | is_waking = try wait(is_waking) 531 | 532 | while dequeue() |(task, pushed)|: 533 | if is_waking or pushed: 534 | notify(is_waking) 535 | is_waking = false 536 | task.run() 537 | 538 | dequeue(): ?(task: *Task, pushed: bool) 539 | if thread_local.buffer.pop() |task| 540 | return (task, false) 541 | if steal_queue(&thread_local.queue) |task| 542 | return (task, true) 543 | 544 | for all_threads following stack til null |t|: 545 | if steal_queue(&t.queue) |task| 546 | return (task, true) 547 | if t.buffer.steal(into: &thread_local.buffer) |task| 548 | return (task, true) 549 | return null 550 | ``` 551 | 552 | ## Closings 553 | 554 | I probably missed something in my explanations. If so, I urge you to read the [source](https://github.com/kprotty/zap/blob/blog/src/thread_pool.zig). It's well commented I assure you :). I've provided a [Zig `async` wrapper](https://github.com/kprotty/zap/blob/blog/benchmarks/zig/async.zig) to the thread pool as well as [benchmarks](https://github.com/kprotty/zap/tree/blog/benchmarks) for competing async runtimes in the repository. Feel free run those locally, add your own, or modify the zig one. Learned a lot by doing this so here's some other links to articles about varying [schedulers along with my own tips](https://twitter.com/kingprotty/status/1416774977836093445). *And as always, hope you learned something* 555 | --------------------------------------------------------------------------------