├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── main.cpp └── src ├── lockfree_ring_buffer.h ├── tasks.cpp └── tasks.h /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: true 7 | AlignConsecutiveDeclarations: true 8 | AlignEscapedNewlinesLeft: false 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: false 15 | AllowShortIfStatementsOnASingleLine: false 16 | AllowShortLoopsOnASingleLine: false 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: false 20 | AlwaysBreakTemplateDeclarations: false 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: false 25 | AfterControlStatement: false 26 | AfterEnum: false 27 | AfterFunction: false 28 | AfterNamespace: false 29 | AfterObjCDeclaration: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | BeforeCatch: false 33 | BeforeElse: false 34 | IndentBraces: false 35 | BreakBeforeBinaryOperators: None 36 | BreakBeforeBraces: Attach 37 | BreakBeforeTernaryOperators: true 38 | BreakConstructorInitializersBeforeComma: true 39 | BreakAfterJavaFieldAnnotations: false 40 | BreakStringLiterals: true 41 | ColumnLimit: 0 42 | CommentPragmas: '^ IWYU pragma:' 43 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 44 | ConstructorInitializerIndentWidth: 4 45 | ContinuationIndentWidth: 4 46 | Cpp11BracedListStyle: true 47 | DerivePointerAlignment: false 48 | DisableFormat: false 49 | ExperimentalAutoDetectBinPacking: false 50 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 51 | FixNamespaceComments: true 52 | IncludeCategories: 53 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 54 | Priority: 2 55 | - Regex: '^(<|"(gtest|isl|json)/)' 56 | Priority: 3 57 | - Regex: '.*' 58 | Priority: 1 59 | IncludeIsMainRegex: '$' 60 | IndentCaseLabels: false 61 | IndentPPDirectives: AfterHash 62 | IndentWidth: 4 63 | IndentWrappedFunctionNames: false 64 | JavaScriptQuotes: Leave 65 | JavaScriptWrapImports: true 66 | KeepEmptyLinesAtTheStartOfBlocks: true 67 | MacroBlockBegin: '' 68 | MacroBlockEnd: '' 69 | MaxEmptyLinesToKeep: 1 70 | NamespaceIndentation: None 71 | ObjCBlockIndentWidth: 4 72 | ObjCSpaceAfterProperty: false 73 | ObjCSpaceBeforeProtocolList: true 74 | PenaltyBreakBeforeFirstCallParameter: 19 75 | PenaltyBreakComment: 300 76 | PenaltyBreakFirstLessLess: 120 77 | PenaltyBreakString: 1000 78 | PenaltyExcessCharacter: 1000000 79 | PenaltyReturnTypeOnItsOwnLine: 60 80 | PointerAlignment: Left 81 | ReflowComments: true 82 | SortIncludes: false 83 | SpaceAfterCStyleCast: false 84 | SpaceAfterTemplateKeyword: true 85 | SpaceBeforeAssignmentOperators: true 86 | SpaceBeforeParens: ControlStatements 87 | SpaceInEmptyParentheses: false 88 | SpacesBeforeTrailingComments: 1 89 | SpacesInAngles: false 90 | SpacesInContainerLiterals: true 91 | SpacesInCStyleCastParentheses: true 92 | SpacesInParentheses: true 93 | SpacesInSquareBrackets: true 94 | Standard: Cpp11 95 | TabWidth: 4 96 | UseTab: ForIndentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | CMakeLists.txt.user 3 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | 3 | project(tasks LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 20) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | set(SOURCES main.cpp) 9 | set(SOURCES ${SOURCES} src/tasks.cpp) 10 | set(SOURCES ${SOURCES} src/tasks.h) 11 | set(SOURCES ${SOURCES} src/lockfree_ring_buffer.h) 12 | 13 | add_executable(tasks ${SOURCES}) 14 | 15 | install(TARGETS tasks 16 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tim Gfrerer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pal_tasks 2 | 3 | Code for the C++20 coroutines based job system discussed in: https://poniesandlight.co.uk/reflect/coroutines_job_system/ 4 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "src/tasks.h" 2 | #include 3 | 4 | #include 5 | 6 | int main() { 7 | 8 | // Create a scheduler with as many hardware threads as possible 9 | // 0 ... No worker threads, just one main thread 10 | // n ... n number of worker threads 11 | // -1 ... As many worker threads as cpus, -1 12 | Scheduler* scheduler = Scheduler::create( -1 ); 13 | 14 | if ( false ) { 15 | 16 | TaskList tasks{}; 17 | auto task_generator = []( int i ) -> Task { 18 | std::cout << "doing some work: " << i++ << std::endl; 19 | 20 | // put this coroutine back on the scheduler 21 | co_await suspend_task(); 22 | 23 | // we have resumed this coroutine from the scheduler 24 | std::cout << "resuming work: " << i++ << std::endl; 25 | 26 | // complete work, signal to the compiler that this is a 27 | // coroutine for political reasons. 28 | co_return; 29 | }; 30 | 31 | // add many more tasks 32 | for ( int i = 0; i != 5; i++ ) { 33 | tasks.add_task( task_generator( i ) ); 34 | } 35 | 36 | // Execute all tasks we find on the task list 37 | scheduler->wait_for_task_list( tasks ); 38 | } 39 | // 40 | 41 | // ---------------------------------------------------------------------- 42 | 43 | srand( 0xdeadbeef ); 44 | 45 | std::cout << "MAIN thread is: " << std::hex << std::this_thread::get_id() << std::endl; 46 | 47 | if ( true ) { 48 | /* 49 | * This test is for whether we can issue tasks from within our 50 | * current task system. 51 | */ 52 | 53 | TaskList another_task_list{}; 54 | auto coro_generator = []( int i, Scheduler* sched ) -> Task { 55 | std::cout << "first level coroutine: " << std::dec << i++ << " on thread: " << std::hex << std::this_thread::get_id() << std::endl 56 | << std::flush; 57 | 58 | std::this_thread::sleep_for( std::chrono::microseconds( rand() % 55000 ) ); 59 | 60 | auto inner_coro_generator = []( int i, int j ) -> Task { 61 | std::cout << "\t executing inner coroutine: " << std::dec << i << ":" << j++ << " on thread: " << std::hex << std::this_thread::get_id() << std::endl 62 | << std::flush; 63 | 64 | std::this_thread::sleep_for( std::chrono::microseconds( rand() % 40000 ) ); 65 | // this yields control back to the await_suspend method, and to our scheduler 66 | co_await suspend_task(); 67 | 68 | std::this_thread::sleep_for( std::chrono::microseconds( rand() % 33000 ) ); 69 | std::cout << "\t executing inner coroutine: " << std::dec << i << ":" << j++ << " on thread: " << std::hex << std::this_thread::get_id() << std::endl; 70 | co_return; 71 | }; 72 | 73 | uint32_t num_tasks = rand() % 30; 74 | 75 | // Create a task list for tasks which are spun off from within this task 76 | TaskList inner_task_list{}; 77 | 78 | for ( int j = 0; j != num_tasks; j++ ) { 79 | inner_task_list.add_task( inner_coro_generator( i, j * 10 ) ); 80 | } 81 | 82 | std::this_thread::sleep_for( std::chrono::nanoseconds( rand() % 40000000 ) ); 83 | 84 | // Suspend this task 85 | co_await suspend_task(); 86 | 87 | // ----------| invariant: we are back after resuming. 88 | 89 | std::cout << "executing first level coroutine: " << std::dec << i << " on thread: " << std::hex << std::this_thread::get_id() << std::endl; 90 | 91 | // Execute, and wait for tasks that we spin out from this task 92 | sched->wait_for_task_list( inner_task_list ); 93 | 94 | // Suspend this task again 95 | co_await suspend_task(); 96 | 97 | // ----------| invariant: we are back after resuming. 98 | 99 | std::cout << "finished first level coroutine: " << std::dec << i << " on thread: " << std::hex << std::this_thread::get_id() << std::endl; 100 | co_return; 101 | }; 102 | 103 | for ( int i = 0; i != 20; i++ ) { 104 | another_task_list.add_task( coro_generator( i * 10, scheduler ) ); 105 | } 106 | 107 | std::cout << "main program starts wait for task list." << std::endl 108 | << std::flush; 109 | 110 | scheduler->wait_for_task_list( another_task_list ); 111 | } 112 | 113 | std::cout << "Back with main program." << std::endl 114 | << std::flush; 115 | 116 | delete scheduler; 117 | 118 | return 0; 119 | } 120 | -------------------------------------------------------------------------------- /src/lockfree_ring_buffer.h: -------------------------------------------------------------------------------- 1 | #ifndef _LOCK_FREE_RING_BUFFER_H_ 2 | #define _LOCK_FREE_RING_BUFFER_H_ 3 | 4 | /* 5 | * Based on the work of Brian Watling for libfiber - see 6 | * 7 | * 8 | * libfiber is: 9 | * 10 | * Copyright (c) 2012-2015, Brian Watling and other contributors 11 | * 12 | * Permission to use, copy, modify, and/or distribute this software for any 13 | * purpose with or without fee is hereby granted, provided that the above 14 | * copyright notice and this permission notice appear in all copies. 15 | * 16 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 17 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 18 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 19 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 20 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 21 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 22 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 23 | */ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | 33 | static_assert( std::atomic_size_t::is_always_lock_free, "atomic_size_t must be always lock free" ); 34 | 35 | inline uint32_t next_power_of_2( uint32_t v ) { 36 | v--; 37 | v |= v >> 1; 38 | v |= v >> 2; 39 | v |= v >> 4; 40 | v |= v >> 8; 41 | v |= v >> 16; 42 | return ++v; 43 | } 44 | 45 | class lockfree_ring_buffer_t { 46 | // high and low are generally used together; no point putting them on separate cache lines 47 | std::atomic_size_t m_high; 48 | char _cache_padding1[ 64 - sizeof( std::atomic_size_t ) ]; 49 | std::atomic_size_t m_low; 50 | char _cache_padding2[ 64 - sizeof( std::atomic_size_t ) ]; 51 | uint32_t m_capacity; 52 | uint32_t m_power_of_2_mod; 53 | // buffer must be last - it spills outside of this struct 54 | std::vector buffer; 55 | 56 | lockfree_ring_buffer_t( const lockfree_ring_buffer_t& ) = delete; 57 | lockfree_ring_buffer_t( lockfree_ring_buffer_t&& ) = delete; 58 | lockfree_ring_buffer_t& operator=( const lockfree_ring_buffer_t& ) = delete; 59 | lockfree_ring_buffer_t& operator=( lockfree_ring_buffer_t&& ) = delete; 60 | 61 | public: 62 | lockfree_ring_buffer_t( uint32_t power_of_2_size ) 63 | : m_capacity( next_power_of_2( power_of_2_size ) ) 64 | , m_power_of_2_mod( m_capacity - 1 ) 65 | , buffer( m_capacity, nullptr ) { 66 | assert( power_of_2_size && power_of_2_size < 32 ); 67 | } 68 | size_t size() { 69 | // read high first; make it look less than or equal to its actual size 70 | const uint64_t high = this->m_high.load( std::memory_order_acquire ); 71 | // load_load_barrier(); 72 | const int64_t size = high - this->m_low.load( std::memory_order_acquire ); 73 | return size >= 0 ? size : 0; 74 | } 75 | 76 | int try_push( void* in ) { 77 | assert( in ); // can't store NULLs; we rely on a NULL to indicate a spot in the buffer has not been written yet 78 | // read low first; this means the buffer will appear larger or equal to its actual size 79 | const uint64_t low = this->m_low.load( std::memory_order_acquire ); 80 | // load_load_barrier(); 81 | uint64_t high = this->m_high.load( std::memory_order_acquire ); 82 | const uint64_t index = high & this->m_power_of_2_mod; 83 | if ( !this->buffer[ index ] && 84 | high - low < this->m_capacity && 85 | this->m_high.compare_exchange_weak( high, high + 1, std::memory_order_release ) ) { 86 | this->buffer[ index ] = in; 87 | return 1; 88 | } 89 | return 0; 90 | } 91 | 92 | void push( void* in ) { 93 | while ( !this->try_push( in ) ) { 94 | if ( this->m_high - this->m_low >= this->m_capacity ) { 95 | // the buffer is full - we must block... 96 | std::this_thread::sleep_for( std::chrono::nanoseconds( 100 ) ); 97 | } 98 | } 99 | } 100 | 101 | void* try_pop() { 102 | // read high first; this means the buffer will appear smaller or equal to its actual size 103 | const uint64_t high = this->m_high.load( std::memory_order_acquire ); 104 | // load_load_barrier(); 105 | uint64_t low = this->m_low.load( std::memory_order_acquire ); 106 | const uint64_t index = low & this->m_power_of_2_mod; 107 | void* const ret = this->buffer[ index ]; 108 | if ( ret && 109 | high > low && 110 | this->m_low.compare_exchange_weak( low, low + 1, std::memory_order_release ) ) { 111 | this->buffer[ index ] = nullptr; 112 | return ret; 113 | } 114 | return nullptr; 115 | } 116 | 117 | void* pop() { 118 | void* ret; 119 | while ( !( ret = try_pop() ) ) { 120 | if ( this->m_high <= this->m_low ) { 121 | // cpu_relax();//the buffer is empty 122 | std::this_thread::sleep_for( std::chrono::nanoseconds( 100 ) ); 123 | } 124 | } 125 | return ret; 126 | } 127 | 128 | // ---------------------------------------------------------------------- 129 | // The following methods are specialisations only useful in the context 130 | // of our job system - we know that while we set up at tasklist we don't 131 | // have any other thread competing for access and are therefore safe to 132 | // do direct access, and re-allocations. 133 | // ---------------------------------------------------------------------- 134 | 135 | // ONLY SAFE IN SINGLE_THREADED ENVIRONMENT 136 | // Dynamically grows the buffer if necessary. You must not do this once 137 | // there is more than one thread modifying the buffer, as it is then unsafe. 138 | int unsafe_initial_dynamic_push( void* in ) { 139 | assert( m_low == 0 && "you must not unsafe push once an item has been popped from this array" ); 140 | if ( this->m_high - this->m_low >= this->m_capacity ) { 141 | this->m_capacity *= 2; // double size 142 | this->buffer.resize( m_capacity, nullptr ); 143 | this->m_power_of_2_mod = this->m_capacity - 1; 144 | } 145 | return try_push( in ); 146 | } 147 | 148 | // ONLY SAFE IN SINGLE_THREADED ENVIRONMENT 149 | // iterates over all elements in the data array and applies the 150 | // call back function 151 | void unsafe_for_each( void ( *fun )( void* item, void* user_data ), void* user_data ) { 152 | size_t low = this->m_low; 153 | size_t high = this->m_high; 154 | for ( size_t i = low; i != high; i++ ) { 155 | fun( buffer[ i ], user_data ); 156 | } 157 | } 158 | }; 159 | 160 | #endif 161 | -------------------------------------------------------------------------------- /src/tasks.cpp: -------------------------------------------------------------------------------- 1 | #include "tasks.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "lockfree_ring_buffer.h" 9 | 10 | using coroutine_handle_t = std::coroutine_handle; 11 | 12 | // A channel is a thread-safe primitive to communicate with worker threads - 13 | // each worker thread has exactly one channel. Once a channel contains 14 | // a payload it is blocked (you cannot push anymore handles onto this channel). 15 | // The channel gets free and ready to receive another handle as soon as the 16 | // worker thread has finished processing the current handle. 17 | struct Channel { 18 | void* handle; // storage for channel payload: one single handle. void means that the channel is free. 19 | std::atomic_flag flag; // signal that the current channel is busy. 20 | 21 | bool try_push( coroutine_handle_t& h ) { 22 | 23 | if ( flag.test_and_set() ) { 24 | // if the current channel was already flagged 25 | // we cannot add anymore work. 26 | return false; 27 | } 28 | 29 | // --------| invariant: current channel is available now 30 | 31 | handle = h.address(); 32 | 33 | // If there is a thread blocked on this operation, we 34 | // unblock it here. 35 | flag.notify_one(); 36 | 37 | return true; 38 | } 39 | 40 | ~Channel() { 41 | 42 | // Once the channel accepts a coroutine handle, it becomes the 43 | // owner of the handle. If there are any leftover valid handles 44 | // that we own when this object dips into ovlivion, we must clean 45 | // them up first. 46 | // 47 | if ( this->handle ) { 48 | // std::cout << "WARNING: leftover task in channel." << std::endl; 49 | // std::cout << "destroying task: " << this->handle << std::endl; 50 | Task::from_address( this->handle ).destroy(); 51 | } 52 | 53 | this->handle = nullptr; 54 | } 55 | }; 56 | 57 | class task_list_o { 58 | lockfree_ring_buffer_t tasks; 59 | std::atomic_size_t num_tasks; // number of tasks, only gets decremented if taks has been removed 60 | 61 | public: 62 | std::atomic_flag block_flag; // flag used to signal that dependent tasks have completed 63 | 64 | task_list_o( uint32_t capacity_hint = 32 ) // start with capacity of 32 65 | : tasks( capacity_hint ) 66 | , num_tasks( 0 ) { 67 | } 68 | 69 | ~task_list_o() { 70 | // If there are any tasks left on the task list, we must destroy them, as we own them. 71 | void* task; 72 | while ( ( task = this->tasks.try_pop() ) ) { 73 | Task::from_address( task ).destroy(); 74 | } 75 | } 76 | 77 | // Push a suspended task back onto the end of the task list 78 | inline void push_task( coroutine_handle_t const& c ) { 79 | tasks.push( c.address() ); 80 | } 81 | 82 | // Get the next task if possible, if there is no next task, 83 | // return an empty coroutine handle. 84 | // An empty coroutine handle will compare true to nullptr 85 | inline coroutine_handle_t pop_task() { 86 | return Task::from_address( tasks.try_pop() ); 87 | } 88 | 89 | // Return the number of tasks which are both in flight and waiting 90 | // 91 | // Note this is not the same as tasks.size() as any tasks which are being 92 | // processed and are in flight will not show up on the task list. 93 | // 94 | // num_tasks gets decremented only if a task was fully completed. 95 | inline size_t get_tasks_count() { 96 | return num_tasks; 97 | } 98 | 99 | // Add a new task to the task list - only allowed in setup phase, 100 | // where only one thread has access to the task list. 101 | void add_task( coroutine_handle_t& c ) { 102 | c.promise().p_task_list = this; 103 | tasks.unsafe_initial_dynamic_push( c.address() ); 104 | num_tasks++; 105 | } 106 | 107 | void tag_all_tasks_with_scheduler( scheduler_impl* p_scheduler ) { 108 | tasks.unsafe_for_each( 109 | []( void* c, void* p_scheduler ) { 110 | coroutine_handle_t::from_address( c ).promise().scheduler = static_cast( p_scheduler ); 111 | }, 112 | p_scheduler ); 113 | } 114 | 115 | void decrement_task_count() { 116 | size_t num_flags = --num_tasks; 117 | if ( num_flags == 0 ) { 118 | block_flag.clear( std::memory_order_release ); 119 | block_flag.notify_one(); // unblock us on block flag. 120 | } 121 | } 122 | }; 123 | 124 | TaskList::TaskList( uint32_t hint_capacity ) 125 | : p_impl( new task_list_o( hint_capacity ) ) { 126 | } 127 | 128 | void TaskList::add_task( Task c ) { 129 | assert( p_impl != nullptr && "task list must be valid. Was this task list already used?" ); 130 | p_impl->add_task( c ); 131 | } 132 | 133 | TaskList::~TaskList() { 134 | // In case that this task list was deleted already, p_impl will 135 | // be nullptr, which means that this delete operator is a no-op. 136 | // otherwise (in case a tasklist has not been used and needs to 137 | // be cleaned up), this will perform the cleanup for us. 138 | delete p_impl; 139 | } 140 | 141 | // ---------------------------------------------------------------------- 142 | 143 | class scheduler_impl { 144 | 145 | bool move_task_to_worker_thread( coroutine_handle_t& c ); 146 | 147 | public: 148 | std::vector channels; // non-owning - channels are owned by their threads 149 | std::vector threads; 150 | 151 | scheduler_impl( int32_t num_worker_threads = 0 ); 152 | 153 | ~scheduler_impl() { 154 | // We must unblock any threads which are currently waiting on a flag signal for more work 155 | // as there is no more work coming, we must artificially signal the flag so that these 156 | // worker threads can resume to completion. 157 | for ( auto* c : channels ) { 158 | if ( c ) { 159 | c->flag.test_and_set(); // Set flag so that if there is a worker blocked on this flag, it may proceed. 160 | c->flag.notify_one(); // Notify the worker thread (if any worker thread is waiting) that the flag has flipped. 161 | // without notify, waiters will not be notified that the flag has flipped. 162 | } 163 | } 164 | // We must wait until all the threads have been joined. 165 | // Deleting a jthread object implicitly stops (sets the stop_token) and joins. 166 | threads.clear(); 167 | } 168 | 169 | // Execute all tasks in the task list, then invalidate the task list object 170 | void wait_for_task_list( TaskList& p_t ); 171 | }; 172 | 173 | scheduler_impl::scheduler_impl( int32_t num_worker_threads ) { 174 | 175 | if ( num_worker_threads < 0 ) { 176 | // If negative, then this means that we must 177 | // count backwards from the number of available hardware threads 178 | num_worker_threads = std::jthread::hardware_concurrency() + num_worker_threads; 179 | } 180 | 181 | assert( num_worker_threads >= 0 && "Inferred number of worker threads must not be negative" ); 182 | 183 | // std::cout << "Initializing scheduler with " << num_worker_threads << " worker threads" << std::endl; 184 | 185 | // Reserve memory so that we can take addresses for channel, 186 | // and don't have to worry about iterator validity 187 | channels.reserve( num_worker_threads ); 188 | threads.reserve( num_worker_threads ); 189 | 190 | // NOTE THAT BY DEFAULT WE DON'T HAVE ANY WORKER THREADS 191 | // 192 | for ( int i = 0; i != num_worker_threads; i++ ) { 193 | channels.emplace_back( new Channel() ); 194 | threads.emplace_back( 195 | // 196 | // Thread worker implementation 197 | // 198 | []( std::stop_token stop_token, Channel* ch ) { 199 | while ( !stop_token.stop_requested() ) { 200 | 201 | if ( ch->handle ) { 202 | coroutine_handle_t::from_address( ch->handle ).resume(); // resume coroutine 203 | ch->handle = nullptr; 204 | // signal that we are ready to receive new tasks 205 | ch->flag.clear( std::memory_order::release ); 206 | continue; 207 | } 208 | 209 | // Wait for flag to be set 210 | // 211 | // The flag is set on any of the following: 212 | // * A new job has been placed in the channel 213 | // * The current task list is empty. 214 | ch->flag.wait( false, std::memory_order::acquire ); 215 | } 216 | 217 | // Channel is owned by the thread - when the thread falls out of scope 218 | // that means that the channel gets deleted, too. 219 | delete ch; 220 | }, 221 | channels.back() ); 222 | } 223 | } 224 | 225 | void scheduler_impl::wait_for_task_list( TaskList& p_t ) { 226 | 227 | if ( p_t.p_impl == nullptr ) { 228 | assert( false && "Task list must have been freshly initialised. Has this task list been waited for already?" ); 229 | return; 230 | } 231 | 232 | // --------| Invariant: TaskList is valid 233 | 234 | // Execute tasks in this task list until there are no more tasks left 235 | // to execute. 236 | 237 | // Before we start executing tasks we must take ownership of them 238 | // by tagging them so that they know which scheduler they belong to. 239 | p_t.p_impl->tag_all_tasks_with_scheduler( this ); 240 | 241 | // Distribute work, as long as there is work to distribute 242 | 243 | while ( p_t.p_impl->get_tasks_count() ) { 244 | 245 | // ----------| Invariant: There are Tasks in this Task List which have not yet completed 246 | 247 | coroutine_handle_t c = ( p_t.p_impl )->pop_task(); 248 | 249 | if ( c == nullptr ) { 250 | 251 | // We could not fetch a task from the task list - this means 252 | // that there are tasks in-progress that we must wait for. 253 | 254 | if ( p_t.p_impl->block_flag.test_and_set( std::memory_order::acq_rel ) ) { 255 | std::cout << "blocking thread " << std::this_thread::get_id() << " on [" << p_t.p_impl << "]" << std::endl; 256 | // Wait for the flag to be set - this is the case if any of these happen: 257 | // * the scheduler is destroyed 258 | // * the last task of the task list has completed, and the task list is now empty. 259 | p_t.p_impl->block_flag.wait( true, std::memory_order::acquire ); 260 | std::cout << "resuming thread " << std::this_thread::get_id() << " on [" << p_t.p_impl << "]" << std::endl; 261 | } else { 262 | std::cout << "spinning thread " << std::this_thread::get_id() << " on [" << p_t.p_impl << "]" << std::endl; 263 | } 264 | 265 | continue; 266 | } 267 | 268 | // ----------| Invariant: current coroutine is valid 269 | 270 | // Find a free channel. if there is, then place this handle in the channel, 271 | // which means that it will be executed on the worker thread associated 272 | // with this channel. 273 | 274 | if ( move_task_to_worker_thread( c ) ) { 275 | // Pushing consumes the coroutine handle - that is, it becomes owned by the channel 276 | // who owns it for the worker thread. 277 | // 278 | // If we made it in here, the handle was successfully offloaded to a worker thread. 279 | // 280 | // The worker thread must now execute the payload, and the task will decrement the 281 | // counter for the current TaskList upon completion. 282 | continue; 283 | } 284 | 285 | // --------| Invariant: All worker threads are busy - or there are no worker threads: we must execute on this thread 286 | c(); 287 | } 288 | 289 | // Once all tasks have been complete, release task list 290 | delete p_t.p_impl; // Free task list impl 291 | p_t.p_impl = nullptr; // Signal to any future users that this task list has been used already 292 | } 293 | 294 | // ---------------------------------------------------------------------- 295 | 296 | inline bool scheduler_impl::move_task_to_worker_thread( coroutine_handle_t& c ) { 297 | // Iterate over all channels. If we can place the coroutine 298 | // on a channel, do so. 299 | for ( auto& ch : channels ) { 300 | if ( true == ch->try_push( c ) ) { 301 | return true; 302 | } 303 | } 304 | return false; 305 | } 306 | 307 | // ---------------------------------------------------------------------- 308 | 309 | Scheduler::Scheduler( int32_t num_worker_threads ) 310 | : p_impl( new scheduler_impl( num_worker_threads ) ) { 311 | } 312 | 313 | void Scheduler::wait_for_task_list( TaskList& p_t ) { 314 | p_impl->wait_for_task_list( p_t ); 315 | } 316 | 317 | Scheduler* Scheduler::create( int32_t num_worker_threads ) { 318 | return new Scheduler( num_worker_threads ); 319 | } 320 | 321 | Scheduler::~Scheduler() { 322 | delete p_impl; 323 | } 324 | 325 | // ---------------------------------------------------------------------- 326 | 327 | void suspend_task::await_suspend( std::coroutine_handle h ) noexcept { 328 | 329 | // ----------| Invariant: At this point the coroutine pointed to by h 330 | // has been fully suspended. This is guaranteed by the c++ standard. 331 | 332 | auto& promise = h.promise(); 333 | 334 | auto& task_list = promise.p_task_list; 335 | 336 | // Put the current coroutine to the back of the scheduler queue 337 | // as it has been fully suspended at this point. 338 | 339 | task_list->push_task( promise.get_return_object() ); 340 | 341 | { 342 | // We must unblock/awake the scheduling thread each time we suspend 343 | // a coroutine so that the scheduling worker may pick up work again, 344 | // in case it had been put to sleep earlier. 345 | promise.p_task_list->block_flag.clear( std::memory_order_release ); 346 | promise.p_task_list->block_flag.notify_one(); // wake up worker just in case 347 | } 348 | 349 | { 350 | // --- Eager Workers --- 351 | // 352 | // Eagerly try to fetch & execute the next task from the front of the 353 | // scheduler queue - 354 | // We do this so that multiple threads can share the 355 | // scheduling workload. 356 | // 357 | // But we can also disable that, so that there is only one thread 358 | // that does the scheduling, and removing elements from the 359 | // queue. 360 | 361 | coroutine_handle_t c = task_list->pop_task(); 362 | 363 | if ( c ) { 364 | assert( !c.done() && "task must not be done" ); 365 | c(); 366 | } 367 | } 368 | 369 | // Note: Once we drop off here, control will return to where the resume() 370 | // command that brought us here was issued. 371 | } 372 | 373 | // ---------------------------------------------------------------------- 374 | 375 | void finalize_task::await_suspend( std::coroutine_handle h ) noexcept { 376 | // This is the last time that this coroutine will be awakened 377 | // we do not suspend it anymore after this 378 | h.promise().p_task_list->decrement_task_count(); 379 | h.destroy(); // are we allowed to destroy here? 380 | } 381 | -------------------------------------------------------------------------------- /src/tasks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | struct TaskPromise; // ffdecl. 6 | 7 | struct Task : std::coroutine_handle { 8 | using promise_type = ::TaskPromise; 9 | }; 10 | 11 | struct suspend_task { 12 | // if await_ready is false, then await_suspend will be called 13 | constexpr bool await_ready() noexcept { 14 | return false; 15 | }; 16 | void await_suspend( std::coroutine_handle h ) noexcept; 17 | void await_resume() noexcept {}; 18 | }; 19 | 20 | struct finalize_task { 21 | // if await_ready is false, then await_suspend will be called 22 | constexpr bool await_ready() noexcept { 23 | return false; 24 | }; 25 | void await_suspend( std::coroutine_handle h ) noexcept; 26 | void await_resume() noexcept {}; 27 | }; 28 | 29 | 30 | class scheduler_impl; // ffdecl, pimpl 31 | class task_list_o; // ffdecl 32 | class TaskList; // ffdecl 33 | 34 | class Scheduler { 35 | scheduler_impl* p_impl = nullptr; 36 | 37 | Scheduler( int32_t num_worker_threads ); 38 | 39 | Scheduler( const Scheduler& ) = delete; 40 | Scheduler( Scheduler&& ) = delete; // move constructor 41 | Scheduler& operator=( const Scheduler& ) = delete; 42 | Scheduler& operator=( Scheduler&& ) = delete; // move assignment 43 | 44 | public: 45 | // Execute all tasks in the task list, then free the task list object 46 | // this takes possession of the task list object, and acts as if it was 47 | // a blocking call. 48 | // 49 | // Once this call returns, the TaskList that was given as a parameter 50 | // has been consumed, and you should not re-use it. 51 | void wait_for_task_list( TaskList& p_t ); 52 | 53 | // Create a scheduler with as many hardware threads as possible 54 | // 0 ... No worker threads, just one main thread 55 | // n ... n number of worker threads 56 | // -1 ... As many worker threads as cpus, -1 57 | static Scheduler* create( int32_t num_worker_threads = 0 ); 58 | 59 | ~Scheduler(); 60 | }; 61 | 62 | class TaskList { 63 | 64 | task_list_o* p_impl; // owning 65 | 66 | TaskList( const TaskList& ) = delete; 67 | TaskList( TaskList&& ) = delete; // move constructor 68 | TaskList& operator=( const TaskList& ) = delete; 69 | TaskList& operator=( TaskList&& ) = delete; // move assignment 70 | 71 | public: 72 | TaskList( uint32_t hint_capacity = 1 ); // default constructor 73 | 74 | ~TaskList(); 75 | 76 | void add_task( Task c ); 77 | 78 | friend class scheduler_impl; 79 | }; 80 | 81 | struct TaskPromise { 82 | Task get_return_object() { 83 | return { Task::from_promise( *this ) }; 84 | } 85 | std::suspend_always initial_suspend() noexcept { 86 | return {}; 87 | } 88 | finalize_task final_suspend() noexcept { return {}; } 89 | void return_void(){}; 90 | void unhandled_exception(){}; 91 | scheduler_impl* scheduler = nullptr; // owned by scheduler 92 | task_list_o* p_task_list = nullptr; // owned by scheduler 93 | }; 94 | --------------------------------------------------------------------------------