├── .gitignore ├── Cargo.toml ├── README.md ├── consts ├── Cargo.toml └── src │ └── lib.rs ├── core-macros ├── Cargo.toml └── src │ └── lib.rs ├── core ├── Cargo.toml └── src │ ├── derivatives │ ├── f32.rs │ ├── f64.rs │ ├── i128.rs │ ├── i16.rs │ ├── i32.rs │ ├── i64.rs │ ├── i8.rs │ ├── mod.rs │ ├── u128.rs │ ├── u16.rs │ ├── u32.rs │ ├── u64.rs │ └── u8.rs │ ├── dict.rs │ ├── lib.rs │ └── traits.rs ├── macros ├── Cargo.toml └── src │ ├── forward.rs │ ├── lib.rs │ └── reverse.rs ├── src ├── lib.rs └── main.rs └── tests ├── forward.rs ├── forward_general.rs ├── reverse.rs └── reverse_general.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-ad" 3 | version = "0.8.0" 4 | edition = "2021" 5 | 6 | description = "Rust Auto-Differentiation." 7 | license = "Apache-2.0" 8 | repository = "https://github.com/JonathanWoollett-Light/rust-ad" 9 | documentation = "https://docs.rs/rust-ad/" 10 | readme = "README.md" 11 | exclude = ["/src/main.rs","/tests"] 12 | 13 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 14 | 15 | [dependencies] 16 | syn = { version="1.0.82", features=["full","extra-traits"] } 17 | rust-ad-macros = { version = "0.8.0", path = "macros" } 18 | 19 | # rust-ad-core-macros = { version = "0.8.0", path = "./core-macros" } # TEMP 20 | # rust-ad-core = { version = "0.8.0", path = "./core" } # TEMP 21 | 22 | [workspace] 23 | members = ["./macros","./core","./core-macros","./consts"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RustAD - Rust Auto-Differentiation 2 | 3 | [![Crates.io](https://img.shields.io/crates/v/rust-ad)](https://crates.io/crates/rust-ad) 4 | [![lib.rs.io](https://img.shields.io/crates/v/rust-ad?color=blue&label=lib.rs)](https://lib.rs/crates/rust-ad) 5 | [![docs](https://img.shields.io/crates/v/rust-ad?color=yellow&label=docs)](https://docs.rs/rust-ad) 6 | 7 | A restrictive WIP beginnings of a library attempting to implement auto-differentiation in Rust. 8 | 9 | **Why would I use this over \?** You wouldn't, not yet anyway. I'd say wait until support for ndarray is more comprehensive, then this becomes probably the most convenient Rust AutoDiff library. 10 | 11 | **It's all messy be warned.** 12 | 13 | ## Status 14 | 15 | I'm thinking of transitioning this project to a binary where running `cargo rust-ad` performs auto-diff on any Rust code. This would allow support for [ndarray](https://docs.rs/ndarray/latest/ndarray/), [nalgebra](https://docs.rs/nalgebra/latest/nalgebra/) and any library written in Rust code by performing auto-diff through all dependencies. This would allow a fully generalized conveniant approach to auto-diff producing high level code which can be optimized by the compiler. 16 | 17 | This transition will occur when all support items are covered. 18 | 19 | *These are not ordered.* 20 | 21 | - [x] Forward Auto-differentiation 22 | - [x] Reverse Auto-differentiation 23 | - [x] Numerical primitives (e.g. `f32`, `u32` etc.) support* 24 | - [ ] `if`, `if else` and `else` support 25 | - [ ] `for`, `while` and `loop` support 26 | - [ ] `map` and `fold` 27 | 28 | *`typeof` (e.g. [`decltype`](https://en.cppreference.com/w/cpp/language/decltype)) not being currently implemented in Rust makes support more difficult. 29 | 30 | ## Application 31 | 32 | Auto-differentiation is implemented via 2 attribute procedural macros, e.g. 33 | 34 | ```rust 35 | fn multi_test() { 36 | let (f, (der_x, der_y)) = forward!(multi, 3f32, 5f32); 37 | assert_eq!(f, 15.4f32); 38 | assert_eq!(der_x, 8f32); 39 | assert_eq!(der_y, -0.08f32); 40 | 41 | /// f = x^2 + 2x + 2/y 42 | /// δx|y=5 = 2x + 2 43 | /// δy|x=3 = -2/y^2 44 | #[forward_autodiff] 45 | fn multi(x: f32, y: f32) -> f32 { 46 | let a = x.powi(2i32); 47 | let b = x * 2f32; 48 | let c = 2f32 / y; 49 | let f = a + b + c; 50 | return f; 51 | } 52 | } 53 | ``` 54 | ```rust 55 | fn multi_test() { 56 | let (f, (der_x, der_y)) = reverse!(multi, (3f32, 5f32), (1f32)); 57 | assert_eq!(f, 15.4f32); 58 | assert_eq!(der_x, 8f32); 59 | assert_eq!(der_y, -0.08f32); 60 | 61 | /// f = x^2 + 2x + 2/y 62 | /// δx|y=5 = 2x + 2 63 | /// δy|x=3 = -2/y^2 64 | #[reverse_autodiff] 65 | fn multi(x: f32, y: f32) -> f32 { 66 | let a = x.powi(2i32); 67 | let b = x * 2f32; 68 | let c = 2f32 / y; 69 | let f = a + b + c; 70 | return f; 71 | } 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /consts/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-ad-consts" 3 | version = "0.8.0" 4 | edition = "2021" 5 | 6 | description = "Rust Auto-Differentiation." 7 | license = "Apache-2.0" 8 | repository = "https://github.com/JonathanWoollett-Light/rust-ad" 9 | documentation = "https://docs.rs/rust-ad/" 10 | readme = "../README.md" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | const_format = "0.2.22" -------------------------------------------------------------------------------- /consts/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! **I do not recommend using this directly, please sea [rust-ad](https://crates.io/crates/rust-ad).** 2 | //! 3 | //! Internal constants. 4 | //! 5 | //! Lowest level of dependency. 6 | 7 | use const_format::concatcp; 8 | 9 | /// Prefix used for the derivatives of a variable (e.g. The derivative of `x` would be `der_x`). 10 | pub const DERIVATIVE_PREFIX: &str = "__der_"; 11 | /// Prefix for external forward auto-diff functions. 12 | pub const FORWARD_PREFIX: &str = "__f_"; 13 | /// Prefix for external reverse auto-diff functions. 14 | pub const REVERSE_PREFIX: &str = "__r_"; 15 | /// Suffix for internal functions. 16 | const INTERNAL_SUFFIX: &str = "internal_"; 17 | /// Prefix for internal forward auto-diff functions (e.g. `__f_a_users_function` vs `__f_internal_powi_f32`). 18 | pub const INTERNAL_FORWARD_PREFIX: &str = concatcp!(FORWARD_PREFIX, INTERNAL_SUFFIX); 19 | /// Prefix for internal reverse auto-diff functions (e.g. `__r_a_users_function` vs `__r_internal_powi_f32`). 20 | pub const INTERNAL_REVERSE_PREFIX: &str = concatcp!(REVERSE_PREFIX, INTERNAL_SUFFIX); 21 | 22 | // const RETURN_SUFFIX: &str = "__rtn"; 23 | // pub const REVERSE_RETURN_DERIVATIVE: &str = concatcp!(DERIVATIVE_PREFIX,RETURN_SUFFIX); 24 | 25 | pub const REVERSE_RETURN_DERIVATIVE: &str = "r"; 26 | -------------------------------------------------------------------------------- /core-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-ad-core-macros" 3 | version = "0.8.0" 4 | edition = "2021" 5 | 6 | description = "Rust Auto-Differentiation." 7 | license = "Apache-2.0" 8 | repository = "https://github.com/JonathanWoollett-Light/rust-ad" 9 | documentation = "https://docs.rs/rust-ad/" 10 | readme = "../README.md" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | rust-ad-consts = { version = "0.8.0", path = "../consts" } 16 | 17 | [lib] 18 | proc-macro = true -------------------------------------------------------------------------------- /core-macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! **I do not recommend using this directly, please sea [rust-ad](https://crates.io/crates/rust-ad).** 2 | //! 3 | //! Internal proc-macro functionality. 4 | use proc_macro::{TokenStream, TokenTree}; 5 | use rust_ad_consts::{INTERNAL_FORWARD_PREFIX, INTERNAL_REVERSE_PREFIX}; 6 | 7 | /// Same result as performing `forward_derivative_macro` and `reverse_derivative_macro` consecutively. 8 | /// 9 | /// But this does it a little neater and more efficiently. 10 | #[proc_macro] 11 | pub fn combined_derivative_macro(item: TokenStream) -> TokenStream { 12 | // eprintln!("\nitem:\n{:?}\n",item); 13 | let mut iter = item.into_iter(); 14 | let name = match iter.next() { 15 | Some(TokenTree::Ident(ident)) => ident, 16 | _ => panic!("No function ident"), 17 | }; 18 | let vec = iter.collect::>(); 19 | assert_eq!(vec.len() % 2, 0, "Bad punctuation"); 20 | let num = (vec.len() - 1) / 2; 21 | let mut iter = vec.chunks_exact(2); 22 | 23 | let default = match iter.next() { 24 | Some(item) => { 25 | let (punc, lit) = (&item[0], &item[1]); 26 | match (punc, lit) { 27 | (TokenTree::Punct(_), TokenTree::Literal(default)) => default, 28 | _ => panic!("Bad default value"), 29 | } 30 | } 31 | _ => panic!("No default value"), 32 | }; 33 | 34 | let iter = iter.enumerate(); 35 | let arg_fmt_str = (0..num) 36 | .map(|i| format!("args[{}],", i)) 37 | .collect::(); 38 | 39 | let der_functions = iter 40 | .map(|(index, item)| { 41 | let (punc, lit) = (&item[0], &item[1]); 42 | match (punc, lit) { 43 | (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!( 44 | "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n", 45 | index, format_str, arg_fmt_str 46 | ), 47 | _ => panic!("Bad format strings"), 48 | } 49 | }) 50 | .collect::(); 51 | let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::(); 52 | 53 | let for_out_str = format!( 54 | "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};", 55 | INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str 56 | ); 57 | let rev_out_str = format!( 58 | "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};", 59 | INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str 60 | ); 61 | let out_str = format!("{}\n{}", for_out_str, rev_out_str); 62 | // eprintln!("out_str: \n{}\n",out_str); 63 | out_str.parse().unwrap() 64 | } 65 | 66 | /// Generates forward derivative functions. 67 | /// ```ignore 68 | /// static outer_test: FgdType = { 69 | /// const base_fn: DFn = |args:&[String]| -> String { format!("{0}-{1}",args[0],args[1]) }; 70 | /// const exponent_fn: DFn = |args:&[String]| -> String { format!("{0}*{1}+{0}",args[0],args[1]) }; 71 | /// fgd::<"0f32",{&[base_fn, exponent_fn]}> 72 | /// }; 73 | /// ``` 74 | /// Is equivalent to 75 | /// ```ignore 76 | /// forward_derivative_macro!(outer_test,"0f32","{0}-{1}","{0}*{1}+{0}"); 77 | /// ``` 78 | #[proc_macro] 79 | pub fn forward_derivative_macro(item: TokenStream) -> TokenStream { 80 | // eprintln!("\nitem:\n{:?}\n",item); 81 | let mut iter = item.into_iter(); 82 | let name = match iter.next() { 83 | Some(TokenTree::Ident(ident)) => ident, 84 | _ => panic!("No function ident"), 85 | }; 86 | let vec = iter.collect::>(); 87 | assert_eq!(vec.len() % 2, 0, "Bad punctuation"); 88 | let num = (vec.len() - 1) / 2; 89 | let mut iter = vec.chunks_exact(2); 90 | 91 | let default = match iter.next() { 92 | Some(item) => { 93 | let (punc, lit) = (&item[0], &item[1]); 94 | match (punc, lit) { 95 | (TokenTree::Punct(_), TokenTree::Literal(default)) => default, 96 | _ => panic!("Bad default value"), 97 | } 98 | } 99 | _ => panic!("No default value"), 100 | }; 101 | 102 | let iter = iter.enumerate(); 103 | let arg_fmt_str = (0..num) 104 | .map(|i| format!("args[{}],", i)) 105 | .collect::(); 106 | 107 | let der_functions = iter 108 | .map(|(index, item)| { 109 | let (punc, lit) = (&item[0], &item[1]); 110 | match (punc, lit) { 111 | (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!( 112 | "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n", 113 | index, format_str, arg_fmt_str 114 | ), 115 | _ => panic!("Bad format strings"), 116 | } 117 | }) 118 | .collect::(); 119 | let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::(); 120 | let out_str = format!( 121 | "pub static {}{}: FgdType = {{\n{}\n\tfgd::<{{ {} }},{{ &[{}] }}>\n}};", 122 | INTERNAL_FORWARD_PREFIX, name, der_functions, default, fn_fmt_str 123 | ); 124 | // eprintln!("out_str: \n{}\n",out_str); 125 | out_str.parse().unwrap() 126 | } 127 | /// Generates reverse derivative functions. 128 | #[proc_macro] 129 | pub fn reverse_derivative_macro(item: TokenStream) -> TokenStream { 130 | // eprintln!("\nitem:\n{:?}\n",item); 131 | let mut iter = item.into_iter(); 132 | let name = match iter.next() { 133 | Some(TokenTree::Ident(ident)) => ident, 134 | _ => panic!("No function ident"), 135 | }; 136 | let vec = iter.collect::>(); 137 | assert_eq!(vec.len() % 2, 0, "Bad punctuation"); 138 | let num = (vec.len() - 1) / 2; 139 | let mut iter = vec.chunks_exact(2); 140 | 141 | let default = match iter.next() { 142 | Some(item) => { 143 | let (punc, lit) = (&item[0], &item[1]); 144 | match (punc, lit) { 145 | (TokenTree::Punct(_), TokenTree::Literal(default)) => default, 146 | _ => panic!("Bad default value"), 147 | } 148 | } 149 | _ => panic!("No default value"), 150 | }; 151 | 152 | let iter = iter.enumerate(); 153 | let arg_fmt_str = (0..num) 154 | .map(|i| format!("args[{}],", i)) 155 | .collect::(); 156 | 157 | let der_functions = iter 158 | .map(|(index, item)| { 159 | let (punc, lit) = (&item[0], &item[1]); 160 | match (punc, lit) { 161 | (TokenTree::Punct(_), TokenTree::Literal(format_str)) => format!( 162 | "\tconst f{}: DFn = |args: &[Arg]| -> String {{ compose!({},{}) }};\n", 163 | index, format_str, arg_fmt_str 164 | ), 165 | _ => panic!("Bad format strings"), 166 | } 167 | }) 168 | .collect::(); 169 | let fn_fmt_str = (0..num).map(|i| format!("f{},", i)).collect::(); 170 | let out_str = format!( 171 | "pub static {}{}: RgdType = {{\n{}\n\trgd::<{{ {} }},{{ &[{}] }}>\n}};", 172 | INTERNAL_REVERSE_PREFIX, name, der_functions, default, fn_fmt_str 173 | ); 174 | // eprintln!("out_str: \n{}\n",out_str); 175 | out_str.parse().unwrap() 176 | } 177 | 178 | /// `format!()` but: 179 | /// 1. only allows positional arguments e.g. `{0}`, `{1}`, etc. 180 | /// 2. allows unused arguments. 181 | #[proc_macro] 182 | pub fn compose(item: TokenStream) -> TokenStream { 183 | // eprintln!("item: {}",item); 184 | let mut iter = item.into_iter(); 185 | let fmt_str = match iter.next() { 186 | Some(TokenTree::Literal(l)) => l.to_string(), 187 | _ => panic!("No fmt str"), 188 | }; 189 | let vec = iter.skip(1).collect::>(); 190 | let component_iter = vec.split(|t| match t { 191 | TokenTree::Punct(p) => p.as_char() == ',', 192 | _ => false, 193 | }); 194 | let components = component_iter 195 | .map(|component_slice| { 196 | component_slice 197 | .iter() 198 | .map(|c| c.to_string()) 199 | .collect::() 200 | }) 201 | .collect::>(); 202 | 203 | let mut bytes_string = Vec::from(&fmt_str.as_bytes()[1..fmt_str.len() - 1]); 204 | let mut i = 0; 205 | let mut out_str = String::from("let mut temp = String::new();"); 206 | while i < bytes_string.len() { 207 | if bytes_string[i] == b'}' { 208 | // Removes opening '}' 209 | let index_str = String::from_utf8(bytes_string.drain(0..i).collect::>()) 210 | .expect("compose: utf8"); 211 | let index: usize = index_str.parse().expect("compose: parse"); 212 | out_str.push_str(&format!( 213 | "\n\ttemp.push_str(&{}.to_string());", 214 | components[index] 215 | )); 216 | // Removes'}' 217 | bytes_string.remove(0); 218 | i = 0; 219 | } else if bytes_string[i] == b'{' { 220 | let segment = String::from_utf8(bytes_string.drain(0..i).collect::>()) 221 | .expect("compose: utf8"); 222 | out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment)); 223 | // Removes '{' 224 | bytes_string.remove(0); 225 | i = 0; 226 | } else { 227 | i += 1; 228 | } 229 | } 230 | let segment = String::from_utf8(bytes_string).expect("compose: utf8"); 231 | out_str.push_str(&format!("\n\ttemp.push_str(\"{}\");", segment)); 232 | 233 | let out_str = format!("{{\n\t{}\n\ttemp\n}}", out_str); 234 | // eprintln!("out_str: {}",out_str); 235 | out_str.parse().unwrap() 236 | } 237 | 238 | // TODO Can we replace these with declarative macros like `der!` (and then just move them into `rust-ad-core`)? 239 | /// Gets internal forward derivative function identifier 240 | #[proc_macro] 241 | pub fn f(item: TokenStream) -> TokenStream { 242 | let mut items = item.into_iter(); 243 | let function_ident = match items.next() { 244 | Some(proc_macro::TokenTree::Ident(ident)) => ident, 245 | _ => panic!("Requires function identifier"), 246 | }; 247 | let call_str = format!("{}{}", INTERNAL_FORWARD_PREFIX, function_ident); 248 | call_str.parse().unwrap() 249 | } 250 | /// Gets internal reverse derivative function identifier 251 | #[proc_macro] 252 | pub fn r(item: TokenStream) -> TokenStream { 253 | let mut items = item.into_iter(); 254 | let function_ident = match items.next() { 255 | Some(proc_macro::TokenTree::Ident(ident)) => ident, 256 | _ => panic!("Requires function identifier"), 257 | }; 258 | let call_str = format!("{}{}", INTERNAL_REVERSE_PREFIX, function_ident); 259 | call_str.parse().unwrap() 260 | } 261 | -------------------------------------------------------------------------------- /core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-ad-core" 3 | version = "0.8.0" 4 | edition = "2021" 5 | 6 | description = "Rust Auto-Differentiation." 7 | license = "Apache-2.0" 8 | repository = "https://github.com/JonathanWoollett-Light/rust-ad" 9 | documentation = "https://docs.rs/rust-ad/" 10 | readme = "../README.md" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | syn = { version="1.0.82", features=["full","extra-traits"] } 16 | lazy_static = "1.4.0" 17 | rust-ad-core-macros = { version = "0.8.0", path = "../core-macros" } 18 | rust-ad-consts = { version = "0.8.0", path = "../consts" } 19 | quote = "1.0.10" # Just for ToTokens -------------------------------------------------------------------------------- /core/src/derivatives/f32.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_f32, "0f32", "1f32", "1f32"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_f32, "0f32", "1f32", "-1f32"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_f32, "0f32", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_f32, "0f32", "1f32/{1}", "-{0}/({1}*{1})"); 15 | 16 | // Exponent procedures 17 | // ------------------------------------------------------------------- 18 | 19 | // Derivative of [`powi`](https://doc.rust-lang.org/std/primitive.f32.html#method.powi). 20 | combined_derivative_macro!( 21 | powi_f32, 22 | "0f32", 23 | "{1} as f32 * {0}.powi({1} - 1i32)", 24 | "{0}.powi({1}) * {0}.ln()" 25 | ); 26 | // Derivative of [`powf`](https://doc.rust-lang.org/std/primitive.f32.html#method.powf) 27 | combined_derivative_macro!( 28 | powf_f32, 29 | "0f32", 30 | "{1} as f32 * {0}.powf({1} - 1f32)", 31 | "{0}.powf({1}) * {0}.ln()" 32 | ); 33 | // Derivative of [`sqrt`](https://doc.rust-lang.org/std/primitive.f32.html#method.sqrt). 34 | combined_derivative_macro!(sqrt_f32, "0f32", "1f32 / (2f32 * {0}.sqrt())"); 35 | // Derivative of [`cbrt`](https://doc.rust-lang.org/std/primitive.f32.html#method.cbrt). 36 | combined_derivative_macro!(cbrt_f32, "0f32", "1f32 / (3f32*{0}.powf(2f32/3f32))"); 37 | // Derivative of [`exp`](https://doc.rust-lang.org/std/primitive.f32.html#method.exp). 38 | combined_derivative_macro!(exp_f32, "0f32", "{0}.exp()"); 39 | // Derivative of [`exp2`](https://doc.rust-lang.org/std/primitive.f32.html#method.exp2). 40 | combined_derivative_macro!(exp2_f32, "0f32", "{0}.exp2() * 2f32.ln()"); 41 | // Derivative of [`exp_m1`](https://doc.rust-lang.org/std/primitive.f32.html#method.exp_m1). 42 | combined_derivative_macro!(exp_m1_f32, "0f32", "{0}.exp()"); 43 | 44 | // Log procedures 45 | // ------------------------------------------------------------------- 46 | 47 | // Derivative of [`ln`](https://doc.rust-lang.org/std/primitive.f32.html#method.ln). 48 | combined_derivative_macro!(ln_f32, "0f32", "1f32 / {0}"); 49 | // Derivative of [`ln_1p`](https://doc.rust-lang.org/std/primitive.f32.html#method.ln_1p). 50 | combined_derivative_macro!(ln_1p_f32, "0f32", "1f32 / (1f32+{0})"); 51 | // Derivative of [`log`](https://doc.rust-lang.org/std/primitive.f32.html#method.log). 52 | combined_derivative_macro!( 53 | log_f32, 54 | "0f32", 55 | "1f32 / ({0}*{1}.ln())", 56 | "-{0}.ln() / ({1} *{1}.ln()*{1}.ln())" 57 | ); 58 | // Derivative of [`log10`](https://doc.rust-lang.org/std/primitive.f32.html#method.log10). 59 | combined_derivative_macro!(log10_f32, "0f32", "1f32 / ({0}*(10f32).ln())"); 60 | // Derivative of [`log2`](https://doc.rust-lang.org/std/primitive.f32.html#method.log2). 61 | combined_derivative_macro!(log2_f32, "0f32", "1f32 / ({0}*(2f32).ln())"); 62 | 63 | // Trig procedures 64 | // ------------------------------------------------------------------- 65 | 66 | // Derivative of [`acos`](https://doc.rust-lang.org/std/primitive.f32.html#method.acos). 67 | combined_derivative_macro!(acos_f32, "0f32", "-1f32 / (1f32-{0}*{0}).sqrt())"); 68 | // Derivative of [`acosh`](https://doc.rust-lang.org/std/primitive.f32.html#method.acosh). 69 | combined_derivative_macro!( 70 | acosh_f32, 71 | "0f32", 72 | "1f32 / ( ({0}-1f32).sqrt() * ({0}+1f32).sqrt() )" 73 | ); 74 | // Derivative of [`asin`](https://doc.rust-lang.org/std/primitive.f32.html#method.asin). 75 | combined_derivative_macro!(asin_f32, "0f32", "1f32 / (1f32-{0}*{0}).sqrt()"); 76 | // Derivative of [`asinh`](https://doc.rust-lang.org/std/primitive.f32.html#method.asinh). 77 | combined_derivative_macro!(asinh_f32, "0f32", "1f32 / ({0}*{0}+1f32).sqrt()"); 78 | // Derivative of [`atan`](https://doc.rust-lang.org/std/primitive.f32.html#method.atan). 79 | combined_derivative_macro!(atan_f32, "0f32", "1f32 / ({0}*{0}+1f32)"); 80 | // Derivative of [`sin`](https://doc.rust-lang.org/std/primitive.f32.html#method.sin). 81 | combined_derivative_macro!(sin_f32, "0f32", "{0}.cos()"); 82 | // Derivative of [`atanh`](https://doc.rust-lang.org/std/primitive.f32.html#method.atanh). 83 | combined_derivative_macro!(atanh_f32, "0f32", "1f32 / (1f32-{0}*{0})"); 84 | // Derivative of [`cos`](https://doc.rust-lang.org/std/primitive.f32.html#method.cos). 85 | combined_derivative_macro!(cos_f32, "0f32", "-({0}).sin()"); 86 | // Derivative of [`cosh`](https://doc.rust-lang.org/std/primitive.f32.html#method.cosh). 87 | combined_derivative_macro!(cosh_f32, "0f32", "{0}.sinh()"); 88 | // Derivative of [`sinh`](https://doc.rust-lang.org/std/primitive.f32.html#method.sinh). 89 | combined_derivative_macro!(sinh_f32, "0f32", "{0}.cosh()"); 90 | // Derivative of [`tan`](https://doc.rust-lang.org/std/primitive.f32.html#method.tan). 91 | combined_derivative_macro!(tan_f32, "0f32", "1f32 / ({0}.cos() * {0}.cos())"); 92 | // Derivative of [`tanh`](https://doc.rust-lang.org/std/primitive.f32.html#method.tanh). 93 | // combined_derivative_macro!(tanh_f32, "0f32","1f32 / ({base}.cosh()*{base}.cosh())"); 94 | 95 | // TODO Add atan2 (https://doc.rust-lang.org/std/primitive.f32.html#method.atan2) 96 | // TODO Add sin_cos (https://doc.rust-lang.org/std/primitive.f32.html#method.sin_cos) 97 | 98 | // Misc procedures 99 | // ------------------------------------------------------------------- 100 | 101 | // Derivative of [`abs`](https://doc.rust-lang.org/std/primitive.f32.html#method.abs). 102 | combined_derivative_macro!(abs_f32, "0f32", "{0}.signum()"); 103 | // Derivative of [`recip`](https://doc.rust-lang.org/std/primitive.f32.html#method.recip). 104 | combined_derivative_macro!(recip_f32, "0f32", "-1f32 / ({0}{0})"); 105 | 106 | // TODO For the below functions, I do not think the given derivatives are entirely accurate. 107 | 108 | // Derivative of [`ceil`](https://doc.rust-lang.org/std/primitive.f32.html#method.ceil). 109 | combined_derivative_macro!(ceil_f32, "0f32", "1f32"); 110 | // Derivative of [`floor`](https://doc.rust-lang.org/std/primitive.f32.html#method.floor). 111 | combined_derivative_macro!(floor_f32, "0f32", "1f32"); 112 | // Derivative of [`fract`](https://doc.rust-lang.org/std/primitive.f32.html#method.fract). 113 | combined_derivative_macro!(fract_f32, "0f32", "1f32"); 114 | // Derivative of [`round`](https://doc.rust-lang.org/std/primitive.f32.html#method.round). 115 | combined_derivative_macro!(round_f32, "0f32", "1f32"); 116 | 117 | // TODO Add some of these procedures here: 118 | // - clamp https://doc.rust-lang.org/std/primitive.f32.html#method.clamp 119 | // - div_eculid https://doc.rust-lang.org/std/primitive.f32.html#method.div_euclid 120 | // - hypot https://doc.rust-lang.org/std/primitive.f32.html#method.hypot 121 | // - mul_add https://doc.rust-lang.org/std/primitive.f32.html#method.mul_add 122 | // - signum https://doc.rust-lang.org/std/primitive.f32.html#method.signum 123 | // - rem_euclid https://doc.rust-lang.org/std/primitive.f32.html#method.rem_euclid 124 | // - to_degrees https://doc.rust-lang.org/std/primitive.f32.html#method.to_degrees 125 | // - to_radians https://doc.rust-lang.org/std/primitive.f32.html#method.to_radians 126 | // - trunc https://doc.rust-lang.org/std/primitive.f32.html#method.trunc 127 | -------------------------------------------------------------------------------- /core/src/derivatives/f64.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_f64, "0f64", "1f64", "1f64"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_f64, "0f64", "1f64", "-1f64"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_f64, "0f64", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_f64, "0f64", "1f64/{1}", "-{0}/({1}*{1})"); 15 | 16 | // Exponent procedures 17 | // ------------------------------------------------------------------- 18 | 19 | // Derivative of [`powi`](https://doc.rust-lang.org/std/primitive.f64.html#method.powi). 20 | combined_derivative_macro!( 21 | powi_f64, 22 | "0f64", 23 | "{1} as f64 * {0}.powi({1} - 1i32)", 24 | "{0}.powi({1}) * {0}.ln()" 25 | ); 26 | // Derivative of [`powf`](https://doc.rust-lang.org/std/primitive.f64.html#method.powf) 27 | combined_derivative_macro!( 28 | powf_f64, 29 | "0f64", 30 | "{1} as f64 * {0}.powf({1} - 1f64)", 31 | "{0}.powf({1}) * {0}.ln()" 32 | ); 33 | // Derivative of [`sqrt`](https://doc.rust-lang.org/std/primitive.f64.html#method.sqrt). 34 | combined_derivative_macro!(sqrt_f64, "0f64", "1f64 / (2f64 * {0}.sqrt())"); 35 | // Derivative of [`cbrt`](https://doc.rust-lang.org/std/primitive.f64.html#method.cbrt). 36 | combined_derivative_macro!(cbrt_f64, "0f64", "1f64 / (3f64*{0}.powf(2f64/3f64))"); 37 | // Derivative of [`exp`](https://doc.rust-lang.org/std/primitive.f64.html#method.exp). 38 | combined_derivative_macro!(exp_f64, "0f64", "{0}.exp()"); 39 | // Derivative of [`exp2`](https://doc.rust-lang.org/std/primitive.f64.html#method.exp2). 40 | combined_derivative_macro!(exp2_f64, "0f64", "{0}.exp2() * (2f64).ln()"); 41 | // Derivative of [`exp_m1`](https://doc.rust-lang.org/std/primitive.f64.html#method.exp_m1). 42 | combined_derivative_macro!(exp_m1_f64, "0f64", "{0}.exp()"); 43 | 44 | // Log procedures 45 | // ------------------------------------------------------------------- 46 | 47 | // Derivative of [`ln`](https://doc.rust-lang.org/std/primitive.f64.html#method.ln). 48 | combined_derivative_macro!(ln_f64, "0f64", "1f64 / {0}"); 49 | // Derivative of [`ln_1p`](https://doc.rust-lang.org/std/primitive.f64.html#method.ln_1p). 50 | combined_derivative_macro!(ln_1p_f64, "0f64", "1f64 / (1f64+{0})"); 51 | // Derivative of [`log`](https://doc.rust-lang.org/std/primitive.f64.html#method.log). 52 | combined_derivative_macro!( 53 | log_f64, 54 | "0f64", 55 | "1f64 / ({0}*{1}.ln())", 56 | "-{0}.ln() / ({1} *{1}.ln()*{1}.ln())" 57 | ); 58 | // Derivative of [`log10`](https://doc.rust-lang.org/std/primitive.f64.html#method.log10). 59 | combined_derivative_macro!(log10_f64, "0f64", "1f64 / ({0}*(10f64).ln())"); 60 | // Derivative of [`log2`](https://doc.rust-lang.org/std/primitive.f64.html#method.log2). 61 | combined_derivative_macro!(log2_f64, "0f64", "1f64 / ({0}*(2f64).ln())"); 62 | 63 | // Trig procedures 64 | // ------------------------------------------------------------------- 65 | 66 | // Derivative of [`acos`](https://doc.rust-lang.org/std/primitive.f64.html#method.acos). 67 | combined_derivative_macro!(acos_f64, "0f64", "-1f64 / (1f64-{0}*{0}).sqrt())"); 68 | // Derivative of [`acosh`](https://doc.rust-lang.org/std/primitive.f64.html#method.acosh). 69 | combined_derivative_macro!( 70 | acosh_f64, 71 | "0f64", 72 | "1f64 / ( ({0}-1f64).sqrt() * ({0}+1f64).sqrt() )" 73 | ); 74 | // Derivative of [`asin`](https://doc.rust-lang.org/std/primitive.f64.html#method.asin). 75 | combined_derivative_macro!(asin_f64, "0f64", "1f64 / (1f64-{0}*{0}).sqrt()"); 76 | // Derivative of [`asinh`](https://doc.rust-lang.org/std/primitive.f64.html#method.asinh). 77 | combined_derivative_macro!(asinh_f64, "0f32", "1f64 / ({0}*{0}+1f64).sqrt()"); 78 | // Derivative of [`atan`](https://doc.rust-lang.org/std/primitive.f64.html#method.atan). 79 | combined_derivative_macro!(atan_f64, "0f32", "1f64 / ({0}*{0}+1f64)"); 80 | // Derivative of [`sin`](https://doc.rust-lang.org/std/primitive.f64.html#method.sin). 81 | combined_derivative_macro!(sin_f64, "0f32", "{0}.cos()"); 82 | // Derivative of [`atanh`](https://doc.rust-lang.org/std/primitive.f64.html#method.atanh). 83 | combined_derivative_macro!(atanh_f64, "0f32", "1f64 / (1f64-{0}*{0})"); 84 | // Derivative of [`cos`](https://doc.rust-lang.org/std/primitive.f64.html#method.cos). 85 | combined_derivative_macro!(cos_f64, "0f32", "-({0}).sin()"); 86 | // Derivative of [`cosh`](https://doc.rust-lang.org/std/primitive.f64.html#method.cosh). 87 | combined_derivative_macro!(cosh_f64, "0f32", "{0}.sinh()"); 88 | // Derivative of [`sinh`](https://doc.rust-lang.org/std/primitive.f64.html#method.sinh). 89 | combined_derivative_macro!(sinh_f64, "0f32", "{0}.cosh()"); 90 | // Derivative of [`tan`](https://doc.rust-lang.org/std/primitive.f64.html#method.tan). 91 | combined_derivative_macro!(tan_f64, "0f32", "1f64 / ({0}.cos() * {0}.cos())"); 92 | // Derivative of [`tanh`](https://doc.rust-lang.org/std/primitive.f64.html#method.tanh). 93 | // combined_derivative_macro!(tanh_f64, "0f32", "1f64 / ({base}.cosh()*{base}.cosh())"); 94 | 95 | // TODO Add atan2 (https://doc.rust-lang.org/std/primitive.f64.html#method.atan2) 96 | // TODO Add sin_cos (https://doc.rust-lang.org/std/primitive.f64.html#method.sin_cos) 97 | 98 | // Misc procedures 99 | // ------------------------------------------------------------------- 100 | 101 | // Derivative of [`abs`](https://doc.rust-lang.org/std/primitive.f64.html#method.abs). 102 | combined_derivative_macro!(abs_f64, "0f32", "{0}.signum()"); 103 | // Derivative of [`recip`](https://doc.rust-lang.org/std/primitive.f64.html#method.recip). 104 | combined_derivative_macro!(recip_f64, "0f32", "-1f64 / ({0}{0})"); 105 | 106 | // TODO For the below functions, I do not think the given derivatives are entirely accurate. 107 | 108 | // Derivative of [`ceil`](https://doc.rust-lang.org/std/primitive.f64.html#method.ceil). 109 | combined_derivative_macro!(ceil_f64, "0f32", "1f64"); 110 | // Derivative of [`floor`](https://doc.rust-lang.org/std/primitive.f64.html#method.floor). 111 | combined_derivative_macro!(floor_f64, "0f32", "1f64"); 112 | // Derivative of [`fract`](https://doc.rust-lang.org/std/primitive.f64.html#method.fract). 113 | combined_derivative_macro!(fract_f64, "0f32", "1f64"); 114 | // Derivative of [`round`](https://doc.rust-lang.org/std/primitive.f64.html#method.round). 115 | combined_derivative_macro!(round_f64, "0f32", "1f64"); 116 | 117 | // TODO Add some of these procedures here: 118 | // - clamp https://doc.rust-lang.org/std/primitive.f64.html#method.clamp 119 | // - div_eculid https://doc.rust-lang.org/std/primitive.f64.html#method.div_euclid 120 | // - hypot https://doc.rust-lang.org/std/primitive.f64.html#method.hypot 121 | // - mul_add https://doc.rust-lang.org/std/primitive.f64.html#method.mul_add 122 | // - signum https://doc.rust-lang.org/std/primitive.f64.html#method.signum 123 | // - rem_euclid https://doc.rust-lang.org/std/primitive.f64.html#method.rem_euclid 124 | // - to_degrees https://doc.rust-lang.org/std/primitive.f64.html#method.to_degrees 125 | // - to_radians https://doc.rust-lang.org/std/primitive.f64.html#method.to_radians 126 | // - trunc https://doc.rust-lang.org/std/primitive.f64.html#method.trunc 127 | -------------------------------------------------------------------------------- /core/src/derivatives/i128.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_i128, "0i128", "1i128", "1i128"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_i128, "0i128", "1i128", "-1i128"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_i128, "0i128", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_i128, "0i128", "1i128/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/i16.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_i16, "0i16", "1i16", "1i16"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_i16, "0i16", "1i16", "-1i16"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_i16, "0i16", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_i16, "0i16", "1i16/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/i32.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_i32, "0i32", "1i32", "1i32"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_i32, "0i32", "1i32", "-1i32"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_i32, "0i32", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_i32, "0i32", "1i32/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/i64.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_i64, "0i64", "1i64", "1i64"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_i64, "0i64", "1i64", "-1i64"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_i64, "0i64", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_i64, "0i64", "1i64/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/i8.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_i8, "0i8", "1i8", "1i8"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_i8, "0i8", "1i8", "-1i8"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_i8, "0i8", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_i8, "0i8", "1i8/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use std::collections::HashSet; 3 | 4 | /// Derivative functions for `f32`s. 5 | pub mod f32; 6 | pub use self::f32::*; 7 | /// Derivative functions for `f64`s. 8 | pub mod f64; 9 | pub use self::f64::*; 10 | /// Derivative functions for `i8`s. 11 | pub mod i8; 12 | pub use self::i8::*; 13 | /// Derivative functions for `i16`s. 14 | pub mod i16; 15 | pub use self::i16::*; 16 | /// Derivative functions for `i32`s. 17 | pub mod i32; 18 | pub use self::i32::*; 19 | /// Derivative functions for `i64`s. 20 | pub mod i64; 21 | pub use self::i64::*; 22 | /// Derivative functions for `i128`s. 23 | pub mod i128; 24 | pub use self::i128::*; 25 | /// Derivative functions for `u8`s. 26 | pub mod u8; 27 | pub use self::u8::*; 28 | /// Derivative functions for `u16`s. 29 | pub mod u16; 30 | pub use self::u16::*; 31 | /// Derivative functions for `u32`s. 32 | pub mod u32; 33 | pub use self::u32::*; 34 | /// Derivative functions for `u64`s. 35 | pub mod u64; 36 | pub use self::u64::*; 37 | /// Derivative functions for `u128`s. 38 | pub mod u128; 39 | pub use self::u128::*; 40 | // /// Derivative functions for [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html). 41 | // pub mod ndarray; 42 | // pub use self::ndarray::*; 43 | 44 | /// Forward General Derivative type 45 | #[cfg(debug_assertions)] 46 | pub type FgdType = fn(String, &[Arg], &[String]) -> syn::Stmt; 47 | #[cfg(not(debug_assertions))] 48 | pub type FgdType = fn(String, &[Arg], &[String], &mut HashSet) -> syn::Stmt; 49 | 50 | /// Reverse General Derivative type 51 | pub type RgdType = fn( 52 | String, 53 | &[Arg], 54 | &mut Vec>>, 55 | &mut Vec>, 56 | ) -> Option; 57 | 58 | /// Function argument type 59 | pub enum Arg { 60 | /// e.g. `a` 61 | Variable(String), 62 | /// e.g. `7.3f32` 63 | Literal(String), 64 | } 65 | impl std::fmt::Display for Arg { 66 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 67 | match self { 68 | Self::Variable(s) => write!(f, "{}", s), 69 | Self::Literal(s) => write!(f, "{}", s), 70 | } 71 | } 72 | } 73 | impl TryFrom<&syn::Expr> for Arg { 74 | type Error = String; 75 | fn try_from(expr: &syn::Expr) -> Result { 76 | match expr { 77 | syn::Expr::Lit(l) => match &l.lit { 78 | syn::Lit::Int(int) => Ok(Self::Literal(int.to_string())), 79 | syn::Lit::Float(float) => Ok(Self::Literal(float.to_string())), 80 | _ => { 81 | Diagnostic::spanned( 82 | expr.span().unwrap(), 83 | proc_macro::Level::Error, 84 | format!("non-literal and non-path argument: {:?}", expr), 85 | ) 86 | .emit(); 87 | Err(format!("Arg::TryFrom: {:?}", expr)) 88 | } 89 | }, 90 | syn::Expr::Path(p) => Ok(Self::Variable(p.path.segments[0].ident.to_string())), 91 | _ => { 92 | Diagnostic::spanned( 93 | expr.span().unwrap(), 94 | proc_macro::Level::Error, 95 | format!("non-literal and non-path argument: {:?}", expr), 96 | ) 97 | .emit(); 98 | Err(format!("Arg::TryFrom: {:?}", expr)) 99 | } 100 | } 101 | } 102 | } 103 | 104 | /// Derivative function type 105 | pub type DFn = fn(&[Arg]) -> String; 106 | 107 | /// Local identifier and method identifier 108 | pub fn lm_identifiers(stmt: &syn::Stmt) -> (String, &syn::ExprMethodCall) { 109 | let local = stmt.local().expect("lm_identifiers: not local"); 110 | let init = &local.init; 111 | let method_expr = init 112 | .as_ref() 113 | .unwrap() 114 | .1 115 | .method_call() 116 | .expect("lm_identifiers: not method"); 117 | 118 | let local_ident = local 119 | .pat 120 | .ident() 121 | .expect("lm_identifiers: not ident") 122 | .ident 123 | .to_string(); 124 | (local_ident, method_expr) 125 | } 126 | 127 | // TODO Replace `cumulative_derivative_wrt_rt` and `Type` with neater functionality. 128 | /// Gets cumulative derivative for given expression for a given input variable (only supports literals and paths). 129 | /// 130 | /// See `cumulative_derivative_wrt` for more documentation 131 | pub fn cumulative_derivative_wrt_rt( 132 | expr: &syn::Expr, 133 | input_var: &str, 134 | function_inputs: &[String], 135 | out_type: &Type, 136 | ) -> String { 137 | match expr { 138 | // Result 1 139 | syn::Expr::Lit(_) => out_type.zero(), 140 | syn::Expr::Path(path_expr) => { 141 | // x typically is the left or right of binary expression, regardless we are doing d/dx(expr) so at this we got 142 | let x = path_expr.path.segments[0].ident.to_string(); 143 | 144 | // Result 3 145 | if x == input_var { 146 | der!(input_var) 147 | } 148 | // Result 4 149 | else if function_inputs.contains(&x) { 150 | out_type.zero() 151 | } 152 | // Result 2 153 | else { 154 | wrt!(x, input_var) 155 | } 156 | } 157 | _ => panic!("cumulative_derivative_wrt: unsupported expr"), 158 | } 159 | } 160 | /// Struct for some internal functionality (this will soon be removed). 161 | #[derive(PartialEq, Eq)] 162 | pub enum Type { 163 | F32, 164 | F64, 165 | U8, 166 | U16, 167 | U32, 168 | U64, 169 | U128, 170 | I8, 171 | I16, 172 | I32, 173 | I64, 174 | I128, 175 | } 176 | impl Type { 177 | pub fn zero(&self) -> String { 178 | format!("0{}", self.to_string()) 179 | } 180 | } 181 | impl ToString for Type { 182 | fn to_string(&self) -> String { 183 | match self { 184 | Self::F32 => "f32", 185 | Self::F64 => "f64", 186 | Self::U8 => "u8", 187 | Self::U16 => "u16", 188 | Self::U32 => "u32", 189 | Self::U64 => "u64", 190 | Self::U128 => "u128", 191 | Self::I8 => "i8", 192 | Self::I16 => "i16", 193 | Self::I32 => "i32", 194 | Self::I64 => "i64", 195 | Self::I128 => "i128", 196 | } 197 | .into() 198 | } 199 | } 200 | impl TryFrom<&str> for Type { 201 | type Error = &'static str; 202 | fn try_from(string: &str) -> Result { 203 | match string { 204 | "f32" => Ok(Self::F32), 205 | "f64" => Ok(Self::F64), 206 | "u8" => Ok(Self::U8), 207 | "u16" => Ok(Self::U16), 208 | "u32" => Ok(Self::U32), 209 | "u64" => Ok(Self::U64), 210 | "u128" => Ok(Self::U128), 211 | "i8" => Ok(Self::I8), 212 | "i16" => Ok(Self::I16), 213 | "i32" => Ok(Self::I32), 214 | "i64" => Ok(Self::I64), 215 | "i128" => Ok(Self::I128), 216 | _ => Err("Type::try_from unsupported type"), 217 | } 218 | } 219 | } 220 | 221 | /// Forward general derivative 222 | /// ```ignore 223 | /// static outer_test: FgdType = { 224 | /// const base_fn: DFn = |args:&[String]| -> String { format!("{0}-{1}",args[0],args[1]) }; 225 | /// const exponent_fn: DFn = |args:&[String]| -> String { format!("{0}*{1}+{0}",args[0],args[1]) }; 226 | /// fgd::<"0f32",{&[base_fn, exponent_fn]}> 227 | /// }; 228 | /// ``` 229 | /// Is equivalent to 230 | /// ```ignore 231 | /// forward_derivative_macro!(outer_test,"0f32","{0}-{1}","{0}*{1}+{0}"); 232 | /// ``` 233 | pub fn fgd( 234 | local_ident: String, 235 | args: &[Arg], 236 | outer_fn_args: &[String], 237 | #[cfg(not(debug_assertions))] non_zero_derivatives: &mut HashSet, 238 | ) -> syn::Stmt { 239 | assert_eq!( 240 | args.len(), 241 | TRANSLATION_FUNCTIONS.len(), 242 | "fgd args len mismatch" 243 | ); 244 | 245 | // Gets vec of derivative idents and derivative functions 246 | // TODO Put these 2 different implementations together more cleanly. 247 | // TODO Improve docs here. 248 | #[cfg(debug_assertions)] 249 | let (idents, derivatives) = outer_fn_args 250 | .iter() 251 | .map(|outer_fn_input| { 252 | let acc = args 253 | .iter() 254 | .zip(TRANSLATION_FUNCTIONS.iter()) 255 | .map(|(arg,t)| 256 | // See the docs for cumulative (these if's accomplish the same-ish thing) 257 | match arg { 258 | Arg::Literal(_) => DEFAULT.to_string(), // Since we are multiplying by `DEFAULT` (e.g. `0.`) we can simply ignore this property 259 | Arg::Variable(v) => { 260 | let a = t(args); 261 | let b = if v == outer_fn_input { 262 | der!(outer_fn_input) 263 | } else if outer_fn_args.contains(v) { 264 | DEFAULT.to_string() // Since we are multiplying by `DEFAULT` (e.g. `0.`) we can simply ignore this property 265 | } else { 266 | wrt!(arg,outer_fn_input) 267 | }; 268 | // eprintln!("a: {}, b: {}",a,b); 269 | format!("({})*{}",a,b) 270 | } 271 | }) 272 | .intersperse(String::from("+")) 273 | .collect::(); 274 | let new_der = wrt!(local_ident, outer_fn_input); 275 | (new_der, acc) 276 | }) 277 | .unzip::<_, _, Vec<_>, Vec<_>>(); 278 | #[cfg(not(debug_assertions))] 279 | let (idents, derivatives) = outer_fn_args 280 | .iter() 281 | .filter_map(|outer_fn_input| { 282 | let acc = args 283 | .iter() 284 | .zip(TRANSLATION_FUNCTIONS.iter()) 285 | .filter_map(|(arg,t)| 286 | // See the docs for cumulative (these if's accomplish the same-ish thing) 287 | // TODO Improve docs here directly 288 | match arg { 289 | Arg::Literal(_) => None, // Since we are multiplying by `DEFAULT` (e.g. `0.`) we can simply ignore this property 290 | Arg::Variable(v) => { 291 | let a = t(args); 292 | let b = if v == outer_fn_input { 293 | Some(der!(outer_fn_input)) 294 | } else if outer_fn_args.contains(v) { 295 | None // Since we are multiplying by `DEFAULT` (e.g. `0.`) we can simply ignore this property 296 | } else { 297 | let der = wrt!(arg,outer_fn_input); 298 | // If the derivative has not been defined, we know it would've been defined as zero 299 | non_zero_derivatives.get(&der).cloned() 300 | }; 301 | // eprintln!("a: {}, b: {}",a,b); 302 | match b { 303 | Some(acc_der) => Some(format!("({})*{}",a,acc_der)), 304 | None => None 305 | } 306 | } 307 | }) 308 | .intersperse(String::from("+")) 309 | .collect::(); 310 | match acc.is_empty() { 311 | true => None, 312 | false => { 313 | let new_der = wrt!(local_ident, outer_fn_input); 314 | // If there are some non-zero components this derivative may be non-zero and is thus worth defining 315 | non_zero_derivatives.insert(new_der.clone()); 316 | Some((new_der, acc)) 317 | } 318 | } 319 | }) 320 | .unzip::<_, _, Vec<_>, Vec<_>>(); 321 | 322 | // Equivalent to `derivatives.len()` 323 | let stmt_str = match idents.len() { 324 | 0 => unreachable!(), 325 | 1 => format!("let {} = {};", idents[0], derivatives[0]), 326 | _ => format!( 327 | "let ({}) = ({});", 328 | idents 329 | .into_iter() 330 | .intersperse(String::from(",")) 331 | .collect::(), 332 | derivatives 333 | .into_iter() 334 | .intersperse(String::from(",")) 335 | .collect::() 336 | ), 337 | }; 338 | syn::parse_str(&stmt_str).expect("fgd: parse fail") 339 | } 340 | 341 | /// Reverse General Derivative 342 | pub fn rgd( 343 | local_ident: String, 344 | args: &[Arg], 345 | component_map: &mut Vec>>, 346 | return_derivatives: &mut Vec>, 347 | ) -> Option { 348 | debug_assert_eq!( 349 | args.len(), 350 | TRANSLATION_FUNCTIONS.len(), 351 | "rgd args len mismatch" 352 | ); 353 | debug_assert_eq!(component_map.len(), return_derivatives.len()); 354 | 355 | let (output_idents, output_derivatives) = (0..component_map.len()) 356 | .filter_map(|index| { 357 | let (idents, derivatives) = args 358 | .iter() 359 | .zip(TRANSLATION_FUNCTIONS.iter()) 360 | .filter_map(|(arg, t)| match arg { 361 | Arg::Variable(v) => Some((v, t)), 362 | Arg::Literal(_) => None, 363 | }) 364 | .filter_map(|(arg, t)| { 365 | let rtn = rtn!(index); 366 | let der_ident = wrtn!(arg, local_ident, rtn); 367 | let wrt = wrt!(local_ident, rtn); 368 | 369 | // If component exists 370 | match return_derivatives[index].contains(&local_ident) { 371 | true => { 372 | append_insert(arg, local_ident.clone(), &mut component_map[index]); 373 | let (derivative, accumulator) = (t(args), wrt); 374 | let full_der = format!("({})*{}", derivative, accumulator); 375 | Some((der_ident, full_der)) 376 | } 377 | false => None, 378 | } 379 | }) 380 | .unzip::<_, _, Vec<_>, Vec<_>>(); 381 | // let (idents, derivatives) = (idents.into_iter().intersperse(String::from(",")).collect::(), derivatives.into_iter().intersperse(String::from(",")).collect::()); 382 | (!idents.is_empty()).then(|| (idents, derivatives)) 383 | }) 384 | .unzip::<_, _, Vec<_>, Vec<_>>(); 385 | 386 | match output_idents.len() { 387 | 0 => None, 388 | 1 => match output_idents[0].len() { 389 | 0 => unreachable!(), 390 | 1 => Some( 391 | syn::parse_str(&format!( 392 | "let {} = {};", 393 | output_idents[0][0], output_derivatives[0][0] 394 | )) 395 | .expect("fgd: 1 parse fail"), 396 | ), 397 | _ => Some( 398 | syn::parse_str(&format!( 399 | "let ({}) = ({});", 400 | output_idents[0] 401 | .iter() 402 | .cloned() 403 | .intersperse(String::from(",")) 404 | .collect::(), 405 | output_derivatives[0] 406 | .iter() 407 | .cloned() 408 | .intersperse(String::from(",")) 409 | .collect::() 410 | )) 411 | .expect("fgd: 1 parse fail"), 412 | ), 413 | }, 414 | _ => { 415 | let (output_idents, output_derivatives) = ( 416 | output_idents 417 | .into_iter() 418 | .map(|ri| { 419 | format!( 420 | "({})", 421 | ri.into_iter() 422 | .intersperse(String::from(",")) 423 | .collect::() 424 | ) 425 | }) 426 | .intersperse(String::from(",")) 427 | .collect::(), 428 | output_derivatives 429 | .into_iter() 430 | .map(|rd| { 431 | format!( 432 | "({})", 433 | rd.into_iter() 434 | .intersperse(String::from(",")) 435 | .collect::() 436 | ) 437 | }) 438 | .intersperse(String::from(",")) 439 | .collect::(), 440 | ); 441 | let stmt_str = format!("let ({}) = ({});", output_idents, output_derivatives); 442 | Some(syn::parse_str(&stmt_str).expect("fgd: 3 parse fail")) 443 | } 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /core/src/derivatives/u128.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_u128, "0u128", "1u128", "1u128"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_u128, "0u128", "1u128", "-1u128"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_u128, "0u128", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_u128, "0u128", "1u128/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/u16.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_u16, "0u16", "1u16", "1u16"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_u16, "0u16", "1u16", "-1u16"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_u16, "0u16", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_u16, "0u16", "1u16/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/u32.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_u32, "0u32", "1u32", "1u32"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_u32, "0u32", "1u32", "-1u32"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_u32, "0u32", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_u32, "0u32", "1u32/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/u64.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_u64, "0u64", "1u64", "1u64"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_u64, "0u64", "1u64", "-1u64"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_u64, "0u64", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_u64, "0u64", "1u64/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/derivatives/u8.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use rust_ad_core_macros::{combined_derivative_macro, compose}; 3 | 4 | // Primitive procedures 5 | // ------------------------------------------------------------------- 6 | 7 | // Derivative of [std::ops::Add]. 8 | combined_derivative_macro!(add_u8, "0u8", "1u8", "1u8"); 9 | // Derivative of [std::ops::Sub]. 10 | combined_derivative_macro!(sub_u8, "0u8", "1u8", "-1u8"); 11 | // Derivative of [std::ops::Mul]. 12 | combined_derivative_macro!(mul_u8, "0u8", "{1}", "{0}"); 13 | // Derivative of [std::ops::Div]. 14 | combined_derivative_macro!(div_u8, "0u8", "1u8/{1}", "-{0}/({1}*{1})"); 15 | -------------------------------------------------------------------------------- /core/src/dict.rs: -------------------------------------------------------------------------------- 1 | use crate::derivatives::*; 2 | use rust_ad_core_macros::{f, r}; 3 | use std::{collections::HashMap, fmt}; 4 | 5 | /// Signature information to refer to specific method. 6 | #[derive(Hash, PartialEq, Eq, Debug)] 7 | pub struct MethodSignature { 8 | name: String, 9 | receiver_type: String, 10 | input_types: Vec, 11 | } 12 | impl MethodSignature { 13 | pub fn new(name: String, receiver_type: String, input_types: Vec) -> Self { 14 | Self { 15 | name, 16 | receiver_type, 17 | input_types, 18 | } 19 | } 20 | } 21 | impl From<(&'static str, &'static str, &'static [&'static str; N])> 22 | for MethodSignature 23 | { 24 | fn from( 25 | (name, receiver_type, input_types): ( 26 | &'static str, 27 | &'static str, 28 | &'static [&'static str; N], 29 | ), 30 | ) -> Self { 31 | Self { 32 | name: String::from(name), 33 | receiver_type: String::from(receiver_type), 34 | input_types: input_types.iter().map(|s| String::from(*s)).collect(), 35 | } 36 | } 37 | } 38 | impl fmt::Display for MethodSignature { 39 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 40 | write!( 41 | f, 42 | "{}.{}({})", 43 | self.receiver_type, 44 | self.name, 45 | self.input_types 46 | .iter() 47 | .cloned() 48 | .intersperse(String::from(",")) 49 | .collect::() 50 | ) 51 | } 52 | } 53 | /// A map of method signatures to useful data (output type, etc.). 54 | type MethodMap = HashMap; 55 | /// Signature information to refer to specific function. 56 | #[derive(Hash, PartialEq, Eq, Debug)] 57 | pub struct FunctionSignature { 58 | name: String, 59 | input_types: Vec, 60 | } 61 | impl FunctionSignature { 62 | pub fn new(name: String, input_types: Vec) -> Self { 63 | Self { name, input_types } 64 | } 65 | } 66 | impl From<(&'static str, &'static [&'static str; N])> for FunctionSignature { 67 | fn from((name, input_types): (&'static str, &'static [&'static str; N])) -> Self { 68 | Self { 69 | name: String::from(name), 70 | input_types: input_types.iter().map(|s| String::from(*s)).collect(), 71 | } 72 | } 73 | } 74 | impl fmt::Display for FunctionSignature { 75 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 76 | write!( 77 | f, 78 | "{}({})", 79 | self.name, 80 | self.input_types 81 | .iter() 82 | .cloned() 83 | .intersperse(String::from(",")) 84 | .collect::() 85 | ) 86 | } 87 | } 88 | /// A map of function signatures to useful data (output type, etc.). 89 | type FunctionMap = HashMap; 90 | 91 | /// Information to relating to specific procedure, output type, etc. (including functions for transforming statements into derivatives). 92 | pub struct ProcedureOutputs { 93 | /// Output type of procedure 94 | pub output_type: String, 95 | /// Transformation procedure to give the forward derivative 96 | pub forward_derivative: FgdType, 97 | /// Transformation procedure to give the reverse derivative 98 | pub reverse_derivative: RgdType, 99 | } 100 | impl ProcedureOutputs { 101 | pub fn new( 102 | output_type: &'static str, 103 | forward_derivative: FgdType, 104 | reverse_derivative: RgdType, 105 | ) -> Self { 106 | Self { 107 | output_type: String::from(output_type), 108 | forward_derivative, 109 | reverse_derivative, 110 | } 111 | } 112 | } 113 | // TODO Why doesn't this work? 114 | impl From<(&'static str, FgdType, RgdType)> for ProcedureOutputs { 115 | fn from( 116 | (output_type, forward_derivative, reverse_derivative): (&'static str, FgdType, RgdType), 117 | ) -> Self { 118 | Self { 119 | output_type: String::from(output_type), 120 | forward_derivative, 121 | reverse_derivative, 122 | } 123 | } 124 | } 125 | 126 | /// Currently supported binary operations. 127 | #[derive(Hash, PartialEq, Eq, Debug)] 128 | pub enum BinOp { 129 | Add, 130 | Sub, 131 | Mul, 132 | Div, 133 | } 134 | impl TryFrom<&'static str> for BinOp { 135 | type Error = &'static str; 136 | fn try_from(symbol: &'static str) -> Result { 137 | match symbol { 138 | "+" => Ok(Self::Add), 139 | "-" => Ok(Self::Sub), 140 | "*" => Ok(Self::Mul), 141 | "/" => Ok(Self::Div), 142 | _ => Err("Unrecognized symbol"), 143 | } 144 | } 145 | } 146 | impl TryFrom for BinOp { 147 | type Error = &'static str; 148 | fn try_from(op: syn::BinOp) -> Result { 149 | match op { 150 | syn::BinOp::Add(_) => Ok(Self::Add), 151 | syn::BinOp::Sub(_) => Ok(Self::Sub), 152 | syn::BinOp::Mul(_) => Ok(Self::Mul), 153 | syn::BinOp::Div(_) => Ok(Self::Div), 154 | _ => Err("Unrecognized syn op"), 155 | } 156 | } 157 | } 158 | impl fmt::Display for BinOp { 159 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 160 | match self { 161 | Self::Add => write!(f, "+"), 162 | Self::Sub => write!(f, "-"), 163 | Self::Mul => write!(f, "*"), 164 | Self::Div => write!(f, "/"), 165 | } 166 | } 167 | } 168 | 169 | /// Signature information to refer to specific binary operation. 170 | #[derive(Hash, PartialEq, Eq, Debug)] 171 | pub struct OperationSignature { 172 | /// Left-hand-side type 173 | lhs: String, 174 | /// Operation type 175 | op: BinOp, 176 | /// Right-hand-side type 177 | rhs: String, 178 | } 179 | impl From<(&'static str, &'static str, &'static str)> for OperationSignature { 180 | fn from((lhs, op, rhs): (&'static str, &'static str, &'static str)) -> Self { 181 | Self { 182 | lhs: String::from(lhs), 183 | op: BinOp::try_from(op).expect("No symbol"), 184 | rhs: String::from(rhs), 185 | } 186 | } 187 | } 188 | impl From<(String, syn::BinOp, String)> for OperationSignature { 189 | fn from((lhs, op, rhs): (String, syn::BinOp, String)) -> Self { 190 | Self { 191 | lhs, 192 | op: BinOp::try_from(op).expect("No op"), 193 | rhs, 194 | } 195 | } 196 | } 197 | impl fmt::Display for OperationSignature { 198 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 199 | write!(f, "{}{}{}", self.lhs, self.op, self.rhs) 200 | } 201 | } 202 | /// A map of binary operation signatures to useful data (output type, etc.). 203 | type OperationMap = HashMap; 204 | 205 | // Supported methods, functions and operations. 206 | lazy_static::lazy_static! { 207 | /// Internal map of currently supported functions. 208 | pub static ref SUPPORTED_FUNCTIONS: FunctionMap = { 209 | FunctionMap::new() 210 | }; 211 | /// Internal map of currently supported methods. 212 | pub static ref SUPPORTED_METHODS: MethodMap = { 213 | let mut map = MethodMap::new(); 214 | // f32 215 | // ---------------------------- 216 | // Exponents 217 | map.insert(("powi","f32",&["i32"]).into(),ProcedureOutputs::new("f32",f!(powi_f32),r!(powi_f32))); 218 | map.insert(("powf","f32",&["f32"]).into(),ProcedureOutputs::new("f32",f!(powf_f32),r!(powf_f32))); 219 | map.insert(("sqrt","f32",&[]).into(),ProcedureOutputs::new("f32",f!(sqrt_f32),r!(sqrt_f32))); 220 | map.insert(("cbrt","f32",&[]).into(),ProcedureOutputs::new("f32",f!(cbrt_f32),r!(cbrt_f32))); 221 | map.insert(("exp","f32",&[]).into(),ProcedureOutputs::new("f32",f!(exp_f32),r!(exp_f32))); 222 | map.insert(("exp2","f32",&[]).into(),ProcedureOutputs::new("f32",f!(exp2_f32),r!(exp2_f32))); 223 | map.insert(("exp_m1","f32",&[]).into(),ProcedureOutputs::new("f32",f!(exp_m1_f32),r!(exp_m1_f32))); 224 | // Logs 225 | map.insert(("ln","f32",&[]).into(),ProcedureOutputs::new("f32",f!(ln_f32),r!(ln_f32))); 226 | map.insert(("ln_1p","f32",&[]).into(),ProcedureOutputs::new("f32",f!(ln_1p_f32),r!(ln_1p_f32))); 227 | map.insert(("log","f32",&["f32"]).into(),ProcedureOutputs::new("f32",f!(log_f32),r!(log_f32))); 228 | map.insert(("log10","f32",&[]).into(),ProcedureOutputs::new("f32",f!(log10_f32),r!(log10_f32))); 229 | map.insert(("log2","f32",&[]).into(),ProcedureOutputs::new("f32",f!(log2_f32),r!(log2_f32))); 230 | // Trig 231 | map.insert(("acos","f32",&[]).into(),ProcedureOutputs::new("f32",f!(acos_f32),r!(acos_f32))); 232 | map.insert(("acosh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(acosh_f32),r!(acosh_f32))); 233 | map.insert(("asin","f32",&[]).into(),ProcedureOutputs::new("f32",f!(asin_f32),r!(asin_f32))); 234 | map.insert(("asinh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(asinh_f32),r!(asinh_f32))); 235 | map.insert(("atan","f32",&[]).into(),ProcedureOutputs::new("f32",f!(atan_f32),r!(atan_f32))); 236 | map.insert(("sin","f32",&[]).into(),ProcedureOutputs::new("f32",f!(sin_f32),r!(sin_f32))); 237 | map.insert(("atanh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(atanh_f32),r!(atanh_f32))); 238 | map.insert(("cos","f32",&[]).into(),ProcedureOutputs::new("f32",f!(cos_f32),r!(cos_f32))); 239 | map.insert(("cosh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(cosh_f32),r!(cosh_f32))); 240 | map.insert(("sinh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(sinh_f32),r!(sinh_f32))); 241 | map.insert(("tan","f32",&[]).into(),ProcedureOutputs::new("f32",f!(tan_f32),r!(tan_f32))); 242 | // map.insert(("tanh","f32",&[]).into(),ProcedureOutputs::new("f32",f!(tanh_f32),r!(tanh::<{Type::F32}>))); 243 | // Misc 244 | map.insert(("abs","f32",&[]).into(),ProcedureOutputs::new("f32",f!(abs_f32),r!(abs_f32))); 245 | map.insert(("ceil","f32",&[]).into(),ProcedureOutputs::new("f32",f!(ceil_f32),r!(ceil_f32))); 246 | map.insert(("floor","f32",&[]).into(),ProcedureOutputs::new("f32",f!(floor_f32),r!(floor_f32))); 247 | map.insert(("fract","f32",&[]).into(),ProcedureOutputs::new("f32",f!(fract_f32),r!(fract_f32))); 248 | map.insert(("recip","f32",&[]).into(),ProcedureOutputs::new("f32",f!(recip_f32),r!(recip_f32))); 249 | map.insert(("round","f32",&[]).into(),ProcedureOutputs::new("f32",f!(round_f32),r!(round_f32))); 250 | 251 | // f64 252 | // ---------------------------- 253 | // Exponents 254 | map.insert(("powi","f64",&["i32"]).into(),ProcedureOutputs::new("f64",f!(powi_f64),r!(powi_f64))); 255 | map.insert(("powf","f64",&["f64"]).into(),ProcedureOutputs::new("f64",f!(powf_f64),r!(powf_f64))); 256 | map.insert(("sqrt","f64",&[]).into(),ProcedureOutputs::new("f64",f!(sqrt_f64),r!(sqrt_f64))); 257 | map.insert(("cbrt","f64",&[]).into(),ProcedureOutputs::new("f64",f!(cbrt_f64),r!(cbrt_f64))); 258 | map.insert(("exp","f64",&[]).into(),ProcedureOutputs::new("f64",f!(exp_f64),r!(exp_f64))); 259 | map.insert(("exp2","f64",&[]).into(),ProcedureOutputs::new("f64",f!(exp2_f64),r!(exp2_f64))); 260 | map.insert(("exp_m1","f64",&[]).into(),ProcedureOutputs::new("f64",f!(exp_m1_f64),r!(exp_m1_f64))); 261 | // Logs 262 | map.insert(("ln","f64",&[]).into(),ProcedureOutputs::new("f64",f!(ln_f64),r!(ln_f64))); 263 | map.insert(("ln_1p","f64",&[]).into(),ProcedureOutputs::new("f64",f!(ln_1p_f64),r!(ln_1p_f64))); 264 | map.insert(("log","f64",&["f64"]).into(),ProcedureOutputs::new("f64",f!(log_f64),r!(log_f64))); 265 | map.insert(("log10","f64",&[]).into(),ProcedureOutputs::new("f64",f!(log10_f64),r!(log10_f64))); 266 | map.insert(("log2","f64",&[]).into(),ProcedureOutputs::new("f64",f!(log2_f64),r!(log2_f64))); 267 | // Trig 268 | map.insert(("acos","f64",&[]).into(),ProcedureOutputs::new("f64",f!(acos_f64),r!(acos_f64))); 269 | map.insert(("acosh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(acosh_f64),r!(acosh_f64))); 270 | map.insert(("asin","f64",&[]).into(),ProcedureOutputs::new("f64",f!(asin_f64),r!(asin_f64))); 271 | map.insert(("asinh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(asinh_f64),r!(asinh_f64))); 272 | map.insert(("atan","f64",&[]).into(),ProcedureOutputs::new("f64",f!(atan_f64),r!(atan_f64))); 273 | map.insert(("sin","f64",&[]).into(),ProcedureOutputs::new("f64",f!(sin_f64),r!(sin_f64))); 274 | map.insert(("atanh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(atanh_f64),r!(atanh_f64))); 275 | map.insert(("cos","f64",&[]).into(),ProcedureOutputs::new("f64",f!(cos_f64),r!(cos_f64))); 276 | map.insert(("cosh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(cosh_f64),r!(cosh_f64))); 277 | map.insert(("sinh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(sinh_f64),r!(sinh_f64))); 278 | map.insert(("tan","f64",&[]).into(),ProcedureOutputs::new("f64",f!(tan_f64),r!(tan_f64))); 279 | // map.insert(("tanh","f64",&[]).into(),ProcedureOutputs::new("f64",f!(tanh_f64),r!(tanh::<{Type::F32}>))); 280 | // Misc 281 | map.insert(("abs","f64",&[]).into(),ProcedureOutputs::new("f64",f!(abs_f64),r!(abs_f64))); 282 | map.insert(("ceil","f64",&[]).into(),ProcedureOutputs::new("f64",f!(ceil_f64),r!(ceil_f64))); 283 | map.insert(("floor","f64",&[]).into(),ProcedureOutputs::new("f64",f!(floor_f64),r!(floor_f64))); 284 | map.insert(("fract","f64",&[]).into(),ProcedureOutputs::new("f64",f!(fract_f64),r!(fract_f64))); 285 | map.insert(("recip","f64",&[]).into(),ProcedureOutputs::new("f64",f!(recip_f64),r!(recip_f64))); 286 | map.insert(("round","f64",&[]).into(),ProcedureOutputs::new("f64",f!(round_f64),r!(round_f64))); 287 | 288 | // Return 289 | // ------------------------------------------------------------------------------------ 290 | map 291 | }; 292 | /// Internal map of currently supported operations. 293 | pub static ref SUPPORTED_OPERATIONS: OperationMap = { 294 | let mut map = OperationMap::new(); 295 | // Primitives 296 | // ------------------------------------------------------------------------------------ 297 | // f32 arithmetics 298 | map.insert(("f32","+","f32").into(),ProcedureOutputs::new("f32",f!(add_f32),r!(add_f32))); 299 | map.insert(("f32","*","f32").into(),ProcedureOutputs::new("f32",f!(mul_f32),r!(mul_f32))); 300 | map.insert(("f32","/","f32").into(),ProcedureOutputs::new("f32",f!(div_f32),r!(div_f32))); 301 | map.insert(("f32","-","f32").into(),ProcedureOutputs::new("f32",f!(sub_f32),r!(sub_f32))); 302 | // f64 arithmetics 303 | map.insert(("f64","+","f64").into(),ProcedureOutputs::new("f64",f!(add_f64),r!(add_f64))); 304 | map.insert(("f64","*","f64").into(),ProcedureOutputs::new("f64",f!(mul_f64),r!(mul_f64))); 305 | map.insert(("f64","/","f64").into(),ProcedureOutputs::new("f64",f!(div_f64),r!(div_f64))); 306 | map.insert(("f64","-","f64").into(),ProcedureOutputs::new("f64",f!(sub_f64),r!(sub_f64))); 307 | // i8 arithmetics 308 | map.insert(("i8","+","i8").into(),ProcedureOutputs::new("i8",f!(add_i8),r!(add_i8))); 309 | map.insert(("i8","*","i8").into(),ProcedureOutputs::new("i8",f!(mul_i8),r!(mul_i8))); 310 | map.insert(("i8","/","i8").into(),ProcedureOutputs::new("i8",f!(div_i8),r!(div_i8))); 311 | map.insert(("i8","-","i8").into(),ProcedureOutputs::new("i8",f!(sub_i8),r!(sub_i8))); 312 | // i16 arithmetics 313 | map.insert(("i16","+","i16").into(),ProcedureOutputs::new("i16",f!(add_i16),r!(add_i16))); 314 | map.insert(("i16","*","i16").into(),ProcedureOutputs::new("i16",f!(mul_i16),r!(mul_i16))); 315 | map.insert(("i16","/","i16").into(),ProcedureOutputs::new("i16",f!(div_i16),r!(div_i16))); 316 | map.insert(("i16","-","i16").into(),ProcedureOutputs::new("i16",f!(sub_i16),r!(sub_i16))); 317 | // i32 arithmetics 318 | map.insert(("i32","+","i32").into(),ProcedureOutputs::new("i32",f!(add_i32),r!(add_i32))); 319 | map.insert(("i32","*","i32").into(),ProcedureOutputs::new("i32",f!(mul_i32),r!(mul_i32))); 320 | map.insert(("i32","/","i32").into(),ProcedureOutputs::new("i32",f!(div_i32),r!(div_i32))); 321 | map.insert(("i32","-","i32").into(),ProcedureOutputs::new("i32",f!(sub_i32),r!(sub_i32))); 322 | // i64 arithmetics 323 | map.insert(("i64","+","i64").into(),ProcedureOutputs::new("i64",f!(add_i64),r!(add_i64))); 324 | map.insert(("i64","*","i64").into(),ProcedureOutputs::new("i64",f!(mul_i64),r!(mul_i64))); 325 | map.insert(("i64","/","i64").into(),ProcedureOutputs::new("i64",f!(div_i64),r!(div_i64))); 326 | map.insert(("i64","-","i64").into(),ProcedureOutputs::new("i64",f!(sub_i64),r!(sub_i64))); 327 | // i128 arithmetics 328 | map.insert(("i128","+","i128").into(),ProcedureOutputs::new("i128",f!(add_i128),r!(add_i128))); 329 | map.insert(("i128","*","i128").into(),ProcedureOutputs::new("i128",f!(mul_i128),r!(mul_i128))); 330 | map.insert(("i128","/","i128").into(),ProcedureOutputs::new("i128",f!(div_i128),r!(div_i128))); 331 | map.insert(("i128","-","i128").into(),ProcedureOutputs::new("i128",f!(sub_i128),r!(sub_i128))); 332 | // u8 arithmetics 333 | map.insert(("u8","+","u8").into(),ProcedureOutputs::new("u8",f!(add_u8),r!(add_u8))); 334 | map.insert(("u8","*","u8").into(),ProcedureOutputs::new("u8",f!(mul_u8),r!(mul_u8))); 335 | map.insert(("u8","/","u8").into(),ProcedureOutputs::new("u8",f!(div_u8),r!(div_u8))); 336 | map.insert(("u8","-","u8").into(),ProcedureOutputs::new("u8",f!(sub_u8),r!(sub_u8))); 337 | // u16 arithmetics 338 | map.insert(("u16","+","u16").into(),ProcedureOutputs::new("u16",f!(add_u16),r!(add_u16))); 339 | map.insert(("u16","*","u16").into(),ProcedureOutputs::new("u16",f!(mul_u16),r!(mul_u16))); 340 | map.insert(("u16","/","u16").into(),ProcedureOutputs::new("u16",f!(div_u16),r!(div_u16))); 341 | map.insert(("u16","-","u16").into(),ProcedureOutputs::new("u16",f!(sub_u16),r!(sub_u16))); 342 | // u32 arithmetics 343 | map.insert(("u32","+","u32").into(),ProcedureOutputs::new("u32",f!(add_u32),r!(add_u32))); 344 | map.insert(("u32","*","u32").into(),ProcedureOutputs::new("u32",f!(mul_u32),r!(mul_u32))); 345 | map.insert(("u32","/","u32").into(),ProcedureOutputs::new("u32",f!(div_u32),r!(div_u32))); 346 | map.insert(("u32","-","u32").into(),ProcedureOutputs::new("u32",f!(sub_u32),r!(sub_u32))); 347 | // u64 arithmetics 348 | map.insert(("u64","+","u64").into(),ProcedureOutputs::new("u64",f!(add_u64),r!(add_u64))); 349 | map.insert(("u64","*","u64").into(),ProcedureOutputs::new("u64",f!(mul_u64),r!(mul_u64))); 350 | map.insert(("u64","/","u64").into(),ProcedureOutputs::new("u64",f!(div_u64),r!(div_u64))); 351 | map.insert(("u64","-","u64").into(),ProcedureOutputs::new("u64",f!(sub_u64),r!(sub_u64))); 352 | // u128 arithmetics 353 | map.insert(("u128","+","u128").into(),ProcedureOutputs::new("u128",f!(add_u128),r!(add_u128))); 354 | map.insert(("u128","*","u128").into(),ProcedureOutputs::new("u128",f!(mul_u128),r!(mul_u128))); 355 | map.insert(("u128","/","u128").into(),ProcedureOutputs::new("u128",f!(div_u128),r!(div_u128))); 356 | map.insert(("u128","-","u128").into(),ProcedureOutputs::new("u128",f!(sub_u128),r!(sub_u128))); 357 | // Return 358 | // ------------------------------------------------------------------------------------ 359 | map 360 | }; 361 | } 362 | -------------------------------------------------------------------------------- /core/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(incomplete_features)] 2 | #![feature(iter_intersperse)] 3 | #![feature(adt_const_params)] 4 | #![feature(proc_macro_diagnostic)] 5 | 6 | //! **I do not recommend using this directly, please sea [rust-ad](https://crates.io/crates/rust-ad).** 7 | //! 8 | //! Internal non-proc-macro functionality. 9 | 10 | extern crate proc_macro; 11 | use proc_macro::Diagnostic; 12 | use std::collections::HashMap; 13 | use syn::spanned::Spanned; 14 | 15 | /// Functions to compute derivatives (specific function support). 16 | /// 17 | /// A function name may look like `__f_internal_powi_f32`: 18 | /// 1. `__f` represents a forward auto-diff function 19 | /// 2. `_internal_` is just the internal identifier. 20 | /// 3. `powi_` is the function being supported. 21 | /// 4. `f32` is the general type (while this doesn't technically enforce anything it will typically be the respective `Self`). 22 | pub mod derivatives; 23 | pub use derivatives::*; 24 | 25 | mod dict; 26 | pub use dict::*; 27 | 28 | /// Some utility functions used for [syn]. 29 | pub mod traits; 30 | use traits::*; 31 | 32 | /// Prefix used for flattening binary expressions in function arguments. 33 | pub const FUNCTION_PREFIX: &str = "f"; 34 | /// Prefix used for flattening binary expressions as a receiver for a method. 35 | pub const RECEIVER_PREFIX: &str = "r"; 36 | /// Prefix used for flattening return statements; 37 | pub const RETURN_SUFFIX: &str = "rtn"; 38 | 39 | /// Insert key into map with initial value element or append to existing value 40 | pub fn append_insert(key: &str, value: String, map: &mut HashMap>) { 41 | if let Some(entry) = map.get_mut(key) { 42 | entry.push(value); 43 | } else { 44 | map.insert(String::from(key), vec![value]); 45 | } 46 | } 47 | /// Gets type of a given expression as a string 48 | /// 49 | /// e.g. for `let a = b + c`, given `b` and `c` are in `type_map` (lets say they are `f32` and `f64`) then we look for the operation `f32+f64` in supported operations, then we know the output type and we return this type. 50 | pub fn expr_type( 51 | expr: &syn::Expr, 52 | type_map: &HashMap, 53 | ) -> Result { 54 | match expr { 55 | syn::Expr::Path(path_expr) => { 56 | let var = path_expr.path.segments[0].ident.to_string(); 57 | match type_map.get(&var) { 58 | Some(ident) => Ok(ident.clone()), 59 | None => { 60 | Diagnostic::spanned( 61 | path_expr.span().unwrap(), 62 | proc_macro::Level::Error, 63 | format!("variable not found in type map ({:?})", type_map), 64 | ) 65 | .emit(); 66 | Err(String::from("expr_type")) 67 | } 68 | } 69 | } 70 | syn::Expr::Lit(lit_expr) => literal_type(lit_expr), 71 | syn::Expr::Call(call_expr) => { 72 | let function_sig = pass!(function_signature(call_expr, type_map), "expr_type"); 73 | match SUPPORTED_FUNCTIONS.get(&function_sig) { 74 | Some(out_sig) => Ok(out_sig.output_type.clone()), 75 | None => { 76 | let error = format!("expr_type: unsupported function: {}", function_sig); 77 | Diagnostic::spanned(call_expr.span().unwrap(), proc_macro::Level::Error, error) 78 | .emit(); 79 | Err(String::from("expr_type")) 80 | } 81 | } 82 | } 83 | syn::Expr::MethodCall(method_expr) => { 84 | let method_sig = pass!(method_signature(method_expr, type_map), "expr_type"); 85 | // Searches for supported function signature by function identifier and argument types. 86 | match SUPPORTED_METHODS.get(&method_sig) { 87 | Some(out_sig) => Ok(out_sig.output_type.clone()), 88 | None => { 89 | let error = format!("unsupported method: {}", method_sig); 90 | Diagnostic::spanned( 91 | method_expr.span().unwrap(), 92 | proc_macro::Level::Error, 93 | error, 94 | ) 95 | .emit(); 96 | Err(String::from("expr_type")) 97 | } 98 | } 99 | } 100 | syn::Expr::Binary(bin_expr) => { 101 | let operation_sig = match operation_signature(bin_expr, type_map) { 102 | Ok(types) => types, 103 | Err(e) => return Err(e), 104 | }; 105 | // I think this is cleaner than embedding a `format!` within an `.expect` 106 | match SUPPORTED_OPERATIONS.get(&operation_sig) { 107 | Some(out_sig) => Ok(out_sig.output_type.clone()), 108 | None => { 109 | Diagnostic::spanned( 110 | bin_expr.span().unwrap(), 111 | proc_macro::Level::Error, 112 | format!("expr_type: unsupported binary operation: {}", operation_sig), 113 | ) 114 | .emit(); 115 | Err(String::from("expr_type")) 116 | } 117 | } 118 | } 119 | _ => { 120 | Diagnostic::spanned( 121 | expr.span().unwrap(), 122 | proc_macro::Level::Error, 123 | "expr_type: unsupported expression type", 124 | ) 125 | .emit(); 126 | Err(String::from("expr_type")) 127 | } 128 | } 129 | } 130 | 131 | /// Gets type of literal (only supported numerical types) 132 | pub fn literal_type(expr_lit: &syn::ExprLit) -> Result { 133 | match &expr_lit.lit { 134 | syn::Lit::Float(float_lit) => { 135 | // Float literal is either f32 or f64 136 | let float_str = float_lit.to_string(); 137 | 138 | let n = float_str.len(); 139 | // If n<=3 then there is not enough space for the float type identifier, since they require 3 chars 140 | if n <= 3 { 141 | Diagnostic::spanned( 142 | expr_lit.span().unwrap(), 143 | proc_macro::Level::Error, 144 | "All literals need a type suffix e.g. `10.2f32` -- Bad float literal (len)", 145 | ) 146 | .emit(); 147 | return Err(String::from("literal_type")); 148 | } 149 | let float_type_str = &float_str[n - 3..n]; 150 | if !(float_type_str == "f32" || float_type_str == "f64") { 151 | Diagnostic::spanned( 152 | expr_lit.span().unwrap(), 153 | proc_macro::Level::Error, 154 | "All literals need a type suffix e.g. `10.2f32` -- Bad float literal (type)", 155 | ) 156 | .emit(); 157 | return Err(String::from("literal_type")); 158 | } 159 | Ok(String::from(float_type_str)) 160 | } 161 | syn::Lit::Int(int_lit) => { 162 | // Integer literal could be any of the numbers, `4f32`, `16u32` etc. 163 | let int_str = int_lit.to_string(); 164 | let n = int_str.len(); 165 | 166 | // Checking if `i128` or `u128` (the 4 character length type annotations) 167 | let large_type = if n > 4 { 168 | let large_int_str = &int_str[n - 4..n]; 169 | match large_int_str { 170 | "i128" | "u128" => Some(String::from(large_int_str)), 171 | _ => None, 172 | } 173 | } else { 174 | None 175 | }; 176 | // Checking if `f32` or `u16` etc. (the 3 character length type annotations) 177 | let standard_type = if n > 3 { 178 | let standard_int_str = &int_str[n - 3..n]; 179 | match standard_int_str { 180 | "u16" | "u32" | "u64" | "i16" | "i32" | "i64" | "f32" | "f64" => { 181 | Some(String::from(standard_int_str)) 182 | } 183 | _ => None, 184 | } 185 | } else { 186 | None 187 | }; 188 | // Checking `u8` or `i8` (2 character length type annotations) 189 | let short_type = if n > 2 { 190 | let short_int_str = &int_str[n - 2..n]; 191 | match short_int_str { 192 | "i8" | "u8" => Some(String::from(short_int_str)), 193 | _ => None, 194 | } 195 | } else { 196 | None 197 | }; 198 | 199 | match large_type.or(standard_type).or(short_type) { 200 | Some(int_lit_some) => Ok(int_lit_some), 201 | None => { 202 | Diagnostic::spanned( 203 | expr_lit.span().unwrap(), 204 | proc_macro::Level::Error, 205 | "All literals need a type suffix e.g. `10.2f32` -- Bad integer literal", 206 | ) 207 | .emit(); 208 | Err(String::from("literal_type")) 209 | } 210 | } 211 | } 212 | _ => { 213 | Diagnostic::spanned( 214 | expr_lit.span().unwrap(), 215 | proc_macro::Level::Error, 216 | "Unsupported literal (only integer and float literals are supported)", 217 | ) 218 | .emit(); 219 | Err(String::from("literal_type")) 220 | } 221 | } 222 | } 223 | 224 | /// Given an index (e.g. `1`) appends `REVERSE_JOINED_DERIVATIVE` (e.g. `der_a`). 225 | #[macro_export] 226 | macro_rules! rtn { 227 | ($a:expr) => {{ 228 | format!("{}{}", rust_ad_consts::REVERSE_RETURN_DERIVATIVE, $a) 229 | }}; 230 | } 231 | /// Given identifier string (e.g. `x`) appends `DERIVATIVE_PREFIX` (e.g. `der_a`). 232 | #[macro_export] 233 | macro_rules! der { 234 | ($a:expr) => {{ 235 | format!("{}{}", rust_ad_consts::DERIVATIVE_PREFIX, $a) 236 | }}; 237 | } 238 | /// With-Respect-To Nth 239 | /// 240 | /// wrt!(a,b,1) = δa/δb_1 241 | #[macro_export] 242 | macro_rules! wrtn { 243 | ($a:expr, $b:expr, $c: expr) => {{ 244 | format!("{}_wrt_{}_{}", $a, $b, $c) 245 | }}; 246 | } 247 | /// With-Respect-To 248 | /// 249 | /// wrt!(a,b) = δa/δb 250 | #[macro_export] 251 | macro_rules! wrt { 252 | ($a:expr,$b:expr) => {{ 253 | format!("{}_wrt_{}", $a, $b) 254 | }}; 255 | } 256 | // TODO Is there not a nice inbuilt way to do this? 257 | #[macro_export] 258 | macro_rules! pass { 259 | ($result: expr,$prefix:expr) => { 260 | match $result { 261 | Ok(res) => res, 262 | Err(err) => { 263 | return Err(format!("{}->{}", $prefix, err)); 264 | } 265 | } 266 | }; 267 | } 268 | /// Used so its easier to change return error type. 269 | pub type PassError = String; 270 | 271 | /// Gets method signature for internal use 272 | pub fn method_signature( 273 | method_expr: &syn::ExprMethodCall, 274 | type_map: &HashMap, 275 | ) -> Result { 276 | // Gets method identifier 277 | let method_str = method_expr.method.to_string(); 278 | // Gets receiver type 279 | let receiver_type_str = pass!( 280 | expr_type(&*method_expr.receiver, type_map), 281 | "method_signature" 282 | ); 283 | // Gets argument types 284 | let arg_types_res = method_expr 285 | .args 286 | .iter() 287 | .map(|p| expr_type(p, type_map)) 288 | .collect::, _>>(); 289 | let arg_types = pass!(arg_types_res, "method_signature"); 290 | Ok(MethodSignature::new( 291 | method_str, 292 | receiver_type_str, 293 | arg_types, 294 | )) 295 | } 296 | /// Gets function signature for internal use 297 | pub fn function_signature( 298 | function_expr: &syn::ExprCall, 299 | type_map: &HashMap, 300 | ) -> Result { 301 | // Gets argument types 302 | 303 | let arg_types_res = function_expr 304 | .args 305 | .iter() 306 | .map(|arg| expr_type(arg, type_map)) 307 | .collect::, _>>(); 308 | let arg_types = pass!(arg_types_res, "function_signature"); 309 | // Gets function identifier1 310 | let func_ident_str = function_expr 311 | .func 312 | .path() 313 | .expect("function_signature: func not path") 314 | .path 315 | .segments[0] 316 | .ident 317 | .to_string(); 318 | // Create function signature 319 | Ok(FunctionSignature::new(func_ident_str, arg_types)) 320 | } 321 | 322 | /// Gets operation signature for internal use 323 | pub fn operation_signature( 324 | operation_expr: &syn::ExprBinary, 325 | type_map: &HashMap, 326 | ) -> Result { 327 | // Gets types of lhs and rhs of expression 328 | let (left, right) = ( 329 | expr_type(&*operation_expr.left, type_map), 330 | expr_type(&*operation_expr.right, type_map), 331 | ); 332 | if left.is_err() { 333 | Diagnostic::spanned( 334 | operation_expr.left.span().unwrap(), 335 | proc_macro::Level::Error, 336 | "operation_signature: unsupported left type", 337 | ) 338 | .emit(); 339 | } 340 | if right.is_err() { 341 | Diagnostic::spanned( 342 | operation_expr.right.span().unwrap(), 343 | proc_macro::Level::Error, 344 | "operation_signature: unsupported right type", 345 | ) 346 | .emit(); 347 | } 348 | match (left, right) { 349 | (Ok(l), Ok(r)) => Ok(OperationSignature::from((l, operation_expr.op, r))), 350 | _ => Err(String::from("operation_signature")), 351 | } 352 | } 353 | -------------------------------------------------------------------------------- /core/src/traits.rs: -------------------------------------------------------------------------------- 1 | extern crate proc_macro; 2 | // TODO Make macro to minimize code duplication here. 3 | 4 | type UnwrapResult<'a, T> = Result<&'a T, &'static str>; 5 | type UnwrapResultMut<'a, T> = Result<&'a mut T, &'static str>; 6 | pub trait Named { 7 | fn name(&self) -> &'static str; 8 | } 9 | // TODO Can't this be done with a macro? 10 | impl Named for syn::Expr { 11 | fn name(&self) -> &'static str { 12 | match self { 13 | Self::Array(_) => "Array", 14 | Self::Assign(_) => "Assign", 15 | Self::AssignOp(_) => "AssignOp", 16 | Self::Async(_) => "Async", 17 | Self::Await(_) => "Await", 18 | Self::Binary(_) => "Binary", 19 | Self::Block(_) => "Block", 20 | Self::Box(_) => "Box", 21 | Self::Break(_) => "Break", 22 | Self::Call(_) => "Call", 23 | Self::Cast(_) => "Cast", 24 | Self::Closure(_) => "Closure", 25 | Self::Continue(_) => "Continue", 26 | Self::Field(_) => "Field", 27 | Self::ForLoop(_) => "ForLoop", 28 | Self::Group(_) => "Group", 29 | Self::If(_) => "If", 30 | Self::Index(_) => "Index", 31 | Self::Let(_) => "Let", 32 | Self::Lit(_) => "Lit", 33 | Self::Loop(_) => "Loop", 34 | Self::Macro(_) => "Macro", 35 | Self::Match(_) => "Match", 36 | Self::MethodCall(_) => "MethodCall", 37 | Self::Paren(_) => "Paren", 38 | Self::Path(_) => "Path", 39 | Self::Range(_) => "Range", 40 | Self::Reference(_) => "Reference", 41 | Self::Repeat(_) => "Repeat", 42 | Self::Return(_) => "Return", 43 | Self::Struct(_) => "Struct", 44 | Self::Try(_) => "Try", 45 | Self::TryBlock(_) => "TryBlock", 46 | Self::Tuple(_) => "Tuple", 47 | Self::Type(_) => "Type", 48 | Self::Unary(_) => "Unary", 49 | Self::Unsafe(_) => "Unsafe", 50 | Self::Verbatim(_) => "Verbatim", 51 | Self::While(_) => "While", 52 | Self::Yield(_) => "Yield", 53 | _ => unreachable!(), // some variants omitted 54 | } 55 | } 56 | } 57 | pub trait UnwrapReturnType { 58 | fn type_(&self) -> UnwrapResult; 59 | } 60 | impl UnwrapReturnType for syn::ReturnType { 61 | fn type_(&self) -> UnwrapResult { 62 | match self { 63 | Self::Type(_, typed_) => Ok(&**typed_), 64 | _ => Err("called `ReturnType::type_()` on a non `Type` value"), 65 | } 66 | } 67 | } 68 | 69 | pub trait UnwrapType { 70 | fn path(&self) -> UnwrapResult; 71 | } 72 | impl UnwrapType for syn::Type { 73 | fn path(&self) -> UnwrapResult { 74 | match self { 75 | Self::Path(path) => Ok(path), 76 | _ => Err("called `Type::path()` on a non `Path` value"), 77 | } 78 | } 79 | } 80 | pub trait UnwrapLit { 81 | fn float(&self) -> UnwrapResult; 82 | } 83 | impl UnwrapLit for syn::Lit { 84 | fn float(&self) -> UnwrapResult { 85 | match self { 86 | Self::Float(float) => Ok(float), 87 | _ => Err("called `Lit::float()` on a non `Float` value"), 88 | } 89 | } 90 | } 91 | pub trait UnwrapTokenTree { 92 | fn ident(&self) -> UnwrapResult; 93 | fn literal(&self) -> UnwrapResult; 94 | } 95 | impl UnwrapTokenTree for proc_macro::TokenTree { 96 | fn ident(&self) -> UnwrapResult { 97 | match self { 98 | Self::Ident(local) => Ok(local), 99 | _ => Err("called `TokenTree::ident()` on a non `Ident` value"), 100 | } 101 | } 102 | fn literal(&self) -> UnwrapResult { 103 | match self { 104 | Self::Literal(lit) => Ok(lit), 105 | _ => Err("called `TokenTree::literal()` on a non `Literal` value"), 106 | } 107 | } 108 | } 109 | 110 | pub trait UnwrapStmt { 111 | fn local(&self) -> UnwrapResult; 112 | fn local_mut(&mut self) -> UnwrapResultMut; 113 | fn semi(&self) -> UnwrapResult; 114 | fn semi_mut(&mut self) -> UnwrapResultMut; 115 | } 116 | impl UnwrapStmt for syn::Stmt { 117 | fn local(&self) -> UnwrapResult { 118 | match self { 119 | Self::Local(local) => Ok(local), 120 | _ => Err("called `Stmt::local()` on a non `Local` value"), 121 | } 122 | } 123 | fn local_mut(&mut self) -> UnwrapResultMut { 124 | match self { 125 | Self::Local(local) => Ok(local), 126 | _ => Err("called `Stmt::local_mut()` on a non `Local` value"), 127 | } 128 | } 129 | fn semi(&self) -> UnwrapResult { 130 | match self { 131 | Self::Semi(expr, _) => Ok(expr), 132 | _ => Err("called `Stmt::semi()` on a non `Semi` value"), 133 | } 134 | } 135 | fn semi_mut(&mut self) -> UnwrapResultMut { 136 | match self { 137 | Self::Semi(expr, _) => Ok(expr), 138 | _ => Err("called `Stmt::semi_mut()` on a non `Semi` value"), 139 | } 140 | } 141 | } 142 | pub trait IsStmt { 143 | fn is_local(&self) -> bool; 144 | fn is_semi(&self) -> bool; 145 | } 146 | impl IsStmt for syn::Stmt { 147 | fn is_local(&self) -> bool { 148 | matches!(self, Self::Local(_)) 149 | } 150 | fn is_semi(&self) -> bool { 151 | matches!(self, Self::Semi(_, _)) 152 | } 153 | } 154 | pub trait UnwrapPat { 155 | fn ident_mut(&mut self) -> UnwrapResultMut; 156 | fn ident(&self) -> UnwrapResult; 157 | fn tuple_mut(&mut self) -> UnwrapResultMut; 158 | fn tuple(&self) -> UnwrapResult; 159 | } 160 | impl UnwrapPat for syn::Pat { 161 | fn ident_mut(&mut self) -> UnwrapResultMut { 162 | match self { 163 | Self::Ident(ident) => Ok(ident), 164 | _ => Err("called `Pat::ident()` on a non `Ident` value"), 165 | } 166 | } 167 | fn ident(&self) -> UnwrapResult { 168 | match self { 169 | Self::Ident(ident) => Ok(ident), 170 | _ => Err("called `Pat::ident()` on a non `Ident` value"), 171 | } 172 | } 173 | fn tuple_mut(&mut self) -> UnwrapResultMut { 174 | match self { 175 | Self::Tuple(tuple) => Ok(tuple), 176 | _ => Err("called `Pat::tuple_mut()` on a non `Tuple` value"), 177 | } 178 | } 179 | fn tuple(&self) -> UnwrapResult { 180 | match self { 181 | Self::Tuple(tuple) => Ok(tuple), 182 | _ => Err("called `Pat::tuple()` on a non `Tuple` value"), 183 | } 184 | } 185 | } 186 | pub trait IsExpr { 187 | fn is_binary(&self) -> bool; 188 | fn is_path(&self) -> bool; 189 | fn is_return(&self) -> bool; 190 | fn is_call(&self) -> bool; 191 | fn is_method_call(&self) -> bool; 192 | fn is_lit(&self) -> bool; 193 | } 194 | impl IsExpr for syn::Expr { 195 | fn is_binary(&self) -> bool { 196 | matches!(self, Self::Binary(_)) 197 | } 198 | fn is_path(&self) -> bool { 199 | matches!(self, Self::Path(_)) 200 | } 201 | fn is_return(&self) -> bool { 202 | matches!(self, Self::Return(_)) 203 | } 204 | fn is_call(&self) -> bool { 205 | matches!(self, Self::Call(_)) 206 | } 207 | fn is_method_call(&self) -> bool { 208 | matches!(self, Self::MethodCall(_)) 209 | } 210 | fn is_lit(&self) -> bool { 211 | matches!(self, Self::Lit(_)) 212 | } 213 | } 214 | pub trait UnwrapExpr { 215 | fn binary(&self) -> UnwrapResult; 216 | fn binary_mut(&mut self) -> UnwrapResultMut; 217 | fn block(&self) -> UnwrapResult; 218 | fn block_mut(&mut self) -> UnwrapResultMut; 219 | fn path(&self) -> UnwrapResult; 220 | fn return_(&self) -> UnwrapResult; 221 | fn return_mut(&mut self) -> UnwrapResultMut; 222 | fn call(&self) -> UnwrapResult; 223 | fn call_mut(&mut self) -> UnwrapResultMut; 224 | fn method_call(&self) -> UnwrapResult; 225 | fn method_call_mut(&mut self) -> UnwrapResultMut; 226 | fn paren(&self) -> UnwrapResult; 227 | } 228 | impl UnwrapExpr for syn::Expr { 229 | fn binary(&self) -> UnwrapResult { 230 | match self { 231 | Self::Binary(b) => Ok(b), 232 | _ => Err("called `Expr::binary()` on a non `Binary` value"), 233 | } 234 | } 235 | fn binary_mut(&mut self) -> UnwrapResultMut { 236 | match self { 237 | Self::Binary(b) => Ok(b), 238 | _ => Err("called `Expr::binary_mut()` on a non `Binary` value"), 239 | } 240 | } 241 | fn block(&self) -> UnwrapResult { 242 | match self { 243 | Self::Block(b) => Ok(b), 244 | _ => Err("called `Expr::block()` on a non `Block` value"), 245 | } 246 | } 247 | fn block_mut(&mut self) -> UnwrapResultMut { 248 | match self { 249 | Self::Block(b) => Ok(b), 250 | _ => Err("called `Expr::block_mut()` on a non `Block` value"), 251 | } 252 | } 253 | fn path(&self) -> UnwrapResult { 254 | match self { 255 | Self::Path(b) => Ok(b), 256 | _ => Err("called `Expr::path()` on a non `Path` value"), 257 | } 258 | } 259 | fn return_(&self) -> UnwrapResult { 260 | match self { 261 | Self::Return(b) => Ok(b), 262 | _ => Err("called `Expr::return_()` on a non `Return` value"), 263 | } 264 | } 265 | fn return_mut(&mut self) -> UnwrapResultMut { 266 | match self { 267 | Self::Return(b) => Ok(b), 268 | _ => Err("called `Expr::return_mut()` on a non `Return` value"), 269 | } 270 | } 271 | fn call(&self) -> UnwrapResult { 272 | match self { 273 | Self::Call(b) => Ok(b), 274 | _ => Err("called `Expr::call()` on a non `Call` value"), 275 | } 276 | } 277 | fn call_mut(&mut self) -> UnwrapResultMut { 278 | match self { 279 | Self::Call(b) => Ok(b), 280 | _ => Err("called `Expr::call_mut()` on a non `Call` value"), 281 | } 282 | } 283 | fn method_call(&self) -> UnwrapResult { 284 | match self { 285 | Self::MethodCall(b) => Ok(b), 286 | _ => Err("called `Expr::method_call()` on a non `MethodCall` value"), 287 | } 288 | } 289 | fn method_call_mut(&mut self) -> UnwrapResultMut { 290 | match self { 291 | Self::MethodCall(b) => Ok(b), 292 | _ => Err("called `Expr::method_call_mut()` on a non `MethodCall` value"), 293 | } 294 | } 295 | fn paren(&self) -> UnwrapResult { 296 | match self { 297 | Self::Paren(b) => Ok(b), 298 | _ => Err("called `Expr::paren()` on a non `Paren` value"), 299 | } 300 | } 301 | } 302 | pub trait UnwrapMember { 303 | fn named(&self) -> UnwrapResult; 304 | } 305 | impl UnwrapMember for syn::Member { 306 | fn named(&self) -> UnwrapResult { 307 | match self { 308 | Self::Named(i) => Ok(i), 309 | Self::Unnamed(_) => Err("called `Member::named()` on a non `Named` value"), 310 | } 311 | } 312 | } 313 | pub trait UnwrapFnArg { 314 | fn typed(&self) -> UnwrapResult; 315 | } 316 | impl UnwrapFnArg for syn::FnArg { 317 | fn typed(&self) -> UnwrapResult { 318 | match self { 319 | Self::Typed(i) => Ok(i), 320 | Self::Receiver(_) => Err("called `PatType::typed()` on a non `Typed` value"), 321 | } 322 | } 323 | } 324 | -------------------------------------------------------------------------------- /macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-ad-macros" 3 | version = "0.8.0" 4 | edition = "2021" 5 | 6 | description = "Rust Auto-Differentiation." 7 | license = "Apache-2.0" 8 | repository = "https://github.com/JonathanWoollett-Light/rust-ad" 9 | documentation = "https://docs.rs/rust-ad/" 10 | readme = "../README.md" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | syn = { version="1.0.82", features=["full","extra-traits"] } 16 | quote = "1.0.10" 17 | rust-ad-core = { version = "0.8.0", path = "../core" } 18 | rust-ad-consts = { version = "0.8.0", path = "../consts" } 19 | 20 | [lib] 21 | proc-macro = true -------------------------------------------------------------------------------- /macros/src/forward.rs: -------------------------------------------------------------------------------- 1 | use proc_macro::Diagnostic; 2 | use quote::ToTokens; 3 | use rust_ad_core::traits::*; 4 | use rust_ad_core::Arg; 5 | use rust_ad_core::*; 6 | use std::collections::HashMap; 7 | #[cfg(not(debug_assertions))] 8 | use std::collections::HashSet; 9 | use syn::spanned::Spanned; 10 | 11 | pub fn update_forward_return( 12 | block: &mut syn::Block, 13 | function_inputs: &[String], 14 | #[cfg(not(debug_assertions))] type_map: HashMap, 15 | #[cfg(not(debug_assertions))] non_zero_derivatives: HashSet, 16 | ) -> Result<(), PassError> { 17 | match block.stmts.last_mut() { 18 | Some(last_stmt) => { 19 | *last_stmt = match last_stmt { 20 | syn::Stmt::Semi(syn::Expr::Return(expr_return_opt), _) => { 21 | let expr_return = match expr_return_opt.expr.as_ref() { 22 | Some(e) => e, 23 | None => { 24 | Diagnostic::spanned( 25 | expr_return_opt.span().unwrap(), 26 | proc_macro::Level::Error, 27 | "No return expression", 28 | ) 29 | .emit(); 30 | return Err("update_forward_return: No return expression".into()); 31 | } 32 | }; 33 | match &**expr_return { 34 | syn::Expr::Tuple(expr_tuple) => { 35 | let return_idents = expr_tuple.to_token_stream().to_string(); 36 | eprintln!("return_idents: {}", return_idents); 37 | let return_str = format!("return ({},({}));", 38 | return_idents, 39 | expr_tuple.elems.iter() 40 | .map(|e| 41 | match e { 42 | syn::Expr::Path(ep) => { 43 | let ep_str = ep.to_token_stream().to_string(); 44 | format!("({})",function_inputs 45 | .iter() 46 | .map(|input| if ep_str == *input { 47 | der!(input) 48 | } else { 49 | let der = wrt!(ep_str, input); 50 | #[cfg(not(debug_assertions))] 51 | match non_zero_derivatives.contains(&der) { 52 | true => der, 53 | false => format!("0{}", type_map.get(input).unwrap()), 54 | } 55 | #[cfg(debug_assertions)] 56 | der 57 | }) 58 | .intersperse(String::from(",")) 59 | .collect::() 60 | ) 61 | } 62 | _ => panic!("update_forward_return: Unsupported inner tuple type (e.g. `return (x,y)` is supported, `return (x,(a,b))` is not supported)") 63 | }) 64 | .intersperse(String::from(",")) 65 | .collect::() 66 | ); 67 | syn::parse_str(&return_str) 68 | .expect("update_forward_return: tuple parse fail") 69 | } 70 | syn::Expr::Path(expr_path) => { 71 | let return_ident = expr_path.to_token_stream().to_string(); 72 | 73 | // The if case where `ident == input` is for when you are returning an input. 74 | let return_str = format!( 75 | "return ({},{});", 76 | return_ident, 77 | match function_inputs.len() { 78 | 0 => String::new(), 79 | 1 => { 80 | let input = &function_inputs[0]; 81 | if return_ident == *input { 82 | der!(input) 83 | } else { 84 | let der = wrt!(return_ident, input); 85 | #[cfg(not(debug_assertions))] 86 | match non_zero_derivatives.contains(&der) { 87 | true => der, 88 | false => { 89 | format!("0{}", type_map.get(input).unwrap()) 90 | } 91 | } 92 | #[cfg(debug_assertions)] 93 | der 94 | } 95 | } 96 | _ => format!( 97 | "({})", 98 | function_inputs 99 | .iter() 100 | .map(|input| if return_ident == *input { 101 | der!(input) 102 | } else { 103 | let der = wrt!(return_ident, input); 104 | #[cfg(not(debug_assertions))] 105 | match non_zero_derivatives.contains(&der) { 106 | true => der, 107 | false => { 108 | format!("0{}", type_map.get(input).unwrap()) 109 | } 110 | } 111 | #[cfg(debug_assertions)] 112 | der 113 | }) 114 | .intersperse(String::from(",")) 115 | .collect::() 116 | ), 117 | } 118 | ); 119 | syn::parse_str(&return_str) 120 | .expect("update_forward_return: path parse fail") 121 | } 122 | _ => { 123 | Diagnostic::spanned( 124 | expr_return_opt.span().unwrap(), 125 | proc_macro::Level::Error, 126 | "Unsupported return expression", 127 | ) 128 | .emit(); 129 | return Err( 130 | "update_forward_return: unsupported return expression".into() 131 | ); 132 | } 133 | } 134 | } 135 | _ => { 136 | Diagnostic::spanned( 137 | block.span().unwrap(), 138 | proc_macro::Level::Error, 139 | "Unsupported return statement", 140 | ) 141 | .emit(); 142 | return Err("update_forward_return: Unsupported return statement".into()); 143 | } 144 | }; 145 | } 146 | _ => { 147 | Diagnostic::spanned( 148 | block.span().unwrap(), 149 | proc_macro::Level::Error, 150 | "No return statement", 151 | ) 152 | .emit(); 153 | return Err("update_forward_return: No return statement".into()); 154 | } 155 | }; 156 | Ok(()) 157 | } 158 | 159 | /// Intersperses values with respect to the preceding values. 160 | pub fn intersperse_succeeding_stmts( 161 | mut x: Vec, 162 | mut extra: K, 163 | f: fn(&syn::Stmt, &mut K) -> Result, PassError>, 164 | ) -> Result, PassError> { 165 | let len = x.len(); 166 | let new_len = len * 2 - 1; 167 | let mut y = Vec::with_capacity(new_len); 168 | 169 | while x.len() > 1 { 170 | y.push(x.remove(0)); 171 | let after_opt = pass!( 172 | f(y.last().unwrap(), &mut extra), 173 | "intersperse_succeeding_stmts" 174 | ); 175 | if let Some(after) = after_opt { 176 | y.push(after); 177 | } 178 | } 179 | y.push(x.remove(0)); 180 | Ok(y) 181 | } 182 | 183 | // TODO Reduce code duplication between `reverse_derivative` and `forward_derivative` 184 | pub fn forward_derivative( 185 | stmt: &syn::Stmt, 186 | #[cfg(not(debug_assertions))] (type_map, function_inputs, non_zero_derivatives): &mut ( 187 | &HashMap, 188 | &[String], 189 | &mut HashSet, 190 | ), 191 | #[cfg(debug_assertions)] (type_map, function_inputs): &mut ( 192 | &HashMap, 193 | &[String], 194 | ), 195 | ) -> Result, PassError> { 196 | if let syn::Stmt::Local(local) = stmt { 197 | let local_ident = local 198 | .pat 199 | .ident() 200 | .expect("forward_derivative: not ident") 201 | .ident 202 | .to_string(); 203 | if let Some(init) = &local.init { 204 | // eprintln!("init: {:#?}",init); 205 | if let syn::Expr::Binary(bin_expr) = &*init.1 { 206 | // Creates operation signature struct 207 | let operation_sig = pass!( 208 | operation_signature(bin_expr, type_map), 209 | "forward_derivative" 210 | ); 211 | // Looks up operation with the given lhs type and rhs type and BinOp. 212 | let operation_out_signature = match SUPPORTED_OPERATIONS.get(&operation_sig) { 213 | Some(sig) => sig, 214 | None => { 215 | let error = format!("unsupported derivative for {}", operation_sig); 216 | Diagnostic::spanned( 217 | bin_expr.span().unwrap(), 218 | proc_macro::Level::Error, 219 | error, 220 | ) 221 | .emit(); 222 | return Err(String::from("forward_derivative")); 223 | } 224 | }; 225 | // Applies the forward derivative function for the found operation. 226 | let new_stmt = (operation_out_signature.forward_derivative)( 227 | local_ident, 228 | &[ 229 | Arg::try_from(&*bin_expr.left).expect("forward_derivative: bin left"), 230 | Arg::try_from(&*bin_expr.right).expect("forward_derivative: bin right"), 231 | ], 232 | function_inputs, 233 | #[cfg(not(debug_assertions))] 234 | non_zero_derivatives, 235 | ); 236 | return Ok(Some(new_stmt)); 237 | } else if let syn::Expr::Call(call_expr) = &*init.1 { 238 | // Create function in signature 239 | let function_in_signature = pass!( 240 | function_signature(call_expr, type_map), 241 | "forward_derivative" 242 | ); 243 | // Gets function out signature 244 | let function_out_signature = match SUPPORTED_FUNCTIONS.get(&function_in_signature) { 245 | Some(sig) => sig, 246 | None => { 247 | let error = format!("unsupported derivative for {}", function_in_signature); 248 | Diagnostic::spanned( 249 | call_expr.span().unwrap(), 250 | proc_macro::Level::Error, 251 | error, 252 | ) 253 | .emit(); 254 | return Err(String::from("forward_derivative")); 255 | } 256 | }; 257 | let args = call_expr 258 | .args 259 | .iter() 260 | .map(|a| Arg::try_from(a).expect("forward_derivative: call arg")) 261 | .collect::>(); 262 | // Gets new stmt 263 | let new_stmt = (function_out_signature.forward_derivative)( 264 | local_ident, 265 | args.as_slice(), 266 | function_inputs, 267 | #[cfg(not(debug_assertions))] 268 | non_zero_derivatives, 269 | ); 270 | 271 | return Ok(Some(new_stmt)); 272 | } else if let syn::Expr::MethodCall(method_expr) = &*init.1 { 273 | let method_sig = pass!( 274 | method_signature(method_expr, type_map), 275 | "forward_derivative" 276 | ); 277 | let method_out = match SUPPORTED_METHODS.get(&method_sig) { 278 | Some(sig) => sig, 279 | None => { 280 | let error = format!("unsupported derivative for {}", method_sig); 281 | Diagnostic::spanned( 282 | method_expr.span().unwrap(), 283 | proc_macro::Level::Error, 284 | error, 285 | ) 286 | .emit(); 287 | return Err(String::from("forward_derivative")); 288 | } 289 | }; 290 | let args = { 291 | let mut base = Vec::new(); 292 | let receiver = Arg::try_from(&*method_expr.receiver) 293 | .expect("forward_derivative: method receiver"); 294 | base.push(receiver); 295 | let mut args = method_expr 296 | .args 297 | .iter() 298 | .map(|a| Arg::try_from(a).expect("forward_derivative: method arg")) 299 | .collect::>(); 300 | base.append(&mut args); 301 | base 302 | }; 303 | 304 | let new_stmt = (method_out.forward_derivative)( 305 | local_ident, 306 | args.as_slice(), 307 | function_inputs, 308 | #[cfg(not(debug_assertions))] 309 | non_zero_derivatives, 310 | ); 311 | return Ok(Some(new_stmt)); 312 | } else if let syn::Expr::Path(expr_path) = &*init.1 { 313 | // Given `let x = y;` 314 | 315 | // This is `x` 316 | let out_ident = local 317 | .pat 318 | .ident() 319 | .expect("forward_derivative: not ident") 320 | .ident 321 | .to_string(); 322 | // This `y` 323 | let in_ident = expr_path.path.segments[0].ident.to_string(); 324 | // This is type of `y` 325 | let out_type = type_map 326 | .get(&in_ident) 327 | .expect("forward_derivative: return not found type"); 328 | let return_type = rust_ad_core::Type::try_from(out_type.as_str()) 329 | .expect("forward_derivative: unsupported return type"); 330 | 331 | let idents = function_inputs 332 | .iter() 333 | .map(|input| wrt!(out_ident, input)) 334 | .intersperse(String::from(",")) 335 | .collect::(); 336 | let derivatives = function_inputs 337 | .iter() 338 | .map(|input| { 339 | cumulative_derivative_wrt_rt(&*init.1, input, function_inputs, &return_type) 340 | }) 341 | .intersperse(String::from(",")) 342 | .collect::(); 343 | let stmt_str = format!("let ({}) = ({});", idents, derivatives); 344 | let new_stmt: syn::Stmt = 345 | syn::parse_str(&stmt_str).expect("forward_derivative: parse fail"); 346 | 347 | return Ok(Some(new_stmt)); 348 | } else if let syn::Expr::Lit(expr_lit) = &*init.1 { 349 | // Given `let x = y;` 350 | 351 | // This is `x` 352 | let out_ident = local 353 | .pat 354 | .ident() 355 | .expect("forward_derivative: not ident") 356 | .ident 357 | .to_string(); 358 | // This is type of `y` 359 | let out_type = literal_type(expr_lit).expect("forward_derivative: bad lit type"); 360 | let return_type = rust_ad_core::Type::try_from(out_type.as_str()) 361 | .expect("forward_derivative: unsupported return type"); 362 | 363 | let idents = function_inputs 364 | .iter() 365 | .map(|input| wrt!(out_ident, input)) 366 | .intersperse(String::from(",")) 367 | .collect::(); 368 | let derivatives = function_inputs 369 | .iter() 370 | .map(|input| { 371 | cumulative_derivative_wrt_rt(&*init.1, input, function_inputs, &return_type) 372 | }) 373 | .intersperse(String::from(",")) 374 | .collect::(); 375 | let stmt_str = format!("let ({}) = ({});", idents, derivatives); 376 | let new_stmt: syn::Stmt = 377 | syn::parse_str(&stmt_str).expect("forward_derivative: parse fail"); 378 | 379 | return Ok(Some(new_stmt)); 380 | } 381 | } 382 | } 383 | Ok(None) 384 | } 385 | -------------------------------------------------------------------------------- /macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(proc_macro_span)] 2 | #![feature(iter_intersperse)] 3 | #![feature(proc_macro_diagnostic)] 4 | #![feature(string_remove_matches)] 5 | 6 | //! **I do not recommend using this directly, please sea [rust-ad](https://crates.io/crates/rust-ad).** 7 | //! 8 | //! External proc-macro functionality. 9 | 10 | use quote::ToTokens; 11 | use rust_ad_core::traits::*; 12 | use rust_ad_core::*; 13 | 14 | extern crate proc_macro; 15 | use proc_macro::{Diagnostic, TokenStream}; 16 | use syn::spanned::Spanned; 17 | 18 | use std::collections::{HashMap, HashSet}; 19 | 20 | mod forward; 21 | use forward::*; 22 | mod reverse; 23 | use reverse::*; 24 | 25 | /// Calls forward auto-differentiation function corresponding to a given function. 26 | /// 27 | /// ``` 28 | /// fn complex_test() { 29 | /// let (f, (der_x, der_y, der_z)) = forward!(complex, 3f32, 5f32, 7f32); 30 | /// is_near(f, 10.1187260448).unwrap(); 31 | /// is_near(der_x, 6.28571428571).unwrap(); 32 | /// is_near(der_y, -0.034212882033).unwrap(); 33 | /// is_near(der_z, -0.128914606556).unwrap(); 34 | /// 35 | /// // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 36 | /// // ∂x = 2(x+1/z) 37 | /// // ∂y = -2 / (y+z^0.5)^2 38 | /// // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 39 | /// // Therefore: 40 | /// // f(3,5,7) = 10.1187260448... 41 | /// // ∂x| = 6.28571428571... 42 | /// // ∂y| = −0.034212882033... 43 | /// // ∂z| = −0.128914606556... 44 | /// #[forward_autodiff] 45 | /// fn complex(x: f32, y: f32, z: f32) -> f32 { 46 | /// let a = x.powi(2i32); 47 | /// let b = x * 2f32 / z; 48 | /// let c = 2f32 / (z.sqrt() + y); 49 | /// let f = a + b + c; 50 | /// return f; 51 | /// } 52 | /// } 53 | /// ``` 54 | #[proc_macro] 55 | pub fn forward(_item: TokenStream) -> TokenStream { 56 | let mut items = _item.into_iter(); 57 | let function_ident = match items.next() { 58 | Some(proc_macro::TokenTree::Ident(ident)) => ident, 59 | _ => panic!("Requires function identifier"), 60 | }; 61 | let vec = items.collect::>(); 62 | let items = vec.chunks_exact(2); 63 | let (inputs, input_spans) = items 64 | .map(|item| { 65 | let (punc, lit) = (&item[0], &item[1]); 66 | match (punc, lit) { 67 | (proc_macro::TokenTree::Punct(_), proc_macro::TokenTree::Literal(num)) => { 68 | (format!("{}", num), num.span()) 69 | } 70 | _ => { 71 | Diagnostic::spanned( 72 | punc.span().join(lit.span()).expect("forward: join error"), 73 | proc_macro::Level::Error, 74 | "Bad statement format, this should be `,` e.g. `,1f32`", 75 | ) 76 | .emit(); 77 | panic!(); 78 | } 79 | } 80 | }) 81 | .unzip::<_, _, Vec<_>, Vec<_>>(); 82 | let input_derivatives = inputs 83 | .iter() 84 | .zip(input_spans.iter()) 85 | .map(|(input, span)| { 86 | match literal_type(&syn::parse_str(input).expect("forward:lit parse fail")) { 87 | Ok(lit_type) => Ok(format!("1{}", lit_type)), 88 | Err(e) => { 89 | let err = "Unsupported literal type"; 90 | Diagnostic::spanned(*span, proc_macro::Level::Error, err).emit(); 91 | Err(format!("forward: {}", e)) 92 | } 93 | } 94 | }) 95 | .collect::, _>>(); 96 | let input_derivatives = match input_derivatives { 97 | Ok(res) => res, 98 | Err(_) => panic!(), 99 | }; 100 | 101 | // let inputs_str = inputs.into_iter().collect::(); 102 | 103 | let (inputs_str, derivatives_str) = match inputs.len() { 104 | 0 => (String::new(), String::new()), 105 | 1 => (inputs[0].clone(), input_derivatives[0].clone()), 106 | _ => ( 107 | format!( 108 | "({})", 109 | inputs 110 | .into_iter() 111 | .intersperse(String::from(",")) 112 | .collect::() 113 | ), 114 | format!( 115 | "({})", 116 | input_derivatives 117 | .into_iter() 118 | .intersperse(String::from(",")) 119 | .collect::() 120 | ), 121 | ), 122 | }; 123 | 124 | let call_str = format!( 125 | "{}{}({},{})", 126 | rust_ad_consts::FORWARD_PREFIX, 127 | function_ident, 128 | inputs_str, 129 | derivatives_str 130 | ); 131 | call_str.parse().unwrap() 132 | } 133 | /// Calls reverse auto-differentiation function corresponding to a given function. 134 | /// 135 | /// Note that since this macro doesn't know the number of outputs of the `complex` you need to specify the output seed derivatives manually. In this case for 1 `f32` we do `(1f32)` (for 2 it would be `(1f32,1f32)` etc.). 136 | /// 137 | /// ``` 138 | /// fn complex_test() { 139 | /// let (f, (der_x, der_y, der_z)) = reverse!(complex, (3f32, 5f32, 7f32), (1f32)); 140 | /// is_near(f, 10.1187260448).unwrap(); 141 | /// is_near(der_x, 6.28571428571).unwrap(); 142 | /// is_near(der_y, -0.034212882033).unwrap(); 143 | /// is_near(der_z, -0.128914606556).unwrap(); 144 | /// 145 | /// // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 146 | /// // ∂x = 2(x+1/z) 147 | /// // ∂y = -2 / (y+z^0.5)^2 148 | /// // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 149 | /// // Therefore: 150 | /// // f(3,5,7) = 10.1187260448... 151 | /// // ∂x| = 6.28571428571... 152 | /// // ∂y| = −0.034212882033... 153 | /// // ∂z| = −0.128914606556... 154 | /// #[reverse_autodiff] 155 | /// fn complex(x: f32, y: f32, z: f32) -> f32 { 156 | /// let a = x.powi(2i32); 157 | /// let b = x * 2f32 / z; 158 | /// let c = 2f32 / (z.sqrt() + y); 159 | /// let f = a + b + c; 160 | /// return f; 161 | /// } 162 | /// } 163 | /// ``` 164 | #[proc_macro] 165 | pub fn reverse(_item: TokenStream) -> TokenStream { 166 | let mut items = _item.into_iter(); 167 | let function_ident = match items.next() { 168 | Some(proc_macro::TokenTree::Ident(ident)) => ident, 169 | _ => panic!("Requires function identifier"), 170 | }; 171 | 172 | // Checks `,` 173 | match items.next() { 174 | Some(proc_macro::TokenTree::Punct(p)) => { 175 | if p.to_string() != "," { 176 | Diagnostic::spanned( 177 | p.span(), 178 | proc_macro::Level::Error, 179 | "This should be a comma e.g. `,`", 180 | ) 181 | .emit(); 182 | panic!(); 183 | } 184 | } 185 | Some(e) => { 186 | Diagnostic::spanned( 187 | e.span(), 188 | proc_macro::Level::Error, 189 | "This should be a comma e.g. `,`", 190 | ) 191 | .emit(); 192 | panic!(); 193 | } 194 | None => { 195 | return format!("{}{}()", rust_ad_consts::REVERSE_PREFIX, function_ident) 196 | .parse() 197 | .unwrap(); 198 | } 199 | } 200 | 201 | // Gets inputs tuple 202 | let inputs = match items.next() { 203 | Some(proc_macro::TokenTree::Group(inputs_group)) => { 204 | let inputs_stream = inputs_group.stream(); 205 | let inputs_vec = inputs_stream.into_iter().collect::>(); 206 | let inputs = inputs_vec 207 | .chunks(2) 208 | .map(|items| match &items[0] { 209 | proc_macro::TokenTree::Literal(num) => num.to_string(), 210 | _ => panic!("reverse: bad"), 211 | }) 212 | .collect::>(); 213 | inputs 214 | } 215 | Some(e) => { 216 | Diagnostic::spanned(e.span(), proc_macro::Level::Error, "Bad inputs").emit(); 217 | panic!(); 218 | } 219 | _ => { 220 | panic!("No inputs"); 221 | } 222 | }; 223 | 224 | // Checks `,` 225 | match items.next() { 226 | Some(proc_macro::TokenTree::Punct(p)) => { 227 | if p.to_string() != "," { 228 | Diagnostic::spanned( 229 | p.span(), 230 | proc_macro::Level::Error, 231 | "This should be a comma e.g. `,`", 232 | ) 233 | .emit(); 234 | panic!(); 235 | } 236 | } 237 | Some(e) => { 238 | Diagnostic::spanned( 239 | e.span(), 240 | proc_macro::Level::Error, 241 | "This should be a comma e.g. `,`", 242 | ) 243 | .emit(); 244 | panic!(); 245 | } 246 | None => { 247 | return format!("{}{}()", rust_ad_consts::REVERSE_PREFIX, function_ident) 248 | .parse() 249 | .unwrap(); 250 | } 251 | } 252 | 253 | // Gets output derivatives tuple 254 | let output_derivatives = match items.next() { 255 | Some(proc_macro::TokenTree::Group(output_derivatives_group)) => { 256 | let output_derivatives_stream = output_derivatives_group.stream(); 257 | let output_derivatives_vec = output_derivatives_stream.into_iter().collect::>(); 258 | let output_derivatives = output_derivatives_vec 259 | .chunks(2) 260 | .map(|items| match &items[0] { 261 | proc_macro::TokenTree::Literal(num) => num.to_string(), 262 | _ => panic!("reverse: bad"), 263 | }) 264 | .collect::>(); 265 | output_derivatives 266 | } 267 | Some(e) => { 268 | Diagnostic::spanned(e.span(), proc_macro::Level::Error, "Bad outputs").emit(); 269 | panic!(); 270 | } 271 | _ => { 272 | panic!("No output derivatives"); 273 | } 274 | }; 275 | 276 | let inputs_str = match inputs.len() { 277 | 0 => unreachable!(), 278 | 1 => match output_derivatives.len() { 279 | 0 => inputs[0].to_string(), 280 | 1 => format!("{},{}", inputs[0], output_derivatives[0]), 281 | _ => format!( 282 | "{},({})", 283 | inputs[0], 284 | output_derivatives 285 | .into_iter() 286 | .intersperse(String::from(",")) 287 | .collect::() 288 | ), 289 | }, 290 | _ => match output_derivatives.len() { 291 | 0 => format!( 292 | "({})", 293 | inputs 294 | .into_iter() 295 | .intersperse(String::from(",")) 296 | .collect::() 297 | ), 298 | 1 => format!( 299 | "({}),{}", 300 | inputs 301 | .into_iter() 302 | .intersperse(String::from(",")) 303 | .collect::(), 304 | output_derivatives[0] 305 | ), 306 | _ => format!( 307 | "({}),({})", 308 | inputs 309 | .into_iter() 310 | .intersperse(String::from(",")) 311 | .collect::(), 312 | output_derivatives 313 | .into_iter() 314 | .intersperse(String::from(",")) 315 | .collect::() 316 | ), 317 | }, 318 | }; 319 | 320 | let call_str = format!( 321 | "{}{}({})", 322 | rust_ad_consts::REVERSE_PREFIX, 323 | function_ident, 324 | inputs_str 325 | ); 326 | call_str.parse().unwrap() 327 | } 328 | 329 | /// Flattens nested binary expressions into separate variable assignments. 330 | /// 331 | /// E.g. 332 | /// ``` 333 | /// #[rust_ad::unweave] 334 | /// fn function_name(x: f32, y: f32) -> f32 { 335 | /// let v = 2f32 * x + y / 3.0f32; 336 | /// return v; 337 | /// } 338 | /// ``` 339 | /// Expands to: 340 | /// ``` 341 | /// fn function_name(x: f32, y: f32) -> f32 { 342 | /// let _v = 2f32 * x; 343 | /// let v_ = y / 3.0f32; 344 | /// let v = _v + v_; 345 | /// return v; 346 | /// } 347 | /// ``` 348 | #[proc_macro_attribute] 349 | pub fn unweave(_attr: TokenStream, item: TokenStream) -> TokenStream { 350 | let ast = syn::parse_macro_input!(item as syn::Item); 351 | 352 | // Checks item is impl. 353 | let mut ast = match ast { 354 | syn::Item::Fn(func) => func, 355 | _ => panic!("Macro must be applied to a `fn`"), 356 | }; 357 | 358 | let block = &mut ast.block; 359 | 360 | let statements = block 361 | .stmts 362 | .iter() 363 | .flat_map(unwrap_statement) 364 | .collect::>(); 365 | block.stmts = statements; 366 | 367 | let new = quote::quote! { #ast }; 368 | TokenStream::from(new) 369 | } 370 | 371 | /// Generates the forward auto-differentiation function for a given function. 372 | /// ``` 373 | /// fn complex_test() { 374 | /// let (f, (der_x, der_y, der_z)) = __f_complex((3f32, 5f32, 7f32),(1f32,1f32,1f32)); 375 | /// is_near(f, 10.1187260448).unwrap(); 376 | /// is_near(der_x, 6.28571428571).unwrap(); 377 | /// is_near(der_y, -0.034212882033).unwrap(); 378 | /// is_near(der_z, -0.128914606556).unwrap(); 379 | /// 380 | /// // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 381 | /// // ∂x = 2(x+1/z) 382 | /// // ∂y = -2 / (y+z^0.5)^2 383 | /// // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 384 | /// // Therefore: 385 | /// // f(3,5,7) = 10.1187260448... 386 | /// // ∂x| = 6.28571428571... 387 | /// // ∂y| = −0.034212882033... 388 | /// // ∂z| = −0.128914606556... 389 | /// #[forward_autodiff] 390 | /// fn complex(x: f32, y: f32, z: f32) -> f32 { 391 | /// let a = x.powi(2i32); 392 | /// let b = x * 2f32 / z; 393 | /// let c = 2f32 / (z.sqrt() + y); 394 | /// let f = a + b + c; 395 | /// return f; 396 | /// } 397 | /// } 398 | /// ``` 399 | /// Like a derive macro, this new function is appended to your code, the original `function_name` function remains unedited. 400 | #[proc_macro_attribute] 401 | pub fn forward_autodiff(_attr: TokenStream, item: TokenStream) -> TokenStream { 402 | let start_item = item.clone(); 403 | let ast = syn::parse_macro_input!(item as syn::Item); 404 | // eprintln!("{:#?}",ast); 405 | 406 | // Checks item is function. 407 | let mut function = match ast { 408 | syn::Item::Fn(func) => func, 409 | _ => panic!("Only `fn` items are supported."), 410 | }; 411 | 412 | // Updates function signature 413 | // --------------------------------------------------------------------------- 414 | let function_input_identifiers = { 415 | // Updates identifier. 416 | function.sig.ident = syn::Ident::new( 417 | &format!("{}{}", rust_ad_consts::FORWARD_PREFIX, function.sig.ident), 418 | function.sig.ident.span(), 419 | ); 420 | // Gets function inputs in usable form `[(ident,type)]`. 421 | let (input_idents, input_types) = function 422 | .sig 423 | .inputs 424 | .iter() 425 | .map(|fn_arg| { 426 | let typed = fn_arg.typed().expect("forward: signature input not typed"); 427 | let mut arg_type = typed.ty.to_token_stream().to_string(); 428 | arg_type.remove_matches(" "); // Remove space separators in type 429 | 430 | let arg_ident = typed.pat.to_token_stream().to_string(); 431 | (arg_ident, arg_type) 432 | }) 433 | .unzip::<_, _, Vec<_>, Vec<_>>(); 434 | // Put existing inputs into tuple. 435 | // ----------------------------------------------- 436 | let inputs_tuple_str = match input_idents.len() { 437 | 0 => String::new(), 438 | 1 => format!("{}:{}", input_idents[0], input_types[0]), 439 | _ => format!( 440 | "({}):({})", 441 | input_idents 442 | .iter() 443 | .cloned() 444 | .intersperse(String::from(",")) 445 | .collect::(), 446 | input_types 447 | .iter() 448 | .cloned() 449 | .intersperse(String::from(",")) 450 | .collect::() 451 | ), 452 | }; 453 | let inputs_tuple = 454 | syn::parse_str(&inputs_tuple_str).expect("forward: inputs tuple parse fail"); 455 | // Gets tuple of derivatives of inputs. 456 | // ----------------------------------------------- 457 | let derivative_input_tuple_str = match input_idents.len() { 458 | 0 => String::new(), 459 | 1 => format!("{}:{}", der!(input_idents[0]), input_types[0]), 460 | _ => format!( 461 | "({}):({})", 462 | input_idents 463 | .iter() 464 | .cloned() 465 | .map(|i| der!(i)) 466 | .intersperse(String::from(",")) 467 | .collect::(), 468 | input_types 469 | .iter() 470 | .cloned() 471 | .intersperse(String::from(",")) 472 | .collect::() 473 | ), 474 | }; 475 | 476 | let derivative_input_tuple = 477 | syn::parse_str(&derivative_input_tuple_str).expect("forward: output tuple parse fail"); 478 | // Sets new function inputs 479 | // ----------------------------------------------- 480 | let mut new_fn_inputs = syn::punctuated::Punctuated::new(); 481 | new_fn_inputs.push(inputs_tuple); 482 | new_fn_inputs.push(derivative_input_tuple); 483 | function.sig.inputs = new_fn_inputs; 484 | // Sets new function outputs 485 | update_function_outputs(&mut function.sig, input_types).expect("forward_autodiff 0"); 486 | input_idents 487 | }; 488 | 489 | // Forward autodiff 490 | // --------------------------------------------------------------------------- 491 | 492 | // Flattens statements 493 | function.block.stmts = function 494 | .block 495 | .stmts 496 | .iter() 497 | .flat_map(unwrap_statement) 498 | .collect::>(); 499 | 500 | // Propagates types through function 501 | let type_map = propagate_types(&function).expect("forward_autodiff 1"); 502 | 503 | // In release we apply optimizations which shrink the produced code (eliminating unnecessary code) 504 | // These are not applied in debug mode so one might use debug to give a clearer view of the fundamental process. 505 | #[cfg(not(debug_assertions))] 506 | let mut non_zero_derivatives = HashSet::::new(); 507 | 508 | #[cfg(debug_assertions)] 509 | let der_info = (&type_map, function_input_identifiers.as_slice()); 510 | #[cfg(not(debug_assertions))] 511 | let der_info = ( 512 | &type_map, 513 | function_input_identifiers.as_slice(), 514 | &mut non_zero_derivatives, 515 | ); 516 | 517 | // Intersperses forward derivatives 518 | let derivative_stmts = 519 | intersperse_succeeding_stmts(function.block.stmts, der_info, forward_derivative) 520 | .expect("forward_autodiff 2"); 521 | function.block.stmts = derivative_stmts; 522 | // Updates return statement 523 | update_forward_return( 524 | &mut function.block, 525 | function_input_identifiers.as_slice(), 526 | #[cfg(not(debug_assertions))] 527 | type_map, 528 | #[cfg(not(debug_assertions))] 529 | non_zero_derivatives, 530 | ) 531 | .expect("forward_autodiff 3"); 532 | 533 | let new = quote::quote! { #function }; 534 | let new_stream = TokenStream::from(new); 535 | join_streams(start_item, new_stream) 536 | } 537 | fn join_streams(mut a: TokenStream, b: TokenStream) -> TokenStream { 538 | a.extend(b.into_iter()); 539 | a 540 | } 541 | 542 | fn update_function_outputs( 543 | function_signature: &mut syn::Signature, 544 | function_input_types: Vec, 545 | ) -> Result<(), PassError> { 546 | let function_output = &mut function_signature.output; 547 | // Updates output to include to derivatives for each output respective to each input 548 | // e.g. `fn(x:f32,x_:f32,y:f32,y_:f32)->(f32,f32)` => `fn(x:f32,y:f32)->((f32,(f32,f32)),(f32,(f32,f32)))` 549 | // eprintln!("function_output:\n{:#?}",function_output); 550 | let function_input_string = match function_input_types.len() { 551 | 0 => String::new(), 552 | 1 => function_input_types[0].clone(), 553 | _ => format!( 554 | "({}),", 555 | function_input_types 556 | .into_iter() 557 | .intersperse(String::from(",")) 558 | .collect::() 559 | ), 560 | }; 561 | let return_type_str = match function_output { 562 | syn::ReturnType::Type(_, return_type) => match &**return_type { 563 | syn::Type::Path(return_path) => { 564 | let return_str = return_path.to_token_stream().to_string(); 565 | format!("->({},{})", return_str, function_input_string) 566 | } 567 | syn::Type::Tuple(return_tuple) => { 568 | let return_str = return_tuple.to_token_stream().to_string(); 569 | format!( 570 | "->({},({}))", 571 | return_str, 572 | function_input_string.repeat(return_tuple.elems.len()) 573 | ) 574 | } 575 | _ => { 576 | let err = "Unsupported return type (supported types are tuples (e.g. `(f32,f32)`) or paths (e.g. `f32`))"; 577 | Diagnostic::spanned(return_type.span().unwrap(), proc_macro::Level::Error, err) 578 | .emit(); 579 | return Err(err.to_string()); 580 | } 581 | }, 582 | // TODO What does this even look like? 583 | syn::ReturnType::Default => { 584 | let err = "Unsupported return form"; 585 | Diagnostic::spanned( 586 | function_output.span().unwrap(), 587 | proc_macro::Level::Error, 588 | err, 589 | ) 590 | .emit(); 591 | return Err(err.to_string()); 592 | } 593 | }; 594 | // eprintln!("return_type_str: {}",return_type_str); 595 | *function_output = pass!( 596 | syn::parse_str(&return_type_str), 597 | "forward: failed output parse" 598 | ); 599 | Ok(()) 600 | } 601 | 602 | /// Returns a tuple of a given number of clones of a variable. 603 | /// ``` 604 | /// let x = 2; 605 | /// assert_eq!(rust_ad::dup!(x,3),(x.clone(),x.clone(),x.clone())); 606 | /// ``` 607 | #[proc_macro] 608 | pub fn dup(_item: TokenStream) -> TokenStream { 609 | // eprintln!("what?: {:?}",_item); 610 | let vec = _item.into_iter().collect::>(); 611 | match (vec.get(0), vec.get(1), vec.get(2)) { 612 | ( 613 | Some(proc_macro::TokenTree::Ident(var)), 614 | Some(proc_macro::TokenTree::Punct(_)), 615 | Some(proc_macro::TokenTree::Literal(num)), 616 | ) => { 617 | let tuple = format!( 618 | "({})", 619 | format!("{}.clone(),", var).repeat(num.to_string().parse().unwrap()) 620 | ); 621 | tuple.parse().unwrap() 622 | } 623 | _ => panic!("Bad input"), 624 | } 625 | } 626 | 627 | /// Generates the reverse auto-differentiation function for a given function. 628 | /// ``` 629 | /// fn complex_test() { 630 | /// let (f, (der_x, der_y, der_z)) = __r_complex((3f32, 5f32, 7f32), (1f32)); 631 | /// is_near(f, 10.1187260448).unwrap(); 632 | /// is_near(der_x, 6.28571428571).unwrap(); 633 | /// is_near(der_y, -0.034212882033).unwrap(); 634 | /// is_near(der_z, -0.128914606556).unwrap(); 635 | /// 636 | /// // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 637 | /// // ∂x = 2(x+1/z) 638 | /// // ∂y = -2 / (y+z^0.5)^2 639 | /// // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 640 | /// // Therefore: 641 | /// // f(3,5,7) = 10.1187260448 642 | /// // ∂x| = 6.28571428571 643 | /// // ∂y| = −0.034212882033 644 | /// // ∂z| = −0.128914606556 645 | /// #[reverse_autodiff] 646 | /// fn complex(x: f32, y: f32, z: f32) -> f32 { 647 | /// let a = x.powi(2i32); 648 | /// let b = x * 2f32 / z; 649 | /// let c = 2f32 / (z.sqrt() + y); 650 | /// let f = a + b + c; 651 | /// return f; 652 | /// } 653 | /// } 654 | /// ``` 655 | /// Like a derive macro, this new function is appended to your code, the original `function_name` function remains unedited. 656 | #[proc_macro_attribute] 657 | pub fn reverse_autodiff(_attr: TokenStream, item: TokenStream) -> TokenStream { 658 | let start_item = item.clone(); 659 | let ast = syn::parse_macro_input!(item as syn::Item); 660 | // Checks item is function. 661 | let mut function = match ast { 662 | syn::Item::Fn(func) => func, 663 | _ => panic!("Only `fn` items are supported."), 664 | }; 665 | // Unwraps nested binary expressions 666 | // --------------------------------------------------------------------------- 667 | let statements = function 668 | .block 669 | .stmts 670 | .iter() 671 | .flat_map(unwrap_statement) 672 | .collect::>(); 673 | function.block.stmts = statements; 674 | 675 | // Updates function signature 676 | // --------------------------------------------------------------------------- 677 | let (function_input_identifiers, number_of_return_elements) = { 678 | // Updates identifier. 679 | function.sig.ident = syn::Ident::new( 680 | &format!("{}{}", rust_ad_consts::REVERSE_PREFIX, function.sig.ident), 681 | function.sig.ident.span(), 682 | ); 683 | // Gets function inputs in usable form `[(ident,type)]`. 684 | let (input_idents, input_types) = function 685 | .sig 686 | .inputs 687 | .iter() 688 | .map(|fn_arg| { 689 | let typed = fn_arg.typed().expect("reverse: signature input not typed"); 690 | let mut arg_type = typed.ty.to_token_stream().to_string(); 691 | arg_type.remove_matches(" "); // Remove space separators in type 692 | 693 | let arg_ident = typed.pat.to_token_stream().to_string(); 694 | (arg_ident, arg_type) 695 | }) 696 | .unzip::<_, _, Vec<_>, Vec<_>>(); 697 | // Put existing inputs into tuple. 698 | // ----------------------------------------------- 699 | let inputs_tuple_str = match input_idents.len() { 700 | 0 => String::new(), 701 | 1 => format!("{}:{}", input_idents[0], input_types[0]), 702 | _ => format!( 703 | "({}):({})", 704 | input_idents 705 | .iter() 706 | .cloned() 707 | .intersperse(String::from(",")) 708 | .collect::(), 709 | input_types 710 | .iter() 711 | .cloned() 712 | .intersperse(String::from(",")) 713 | .collect::() 714 | ), 715 | }; 716 | let inputs_tuple = 717 | syn::parse_str(&inputs_tuple_str).expect("reverse: inputs tuple parse fail"); 718 | // Gets tuple of derivatives of inputs. 719 | // ----------------------------------------------- 720 | let (function_output, number_of_return_elements) = match &function.sig.output { 721 | syn::ReturnType::Type(_, return_type) => ( 722 | return_type.to_token_stream().to_string(), 723 | match &**return_type { 724 | syn::Type::Tuple(type_tuple) => type_tuple.elems.len(), 725 | syn::Type::Path(_) => 1, 726 | _ => { 727 | let err = "Unsupported return type"; 728 | Diagnostic::spanned( 729 | return_type.span().unwrap(), 730 | proc_macro::Level::Error, 731 | err, 732 | ) 733 | .emit(); 734 | panic!("{}", err); 735 | } 736 | }, 737 | ), 738 | syn::ReturnType::Default => { 739 | let err = "Unsupported return form"; 740 | Diagnostic::spanned( 741 | function.sig.output.span().unwrap(), 742 | proc_macro::Level::Error, 743 | err, 744 | ) 745 | .emit(); 746 | panic!("{}", err); 747 | } 748 | }; 749 | let derivative_input_tuple_str = match number_of_return_elements { 750 | 0 => String::new(), 751 | 1 => format!("{}:{}", rtn!(0), function_output), 752 | _ => format!( 753 | "({}):{}", 754 | (0..number_of_return_elements) 755 | .map(|i| rtn!(i)) 756 | .intersperse(String::from(",")) 757 | .collect::(), 758 | function_output 759 | ), 760 | }; 761 | 762 | let derivative_input_tuple = 763 | syn::parse_str(&derivative_input_tuple_str).expect("reverse: output tuple parse fail"); 764 | // Sets new function inputs 765 | // ----------------------------------------------- 766 | let mut new_fn_inputs = syn::punctuated::Punctuated::new(); 767 | new_fn_inputs.push(inputs_tuple); 768 | new_fn_inputs.push(derivative_input_tuple); 769 | function.sig.inputs = new_fn_inputs; 770 | // Sets new function outputs 771 | update_function_outputs(&mut function.sig, input_types).expect("reverse_autodiff 0"); 772 | (input_idents, number_of_return_elements) 773 | }; 774 | 775 | // Propagates types through function 776 | // --------------------------------------------------------------------------- 777 | let type_map = propagate_types(&function).expect("propagate_types: "); 778 | 779 | // Generates reverse mode code 780 | // --------------------------------------------------------------------------- 781 | let mut component_map = vec![HashMap::new(); number_of_return_elements]; 782 | let mut return_derivatives = vec![HashSet::new(); number_of_return_elements]; 783 | 784 | let mut rev_iter = function.block.stmts.iter().rev().peekable(); 785 | let mut reverse_derivative_stmts = Vec::new(); 786 | 787 | // In release we apply optimizations which shrink the produced code (eliminating unnecessary code) 788 | // These are not applied in debug mode so one might use debug to give a clearer view of the fundamental process. 789 | #[cfg(not(debug_assertions))] 790 | let mut non_zero_derivatives = HashSet::::new(); 791 | 792 | // For the last statement (which we presume to be a return) we skip the accumulation for the next stage since we can set the accumulative derivatives directly. 793 | if let Some(return_stmt) = rev_iter.next() { 794 | reverse_derivative_stmts.append( 795 | &mut reverse_derivative( 796 | return_stmt, 797 | &type_map, 798 | &mut component_map, 799 | &mut return_derivatives, 800 | ) 801 | .expect("rtn der temp"), 802 | ); 803 | } 804 | // For the statement preceding the return statement 805 | let mut rest = rev_iter 806 | .flat_map(|next| { 807 | reverse_derivative(next, &type_map, &mut component_map, &mut return_derivatives) 808 | .expect("der temp") 809 | }) 810 | .collect::>(); 811 | reverse_derivative_stmts.append(&mut rest); 812 | // Collects inputs for return statement. 813 | if let Some(input_accumulation) = reverse_accumulate_inputs( 814 | &function_input_identifiers, 815 | &component_map, 816 | &type_map, 817 | &return_derivatives, 818 | ) { 819 | reverse_derivative_stmts.push(input_accumulation); 820 | } 821 | 822 | // Gets new return statement 823 | let new_return = reverse_append_derivatives( 824 | function.block.stmts.pop().unwrap(), 825 | &function_input_identifiers, 826 | ) 827 | .expect("rtn acc temp"); 828 | // Adds derivatives to block 829 | function.block.stmts.append(&mut reverse_derivative_stmts); 830 | function.block.stmts.push(new_return); 831 | 832 | let new = quote::quote! { #function }; 833 | let new_stream = TokenStream::from(new); 834 | join_streams(start_item, new_stream) 835 | } 836 | 837 | /// Unwraps nested expressions into separate variable assignments. 838 | /// 839 | /// E.g. 840 | /// ```ignore 841 | /// let a = b*c + d/e; 842 | /// ``` 843 | /// Becomes: 844 | /// ```ignore 845 | /// let _a = b*c; 846 | /// let a_ = d/c; 847 | /// let a = _a + a_; 848 | /// ``` 849 | /// 850 | /// E.g. 851 | /// ```ignore 852 | /// let a = function_one(function_two(b)+c); 853 | /// ``` 854 | /// Becomes: 855 | /// ```ignore 856 | /// let __a = function_two(b); 857 | /// let _a = __a + c; 858 | /// let a = function_one(_a); 859 | /// ``` 860 | fn unwrap_statement(stmt: &syn::Stmt) -> Vec { 861 | // eprintln!("unwrap stmt:\n{:#?}\n", stmt); 862 | 863 | let mut statements = Vec::new(); 864 | // TODO Avoid this clone. 865 | let mut base_statement = stmt.clone(); 866 | 867 | // If the statement is local variable declaration (e.g. `let ...`). 868 | if let syn::Stmt::Local(local) = stmt { 869 | let local_ident = &local 870 | .pat 871 | .ident() 872 | .unwrap_or_else(|_|panic!("unwrap_statement: non-ident local pattern (must be `let x =...;`, cannot be a tuple etc.): {{\n{:#?}\n}}",local)) 873 | .ident.to_string(); 874 | // If our statement has some initialization (e.g. `let a = 3;`). 875 | if let Some(init) = local.init.as_ref() { 876 | // eprintln!("init: {:#?}", init); 877 | 878 | // If initialization is a binary expression (e.g. `let a = b + c;`). 879 | if let syn::Expr::Binary(bin_expr) = init.1.as_ref() { 880 | // If left side is not 881 | 882 | // If left is not a literal or path 883 | if !(bin_expr.left.is_lit() || bin_expr.left.is_path()) { 884 | // Creates new left statement. 885 | let left_ident = format!("{}_", local_ident); 886 | let new_stmt_str = 887 | format!("let {} = {};", left_ident, bin_expr.left.to_token_stream()); 888 | let new_stmt: syn::Stmt = 889 | syn::parse_str(&new_stmt_str).expect("unwrap: left bad parse"); 890 | // Recurse 891 | statements.append(&mut unwrap_statement(&new_stmt)); 892 | 893 | // Updates statement to contain variable referencing new statement. 894 | let left_expr: syn::Expr = 895 | syn::parse_str(&left_ident).expect("unwrap: left parse fail"); 896 | *base_statement 897 | .local_mut() 898 | .expect("unwrap: 1a") 899 | .init 900 | .as_mut() 901 | .unwrap() 902 | .1 903 | .binary_mut() 904 | .expect("unwrap: 1b") 905 | .left = left_expr; 906 | } 907 | // If right is not a literal or path 908 | if !(bin_expr.right.is_lit() || bin_expr.right.is_path()) { 909 | // eprintln!("this should trigger: {}",right_bin_expr.to_token_stream()); 910 | // Creates new right statement. 911 | let right_ident = format!("{}_", local_ident); 912 | let new_stmt_str = format!( 913 | "let {} = {};", 914 | right_ident, 915 | bin_expr.right.to_token_stream() 916 | ); 917 | let new_stmt: syn::Stmt = 918 | syn::parse_str(&new_stmt_str).expect("unwrap: right bad parse"); 919 | // Recurse 920 | statements.append(&mut unwrap_statement(&new_stmt)); 921 | 922 | // Updates statement to contain variable referencing new statement. 923 | let right_expr: syn::Expr = 924 | syn::parse_str(&right_ident).expect("unwrap: right parse fail"); 925 | *base_statement 926 | .local_mut() 927 | .expect("unwrap: 2a") 928 | .init 929 | .as_mut() 930 | .unwrap() 931 | .1 932 | .binary_mut() 933 | .expect("unwrap: 2b") 934 | .right = right_expr; 935 | } 936 | } 937 | // If initialization is function call (e.g. `let a = my_function(b,c);`). 938 | else if let syn::Expr::Call(call_expr) = init.1.as_ref() { 939 | // eprintln!("call_expr: {:#?}",call_expr); 940 | 941 | // For each function argument. 942 | for (i, arg) in call_expr.args.iter().enumerate() { 943 | // eprintln!("i: {:#?}, arg: {:#?}",i,arg); 944 | 945 | // If function argument is binary expression 946 | if let syn::Expr::Binary(arg_bin_expr) = arg { 947 | // eprintln!("arg_bin_expr: {:#?}",arg_bin_expr); 948 | 949 | // Creates new function argument statement. 950 | let mut func_stmt = stmt.clone(); 951 | let func_local = func_stmt 952 | .local_mut() 953 | .expect("unwrap: function statement not local"); 954 | let func_ident = 955 | format!("{}_{}", FUNCTION_PREFIX.repeat(i + 1), local_ident); 956 | func_local 957 | .pat 958 | .ident_mut() 959 | .expect("unwrap: function not ident") 960 | .ident = 961 | syn::parse_str(&func_ident).expect("unwrap: function ident parse fail"); 962 | *func_local.init.as_mut().unwrap().1 = 963 | syn::Expr::Binary(arg_bin_expr.clone()); 964 | // Recurse 965 | statements.append(&mut unwrap_statement(&func_stmt)); 966 | 967 | // Updates statement to contain reference to new variables 968 | let arg_expr: syn::Expr = 969 | syn::parse_str(&func_ident).expect("unwrap: function parse fail"); 970 | base_statement 971 | .local_mut() 972 | .expect("unwrap: function local") 973 | .init 974 | .as_mut() 975 | .unwrap() 976 | .1 977 | .call_mut() 978 | .expect("unwrap: function call") 979 | .args[i] = arg_expr; 980 | } 981 | } 982 | } 983 | // If initialization is method call (e.g. `let a = b.my_function(c);`). 984 | else if let syn::Expr::MethodCall(method_expr) = init.1.as_ref() { 985 | // If method is call on value in parenthesis (e.g. `(x).method()`). 986 | if let syn::Expr::Paren(parenthesis) = &*method_expr.receiver { 987 | // If method is called on value which is binary expression (e.g. `(x+y).method()`). 988 | if let syn::Expr::Binary(bin_expr) = &*parenthesis.expr { 989 | // Creates new statement. 990 | let mut receiver_stmt = stmt.clone(); 991 | let receiver_local = receiver_stmt 992 | .local_mut() 993 | .expect("unwrap: receiver statement not local"); 994 | let receiver_ident = format!("{}_{}", RECEIVER_PREFIX, local_ident); 995 | receiver_local 996 | .pat 997 | .ident_mut() 998 | .expect("unwrap: receiver not ident") 999 | .ident = syn::parse_str(&receiver_ident) 1000 | .expect("unwrap: receiver ident parse fail"); 1001 | *receiver_local.init.as_mut().unwrap().1 = 1002 | syn::Expr::Binary(bin_expr.clone()); 1003 | // Recurse 1004 | statements.append(&mut unwrap_statement(&receiver_stmt)); 1005 | 1006 | // Updates statement to contain variable referencing new statement. 1007 | let receiver_expr: syn::Expr = 1008 | syn::parse_str(&receiver_ident).expect("unwrap: receiver parse fail"); 1009 | *base_statement 1010 | .local_mut() 1011 | .expect("unwrap: 3a") 1012 | .init 1013 | .as_mut() 1014 | .unwrap() 1015 | .1 1016 | .method_call_mut() 1017 | .expect("unwrap: 3b") 1018 | .receiver = receiver_expr; 1019 | } 1020 | } 1021 | for (i, arg) in method_expr.args.iter().enumerate() { 1022 | // eprintln!("i: {:#?}, arg: {:#?}",i,arg); 1023 | 1024 | // If function argument is binary expression 1025 | if let syn::Expr::Binary(arg_bin_expr) = arg { 1026 | // eprintln!("arg_bin_expr: {:#?}",arg_bin_expr); 1027 | 1028 | // Creates new function argument statement. 1029 | let mut func_stmt = stmt.clone(); 1030 | let func_local = func_stmt 1031 | .local_mut() 1032 | .expect("unwrap: method statement not local"); 1033 | let func_ident = 1034 | format!("{}_{}", FUNCTION_PREFIX.repeat(i + 1), local_ident); 1035 | func_local 1036 | .pat 1037 | .ident_mut() 1038 | .expect("unwrap: method not ident") 1039 | .ident = 1040 | syn::parse_str(&func_ident).expect("unwrap: method ident parse fail"); 1041 | *func_local.init.as_mut().unwrap().1 = 1042 | syn::Expr::Binary(arg_bin_expr.clone()); 1043 | // Recurse 1044 | statements.append(&mut unwrap_statement(&func_stmt)); 1045 | 1046 | // Updates statement to contain reference to new variables 1047 | let arg_expr: syn::Expr = 1048 | syn::parse_str(&func_ident).expect("unwrap: method parse fail"); 1049 | base_statement 1050 | .local_mut() 1051 | .expect("unwrap: method local") 1052 | .init 1053 | .as_mut() 1054 | .unwrap() 1055 | .1 1056 | .method_call_mut() 1057 | .expect("unwrap: method call") 1058 | .args[i] = arg_expr; 1059 | } 1060 | } 1061 | } else if let syn::Expr::Paren(paren_expr) = init.1.as_ref() { 1062 | base_statement 1063 | .local_mut() 1064 | .expect("unwrap: 3a") 1065 | .init 1066 | .as_mut() 1067 | .unwrap() 1068 | .1 = paren_expr.expr.clone(); 1069 | statements.append(&mut unwrap_statement(&base_statement)); 1070 | // Skips adding base statement we already added. 1071 | return statements; 1072 | } 1073 | } 1074 | } else if let syn::Stmt::Semi(syn::Expr::Return(rtn_expr), _) = stmt { 1075 | if let Some(rtn) = &rtn_expr.expr { 1076 | if let syn::Expr::Binary(_bin_expr) = &**rtn { 1077 | let new_ident = format!("_{}", RETURN_SUFFIX); 1078 | let new_stmt_str = format!("let {};", new_ident); 1079 | let mut new_stmt: syn::Stmt = 1080 | syn::parse_str(&new_stmt_str).expect("unwrap: return stmt parse fail"); 1081 | let new_local = new_stmt 1082 | .local_mut() 1083 | .expect("unwrap: return statement not local"); 1084 | new_local 1085 | .pat 1086 | .ident_mut() 1087 | .expect("unwrap: return not ident") 1088 | .ident = syn::parse_str(&new_ident).expect("unwrap: return ident parse fail"); 1089 | 1090 | // TODO Create `eq_token` some better way. 1091 | let eq_token = syn::parse_str("=").expect("unwrap: fml this is dumb"); 1092 | 1093 | new_local.init = Some((eq_token, rtn.clone())); 1094 | // Recurse 1095 | statements.append(&mut unwrap_statement(&new_stmt)); 1096 | 1097 | // Updates statement to contain variable referencing new statement. 1098 | let new_rtn_str = format!("return {};", new_ident); 1099 | let new_rtn_expr: syn::Stmt = 1100 | syn::parse_str(&new_rtn_str).expect("unwrap: return parse fail"); 1101 | base_statement = new_rtn_expr; 1102 | } 1103 | } 1104 | } 1105 | statements.push(base_statement); 1106 | // eprintln!("statements.len(): {}", statements.len()); 1107 | statements 1108 | } 1109 | 1110 | /// Gets the types of all variables in a function. 1111 | /// 1112 | /// Propagates types through variables in a function from the input types. 1113 | /// 1114 | /// Returns a hashmap of identifier->type. 1115 | /// 1116 | /// CURRENTLY DOES NOT SUPPORT PROCEDURES WHICH RETURN MULTIPLE DIFFERENT TYPES 1117 | fn propagate_types(func: &syn::ItemFn) -> Result, PassError> { 1118 | // Collects input tuples into initial `type_map`. 1119 | let input_types = func.sig.inputs 1120 | .iter() 1121 | .map(|input| match input { 1122 | syn::FnArg::Typed(pat_type) => match (&*pat_type.pat,&*pat_type.ty) { 1123 | (syn::Pat::Ident(path_ident),syn::Type::Path(path_type)) => { 1124 | let ident = path_ident.to_token_stream().to_string(); 1125 | let mut type_str = path_type.to_token_stream().to_string(); 1126 | type_str.remove_matches(" "); // Remove space separators in type 1127 | Ok(vec![(ident,type_str)]) 1128 | }, 1129 | (syn::Pat::Tuple(tuple_ident),syn::Type::Tuple(tuple_type)) => { 1130 | // eprintln!("tuple_ident: {}",tuple_ident.to_token_stream()); 1131 | let input_types_vec = tuple_ident.elems.iter().zip(tuple_type.elems.iter()).map(|(i,t)| match i { 1132 | syn::Pat::Ident(ident) => { 1133 | let ident_str = ident.to_token_stream().to_string(); 1134 | if ident_str == "_" { 1135 | Ok(None) 1136 | } 1137 | else { 1138 | let mut type_str = t.to_token_stream().to_string(); 1139 | type_str.remove_matches(" "); // Remove space separators in type 1140 | Ok(Some((ident_str,type_str))) 1141 | } 1142 | } 1143 | _ => { 1144 | // eprintln!("ident i: {:#?}",i); 1145 | let err = "Non-ident tuple type. `return (a,b,)` is supported. `return (a,(b,c))` is not supported."; 1146 | Diagnostic::spanned( 1147 | input.span().unwrap(), 1148 | proc_macro::Level::Error, 1149 | err, 1150 | ) 1151 | .emit(); 1152 | Err(err) 1153 | } 1154 | }).collect::,_>>().expect("propagate_types: tuple input error"); 1155 | let input_types_vec = input_types_vec.into_iter().flatten().collect::>(); 1156 | Ok(input_types_vec) 1157 | } 1158 | _ => { 1159 | eprintln!("pat_type.pat: \n{:#?}",pat_type.pat); 1160 | eprintln!("pat_type.ty: \n{:#?}",pat_type.ty); 1161 | 1162 | let err = "Unsupported input type combination"; 1163 | Diagnostic::spanned( 1164 | input.span().unwrap(), 1165 | proc_macro::Level::Error, 1166 | err, 1167 | ) 1168 | .emit(); 1169 | Err(err) 1170 | } 1171 | }, 1172 | syn::FnArg::Receiver(_) => { 1173 | let err = "Unsupported input type"; 1174 | Diagnostic::spanned( 1175 | input.span().unwrap(), 1176 | proc_macro::Level::Error, 1177 | err, 1178 | ) 1179 | .emit(); 1180 | Err(err) 1181 | } 1182 | }) 1183 | .collect::,_>>().expect("propagate_types: input types error"); 1184 | let mut type_map = input_types 1185 | .into_iter() 1186 | .flatten() 1187 | .collect::>(); 1188 | 1189 | // Propagates types through statements 1190 | for stmt in func.block.stmts.iter() { 1191 | if let syn::Stmt::Local(local) = stmt { 1192 | // Gets identifier/s of variable/s being defined 1193 | let var_idents = match &local.pat { 1194 | syn::Pat::Ident(pat_ident) => vec![pat_ident.ident.to_string()], 1195 | syn::Pat::Tuple(pat_tuple) => pat_tuple 1196 | .elems 1197 | .iter() 1198 | .map(|e| { 1199 | e.ident() 1200 | .expect("propagate_types: tuple not ident") 1201 | .ident 1202 | .to_string() 1203 | }) 1204 | .collect(), 1205 | _ => panic!("propagate_types: local pat not ident:\n{:#?}", local.pat), 1206 | }; 1207 | if let Some(init) = &local.init { 1208 | let output_type = match expr_type(&*init.1, &type_map) { 1209 | Ok(res) => res, 1210 | Err(e) => return Err(e), 1211 | }; 1212 | for var_ident in var_idents.into_iter() { 1213 | type_map.insert(var_ident, output_type.clone()); 1214 | } 1215 | } 1216 | } 1217 | } 1218 | // eprintln!("final map: {:?}", map); 1219 | Ok(type_map) 1220 | } 1221 | -------------------------------------------------------------------------------- /macros/src/reverse.rs: -------------------------------------------------------------------------------- 1 | extern crate proc_macro; 2 | use proc_macro::Diagnostic; 3 | 4 | use proc_macro::Level::Error; 5 | use quote::ToTokens; 6 | use rust_ad_core::*; 7 | use std::collections::HashMap; 8 | use std::collections::HashSet; 9 | use syn::spanned::Spanned; 10 | 11 | // Given return statement outputs return statement with appended derivatives 12 | pub fn reverse_append_derivatives( 13 | stmt: syn::Stmt, 14 | function_input_identifiers: &[String], 15 | ) -> Result { 16 | const NAME: &str = "reverse_append_derivatives"; 17 | if let syn::Stmt::Semi(syn::Expr::Return(return_struct), _) = stmt { 18 | if let Some(return_expr) = &return_struct.expr { 19 | match &**return_expr { 20 | // If return expression is tuple e.g. `return (a,b);` 21 | syn::Expr::Tuple(return_tuple) => { 22 | let return_idents_res = return_tuple.elems.iter().enumerate().map(|(index,e)| match e { 23 | syn::Expr::Path(p) => { 24 | let path_ident = p.to_token_stream().to_string(); 25 | let rtn_ident = rtn!(index); 26 | 27 | let (ident,der) = (path_ident,format!("({})", 28 | function_input_identifiers 29 | .iter() 30 | .map(|input|wrt!(input,rtn_ident)) 31 | .intersperse(String::from(",")) 32 | .collect::() 33 | )); 34 | Ok((ident,der)) 35 | }, 36 | syn::Expr::Lit(l) => { 37 | let der = function_input_identifiers 38 | .iter() 39 | .map(|_|format!("0{}",literal_type(l).expect("reverse_append_derivatives: unsupported literal type"))) 40 | .intersperse(String::from(",")) 41 | .collect::(); 42 | 43 | Ok(( 44 | l.to_token_stream().to_string(), 45 | format!("({})",der) 46 | )) 47 | 48 | }, 49 | _ => { 50 | let err = "Unsupported return tuple element. Elements in a returned tuple must be paths or literals (e.g. `return (a,b,2f32))` is supported, `return (a,(b,c))` is not supported)."; 51 | Diagnostic::spanned( 52 | return_struct.span().unwrap(), 53 | proc_macro::Level::Error, 54 | err, 55 | ) 56 | .emit(); 57 | Err(err.to_string()) 58 | } 59 | }).collect::,_>>(); 60 | let return_idents = pass!(return_idents_res, NAME); 61 | let (ident, der) = return_idents.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); 62 | 63 | let new_return_stmt_str = format!( 64 | "return (({}),({}));", 65 | ident 66 | .into_iter() 67 | .intersperse(String::from(",")) 68 | .collect::(), 69 | der.into_iter() 70 | .intersperse(String::from(",")) 71 | .collect::(), 72 | ); 73 | let new_return_stmt = pass!(syn::parse_str(&new_return_stmt_str), NAME); 74 | Ok(new_return_stmt) 75 | } 76 | // If return expression is path e.g. `return a;` 77 | syn::Expr::Path(return_path) => { 78 | let path_ident = return_path.to_token_stream().to_string(); 79 | let rtn_ident = rtn!(0); 80 | let tuple_str = match function_input_identifiers.len() { 81 | 0 => String::new(), 82 | 1 => wrt!(function_input_identifiers[0], rtn_ident), 83 | _ => format!( 84 | "({})", 85 | function_input_identifiers 86 | .iter() 87 | .map(|input| wrt!(input, rtn_ident)) 88 | .intersperse(String::from(",")) 89 | .collect::() 90 | ), 91 | }; 92 | let new_return_stmt_str = format!("return ({},{});", path_ident, tuple_str); 93 | let new_return_stmt = pass!(syn::parse_str(&new_return_stmt_str), NAME); 94 | Ok(new_return_stmt) 95 | } 96 | syn::Expr::Lit(l) => { 97 | let new_return_stmt_str = format!( 98 | "return ({},0{});", 99 | l.to_token_stream(), 100 | literal_type(l).expect("Unsupported literal type") 101 | ); 102 | let new_return_stmt = pass!(syn::parse_str(&new_return_stmt_str), NAME); 103 | Ok(new_return_stmt) 104 | } 105 | _ => { 106 | let err = "Unsupported return expression"; 107 | Diagnostic::spanned( 108 | return_struct.span().unwrap(), 109 | proc_macro::Level::Error, 110 | err, 111 | ) 112 | .emit(); 113 | Err(format!("{}: {}", NAME, err)) 114 | } 115 | } 116 | } else { 117 | let err = "No return expression"; 118 | Diagnostic::spanned(return_struct.span().unwrap(), proc_macro::Level::Error, err) 119 | .emit(); 120 | Err(format!("{}: {}", NAME, err)) 121 | } 122 | } else { 123 | let err = "Not return statement"; 124 | Diagnostic::spanned(stmt.span().unwrap(), proc_macro::Level::Error, err).emit(); 125 | Err(format!("{}: {}", NAME, err)) 126 | } 127 | } 128 | 129 | pub fn reverse_accumulate_inputs( 130 | function_inputs: &[String], 131 | component_map: &[HashMap>], 132 | type_map: &HashMap, 133 | return_derivatives: &[HashSet], 134 | ) -> Option { 135 | debug_assert_eq!(component_map.len(), return_derivatives.len()); 136 | 137 | let (inputs, derivative) = (0..component_map.len()) 138 | .filter_map(|index| { 139 | let rtn = rtn!(index); 140 | let (idents, derivatives) = function_inputs 141 | .iter() 142 | .filter_map(|input| { 143 | (!return_derivatives[index].contains(input)).then(|| { 144 | let ident = wrt!(input, rtn); 145 | 146 | let sum_str = match component_map[index].get(input) { 147 | Some(component_vec) => component_vec 148 | .iter() 149 | .map(|component| wrtn!(input, component, rtn)) 150 | .intersperse(String::from("+")) 151 | .collect::(), 152 | None => format!( 153 | "0{}", 154 | type_map 155 | .get(&rtn) 156 | .expect("reverse_accumulate_inputs: missed return") 157 | ), 158 | }; 159 | (ident, sum_str) 160 | }) 161 | }) 162 | .unzip::<_, _, Vec<_>, Vec<_>>(); 163 | match idents.len() { 164 | 0 => None, 165 | _ => Some((idents, derivatives)), 166 | } 167 | }) 168 | .unzip::<_, _, Vec<_>, Vec<_>>(); 169 | 170 | let stmt_str = match inputs.len() { 171 | 0 => String::new(), 172 | 1 => match inputs[0].len() { 173 | 0 => String::new(), 174 | 1 => format!("let {} = {};", inputs[0][0], derivative[0][0]), 175 | _ => format!( 176 | "let ({}) = ({});", 177 | inputs[0] 178 | .iter() 179 | .cloned() 180 | .intersperse(String::from(",")) 181 | .collect::(), 182 | derivative[0] 183 | .iter() 184 | .cloned() 185 | .intersperse(String::from(",")) 186 | .collect::() 187 | ), 188 | }, 189 | _ => format!( 190 | "let ({}) = ({});", 191 | inputs 192 | .into_iter() 193 | .map(|i| format!( 194 | "({})", 195 | i.into_iter() 196 | .intersperse(String::from(",")) 197 | .collect::() 198 | )) 199 | .intersperse(String::from(",")) 200 | .collect::(), 201 | derivative 202 | .into_iter() 203 | .map(|i| format!( 204 | "({})", 205 | i.into_iter() 206 | .intersperse(String::from(",")) 207 | .collect::() 208 | )) 209 | .intersperse(String::from(",")) 210 | .collect::() 211 | ), 212 | }; 213 | (!stmt_str.is_empty()).then(|| { 214 | syn::parse_str(&stmt_str) 215 | .unwrap_or_else(|_| panic!("reverse_accumulate_inputs: parse fail `{}`", stmt_str)) 216 | }) 217 | } 218 | fn reverse_accumulate_derivative( 219 | stmt: &syn::Stmt, 220 | component_map: &[HashMap>], 221 | return_derivatives: &mut Vec>, 222 | ) -> Result, PassError> { 223 | debug_assert_eq!(component_map.len(), return_derivatives.len()); 224 | const NAME: &str = "reverse_accumulate_derivative"; 225 | match stmt { 226 | // If we have a local variable declaration statement e.g. `let a;`. 227 | syn::Stmt::Local(local) => match &local.pat { 228 | syn::Pat::Ident(local_ident) => { 229 | let ident_str = local_ident.to_token_stream().to_string(); 230 | let (accumulative_derivatives, derivative_sums) = (0..component_map.len()) 231 | .filter_map(|index| { 232 | component_map[index].get(&ident_str).map(|components| { 233 | let rtn = rtn!(index); 234 | let acc = wrt!(ident_str, rtn); 235 | // Inserting here, notes that now we have a derivative for `ident_str` affecting `rtn!(index)` 236 | return_derivatives[index].insert(ident_str.clone()); 237 | ( 238 | acc, 239 | components 240 | .iter() 241 | .map(|d| wrtn!(ident_str, d, rtn)) 242 | .collect::>(), 243 | ) 244 | }) 245 | }) 246 | .unzip::<_, _, Vec<_>, Vec<_>>(); 247 | 248 | // equivalent to `derivative_sums.len()` 249 | let rtn_str = match accumulative_derivatives.len() { 250 | 0 => Ok(None), 251 | 1 => match derivative_sums[0].len() { 252 | 0 => unreachable!(), 253 | 1 => Ok(Some(format!( 254 | "let {} = {};", 255 | accumulative_derivatives[0], derivative_sums[0][0] 256 | ))), 257 | _ => Ok(Some(format!( 258 | "let {} = ({});", 259 | accumulative_derivatives[0], 260 | derivative_sums[0] 261 | .iter() 262 | .cloned() 263 | .intersperse(String::from("+")) 264 | .collect::(), 265 | ))), 266 | }, 267 | _ => Ok(Some(format!( 268 | "let ({}) = ({});", 269 | accumulative_derivatives 270 | .into_iter() 271 | .intersperse(String::from(",")) 272 | .collect::(), 273 | derivative_sums 274 | .into_iter() 275 | .map(|d| format!( 276 | "({})", 277 | d.into_iter() 278 | .intersperse(String::from("+")) 279 | .collect::() 280 | )) 281 | .intersperse(String::from(",")) 282 | .collect::(), 283 | ))), 284 | }; 285 | rtn_str.map(|res| { 286 | res.map(|opt| { 287 | syn::parse_str(&opt).expect("reverse_accumulate_derivative: parse fail") 288 | }) 289 | }) 290 | } 291 | _ => { 292 | let err = "Unsupported local declaration type. Only path declarations are supported (e.g. `let a = ... ;`)"; 293 | Diagnostic::spanned(local.span().unwrap(), proc_macro::Level::Error, err).emit(); 294 | Err(format!("{}: {}", NAME, err)) 295 | } 296 | }, 297 | _ => Ok(None), 298 | } 299 | } 300 | 301 | pub fn reverse_derivative( 302 | stmt: &syn::Stmt, 303 | type_map: &HashMap, 304 | component_map: &mut Vec>>, 305 | return_derivatives: &mut Vec>, 306 | ) -> Result, PassError> { 307 | const NAME: &str = "reverse_derivative"; 308 | match stmt { 309 | // If we have a local variable declaration e.g. `let a;` 310 | syn::Stmt::Local(local_stmt) => { 311 | let local_ident = local_stmt.pat.to_token_stream().to_string(); 312 | match &local_stmt.init { 313 | // If there is some initialization e.g. `let a = ... ;` 314 | Some((_, init)) => { 315 | match &**init { 316 | // If we have local variable declaration with a binary expression as initialization e.g. `let a = b + c;`. 317 | syn::Expr::Binary(bin_init_expr) => { 318 | // Accumulate derivatives for multiplying by components 319 | let accumulation_stmt_opt = pass!( 320 | reverse_accumulate_derivative( 321 | stmt, 322 | component_map, 323 | return_derivatives 324 | ), 325 | NAME 326 | ); 327 | let mut rtn_vec = vec![accumulation_stmt_opt]; 328 | // if let Some(accumulation_stmt) = accumulation_stmt_opt { 329 | // rtn_vec.push(accumulation_stmt); 330 | // } 331 | // Create binary operation signature (formed with the lhs type, rhs type and operation symbol (`+`, `-` etc.)). 332 | let op_sig = pass!(operation_signature(bin_init_expr, type_map), NAME); 333 | // Looks up binary operation of the formed signature in our supported operations map. 334 | match SUPPORTED_OPERATIONS.get(&op_sig) { 335 | // If we find an entry for an output signature, this means the operation is supported. 336 | // Applies the reverse derivative function for the found operation. 337 | Some(out_sig) => { 338 | rtn_vec.push((out_sig.reverse_derivative)( 339 | local_ident, 340 | &[ 341 | pass!(Arg::try_from(&*bin_init_expr.left), NAME), 342 | pass!(Arg::try_from(&*bin_init_expr.right), NAME), 343 | ], 344 | component_map, 345 | return_derivatives, 346 | )); 347 | } 348 | // If we don't find an entry, this means the operation is not supported. 349 | None => { 350 | // Since we do not support this operation and without considering it the whole process will not be accurate, we throw an error. 351 | let err = format!("Unsupported operation: {}", op_sig); 352 | Diagnostic::spanned( 353 | bin_init_expr.span().unwrap(), 354 | Error, 355 | err.clone(), 356 | ) 357 | .emit(); 358 | return Err(format!("{}: {}", NAME, err)); 359 | } 360 | } 361 | Ok(rtn_vec.into_iter().flatten().collect::>()) 362 | } 363 | // If we have local variable declaration with a function call expression as initialization e.g. `let a = f(b,c);`. 364 | syn::Expr::Call(call_init_expr) => { 365 | // Accumulate derivatives for multiplying by components 366 | let accumulation_stmt_opt = pass!( 367 | reverse_accumulate_derivative( 368 | stmt, 369 | component_map, 370 | return_derivatives 371 | ), 372 | NAME 373 | ); 374 | let mut rtn_vec = vec![accumulation_stmt_opt]; 375 | // Create function signature (formed with function identifier and argument types) 376 | let fn_sig = pass!(function_signature(call_init_expr, type_map), NAME); 377 | // Looks up function of our formed function signature in our supported functions map. 378 | match SUPPORTED_FUNCTIONS.get(&fn_sig) { 379 | // If we find an entry for an output signature, this means the function is supported. 380 | Some(out_sig) => { 381 | // Collects arguments 382 | let args = pass!( 383 | call_init_expr 384 | .args 385 | .iter() 386 | .map(Arg::try_from) 387 | .collect::, _>>(), 388 | NAME 389 | ); 390 | // Applies the reverse derivative function for the found function. 391 | let new_stmt = (out_sig.reverse_derivative)( 392 | local_ident, 393 | args.as_slice(), 394 | component_map, 395 | return_derivatives, 396 | ); 397 | rtn_vec.push(new_stmt); 398 | } 399 | // If we don't find an entry, this means the function is not supported. 400 | None => { 401 | // Since we do not support this function and without considering it the whole process will not be accurate, we throw an error. 402 | let err = format!("Unsupported function: {}", fn_sig); 403 | Diagnostic::spanned( 404 | call_init_expr.span().unwrap(), 405 | Error, 406 | err.clone(), 407 | ) 408 | .emit(); 409 | return Err(format!("{}: {}", NAME, err)); 410 | } 411 | } 412 | Ok(rtn_vec.into_iter().flatten().collect::>()) 413 | } 414 | // If we have local variable declaration with a method call expression as initialization e.g. `let a = b.f(c);`. 415 | syn::Expr::MethodCall(method_init_expr) => { 416 | // Accumulate derivatives for multiplying by components 417 | let accumulation_stmt_opt = pass!( 418 | reverse_accumulate_derivative( 419 | stmt, 420 | component_map, 421 | return_derivatives 422 | ), 423 | NAME 424 | ); 425 | let mut rtn_vec = vec![accumulation_stmt_opt]; 426 | // Create function signature (formed with function identifier and argument types) 427 | let mt_sig = pass!(method_signature(method_init_expr, type_map), NAME); 428 | // Looks up function of our formed function signature in our supported functions map. 429 | match SUPPORTED_METHODS.get(&mt_sig) { 430 | // If we find an entry for an output signature, this means the function is supported. 431 | Some(out_sig) => { 432 | // Collects arguments 433 | let mut args = pass!( 434 | method_init_expr 435 | .args 436 | .iter() 437 | .map(Arg::try_from) 438 | .collect::, _>>(), 439 | NAME 440 | ); 441 | // Inserts receiver argument as first argument (the receiver argument is the respective `self` in `let a = b.f(c)` it would be `b`). 442 | let receiver = 443 | pass!(Arg::try_from(&*method_init_expr.receiver), NAME); 444 | args.insert(0, receiver); 445 | // Applies the reverse derivative function for the found function. 446 | let new_stmt = (out_sig.reverse_derivative)( 447 | local_ident, 448 | args.as_slice(), 449 | component_map, 450 | return_derivatives, 451 | ); 452 | rtn_vec.push(new_stmt); 453 | } 454 | // If we don't find an entry, this means the method is not supported. 455 | None => { 456 | // Since we do not support this method and without considering it the whole process will not be accurate, we throw an error. 457 | let err = format!("Unsupported method: {}", mt_sig); 458 | Diagnostic::spanned( 459 | method_init_expr.span().unwrap(), 460 | Error, 461 | err.clone(), 462 | ) 463 | .emit(); 464 | return Err(format!("{}: {}", NAME, err)); 465 | } 466 | } 467 | Ok(rtn_vec.into_iter().flatten().collect::>()) 468 | } 469 | // If we have local variable declaration with an assignment expression as initialization e.g. `let a = b;`. 470 | syn::Expr::Path(path_init_expr) => { 471 | // Variable being assigned (e.g. `b`). 472 | let in_ident = path_init_expr.to_token_stream().to_string(); 473 | 474 | let (ident_str, der_str) = (0..component_map.len()) 475 | .filter_map(|index| { 476 | let rtn = rtn!(index); 477 | let from_wrt = wrt!(local_ident, rtn); 478 | let to_wrt = wrt!(&in_ident, rtn); 479 | 480 | // If component exists 481 | match component_map[index] 482 | .get(&local_ident) 483 | .map(|e| e.contains(&rtn)) 484 | { 485 | Some(true) => { 486 | return_derivatives[index].insert(in_ident.clone()); 487 | Some((to_wrt, from_wrt)) 488 | } 489 | _ => None, 490 | } 491 | }) 492 | .unzip::<_, _, Vec, Vec>(); 493 | 494 | // equivalent to `der_str.len()` 495 | let stmt_str = match ident_str.len() { 496 | 0 => None, 497 | 1 => Some(format!("let {} = {};", ident_str[0], der_str[0])), 498 | _ => Some(format!( 499 | "let ({}) = ({});", 500 | ident_str 501 | .into_iter() 502 | .intersperse(String::from(",")) 503 | .collect::(), 504 | der_str 505 | .into_iter() 506 | .intersperse(String::from(",")) 507 | .collect::(), 508 | )), 509 | }; 510 | Ok(match stmt_str { 511 | Some(s) => { 512 | vec![syn::parse_str(&s).expect("blah blah blah parse fail")] 513 | } 514 | None => Vec::new(), 515 | }) 516 | } 517 | _ => Ok(Vec::new()), 518 | } 519 | } 520 | None => Ok(Vec::new()), 521 | } 522 | } 523 | // If we have a return statement e.g. `return (a,b);` 524 | // TODO If return statement we need to set the accumulative derivative of the return component as the input return derivatives. 525 | syn::Stmt::Semi(syn::Expr::Return(return_struct), _) => { 526 | match &return_struct.expr { 527 | // If there is some return expression e.g. `return (a,b);` 528 | Some(return_expr) => match &**return_expr { 529 | // If return expression is tuple e.g. `return (a,b);` 530 | syn::Expr::Tuple(return_tuple) => { 531 | let return_idents_res = return_tuple.elems.iter().enumerate().filter_map(|(index,e)| match e { 532 | syn::Expr::Path(p) => { 533 | let path_ident = p.to_token_stream().to_string(); 534 | let rtn_ident = rtn!(index); 535 | 536 | return_derivatives[index].insert(path_ident.clone()); 537 | 538 | let (ident,der) = (wrt!(path_ident,rtn_ident),rtn_ident); 539 | Some(Ok((ident,der))) 540 | }, 541 | syn::Expr::Lit(_) => None, 542 | _ => { 543 | let err = "Unsupported return tuple element. Elements in a returned tuple must be paths or literals (e.g. `return (a,b,2f32))` is supported, `return (a,(b,c))` is not supported)."; 544 | Diagnostic::spanned( 545 | return_struct.span().unwrap(), 546 | proc_macro::Level::Error, 547 | err, 548 | ) 549 | .emit(); 550 | Some(Err(err.to_string())) 551 | } 552 | }).collect::,_>>(); 553 | let return_idents = pass!(return_idents_res, NAME); 554 | let (ident, der) = 555 | return_idents.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); 556 | 557 | let new_return_stmt_str = format!( 558 | "let ({}) = ({});", 559 | ident 560 | .into_iter() 561 | .intersperse(String::from(",")) 562 | .collect::(), 563 | der.into_iter() 564 | .intersperse(String::from(",")) 565 | .collect::(), 566 | ); 567 | 568 | let new_return_stmt = pass!(syn::parse_str(&new_return_stmt_str), NAME); 569 | Ok(vec![new_return_stmt]) 570 | } 571 | // If return expression is path e.g. `return a;` 572 | syn::Expr::Path(return_path) => { 573 | let path_ident = return_path.to_token_stream().to_string(); 574 | let rtn_ident = rtn!(0); 575 | 576 | return_derivatives[0].insert(path_ident.clone()); 577 | 578 | let new_stmt_str = 579 | format!("let {} = {};", wrt!(path_ident, rtn_ident), rtn_ident); 580 | let new_stmt = pass!(syn::parse_str(&new_stmt_str), NAME); 581 | Ok(vec![new_stmt]) 582 | } 583 | syn::Expr::Lit(_) => Ok(Vec::new()), 584 | _ => { 585 | let err = "Unsupported return type. Only tuples (e.g. `return (a,b,c);`), paths (e.g. `return a;`) and literals (e.g. `return 5f32;`) are supported."; 586 | Diagnostic::spanned( 587 | return_struct.span().unwrap(), 588 | proc_macro::Level::Error, 589 | err, 590 | ) 591 | .emit(); 592 | panic!("{}", err); 593 | } 594 | }, 595 | // If there is no return expression e.g. `return;` 596 | None => Ok(Vec::new()), 597 | } 598 | } 599 | _ => Ok(Vec::new()), 600 | } 601 | } 602 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A restrictive WIP beginnings of a library attempting to implement auto-differentiation in Rust. 2 | //! 3 | //! ### Status 4 | //! 5 | //! - [x] Forward Auto-differentiation 6 | //! - [x] Reverse Auto-differentiation 7 | //! - [x] `f32` & `f64` support* 8 | //! - [ ] limited° [ndarray](https://github.com/rust-ndarray/ndarray) support* 9 | //! - [ ] limited° [nalgebra](https://docs.rs/nalgebra/latest/nalgebra/) support* 10 | //! - [ ] `if`, `if else` and `else` support 11 | //! - [ ] `for`, `while` and `loop` support 12 | //! 13 | //! *`typeof` (e.g. [`decltype`](https://en.cppreference.com/w/cpp/language/decltype)) not being currently implemented in Rust makes support more difficult. 14 | //! 15 | //! °Support limited to the basic blas-like operations. 16 | //! 17 | //! Type/s | Support 18 | //! --- | --- 19 | //! Floats: `f32` & `f64` | `+`, `-`, `*`, `/` and most methods (e.g. `powf`). 20 | //! Integers: `u16`, `i16` etc. | `+`, `-`, `*` and `/` 21 | //! 22 | //! For the specifics of operation support see the [rust-ad-core docs](https://docs.rs/rust-ad-core/). 23 | //! 24 | //! 25 | //! ### Multi-variate output format 26 | //! 27 | //! For some code 28 | //! ``` 29 | //! assert_eq!( 30 | //! ((8f32, 16f32), ((1f32, 1f32), (2f32, 2f32))), 31 | //! rust_ad::forward!(tuple_function, 3f32, 5f32) 32 | //! ); 33 | //! #[rust_ad::forward_autodiff] 34 | //! fn tuple_function(x: f32, y: f32) -> (f32, f32) { 35 | //! let a1 = x + y; 36 | //! let b = 2f32 * a1; 37 | //! return (a1, b); 38 | //! } 39 | //! ``` 40 | //! - `(8f32, 16f32)` is the output value (`(a1, b)`). 41 | //! - `(1f32, 1f32)` represents the affects/derivatives of `x` on the return elements (`a1`, `b`). The first element representing `x` affect on the first return element (`a1`) and the second representing `x` affect on second return element (`b`). 42 | //! - `(2f32, 2f32)` represents the affects/derivatives of `y`. 43 | //! 44 | //! 45 | //! ### Resources 46 | //! - [Automatic Differentiation in Rust](https://github.com/JonathanWoollett-Light/autodiff-book) 47 | //! - [automatic-differentiation-worked-examples](http://h2.jaguarpaw.co.uk/posts/automatic-differentiation-worked-examples/) 48 | pub use rust_ad_macros::*; 49 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | extern crate rust_ad; 3 | #[allow(unused_imports)] 4 | use rust_ad::{forward, forward_autodiff, reverse, reverse_autodiff, unweave}; 5 | 6 | // To run in: 7 | // - debug: `cargo rustc --bin rust-ad --profile=dev -- -Zunpretty=expanded` 8 | // - release: `cargo rustc --bin rust-ad --profile=release -- -Zunpretty=expanded` 9 | // (I think `cargo expand --bin rust-ad` just does debug) 10 | 11 | // #[forward_autodiff] 12 | // fn empty(x: f32) -> f32 { 13 | // return x; 14 | // } 15 | 16 | // #[forward_autodiff] 17 | // fn plus(x: f32) -> f32 { 18 | // return x + 1f32; 19 | // } 20 | 21 | // #[forward_autodiff] 22 | // fn quad(x: f32) -> f32 { 23 | // let a = x.powi(2i32); 24 | // let b = x * 2f32; 25 | // let c = 2f32; 26 | // let f = a + b + c; 27 | // return f; 28 | // } 29 | 30 | // #[forward_autodiff] 31 | // fn multi(x: f32, y: f32) -> f32 { 32 | // let a = x.powi(2i32); 33 | // let b = x * 2f32; 34 | // let c = 2f32 / y; 35 | // let f = a + b + c; 36 | // return f; 37 | // } 38 | 39 | // #[forward_autodiff] 40 | // fn complex(x: f32, y: f32, z: f32) -> f32 { 41 | // let a = x.powi(2i32); 42 | // let b = x * 2f32 / z; 43 | // let c = 2f32 / (z.sqrt() + y); 44 | // let f = a + b + c; 45 | // return f; 46 | // } 47 | 48 | // #[reverse_autodiff] 49 | // fn powi_fn(x: f32, y: f32) -> f32 { 50 | // let a = x.powi(2i32); 51 | // let b = x * 2f32 * a; 52 | // let c = 2f32 / y; 53 | // let f = a + b + c; 54 | // return f; 55 | // } 56 | 57 | fn main() { 58 | // assert_eq!( 59 | // ((8f32, 16f32), ((1f32, 1f32), (2f32, 2f32))), 60 | // forward!(tuple_function, 3f32, 5f32) 61 | // ); 62 | } 63 | -------------------------------------------------------------------------------- /tests/forward.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use rust_ad::*; 3 | const TOLERANCE: f32 = 0.01; 4 | pub fn is_near(a: f32, b: f32) -> Result<(), String> { 5 | if (a - b).abs() < TOLERANCE { 6 | Ok(()) 7 | } else { 8 | Err(format!("{} is not near {}", a, b)) 9 | } 10 | } 11 | 12 | #[test] 13 | fn powi_test() { 14 | let (f, (der_x, der_y)) = forward!(powi_fn, 3f32, 5f32); 15 | assert_eq!(f, 63.4f32); 16 | assert_eq!(der_x, 60f32); 17 | assert_eq!(der_y, -0.08f32); 18 | 19 | /// Equations: 20 | /// - f = x^2 + 2x^3 + 2/y 21 | /// - ∂x|y=5 = 2x(1+3x) 22 | /// - ∂y|x=3 = -2/y^2 23 | /// Values: 24 | /// - f(3,5) = 9 + 54 + 2.5 = 63.4 25 | /// - ∂x|y=5(3) = 60 26 | /// - ∂y|x=3(5) = -0.08 27 | #[forward_autodiff] 28 | fn powi_fn(x: f32, y: f32) -> f32 { 29 | let a = x.powi(2i32); 30 | let b = x * 2f32 * a; 31 | let c = 2f32 / y; 32 | let f = a + b + c; 33 | return f; 34 | } 35 | } 36 | #[test] 37 | fn powf_test() { 38 | let (f, (der_x, der_y)) = forward!(powf_fn, 3f32, 5f32); 39 | assert_eq!(f, 63.4f32); 40 | assert_eq!(der_x, 60f32); 41 | assert_eq!(der_y, -0.08f32); 42 | 43 | /// Equations: 44 | /// - f = x^2 + 2x^3 + 2/y 45 | /// - ∂x|y=5 = 2x(1+3x) 46 | /// - ∂y|x=3 = -2/y^2 47 | /// Values: 48 | /// - f(3,5) = 9 + 54 + 2.5 = 63.4 49 | /// - ∂x|y=5(3) = 60 50 | /// - ∂y|x=3(5) = -0.08 51 | #[forward_autodiff] 52 | fn powf_fn(x: f32, y: f32) -> f32 { 53 | let a = x.powf(2f32); 54 | let b = x * 2f32 * a; 55 | let c = 2f32 / y; 56 | let f = a + b + c; 57 | return f; 58 | } 59 | } 60 | #[test] 61 | fn sqrt_test() { 62 | let (f, (der_x, der_y)) = forward!(sqrt_fn, 3f32, 5f32); 63 | is_near(f, 12.524355653f32).unwrap(); 64 | is_near(der_x, 5.4848275573f32).unwrap(); 65 | is_near(der_y, -0.08f32).unwrap(); 66 | 67 | /// Equations: 68 | /// - f = x^0.5 + 2x*x^0.5 + 2/y 69 | /// - ∂x|y=5 = (6x+1)/(2x^0.5) 70 | /// - ∂y|x=3 = -2/y^2 71 | /// Values: 72 | /// - f(3,5) = 12.524355653 73 | /// - ∂x|y=5(3) = 5.4848275573 74 | /// - ∂y|x=3(5) = -0.08 75 | #[forward_autodiff] 76 | fn sqrt_fn(x: f32, y: f32) -> f32 { 77 | let a = x.sqrt(); 78 | let b = x * 2f32 * a; 79 | let c = 2f32 / y; 80 | let f = a + b + c; 81 | return f; 82 | } 83 | } 84 | #[test] 85 | fn ln_test() { 86 | let (f, (der_x, der_y)) = forward!(ln_fn, 3f32, 5f32); 87 | is_near(f, 8.09028602068f32).unwrap(); 88 | is_near(der_x, 4.53055791067f32).unwrap(); 89 | is_near(der_y, -0.08f32).unwrap(); 90 | 91 | /// Equations: 92 | /// - f = ln(x) + 2x*ln(x)+ 2/y 93 | /// - ∂x|y=5 = (1/x) + 2*log(x)+2 94 | /// - ∂y|x=3 = -2/y^2 95 | /// Values: 96 | /// - f(3,5) = 8.09028602068 97 | /// - ∂x|y=5(3) = 4.53055791067 98 | /// - ∂y|x=3(5) = -0.08 99 | #[forward_autodiff] 100 | fn ln_fn(x: f32, y: f32) -> f32 { 101 | let a = x.ln(); 102 | let b = x * 2f32 * a; 103 | let c = 2f32 / y; 104 | let f = a + b + c; 105 | return f; 106 | } 107 | } 108 | #[test] 109 | fn log_test() { 110 | let (f, (der_x, der_y)) = forward!(log_fn, 3f32, 5f32); 111 | is_near(f, 11.494737505f32).unwrap(); 112 | is_near(der_x, 6.53621343018f32).unwrap(); 113 | is_near(der_y, -0.08f32).unwrap(); 114 | 115 | /// Equations: 116 | /// - f = log2(x) + 2x*log2(x)+ 2/y 117 | /// - ∂x|y=5 = ( 2x + 2x*ln(x)+1 ) / (x*ln(2)) 118 | /// - ∂y|x=3 = -2/y^2 119 | /// Values: 120 | /// - f(3,5) = 11.494737505 121 | /// - ∂x|y=5(3) = 6.53621343018 122 | /// - ∂y|x=3(5) = -0.08 123 | #[forward_autodiff] 124 | fn log_fn(x: f32, y: f32) -> f32 { 125 | let a = x.log(2f32); 126 | let b = x * 2f32 * a; 127 | let c = 2f32 / y; 128 | let f = a + b + c; 129 | return f; 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /tests/forward_general.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use rust_ad::*; 3 | const TOLERANCE: f32 = 0.01; 4 | pub fn is_near(a: f32, b: f32) -> Result<(), String> { 5 | if (a - b).abs() < TOLERANCE { 6 | Ok(()) 7 | } else { 8 | Err(format!("{} is not near {}", a, b)) 9 | } 10 | } 11 | 12 | #[test] 13 | fn empty_test() { 14 | let (x, der_x) = forward!(empty, 1f32); 15 | assert_eq!(x, 1.); 16 | assert_eq!(der_x, 1.); 17 | 18 | #[forward_autodiff] 19 | fn empty(x: f32) -> f32 { 20 | return x; 21 | } 22 | } 23 | #[test] 24 | fn plus_test() { 25 | let (x, der_x) = forward!(plus, 1f32); 26 | assert_eq!(x, 2f32); 27 | assert_eq!(der_x, 1f32); 28 | 29 | #[forward_autodiff] 30 | fn plus(x: f32) -> f32 { 31 | return x + 1f32; 32 | } 33 | } 34 | #[test] 35 | fn quad_test() { 36 | let (x, der_x) = forward!(quad, 3f32); 37 | assert_eq!(x, 17f32); 38 | assert_eq!(der_x, 8f32); 39 | 40 | #[forward_autodiff] 41 | fn quad(x: f32) -> f32 { 42 | let a = x.powi(2i32); 43 | let b = x * 2f32; 44 | let c = 2f32; 45 | let f = a + b + c; 46 | return f; 47 | } 48 | } 49 | #[test] 50 | fn multi_test() { 51 | let (f, (der_x, der_y)) = forward!(multi, 3f32, 5f32); 52 | assert_eq!(f, 15.4f32); 53 | assert_eq!(der_x, 8f32); 54 | assert_eq!(der_y, -0.08f32); 55 | 56 | /// f = x^2 + 2x + 2/y 57 | /// δx|y=5 = 2x + 2 58 | /// δy|x=3 = 2 59 | #[forward_autodiff] 60 | fn multi(x: f32, y: f32) -> f32 { 61 | let a = x.powi(2i32); 62 | let b = x * 2f32; 63 | let c = 2f32 / y; 64 | let f = a + b + c; 65 | return f; 66 | } 67 | } 68 | 69 | #[test] 70 | fn complex_test() { 71 | let (f, (der_x, der_y, der_z)) = forward!(complex, 3f32, 5f32, 7f32); 72 | is_near(f, 10.1187260448).unwrap(); 73 | is_near(der_x, 6.28571428571).unwrap(); 74 | is_near(der_y, -0.034212882033).unwrap(); 75 | is_near(der_z, -0.128914606556).unwrap(); 76 | 77 | // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 78 | // ∂x = 2(x+1/z) 79 | // ∂y = -2 / (y+z^0.5)^2 80 | // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 81 | // Therefore: 82 | // f(3,5,7) = 10.1187260448 83 | // ∂x| = 6.28571428571 84 | // ∂y| = −0.034212882033 85 | // ∂z| = −0.128914606556 86 | #[forward_autodiff] 87 | fn complex(x: f32, y: f32, z: f32) -> f32 { 88 | let a = x.powi(2i32); 89 | let b = x * 2f32 / z; 90 | let c = 2f32 / (z.sqrt() + y); 91 | let f = a + b + c; 92 | return f; 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /tests/reverse.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use rust_ad::*; 3 | const TOLERANCE: f32 = 0.01; 4 | pub fn is_near(a: f32, b: f32) -> Result<(), String> { 5 | if (a - b).abs() < TOLERANCE { 6 | Ok(()) 7 | } else { 8 | Err(format!("{} is not near {}", a, b)) 9 | } 10 | } 11 | 12 | #[test] 13 | fn powi_test() { 14 | let (f, (der_x, der_y)) = reverse!(powi_fn, (3f32, 5f32), (1f32)); 15 | assert_eq!(f, 63.4f32); 16 | assert_eq!(der_x, 60f32); 17 | assert_eq!(der_y, -0.08f32); 18 | 19 | /// Equations: 20 | /// - f = x^2 + 2x^3 + 2/y 21 | /// - ∂x|y=5 = 2x(1+3x) 22 | /// - ∂y|x=3 = -2/y^2 23 | /// Values: 24 | /// - f(3,5) = 9 + 54 + 2.5 = 63.4 25 | /// - ∂x|y=5(3) = 60 26 | /// - ∂y|x=3(5) = -0.08 27 | #[reverse_autodiff] 28 | fn powi_fn(x: f32, y: f32) -> f32 { 29 | let a = x.powi(2i32); 30 | let b = x * 2f32 * a; 31 | let c = 2f32 / y; 32 | let f = a + b + c; 33 | return f; 34 | } 35 | } 36 | #[test] 37 | fn powf_test() { 38 | let (f, (der_x, der_y)) = reverse!(powf_fn, (3f32, 5f32), (1f32)); 39 | assert_eq!(f, 63.4f32); 40 | assert_eq!(der_x, 60f32); 41 | assert_eq!(der_y, -0.08f32); 42 | 43 | /// Equations: 44 | /// - f = x^2 + 2x^3 + 2/y 45 | /// - ∂x|y=5 = 2x(1+3x) 46 | /// - ∂y|x=3 = -2/y^2 47 | /// Values: 48 | /// - f(3,5) = 9 + 54 + 2.5 = 63.4 49 | /// - ∂x|y=5(3) = 60 50 | /// - ∂y|x=3(5) = -0.08 51 | #[reverse_autodiff] 52 | fn powf_fn(x: f32, y: f32) -> f32 { 53 | let a = x.powf(2f32); 54 | let b = x * 2f32 * a; 55 | let c = 2f32 / y; 56 | let f = a + b + c; 57 | return f; 58 | } 59 | } 60 | #[test] 61 | fn sqrt_test() { 62 | let (f, (der_x, der_y)) = reverse!(sqrt_fn, (3f32, 5f32), (1f32)); 63 | is_near(f, 12.524355653f32).unwrap(); 64 | is_near(der_x, 5.4848275573f32).unwrap(); 65 | is_near(der_y, -0.08f32).unwrap(); 66 | 67 | /// Equations: 68 | /// - f = x^0.5 + 2x*x^0.5 + 2/y 69 | /// - ∂x|y=5 = (6x+1)/(2x^0.5) 70 | /// - ∂y|x=3 = -2/y^2 71 | /// Values: 72 | /// - f(3,5) = 12.524355653 73 | /// - ∂x|y=5(3) = 5.4848275573 74 | /// - ∂y|x=3(5) = -0.08 75 | #[reverse_autodiff] 76 | fn sqrt_fn(x: f32, y: f32) -> f32 { 77 | let a = x.sqrt(); 78 | let b = x * 2f32 * a; 79 | let c = 2f32 / y; 80 | let f = a + b + c; 81 | return f; 82 | } 83 | } 84 | #[test] 85 | fn ln_test() { 86 | let (f, (der_x, der_y)) = reverse!(ln_fn, (3f32, 5f32), (1f32)); 87 | is_near(f, 8.09028602068f32).unwrap(); 88 | is_near(der_x, 4.53055791067f32).unwrap(); 89 | is_near(der_y, -0.08f32).unwrap(); 90 | 91 | /// Equations: 92 | /// - f = ln(x) + 2x*ln(x)+ 2/y 93 | /// - ∂x|y=5 = (1/x) + 2*log(x)+2 94 | /// - ∂y|x=3 = -2/y^2 95 | /// Values: 96 | /// - f(3,5) = 8.09028602068 97 | /// - ∂x|y=5(3) = 4.53055791067 98 | /// - ∂y|x=3(5) = -0.08 99 | #[reverse_autodiff] 100 | fn ln_fn(x: f32, y: f32) -> f32 { 101 | let a = x.ln(); 102 | let b = x * 2f32 * a; 103 | let c = 2f32 / y; 104 | let f = a + b + c; 105 | return f; 106 | } 107 | } 108 | #[test] 109 | fn log_test() { 110 | let (f, (der_x, der_y)) = reverse!(log_fn, (3f32, 5f32), (1f32)); 111 | is_near(f, 11.494737505f32).unwrap(); 112 | is_near(der_x, 6.53621343018f32).unwrap(); 113 | is_near(der_y, -0.08f32).unwrap(); 114 | 115 | /// Equations: 116 | /// - f = log2(x) + 2x*log2(x)+ 2/y 117 | /// - ∂x|y=5 = ( 2x + 2x*ln(x)+1 ) / (x*ln(2)) 118 | /// - ∂y|x=3 = -2/y^2 119 | /// Values: 120 | /// - f(3,5) = 11.494737505 121 | /// - ∂x|y=5(3) = 6.53621343018 122 | /// - ∂y|x=3(5) = -0.08 123 | #[reverse_autodiff] 124 | fn log_fn(x: f32, y: f32) -> f32 { 125 | let a = x.log(2f32); 126 | let b = x * 2f32 * a; 127 | let c = 2f32 / y; 128 | let f = a + b + c; 129 | return f; 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /tests/reverse_general.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use rust_ad::*; 3 | const TOLERANCE: f32 = 0.01; 4 | pub fn is_near(a: f32, b: f32) -> Result<(), String> { 5 | if (a - b).abs() < TOLERANCE { 6 | Ok(()) 7 | } else { 8 | Err(format!("{} is not near {}", a, b)) 9 | } 10 | } 11 | 12 | #[test] 13 | fn empty_test() { 14 | let (x, der_x) = reverse!(empty, (1f32), (1f32)); 15 | assert_eq!(x, 1.); 16 | assert_eq!(der_x, 1.); 17 | 18 | #[reverse_autodiff] 19 | fn empty(x: f32) -> f32 { 20 | return x; 21 | } 22 | } 23 | #[test] 24 | fn plus_test() { 25 | let (x, der_x) = reverse!(plus, (1f32), (1f32)); 26 | assert_eq!(x, 2f32); 27 | assert_eq!(der_x, 1f32); 28 | 29 | #[reverse_autodiff] 30 | fn plus(x: f32) -> f32 { 31 | return x + 1f32; 32 | } 33 | } 34 | #[test] 35 | fn quad_test() { 36 | let (x, der_x) = reverse!(quad, (3f32), (1f32)); 37 | assert_eq!(x, 17f32); 38 | assert_eq!(der_x, 8f32); 39 | 40 | #[reverse_autodiff] 41 | fn quad(x: f32) -> f32 { 42 | let a = x.powi(2i32); 43 | let b = x * 2f32; 44 | let c = 2f32; 45 | let f = a + b + c; 46 | return f; 47 | } 48 | } 49 | #[test] 50 | fn multi_test() { 51 | let (f, (der_x, der_y)) = reverse!(multi, (3f32, 5f32), (1f32)); 52 | assert_eq!(f, 15.4f32); 53 | assert_eq!(der_x, 8f32); 54 | assert_eq!(der_y, -0.08f32); 55 | 56 | #[reverse_autodiff] 57 | fn multi(x: f32, y: f32) -> f32 { 58 | let a = x.powi(2i32); 59 | let b = x * 2f32; 60 | let c = 2f32 / y; 61 | let f = a + b + c; 62 | return f; 63 | } 64 | } 65 | #[test] 66 | fn complex_test() { 67 | let (f, (der_x, der_y, der_z)) = reverse!(complex, (3f32, 5f32, 7f32), (1f32)); 68 | is_near(f, 10.1187260448).unwrap(); 69 | is_near(der_x, 6.28571428571).unwrap(); 70 | is_near(der_y, -0.034212882033).unwrap(); 71 | is_near(der_z, -0.128914606556).unwrap(); 72 | 73 | // f(x,y,z) = x^2 + 2x/z + 2/(y+z^0.5) 74 | // ∂x = 2(x+1/z) 75 | // ∂y = -2 / (y+z^0.5)^2 76 | // ∂z = -2x/z^2 -1/(z^0.5 * (y+z^0.5)^2) 77 | // Therefore: 78 | // f(3,5,7) = 10.1187260448 79 | // ∂x| = 6.28571428571 80 | // ∂y| = −0.034212882033 81 | // ∂z| = −0.128914606556 82 | #[reverse_autodiff] 83 | fn complex(x: f32, y: f32, z: f32) -> f32 { 84 | let a = x.powi(2i32); 85 | let b = x * 2f32 / z; 86 | let c = 2f32 / (z.sqrt() + y); 87 | let f = a + b + c; 88 | return f; 89 | } 90 | } 91 | --------------------------------------------------------------------------------