├── .gitignore ├── Cargo.toml ├── async_dag_tools ├── Cargo.toml └── src │ └── bin │ ├── pre-commit.rs │ └── install-pre-commit-hook.rs ├── async_dag ├── Cargo.toml ├── examples │ ├── fib.rs │ └── tree.rs └── src │ ├── graph │ ├── infallible.rs │ ├── error.rs │ └── runner.rs │ ├── any.rs │ ├── curry.rs │ ├── task │ └── infallible.rs │ ├── task.rs │ ├── lib.rs │ ├── tuple.rs │ └── graph.rs └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "async_dag", 4 | "async_dag_tools" 5 | ] 6 | -------------------------------------------------------------------------------- /async_dag_tools/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "async_dag_tools" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | anyhow = "1" 10 | -------------------------------------------------------------------------------- /async_dag/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "async_dag" 3 | version = "0.1.2" 4 | edition = "2021" 5 | description = "An async task scheduling utilitiy." 6 | license = "MIT OR Apache-2.0" 7 | repository = "https://github.com/chubei-oppen/async_dag" 8 | keywords = ["async", "dag", "scheduling"] 9 | categories = ["algorithms", "asynchronous", "concurrency"] 10 | readme = "../README.md" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | daggy = "0.8.0" 16 | dyn-clone = "1.0.5" 17 | futures = "0.3.21" 18 | seq-macro = "0.3.0" 19 | -------------------------------------------------------------------------------- /async_dag_tools/src/bin/pre-commit.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | use anyhow::*; 4 | 5 | fn run(command: &mut Command) -> Result<()> { 6 | let status = command.status()?; 7 | if !status.success() { 8 | bail!("{:?} failed with status {}", command, status); 9 | } 10 | Ok(()) 11 | } 12 | 13 | fn main() -> Result<()> { 14 | run(Command::new("cargo").args(["fmt", "--check"]))?; 15 | run(Command::new("cargo").args(["clippy", "--", "-D", "warnings"]))?; 16 | run(Command::new("cargo").args(["doc"]))?; 17 | run(Command::new("cargo").args(["test"]))?; 18 | run(Command::new("cargo") 19 | .args(["sync-readme", "-c"]) 20 | .current_dir("async_dag"))?; 21 | Ok(()) 22 | } 23 | -------------------------------------------------------------------------------- /async_dag/examples/fib.rs: -------------------------------------------------------------------------------- 1 | use async_dag::*; 2 | use futures::executor::block_on; 3 | 4 | const N: usize = 44; 5 | 6 | async fn sum(lhs: i32, rhs: i32) -> i32 { 7 | lhs + rhs 8 | } 9 | 10 | fn main() { 11 | let mut graph = Graph::new(); 12 | let mut first = graph.add_task(|| async { 1 }); 13 | let mut second = graph.add_task(|| async { 1 }); 14 | for _ in 0..N { 15 | let next = graph.add_child_task(first, sum, 0).unwrap(); 16 | graph.update_dependency(second, next, 1).unwrap(); 17 | 18 | first = second; 19 | second = next; 20 | } 21 | block_on(graph.run()); 22 | let result = graph.get_value::(second).unwrap(); 23 | print!("The {}th fibonacci number is {}", N + 2, result); 24 | } 25 | -------------------------------------------------------------------------------- /async_dag/examples/tree.rs: -------------------------------------------------------------------------------- 1 | use async_dag::*; 2 | use futures::executor::block_on; 3 | 4 | async fn sum(lhs: i32, rhs: i32) -> i32 { 5 | lhs + rhs 6 | } 7 | 8 | fn add_parent_task(graph: &mut Graph, depth: u8, child: NodeIndex) { 9 | if depth == 0 { 10 | graph.add_parent_task(|| async { 1i32 }, child, 0).unwrap(); 11 | graph.add_parent_task(|| async { 1i32 }, child, 1).unwrap(); 12 | } else { 13 | let lhs = graph.add_parent_task(sum, child, 0).unwrap(); 14 | add_parent_task(graph, depth - 1, lhs); 15 | let rhs = graph.add_parent_task(sum, child, 1).unwrap(); 16 | add_parent_task(graph, depth - 1, rhs); 17 | } 18 | } 19 | 20 | fn main() { 21 | let mut graph = Graph::new(); 22 | let root = graph.add_task(sum); 23 | add_parent_task(&mut graph, 10, root); 24 | block_on(graph.run()); 25 | let result = graph.get_value::(root).unwrap(); 26 | println!("Result: {}", result); 27 | } 28 | -------------------------------------------------------------------------------- /async_dag_tools/src/bin/install-pre-commit-hook.rs: -------------------------------------------------------------------------------- 1 | use anyhow::*; 2 | use std::{ffi::OsString, fs::copy, path::PathBuf, process::Command}; 3 | 4 | fn main() -> Result<()> { 5 | let mut result_path: OsString = "./.git/hooks/pre-commit".into(); 6 | if cfg!(windows) { 7 | result_path.push(".exe"); 8 | } 9 | let result_path: PathBuf = result_path.into(); 10 | 11 | if !result_path.exists() { 12 | let target_name = "pre-commit"; 13 | let mut command = Command::new("cargo"); 14 | command.args(["build", "--bin", target_name, "--release"]); 15 | let status = command.status()?; 16 | if !status.success() { 17 | bail!("{:?} failed with status {}", command, status); 18 | } 19 | 20 | let mut executable_file_name: String = target_name.into(); 21 | if cfg!(windows) { 22 | executable_file_name.push_str(".exe"); 23 | } 24 | 25 | copy( 26 | format!("./target/release/{}", executable_file_name), 27 | result_path, 28 | )?; 29 | } else { 30 | bail!("{:?} already exists", result_path); 31 | } 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /async_dag/src/graph/infallible.rs: -------------------------------------------------------------------------------- 1 | use super::Edge; 2 | use super::NodeIndex; 3 | use super::TryGraph; 4 | use crate::any::IntoAny; 5 | use crate::error::ErrorWithTask; 6 | use crate::task::IntoInfallibleTask; 7 | use std::convert::Infallible; 8 | 9 | /// A [`TryGraph`] with infallible tasks. 10 | pub type Graph<'a> = TryGraph<'a, Infallible>; 11 | 12 | impl<'a> Graph<'a> { 13 | /// Adds an infallible task. See [`TryGraph::add_try_task`]. 14 | pub fn add_task>( 15 | &mut self, 16 | task: T, 17 | ) -> NodeIndex { 18 | self.add_task_impl(task.into_task()) 19 | } 20 | 21 | /// Adds an infallible task and set it as `child`'s dependency at `index`. 22 | /// 23 | /// See [`TryGraph::add_parent_try_task`]. 24 | pub fn add_parent_task>( 25 | &mut self, 26 | task: T, 27 | child: NodeIndex, 28 | index: Edge, 29 | ) -> Result> { 30 | self.add_parent_task_impl::(task.into_task(), child, index) 31 | } 32 | 33 | /// Adds an infallible task and set it's dependency at `index` to `parent`. 34 | /// 35 | /// See [`TryGraph::add_child_try_task`]. 36 | pub fn add_child_task>( 37 | &mut self, 38 | parent: NodeIndex, 39 | task: T, 40 | index: Edge, 41 | ) -> Result> { 42 | self.add_child_task_impl::(parent, task.into_task(), index) 43 | } 44 | 45 | /// Infallible version of [`TryGraph::run`]. 46 | pub async fn run(&mut self) { 47 | self.try_run().await.unwrap(); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /async_dag/src/graph/error.rs: -------------------------------------------------------------------------------- 1 | //! The error types. 2 | 3 | use super::NodeIndex; 4 | use crate::any::TypeInfo; 5 | use crate::tuple::TupleIndex; 6 | 7 | /// Errors that can happen during graph construction. 8 | #[derive(Debug)] 9 | #[allow(variant_size_differences)] 10 | pub enum Error { 11 | /// The specified dependent node has started running its task and can't have its dependency modified. 12 | HasStarted(NodeIndex), 13 | /// The specified dependency index is greater than or equal to the dependent node's task's number of inputs. 14 | OutOfRange(TupleIndex), 15 | /// The dependent node's task has `input` type at specified index, but the depended node's task has a different `output` type. 16 | TypeMismatch { 17 | /// The input type for the child. 18 | input: TypeInfo, 19 | /// The output type from the parent. 20 | output: TypeInfo, 21 | }, 22 | /// Adding the specified dependency would have caused the graph to cycle. 23 | WouldCycle, 24 | } 25 | 26 | impl std::fmt::Display for Error { 27 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 28 | match self { 29 | Self::HasStarted(index) => f.debug_tuple("Error::HasStarted").field(index).finish(), 30 | Self::OutOfRange(len) => f.debug_tuple("Error::OutOfRange").field(len).finish(), 31 | Self::TypeMismatch { input, output } => f 32 | .debug_struct("Error::TypeMismatch") 33 | .field("input", input) 34 | .field("output", output) 35 | .finish(), 36 | Self::WouldCycle => f.debug_tuple("Error::WouldCycle").finish(), 37 | } 38 | } 39 | } 40 | 41 | impl std::error::Error for Error {} 42 | 43 | /// An [`Error`] and a [`TryTask`](crate::task::TryTask). 44 | #[derive(Debug)] 45 | pub struct ErrorWithTask { 46 | /// The error. 47 | pub error: Error, 48 | /// The task. 49 | pub task: T, 50 | } 51 | 52 | impl std::fmt::Display for ErrorWithTask { 53 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 54 | f.debug_struct("ErrorWithTask") 55 | .field("error", &self.error) 56 | .field("task", &self.task) 57 | .finish() 58 | } 59 | } 60 | 61 | impl std::error::Error for ErrorWithTask {} 62 | -------------------------------------------------------------------------------- /async_dag/src/any.rs: -------------------------------------------------------------------------------- 1 | use dyn_clone::DynClone; 2 | use std::{ 3 | any::{type_name, Any, TypeId}, 4 | hash::Hash, 5 | }; 6 | 7 | /// Conversion to [`Any`] to workaround [#65991](https://github.com/rust-lang/rust/issues/65991). 8 | /// Implemented for anything that's `'static` and [`Clone`]. 9 | pub trait IntoAny: DynClone + Any { 10 | /// The conversion. 11 | fn into_any(self: Box) -> Box; 12 | } 13 | 14 | dyn_clone::clone_trait_object!(IntoAny); 15 | 16 | impl IntoAny for T { 17 | fn into_any(self: Box) -> Box { 18 | Box::new(*self) 19 | } 20 | } 21 | 22 | pub fn downcast(value: Box) -> Result> { 23 | if (*value).type_id() != TypeId::of::() { 24 | return Err(value); 25 | } 26 | let value = value.into_any(); 27 | // We've checked the type id. 28 | Ok(*Box::::downcast::(value).unwrap()) 29 | } 30 | 31 | /// A [`TypeId`] and the type's name. 32 | #[derive(Debug, Clone, Copy)] 33 | pub struct TypeInfo { 34 | id: TypeId, 35 | name: &'static str, 36 | } 37 | 38 | impl TypeInfo { 39 | /// Gets the [`TypeId`]. 40 | pub fn id(&self) -> TypeId { 41 | self.id 42 | } 43 | 44 | /// Gets the type name. 45 | pub fn name(&self) -> &'static str { 46 | self.name 47 | } 48 | 49 | /// Returns the [`TypeInfo`] of the type this generic function has been 50 | /// instantiated with. 51 | pub fn of() -> Self { 52 | TypeInfo { 53 | id: TypeId::of::(), 54 | name: type_name::(), 55 | } 56 | } 57 | } 58 | 59 | impl Hash for TypeInfo { 60 | fn hash(&self, state: &mut H) { 61 | self.id.hash(state) 62 | } 63 | } 64 | 65 | impl PartialEq for TypeInfo { 66 | fn eq(&self, other: &TypeInfo) -> bool { 67 | self.id.eq(&other.id) 68 | } 69 | } 70 | 71 | impl Eq for TypeInfo {} 72 | 73 | impl PartialOrd for TypeInfo { 74 | fn partial_cmp(&self, other: &TypeInfo) -> Option { 75 | self.id.partial_cmp(&other.id) 76 | } 77 | } 78 | 79 | impl Ord for TypeInfo { 80 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 81 | self.id.cmp(&other.id) 82 | } 83 | } 84 | 85 | /// A [`Box`]ed [`IntoAny`]. 86 | pub type DynAny = Box; 87 | 88 | impl std::fmt::Debug for DynAny { 89 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 90 | f.debug_struct("NamedAny").finish_non_exhaustive() 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /async_dag/src/curry.rs: -------------------------------------------------------------------------------- 1 | use crate::any::DynAny; 2 | use crate::any::IntoAny; 3 | use crate::any::TypeInfo; 4 | use crate::task::TryTask; 5 | use crate::tuple::InsertResult; 6 | use crate::tuple::TakeError; 7 | use crate::tuple::Tuple; 8 | use crate::tuple::TupleIndex; 9 | use crate::tuple::TupleOption; 10 | use futures::future::BoxFuture; 11 | use futures::FutureExt; 12 | use futures::TryFutureExt; 13 | 14 | pub type TaskFuture<'a, Err> = BoxFuture<'a, Result>; 15 | 16 | /// [`Curry`] describes the process of currying and finally calling. 17 | pub trait Curry<'a, Err> { 18 | /// The number of inputs of the original task. 19 | fn num_inputs(&self) -> TupleIndex; 20 | 21 | /// Returns the [`TypeInfo`] of the input at `index`, [`None`] if `index` is out of range. 22 | fn input_type_info(&self, index: TupleIndex) -> Option; 23 | 24 | /// Returns the [`TypeInfo`] of the successful output. 25 | fn output_type_info(&self) -> TypeInfo; 26 | 27 | /// Returns `true` if the inner task's inputs has been populated and becomes ready for running. 28 | fn ready(&self) -> bool; 29 | 30 | /// Inserts a input to the inner task, i.e. currying. 31 | /// 32 | /// `self` is unchanged on error. 33 | fn curry(&mut self, index: TupleIndex, value: DynAny) -> InsertResult; 34 | 35 | /// Consumes the inner task and inputs and returns a future of the output value. 36 | fn call(self: Box) -> Result, TakeError>; 37 | } 38 | 39 | /// [`CurriedTask`] holds a task and its inputs and tracks if all inputs are ready. 40 | pub struct CurriedTask<'a, Err, T: TryTask<'a, Err = Err>> { 41 | task: T, 42 | inputs: ::Option, 43 | } 44 | 45 | impl<'a, Err, T: TryTask<'a, Err = Err>> CurriedTask<'a, Err, T> { 46 | /// Creates a [CurriedTask] from a task and no inputs. 47 | pub fn new(task: T) -> Self { 48 | CurriedTask { 49 | task, 50 | inputs: Default::default(), 51 | } 52 | } 53 | } 54 | 55 | fn make_any(t: T) -> DynAny { 56 | Box::new(t) 57 | } 58 | 59 | impl<'a, Err, T: TryTask<'a, Err = Err>> Curry<'a, Err> for CurriedTask<'a, Err, T> { 60 | fn num_inputs(&self) -> TupleIndex { 61 | T::Inputs::LEN 62 | } 63 | 64 | fn input_type_info(&self, index: TupleIndex) -> Option { 65 | T::Inputs::type_info(index) 66 | } 67 | 68 | fn output_type_info(&self) -> TypeInfo { 69 | TypeInfo::of::() 70 | } 71 | 72 | fn ready(&self) -> bool { 73 | self.inputs.first_none().is_none() 74 | } 75 | 76 | fn curry(&mut self, index: TupleIndex, value: DynAny) -> InsertResult { 77 | self.inputs.insert(index, value) 78 | } 79 | 80 | fn call(self: Box) -> Result, TakeError> { 81 | let CurriedTask { task, mut inputs } = *self; 82 | let inputs = inputs.take()?; 83 | let future = task.run(inputs); 84 | let future = future.map_ok(make_any); 85 | Ok(future.boxed()) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /async_dag/src/task/infallible.rs: -------------------------------------------------------------------------------- 1 | use super::TryTask; 2 | use crate::any::IntoAny; 3 | use futures::future::FutureExt; 4 | use futures::future::Map; 5 | use seq_macro::seq; 6 | use std::any::type_name; 7 | use std::convert::Infallible; 8 | use std::future::Future; 9 | use std::marker::PhantomData; 10 | 11 | /// Conversion to a [`Infallible`] [`TryTask`]. 12 | pub trait IntoInfallibleTask<'a, Args, Ok> { 13 | /// The [`TryTask`] type. 14 | type Task: TryTask<'a, Ok = Ok, Err = Infallible> + 'a; 15 | 16 | /// The conversion. 17 | fn into_task(self) -> Self::Task; 18 | } 19 | 20 | /// A [`Infallible`] [`TryTask`] for types that implement [`FnOnce`]. 21 | pub struct InfallibleFnOnceTask { 22 | function: Fn, 23 | ok: PhantomData, 24 | fut: PhantomData, 25 | args: PhantomData, 26 | } 27 | 28 | impl InfallibleFnOnceTask { 29 | fn new(function: Fn) -> Self { 30 | InfallibleFnOnceTask { 31 | function, 32 | ok: Default::default(), 33 | fut: Default::default(), 34 | args: Default::default(), 35 | } 36 | } 37 | } 38 | 39 | impl std::fmt::Debug for InfallibleFnOnceTask { 40 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 41 | f.write_str(&format!( 42 | "InfallibleFnOnceTask{} -> impl Future {{ ... }}", 43 | type_name::(), 44 | type_name::(), 45 | )) 46 | } 47 | } 48 | 49 | macro_rules! task_impl { 50 | ($N:literal) => { 51 | seq!(i in 0..$N { 52 | impl<'a, Fn, Ok, Fut, #(I~i,)*> IntoInfallibleTask<'a, (#(I~i,)*), Ok> for Fn 53 | where 54 | Fn: FnOnce(#(I~i,)*) -> Fut + 'a, 55 | Ok: IntoAny, 56 | Fut: Future + Send + 'a, 57 | #( 58 | I~i: IntoAny, 59 | )* 60 | { 61 | type Task = InfallibleFnOnceTask; 62 | 63 | fn into_task(self) -> Self::Task { 64 | InfallibleFnOnceTask::new(self) 65 | } 66 | } 67 | 68 | impl<'a, Fn, Ok, Fut, #(I~i,)*> TryTask<'a> for InfallibleFnOnceTask 69 | where 70 | Fn: FnOnce(#(I~i,)*) -> Fut, 71 | Ok: IntoAny, 72 | Fut: Future + Send + 'a, 73 | #( 74 | I~i: IntoAny, 75 | )* 76 | { 77 | type Inputs = (#(I~i,)*); 78 | type Ok = Ok; 79 | type Err = Infallible; 80 | type Future = Map Result>; 81 | fn run(self, (#(v~i,)*): Self::Inputs) -> Self::Future { 82 | (self.function)(#(v~i,)*).map(Ok) 83 | } 84 | } 85 | }); 86 | }; 87 | } 88 | 89 | seq!(N in 0..=12 { 90 | task_impl!(N); 91 | }); 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async DAG 2 | 3 | [![Crate](https://img.shields.io/crates/v/async_dag.svg)](https://crates.io/crates/async_dag) 4 | ![Crates.io](https://img.shields.io/crates/l/async_dag.svg) 5 | [![API](https://docs.rs/async_dag/badge.svg)](https://docs.rs/async_dag) 6 | 7 | 8 | 9 | `async_dag` is an async task scheduling utility. 10 | 11 | When async tasks and their dependencies can be described by a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph), 12 | this crate ensures the tasks are run at maximum posiible parallelism. 13 | 14 | # Example 15 | 16 | Say there are several tasks which either produces an `i32` or sums two `i32`s, 17 | and they have dependency relationship described by following graph, 18 | 19 | ```text 20 | 7 21 | / \ 22 | 3 \ 23 | / \ \ 24 | 1 2 4 25 | ``` 26 | 27 | which means there are three tasks producing value `1`, `2` and `4`, 28 | a task summing `1` and `2` to get `3`, 29 | and a task summing `3` and `4` to get the final output, `7`. 30 | 31 | A casual developer may write 32 | 33 | ```rust 34 | let _3 = sum(_1.await, _2.await).await; 35 | let _7 = sum(_3, _4.await).await; 36 | ``` 37 | 38 | Above code is inefficient because every task only begins after the previous one completes. 39 | 40 | A better version would be 41 | 42 | ```rust 43 | let (_1, _2, _4) = join!(_1, _2, _4).await; 44 | let _3 = sum(_1, _2).await; 45 | let _7 = sum(_3, _4).await; 46 | ``` 47 | 48 | where `_1`, `_2` and `_4` run in parallel. 49 | 50 | However, above scheduling is still not optimal 51 | because the summing of `_1` and `_2` can run in parallel with `_4`. 52 | 53 | To acheive maximum parallelism, one has to write something like 54 | 55 | ```rust 56 | let _1_2 = join!(_1, _2); 57 | let (_3, _4) = select! { 58 | _3 = _1_2 => { 59 | (_3, _4.await) 60 | } 61 | _4 = _4 => { 62 | let (_1, _2) = _1_2.await; 63 | (sum(_1, _2).await, _4) 64 | } 65 | } 66 | let _7 = sum(_3, _4).await; 67 | ``` 68 | 69 | The code is quite obscure 70 | and the manual scheduling quickly becomes tiring, 71 | if possible at all, with a few more tasks and dependencies. 72 | 73 | With `async_dag`, one can write 74 | 75 | ```rust 76 | use async_dag::Graph; 77 | 78 | async fn sum(lhs: i32, rhs: i32) -> i32 { lhs + rhs } 79 | 80 | async fn run() { 81 | let mut graph = Graph::new(); 82 | // The closures are not run yet. 83 | let _1 = graph.add_task(|| async { 1 } ); 84 | let _2 = graph.add_task(|| async { 2 } ); 85 | let _4 = graph.add_task(|| async { 4 } ); 86 | 87 | // Sets `_1` as `_3`'s first parameter. 88 | let _3 = graph.add_child_task(_1, sum, 0).unwrap(); 89 | // Sets `_2` as `_3`'s second parameter. 90 | graph.update_dependency(_2, _3, 1).unwrap(); 91 | 92 | // Sets `_3` as `_7`'s first parameter. 93 | let _7 = graph.add_child_task(_3, sum, 0).unwrap(); 94 | // Sets `_4` as `_7`'s second parameter. 95 | graph.update_dependency(_4, _7, 1).unwrap(); 96 | 97 | // Runs all the tasks with maximum possible parallelism. 98 | graph.run().await; 99 | 100 | assert_eq!(graph.get_value::(_7).unwrap(), 7); 101 | } 102 | 103 | use futures::executor::block_on; 104 | block_on(run()); 105 | 106 | ``` 107 | 108 | # Fail-fast graphs 109 | 110 | `TryGraph` can be used if the user wants a fail-fast strategy with fallible tasks. 111 | 112 | It aborts running futures when any one of them completes with a `Err`. 113 | 114 | 115 | 116 | # Dev 117 | 118 | pre-commit hook setup: `cargo run --bin install-pre-commit-hook`. 119 | -------------------------------------------------------------------------------- /async_dag/src/task.rs: -------------------------------------------------------------------------------- 1 | use crate::any::IntoAny; 2 | use crate::tuple::Tuple; 3 | use seq_macro::seq; 4 | use std::any::type_name; 5 | use std::future::Future; 6 | use std::marker::PhantomData; 7 | 8 | /// An async task. 9 | pub trait TryTask<'a>: std::fmt::Debug { 10 | /// Tuple of inputs. 11 | type Inputs: Tuple; 12 | 13 | /// Successful output. 14 | type Ok: IntoAny; 15 | 16 | /// Error output. 17 | type Err: 'a; 18 | 19 | /// Output future. 20 | type Future: Future> + Send + 'a; 21 | 22 | /// Runs the task and gets a future. 23 | fn run(self, inputs: Self::Inputs) -> Self::Future; 24 | } 25 | 26 | /// Conversion to a [`TryTask`]. 27 | pub trait IntoTryTask<'a, Args, Ok, Err> { 28 | /// The [`TryTask`] type. 29 | type Task: TryTask<'a, Ok = Ok, Err = Err> + 'a; 30 | 31 | /// The conversion. 32 | fn into_task(self) -> Self::Task; 33 | } 34 | 35 | /// A [`TryTask`] for types that implement [`FnOnce`]. 36 | pub struct FnOnceTask { 37 | function: Fn, 38 | ok: PhantomData, 39 | err: PhantomData, 40 | fut: PhantomData, 41 | args: PhantomData, 42 | } 43 | 44 | impl FnOnceTask { 45 | fn new(function: Fn) -> Self { 46 | FnOnceTask { 47 | function, 48 | ok: Default::default(), 49 | err: Default::default(), 50 | fut: Default::default(), 51 | args: Default::default(), 52 | } 53 | } 54 | } 55 | 56 | impl std::fmt::Debug for FnOnceTask { 57 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 58 | f.write_str(&format!( 59 | "FnOnceTask{} -> impl Future {{ ... }}", 60 | type_name::(), 61 | type_name::(), 62 | type_name::(), 63 | )) 64 | } 65 | } 66 | 67 | macro_rules! task_impl { 68 | ($N:literal) => { 69 | seq!(i in 0..$N { 70 | impl<'a, Fn, Ok, Err, Fut, #(I~i,)*> IntoTryTask<'a, (#(I~i,)*), Ok, Err> for Fn 71 | where 72 | Fn: FnOnce(#(I~i,)*) -> Fut + 'a, 73 | Ok: IntoAny, 74 | Err: 'a, 75 | Fut: Future> + Send + 'a, 76 | #( 77 | I~i: IntoAny, 78 | )* 79 | { 80 | type Task = FnOnceTask; 81 | 82 | fn into_task(self) -> Self::Task { 83 | FnOnceTask::new(self) 84 | } 85 | } 86 | 87 | impl<'a, Fn, Ok, Err, Fut, #(I~i,)*> TryTask<'a> for FnOnceTask 88 | where 89 | Fn: FnOnce(#(I~i,)*) -> Fut, 90 | Ok: IntoAny, 91 | Err: 'a, 92 | Fut: Future> + Send + 'a, 93 | #( 94 | I~i: IntoAny, 95 | )* 96 | { 97 | type Inputs = (#(I~i,)*); 98 | type Ok = Ok; 99 | type Err = Err; 100 | type Future = Fut; 101 | fn run(self, (#(v~i,)*): Self::Inputs) -> Self::Future { 102 | (self.function)(#(v~i,)*) 103 | } 104 | } 105 | }); 106 | }; 107 | } 108 | 109 | seq!(N in 0..=12 { 110 | task_impl!(N); 111 | }); 112 | 113 | mod infallible; 114 | 115 | pub use infallible::*; 116 | -------------------------------------------------------------------------------- /async_dag/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! `async_dag` is an async task scheduling utility. 2 | //! 3 | //! When async tasks and their dependencies can be described by a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph), 4 | //! this crate ensures the tasks are run at maximum posiible parallelism. 5 | //! 6 | //! # Example 7 | //! 8 | //! Say there are several tasks which either produces an `i32` or sums two `i32`s, 9 | //! and they have dependency relationship described by following graph, 10 | //! 11 | //! ```text 12 | //! 7 13 | //! / \ 14 | //! 3 \ 15 | //! / \ \ 16 | //! 1 2 4 17 | //! ``` 18 | //! 19 | //! which means there are three tasks producing value `1`, `2` and `4`, 20 | //! a task summing `1` and `2` to get `3`, 21 | //! and a task summing `3` and `4` to get the final output, `7`. 22 | //! 23 | //! A casual developer may write 24 | //! 25 | //! ```ignore 26 | //! let _3 = sum(_1.await, _2.await).await; 27 | //! let _7 = sum(_3, _4.await).await; 28 | //! ``` 29 | //! 30 | //! Above code is inefficient because every task only begins after the previous one completes. 31 | //! 32 | //! A better version would be 33 | //! 34 | //! ```ignore 35 | //! let (_1, _2, _4) = join!(_1, _2, _4).await; 36 | //! let _3 = sum(_1, _2).await; 37 | //! let _7 = sum(_3, _4).await; 38 | //! ``` 39 | //! 40 | //! where `_1`, `_2` and `_4` run in parallel. 41 | //! 42 | //! However, above scheduling is still not optimal 43 | //! because the summing of `_1` and `_2` can run in parallel with `_4`. 44 | //! 45 | //! To acheive maximum parallelism, one has to write something like 46 | //! 47 | //! ```ignore 48 | //! let _1_2 = join!(_1, _2); 49 | //! let (_3, _4) = select! { 50 | //! _3 = _1_2 => { 51 | //! (_3, _4.await) 52 | //! } 53 | //! _4 = _4 => { 54 | //! let (_1, _2) = _1_2.await; 55 | //! (sum(_1, _2).await, _4) 56 | //! } 57 | //! } 58 | //! let _7 = sum(_3, _4).await; 59 | //! ``` 60 | //! 61 | //! The code is quite obscure 62 | //! and the manual scheduling quickly becomes tiring, 63 | //! if possible at all, with a few more tasks and dependencies. 64 | //! 65 | //! With `async_dag`, one can write 66 | //! 67 | //! ``` 68 | //! use async_dag::Graph; 69 | //! 70 | //! async fn sum(lhs: i32, rhs: i32) -> i32 { lhs + rhs } 71 | //! 72 | //! async fn run() { 73 | //! let mut graph = Graph::new(); 74 | //! // The closures are not run yet. 75 | //! let _1 = graph.add_task(|| async { 1 } ); 76 | //! let _2 = graph.add_task(|| async { 2 } ); 77 | //! let _4 = graph.add_task(|| async { 4 } ); 78 | //! 79 | //! // Sets `_1` as `_3`'s first parameter. 80 | //! let _3 = graph.add_child_task(_1, sum, 0).unwrap(); 81 | //! // Sets `_2` as `_3`'s second parameter. 82 | //! graph.update_dependency(_2, _3, 1).unwrap(); 83 | //! 84 | //! // Sets `_3` as `_7`'s first parameter. 85 | //! let _7 = graph.add_child_task(_3, sum, 0).unwrap(); 86 | //! // Sets `_4` as `_7`'s second parameter. 87 | //! graph.update_dependency(_4, _7, 1).unwrap(); 88 | //! 89 | //! // Runs all the tasks with maximum possible parallelism. 90 | //! graph.run().await; 91 | //! 92 | //! assert_eq!(graph.get_value::(_7).unwrap(), 7); 93 | //! } 94 | //! 95 | //! use futures::executor::block_on; 96 | //! block_on(run()); 97 | //! 98 | //! ``` 99 | //! 100 | //! # Fail-fast graphs 101 | //! 102 | //! `TryGraph` can be used if the user wants a fail-fast strategy with fallible tasks. 103 | //! 104 | //! It aborts running futures when any one of them completes with a `Err`. 105 | 106 | #![deny(warnings)] 107 | #![warn( 108 | elided_lifetimes_in_paths, 109 | explicit_outlives_requirements, 110 | keyword_idents, 111 | macro_use_extern_crate, 112 | meta_variable_misuse, 113 | missing_abi, 114 | missing_docs, 115 | missing_debug_implementations, 116 | non_ascii_idents, 117 | noop_method_call, 118 | pointer_structural_match, 119 | trivial_casts, 120 | trivial_numeric_casts, 121 | unsafe_code, 122 | unsafe_op_in_unsafe_fn, 123 | unused_import_braces, 124 | unused_lifetimes, 125 | unused_qualifications, 126 | unused_results, 127 | variant_size_differences 128 | )] 129 | 130 | mod any; 131 | mod curry; 132 | mod graph; 133 | mod task; 134 | mod tuple; 135 | 136 | pub use any::IntoAny; 137 | pub use any::TypeInfo; 138 | pub use curry::Curry; 139 | pub use graph::*; 140 | pub use task::{IntoInfallibleTask, IntoTryTask, TryTask}; 141 | -------------------------------------------------------------------------------- /async_dag/src/graph/runner.rs: -------------------------------------------------------------------------------- 1 | use crate::any::DynAny; 2 | use crate::any::TypeInfo; 3 | use crate::curry::TaskFuture; 4 | use crate::graph::Edge; 5 | use crate::graph::Node; 6 | use crate::graph::NodeIndex; 7 | use daggy::petgraph::visit::EdgeRef; 8 | use daggy::petgraph::visit::IntoEdgesDirected; 9 | use daggy::petgraph::Direction; 10 | use daggy::Dag; 11 | use futures::future::select_all; 12 | use futures::FutureExt; 13 | use std::future::Future; 14 | use std::mem::swap; 15 | use std::task::Poll; 16 | 17 | struct RunningNode<'a, Err> { 18 | index: NodeIndex, 19 | future: TaskFuture<'a, Err>, 20 | } 21 | 22 | impl<'a, Err> Future for RunningNode<'a, Err> { 23 | type Output = (NodeIndex, Result); 24 | 25 | fn poll( 26 | mut self: std::pin::Pin<&mut Self>, 27 | cx: &mut std::task::Context<'_>, 28 | ) -> Poll { 29 | match self.future.poll_unpin(cx) { 30 | Poll::Pending => Poll::Pending, 31 | Poll::Ready(output) => Poll::Ready((self.index, output)), 32 | } 33 | } 34 | } 35 | 36 | // Puts `node` to running if it contains a ready [Curry], doesn't change it otherwise. 37 | fn call_node<'a, Err>(node: &mut Node<'a, Err>) -> Option> { 38 | // Make a placeholder and swap `node` out. 39 | let mut owned_node = Node::Running(TypeInfo::of::<()>()); 40 | swap(node, &mut owned_node); 41 | 42 | if let Node::Curry(curry) = owned_node { 43 | if curry.ready() { 44 | *node = Node::Running(curry.output_type_info()); 45 | Some(curry.call().unwrap()) 46 | } else { 47 | *node = Node::Curry(curry); 48 | None 49 | } 50 | } else { 51 | *node = owned_node; 52 | None 53 | } 54 | } 55 | 56 | /// The async DAG driver algorithm. 57 | pub struct Runner<'task, 'graph, Err> { 58 | // We only modify node weights inside `node_graph`, don't change its structure. 59 | node_graph: &'graph mut Dag, Edge>, 60 | // `edge_graph` has the same structure as `node_graph`, 61 | // so we can access connection information and modify node weights simutaneously. 62 | edge_graph: Dag<(), Edge>, 63 | running: Vec>, 64 | } 65 | 66 | impl<'task, 'graph, Err> Runner<'task, 'graph, Err> { 67 | /// Creates a new runner from a [Graph]. 68 | /// 69 | /// The `graph` must have been type checked. 70 | /// If dropped before running completes, some tasks will be cancelled and forever lost. 71 | pub fn new(graph: &'graph mut Dag, Edge>) -> Self { 72 | let mut running = vec![]; 73 | 74 | for index in 0..graph.node_count() { 75 | let index = NodeIndex::new(index); 76 | let node = graph.node_weight_mut(index).unwrap(); 77 | if let Some(future) = call_node(node) { 78 | running.push(RunningNode { index, future }); 79 | } 80 | } 81 | 82 | let edge_graph = graph.map(|_, _| (), |_, edge| *edge); 83 | 84 | Self { 85 | node_graph: graph, 86 | edge_graph, 87 | running, 88 | } 89 | } 90 | 91 | /// Runs the algorithm. 92 | /// 93 | /// If the returned future is dropped before completion or client error happens, 94 | /// some tasks will be cancelled and forever lost. 95 | pub async fn run(&mut self) -> Result<(), Err> { 96 | while !self.running.is_empty() { 97 | self.step().await?; 98 | } 99 | Ok(()) 100 | } 101 | 102 | /// Polls until one running node is completed. 103 | /// 104 | /// Curries dependent nodes and returns early on error. 105 | async fn step(&mut self) -> Result<(), Err> { 106 | // Swap out `self.running` for `select_all`. 107 | let mut running = vec![]; 108 | swap(&mut self.running, &mut running); 109 | 110 | // If client error happens, return early and drop running futures. 111 | let ((node_index, result), _, running) = select_all(running).await; 112 | let output = result?; 113 | 114 | // Assign back to `self.running`. 115 | self.running = running; 116 | 117 | // Traverse outgoing edges of completed node. 118 | for edge in self 119 | .edge_graph 120 | .edges_directed(node_index, Direction::Outgoing) 121 | { 122 | let child_index = edge.target(); 123 | let child_node = self.node_graph.node_weight_mut(child_index).unwrap(); 124 | 125 | if let Node::Curry(curry) = child_node { 126 | let input_index = *edge.weight(); 127 | curry.curry(input_index, output.clone()).unwrap(); 128 | } 129 | 130 | if let Some(future) = call_node(child_node) { 131 | self.running.push(RunningNode { 132 | index: child_index, 133 | future, 134 | }); 135 | } 136 | } 137 | 138 | let node = self.node_graph.node_weight_mut(node_index).unwrap(); 139 | // It must be `Running`. 140 | let type_info = match node { 141 | Node::Running(type_info) => *type_info, 142 | _ => panic!("Expecting running state"), 143 | }; 144 | *self.node_graph.node_weight_mut(node_index).unwrap() = Node::Value { 145 | value: output, 146 | type_info, 147 | }; 148 | 149 | Ok(()) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /async_dag/src/tuple.rs: -------------------------------------------------------------------------------- 1 | //! Utility structs and traits for manipulating tuples and tuple of [`Option`]s. 2 | 3 | use crate::any::DynAny; 4 | use crate::any::TypeInfo; 5 | use seq_macro::seq; 6 | use std::any::{type_name, Any, TypeId}; 7 | 8 | /// Type used for indexing a [`TupleOption`]. 9 | pub type TupleIndex = u8; 10 | 11 | /// The error that can happen when inserting to a [`TupleOption`]. 12 | #[derive(Debug)] 13 | pub struct InsertError { 14 | /// The error kind. 15 | pub kind: InsertErrorKind, 16 | /// The value that was inserted when this error happens. 17 | pub value: Box, 18 | } 19 | 20 | /// The [`InsertError`] kind. 21 | #[derive(Debug)] 22 | pub enum InsertErrorKind { 23 | /// The inserted value's type is not the expected one. 24 | TypeMismatch { 25 | /// The expected type's [`TypeId`]. 26 | expected: TypeId, 27 | /// The expected type's name. 28 | expected_name: &'static str, 29 | }, 30 | /// The inserting index is out of range. 31 | OutOfRange, 32 | } 33 | 34 | impl std::fmt::Display for InsertError { 35 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 36 | f.debug_struct("InsertError") 37 | .field("kind", &self.kind) 38 | .field("value", &self.value.type_id()) 39 | .finish() 40 | } 41 | } 42 | 43 | impl std::error::Error for InsertError {} 44 | 45 | /// The result of inserting to a [`TupleOption`]. 46 | pub type InsertResult = Result<(), InsertError>; 47 | 48 | /// The error that can happen when taking from [`TupleOption`]. 49 | #[derive(Debug)] 50 | pub struct TakeError { 51 | /// The first missing input's index. 52 | pub index: TupleIndex, 53 | } 54 | 55 | impl std::fmt::Display for TakeError { 56 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 57 | f.debug_struct("TakeError") 58 | .field("index", &self.index) 59 | .finish() 60 | } 61 | } 62 | 63 | impl std::error::Error for TakeError {} 64 | 65 | /// Implemented for all [`Sized`] + `'static` tuple of [`Option`]s. 66 | pub trait TupleOption: Default { 67 | /// Returns index of the first element that is [`None`]. 68 | fn first_none(&self) -> Option; 69 | 70 | /// Inserts `value` at `index`. 71 | /// 72 | /// `self` is unchanged on error. 73 | fn insert(&mut self, index: TupleIndex, value: DynAny) -> InsertResult; 74 | 75 | /// Takes the values out. 76 | /// 77 | /// `self` is unchanged on error. 78 | fn take(&mut self) -> Result; 79 | } 80 | 81 | /// Implemented for all [`Sized`] + `'static` tuples. 82 | pub trait Tuple: Sized { 83 | /// The corresponding tuple of [`Option`]s. 84 | type Option: TupleOption; 85 | 86 | /// Length of the tuple. 87 | const LEN: TupleIndex; 88 | 89 | /// [`TypeId`] and name of the type at `index`. 90 | /// 91 | /// Returns [`None`] if `index` is out of range. 92 | fn type_info(index: TupleIndex) -> Option; 93 | } 94 | 95 | macro_rules! tupl_impl { 96 | ($N:literal) => { 97 | seq!(i in 0..$N { 98 | impl<#(T~i: Any,)*> TupleOption<(#(T~i,)*)> for (#(Option,)*) { 99 | fn first_none(&self) -> Option { 100 | #( 101 | if self.i.is_none() { 102 | return Some(i); 103 | } 104 | )* 105 | None 106 | } 107 | 108 | fn insert(&mut self, index: TupleIndex, value: DynAny) -> InsertResult { 109 | #[allow(clippy::match_single_binding)] 110 | match index { 111 | #( 112 | i => match Box::::downcast::(value.into_any()) { 113 | Ok(t) => { 114 | self.i = Some(*t); 115 | Ok(()) 116 | } 117 | Err(value) => Err(InsertError { 118 | kind: InsertErrorKind::TypeMismatch { 119 | expected: TypeId::of::(), 120 | expected_name: type_name::(), 121 | }, 122 | value, 123 | }), 124 | }, 125 | )* 126 | _ => Err(InsertError { 127 | kind: InsertErrorKind::OutOfRange, 128 | value: value.into_any(), 129 | }), 130 | } 131 | } 132 | 133 | fn take(&mut self) -> Result<(#(T~i,)*), TakeError> { 134 | match self.first_none() { 135 | Some(index) => Err(TakeError { index }), 136 | None => Ok((#(self.i.take().unwrap(),)*)), 137 | } 138 | } 139 | } 140 | }); 141 | 142 | seq!(i in 0..$N { 143 | impl<#(T~i: Any,)*> Tuple for (#(T~i,)*) { 144 | type Option = (#(Option,)*); 145 | 146 | const LEN: TupleIndex = $N; 147 | 148 | fn type_info(index: TupleIndex) -> Option { 149 | #[allow(clippy::match_single_binding)] 150 | match index { 151 | #( 152 | i => Some(TypeInfo::of::()), 153 | )* 154 | _ => None, 155 | } 156 | } 157 | } 158 | }); 159 | }; 160 | } 161 | 162 | seq!(N in 0..=12 { 163 | #( 164 | tupl_impl!(N); 165 | )* 166 | }); 167 | 168 | #[cfg(test)] 169 | mod tests { 170 | use super::*; 171 | 172 | #[test] 173 | fn test_mismatch_type_name() { 174 | let mut option: (Option,) = (None,); 175 | let error = option.insert(0, Box::new(0.0f32)).unwrap_err(); 176 | let expected_name = match error.kind { 177 | InsertErrorKind::TypeMismatch { expected_name, .. } => expected_name, 178 | _ => panic!("Expecting TypeMismatch"), 179 | }; 180 | assert!(expected_name.contains("i32")); 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /async_dag/src/graph.rs: -------------------------------------------------------------------------------- 1 | pub mod error; 2 | mod runner; 3 | 4 | use crate::any::downcast; 5 | use crate::any::DynAny; 6 | use crate::any::IntoAny; 7 | use crate::any::TypeInfo; 8 | use crate::curry::CurriedTask; 9 | use crate::curry::Curry; 10 | use crate::task::IntoTryTask; 11 | use crate::task::TryTask; 12 | use crate::tuple::Tuple; 13 | use crate::tuple::TupleIndex; 14 | use daggy::EdgeIndex; 15 | use error::Error; 16 | use error::ErrorWithTask; 17 | use runner::Runner; 18 | use std::any::type_name; 19 | use std::collections::HashMap; 20 | 21 | /// A [`Box`]ed [`Curry`]. 22 | type DynCurry<'a, Err> = Box + 'a>; 23 | 24 | impl<'a, Err> std::fmt::Debug for DynCurry<'a, Err> { 25 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 26 | f.debug_struct(&format!("Curry<{}>", type_name::())) 27 | .finish_non_exhaustive() 28 | } 29 | } 30 | 31 | /// Node type. 32 | /// 33 | /// A node is either a [`Curry`], running (with a certain output type), 34 | /// or the [`Curry`]'s awaited successful calling output. 35 | #[derive(Debug)] 36 | pub enum Node<'a, Err> { 37 | /// A [`Curry`]. 38 | Curry(DynCurry<'a, Err>), 39 | /// A running node. 40 | /// 41 | /// The [`Curry`] is called and result future is stored elsewhere and perhaps running. 42 | Running(TypeInfo), 43 | /// A successful output from a completed [`TryTask`](crate::task::TryTask). 44 | Value { 45 | /// The output value. 46 | value: DynAny, 47 | /// The output type. 48 | type_info: TypeInfo, 49 | }, 50 | } 51 | 52 | impl<'a, Err> Node<'a, Err> { 53 | /// Converts this [`Node`] into a concrete type. 54 | /// 55 | /// Returns `self` on failure. 56 | pub fn downcast(self) -> Result { 57 | if let Self::Value { value, type_info } = self { 58 | match downcast(value) { 59 | Ok(value) => Ok(value), 60 | Err(value) => Err(Node::Value { value, type_info }), 61 | } 62 | } else { 63 | Err(self) 64 | } 65 | } 66 | } 67 | 68 | /// Node identifier. 69 | pub type NodeIndex = daggy::NodeIndex; 70 | 71 | /// Edge type. 72 | /// 73 | /// An edge connects parent node's task's output to child node's task's input. 74 | /// Its value is the input index. 75 | pub type Edge = TupleIndex; 76 | 77 | /// An async task DAG. 78 | #[derive(Debug, Default)] 79 | pub struct TryGraph<'a, Err: 'a> { 80 | dag: daggy::Dag, Edge>, 81 | dependencies: HashMap<(NodeIndex, Edge), EdgeIndex>, 82 | } 83 | 84 | impl<'a, Err: 'a> TryGraph<'a, Err> { 85 | /// Creates an empty [`TryGraph`]. 86 | pub fn new() -> Self { 87 | Self { 88 | dag: Default::default(), 89 | dependencies: Default::default(), 90 | } 91 | } 92 | 93 | /// Converts `self` into an iterator of [`Node`]s. 94 | /// 95 | /// Client can use this method and previous returned [`NodeIndex`]s to retrive the graph running result. 96 | pub fn into_nodes(self) -> impl Iterator> { 97 | self.dag 98 | .into_graph() 99 | .into_nodes_edges() 100 | .0 101 | .into_iter() 102 | .map(|node| node.weight) 103 | } 104 | 105 | /// Gets the output value of `node`. 106 | /// 107 | /// Returns [`None`] if the `node`'s task hasn't done running or the type does not match. 108 | /// 109 | /// **Panics** if `node` does not exist within the graph. 110 | pub fn get_value(&self, node: NodeIndex) -> Option { 111 | match self.dag.node_weight(node).unwrap() { 112 | Node::Value { value, .. } => downcast(value.clone()).ok(), 113 | _ => None, 114 | } 115 | } 116 | 117 | /// Adds a task without specifying its dependencies. 118 | /// 119 | /// Returns the [`NodeIndex`] representing this task. 120 | /// 121 | /// **Panics** if the graph is at the maximum number of nodes for its index type. 122 | pub fn add_try_task>( 123 | &mut self, 124 | task: T, 125 | ) -> NodeIndex { 126 | self.add_task_impl(task.into_task()) 127 | } 128 | 129 | fn add_task_impl + 'a>(&mut self, task: T) -> NodeIndex { 130 | self.dag.add_node(Self::make_node(task)) 131 | } 132 | 133 | /// Adds a task and set it as `child`'s dependency at `index`. 134 | /// 135 | /// Returns the [`NodeIndex`] representing the added task. 136 | /// 137 | /// If child already has a dependency at `index`, it will be removed. But the depended node won't. 138 | /// 139 | /// This is more efficient than [`TryGraph::add_task`] then [`TryGraph::update_dependency`]. 140 | /// 141 | /// **Panics** if the graph is at the maximum number of nodes for its index type. 142 | /// 143 | /// **Panics** if `child` does not exist within the graph. 144 | pub fn add_parent_try_task>( 145 | &mut self, 146 | task: T, 147 | child: NodeIndex, 148 | index: Edge, 149 | ) -> Result> { 150 | self.add_parent_task_impl::(task.into_task(), child, index) 151 | } 152 | 153 | fn add_parent_task_impl + 'a>( 154 | &mut self, 155 | task: T, 156 | child: NodeIndex, 157 | index: Edge, 158 | ) -> Result> { 159 | if let Err(error) = self.type_check(child, index, TypeInfo::of::()) { 160 | return Err(ErrorWithTask { error, task }); 161 | } 162 | #[allow(unused_results)] 163 | { 164 | self.remove_dependency(child, index); 165 | } 166 | let (edge, node) = self.dag.add_parent(child, index, Self::make_node(task)); 167 | assert!(self.dependencies.insert((child, index), edge).is_none()); 168 | Ok(node) 169 | } 170 | 171 | /// Adds a task and set it's dependency at `index` as `parent`. 172 | /// 173 | /// Returns the [`NodeIndex`] representing the added task. 174 | /// 175 | /// This is more efficient than [`TryGraph::add_task`] then [`TryGraph::update_dependency`]. 176 | /// 177 | /// **Panics** if the graph is at the maximum number of nodes for its index type. 178 | /// 179 | /// **Panics** if `parent` does not exist within the graph. 180 | pub fn add_child_try_task>( 181 | &mut self, 182 | parent: NodeIndex, 183 | task: T, 184 | index: Edge, 185 | ) -> Result> { 186 | self.add_child_task_impl::(parent, task.into_task(), index) 187 | } 188 | 189 | fn add_child_task_impl + 'a>( 190 | &mut self, 191 | parent: NodeIndex, 192 | task: T, 193 | index: Edge, 194 | ) -> Result> { 195 | let input_type_info = match T::Inputs::type_info(index) { 196 | Some(type_info) => type_info, 197 | None => { 198 | return Err(ErrorWithTask { 199 | error: Error::OutOfRange(T::Inputs::LEN), 200 | task, 201 | }) 202 | } 203 | }; 204 | let output_type_info = self.output_type_info(parent); 205 | if let Err(error) = check_type_equality(input_type_info, output_type_info) { 206 | return Err(ErrorWithTask { error, task }); 207 | } 208 | let (edge, node) = self.dag.add_child(parent, index, Self::make_node(task)); 209 | assert!(self.dependencies.insert((node, index), edge).is_none()); 210 | Ok(node) 211 | } 212 | 213 | /// Sets `parent` as `child`'s dependency at `index`. 214 | /// 215 | /// If child already has a dependency at `index`, it will be removed. But the depended node won't. 216 | /// 217 | /// **Panics** if either `parent` or `child` does not exist within the graph. 218 | /// 219 | /// **Panics** if the graph is at the maximum number of edges for its index type. 220 | pub fn update_dependency( 221 | &mut self, 222 | parent: NodeIndex, 223 | child: NodeIndex, 224 | index: Edge, 225 | ) -> Result<(), Error> { 226 | self.type_check(child, index, self.output_type_info(parent))?; 227 | #[allow(unused_results)] 228 | { 229 | self.remove_dependency(child, index); 230 | } 231 | let edge = self 232 | .dag 233 | .add_edge(parent, child, index) 234 | .map_err(|_| Error::WouldCycle)?; 235 | assert!(self.dependencies.insert((child, index), edge).is_none()); 236 | Ok(()) 237 | } 238 | 239 | /// Remove `child`'s dependency at `index` if it has one. 240 | /// 241 | /// Returns `true` if `child` has a dependency at `index` before removing. 242 | pub fn remove_dependency(&mut self, child: NodeIndex, index: Edge) -> bool { 243 | let edge = self.dependencies.remove(&(child, index)); 244 | if let Some(edge) = edge { 245 | assert!(self.dag.remove_edge(edge).is_some()); 246 | true 247 | } else { 248 | false 249 | } 250 | } 251 | 252 | /// Progresses the whole task graph as much as possible, but aborts on first error. 253 | /// 254 | /// If the returned future is dropped before completion, or an error occurs, some tasks will be cancelled and forever lost. 255 | /// Corresponding [`Node`] will be set to [`Node::Running`]. 256 | pub async fn try_run(&mut self) -> Result<(), Err> { 257 | let mut runner = Runner::new(&mut self.dag); 258 | runner.run().await 259 | } 260 | 261 | fn type_check( 262 | &self, 263 | child: NodeIndex, 264 | index: Edge, 265 | output_type_info: TypeInfo, 266 | ) -> Result<(), Error> { 267 | let node = self.dag.node_weight(child).unwrap(); 268 | let curry = match node { 269 | Node::Curry(curry) => curry, 270 | _ => return Err(Error::HasStarted(child)), 271 | }; 272 | let input_type_info = curry 273 | .input_type_info(index) 274 | .ok_or_else(|| Error::OutOfRange(curry.num_inputs()))?; 275 | check_type_equality(input_type_info, output_type_info)?; 276 | Ok(()) 277 | } 278 | 279 | fn make_node + 'a>(task: T) -> Node<'a, Err> { 280 | let curry = CurriedTask::new(task); 281 | Node::Curry(Box::new(curry)) 282 | } 283 | 284 | fn output_type_info(&self, index: NodeIndex) -> TypeInfo { 285 | let node = self.dag.node_weight(index).unwrap(); 286 | match node { 287 | Node::Curry(curry) => curry.output_type_info(), 288 | Node::Running(type_info) => *type_info, 289 | Node::Value { type_info, .. } => *type_info, 290 | } 291 | } 292 | } 293 | 294 | fn check_type_equality(input: TypeInfo, output: TypeInfo) -> Result<(), Error> { 295 | if input != output { 296 | Err(Error::TypeMismatch { input, output }) 297 | } else { 298 | Ok(()) 299 | } 300 | } 301 | 302 | mod infallible; 303 | 304 | pub use infallible::*; 305 | 306 | #[cfg(test)] 307 | mod tests { 308 | use super::*; 309 | use futures::executor::block_on; 310 | use std::any::TypeId; 311 | 312 | #[test] 313 | fn test_diamond_shape_graph() { 314 | let mut graph = Graph::new(); 315 | 316 | let root = graph.add_task(|lhs: i32, rhs: i32| async move { lhs + rhs }); 317 | let lhs = graph 318 | .add_parent_task(|v: i32| async move { v }, root, 0) 319 | .unwrap(); 320 | let rhs = graph 321 | .add_parent_task(|v: i32| async move { v }, root, 1) 322 | .unwrap(); 323 | let input = graph.add_parent_task(|| async move { 1 }, lhs, 0).unwrap(); 324 | graph.update_dependency(input, rhs, 0).unwrap(); 325 | 326 | block_on(graph.run()); 327 | 328 | let result = graph.get_value::(root).unwrap(); 329 | assert_eq!(result, 2); 330 | } 331 | 332 | #[test] 333 | fn test_client_error() { 334 | let mut graph = TryGraph::new(); 335 | let _ = graph.add_try_task::<_, (), _>(|| async { Err(()) }); 336 | block_on(graph.try_run()).unwrap_err(); 337 | } 338 | 339 | #[test] 340 | fn test_has_started_check() { 341 | let mut graph = Graph::new(); 342 | let root = graph.add_task(|_: ()| async { () }); 343 | let parent = graph.add_parent_task(|| async { () }, root, 0).unwrap(); 344 | block_on(graph.run()); 345 | let error = graph.update_dependency(parent, root, 0).unwrap_err(); 346 | let index = match error { 347 | Error::HasStarted(index) => index, 348 | _ => panic!("Expecting has started error"), 349 | }; 350 | assert_eq!(index, root); 351 | } 352 | 353 | #[test] 354 | fn test_type_check() { 355 | let mut graph = Graph::new(); 356 | let root = graph.add_task(|_: ()| async { () }); 357 | 358 | let error = graph.type_check(root, 1, TypeInfo::of::<()>()).unwrap_err(); 359 | let len = match error { 360 | Error::OutOfRange(len) => len, 361 | _ => panic!("Expecting out of range error"), 362 | }; 363 | assert_eq!(len, 1); 364 | 365 | let error = graph 366 | .type_check(root, 0, TypeInfo::of::()) 367 | .unwrap_err(); 368 | let (input, output) = match error { 369 | Error::TypeMismatch { input, output } => (input, output), 370 | _ => panic!("Expecting type mismatch error"), 371 | }; 372 | assert_eq!(input.id(), TypeId::of::<()>()); 373 | assert_eq!(output.id(), TypeId::of::()); 374 | // Name is not guaranteed, but these asserts should be ok... 375 | assert!(input.name().contains("()")); 376 | assert!(output.name().contains("i32")); 377 | } 378 | 379 | #[test] 380 | fn test_cycle_check() { 381 | let mut graph = Graph::new(); 382 | let root = graph.add_task(|_: ()| async { () }); 383 | let parent = graph 384 | .add_parent_task(|_: ()| async { () }, root, 0) 385 | .unwrap(); 386 | let error = graph.update_dependency(root, parent, 0).unwrap_err(); 387 | match error { 388 | Error::WouldCycle => (), 389 | _ => panic!("Expecting would cycle error"), 390 | } 391 | } 392 | 393 | #[test] 394 | fn test_remove_dependency() { 395 | let mut graph = Graph::new(); 396 | let root = graph.add_task(|_: ()| async { () }); 397 | assert!(!graph.remove_dependency(root, 0)); 398 | let _ = graph.add_parent_task(|| async { () }, root, 0).unwrap(); 399 | assert!(graph.remove_dependency(root, 0)); 400 | } 401 | 402 | #[test] 403 | fn test_update_dependency() { 404 | let mut graph = Graph::new(); 405 | let root = graph.add_task(|_: ()| async { () }); 406 | let parent = graph.add_parent_task(|| async { () }, root, 0).unwrap(); 407 | graph.update_dependency(parent, root, 0).unwrap(); 408 | graph.update_dependency(parent, root, 0).unwrap(); 409 | } 410 | } 411 | --------------------------------------------------------------------------------