├── rustfmt.toml ├── .vscode └── settings.json ├── derive-async-local ├── README.md ├── Cargo.toml └── src │ ├── lib.rs │ └── entry.rs ├── .gitignore ├── Taskfile.yml ├── LICENSE ├── Cargo.toml ├── README.md └── src ├── runtime.rs └── lib.rs /rustfmt.toml: -------------------------------------------------------------------------------- 1 | format_code_in_doc_comments = true 2 | group_imports = "StdExternalCrate" 3 | imports_granularity = "Crate" 4 | newline_style = "Unix" 5 | normalize_comments = true 6 | tab_spaces = 2 7 | wrap_comments = false -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.tabSize": 2, 3 | "editor.formatOnSave": true, 4 | "rust-analyzer.rustfmt.extraArgs": [ 5 | "+nightly" 6 | ], 7 | "yaml.schemas": { 8 | "https://taskfile.dev/schema.json": [ 9 | "**/Taskfile.yml", 10 | ] 11 | } 12 | } -------------------------------------------------------------------------------- /derive-async-local/README.md: -------------------------------------------------------------------------------- 1 | # Async Local Derives 2 | ![License](https://img.shields.io/badge/license-MIT-green.svg) 3 | [![Cargo](https://img.shields.io/crates/v/derive-async-local.svg)](https://crates.io/crates/derive-async-local) 4 | [![Documentation](https://docs.rs/derive-async-local/badge.svg)](https://docs.rs/derive-async-local) 5 | 6 | Derives for [async-local](https://crates.io/crates/async-local) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | 13 | # Added by cargo 14 | 15 | /target 16 | /Cargo.lock 17 | -------------------------------------------------------------------------------- /derive-async-local/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "derive-async-local" 3 | authors = ["Thomas Sieverding "] 4 | edition = "2024" 5 | version = "6.0.2" 6 | description = "Derives for async-local" 7 | readme = "./README.md" 8 | license = "MIT" 9 | repository = "https://github.com/Bajix/async-local/" 10 | rust-version = "1.85" 11 | 12 | [lib] 13 | test = false 14 | doctest = false 15 | proc-macro = true 16 | 17 | [dependencies] 18 | proc-macro2 = "1" 19 | quote = "1" 20 | syn = { version = "2", features = ["full"] } 21 | 22 | [features] 23 | rt-multi-thread = [] 24 | -------------------------------------------------------------------------------- /Taskfile.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | env: 4 | TARGET: x86_64-apple-darwin 5 | 6 | tasks: 7 | default: 8 | cmds: 9 | - task: test 10 | - task: clippy-tests 11 | - task: test-address-sanitizer 12 | 13 | test: 14 | cmds: 15 | - cargo test -- --nocapture 16 | 17 | test-miri: 18 | cmds: 19 | - cargo miri test -Z build-std --target $TARGET -- --nocapture 20 | env: 21 | MIRIFLAGS: -Zmiri-backtrace=full -Zmiri-disable-isolation 22 | 23 | doc: 24 | cmds: 25 | - cargo +nightly doc -p async-local --open 26 | env: 27 | RUSTDOCFLAGS: --cfg docsrs 28 | 29 | clippy-tests: 30 | cmds: 31 | - cargo clippy --tests 32 | 33 | check-tests: 34 | cmds: 35 | - cargo check --tests 36 | 37 | test-address-sanitizer: 38 | cmds: 39 | - cargo test -Z build-std --target $TARGET -- --nocapture 40 | ev: 41 | RUSTFLAGS: -Z sanitizer=address 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Thomas Sieverding 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Thomas Sieverding "] 3 | edition = "2024" 4 | name = "async-local" 5 | version = "6.0.2" 6 | description = "For using thread locals within an async context across await points" 7 | readme = "./README.md" 8 | license = "MIT" 9 | repository = "https://github.com/Bajix/async-local/" 10 | rust-version = "1.85" 11 | 12 | [dependencies] 13 | ctor = { version = "0.4.1", default-features = false, features = [ 14 | "proc_macro", 15 | ] } 16 | derive-async-local = { version = "6.0.2", path = "./derive-async-local" } 17 | generativity = "1.1" 18 | linkme = "0.3.32" 19 | num_cpus = "1.16" 20 | tokio = { version = "1", features = ["rt", "rt-multi-thread"], optional = true } 21 | 22 | [dev-dependencies] 23 | tokio = { version = "1", features = ["macros"] } 24 | 25 | [target.'cfg(loom)'.dependencies] 26 | loom = { version = "0.7", features = [] } 27 | 28 | [lib] 29 | doctest = false 30 | bench = false 31 | 32 | [features] 33 | default = ["rt"] 34 | 35 | rt = ["tokio/rt"] 36 | 37 | # Enable Tokio multi_thread runtime flavor 38 | rt-multi-thread = [ 39 | "rt", 40 | "tokio/rt-multi-thread", 41 | "derive-async-local/rt-multi-thread", 42 | ] 43 | 44 | compat = [] 45 | 46 | [workspace] 47 | members = ["derive-async-local"] 48 | 49 | [profile.release] 50 | lto = "fat" 51 | opt-level = 3 52 | codegen-units = 1 53 | 54 | [package.metadata.docs.rs] 55 | rustdoc-args = ["--cfg", "docsrs"] 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async Local 2 | 3 | ![License](https://img.shields.io/badge/license-MIT-green.svg) 4 | [![Cargo](https://img.shields.io/crates/v/async-local.svg)](https://crates.io/crates/async-local) 5 | [![Documentation](https://docs.rs/async-local/badge.svg)](https://docs.rs/async-local) 6 | 7 | ## Unlocking the potential of thread-locals in an async context 8 | 9 | This crate enables references to thread locals to be used in an async context across await points or within blocking threads managed by the Tokio runtime 10 | 11 | ## How it works 12 | 13 | By configuring Tokio with a barrier to rendezvous worker threads during shutdown, it can be gauranteed that no task will outlive thread local data belonging to worker threads. With this, pointers to thread locals constrained by invariant lifetimes are guaranteed to be of a valid lifetime suitable for use accross await points. 14 | 15 | ## Runtime Configuration 16 | 17 | The optimization that this crate provides require that the [async_local::main](https://docs.rs/async-local/latest/async_local/attr.main.html) or [async_local::test](https://docs.rs/async-local/latest/async_local/attr.test.html) macro be used to configure the Tokio runtime. This is enforced by a pre-main check that asserts [async_local::main](https://docs.rs/async-local/latest/async_local/attr.main.html) has been used. 18 | 19 | ## Compatibility Mode 20 | 21 | Enabling the `compat` feature flag will allow this crate to be used with any runtime configuration by disabling the performance optimization this crate provides and instead internally using `std::sync::Arc` 22 | 23 | ## Example usage 24 | 25 | ```rust 26 | #[cfg(test)] 27 | mod tests { 28 | use std::sync::atomic::{AtomicUsize, Ordering}; 29 | 30 | use async_local::{AsyncLocal, Context}; 31 | use generativity::make_guard; 32 | use tokio::task::yield_now; 33 | 34 | thread_local! { 35 | static COUNTER: Context = Context::new(AtomicUsize::new(0)); 36 | } 37 | 38 | #[async_local::test] 39 | async fn it_increments() { 40 | make_guard!(guard); 41 | let counter = COUNTER.local_ref(guard); 42 | yield_now().await; 43 | counter.fetch_add(1, Ordering::SeqCst); 44 | } 45 | } 46 | ``` -------------------------------------------------------------------------------- /src/runtime.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt::{self, Debug}, 3 | io, 4 | sync::{ 5 | Arc, Condvar, Mutex, 6 | atomic::{AtomicUsize, Ordering}, 7 | }, 8 | }; 9 | 10 | use linkme::distributed_slice; 11 | 12 | use crate::{BarrierContext, CONTEXT}; 13 | 14 | #[derive(Default)] 15 | struct ShutdownBarrier { 16 | guard_count: AtomicUsize, 17 | shutdown_finalized: Mutex, 18 | cvar: Condvar, 19 | } 20 | 21 | #[derive(PartialEq, Eq)] 22 | pub(crate) enum Kind { 23 | CurrentThread, 24 | #[cfg(feature = "rt-multi-thread")] 25 | MultiThread, 26 | } 27 | 28 | #[doc(hidden)] 29 | /// Builds Tokio runtime configured with a shutdown barrier 30 | pub struct Builder { 31 | kind: Kind, 32 | worker_threads: usize, 33 | inner: tokio::runtime::Builder, 34 | } 35 | 36 | impl Builder { 37 | /// Returns a new builder with the current thread scheduler selected. 38 | pub fn new_current_thread() -> Builder { 39 | Builder { 40 | kind: Kind::CurrentThread, 41 | worker_threads: 1, 42 | inner: tokio::runtime::Builder::new_current_thread(), 43 | } 44 | } 45 | 46 | /// Returns a new builder with the multi thread scheduler selected. 47 | #[cfg(feature = "rt-multi-thread")] 48 | pub fn new_multi_thread() -> Builder { 49 | let worker_threads = std::env::var("TOKIO_WORKER_THEADS") 50 | .ok() 51 | .and_then(|worker_threads| worker_threads.parse().ok()) 52 | .unwrap_or_else(num_cpus::get); 53 | 54 | Builder { 55 | kind: Kind::MultiThread, 56 | worker_threads, 57 | inner: tokio::runtime::Builder::new_multi_thread(), 58 | } 59 | } 60 | 61 | /// Enables both I/O and time drivers. 62 | pub fn enable_all(&mut self) -> &mut Self { 63 | self.inner.enable_all(); 64 | self 65 | } 66 | 67 | /// Sets the number of worker threads the [`Runtime`] will use. 68 | /// 69 | /// This can be any number above 0 though it is advised to keep this value 70 | /// on the smaller side. 71 | /// 72 | /// This will override the value read from environment variable `TOKIO_WORKER_THREADS`. 73 | /// 74 | /// # Default 75 | /// 76 | /// The default value is the number of cores available to the system. 77 | /// 78 | /// When using the `current_thread` runtime this method has no effect. 79 | /// 80 | /// # Panics 81 | /// 82 | /// This will panic if `val` is not larger than `0`. 83 | #[track_caller] 84 | pub fn worker_threads(&mut self, val: usize) -> &mut Self { 85 | assert!(val > 0, "Worker threads cannot be set to 0"); 86 | if self.kind.ne(&Kind::CurrentThread) { 87 | self.worker_threads = val; 88 | self.inner.worker_threads(val); 89 | } 90 | self 91 | } 92 | 93 | /// Creates a Tokio Runtime configured with a barrier that rendezvous worker threads during shutdown as to ensure tasks never outlive local data owned by worker threads 94 | pub fn build(&mut self) -> io::Result { 95 | let worker_threads = self.worker_threads; 96 | let barrier = Arc::new(ShutdownBarrier::default()); 97 | 98 | let on_thread_start = { 99 | let barrier = barrier.clone(); 100 | move || { 101 | let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release); 102 | 103 | CONTEXT.with(|context| { 104 | if thread_count.ge(&worker_threads) { 105 | *context.borrow_mut() = Some(BarrierContext::PoolWorker) 106 | } else { 107 | *context.borrow_mut() = Some(BarrierContext::RuntimeWorker) 108 | } 109 | }); 110 | } 111 | }; 112 | 113 | let on_thread_stop = move || { 114 | let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel); 115 | 116 | CONTEXT.with(|context| { 117 | if thread_count.eq(&1) { 118 | *barrier.shutdown_finalized.lock().unwrap() = true; 119 | barrier.cvar.notify_all(); 120 | } else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) { 121 | let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap(); 122 | while !*shutdown_finalized { 123 | shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap(); 124 | } 125 | } 126 | }); 127 | }; 128 | 129 | self 130 | .inner 131 | .on_thread_start(on_thread_start) 132 | .on_thread_stop(on_thread_stop) 133 | .build() 134 | .map(Runtime::new) 135 | } 136 | } 137 | 138 | #[doc(hidden)] 139 | pub struct Runtime(tokio::runtime::Runtime); 140 | 141 | impl Debug for Runtime { 142 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 143 | self.0.fmt(f) 144 | } 145 | } 146 | 147 | impl Runtime { 148 | fn new(inner: tokio::runtime::Runtime) -> Self { 149 | Runtime(inner) 150 | } 151 | /// Runs a future to completion on the Tokio runtime. This is the 152 | /// runtime's entry point. 153 | /// 154 | /// This runs the given future on the current thread, blocking until it is 155 | /// complete, and yielding its resolved result. Any tasks or timers 156 | /// which the future spawns internally will be executed on the runtime. 157 | /// 158 | /// # Non-worker future 159 | /// 160 | /// Note that the future required by this function does not run as a 161 | /// worker. The expectation is that other tasks are spawned by the future here. 162 | /// Awaiting on other futures from the future provided here will not 163 | /// perform as fast as those spawned as workers. 164 | /// 165 | /// # Panics 166 | /// 167 | /// This function panics if the provided future panics, or if called within an 168 | /// asynchronous execution context. 169 | /// 170 | /// # Safety 171 | /// This is internal to async_local and is meant to be used exclusively with #[async_local::main] and #[async_local::test]. 172 | #[track_caller] 173 | pub unsafe fn block_on(self, future: F) -> F::Output { 174 | unsafe { self.run(|handle| handle.block_on(future)) } 175 | } 176 | 177 | pub unsafe fn run(self, f: F) -> Output 178 | where 179 | F: for<'a> FnOnce(&'a tokio::runtime::Runtime) -> Output, 180 | { 181 | CONTEXT.with(|context| *context.borrow_mut() = Some(BarrierContext::Owner)); 182 | 183 | let output = f(&self.0); 184 | 185 | drop(self); 186 | 187 | CONTEXT.with(|context| *context.borrow_mut() = None::); 188 | 189 | output 190 | } 191 | } 192 | 193 | #[doc(hidden)] 194 | #[derive(Debug, PartialEq, Eq)] 195 | pub enum RuntimeContext { 196 | Main, 197 | Test, 198 | } 199 | 200 | #[doc(hidden)] 201 | #[distributed_slice] 202 | pub static RUNTIMES: [RuntimeContext]; 203 | 204 | #[cfg(not(feature = "compat"))] 205 | #[ctor::ctor] 206 | fn assert_runtime_configured() { 207 | if RUNTIMES.is_empty() { 208 | panic!( 209 | "The #[async_local::main] or #[async_local::test] macro must be used to configure the Tokio runtime for use with the `async-local` crate. For compatibilty with other async runtime configurations, the `compat` feature can be used to disable the optimizations this crate provides" 210 | ); 211 | } 212 | 213 | if RUNTIMES 214 | .iter() 215 | .fold(0, |acc, context| { 216 | if context.eq(&RuntimeContext::Main) { 217 | acc + 1 218 | } else { 219 | acc 220 | } 221 | }) 222 | .gt(&1) 223 | { 224 | panic!("The #[async_local::main] macro cannot be used more than once"); 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /derive-async-local/src/lib.rs: -------------------------------------------------------------------------------- 1 | use proc_macro2::Span; 2 | use quote::quote; 3 | use syn::{ 4 | Data, DeriveInput, GenericArgument, PathArguments, Type, TypePath, parse::Error, 5 | parse_macro_input, 6 | }; 7 | 8 | mod entry; 9 | 10 | fn is_context(type_path: &TypePath) -> bool { 11 | let segments: Vec<_> = type_path 12 | .path 13 | .segments 14 | .iter() 15 | .map(|segment| segment.ident.to_string()) 16 | .collect(); 17 | 18 | matches!( 19 | *segments 20 | .iter() 21 | .map(String::as_ref) 22 | .collect::>() 23 | .as_slice(), 24 | ["async_local", "Context"] | ["Context"] 25 | ) 26 | } 27 | 28 | /// Derive [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context`](https://docs.rs/async-local/latest/async_local/struct.Context.html)> and [`AsContext`](https://docs.rs/async-local/latest/async_local/trait.AsContext.html) for a struct 29 | #[proc_macro_derive(AsContext)] 30 | pub fn derive_as_context(input: proc_macro::TokenStream) -> proc_macro::TokenStream { 31 | let input = parse_macro_input!(input as DeriveInput); 32 | let ident = &input.ident; 33 | let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); 34 | 35 | if let Some(err) = input 36 | .generics 37 | .lifetimes() 38 | .map(|lifetime| Error::new_spanned(lifetime, "cannot derive AsContext with lifetimes")) 39 | .reduce(|mut err, other| { 40 | err.combine(other); 41 | err 42 | }) 43 | { 44 | return err.into_compile_error().into(); 45 | } 46 | 47 | let data_struct = if let Data::Struct(data_struct) = &input.data { 48 | data_struct 49 | } else { 50 | return Error::new(Span::call_site(), "can only derive AsContext on structs") 51 | .into_compile_error() 52 | .into(); 53 | }; 54 | 55 | let path_fields: Vec<_> = data_struct 56 | .fields 57 | .iter() 58 | .filter_map(|field| { 59 | if let Type::Path(type_path) = &field.ty { 60 | Some((field, type_path)) 61 | } else { 62 | None 63 | } 64 | }) 65 | .collect(); 66 | 67 | let wrapped_context_error = path_fields 68 | .iter() 69 | .filter(|(_, type_path)| { 70 | if let Some(segment) = type_path.path.segments.last() { 71 | if let PathArguments::AngleBracketed(inner) = &segment.arguments { 72 | if let Some(GenericArgument::Type(Type::Path(type_path))) = inner.args.first() { 73 | return is_context(type_path); 74 | } 75 | } 76 | } 77 | false 78 | }) 79 | .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be wrapped in a pointer type nor cell type and must not be invalidated nor repurposed until dropped")) 80 | .reduce(|mut err, other| { 81 | err.combine(other); 82 | err 83 | }); 84 | 85 | if let Some(err) = wrapped_context_error { 86 | return err.into_compile_error().into(); 87 | } 88 | 89 | let context_paths: Vec<_> = path_fields 90 | .iter() 91 | .filter(|(_, type_path)| is_context(type_path)) 92 | .collect(); 93 | 94 | if context_paths.len().eq(&0) { 95 | return Error::new(Span::call_site(), "struct must use Context exactly once") 96 | .into_compile_error() 97 | .into(); 98 | } 99 | 100 | if context_paths.len().gt(&1) { 101 | return context_paths 102 | .into_iter() 103 | .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be used more than once")) 104 | .reduce(|mut err, other| { 105 | err.combine(other); 106 | err 107 | }) 108 | .unwrap() 109 | .into_compile_error() 110 | .into(); 111 | } 112 | 113 | let (field, type_path) = context_paths.into_iter().next().unwrap(); 114 | 115 | let context_ident = &field.ident; 116 | 117 | let ref_type = type_path.path.segments.last().and_then(|segment| { 118 | if let PathArguments::AngleBracketed(ref_type) = &segment.arguments { 119 | Some(&ref_type.args) 120 | } else { 121 | None 122 | } 123 | }); 124 | 125 | let expanded = quote!( 126 | impl #impl_generics AsRef<#type_path> for #ident #ty_generics #where_clause { 127 | fn as_ref(&self) -> &#type_path { 128 | &self.#context_ident 129 | } 130 | } 131 | 132 | unsafe impl #impl_generics async_local::AsContext for #ident #ty_generics #where_clause { 133 | type Target = #ref_type; 134 | } 135 | ); 136 | 137 | expanded.into() 138 | } 139 | 140 | /// Configures main to be executed by the selected Tokio runtime 141 | /// 142 | /// # Borrowing the runtime 143 | /// 144 | /// To borrow the runtime directly, add as a function argument 145 | /// 146 | /// ``` 147 | /// #[async_local::main(flavor = "multi_thread", worker_threads = 10)] 148 | /// fn main(runtime: &tokio::runtime::Runtime) {} 149 | /// ``` 150 | /// 151 | /// # Non-worker async function 152 | /// 153 | /// Note that the async function marked with this macro does not run as a 154 | /// worker. The expectation is that other tasks are spawned by the function here. 155 | /// Awaiting on other futures from the function provided here will not 156 | /// perform as fast as those spawned as workers. 157 | /// 158 | /// # Multi-threaded runtime 159 | /// 160 | /// To use the multi-threaded runtime, the macro can be configured using 161 | /// ``` 162 | /// #[async_local::main(flavor = "multi_thread", worker_threads = 10)] 163 | /// # async fn main() {} 164 | /// ``` 165 | /// 166 | /// The `worker_threads` option configures the number of worker threads, and 167 | /// defaults to the number of cpus on the system. This is the default flavor. 168 | /// 169 | /// Note: The multi-threaded runtime requires the `rt-multi-thread` feature 170 | /// flag. 171 | /// 172 | /// # Current thread runtime 173 | /// 174 | /// To use the single-threaded runtime known as the `current_thread` runtime, 175 | /// the macro can be configured using 176 | /// ``` 177 | /// #[async_local::main(flavor = "current_thread")] 178 | /// # async fn main() {} 179 | /// ``` 180 | /// ## Usage 181 | /// 182 | /// ### Using the multi-thread runtime 183 | /// ```rust 184 | /// #[async_local::main] 185 | /// async fn main() { 186 | /// println!("Hello world"); 187 | /// } 188 | /// ``` 189 | /// 190 | /// ### Using current thread runtime 191 | /// 192 | /// The basic scheduler is single-threaded. 193 | /// ```rust 194 | /// #[async_local::main(flavor = "current_thread")] 195 | /// async fn main() { 196 | /// println!("Hello world"); 197 | /// } 198 | /// ``` 199 | /// 200 | /// ### Set number of worker threads 201 | /// ```rust 202 | /// #[async_local::main(worker_threads = 2)] 203 | /// async fn main() { 204 | /// println!("Hello world"); 205 | /// } 206 | /// ``` 207 | /// 208 | /// ### Configure the runtime to start with time paused 209 | /// ```rust 210 | /// #[async_local::main(flavor = "current_thread", start_paused = true)] 211 | /// async fn main() { 212 | /// println!("Hello world"); 213 | /// } 214 | /// ``` 215 | /// 216 | /// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio`. 217 | #[proc_macro_attribute] 218 | pub fn main( 219 | args: proc_macro::TokenStream, 220 | item: proc_macro::TokenStream, 221 | ) -> proc_macro::TokenStream { 222 | entry::main(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into() 223 | } 224 | 225 | /// Marks async function to be executed by runtime, suitable to test environment. 226 | /// 227 | /// # Borrowing the runtime 228 | /// 229 | /// To borrow the runtime directly, add as a function argument 230 | /// 231 | /// ``` 232 | /// #[async_local::test(flavor = "multi_thread", worker_threads = 10)] 233 | /// fn test(runtime: &tokio::runtime::Runtime) { 234 | /// runtime.block_on(async { 235 | /// assert!(true); 236 | /// }); 237 | /// } 238 | /// ``` 239 | /// 240 | /// # Multi-threaded runtime 241 | /// 242 | /// To use the multi-threaded runtime, the macro can be configured using 243 | /// ```no_run 244 | /// #[async_local::test(flavor = "multi_thread", worker_threads = 1)] 245 | /// async fn my_test() { 246 | /// assert!(true); 247 | /// } 248 | /// ``` 249 | /// 250 | /// The `worker_threads` option configures the number of worker threads, and 251 | /// defaults to the number of cpus on the system. 252 | /// 253 | /// Note: The multi-threaded runtime requires the `rt-multi-thread` feature 254 | /// flag. 255 | /// 256 | /// # Current thread runtime 257 | /// 258 | /// The default test runtime is single-threaded. Each test gets a 259 | /// separate current-thread runtime. 260 | /// ```no_run 261 | /// #[async_local::test] 262 | /// async fn my_test() { 263 | /// assert!(true); 264 | /// } 265 | /// ``` 266 | /// 267 | /// ## Usage 268 | /// 269 | /// ### Using the multi-thread runtime 270 | /// ```no_run 271 | /// #[async_local::test(flavor = "multi_thread")] 272 | /// async fn my_test() { 273 | /// assert!(true); 274 | /// } 275 | /// ``` 276 | /// 277 | /// ### Using current thread runtime 278 | /// ```no_run 279 | /// #[async_local::test] 280 | /// async fn my_test() { 281 | /// assert!(true); 282 | /// } 283 | /// ``` 284 | /// 285 | /// ### Set number of worker threads 286 | /// ```no_run 287 | /// #[async_local::test(flavor = "multi_thread", worker_threads = 2)] 288 | /// async fn my_test() { 289 | /// assert!(true); 290 | /// } 291 | /// ``` 292 | /// 293 | /// ### Configure the runtime to start with time paused 294 | /// ```no_run 295 | /// #[async_local::test(start_paused = true)] 296 | /// async fn my_test() { 297 | /// assert!(true); 298 | /// } 299 | /// ``` 300 | /// 301 | /// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio``. 302 | #[proc_macro_attribute] 303 | pub fn test( 304 | args: proc_macro::TokenStream, 305 | item: proc_macro::TokenStream, 306 | ) -> proc_macro::TokenStream { 307 | entry::test(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into() 308 | } 309 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(test, feature(exit_status_error))] 2 | #![cfg_attr(docsrs, feature(doc_cfg))] 3 | 4 | extern crate self as async_local; 5 | 6 | /// A Tokio Runtime builder that configures a barrier to rendezvous worker threads during shutdown to ensure tasks never outlive local data owned by worker threads 7 | #[doc(hidden)] 8 | #[cfg(all(not(loom), feature = "rt"))] 9 | #[path = "runtime.rs"] 10 | pub mod __runtime; 11 | 12 | #[cfg(not(feature = "compat"))] 13 | use std::ptr::addr_of; 14 | #[cfg(feature = "compat")] 15 | use std::sync::Arc; 16 | #[cfg(not(loom))] 17 | use std::thread::LocalKey; 18 | use std::{cell::RefCell, ops::Deref}; 19 | 20 | pub use derive_async_local::{AsContext, main, test}; 21 | use generativity::{Guard, Id, make_guard}; 22 | #[doc(hidden)] 23 | pub use linkme; 24 | #[cfg(loom)] 25 | use loom::thread::LocalKey; 26 | #[doc(hidden)] 27 | #[cfg(all(not(loom), feature = "rt"))] 28 | pub use tokio::pin; 29 | #[cfg(all(not(loom), feature = "rt"))] 30 | use tokio::task::{JoinHandle, spawn_blocking}; 31 | 32 | #[derive(PartialEq, Eq, Debug)] 33 | pub(crate) enum BarrierContext { 34 | Owner, 35 | /// Tokio Runtime Worker 36 | RuntimeWorker, 37 | /// Tokio Pool Worker 38 | PoolWorker, 39 | } 40 | 41 | thread_local! { 42 | pub(crate) static CONTEXT: RefCell> = const { RefCell::new(None) }; 43 | } 44 | 45 | /// A wrapper type used for creating pointers to thread-locals 46 | pub struct Context( 47 | #[cfg(not(feature = "compat"))] T, 48 | #[cfg(feature = "compat")] Arc, 49 | ); 50 | 51 | impl Context 52 | where 53 | T: Sync, 54 | { 55 | /// Create a new thread-local context 56 | /// 57 | /// If the `compat` feature flag is enabled, [`Context`] will downgrade to internally using [`std::sync::Arc`] to ensure the validity of `T` 58 | /// 59 | /// # Usage 60 | /// 61 | /// Either wrap a type with [`Context`] and assign to a thread-local, or use as an unwrapped field in a struct that derives [`AsContext`] 62 | /// 63 | /// # Example 64 | /// 65 | /// ```rust 66 | /// use std::sync::atomic::{AtomicUsize, Ordering}; 67 | /// 68 | /// use async_local::{AsyncLocal, Context}; 69 | /// use generativity::make_guard; 70 | /// 71 | /// thread_local! { 72 | /// static COUNTER: Context = Context::new(AtomicUsize::new(0)); 73 | /// } 74 | /// 75 | /// #[async_local::main(flavor = "current_thread")] 76 | /// async fn main() { 77 | /// make_guard!(guard); 78 | /// let counter = COUNTER.local_ref(guard); 79 | /// 80 | /// let _count = counter 81 | /// .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed)) 82 | /// .await 83 | /// .unwrap(); 84 | /// } 85 | /// ``` 86 | pub fn new(inner: T) -> Context { 87 | #[cfg(not(feature = "compat"))] 88 | { 89 | Context(inner) 90 | } 91 | #[cfg(feature = "compat")] 92 | { 93 | Context(Arc::new(inner)) 94 | } 95 | } 96 | 97 | /// Construct [`LocalRef`] with an unbounded lifetime. 98 | /// 99 | /// # Safety 100 | /// 101 | /// This lifetime must be restricted to avoid unsoundness 102 | pub unsafe fn local_ref<'a>(&self) -> LocalRef<'a, T> { 103 | unsafe { LocalRef::new(self, Guard::new(Id::new())) } 104 | } 105 | } 106 | 107 | impl AsRef> for Context 108 | where 109 | T: Sync, 110 | { 111 | fn as_ref(&self) -> &Context { 112 | self 113 | } 114 | } 115 | 116 | impl Deref for Context 117 | where 118 | T: Sync, 119 | { 120 | type Target = T; 121 | fn deref(&self) -> &Self::Target { 122 | #[cfg(not(feature = "compat"))] 123 | { 124 | &self.0 125 | } 126 | #[cfg(feature = "compat")] 127 | { 128 | self.0.as_ref() 129 | } 130 | } 131 | } 132 | 133 | /// A marker trait promising that [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context`]> is implemented in a way that can't be invalidated 134 | /// 135 | /// # Safety 136 | /// 137 | /// [`Context`] must not be invalidated as references may exist for the lifetime of the runtime. 138 | pub unsafe trait AsContext: AsRef> { 139 | type Target: Sync + 'static; 140 | } 141 | 142 | unsafe impl AsContext for Context 143 | where 144 | T: Sync, 145 | { 146 | type Target = T; 147 | } 148 | 149 | /// A thread-safe pointer to a thread-local [`Context`] constrained by a "[generative](https://crates.io/crates/generativity)" lifetime brand that is [invariant](https://doc.rust-lang.org/nomicon/subtyping.html#variance) over the lifetime parameter and cannot be coerced into `'static` 150 | #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] 151 | pub struct LocalRef<'id, T: Sync + 'static> { 152 | #[cfg(not(feature = "compat"))] 153 | inner: *const T, 154 | #[cfg(feature = "compat")] 155 | inner: Arc, 156 | /// Lifetime carrier 157 | _brand: Id<'id>, 158 | } 159 | 160 | impl<'id, T> LocalRef<'id, T> 161 | where 162 | T: Sync + 'static, 163 | { 164 | unsafe fn new(context: &Context, guard: Guard<'id>) -> Self { 165 | LocalRef { 166 | #[cfg(not(feature = "compat"))] 167 | inner: addr_of!(context.0), 168 | #[cfg(feature = "compat")] 169 | inner: context.0.clone(), 170 | _brand: guard.into(), 171 | } 172 | } 173 | 174 | /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`] 175 | #[cfg(all(not(loom), feature = "rt"))] 176 | #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] 177 | pub fn with_blocking(self, f: F) -> JoinHandle 178 | where 179 | F: for<'a> FnOnce(LocalRef<'a, T>) -> R + Send + 'static, 180 | R: Send + 'static, 181 | { 182 | use std::mem::transmute; 183 | 184 | let local_ref = unsafe { transmute::, LocalRef<'_, T>>(self) }; 185 | 186 | spawn_blocking(move || f(local_ref)) 187 | } 188 | } 189 | 190 | impl Deref for LocalRef<'_, T> 191 | where 192 | T: Sync, 193 | { 194 | type Target = T; 195 | fn deref(&self) -> &Self::Target { 196 | #[cfg(not(feature = "compat"))] 197 | { 198 | unsafe { &*self.inner } 199 | } 200 | #[cfg(feature = "compat")] 201 | { 202 | self.inner.deref() 203 | } 204 | } 205 | } 206 | 207 | impl Clone for LocalRef<'_, T> 208 | where 209 | T: Sync + 'static, 210 | { 211 | fn clone(&self) -> Self { 212 | LocalRef { 213 | #[cfg(not(feature = "compat"))] 214 | inner: self.inner, 215 | #[cfg(feature = "compat")] 216 | inner: self.inner.clone(), 217 | _brand: self._brand, 218 | } 219 | } 220 | } 221 | 222 | unsafe impl Send for LocalRef<'_, T> where T: Sync {} 223 | unsafe impl Sync for LocalRef<'_, T> where T: Sync {} 224 | /// LocalKey extension for creating thread-safe pointers to thread-local [`Context`] 225 | pub trait AsyncLocal 226 | where 227 | T: AsContext, 228 | { 229 | /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`] 230 | #[cfg(all(not(loom), feature = "rt"))] 231 | #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] 232 | fn with_blocking(&'static self, f: F) -> JoinHandle 233 | where 234 | F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static, 235 | R: Send + 'static; 236 | 237 | /// Acquire a reference to the value in this TLS key. 238 | fn with_async(&'static self, f: F) -> impl Future 239 | where 240 | F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R; 241 | 242 | /// Create a pointer to a thread local [`Context`] using a trusted lifetime carrier. 243 | /// 244 | /// # Usage 245 | /// 246 | /// Use [`generativity::make_guard`] to generate a unique [`invariant`](https://doc.rust-lang.org/nomicon/subtyping.html#variance) lifetime brand 247 | /// 248 | /// # Panic 249 | /// 250 | /// [`LocalRef`] must be created within the async context of the runtime. 251 | fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target>; 252 | } 253 | 254 | impl AsyncLocal for LocalKey 255 | where 256 | T: AsContext, 257 | { 258 | #[cfg(all(not(loom), feature = "rt"))] 259 | #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] 260 | fn with_blocking(&'static self, f: F) -> JoinHandle 261 | where 262 | F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static, 263 | R: Send + 'static, 264 | { 265 | let guard = unsafe { Guard::new(Id::new()) }; 266 | let local_ref = self.local_ref(guard); 267 | spawn_blocking(move || f(local_ref)) 268 | } 269 | 270 | async fn with_async(&'static self, mut f: F) -> R 271 | where 272 | F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R, 273 | { 274 | make_guard!(guard); 275 | let local_ref = self.local_ref(guard); 276 | f(local_ref).await 277 | } 278 | 279 | #[track_caller] 280 | #[inline(always)] 281 | fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target> { 282 | #[cfg(not(feature = "compat"))] 283 | { 284 | if CONTEXT 285 | .with(|context| matches!(&*context.borrow(), None | Some(BarrierContext::PoolWorker))) 286 | { 287 | panic!( 288 | "LocalRef can only be created within the async context of a Tokio Runtime configured by `#[async_local::main]` or `#[async_local::test]`" 289 | ); 290 | } 291 | } 292 | 293 | self.with(|value| unsafe { LocalRef::new(value.as_ref(), guard) }) 294 | } 295 | } 296 | 297 | #[cfg(test)] 298 | mod tests { 299 | use std::sync::atomic::{AtomicUsize, Ordering}; 300 | 301 | use generativity::make_guard; 302 | use tokio::task::yield_now; 303 | 304 | use super::*; 305 | 306 | thread_local! { 307 | static COUNTER: Context = Context::new(AtomicUsize::new(0)); 308 | } 309 | 310 | #[async_local::test] 311 | async fn with_blocking() { 312 | COUNTER 313 | .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed)) 314 | .await 315 | .unwrap(); 316 | 317 | make_guard!(guard); 318 | let local_ref = COUNTER.local_ref(guard); 319 | 320 | local_ref 321 | .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed)) 322 | .await 323 | .unwrap(); 324 | } 325 | 326 | #[async_local::test] 327 | async fn ref_spans_await() { 328 | make_guard!(guard); 329 | let counter = COUNTER.local_ref(guard); 330 | yield_now().await; 331 | counter.fetch_add(1, Ordering::SeqCst); 332 | } 333 | 334 | #[async_local::test] 335 | async fn with_async_trait() { 336 | struct Counter; 337 | 338 | trait Countable { 339 | async fn add_one(ref_guard: LocalRef<'_, AtomicUsize>) -> usize; 340 | } 341 | 342 | impl Countable for Counter { 343 | async fn add_one(counter: LocalRef<'_, AtomicUsize>) -> usize { 344 | yield_now().await; 345 | counter.fetch_add(1, Ordering::Release) 346 | } 347 | } 348 | 349 | make_guard!(guard); 350 | let counter = COUNTER.local_ref(guard); 351 | 352 | Counter::add_one(counter).await; 353 | } 354 | } 355 | -------------------------------------------------------------------------------- /derive-async-local/src/entry.rs: -------------------------------------------------------------------------------- 1 | /// This module is a modified version from Tokio 2 | /// 3 | /// See https://docs.rs/tokio-macros/2.5.0/src/tokio_macros/entry.rs.html 4 | use proc_macro2::{Span, TokenStream, TokenTree}; 5 | use quote::{ToTokens, quote, quote_spanned}; 6 | use syn::{ 7 | Attribute, Ident, PatType, PathSegment, Signature, Visibility, braced, 8 | parse::{Parse, ParseStream, Parser}, 9 | spanned::Spanned, 10 | }; 11 | 12 | // syn::AttributeArgs does not implement syn::Parse 13 | type AttributeArgs = syn::punctuated::Punctuated; 14 | 15 | #[derive(Clone, Copy, PartialEq)] 16 | enum RuntimeFlavor { 17 | CurrentThread, 18 | Threaded, 19 | } 20 | 21 | impl RuntimeFlavor { 22 | fn from_str(s: &str) -> Result { 23 | match s { 24 | "current_thread" => Ok(RuntimeFlavor::CurrentThread), 25 | "multi_thread" => Ok(RuntimeFlavor::Threaded), 26 | "single_thread" => { 27 | Err("The single threaded runtime flavor is called `current_thread`.".to_string()) 28 | } 29 | "basic_scheduler" => Err( 30 | "The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string(), 31 | ), 32 | "threaded_scheduler" => Err( 33 | "The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string(), 34 | ), 35 | _ => Err(format!( 36 | "No such runtime flavor `{s}`. The runtime flavors are `current_thread` and `multi_thread`." 37 | )), 38 | } 39 | } 40 | } 41 | 42 | struct FinalConfig { 43 | flavor: RuntimeFlavor, 44 | worker_threads: Option, 45 | start_paused: Option, 46 | borrow_runtime: Option, 47 | } 48 | 49 | impl FinalConfig { 50 | /// Config used in case of the attribute not being able to build a valid config 51 | fn error_config(input: &ItemFn) -> Self { 52 | let mut config = FinalConfig { 53 | flavor: RuntimeFlavor::CurrentThread, 54 | worker_threads: None, 55 | start_paused: None, 56 | borrow_runtime: None, 57 | }; 58 | 59 | if let Ok(Some(ident)) = get_runtime_ident(&input, false) { 60 | config.borrow_runtime = Some(ident.to_owned()); 61 | } 62 | 63 | config 64 | } 65 | } 66 | 67 | struct Configuration { 68 | rt_multi_thread_available: bool, 69 | default_flavor: RuntimeFlavor, 70 | flavor: Option, 71 | worker_threads: Option<(usize, Span)>, 72 | start_paused: Option<(bool, Span)>, 73 | borrow_runtime: Option, 74 | is_test: bool, 75 | } 76 | 77 | impl Configuration { 78 | fn new(is_test: bool, rt_multi_thread: bool) -> Self { 79 | Configuration { 80 | rt_multi_thread_available: rt_multi_thread, 81 | default_flavor: match is_test { 82 | true => RuntimeFlavor::CurrentThread, 83 | false => RuntimeFlavor::Threaded, 84 | }, 85 | flavor: None, 86 | worker_threads: None, 87 | start_paused: None, 88 | borrow_runtime: None, 89 | is_test, 90 | } 91 | } 92 | 93 | fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { 94 | if self.flavor.is_some() { 95 | return Err(syn::Error::new(span, "`flavor` set multiple times.")); 96 | } 97 | 98 | let runtime_str = parse_string(runtime, span, "flavor")?; 99 | let runtime = 100 | RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; 101 | self.flavor = Some(runtime); 102 | Ok(()) 103 | } 104 | 105 | fn set_worker_threads(&mut self, worker_threads: syn::Lit, span: Span) -> Result<(), syn::Error> { 106 | if self.worker_threads.is_some() { 107 | return Err(syn::Error::new( 108 | span, 109 | "`worker_threads` set multiple times.", 110 | )); 111 | } 112 | 113 | let worker_threads = parse_int(worker_threads, span, "worker_threads")?; 114 | if worker_threads == 0 { 115 | return Err(syn::Error::new(span, "`worker_threads` may not be 0.")); 116 | } 117 | self.worker_threads = Some((worker_threads, span)); 118 | Ok(()) 119 | } 120 | 121 | fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { 122 | if self.start_paused.is_some() { 123 | return Err(syn::Error::new(span, "`start_paused` set multiple times.")); 124 | } 125 | 126 | let start_paused = parse_bool(start_paused, span, "start_paused")?; 127 | self.start_paused = Some((start_paused, span)); 128 | Ok(()) 129 | } 130 | 131 | fn set_borrow_runtime(&mut self, pat_type: &PatType) -> Result<(), syn::Error> { 132 | if self.borrow_runtime.is_some() { 133 | return Err(syn::Error::new( 134 | pat_type.span(), 135 | "attempted to borrow runtime multiple times.", 136 | )); 137 | } 138 | 139 | self.borrow_runtime = Some(pat_type.to_owned()); 140 | Ok(()) 141 | } 142 | 143 | fn macro_name(&self) -> &'static str { 144 | if self.is_test { 145 | "async_local::test" 146 | } else { 147 | "async_local::main" 148 | } 149 | } 150 | 151 | fn build(&self) -> Result { 152 | use RuntimeFlavor as F; 153 | 154 | let flavor = self.flavor.unwrap_or(self.default_flavor); 155 | let worker_threads = match (flavor, self.worker_threads) { 156 | (F::CurrentThread, Some((_, worker_threads_span))) => { 157 | let msg = format!( 158 | "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`", 159 | self.macro_name(), 160 | ); 161 | return Err(syn::Error::new(worker_threads_span, msg)); 162 | } 163 | (F::CurrentThread, None) => None, 164 | (F::Threaded, worker_threads) if self.rt_multi_thread_available => { 165 | worker_threads.map(|(val, _span)| val) 166 | } 167 | (F::Threaded, _) => { 168 | let msg = if self.flavor.is_none() { 169 | "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." 170 | } else { 171 | "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." 172 | }; 173 | return Err(syn::Error::new(Span::call_site(), msg)); 174 | } 175 | }; 176 | 177 | let start_paused = match (flavor, self.start_paused) { 178 | (F::Threaded, Some((_, start_paused_span))) => { 179 | let msg = format!( 180 | "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", 181 | self.macro_name(), 182 | ); 183 | return Err(syn::Error::new(start_paused_span, msg)); 184 | } 185 | (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), 186 | (_, None) => None, 187 | }; 188 | 189 | let borrow_runtime = self.borrow_runtime.clone(); 190 | 191 | Ok(FinalConfig { 192 | flavor, 193 | worker_threads, 194 | start_paused, 195 | borrow_runtime, 196 | }) 197 | } 198 | } 199 | 200 | fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result { 201 | match int { 202 | syn::Lit::Int(lit) => match lit.base10_parse::() { 203 | Ok(value) => Ok(value), 204 | Err(e) => Err(syn::Error::new( 205 | span, 206 | format!("Failed to parse value of `{field}` as integer: {e}"), 207 | )), 208 | }, 209 | _ => Err(syn::Error::new( 210 | span, 211 | format!("Failed to parse value of `{field}` as integer."), 212 | )), 213 | } 214 | } 215 | 216 | fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result { 217 | match int { 218 | syn::Lit::Str(s) => Ok(s.value()), 219 | syn::Lit::Verbatim(s) => Ok(s.to_string()), 220 | _ => Err(syn::Error::new( 221 | span, 222 | format!("Failed to parse value of `{field}` as string."), 223 | )), 224 | } 225 | } 226 | 227 | fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result { 228 | match bool { 229 | syn::Lit::Bool(b) => Ok(b.value), 230 | _ => Err(syn::Error::new( 231 | span, 232 | format!("Failed to parse value of `{field}` as bool."), 233 | )), 234 | } 235 | } 236 | 237 | fn build_config( 238 | input: &ItemFn, 239 | args: AttributeArgs, 240 | is_test: bool, 241 | rt_multi_thread: bool, 242 | ) -> Result { 243 | let mut config = Configuration::new(is_test, rt_multi_thread); 244 | let macro_name = config.macro_name(); 245 | 246 | for arg in args { 247 | match arg { 248 | syn::Meta::NameValue(namevalue) => { 249 | let ident = namevalue 250 | .path 251 | .get_ident() 252 | .ok_or_else(|| syn::Error::new_spanned(&namevalue, "Must have specified ident"))? 253 | .to_string() 254 | .to_lowercase(); 255 | let lit = match &namevalue.value { 256 | syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, 257 | expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")), 258 | }; 259 | match ident.as_str() { 260 | "worker_threads" => { 261 | config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; 262 | } 263 | "flavor" => { 264 | config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; 265 | } 266 | "start_paused" => { 267 | config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; 268 | } 269 | "core_threads" => { 270 | let msg = "Attribute `core_threads` is renamed to `worker_threads`"; 271 | return Err(syn::Error::new_spanned(namevalue, msg)); 272 | } 273 | name => { 274 | let msg = format!( 275 | "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", 276 | ); 277 | return Err(syn::Error::new_spanned(namevalue, msg)); 278 | } 279 | } 280 | } 281 | syn::Meta::Path(path) => { 282 | let name = path 283 | .get_ident() 284 | .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? 285 | .to_string() 286 | .to_lowercase(); 287 | let msg = match name.as_str() { 288 | "threaded_scheduler" | "multi_thread" => { 289 | format!("Set the runtime flavor with #[{macro_name}(flavor = \"multi_thread\")].") 290 | } 291 | "basic_scheduler" | "current_thread" | "single_threaded" => { 292 | format!("Set the runtime flavor with #[{macro_name}(flavor = \"current_thread\")].") 293 | } 294 | "flavor" | "worker_threads" | "start_paused" => { 295 | format!("The `{name}` attribute requires an argument.") 296 | } 297 | name => { 298 | format!( 299 | "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`." 300 | ) 301 | } 302 | }; 303 | return Err(syn::Error::new_spanned(path, msg)); 304 | } 305 | other => { 306 | return Err(syn::Error::new_spanned( 307 | other, 308 | "Unknown attribute inside the macro", 309 | )); 310 | } 311 | } 312 | } 313 | 314 | match (get_runtime_ident(&input, true)?, &input.sig.asyncness) { 315 | (Some(pat), None) => { 316 | config.set_borrow_runtime(pat)?; 317 | } 318 | (Some(_), Some(token)) => { 319 | return Err(syn::Error::new( 320 | token.span(), 321 | "the `async` keyword cannot by used while borrowing the runtime", 322 | )); 323 | } 324 | (None, None) => { 325 | return Err(syn::Error::new_spanned( 326 | input.sig.fn_token, 327 | "the `async` keyword is missing from the function declaration", 328 | )); 329 | } 330 | _ => {} 331 | } 332 | 333 | config.build() 334 | } 335 | 336 | fn get_runtime_ident(input: &ItemFn, strict: bool) -> Result, syn::Error> { 337 | let inputs = input.sig.inputs.iter().map(|fn_arg| 338 | match fn_arg { 339 | syn::FnArg::Receiver(receiver) => Err(syn::Error::new( 340 | receiver.span(), 341 | "function cannot have receiver", 342 | )), 343 | syn::FnArg::Typed(pat_type) => { 344 | if let syn::Type::Reference(type_reference) = pat_type.ty.as_ref() { 345 | if let syn::Type::Path(type_path) = type_reference.elem.as_ref() { 346 | let segments: Vec<&PathSegment> = type_path.path.segments.iter().collect(); 347 | 348 | let runtime_segment = match segments.as_slice() { 349 | &[type_segment] if type_segment.ident.eq("Runtime") => type_segment, 350 | &[module, type_segment] 351 | if module.ident.eq("runtime") && type_segment.ident.eq(&"Runtime") => 352 | { 353 | type_segment 354 | } 355 | &[crate_path, module, type_segment] 356 | if crate_path.ident.eq("tokio") 357 | && module.ident.eq("runtime") 358 | && type_segment.ident.eq("Runtime") => 359 | { 360 | type_segment 361 | } 362 | _ => { 363 | return Err(syn::Error::new( 364 | pat_type.span(), 365 | "unsupported argument type specified", 366 | )); 367 | } 368 | }; 369 | 370 | return match &runtime_segment.arguments { 371 | syn::PathArguments::None => Ok((pat_type, pat_type.span())), 372 | syn::PathArguments::AngleBracketed(angle_bracketed_generic_arguments) => { 373 | let arguments_len = angle_bracketed_generic_arguments.args.len(); 374 | 375 | if arguments_len.eq(&1) { 376 | Err(syn::Error::new(pat_type.span(), format!("Runtime takes 0 generic arguments but 1 generic argument was supplied"))) 377 | } else { 378 | Err(syn::Error::new(pat_type.span(), format!("Runtime takes 0 generic arguments but {arguments_len} generic arguments were supplied"))) 379 | } 380 | }, 381 | syn::PathArguments::Parenthesized(_) => { 382 | Err(syn::Error::new(pat_type.span(), format!("Runtime cannot have parenthesized type parameters"))) 383 | }, 384 | } 385 | }; 386 | } 387 | 388 | Err(syn::Error::new( 389 | fn_arg.span(), 390 | "unsupported argument type specified", 391 | )) 392 | } 393 | } 394 | ); 395 | 396 | let mut runtime_pat = None; 397 | let mut error: Option = None; 398 | 399 | for result in inputs { 400 | match result { 401 | Ok((pat, span)) => { 402 | if runtime_pat.is_some() { 403 | let err = syn::Error::new(span, "attempted to borrow runtime multiple times"); 404 | 405 | if let Some(error) = &mut error { 406 | error.combine(err); 407 | } else { 408 | error = Some(err); 409 | } 410 | } else { 411 | if strict { 412 | runtime_pat = Some(pat); 413 | } else { 414 | return Ok(Some(pat)); 415 | } 416 | } 417 | } 418 | Err(err) => { 419 | if let Some(error) = &mut error { 420 | error.combine(err); 421 | } else { 422 | error = Some(err); 423 | } 424 | } 425 | } 426 | } 427 | 428 | if let Some(err) = error { 429 | Err(err) 430 | } else { 431 | Ok(runtime_pat) 432 | } 433 | } 434 | 435 | fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { 436 | input.sig.asyncness = None; 437 | input.sig.inputs.clear(); 438 | 439 | // If type mismatch occurs, the current rustc points to the last statement. 440 | let (last_stmt_start_span, last_stmt_end_span) = { 441 | let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); 442 | 443 | // `Span` on stable Rust has a limitation that only points to the first 444 | // token, not the whole tokens. We can work around this limitation by 445 | // using the first/last span of the tokens like 446 | // `syn::Error::new_spanned` does. 447 | let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); 448 | let end = last_stmt.last().map_or(start, |t| t.span()); 449 | (start, end) 450 | }; 451 | 452 | let crate_path = Ident::new("async_local", last_stmt_start_span).into_token_stream(); 453 | 454 | let mut rt = match config.flavor { 455 | RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> 456 | #crate_path::__runtime::Builder::new_current_thread() 457 | }, 458 | RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> 459 | #crate_path::__runtime::Builder::new_multi_thread() 460 | }, 461 | }; 462 | if let Some(v) = config.worker_threads { 463 | rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; 464 | } 465 | if let Some(v) = config.start_paused { 466 | rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; 467 | } 468 | 469 | let generated_attrs = if is_test { 470 | quote! { 471 | #[::core::prelude::v1::test] 472 | } 473 | } else { 474 | quote! {} 475 | }; 476 | 477 | let ensure_configured = if !is_test { 478 | quote_spanned! { last_stmt_start_span => 479 | if module_path!().contains("::") { 480 | panic!("#[async_local::main] can only be used on the crate root main function"); 481 | } 482 | 483 | #[async_local::linkme::distributed_slice(async_local::__runtime::RUNTIMES)] 484 | #[linkme(crate = async_local::linkme)] 485 | static RUNTIME_CONTEXT: async_local::__runtime::RuntimeContext = async_local::__runtime::RuntimeContext::Main; 486 | } 487 | } else { 488 | quote_spanned! { last_stmt_start_span => 489 | #[async_local::linkme::distributed_slice(async_local::__runtime::RUNTIMES)] 490 | #[linkme(crate = async_local::linkme)] 491 | static RUNTIME_CONTEXT: async_local::__runtime::RuntimeContext = async_local::__runtime::RuntimeContext::Test; 492 | } 493 | }; 494 | 495 | let body_ident = quote! { body }; 496 | 497 | let run = if config.borrow_runtime.is_some() { 498 | quote! { 499 | return unsafe { 500 | runtime.run(#body_ident) 501 | }; 502 | } 503 | } else { 504 | quote! { 505 | return unsafe { 506 | runtime.block_on(#body_ident) 507 | }; 508 | } 509 | }; 510 | 511 | // This explicit `return` is intentional. See tokio-rs/tokio#4636 512 | let last_block = quote_spanned! {last_stmt_end_span=> 513 | #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)] 514 | { 515 | #ensure_configured 516 | 517 | let runtime = #rt 518 | .enable_all() 519 | .build() 520 | .expect("Failed building the Runtime"); 521 | 522 | #run 523 | } 524 | }; 525 | 526 | let body = input.body(); 527 | 528 | let body = if let Some(ident) = config.borrow_runtime { 529 | quote! { 530 | let body = |#ident| #body; 531 | } 532 | } 533 | // For test functions pin the body to the stack and use `Pin<&mut dyn 534 | // Future>` to reduce the amount of `Runtime::block_on` (and related 535 | // functions) copies we generate during compilation due to the generic 536 | // parameter `F` (the future to block on). This could have an impact on 537 | // performance, but because it's only for testing it's unlikely to be very 538 | // large. 539 | // 540 | // We don't do this for the main function as it should only be used once so 541 | // there will be no benefit. 542 | else if is_test { 543 | let output_type = match &input.sig.output { 544 | // For functions with no return value syn doesn't print anything, 545 | // but that doesn't work as `Output` for our boxed `Future`, so 546 | // default to `()` (the same type as the function output). 547 | syn::ReturnType::Default => quote! { () }, 548 | syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, 549 | }; 550 | quote! { 551 | let body = async #body; 552 | #crate_path::pin!(body); 553 | let body: ::core::pin::Pin<&mut dyn ::core::future::Future> = body; 554 | } 555 | } else { 556 | quote! { 557 | let body = async #body; 558 | } 559 | }; 560 | 561 | input.into_tokens(generated_attrs, body, last_block) 562 | } 563 | 564 | fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { 565 | tokens.extend(error.into_compile_error()); 566 | tokens 567 | } 568 | 569 | pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { 570 | // If any of the steps for this macro fail, we still want to expand to an item that is as close 571 | // to the expected output as possible. This helps out IDEs such that completions and other 572 | // related features keep working. 573 | let input: ItemFn = match syn::parse2(item.clone()) { 574 | Ok(it) => it, 575 | Err(e) => return token_stream_with_error(item, e), 576 | }; 577 | 578 | let config = if input.sig.ident != "main" { 579 | Err(syn::Error::new_spanned( 580 | &input.sig.ident, 581 | "macro can only be used on the root main function", 582 | )) 583 | } else { 584 | AttributeArgs::parse_terminated 585 | .parse2(args) 586 | .and_then(|args| build_config(&input, args, false, rt_multi_thread)) 587 | }; 588 | 589 | match config { 590 | Ok(config) => parse_knobs(input, false, config), 591 | Err(e) => { 592 | let config = FinalConfig::error_config(&input); 593 | token_stream_with_error(parse_knobs(input, false, config), e) 594 | } 595 | } 596 | } 597 | 598 | // Check whether given attribute is a test attribute of forms: 599 | // * `#[test]` 600 | // * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]` 601 | // * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]` 602 | fn is_test_attribute(attr: &Attribute) -> bool { 603 | let path = match &attr.meta { 604 | syn::Meta::Path(path) => path, 605 | _ => return false, 606 | }; 607 | let candidates = [ 608 | ["core", "prelude", "*", "test"], 609 | ["std", "prelude", "*", "test"], 610 | ]; 611 | if path.leading_colon.is_none() 612 | && path.segments.len() == 1 613 | && path.segments[0].arguments.is_none() 614 | && path.segments[0].ident == "test" 615 | { 616 | return true; 617 | } else if path.segments.len() != candidates[0].len() { 618 | return false; 619 | } 620 | candidates.into_iter().any(|segments| { 621 | path 622 | .segments 623 | .iter() 624 | .zip(segments) 625 | .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path)) 626 | }) 627 | } 628 | 629 | pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { 630 | // If any of the steps for this macro fail, we still want to expand to an item that is as close 631 | // to the expected output as possible. This helps out IDEs such that completions and other 632 | // related features keep working. 633 | let input: ItemFn = match syn::parse2(item.clone()) { 634 | Ok(it) => it, 635 | Err(e) => return token_stream_with_error(item, e), 636 | }; 637 | let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) { 638 | let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes"; 639 | Err(syn::Error::new_spanned(attr, msg)) 640 | } else { 641 | AttributeArgs::parse_terminated 642 | .parse2(args) 643 | .and_then(|args| build_config(&input, args, true, rt_multi_thread)) 644 | }; 645 | 646 | match config { 647 | Ok(config) => parse_knobs(input, true, config), 648 | Err(e) => { 649 | let config = FinalConfig::error_config(&input); 650 | token_stream_with_error(parse_knobs(input, true, config), e) 651 | } 652 | } 653 | } 654 | 655 | struct ItemFn { 656 | outer_attrs: Vec, 657 | vis: Visibility, 658 | sig: Signature, 659 | brace_token: syn::token::Brace, 660 | inner_attrs: Vec, 661 | stmts: Vec, 662 | } 663 | 664 | impl ItemFn { 665 | /// Access all attributes of the function item. 666 | fn attrs(&self) -> impl Iterator { 667 | self.outer_attrs.iter().chain(self.inner_attrs.iter()) 668 | } 669 | 670 | /// Get the body of the function item in a manner so that it can be 671 | /// conveniently used with the `quote!` macro. 672 | fn body(&self) -> Body<'_> { 673 | Body { 674 | brace_token: self.brace_token, 675 | stmts: &self.stmts, 676 | } 677 | } 678 | 679 | /// Convert our local function item into a token stream. 680 | fn into_tokens( 681 | self, 682 | generated_attrs: proc_macro2::TokenStream, 683 | body: proc_macro2::TokenStream, 684 | last_block: proc_macro2::TokenStream, 685 | ) -> TokenStream { 686 | let mut tokens = proc_macro2::TokenStream::new(); 687 | // Outer attributes are simply streamed as-is. 688 | for attr in self.outer_attrs { 689 | attr.to_tokens(&mut tokens); 690 | } 691 | 692 | // Inner attributes require extra care, since they're not supported on 693 | // blocks (which is what we're expanded into) we instead lift them 694 | // outside of the function. This matches the behavior of `syn`. 695 | for mut attr in self.inner_attrs { 696 | attr.style = syn::AttrStyle::Outer; 697 | attr.to_tokens(&mut tokens); 698 | } 699 | 700 | // Add generated macros at the end, so macros processed later are aware of them. 701 | generated_attrs.to_tokens(&mut tokens); 702 | 703 | self.vis.to_tokens(&mut tokens); 704 | self.sig.to_tokens(&mut tokens); 705 | 706 | self.brace_token.surround(&mut tokens, |tokens| { 707 | body.to_tokens(tokens); 708 | last_block.to_tokens(tokens); 709 | }); 710 | 711 | tokens 712 | } 713 | } 714 | 715 | impl Parse for ItemFn { 716 | #[inline] 717 | fn parse(input: ParseStream<'_>) -> syn::Result { 718 | // This parse implementation has been largely lifted from `syn`, with 719 | // the exception of: 720 | // * We don't have access to the plumbing necessary to parse inner 721 | // attributes in-place. 722 | // * We do our own statements parsing to avoid recursively parsing 723 | // entire statements and only look for the parts we're interested in. 724 | 725 | let outer_attrs = input.call(Attribute::parse_outer)?; 726 | let vis: Visibility = input.parse()?; 727 | let sig: Signature = input.parse()?; 728 | 729 | let content; 730 | let brace_token = braced!(content in input); 731 | let inner_attrs = Attribute::parse_inner(&content)?; 732 | 733 | let mut buf = proc_macro2::TokenStream::new(); 734 | let mut stmts = Vec::new(); 735 | 736 | while !content.is_empty() { 737 | if let Some(semi) = content.parse::>()? { 738 | semi.to_tokens(&mut buf); 739 | stmts.push(buf); 740 | buf = proc_macro2::TokenStream::new(); 741 | continue; 742 | } 743 | 744 | // Parse a single token tree and extend our current buffer with it. 745 | // This avoids parsing the entire content of the sub-tree. 746 | buf.extend([content.parse::()?]); 747 | } 748 | 749 | if !buf.is_empty() { 750 | stmts.push(buf); 751 | } 752 | 753 | Ok(Self { 754 | outer_attrs, 755 | vis, 756 | sig, 757 | brace_token, 758 | inner_attrs, 759 | stmts, 760 | }) 761 | } 762 | } 763 | 764 | struct Body<'a> { 765 | brace_token: syn::token::Brace, 766 | // Statements, with terminating `;`. 767 | stmts: &'a [TokenStream], 768 | } 769 | 770 | impl ToTokens for Body<'_> { 771 | fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { 772 | self.brace_token.surround(tokens, |tokens| { 773 | for stmt in self.stmts { 774 | stmt.to_tokens(tokens); 775 | } 776 | }); 777 | } 778 | } 779 | --------------------------------------------------------------------------------