├── .gitignore ├── thread_pool.cpp ├── test_yield.cpp ├── priority_thread_pool.hpp ├── README.md ├── test.cpp ├── priority_thread_pool.cpp └── thread_pool.hpp /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | a.out 3 | -------------------------------------------------------------------------------- /thread_pool.cpp: -------------------------------------------------------------------------------- 1 | #include "thread_pool.hpp" 2 | 3 | using namespace std; 4 | 5 | thread_pool::thread_pool(unsigned int n) : base_thread_pool(n) { 6 | init_mutex.unlock(); 7 | } 8 | 9 | thread_pool::~thread_pool() { 10 | wait(); 11 | } 12 | 13 | std::optional> thread_pool::get_task() { 14 | optional> ret; 15 | lock_guard lk(task_mutex); 16 | if(!tasks.empty()) { 17 | ret = std::move(tasks.front()); 18 | tasks.pop(); 19 | } 20 | return ret; 21 | } 22 | 23 | void thread_pool::handle_task(std::future f) { 24 | f.get(); 25 | } 26 | -------------------------------------------------------------------------------- /test_yield.cpp: -------------------------------------------------------------------------------- 1 | #include "priority_thread_pool.hpp" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | 13 | static void core_dump(int sigid) 14 | { 15 | kill(getpid(), SIGSEGV); 16 | } 17 | 18 | const int num_threads = 3; 19 | 20 | void test_void_void(){ 21 | priority_thread_pool p(num_threads); 22 | vector> f; 23 | 24 | std::function low_prio = []() { 25 | long long i = 0; 26 | int j = 0; 27 | for(;;) { 28 | ++i; 29 | if(i % 10000000000 == 0) { 30 | cout << "L" << endl; 31 | i = 0; 32 | j++; 33 | if (j == 100) break; 34 | priority_thread_pool::yield(); 35 | } 36 | } 37 | cout << "L done" << endl; 38 | }; 39 | 40 | std::function high_prio = []() { 41 | long long i = 0; 42 | int j = 0; 43 | for(;;) { 44 | ++i; 45 | if(i % 10000000000 == 0) { 46 | cout << "H" << endl; 47 | i = 0; 48 | j++; 49 | if (j == 100) break; 50 | priority_thread_pool::yield(); 51 | } 52 | } 53 | cout << "H done" << endl; 54 | }; 55 | 56 | // Saturate threads with low prio task. 57 | for(int i = 0;i < num_threads * 2;i++) 58 | f.emplace_back(p.async(0, low_prio)); 59 | 60 | usleep(2000000); 61 | 62 | // Push high prio task. 63 | for(int i = 0;i < num_threads;i++) 64 | f.emplace_back(p.async(10, high_prio)); 65 | cout << "HIGH PRIO PUSHED" << endl; 66 | } 67 | 68 | int main(){ 69 | //signal(SIGINT, core_dump); 70 | test_void_void(); 71 | return 0; 72 | } 73 | -------------------------------------------------------------------------------- /priority_thread_pool.hpp: -------------------------------------------------------------------------------- 1 | #ifndef PRIORITY_THREAD_POOL_HPP 2 | #define PRIORITY_THREAD_POOL_HPP 3 | 4 | #include "thread_pool.hpp" 5 | 6 | #include 7 | 8 | class priority_task { 9 | public: 10 | priority_task(std::function work, int priority = 0) : 11 | work(work), priority(priority) {} 12 | priority_task(const priority_task&) = delete; 13 | ~priority_task(); 14 | 15 | bool operator<(const priority_task& t) const; 16 | 17 | bool run(); 18 | void pause(); 19 | 20 | private: 21 | std::function work; 22 | 23 | void* volatile work_stack = nullptr; 24 | 25 | // Used to to pause. Returns to scheduler. 26 | ucontext_t pause_context; 27 | 28 | // Returns to running task. 29 | ucontext_t work_context; 30 | 31 | int priority; 32 | volatile bool started = false; 33 | volatile bool paused = false; 34 | volatile bool done = false; 35 | 36 | static void _run(void); 37 | }; 38 | 39 | class priority_thread_pool : public base_thread_pool>{ 40 | public: 41 | priority_thread_pool(unsigned int n); 42 | virtual ~priority_thread_pool(); 43 | 44 | /** 45 | * Pushes a new task to the queue. 46 | * 47 | * @param f the function to call when executing the task 48 | * @param args the arguments to pass to the function 49 | * 50 | * @return the future used to wait on the task and get the result 51 | */ 52 | template 53 | auto async(int priority, Fn f, Args... args){ 54 | auto p = package(f, args...); 55 | return std::move(add_task(priority, std::move(p))); 56 | } 57 | 58 | /// Called by tasks of this thread pool to yield. 59 | static void yield(); 60 | 61 | protected: 62 | virtual std::optional> get_task() override; 63 | virtual void handle_task(std::shared_ptr) override; 64 | 65 | auto add_task(int priority, auto p) { 66 | auto t = std::shared_ptr(new priority_task(p.first, priority)); 67 | task_mutex.lock(); 68 | tasks.emplace(t); 69 | task_mutex.unlock(); 70 | return std::move(p.second); 71 | } 72 | 73 | private: 74 | std::priority_queue> tasks; 75 | }; 76 | 77 | #endif 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | thread_pool 2 | =========== 3 | 4 | Simple thread pool using only standard library components. Also includes a class for a priority thread pool. 5 | 6 | Requires concepts and C++20. Currently only GCC 10.0+ is sufficient. Use `-std=c++20 -fconcepts` to compile. 7 | 8 | The priority thread pool is only supported on POSIX/-like systems. But it's still easy to use the normal pool on non-POSIX; just don't compile priority_thread_pool.cpp or include the header. 9 | 10 | For just C++11, use `8bdfb9b`. `5ea01d0` was the latest to support <= C++14. For C++17, use `e3be25` and compile with `-std=c++17 -fconcepts`. 11 | 12 | The priority pool has the same API as described below, accept it has an int parameter first for the priority of the task. E.g. `pool.async(5, func, arg1, arg2)` for priority 5. 13 | 14 | An example that computes the primality of 2 to 10,000 using 8 threads: 15 | ```c++ 16 | #include "thread_pool.hpp" 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | using namespace std; 23 | 24 | // Return the integer argument and a boolean representing its primality. 25 | // We need to return the integer because the loop that retrieves results doesn't know which integer corresponds to which future. 26 | pair is_prime(int n){ 27 | for(int i = 2;i < n;i++) 28 | if(n % i == 0) 29 | return make_pair(i, false); 30 | return make_pair(n, true); 31 | } 32 | 33 | int main(){ 34 | thread_pool pool(8); // Contruct a thread pool with 8 threads. 35 | list>> results; 36 | for(int i = 2;i < 10000;i++){ 37 | // Add a task to the queue. 38 | results.push_back(pool.async(is_prime, i)); 39 | } 40 | 41 | for(auto i = results.begin();i != results.end();i++){ 42 | pair result = i->get(); // Get the pair from the future<...> 43 | cout << result.first << ": " << (result.second ? "is prime" : "is composite") << endl; 44 | } 45 | return 0; 46 | } 47 | ``` 48 | 49 | `thread_pool::async` is a templated method that accepts any `std::function` and arguments to pass to the function. It returns a `std::future` where `Ret` is the return type of the aforementioned `std::function`. 50 | 51 | To submit a task: `future fut(pool.async(func, args...));`. 52 | 53 | To wait on `fut` to complete: `Ret result = fut.get();` 54 | -------------------------------------------------------------------------------- /test.cpp: -------------------------------------------------------------------------------- 1 | #include "thread_pool.hpp" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | 13 | static void core_dump(int sigid) 14 | { 15 | kill(getpid(), SIGSEGV); 16 | } 17 | 18 | const int num_threads = 8; 19 | 20 | void func_void_void(){ 21 | cout << "void" << endl; 22 | } 23 | 24 | void test_void_void(){ 25 | thread_pool p(num_threads); 26 | vector> f; 27 | for(int i = 10000000;i < 20000000;i++) 28 | f.emplace_back(p.async(func_void_void)); 29 | for(auto i = f.begin();i != f.end();i++) 30 | i->get(); 31 | } 32 | 33 | void func_void_int(int n){ 34 | bool prime = true; 35 | if(n < 2) 36 | prime = false; 37 | for(int i = 2;i < n;i++) 38 | if(n % i == 0) 39 | prime = false; 40 | cout << n << ": " << (prime ? "Prime" : "Composite") << endl; 41 | } 42 | 43 | void test_void_int(){ 44 | thread_pool p(num_threads); 45 | vector> f; 46 | for(int i = 10000000;i < 20000000;i++) 47 | f.emplace_back(p.async(func_void_int, i)); 48 | for(auto i = f.begin();i != f.end();i++) 49 | i->get(); 50 | } 51 | 52 | int func_int_void(){ 53 | return 42; 54 | } 55 | 56 | void test_int_void(){ 57 | thread_pool p(num_threads); 58 | vector> f; 59 | for(int i = 100000;i < 1000000;i++) 60 | f.emplace_back(p.async(func_int_void)); 61 | for(auto i = f.begin();i != f.end();i++) 62 | cout << i->get() << endl; 63 | } 64 | 65 | bool func_bool_int(int n){ 66 | if(n < 2) 67 | return false; 68 | for(int i = 2;i < n;i++) 69 | if(n % i == 0) 70 | return false; 71 | return true; 72 | } 73 | 74 | void test_bool_int(){ 75 | thread_pool p(num_threads); 76 | vector> f; 77 | int n = 100000; 78 | int max = 10 * n; 79 | for(int i = n;i < max;i++) 80 | f.emplace_back(p.async(func_bool_int, i)); 81 | for(auto i = f.begin();i != f.end();i++, n++) 82 | cout << n << ": " << i->get() << endl; 83 | } 84 | 85 | int func_int_int_int(int a, int b) { 86 | return a * b; 87 | } 88 | 89 | void test_void_int_int() { 90 | thread_pool p(num_threads); 91 | vector> f; 92 | for(int i = 1;i < 100;i++) { 93 | for(int j = 1;j < 100;j++) { 94 | f.emplace_back(p.async(func_int_int_int, i, j)); 95 | } 96 | } 97 | for(auto i = f.begin();i != f.end();i++) 98 | i->get(); 99 | } 100 | 101 | int main(){ 102 | //signal(SIGINT, core_dump); 103 | test_bool_int(); 104 | test_void_void(); 105 | return 0; 106 | } 107 | -------------------------------------------------------------------------------- /priority_thread_pool.cpp: -------------------------------------------------------------------------------- 1 | #include "priority_thread_pool.hpp" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | 10 | // Size of stacks used by task executor contexts. 11 | constexpr std::size_t STACK_SIZE = 1024 * 8; 12 | 13 | static map> cur_tasks; 14 | static mutex cur_tasks_mutex; 15 | 16 | priority_task::~priority_task() { 17 | assert(done); 18 | free(work_stack); 19 | 20 | // Fail fast. 21 | work_stack = nullptr; 22 | work_context.uc_stack.ss_sp = nullptr; 23 | } 24 | 25 | bool priority_task::operator<(const priority_task& t) const { 26 | return priority < t.priority; 27 | } 28 | 29 | /** 30 | * Actually runs the task, in a forked context. 31 | */ 32 | void priority_task::_run(void) { 33 | shared_ptr t; 34 | { 35 | lock_guard lk(cur_tasks_mutex); 36 | auto it = cur_tasks.find(this_thread::get_id()); 37 | assert(it != cur_tasks.end()); 38 | t = it->second; 39 | } 40 | t->work(); 41 | t->done = true; 42 | } 43 | 44 | /** 45 | * Starts or resumes the forked context and returns whether it is finished. 46 | */ 47 | bool priority_task::run() { 48 | paused = false; 49 | 50 | // This is where we'll resume when yield is called. 51 | getcontext(&pause_context); 52 | if(!started) { 53 | // Create the context which will execute the function. 54 | getcontext(&work_context); 55 | work_stack = malloc(STACK_SIZE); 56 | work_context.uc_stack.ss_size = STACK_SIZE; 57 | work_context.uc_stack.ss_sp = work_stack; 58 | work_context.uc_stack.ss_flags = 0; 59 | work_context.uc_link = &pause_context; 60 | makecontext(&work_context, &_run, 0); 61 | 62 | started = true; 63 | setcontext(&work_context); 64 | __builtin_unreachable(); 65 | } 66 | else { 67 | // done will be true after work_context returns. 68 | if(done) { 69 | return true; 70 | } 71 | // pause will be true if we're called by setcontext(&pause_context) in task::pause(). 72 | else if(!paused) { 73 | setcontext(&work_context); 74 | __builtin_unreachable(); 75 | } 76 | // Effectively, this is the case wherein we're not done (the work 77 | // function hasn't returned) and we're paused. So we return false, 78 | // signifying to the priority_thread_pool::handle_task that the 79 | // task needs to be added back to be resumed later. 80 | else { 81 | return false; 82 | } 83 | } 84 | } 85 | 86 | /** 87 | * Pauses the work context and resumes the context in ::run(). 88 | */ 89 | void priority_task::pause() { 90 | paused = true; 91 | 92 | // We will resume here when task::run is called a second time. 93 | getcontext(&work_context); 94 | 95 | // pause will be false if we're being resumed with a second call to 96 | // task::run. (I.e. setcontext(&work_context).) 97 | if(paused) { 98 | // Jump back into task::run() to return to scheduler. 99 | setcontext(&pause_context); 100 | } 101 | // else return back to running work context. 102 | } 103 | 104 | priority_thread_pool::priority_thread_pool(unsigned int n) : base_thread_pool(n) { 105 | init_mutex.unlock(); 106 | } 107 | 108 | priority_thread_pool::~priority_thread_pool() { 109 | wait(); 110 | } 111 | 112 | /** 113 | * Yields the task the current thread is running. 114 | */ 115 | void priority_thread_pool::yield() { 116 | cur_tasks_mutex.lock(); 117 | auto it = cur_tasks.find(std::this_thread::get_id()); 118 | assert(it != cur_tasks.end()); 119 | auto task = it->second; 120 | cur_tasks_mutex.unlock(); 121 | task->pause(); 122 | } 123 | 124 | optional> priority_thread_pool::get_task() { 125 | optional> ret; 126 | lock_guard lk(task_mutex); 127 | if(!tasks.empty()) { 128 | ret = tasks.top(); 129 | tasks.pop(); 130 | } 131 | return ret; 132 | } 133 | 134 | void priority_thread_pool::handle_task(shared_ptr t) { 135 | auto id = this_thread::get_id(); 136 | { 137 | lock_guard lk(cur_tasks_mutex); 138 | assert(cur_tasks.emplace(id, t).second); 139 | } 140 | bool finished = t->run(); 141 | { 142 | lock_guard lk(cur_tasks_mutex); 143 | cur_tasks.erase(id); 144 | } 145 | // Finished is true when the task is finished executing. If it's not, 146 | // add it back to the heap and resume it later. 147 | if(!finished) { 148 | lock_guard lk(task_mutex); 149 | tasks.emplace(t); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /thread_pool.hpp: -------------------------------------------------------------------------------- 1 | #ifndef THREAD_POOL_HPP 2 | #define THREAD_POOL_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | template 13 | class base_thread_pool{ 14 | protected: 15 | /** 16 | * Wraps tasks in a executor function and wraps the promise which will 17 | * receive the return value of the function. 18 | * 19 | * @param f the function to call when executing the task 20 | * @param args the arguments to pass to the function 21 | * 22 | * @return the future used to wait on the task and get the result 23 | */ 24 | template 25 | requires std::invocable 26 | std::pair,std::future> 27 | package(Fn f, Args... args){ 28 | std::promise *p = new std::promise; 29 | 30 | // Create a function to package as a task. 31 | auto task_wrapper = std::bind([p, f{std::move(f)}](Args... args){ 32 | if constexpr (std::is_same::value) { 33 | f(std::move(args)...); 34 | p->set_value(); 35 | } else { 36 | p->set_value(std::move(f(std::move(args)...))); 37 | } 38 | }, std::move(args)...); 39 | 40 | // Create a function to package as a future for the user to wait on. 41 | auto ret_wrapper = [p]() -> Ret{ 42 | if constexpr (std::is_same::value) { 43 | p->get_future().get(); 44 | delete p; 45 | } else { 46 | auto temp = std::move(p->get_future().get()); 47 | delete p; 48 | return std::move(temp); 49 | } 50 | }; 51 | return make_pair(task_wrapper, std::async(std::launch::deferred, ret_wrapper)); 52 | } 53 | 54 | /** 55 | * Constructs a thread pool with `num_threads` threads. 56 | */ 57 | base_thread_pool(unsigned int num_threads) : num_threads(num_threads){ 58 | init_mutex.lock(); 59 | init_threads(); 60 | } 61 | 62 | /** 63 | * Destructs a thread pool, waiting on tasks to finish. 64 | */ 65 | virtual ~base_thread_pool(){ 66 | wait(); 67 | } 68 | 69 | /** 70 | * Manages thread execution. This is the function that threads actually run. 71 | * It pulls a task out of the queue and executes it. 72 | */ 73 | void thread_func(){ 74 | // Can't call get_task until parent class is constructed. 75 | init_mutex.lock(); 76 | init_mutex.unlock(); 77 | for(;;){ 78 | auto task = get_task(); 79 | 80 | // If there's nothing to do and we're not ready to join, just 81 | // yield. 82 | if(!task && !join){ 83 | std::this_thread::yield(); 84 | continue; 85 | } 86 | // If there's tasks waiting, do one. 87 | else if(task){ 88 | handle_task(std::move(*task)); 89 | } 90 | // If there's no tasks and we're ready to join, then exit the 91 | // function (effectively joining). 92 | else if(join){ 93 | return; 94 | } 95 | } 96 | } 97 | 98 | /** 99 | * Creates threads for the thread pool. 100 | */ 101 | void init_threads(){ 102 | task_mutex.lock(); 103 | for(unsigned int i = 0;i < num_threads;i++){ 104 | auto f = std::bind(&base_thread_pool::thread_func, this); 105 | threads.push_back(std::move(std::thread(f))); 106 | } 107 | task_mutex.unlock(); 108 | } 109 | 110 | /** 111 | * Waits for threads to exit. Leaves thread pool in unusable state. Used by 112 | * destructors. 113 | */ 114 | void wait() { 115 | task_mutex.lock(); 116 | join = true; 117 | task_mutex.unlock(); 118 | while(threads.size() > 0) { 119 | auto &t = threads.back(); 120 | t.join(); 121 | threads.pop_back(); 122 | } 123 | } 124 | 125 | /** 126 | * Returns the next task, if there is one. None if there isn't. 127 | */ 128 | virtual std::optional get_task() = 0; 129 | 130 | /** 131 | * Executes a task. 132 | */ 133 | virtual void handle_task(Task) = 0; 134 | 135 | /// Must be unlocked in child constructors. 136 | std::mutex init_mutex; 137 | std::mutex task_mutex; 138 | private: 139 | bool join = false; 140 | unsigned int num_threads; 141 | std::list threads; 142 | }; 143 | 144 | class thread_pool : public base_thread_pool>{ 145 | public: 146 | thread_pool(unsigned int); 147 | virtual ~thread_pool(); 148 | 149 | /** 150 | * Pushes a new task to the queue. 151 | * 152 | * @param f the function to call when executing the task 153 | * @param args the arguments to pass to the function 154 | * 155 | * @return the future used to wait on the task and get the result 156 | */ 157 | template 158 | auto async(Fn f, Args... args){ 159 | auto p = package(f, args...); 160 | return std::move(add_task(std::move(p))); 161 | } 162 | 163 | protected: 164 | virtual std::optional> get_task() override; 165 | virtual void handle_task(std::future) override; 166 | 167 | auto add_task(auto p) { 168 | auto t = std::async(std::launch::deferred, p.first); 169 | task_mutex.lock(); 170 | tasks.emplace(std::move(t)); 171 | task_mutex.unlock(); 172 | return std::move(p.second); 173 | } 174 | private: 175 | std::queue> tasks; 176 | }; 177 | 178 | #endif 179 | --------------------------------------------------------------------------------