├── .github └── workflows │ ├── codspeed.yml │ ├── gh-pages.yml │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── einsum-codegen ├── Cargo.toml └── src │ ├── codegen │ ├── format.rs │ ├── mod.rs │ └── ndarray │ │ ├── mod.rs │ │ └── naive.rs │ ├── lib.rs │ ├── namespace.rs │ ├── parser.rs │ ├── path.rs │ └── subscripts.rs ├── einsum-derive ├── Cargo.toml ├── README.md ├── benches │ └── einsum.rs ├── src │ └── lib.rs └── tests │ ├── cases │ ├── number_of_arguments_mismatch.rs │ └── number_of_arguments_mismatch.stderr │ └── trybuild.rs └── rust-toolchain /.github/workflows/codspeed.yml: -------------------------------------------------------------------------------- 1 | name: Codspeed 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" # or "master" 7 | pull_request: 8 | # `workflow_dispatch` allows CodSpeed to trigger backtest 9 | # performance analysis in order to generate initial data. 10 | workflow_dispatch: 11 | 12 | jobs: 13 | benchmarks: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Setup rust toolchain, cache and cargo-codspeed binary 19 | uses: moonrepo/setup-rust@v0 20 | with: 21 | channel: stable 22 | cache-target: release 23 | bins: cargo-codspeed 24 | 25 | - name: Build the benchmark target(s) 26 | run: cargo codspeed build -p einsum-derive 27 | 28 | - name: Run the benchmarks 29 | uses: CodSpeedHQ/action@v2 30 | with: 31 | run: cargo codspeed run -p einsum-derive 32 | token: ${{ secrets.CODSPEED_TOKEN }} 33 | -------------------------------------------------------------------------------- /.github/workflows/gh-pages.yml: -------------------------------------------------------------------------------- 1 | # Based on starter workflow 2 | # https://github.com/actions/starter-workflows/blob/8217436fdee2338da2d6fd02b7c9fcff634c40e7/pages/static.yml 3 | # 4 | # Simple workflow for deploying static content to GitHub Pages 5 | name: "GitHub Pages" 6 | 7 | on: 8 | # Runs on pushes targeting the default branch 9 | push: 10 | branches: 11 | - main 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 17 | permissions: 18 | contents: read 19 | pages: write 20 | id-token: write 21 | 22 | # Allow one concurrent deployment 23 | concurrency: 24 | group: "pages" 25 | cancel-in-progress: true 26 | 27 | jobs: 28 | # Single deploy job since we're just deploying 29 | deploy: 30 | environment: 31 | name: github-pages 32 | url: ${{ steps.deployment.outputs.page_url }} 33 | runs-on: ubuntu-latest 34 | steps: 35 | - name: Checkout 36 | uses: actions/checkout@v3 37 | 38 | - uses: actions-rs/toolchain@v1 39 | with: 40 | toolchain: nightly 41 | override: true 42 | default: true 43 | components: rustfmt 44 | 45 | # Generate cargo-doc 46 | - uses: actions-rs/cargo@v1 47 | with: 48 | command: doc 49 | args: --no-deps 50 | 51 | # Generate benchmark report 52 | - uses: actions-rs/cargo@v1 53 | with: 54 | command: bench 55 | 56 | - name: Setup Pages 57 | uses: actions/configure-pages@v2 58 | 59 | - run: | 60 | mkdir -p pages/ 61 | mv target/doc pages/doc 62 | mv target/criterion pages/bench 63 | 64 | # Upload target/doc directory 65 | - name: Upload artifact 66 | uses: actions/upload-pages-artifact@v1 67 | with: 68 | path: pages/ 69 | 70 | - name: Deploy to GitHub Pages 71 | id: deployment 72 | uses: actions/deploy-pages@v1 73 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: {} 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v1 14 | - uses: actions-rs/toolchain@v1 15 | with: 16 | toolchain: nightly 17 | components: rustfmt 18 | - uses: actions-rs/cargo@v1 19 | with: 20 | command: test 21 | toolchain: nightly 22 | 23 | check-format: 24 | runs-on: ubuntu-22.04 25 | steps: 26 | - uses: actions/checkout@v1 27 | - uses: actions-rs/cargo@v1 28 | with: 29 | command: fmt 30 | toolchain: stable 31 | args: -- --check 32 | 33 | clippy: 34 | runs-on: ubuntu-22.04 35 | steps: 36 | - uses: actions/checkout@v1 37 | - uses: actions-rs/cargo@v1 38 | with: 39 | command: clippy 40 | toolchain: stable 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | # insta 13 | *.pending-snap 14 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = [ 4 | "einsum-derive", 5 | "einsum-codegen", 6 | ] 7 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Toshiki Teramura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | einsum-derive/README.md -------------------------------------------------------------------------------- /einsum-codegen/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "einsum-codegen" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Toshiki Teramura "] 6 | 7 | description = "Helper for generating einsum implementation using proc-macro" 8 | documentation = "https://docs.rs/einsum-codegen/" 9 | repository = "https://github.com/termoshtt/einsum-derive" 10 | keywords = ["ndarray", "matrix", "einsum"] 11 | license = "MIT OR Apache-2.0" 12 | readme = "../README.md" 13 | categories = ["algorithms", "science"] 14 | 15 | [dependencies] 16 | anyhow = "1.0.66" 17 | katexit = "0.1.2" 18 | nom = "7.1.1" 19 | 20 | # for codegen 21 | proc-macro2 = "1.0.46" 22 | quote = "1.0.21" 23 | syn = "1.0.102" 24 | 25 | [dev-dependencies] 26 | insta = "1.21.1" 27 | maplit = "1.0.2" 28 | ndarray = "0.15.6" 29 | -------------------------------------------------------------------------------- /einsum-codegen/src/codegen/format.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io::Write, 3 | process::{Command, Stdio}, 4 | }; 5 | 6 | /// Format generated Rust code using `rustfmt` run as external process. 7 | pub fn format_block(tt: String) -> String { 8 | let tt = format!("fn main() {{ {} }}", tt); 9 | 10 | let mut child = Command::new("rustfmt") 11 | .stdin(Stdio::piped()) 12 | .stdout(Stdio::piped()) 13 | .spawn() 14 | .expect("Failed to spawn rustfmt process"); 15 | 16 | // Write input from another thread for avoiding deadlock. 17 | // See https://doc.rust-lang.org/std/process/index.html#handling-io 18 | let mut stdin = child.stdin.take().expect("Failed to open stdin"); 19 | std::thread::spawn(move || { 20 | stdin 21 | .write_all(tt.as_bytes()) 22 | .expect("Failed to write to stdin"); 23 | }); 24 | let output = child 25 | .wait_with_output() 26 | .expect("Failed to wait output of rustfmt process"); 27 | 28 | // non-UTF8 comment should be handled in the tokenize phase, 29 | // and not be included in IR. 30 | let out = String::from_utf8(output.stdout).expect("rustfmt output contains non-UTF8 input"); 31 | 32 | let formatted_lines: Vec<&str> = out 33 | .lines() 34 | .filter_map(|line| match line { 35 | "fn main() {" | "}" => None, 36 | _ => line.strip_prefix(" "), 37 | }) 38 | .collect(); 39 | formatted_lines.join("\n") 40 | } 41 | -------------------------------------------------------------------------------- /einsum-codegen/src/codegen/mod.rs: -------------------------------------------------------------------------------- 1 | //! Generate einsum implementation 2 | 3 | mod format; 4 | pub use format::format_block; 5 | 6 | pub mod ndarray; 7 | -------------------------------------------------------------------------------- /einsum-codegen/src/codegen/ndarray/mod.rs: -------------------------------------------------------------------------------- 1 | //! For [ndarray](https://crates.io/crates/ndarray) crate 2 | 3 | pub mod naive; 4 | 5 | use crate::subscripts::Subscripts; 6 | use proc_macro2::TokenStream as TokenStream2; 7 | use quote::{format_ident, quote}; 8 | 9 | fn dim(n: usize) -> syn::Path { 10 | let ix = quote::format_ident!("Ix{}", n); 11 | syn::parse_quote! { ndarray::#ix } 12 | } 13 | 14 | /// Generate einsum function definition 15 | pub fn function_definition(subscripts: &Subscripts, inner: TokenStream2) -> TokenStream2 { 16 | let fn_name = format_ident!("{}", subscripts.escaped_ident()); 17 | let n = subscripts.inputs.len(); 18 | 19 | let args = &subscripts.inputs; 20 | let storages: Vec = (0..n).map(|n| quote::format_ident!("S{}", n)).collect(); 21 | let dims: Vec = subscripts 22 | .inputs 23 | .iter() 24 | .map(|ss| dim(ss.indices().len())) 25 | .collect(); 26 | 27 | let out_dim = dim(subscripts.output.indices().len()); 28 | 29 | quote! { 30 | fn #fn_name( 31 | #( #args: ndarray::ArrayBase<#storages, #dims> ),* 32 | ) -> ndarray::Array 33 | where 34 | T: ndarray::LinalgScalar, 35 | #( #storages: ndarray::Data ),* 36 | { 37 | #inner 38 | } 39 | } 40 | } 41 | 42 | #[cfg(test)] 43 | mod test { 44 | use crate::{codegen::format_block, *}; 45 | 46 | #[test] 47 | fn function_definition_snapshot() { 48 | let mut namespace = Namespace::init(); 49 | let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); 50 | let inner = quote::quote! { todo!() }; 51 | let tt = format_block(super::function_definition(&subscripts, inner).to_string()); 52 | insta::assert_snapshot!(tt, @r###" 53 | fn ab_bc__ac( 54 | arg0: ndarray::ArrayBase, 55 | arg1: ndarray::ArrayBase, 56 | ) -> ndarray::Array 57 | where 58 | T: ndarray::LinalgScalar, 59 | S0: ndarray::Data, 60 | S1: ndarray::Data, 61 | { 62 | todo!() 63 | } 64 | "###); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /einsum-codegen/src/codegen/ndarray/naive.rs: -------------------------------------------------------------------------------- 1 | //! Generate einsum function with naive loop 2 | 3 | #[cfg(doc)] 4 | use super::function_definition; 5 | 6 | use crate::Subscripts; 7 | 8 | use proc_macro2::TokenStream as TokenStream2; 9 | use quote::quote; 10 | use std::collections::HashSet; 11 | 12 | fn index_ident(i: char) -> syn::Ident { 13 | quote::format_ident!("{}", i) 14 | } 15 | 16 | fn n_ident(i: char) -> syn::Ident { 17 | quote::format_ident!("n_{}", i) 18 | } 19 | 20 | fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 { 21 | let mut tt = inner; 22 | for &i in indices.iter().rev() { 23 | let index = index_ident(i); 24 | let n = n_ident(i); 25 | tt = quote! { 26 | for #index in 0..#n { #tt } 27 | }; 28 | } 29 | tt 30 | } 31 | 32 | fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 { 33 | let mut inner_args_tt = Vec::new(); 34 | for (argc, arg) in subscripts.inputs.iter().enumerate() { 35 | let mut index = Vec::new(); 36 | for i in subscripts.inputs[argc].indices() { 37 | index.push(index_ident(i)); 38 | } 39 | inner_args_tt.push(quote! { 40 | #arg[(#(#index),*)] 41 | }) 42 | } 43 | let mut inner_mul = None; 44 | for inner in inner_args_tt { 45 | match inner_mul { 46 | Some(i) => inner_mul = Some(quote! { #i * #inner }), 47 | None => inner_mul = Some(inner), 48 | } 49 | } 50 | 51 | let output_ident = &subscripts.output; 52 | let mut output_indices = Vec::new(); 53 | for i in &subscripts.output.indices() { 54 | let index = index_ident(*i); 55 | output_indices.push(index.clone()); 56 | } 57 | quote! { 58 | #output_ident[(#(#output_indices),*)] = #inner_mul; 59 | } 60 | } 61 | 62 | /// Generate naive contraction loop 63 | /// 64 | /// ``` 65 | /// # use ndarray::Array2; 66 | /// # let arg0 = Array2::::zeros((3, 3)); 67 | /// # let arg1 = Array2::::zeros((3, 3)); 68 | /// # let mut out0 = Array2::::zeros((3, 3)); 69 | /// # let n_i = 3; 70 | /// # let n_j = 3; 71 | /// # let n_k = 3; 72 | /// for i in 0..n_i { 73 | /// for k in 0..n_k { 74 | /// for j in 0..n_j { 75 | /// out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)]; 76 | /// } 77 | /// } 78 | /// } 79 | /// ``` 80 | /// 81 | pub fn contraction(subscripts: &Subscripts) -> TokenStream2 { 82 | let mut indices: Vec = subscripts.output.indices(); 83 | for i in subscripts.contraction_indices() { 84 | indices.push(i); 85 | } 86 | 87 | let inner = contraction_inner(subscripts); 88 | contraction_for(&indices, inner) 89 | } 90 | 91 | /// Define the index size identifiers, e.g. `n_i` 92 | pub fn define_array_size(subscripts: &Subscripts) -> TokenStream2 { 93 | let mut appeared: HashSet = HashSet::new(); 94 | let mut tt = Vec::new(); 95 | for arg in subscripts.inputs.iter() { 96 | let n_ident: Vec = arg 97 | .indices() 98 | .into_iter() 99 | .map(|i| { 100 | if appeared.contains(&i) { 101 | quote::format_ident!("_") 102 | } else { 103 | appeared.insert(i); 104 | n_ident(i) 105 | } 106 | }) 107 | .collect(); 108 | tt.push(quote! { 109 | let (#(#n_ident),*) = #arg.dim(); 110 | }); 111 | } 112 | quote! { #(#tt)* } 113 | } 114 | 115 | /// Generate `assert_eq!` to check the size of user input tensors 116 | pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 { 117 | let mut tt = Vec::new(); 118 | for arg in &subscripts.inputs { 119 | // local variable, e.g. `n_2` 120 | let n_each: Vec<_> = (0..arg.indices().len()) 121 | .map(|m| quote::format_ident!("n_{}", m)) 122 | .collect(); 123 | // size of index defined previously, e.g. `n_i` 124 | let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect(); 125 | tt.push(quote! { 126 | let (#(#n_each),*) = #arg.dim(); 127 | #(assert_eq!(#n_each, #n);)* 128 | }); 129 | } 130 | quote! { #({ #tt })* } 131 | } 132 | 133 | fn define_output_array(subscripts: &Subscripts) -> TokenStream2 { 134 | // Define output array 135 | let output_ident = &subscripts.output; 136 | let mut n_output = Vec::new(); 137 | for i in subscripts.output.indices() { 138 | n_output.push(n_ident(i)); 139 | } 140 | quote! { 141 | let mut #output_ident = ndarray::Array::zeros((#(#n_output),*)); 142 | } 143 | } 144 | 145 | /// Actual component of einsum [function_definition] 146 | pub fn inner(subscripts: &Subscripts) -> TokenStream2 { 147 | let array_size = define_array_size(subscripts); 148 | let array_size_asserts = array_size_asserts(subscripts); 149 | let output_ident = &subscripts.output; 150 | let output_tt = define_output_array(subscripts); 151 | let contraction_tt = contraction(subscripts); 152 | quote! { 153 | #array_size 154 | #array_size_asserts 155 | #output_tt 156 | #contraction_tt 157 | #output_ident 158 | } 159 | } 160 | 161 | #[cfg(test)] 162 | mod test { 163 | use crate::{codegen::format_block, *}; 164 | 165 | #[test] 166 | fn define_array_size() { 167 | let mut namespace = Namespace::init(); 168 | let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); 169 | let tt = format_block(super::define_array_size(&subscripts).to_string()); 170 | insta::assert_snapshot!(tt, @r###" 171 | let (n_a, n_b) = arg0.dim(); 172 | let (_, n_c) = arg1.dim(); 173 | "###); 174 | } 175 | 176 | #[test] 177 | fn contraction() { 178 | let mut namespace = Namespace::init(); 179 | let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); 180 | let tt = format_block(super::contraction(&subscripts).to_string()); 181 | insta::assert_snapshot!(tt, @r###" 182 | for a in 0..n_a { 183 | for c in 0..n_c { 184 | for b in 0..n_b { 185 | out0[(a, c)] = arg0[(a, b)] * arg1[(b, c)]; 186 | } 187 | } 188 | } 189 | "###); 190 | } 191 | 192 | #[test] 193 | fn inner() { 194 | let mut namespace = Namespace::init(); 195 | let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap(); 196 | let tt = format_block(super::inner(&subscripts).to_string()); 197 | insta::assert_snapshot!(tt, @r###" 198 | let (n_a, n_b) = arg0.dim(); 199 | let (_, n_c) = arg1.dim(); 200 | { 201 | let (n_0, n_1) = arg0.dim(); 202 | assert_eq!(n_0, n_a); 203 | assert_eq!(n_1, n_b); 204 | } 205 | { 206 | let (n_0, n_1) = arg1.dim(); 207 | assert_eq!(n_0, n_b); 208 | assert_eq!(n_1, n_c); 209 | } 210 | let mut out0 = ndarray::Array::zeros((n_a, n_c)); 211 | for a in 0..n_a { 212 | for c in 0..n_c { 213 | for b in 0..n_b { 214 | out0[(a, c)] = arg0[(a, b)] * arg1[(b, c)]; 215 | } 216 | } 217 | } 218 | out0 219 | "###); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /einsum-codegen/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr( 2 | doc, 3 | feature(prelude_import, custom_inner_attributes, proc_macro_hygiene) 4 | )] 5 | #![cfg_attr(doc, katexit::katexit)] 6 | //! Helper crate for einsum algorithm 7 | //! 8 | //! Introduction to einsum 9 | //! ----------------------- 10 | //! The Einstein summation rule in theoretical physics and related field 11 | //! including machine learning is a rule for abbreviating tensor operations. 12 | //! For example, one of most basic tensor operation is inner product of 13 | //! two vectors in $n$-dimensional Euclidean space $x, y \in \mathbb{R}^n$: 14 | //! $$ 15 | //! (x, y) = \sum_{i \in I} x_i y_i 16 | //! $$ 17 | //! where $I$ denotes a set of indices, i.e. $I = \\{0, 1, \ldots, n-1 \\}$. 18 | //! Another example is matrix multiplications. 19 | //! A multiplication of three square matrices $A, B, C \in M_n(\mathbb{R})$ 20 | //! can be written as its element: 21 | //! $$ 22 | //! ABC_{il} = \sum_{j \in J} \sum_{k \in K} a_{ij} b_{jk} c_{kl} 23 | //! $$ 24 | //! 25 | //! Many such tensor operations appear in various field, 26 | //! and we usually define many functions corresponding to each operations. 27 | //! For inner product of vectors, we may defines a function like 28 | //! ```ignore 29 | //! fn inner(a: Array1D, b: Array1D) -> R; 30 | //! ``` 31 | //! for matrix multiplication: 32 | //! ```ignore 33 | //! fn matmul(a: Array2D, b: Array2D) -> Array2D; 34 | //! ``` 35 | //! or taking three matrices: 36 | //! ```ignore 37 | //! fn matmul3(a: Array2D, b: Array2D, c: Array2D) -> Array2D; 38 | //! ``` 39 | //! and so on. 40 | //! 41 | //! These definitions are very similar, and actually, 42 | //! they can be represented in a single manner. 43 | //! These computations multiplicate the element of each tensor with some indices, 44 | //! and sum up them along some indices. 45 | //! So we have to determine 46 | //! 47 | //! - what indices to be used for each tensors in multiplications 48 | //! - what indices to be summed up 49 | //! 50 | //! This can be done by ordering indices for input tensors 51 | //! with a Einstein summation rule, i.e. sum up indices which appears more than once. 52 | //! For example, `inner` is represented by `i,i->`, `matmul` is represented by `ij,jk->ik`, 53 | //! `matmul3` is represented by `ij,jk,kl->il`, and so on 54 | //! where `,` is the separator of each indices 55 | //! and each index must be represented by a single char like `i` or `j`. 56 | //! `->` separates the indices of input tensors and indices of output tensor. 57 | //! If no indices are placed like `i,i->`, it means the tensor is 0-rank, i.e. a scalar value. 58 | //! "einsum" is an algorithm or runtime to be expand such string 59 | //! into actual tensor operations. 60 | //! 61 | //! einsum algorithm 62 | //! ----------------- 63 | //! We discuss an overview of einsum algorithm for understanding the structure of this crate. 64 | //! 65 | //! ### Factorize and Memorize partial summation 66 | //! Partial summation and its memorization reduces number of floating point operations. 67 | //! For simplicity, both addition `+` and multiplication `*` are counted as 1 operation, 68 | //! and do not consider fused multiplication-addition (FMA). 69 | //! In the above `matmul3` example, there are $\\#K \times \\#J$ addition 70 | //! and $2 \times \\#K \times \\#J$ multiplications for every indices $(i, l)$, 71 | //! where $\\#$ denotes the number of elements in the index sets. 72 | //! Assuming the all sizes of indices are same and denoted by $N$, 73 | //! there are $O(N^4)$ floating point operations. 74 | //! 75 | //! When we sum up partially along `j`: 76 | //! $$ 77 | //! \sum_{k \in K} c_{kl} \left( \sum_{j \in J} a_{ij} b_{jk} \right), 78 | //! $$ 79 | //! and memorize its result as $d_{ik}$: 80 | //! $$ 81 | //! \sum_{k \in K} c_{kl} d_{ik}, 82 | //! \text{where} \space d_{ik} = \sum_{j \in J} a_{ij} b_{jk}, 83 | //! $$ 84 | //! there are $O(N^3)$ operations for both computing $d_{ik}$ and final summation 85 | //! with $O(N^2)$ memorization storage. 86 | //! 87 | //! When is this factorization possible? We know that above `matmul3` example 88 | //! is also written as associative matrix product $ABC = A(BC) = (AB)C$, 89 | //! and partial summation along $j$ is corresponding to store $D = AB$. 90 | //! This is not always possible. Let us consider a trace of two matrix product 91 | //! $$ 92 | //! \text{Tr} (AB) = \sum_{i \in I} \sum_{j \in J} a_{ij} b_{ji} 93 | //! $$ 94 | //! This is written as `ij,ji->` in einsum subscript form. 95 | //! We cannot factor out both $a_{ij}$ and $b_{ji}$ out of summation 96 | //! since they contain both indices. 97 | //! Whether this factorization is possible or not can be determined only 98 | //! from einsum subscript form, and we call a subscript is "reducible" 99 | //! if factorization is possible, and "irreducible" if not possible, 100 | //! i.e. `ij,jk,kl->il` is reducible and `ij,ji->` is irreducible. 101 | //! 102 | //! ### Subscript representation 103 | //! 104 | //! To discuss subscript factorization, we have to track which tensors are 105 | //! used as each input. 106 | //! In above `matmul3` example, `ij,jk,kl->il` is factorized into sub-subscripts 107 | //! `ij,jk->ik` and `ik,kl->il` where `ik` in the second subscript uses 108 | //! the output of first subscript. The information of the name of tensors 109 | //! has been dropped from sub-subscripts, 110 | //! and we have to create a mechanism for managing it. 111 | //! 112 | //! We introduce a subscript representation of `matmul3` case with tensor names: 113 | //! 114 | //! ```text 115 | //! ij,jk,kl->il | a b c -> out 116 | //! ``` 117 | //! 118 | //! In this form, the factorization can be described: 119 | //! 120 | //! ```text 121 | //! ij,jk->ik | a b -> d 122 | //! ik,kl->il | d c -> out 123 | //! ``` 124 | //! 125 | //! To clarify the tensor is given from user or created while factorization, 126 | //! we use `arg{N}` and `out{N}` identifiers: 127 | //! 128 | //! ```text 129 | //! ij,jk->ik | arg0 arg1 -> out0 130 | //! ik,kl->il | out0 arg2 -> out1 131 | //! ``` 132 | //! 133 | //! ### Summation order 134 | //! 135 | //! This factorization is not unique. 136 | //! Apparently, there are two ways for `matmul3` case as corresponding to $(AB)C$: 137 | //! 138 | //! ```text 139 | //! ij,jk->ik | arg0 arg1 -> out0 140 | //! ik,kl->il | out0 arg2 -> out1 141 | //! ``` 142 | //! 143 | //! and to $A(BC)$: 144 | //! 145 | //! ```text 146 | //! jk,kl->jl | arg1 arg2 -> out0 147 | //! jl,ij->il | out0 arg0 -> out1 148 | //! ``` 149 | //! 150 | //! These are different procedure i.e. number of floating operations 151 | //! and required intermediate memories are different, 152 | //! but return same output tensor 153 | //! (we ignore non-associativity of floating numbers on this document). 154 | //! This becomes complicated combinational optimization problem 155 | //! if there are many contraction indicies, 156 | //! and the objective of this crate is to (heuristically) solve this problem. 157 | //! 158 | 159 | pub mod codegen; 160 | pub mod parser; 161 | 162 | mod namespace; 163 | mod path; 164 | mod subscripts; 165 | 166 | pub use namespace::*; 167 | pub use path::*; 168 | pub use subscripts::*; 169 | -------------------------------------------------------------------------------- /einsum-codegen/src/namespace.rs: -------------------------------------------------------------------------------- 1 | use proc_macro2::TokenStream; 2 | use quote::{format_ident, ToTokens, TokenStreamExt}; 3 | use std::fmt; 4 | 5 | /// Names of tensors 6 | /// 7 | /// As the crate level document explains, 8 | /// einsum factorization requires to track names of tensors 9 | /// in addition to subscripts, and this struct manages it. 10 | /// This works as a simple counter, which counts how many intermediate 11 | /// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier. 12 | /// 13 | #[derive(Debug, PartialEq, Eq, Clone)] 14 | pub struct Namespace { 15 | last: usize, 16 | } 17 | 18 | impl Namespace { 19 | /// Create new namespace 20 | pub fn init() -> Self { 21 | Namespace { last: 0 } 22 | } 23 | 24 | /// Issue new identifier 25 | pub fn new_ident(&mut self) -> Position { 26 | let pos = Position::Out(self.last); 27 | self.last += 1; 28 | pos 29 | } 30 | } 31 | 32 | /// Which tensor the subscript specifies 33 | #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] 34 | pub enum Position { 35 | /// The tensor which user inputs as N-th argument of einsum 36 | Arg(usize), 37 | /// The tensor created by einsum in its N-th step 38 | Out(usize), 39 | } 40 | 41 | impl fmt::Debug for Position { 42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 43 | match self { 44 | Position::Arg(n) => write!(f, "arg{}", n), 45 | Position::Out(n) => write!(f, "out{}", n), 46 | } 47 | } 48 | } 49 | 50 | impl fmt::Display for Position { 51 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 52 | fmt::Debug::fmt(self, f) 53 | } 54 | } 55 | 56 | impl ToTokens for Position { 57 | fn to_tokens(&self, tokens: &mut TokenStream) { 58 | match self { 59 | Position::Arg(n) => tokens.append(format_ident!("arg{}", n)), 60 | Position::Out(n) => tokens.append(format_ident!("out{}", n)), 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /einsum-codegen/src/parser.rs: -------------------------------------------------------------------------------- 1 | //! Parse einsum subscripts 2 | //! 3 | //! These parsers are implemented using [nom](https://github.com/Geal/nom), 4 | //! and corresponding EBNF-like schema are written in each document page. 5 | //! 6 | 7 | use anyhow::{bail, Error, Result}; 8 | use nom::{ 9 | bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*, IResult, 10 | Parser, 11 | }; 12 | use std::fmt; 13 | 14 | /// index = `a` | `b` | `c` | `d` | `e` | `f` | `g` | `h` | `i` | `j` | `k` | `l` |`m` | `n` | `o` | `p` | `q` | `r` | `s` | `t` | `u` | `v` | `w` | `x` |`y` | `z`; 15 | pub fn index(input: &str) -> IResult<&str, char> { 16 | satisfy(|c| c.is_ascii_lowercase()).parse(input) 17 | } 18 | 19 | /// ellipsis = `...` 20 | pub fn ellipsis(input: &str) -> IResult<&str, &str> { 21 | tag("...").parse(input) 22 | } 23 | 24 | /// subscript = { [index] } [ [ellipsis] { [index] } ]; 25 | pub fn subscript(input: &str) -> IResult<&str, RawSubscript> { 26 | let mut indices = many0(tuple((multispace0, index)).map(|(_space, c)| c)); 27 | let (input, start) = indices(input)?; 28 | let (input, end) = opt(tuple((multispace0, ellipsis, multispace0, indices)) 29 | .map(|(_space_pre, _ellipsis, _space_post, output)| output))(input)?; 30 | if let Some(end) = end { 31 | Ok((input, RawSubscript::Ellipsis { start, end })) 32 | } else { 33 | Ok((input, RawSubscript::Indices(start))) 34 | } 35 | } 36 | 37 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 38 | pub enum RawSubscript { 39 | /// Indices without ellipsis, e.g. `ijk` 40 | Indices(Vec), 41 | /// Indices with ellipsis, e.g. `i...j` 42 | Ellipsis { start: Vec, end: Vec }, 43 | } 44 | 45 | impl PartialEq<[char; N]> for RawSubscript { 46 | fn eq(&self, other: &[char; N]) -> bool { 47 | match self { 48 | RawSubscript::Indices(indices) => indices.eq(other), 49 | _ => false, 50 | } 51 | } 52 | } 53 | 54 | impl fmt::Display for RawSubscript { 55 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 56 | match self { 57 | RawSubscript::Indices(indices) => { 58 | for i in indices { 59 | write!(f, "{}", i)?; 60 | } 61 | } 62 | RawSubscript::Ellipsis { start, end } => { 63 | for i in start { 64 | write!(f, "{}", i)?; 65 | } 66 | write!(f, "___")?; 67 | for i in end { 68 | write!(f, "{}", i)?; 69 | } 70 | } 71 | } 72 | Ok(()) 73 | } 74 | } 75 | 76 | /// Einsum subscripts, e.g. `ij,jk->ik` 77 | #[derive(Debug, PartialEq, Eq)] 78 | pub struct RawSubscripts { 79 | /// Input subscript, `ij` and `jk` 80 | pub inputs: Vec, 81 | /// Output subscript. This may be empty for "implicit mode". 82 | pub output: Option, 83 | } 84 | 85 | impl std::str::FromStr for RawSubscripts { 86 | type Err = Error; 87 | fn from_str(input: &str) -> Result { 88 | use nom::Finish; 89 | if let Ok((_, ss)) = subscripts(input).finish() { 90 | Ok(ss) 91 | } else { 92 | bail!("Invalid subscripts: {}", input); 93 | } 94 | } 95 | } 96 | 97 | /// subscripts = [subscript] {`,` [subscript]} \[ `->` [subscript] \] 98 | pub fn subscripts(input: &str) -> IResult<&str, RawSubscripts> { 99 | let (input, _head) = multispace0(input)?; 100 | let (input, inputs) = separated_list1(tuple((multispace0, char(','))), subscript)(input)?; 101 | let (input, output) = opt(tuple((multispace0, tag("->"), multispace0, subscript)) 102 | .map(|(_space_pre, _arrow, _space_post, output)| output))(input)?; 103 | Ok((input, RawSubscripts { inputs, output })) 104 | } 105 | 106 | #[cfg(test)] 107 | mod tests { 108 | 109 | use super::*; 110 | use nom::Finish; 111 | 112 | #[test] 113 | fn test_subscript() { 114 | let (res, out) = subscript("ijk").finish().unwrap(); 115 | assert_eq!(out, RawSubscript::Indices(vec!['i', 'j', 'k'])); 116 | assert_eq!(res, ""); 117 | 118 | let (res, out) = subscript("...").finish().unwrap(); 119 | assert_eq!( 120 | out, 121 | RawSubscript::Ellipsis { 122 | start: Vec::new(), 123 | end: Vec::new() 124 | } 125 | ); 126 | assert_eq!(res, ""); 127 | 128 | let (res, out) = subscript("i...").finish().unwrap(); 129 | assert_eq!( 130 | out, 131 | RawSubscript::Ellipsis { 132 | start: vec!['i'], 133 | end: Vec::new() 134 | } 135 | ); 136 | assert_eq!(res, ""); 137 | 138 | let (res, out) = subscript("...j").finish().unwrap(); 139 | assert_eq!( 140 | out, 141 | RawSubscript::Ellipsis { 142 | start: Vec::new(), 143 | end: vec!['j'], 144 | } 145 | ); 146 | assert_eq!(res, ""); 147 | 148 | let (res, out) = subscript("i...j").finish().unwrap(); 149 | assert_eq!( 150 | out, 151 | RawSubscript::Ellipsis { 152 | start: vec!['i'], 153 | end: vec!['j'], 154 | } 155 | ); 156 | assert_eq!(res, ""); 157 | } 158 | 159 | #[test] 160 | fn test_operator() { 161 | fn test(input: &str) { 162 | dbg!(input); 163 | let (_, op) = subscripts(input).finish().unwrap(); 164 | assert_eq!( 165 | op, 166 | RawSubscripts { 167 | inputs: vec![ 168 | RawSubscript::Indices(vec!['i', 'j']), 169 | RawSubscript::Indices(vec!['j', 'k']) 170 | ], 171 | output: Some(RawSubscript::Indices(vec!['i', 'k'])), 172 | } 173 | ); 174 | } 175 | test("ij,jk->ik"); 176 | 177 | // with space 178 | test(" ij,jk->ik"); 179 | test("i j,jk->ik"); 180 | test("ij ,jk->ik"); 181 | test("ij, jk->ik"); 182 | test("ij,j k->ik"); 183 | test("ij,jk ->ik"); 184 | test("ij,jk-> ik"); 185 | test("ij,jk->i k"); 186 | 187 | // implicit mode 188 | let (_, op) = subscripts("ij,jk").finish().unwrap(); 189 | assert_eq!( 190 | op, 191 | RawSubscripts { 192 | inputs: vec![ 193 | RawSubscript::Indices(vec!['i', 'j']), 194 | RawSubscript::Indices(vec!['j', 'k']) 195 | ], 196 | output: None, 197 | } 198 | ); 199 | 200 | // with ... 201 | let (_, op) = subscripts("i...,i...->...").finish().unwrap(); 202 | assert_eq!( 203 | op, 204 | RawSubscripts { 205 | inputs: vec![ 206 | RawSubscript::Ellipsis { 207 | start: vec!['i'], 208 | end: Vec::new() 209 | }, 210 | RawSubscript::Ellipsis { 211 | start: vec!['i'], 212 | end: Vec::new() 213 | } 214 | ], 215 | output: Some(RawSubscript::Ellipsis { 216 | start: Vec::new(), 217 | end: Vec::new() 218 | }) 219 | } 220 | ); 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /einsum-codegen/src/path.rs: -------------------------------------------------------------------------------- 1 | //! Execution path 2 | 3 | use crate::*; 4 | use anyhow::Result; 5 | use std::collections::BTreeSet; 6 | 7 | #[derive(Debug, Clone, PartialEq, Eq)] 8 | pub struct Path { 9 | original: Subscripts, 10 | reduced_subscripts: Vec, 11 | } 12 | 13 | impl std::ops::Deref for Path { 14 | type Target = [Subscripts]; 15 | fn deref(&self) -> &[Subscripts] { 16 | &self.reduced_subscripts 17 | } 18 | } 19 | 20 | impl Path { 21 | pub fn output(&self) -> &Subscript { 22 | &self.original.output 23 | } 24 | 25 | pub fn num_args(&self) -> usize { 26 | self.original.inputs.len() 27 | } 28 | 29 | pub fn compute_order(&self) -> usize { 30 | compute_order(&self.reduced_subscripts) 31 | } 32 | 33 | pub fn memory_order(&self) -> usize { 34 | memory_order(&self.reduced_subscripts) 35 | } 36 | 37 | pub fn brute_force(indices: &str) -> Result { 38 | let mut names = Namespace::init(); 39 | let subscripts = Subscripts::from_raw_indices(&mut names, indices)?; 40 | Ok(Path { 41 | original: subscripts.clone(), 42 | reduced_subscripts: brute_force_work(&mut names, subscripts)?, 43 | }) 44 | } 45 | } 46 | 47 | fn compute_order(ss: &[Subscripts]) -> usize { 48 | ss.iter() 49 | .map(|ss| ss.compute_order()) 50 | .max() 51 | .expect("self.0 never be empty") 52 | } 53 | 54 | fn memory_order(ss: &[Subscripts]) -> usize { 55 | ss.iter() 56 | .map(|ss| ss.memory_order()) 57 | .max() 58 | .expect("self.0 never be empty") 59 | } 60 | 61 | fn brute_force_work(names: &mut Namespace, subscripts: Subscripts) -> Result> { 62 | if subscripts.inputs.len() <= 2 { 63 | // Cannot be factorized anymore 64 | return Ok(vec![subscripts]); 65 | } 66 | 67 | let n = subscripts.inputs.len(); 68 | let mut subpaths = (0..2_usize.pow(n as u32)) 69 | .filter_map(|mut m| { 70 | // create combinations specifying which tensors are used 71 | let mut pos = BTreeSet::new(); 72 | for i in 0..n { 73 | if m % 2 == 1 { 74 | pos.insert(*subscripts.inputs[i].position()); 75 | } 76 | m /= 2; 77 | } 78 | // At least two tensors, and not be all 79 | if pos.len() >= 2 && pos.len() < n { 80 | Some(pos) 81 | } else { 82 | None 83 | } 84 | }) 85 | .map(|pos| { 86 | let mut names = names.clone(); 87 | let (inner, outer) = subscripts.factorize(&mut names, pos)?; 88 | let mut sub = brute_force_work(&mut names, outer)?; 89 | sub.insert(0, inner); 90 | Ok(sub) 91 | }) 92 | .collect::>>()?; 93 | subpaths.push(vec![subscripts]); 94 | Ok(subpaths 95 | .into_iter() 96 | .min_by_key(|path| (compute_order(path), memory_order(path))) 97 | .expect("subpath never be empty")) 98 | } 99 | 100 | #[cfg(test)] 101 | mod test { 102 | use super::*; 103 | 104 | #[test] 105 | fn brute_force_ab_bc() -> Result<()> { 106 | let path = Path::brute_force("ab,bc->ac")?; 107 | assert_eq!(path.len(), 1); 108 | assert_eq!(path[0].to_string(), "ab,bc->ac | arg0,arg1->out0"); 109 | Ok(()) 110 | } 111 | 112 | #[test] 113 | fn brute_force_ab_bc_cd_d() -> Result<()> { 114 | let path = Path::brute_force("ab,bc,cd,d->a")?; 115 | assert_eq!(path.len(), 3); 116 | assert_eq!(path[0].to_string(), "ab,b->a | arg2,arg3->out1"); 117 | assert_eq!(path[1].to_string(), "a,ba->b | out1,arg1->out2"); 118 | assert_eq!(path[2].to_string(), "a,ba->b | out2,arg0->out0"); 119 | Ok(()) 120 | } 121 | 122 | #[test] 123 | fn brute_force_a_a_a() -> Result<()> { 124 | let path = Path::brute_force("a,a,a->")?; 125 | assert_eq!(path.len(), 1); 126 | assert_eq!(path[0].to_string(), "a,a,a-> | arg0,arg1,arg2->out0"); 127 | Ok(()) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /einsum-codegen/src/subscripts.rs: -------------------------------------------------------------------------------- 1 | //! Einsum subscripts, e.g. `ij,jk->ik` 2 | use crate::{parser::*, *}; 3 | use anyhow::Result; 4 | use proc_macro2::TokenStream; 5 | use quote::{format_ident, quote, ToTokens, TokenStreamExt}; 6 | use std::{ 7 | collections::{BTreeMap, BTreeSet}, 8 | fmt, 9 | str::FromStr, 10 | }; 11 | 12 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 13 | pub struct Subscript { 14 | raw: RawSubscript, 15 | position: Position, 16 | } 17 | 18 | impl Subscript { 19 | pub fn raw(&self) -> &RawSubscript { 20 | &self.raw 21 | } 22 | 23 | pub fn position(&self) -> &Position { 24 | &self.position 25 | } 26 | 27 | pub fn indices(&self) -> Vec { 28 | match &self.raw { 29 | RawSubscript::Indices(indices) => indices.clone(), 30 | RawSubscript::Ellipsis { start, end } => { 31 | start.iter().chain(end.iter()).cloned().collect() 32 | } 33 | } 34 | } 35 | } 36 | 37 | impl ToTokens for Subscript { 38 | fn to_tokens(&self, tokens: &mut TokenStream) { 39 | ToTokens::to_tokens(&self.position, tokens) 40 | } 41 | } 42 | 43 | #[cfg_attr(doc, katexit::katexit)] 44 | /// Einsum subscripts with tensor names, e.g. `ab,bc->ac | arg0,arg1->out0` 45 | /// 46 | /// Indices are remapped as starting from `a` to distinguish same subscripts, e.g. `i,i->` and `j,j->` 47 | /// 48 | /// ``` 49 | /// use einsum_codegen::{*, parser::RawSubscript}; 50 | /// 51 | /// let mut names = Namespace::init(); 52 | /// let mut ss1 = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap(); 53 | /// 54 | /// let mut names = Namespace::init(); 55 | /// let mut ss2 = Subscripts::from_raw_indices(&mut names, "xz,zy,yw->xw").unwrap(); 56 | /// 57 | /// assert_eq!(ss1, ss2); 58 | /// assert_eq!(ss1.to_string(), "ab,bc,cd->ad | arg0,arg1,arg2->out0"); 59 | /// assert_eq!(ss2.to_string(), "ab,bc,cd->ad | arg0,arg1,arg2->out0"); 60 | /// ``` 61 | #[derive(Clone, PartialEq, Eq)] 62 | pub struct Subscripts { 63 | /// Input subscript, `ij` and `jk` 64 | pub inputs: Vec, 65 | /// Output subscript. 66 | pub output: Subscript, 67 | } 68 | 69 | // `ij,jk->ik | arg0,arg1->out0` format 70 | impl fmt::Debug for Subscripts { 71 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 72 | for (n, input) in self.inputs.iter().enumerate() { 73 | write!(f, "{}", input.raw)?; 74 | if n < self.inputs.len() - 1 { 75 | write!(f, ",")?; 76 | } 77 | } 78 | write!(f, "->{} | ", self.output.raw)?; 79 | 80 | for (n, input) in self.inputs.iter().enumerate() { 81 | write!(f, "{}", input.position)?; 82 | if n < self.inputs.len() - 1 { 83 | write!(f, ",")?; 84 | } 85 | } 86 | write!(f, "->{}", self.output.position)?; 87 | Ok(()) 88 | } 89 | } 90 | 91 | impl fmt::Display for Subscripts { 92 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 93 | fmt::Debug::fmt(self, f) 94 | } 95 | } 96 | 97 | impl ToTokens for Subscripts { 98 | fn to_tokens(&self, tokens: &mut TokenStream) { 99 | let fn_name = format_ident!("{}", self.escaped_ident()); 100 | let args = &self.inputs; 101 | let out = &self.output; 102 | tokens.append_all(quote! { 103 | let #out = #fn_name(#(#args),*); 104 | }); 105 | } 106 | } 107 | 108 | impl Subscripts { 109 | /// Returns $\alpha$ if this subscripts requires $O(N^\alpha)$ floating point operation 110 | pub fn compute_order(&self) -> usize { 111 | self.memory_order() + self.contraction_indices().len() 112 | } 113 | 114 | /// Returns $\beta$ if this subscripts requires $O(N^\beta)$ memory 115 | pub fn memory_order(&self) -> usize { 116 | self.output.indices().len() 117 | } 118 | 119 | /// Normalize subscripts into "explicit mode" 120 | /// 121 | /// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) 122 | /// has "explicit mode" including `->`, e.g. `ij,jk->ik` and 123 | /// "implicit mode" e.g. `ij,jk`. 124 | /// The output subscript is determined from input subscripts in implicit mode: 125 | /// 126 | /// > In implicit mode, the chosen subscripts are important since the axes 127 | /// > of the output are reordered alphabetically. 128 | /// > This means that `np.einsum('ij', a)` doesn’t affect a 2D array, 129 | /// > while `np.einsum('ji', a)` takes its transpose. 130 | /// > Additionally, `np.einsum('ij,jk', a, b)` returns a matrix multiplication, 131 | /// > while, `np.einsum('ij,jh', a, b)` returns the transpose of 132 | /// > the multiplication since subscript ‘h’ precedes subscript ‘i’. 133 | /// 134 | /// ``` 135 | /// use std::str::FromStr; 136 | /// use einsum_codegen::{*, parser::*}; 137 | /// 138 | /// // Infer output subscripts for implicit mode 139 | /// let mut names = Namespace::init(); 140 | /// let raw = RawSubscripts::from_str("ab,bc").unwrap(); 141 | /// let subscripts = Subscripts::from_raw(&mut names, raw); 142 | /// assert_eq!(subscripts.to_string(), "ab,bc->ac | arg0,arg1->out0"); 143 | /// 144 | /// // Reordered alphabetically 145 | /// let mut names = Namespace::init(); // reset namespace 146 | /// let raw = RawSubscripts::from_str("ba").unwrap(); 147 | /// let subscripts = Subscripts::from_raw(&mut names, raw); 148 | /// assert_eq!(subscripts.to_string(), "ab->ba | arg0->out0"); 149 | /// ``` 150 | /// 151 | pub fn from_raw(names: &mut Namespace, raw: RawSubscripts) -> Self { 152 | let inputs = raw 153 | .inputs 154 | .iter() 155 | .enumerate() 156 | .map(|(i, indices)| Subscript { 157 | raw: indices.clone(), 158 | position: Position::Arg(i), 159 | }) 160 | .collect(); 161 | let position = names.new_ident(); 162 | if let Some(output) = raw.output { 163 | let mut cand = Subscripts { 164 | inputs, 165 | output: Subscript { 166 | raw: output, 167 | position, 168 | }, 169 | }; 170 | cand.remap_indices(); 171 | return cand; 172 | } 173 | 174 | let count = count_indices(&inputs); 175 | let output = Subscript { 176 | raw: RawSubscript::Indices( 177 | count 178 | .iter() 179 | .filter_map(|(key, value)| if *value == 1 { Some(*key) } else { None }) 180 | .collect(), 181 | ), 182 | position, 183 | }; 184 | let mut cand = Subscripts { inputs, output }; 185 | cand.remap_indices(); 186 | cand 187 | } 188 | 189 | pub fn from_raw_indices(names: &mut Namespace, indices: &str) -> Result { 190 | let raw = RawSubscripts::from_str(indices)?; 191 | Ok(Self::from_raw(names, raw)) 192 | } 193 | 194 | /// Indices to be contracted 195 | /// 196 | /// ``` 197 | /// use std::str::FromStr; 198 | /// use maplit::btreeset; 199 | /// use einsum_codegen::*; 200 | /// 201 | /// let mut names = Namespace::init(); 202 | /// 203 | /// // Matrix multiplication AB 204 | /// let subscripts = Subscripts::from_raw_indices(&mut names, "ab,bc->ac").unwrap(); 205 | /// assert_eq!(subscripts.contraction_indices(), btreeset!{'b'}); 206 | /// 207 | /// // Reduce all Tr(AB) 208 | /// let subscripts = Subscripts::from_raw_indices(&mut names, "ab,ba->").unwrap(); 209 | /// assert_eq!(subscripts.contraction_indices(), btreeset!{'a', 'b'}); 210 | /// 211 | /// // Take diagonal elements 212 | /// let subscripts = Subscripts::from_raw_indices(&mut names, "aa->a").unwrap(); 213 | /// assert_eq!(subscripts.contraction_indices(), btreeset!{}); 214 | /// ``` 215 | pub fn contraction_indices(&self) -> BTreeSet { 216 | let count = count_indices(&self.inputs); 217 | let mut subscripts: BTreeSet = count 218 | .into_iter() 219 | .filter_map(|(key, value)| if value > 1 { Some(key) } else { None }) 220 | .collect(); 221 | for c in &self.output.indices() { 222 | subscripts.remove(c); 223 | } 224 | subscripts 225 | } 226 | 227 | /// Factorize subscripts 228 | /// 229 | /// ```text 230 | /// ab,bc,cd->ad | arg0,arg1,arg2->out0 231 | /// ``` 232 | /// 233 | /// will be factorized with `(arg0, arg1)` into 234 | /// 235 | /// ```text 236 | /// ab,bc->ac | arg0,arg1 -> out1 237 | /// ab,bc->ac | out1 arg2 -> out0 238 | /// ``` 239 | /// 240 | /// Be sure that the indices of `out1` in the first step `ac` is renamed 241 | /// into `ab` in the second step. 242 | /// 243 | /// ``` 244 | /// use einsum_codegen::{*, parser::RawSubscript}; 245 | /// use std::str::FromStr; 246 | /// use maplit::btreeset; 247 | /// 248 | /// let mut names = Namespace::init(); 249 | /// let base = Subscripts::from_raw_indices(&mut names, "ab,bc,cd->ad").unwrap(); 250 | /// 251 | /// let (step1, step2) = base.factorize(&mut names, 252 | /// btreeset!{ Position::Arg(0), Position::Arg(1) } 253 | /// ).unwrap(); 254 | /// 255 | /// assert_eq!(step1.to_string(), "ab,bc->ac | arg0,arg1->out1"); 256 | /// assert_eq!(step2.to_string(), "ab,bc->ac | out1,arg2->out0"); 257 | /// ``` 258 | pub fn factorize( 259 | &self, 260 | names: &mut Namespace, 261 | inners: BTreeSet, 262 | ) -> Result<(Self, Self)> { 263 | let mut inner_inputs = Vec::new(); 264 | let mut outer_inputs = Vec::new(); 265 | let mut indices: BTreeMap = BTreeMap::new(); 266 | for input in &self.inputs { 267 | if inners.contains(&input.position) { 268 | inner_inputs.push(input.clone()); 269 | for c in input.indices() { 270 | indices 271 | .entry(c) 272 | .and_modify(|(i, _)| *i += 1) 273 | .or_insert((1, 0)); 274 | } 275 | } else { 276 | outer_inputs.push(input.clone()); 277 | for c in input.indices() { 278 | indices 279 | .entry(c) 280 | .and_modify(|(_, o)| *o += 1) 281 | .or_insert((0, 1)); 282 | } 283 | } 284 | } 285 | let out = Subscript { 286 | raw: RawSubscript::Indices( 287 | indices 288 | .into_iter() 289 | .filter_map(|(key, (i, o))| { 290 | if i == 1 || (i >= 2 && o > 0) { 291 | Some(key) 292 | } else { 293 | None 294 | } 295 | }) 296 | .collect(), 297 | ), 298 | position: names.new_ident(), 299 | }; 300 | outer_inputs.insert(0, out.clone()); 301 | 302 | let mut inner = Subscripts { 303 | inputs: inner_inputs, 304 | output: out, 305 | }; 306 | let mut outer = Subscripts { 307 | inputs: outer_inputs, 308 | output: self.output.clone(), 309 | }; 310 | inner.remap_indices(); 311 | outer.remap_indices(); 312 | Ok((inner, outer)) 313 | } 314 | 315 | /// Escaped subscript for identifier 316 | /// 317 | /// This is not injective, e.g. `i...,j->ij` and `i,...j->ij` 318 | /// returns a same result `i____j__ij`. 319 | /// 320 | pub fn escaped_ident(&self) -> String { 321 | use std::fmt::Write; 322 | let mut out = String::new(); 323 | for input in &self.inputs { 324 | write!(out, "{}", input.raw).unwrap(); 325 | write!(out, "_").unwrap(); 326 | } 327 | write!(out, "_{}", self.output.raw).unwrap(); 328 | out 329 | } 330 | 331 | fn remap_indices(&mut self) { 332 | let mut map: BTreeMap = BTreeMap::new(); 333 | let mut update = |raw: &mut RawSubscript| match raw { 334 | RawSubscript::Indices(indices) => { 335 | for i in indices { 336 | if !map.contains_key(i) { 337 | map.insert(*i, 'a' as u32 + map.len() as u32); 338 | } 339 | *i = char::from_u32(map[i]).unwrap(); 340 | } 341 | } 342 | RawSubscript::Ellipsis { start, end } => { 343 | for i in start.iter_mut().chain(end.iter_mut()) { 344 | if !map.contains_key(i) { 345 | map.insert(*i, 'a' as u32 + map.len() as u32); 346 | } 347 | *i = char::from_u32(map[i]).unwrap(); 348 | } 349 | } 350 | }; 351 | for input in &mut self.inputs { 352 | update(&mut input.raw); 353 | } 354 | update(&mut self.output.raw) 355 | } 356 | } 357 | 358 | fn count_indices(inputs: &[Subscript]) -> BTreeMap { 359 | let mut count = BTreeMap::new(); 360 | for input in inputs { 361 | for c in input.indices() { 362 | count.entry(c).and_modify(|n| *n += 1).or_insert(1); 363 | } 364 | } 365 | count 366 | } 367 | 368 | #[cfg(test)] 369 | mod tests { 370 | use super::*; 371 | 372 | #[test] 373 | fn escaped_ident() { 374 | let mut names = Namespace::init(); 375 | 376 | let subscripts = Subscripts::from_raw_indices(&mut names, "ab,bc->ac").unwrap(); 377 | assert_eq!(subscripts.escaped_ident(), "ab_bc__ac"); 378 | 379 | // implicit mode 380 | let subscripts = Subscripts::from_raw_indices(&mut names, "ab,bc").unwrap(); 381 | assert_eq!(subscripts.escaped_ident(), "ab_bc__ac"); 382 | 383 | // output scalar 384 | let subscripts = Subscripts::from_raw_indices(&mut names, "a,a").unwrap(); 385 | assert_eq!(subscripts.escaped_ident(), "a_a__"); 386 | 387 | // ellipsis 388 | let subscripts = Subscripts::from_raw_indices(&mut names, "ab...,bc...->ac...").unwrap(); 389 | assert_eq!(subscripts.escaped_ident(), "ab____bc_____ac___"); 390 | } 391 | } 392 | -------------------------------------------------------------------------------- /einsum-derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "einsum-derive" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Toshiki Teramura "] 6 | 7 | description = "Proc-macro based einsum implementation" 8 | documentation = "https://docs.rs/einsum-derive/" 9 | repository = "https://github.com/termoshtt/einsum-derive" 10 | keywords = ["ndarray", "matrix", "einsum", "proc-macro"] 11 | license = "MIT OR Apache-2.0" 12 | readme = "../README.md" 13 | categories = ["algorithms", "science"] 14 | 15 | [lib] 16 | proc-macro = true 17 | 18 | [dependencies] 19 | proc-macro-error = "1.0.4" 20 | proc-macro2 = "1.0.46" 21 | quote = "1.0.21" 22 | syn = "1.0.102" 23 | 24 | [dev-dependencies] 25 | codspeed-criterion-compat = "2.4.0" 26 | criterion = { version = "0.4.0", features = ["html_reports"] } 27 | insta = "1.21.0" 28 | ndarray = "0.15.6" 29 | ndarray-linalg = "0.16.0" 30 | trybuild = "1.0.71" 31 | 32 | [dependencies.einsum-codegen] 33 | path = "../einsum-codegen" 34 | version = "0.1.0" 35 | 36 | [[bench]] 37 | name = "einsum" 38 | harness = false 39 | -------------------------------------------------------------------------------- /einsum-derive/README.md: -------------------------------------------------------------------------------- 1 | einsum-derive 2 | =============== 3 | Proc-macro based einsum implementation for [ndarray](https://crates.io/crates/ndarray) crate 4 | 5 | ```rust 6 | use ndarray::array; 7 | use einsum_derive::einsum; 8 | 9 | let a = array![ 10 | [1.0, 2.0], 11 | [3.0, 4.0] 12 | ]; 13 | let b = array![ 14 | [1.0, 2.0], 15 | [3.0, 4.0] 16 | ]; 17 | let c = einsum!("ij,jk->ik", a, b); 18 | assert_eq!(c, array![ 19 | [6.0, 8.0], 20 | [12.0, 16.0] 21 | ]); 22 | ``` 23 | 24 | This proc-macro wil compile the input subscripts `"ij,jk->ik"` 25 | to generate Rust code executing corresponding operation. 26 | 27 | Status / Roadmap 28 | ----------------- 29 | - [x] [Optimal contraction by memorizing partial summation to reduce computation order.](https://github.com/termoshtt/einsum-derive/pull/18) 30 | - For example, three matrix multiplication `ij,jk,kl->il` is factorized into 31 | two successive einsum `ij,jk->ik` and `ik,kl->il`. 32 | - [ ] [Call BLAS routines if possible](https://github.com/termoshtt/einsum-derive/issues/22) 33 | - [ ] [Ellipsis `...` support](https://github.com/termoshtt/einsum-derive/issues/7) 34 | 35 | Architecture 36 | ------------- 37 | | | crates.io | docs.rs | GitHub Pages | Description | 38 | |:---------------|:---------:|:-------:|:------------:|:------------| 39 | | einsum-derive | [![crate](https://img.shields.io/crates/v/einsum-derive.svg)](https://crates.io/crates/einsum-derive) | [![docs.rs](https://docs.rs/einsum-derive/badge.svg)](https://docs.rs/einsum-derive) | [![Pages](https://img.shields.io/badge/docs-main-blue)](https://termoshtt.github.io/einsum-derive/doc/einsum_derive/index.html) | proc-macro crate to provide `einsum!` macro | 40 | | einsum-codegen | [![crate](https://img.shields.io/crates/v/einsum-codegen.svg)](https://crates.io/crates/einsum-codegen) | [![docs.rs](https://docs.rs/einsum-codegen/badge.svg)](https://docs.rs/einsum-codegen) | [![Pages](https://img.shields.io/badge/docs-main-blue)](https://termoshtt.github.io/einsum-codegen/doc/einsum_codegen/index.html) | Implements parser for the einsum subscripts and generates Rust code | 41 | 42 | Benchmark 43 | ---------- 44 | [![bench](https://img.shields.io/badge/benchmark-main-yellow)](https://termoshtt.github.io/einsum-derive/bench/report/index.html) 45 | 46 | Benchmark with [criterion.rs](https://github.com/bheisler/criterion.rs) is running on GitHub Action on every commit on the main branch. 47 | The code is placed at [einsum-derive/benches/einsum.rs](./einsum-derive/benches/einsum.rs), and you can run it on your environment by 48 | 49 | ```shell 50 | cargo bench 51 | ``` 52 | 53 | and you will find its result on `target/criterion/report/index.html`. 54 | 55 | License 56 | -------- 57 | 58 | © 2022 Toshiki Teramura (@termoshtt) 59 | 60 | This project is licensed under either of 61 | 62 | - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or https://www.apache.org/licenses/LICENSE-2.0) 63 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or https://opensource.org/licenses/MIT) 64 | 65 | at your option. 66 | 67 | Links 68 | ------ 69 | - [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) is well-known einsum implementation in Python. 70 | - [opt_einsum](https://optimized-einsum.readthedocs.io/en/stable/) is an implementation for optimizing einsum computation for NumPy and other linear algebra packages. 71 | - [oracleofnj/einsum](https://github.com/oracleofnj/einsum) is a runtime-based implementation of einsum for rust-ndarray 72 | -------------------------------------------------------------------------------- /einsum-derive/benches/einsum.rs: -------------------------------------------------------------------------------- 1 | use codspeed_criterion_compat::*; 2 | use einsum_derive::einsum; 3 | use ndarray::*; 4 | use ndarray_linalg::*; 5 | 6 | fn einsum_bench(c: &mut Criterion) { 7 | let mut group = c.benchmark_group("einsum"); 8 | group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 9 | 10 | for &n in &[4, 8, 16, 32, 64, 128] { 11 | group.bench_with_input(BenchmarkId::new("ij_jk", n), &n, |bench, n| { 12 | let a: Array2 = random((*n, *n)); 13 | let b: Array2 = random((*n, *n)); 14 | bench.iter(|| { 15 | let _c = einsum!("ij,jk", a.clone(), b.clone()); 16 | }) 17 | }); 18 | 19 | group.bench_with_input(BenchmarkId::new("ij_jk_kl", n), &n, |bench, n| { 20 | let a: Array2 = random((*n, *n)); 21 | let b: Array2 = random((*n, *n)); 22 | let c: Array2 = random((*n, *n)); 23 | bench.iter(|| { 24 | let _c = einsum!("ij,jk,kl", a.clone(), b.clone(), c.clone()); 25 | }) 26 | }); 27 | } 28 | } 29 | 30 | criterion_group!(einsum, einsum_bench); 31 | criterion_main!(einsum); 32 | -------------------------------------------------------------------------------- /einsum-derive/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | 3 | use einsum_codegen::{codegen::ndarray::*, *}; 4 | use proc_macro::TokenStream; 5 | use proc_macro2::TokenStream as TokenStream2; 6 | use proc_macro_error::{abort_call_site, proc_macro_error}; 7 | use quote::quote; 8 | use std::collections::BTreeSet; 9 | use syn::parse::Parser; 10 | 11 | /// proc-macro based einsum 12 | #[proc_macro_error] 13 | #[proc_macro] 14 | pub fn einsum(input: TokenStream) -> TokenStream { 15 | einsum2(input.into()).into() 16 | } 17 | 18 | fn einsum2(input: TokenStream2) -> TokenStream2 { 19 | let (subscripts, args) = parse(input); 20 | let arg_ident: Vec<_> = (0..args.len()).map(Position::Arg).collect(); 21 | let path = Path::brute_force(&subscripts).expect("Failed to construct execution path"); 22 | let mut defined = BTreeSet::new(); 23 | let fn_defs: Vec<_> = path 24 | .iter() 25 | .filter_map(|ss| { 26 | if defined.contains(&ss.escaped_ident()) { 27 | None 28 | } else { 29 | defined.insert(ss.escaped_ident()); 30 | let inner = naive::inner(ss); 31 | Some(function_definition(ss, inner)) 32 | } 33 | }) 34 | .collect(); 35 | let out = path.output(); 36 | if path.num_args() != args.len() { 37 | abort_call_site!( 38 | "Argument number mismatch: subscripts ({}), args ({})", 39 | path.num_args(), 40 | args.len() 41 | ) 42 | } 43 | 44 | quote! { 45 | { 46 | #(#fn_defs)* 47 | #(let #arg_ident = #args;)* 48 | #(#path)* 49 | #out 50 | } 51 | } 52 | } 53 | 54 | fn parse(input: TokenStream2) -> (String, Vec) { 55 | let parser = syn::punctuated::Punctuated::::parse_terminated; 56 | let args = parser.parse2(input).expect("Invalid input for einsum!"); 57 | let mut iter = args.into_iter(); 58 | let subscripts = if let Some(syn::Expr::Lit(syn::ExprLit { 59 | lit: syn::Lit::Str(lit), 60 | attrs: _, 61 | })) = iter.next() 62 | { 63 | lit.value() 64 | } else { 65 | panic!("einsum! must start with subscript string literal") 66 | }; 67 | let args = iter.collect::>(); 68 | (subscripts, args) 69 | } 70 | 71 | #[cfg(test)] 72 | mod test { 73 | use super::*; 74 | use einsum_codegen::codegen::format_block; 75 | use std::str::FromStr; 76 | 77 | #[test] 78 | fn test_parse() { 79 | let input = TokenStream2::from_str(r#""ab,bc->ac", x, y"#).unwrap(); 80 | let (subscripts, exprs) = parse(input); 81 | assert_eq!(subscripts, "ab,bc->ac"); 82 | assert_eq!(exprs.len(), 2); 83 | assert_eq!(exprs[0], syn::parse_str::("x").unwrap()); 84 | assert_eq!(exprs[1], syn::parse_str::("y").unwrap()); 85 | } 86 | 87 | #[test] 88 | fn einsum_ab_bc() { 89 | let input = TokenStream2::from_str(r#""ab,bc->ac", x, y"#).unwrap(); 90 | let tt = format_block(einsum2(input).to_string()); 91 | insta::assert_snapshot!(tt, @r###" 92 | { 93 | fn ab_bc__ac( 94 | arg0: ndarray::ArrayBase, 95 | arg1: ndarray::ArrayBase, 96 | ) -> ndarray::Array 97 | where 98 | T: ndarray::LinalgScalar, 99 | S0: ndarray::Data, 100 | S1: ndarray::Data, 101 | { 102 | let (n_a, n_b) = arg0.dim(); 103 | let (_, n_c) = arg1.dim(); 104 | { 105 | let (n_0, n_1) = arg0.dim(); 106 | assert_eq!(n_0, n_a); 107 | assert_eq!(n_1, n_b); 108 | } 109 | { 110 | let (n_0, n_1) = arg1.dim(); 111 | assert_eq!(n_0, n_b); 112 | assert_eq!(n_1, n_c); 113 | } 114 | let mut out0 = ndarray::Array::zeros((n_a, n_c)); 115 | for a in 0..n_a { 116 | for c in 0..n_c { 117 | for b in 0..n_b { 118 | out0[(a, c)] = arg0[(a, b)] * arg1[(b, c)]; 119 | } 120 | } 121 | } 122 | out0 123 | } 124 | let arg0 = x; 125 | let arg1 = y; 126 | let out0 = ab_bc__ac(arg0, arg1); 127 | out0 128 | } 129 | "###); 130 | } 131 | 132 | #[test] 133 | fn einsum_ab_bc_cd() { 134 | let input = TokenStream2::from_str(r#""ab,bc,cd->ad", x, y, z"#).unwrap(); 135 | let tt = format_block(einsum2(input).to_string()); 136 | insta::assert_snapshot!(tt, @r###" 137 | { 138 | fn ab_bc__ac( 139 | arg0: ndarray::ArrayBase, 140 | arg1: ndarray::ArrayBase, 141 | ) -> ndarray::Array 142 | where 143 | T: ndarray::LinalgScalar, 144 | S0: ndarray::Data, 145 | S1: ndarray::Data, 146 | { 147 | let (n_a, n_b) = arg0.dim(); 148 | let (_, n_c) = arg1.dim(); 149 | { 150 | let (n_0, n_1) = arg0.dim(); 151 | assert_eq!(n_0, n_a); 152 | assert_eq!(n_1, n_b); 153 | } 154 | { 155 | let (n_0, n_1) = arg1.dim(); 156 | assert_eq!(n_0, n_b); 157 | assert_eq!(n_1, n_c); 158 | } 159 | let mut out1 = ndarray::Array::zeros((n_a, n_c)); 160 | for a in 0..n_a { 161 | for c in 0..n_c { 162 | for b in 0..n_b { 163 | out1[(a, c)] = arg0[(a, b)] * arg1[(b, c)]; 164 | } 165 | } 166 | } 167 | out1 168 | } 169 | let arg0 = x; 170 | let arg1 = y; 171 | let arg2 = z; 172 | let out1 = ab_bc__ac(arg0, arg1); 173 | let out0 = ab_bc__ac(out1, arg2); 174 | out0 175 | } 176 | "###); 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /einsum-derive/tests/cases/number_of_arguments_mismatch.rs: -------------------------------------------------------------------------------- 1 | use einsum_derive::einsum; 2 | use ndarray::array; 3 | 4 | fn main() { 5 | let a = array![[1.0, 2.0], [3.0, 4.0]]; 6 | let c = einsum!("ij,jk->ik", a /* needs one more arg */); 7 | } 8 | -------------------------------------------------------------------------------- /einsum-derive/tests/cases/number_of_arguments_mismatch.stderr: -------------------------------------------------------------------------------- 1 | error: Argument number mismatch: subscripts (2), args (1) 2 | --> tests/cases/number_of_arguments_mismatch.rs:6:13 3 | | 4 | 6 | let c = einsum!("ij,jk->ik", a /* needs one more arg */); 5 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | | 7 | = note: this error originates in the macro `einsum` (in Nightly builds, run with -Z macro-backtrace for more info) 8 | 9 | error: expected expression, found end of macro arguments 10 | --> tests/cases/number_of_arguments_mismatch.rs:6:13 11 | | 12 | 6 | let c = einsum!("ij,jk->ik", a /* needs one more arg */); 13 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 14 | -------------------------------------------------------------------------------- /einsum-derive/tests/trybuild.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn trybuild() { 3 | let t = trybuild::TestCases::new(); 4 | t.compile_fail("tests/cases/number_of_arguments_mismatch.rs"); 5 | } 6 | -------------------------------------------------------------------------------- /rust-toolchain: -------------------------------------------------------------------------------- 1 | nightly 2 | --------------------------------------------------------------------------------