├── .codecov.yml ├── .github ├── CONTRIBUTING.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── Docs.yaml │ ├── Linting.yml │ └── Tests.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── Makefile ├── README.md ├── devtools ├── RELEASE.md ├── allowlist.txt ├── ci_scripts │ └── check_no_numpy.py ├── conda-envs │ ├── full-environment.yaml │ ├── min-deps-environment.yaml │ ├── min-ver-environment.yaml │ └── torch-only-environment.yaml └── conda-recipe │ └── meta.yaml ├── docs ├── api_reference.md ├── changelog.md ├── css │ └── custom.css ├── examples │ ├── dask_reusing_intermediaries.md │ └── large_expr_with_greedy.md ├── getting_started │ ├── backends.md │ ├── input_format.md │ ├── install.md │ ├── reusing_paths.md │ └── sharing_intermediates.md ├── img │ ├── ex_dask_reuse_graph.png │ ├── path_finding_time.png │ ├── path_found_flops.png │ └── path_found_flops_random.png ├── index.md ├── javascript │ └── config.js ├── paths │ ├── branching_path.md │ ├── custom_paths.md │ ├── dp_path.md │ ├── greedy_path.md │ ├── introduction.md │ ├── optimal_path.md │ └── random_greedy_path.md └── requirements.yml ├── mkdocs.yml ├── opt_einsum ├── __init__.py ├── _version.py ├── backends │ ├── __init__.py │ ├── cupy.py │ ├── dispatch.py │ ├── jax.py │ ├── object_arrays.py │ ├── tensorflow.py │ ├── theano.py │ └── torch.py ├── blas.py ├── contract.py ├── helpers.py ├── parser.py ├── path_random.py ├── paths.py ├── sharing.py ├── testing.py ├── tests │ ├── __init__.py │ ├── test_backends.py │ ├── test_blas.py │ ├── test_contract.py │ ├── test_edge_cases.py │ ├── test_input.py │ ├── test_parser.py │ ├── test_paths.py │ └── test_sharing.py └── typing.py ├── paper ├── paper.bib └── paper.md ├── pyproject.toml └── scripts ├── README.md └── compare_random_paths.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | ignore: 3 | - test/.* 4 | - compare/.* 5 | - test_helper.py 6 | - setup.py 7 | status: 8 | patch: false 9 | project: 10 | default: 11 | threshold: 80% 12 | comment: 13 | layout: "header" 14 | require_changes: false 15 | branches: null 16 | behavior: default 17 | flags: null 18 | paths: null 19 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We welcome contributions from external contributors, and this document 4 | describes how to merge code changes into this `opt_einsum`. 5 | 6 | ## Getting Started 7 | 8 | * Make sure you have a [GitHub account](https://github.com/signup/free). 9 | * [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub. 10 | * On your local machine, 11 | [clone](https://help.github.com/articles/cloning-a-repository/) your fork of 12 | the repository. 13 | 14 | ## Making Changes 15 | 16 | * Add some really awesome code to your local fork. It's usually a [good 17 | idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/) 18 | to make changes on a 19 | [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/) 20 | with the branch name relating to the feature you are going to add. 21 | * When you are ready for others to examine and comment on your new feature, 22 | navigate to your fork of `opt_einsum` on GitHub and open a [pull 23 | request](https://help.github.com/articles/using-pull-requests/) (PR). Note that 24 | after you launch a PR from one of your fork's branches, all 25 | subsequent commits to that branch will be added to the open pull request 26 | automatically. Each commit added to the PR will be validated for 27 | mergability, compilation and test suite compliance; the results of these tests 28 | will be visible on the PR page. 29 | * If you're providing a new feature, you must add test cases and documentation. 30 | * When the code is ready to go, make sure you run the test suite using pytest. 31 | * When you're ready to be considered for merging, check the "Ready to go" 32 | box on the PR page to let the `opt_einsum` devs know that the changes are complete. 33 | The code will not be merged until this box is checked, the continuous 34 | integration returns checkmarks, 35 | and multiple core developers give "Approved" reviews. 36 | 37 | ## Additional Resources 38 | 39 | * [General GitHub documentation](https://help.github.com/) 40 | * [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/) 41 | * [A guide to contributing to software packages](http://www.contribution-guide.org) 42 | * [Thinkful PR example](http://www.thinkful.com/learn/github-pull-request-tutorial/#Time-to-Submit-Your-First-PR) 43 | 44 | ## Code of Conduct 45 | 46 | ### Our Pledge 47 | 48 | In the interest of fostering an open and welcoming environment, we as 49 | contributors and maintainers pledge to making participation in our project and 50 | our community a harassment-free experience for everyone, regardless of age, body 51 | size, disability, ethnicity, gender identity and expression, level of experience, 52 | nationality, personal appearance, race, religion, or sexual identity and 53 | orientation. 54 | 55 | ### Our Standards 56 | 57 | Examples of behavior that contributes to creating a positive environment 58 | include: 59 | 60 | * Using welcoming and inclusive language 61 | * Being respectful of differing viewpoints and experiences 62 | * Gracefully accepting constructive criticism 63 | * Focusing on what is best for the community 64 | * Showing empathy towards other community members 65 | 66 | Examples of unacceptable behavior by participants include: 67 | 68 | * The use of sexualized language or imagery and unwelcome sexual attention or 69 | advances 70 | * Trolling, insulting/derogatory comments, and personal or political attacks 71 | * Public or private harassment 72 | * Publishing others' private information, such as a physical or electronic 73 | address, without explicit permission 74 | * Other conduct which could reasonably be considered inappropriate in a 75 | professional setting 76 | 77 | ### Our Responsibilities 78 | 79 | Project maintainers are responsible for clarifying the standards of acceptable 80 | behavior and are expected to take appropriate and fair corrective action in 81 | response to any instances of unacceptable behavior. 82 | 83 | Project maintainers have the right and responsibility to remove, edit, or 84 | reject comments, commits, code, wiki edits, issues, and other contributions 85 | that are not aligned to this Code of Conduct, or to ban temporarily or 86 | permanently any contributor for other behaviors that they deem inappropriate, 87 | threatening, offensive, or harmful. 88 | 89 | ### Scope 90 | 91 | This Code of Conduct applies both within project spaces and in public spaces 92 | when an individual is representing the project or its community. Examples of 93 | representing a project or community include using an official project e-mail 94 | address, posting via an official social media account, or acting as an appointed 95 | representative at an online or offline event. Representation of a project may be 96 | further defined and clarified by project maintainers. 97 | 98 | ### Enforcement 99 | 100 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 101 | reported by contacting the project team at [dgasmith@vt.edu]. All 102 | complaints will be reviewed and investigated and will result in a response that 103 | is deemed necessary and appropriate to the circumstances. The project team is 104 | obligated to maintain confidentiality with regard to the reporter of an incident. 105 | Further details of specific enforcement policies may be posted separately. 106 | 107 | Project maintainers who do not follow or enforce the Code of Conduct in good 108 | faith may face temporary or permanent repercussions as determined by other 109 | members of the project's leadership. 110 | 111 | ### Attribution 112 | 113 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 114 | available at [http://contributor-covenant.org/version/1/4][version] 115 | 116 | [homepage]: http://contributor-covenant.org 117 | [version]: http://contributor-covenant.org/version/1/4/ 118 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Provide a brief description of the PR's purpose here. 3 | 4 | ## Todos 5 | Notable points that this PR has either accomplished or will accomplish. 6 | - [ ] TODO 1 7 | 8 | ## Questions 9 | - [ ] Question1 10 | 11 | ## Status 12 | - [ ] Ready to go -------------------------------------------------------------------------------- /.github/workflows/Docs.yaml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: 3.x 19 | 20 | - name: Install 21 | run: | 22 | pip install -r docs/requirements.yml 23 | pip install -e . 24 | 25 | - name: Build Docs 26 | run: mkdocs build 27 | 28 | # Only deploy if main, otherwise just build for testing 29 | - name: Deploy 30 | if: endsWith(github.ref, '/main') 31 | run: mkdocs gh-deploy --force 32 | -------------------------------------------------------------------------------- /.github/workflows/Linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | mypy: 11 | name: MyPy 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: [3.12] 16 | environment: ["min-deps"] 17 | 18 | runs-on: ubuntu-latest 19 | defaults: 20 | run: 21 | shell: bash -el {0} 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - uses: conda-incubator/setup-miniconda@v3 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | channel-priority: true 30 | activate-environment: test 31 | environment-file: devtools/conda-envs/${{ matrix.environment }}-environment.yaml 32 | 33 | - name: Environment Information 34 | run: | 35 | conda info 36 | conda list 37 | conda config --show-sources 38 | conda config --show 39 | 40 | - name: Install 41 | run: python -m pip install . --no-deps 42 | 43 | - name: MyPy 44 | run: mypy opt_einsum 45 | 46 | ruff: 47 | name: Ruff 48 | strategy: 49 | fail-fast: false 50 | matrix: 51 | python-version: [3.12] 52 | environment: ["min-deps"] 53 | 54 | runs-on: ubuntu-latest 55 | defaults: 56 | run: 57 | shell: bash -el {0} 58 | 59 | steps: 60 | - uses: actions/checkout@v2 61 | 62 | - uses: conda-incubator/setup-miniconda@v3 63 | with: 64 | python-version: ${{ matrix.python-version }} 65 | channel-priority: true 66 | activate-environment: test 67 | environment-file: devtools/conda-envs/${{ matrix.environment }}-environment.yaml 68 | 69 | - name: Environment Information 70 | run: | 71 | conda info 72 | conda list 73 | conda config --show-sources 74 | conda config --show 75 | 76 | - name: Lint 77 | run: | 78 | set -e 79 | make fmt-check 80 | -------------------------------------------------------------------------------- /.github/workflows/Tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | miniconda-setup: 11 | name: Env 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | include: 16 | - python-version: 3.8 17 | environment: "min-deps" 18 | - python-version: 3.12 19 | environment: "min-deps" 20 | - python-version: 3.9 21 | environment: "min-ver" 22 | - python-version: 3.11 23 | environment: "full" 24 | - python-version: 3.12 25 | environment: "torch-only" 26 | 27 | runs-on: ubuntu-latest 28 | defaults: 29 | run: 30 | shell: bash -el {0} 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | 35 | - uses: conda-incubator/setup-miniconda@v3 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | channel-priority: true 39 | activate-environment: test 40 | environment-file: devtools/conda-envs/${{ matrix.environment }}-environment.yaml 41 | 42 | - name: Environment Information 43 | run: | 44 | conda info 45 | conda list 46 | conda config --show-sources 47 | conda config --show 48 | 49 | - name: Check no NumPy for torch-only environment 50 | if: matrix.environment == 'torch-only' 51 | run: | 52 | python devtools/ci_scripts/check_no_numpy.py 53 | 54 | - name: Install 55 | run: | 56 | python -m pip install . --no-deps 57 | 58 | - name: Test 59 | run: | 60 | pytest -v --cov=opt_einsum opt_einsum/ --cov-report=xml 61 | 62 | - name: Coverage 63 | run: | 64 | coverage report 65 | 66 | - uses: codecov/codecov-action@v4 67 | with: 68 | token: ${{ secrets.CODECOV_TOKEN }} 69 | files: ./coverage.xml 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Hatch replaces this file 2 | opt_einsum/_version.py 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Vim scratch files 12 | .swp 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | *.swp 53 | 54 | # Django stuff: 55 | *.log 56 | 57 | # PyCharm junk 58 | .idea 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | \.pytest_cache/ 67 | 68 | docs/source/autosummary/ 69 | site/ 70 | 71 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Smith 5 | given-names: Daniel 6 | orcid: https://orcid.org/0000-0001-8626-0900 7 | - family-names: Gray 8 | given-names: Johnnie 9 | orcid: https://orcid.org/0000-0001-9461-3024 10 | title: "`opt_einsum` - A Python package for optimizing contraction order for einsum-like expressions" 11 | version: 3.3.0 12 | doi: 10.21105/joss.00753 13 | date-released: 2019-06-28 14 | url: "https://github.com/dgasmith/opt_einsum" 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Daniel Smith 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := all 2 | 3 | .PHONY: install 4 | install: 5 | pip install -e . 6 | 7 | .PHONY: fmt 8 | fmt: 9 | ruff check opt_einsum --fix 10 | ruff format opt_einsum 11 | 12 | .PHONY: fmt-unsafe 13 | fmt-unsafe: 14 | ruff check opt_einsum --fix --unsafe-fixes 15 | ruff format opt_einsum 16 | 17 | .PHONY: fmt-check 18 | fmt-check: 19 | ruff check opt_einsum 20 | ruff format --check opt_einsum 21 | 22 | .PHONY: mypy 23 | mypy: 24 | mypy opt_einsum 25 | 26 | .PHONY: test 27 | test: 28 | pytest -v --cov=opt_einsum/ 29 | 30 | .PHONY: docs 31 | docs: 32 | mkdocs build 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimized Einsum 2 | 3 | [![Tests](https://github.com/dgasmith/opt_einsum/actions/workflows/Tests.yml/badge.svg)](https://github.com/dgasmith/opt_einsum/actions/workflows/Tests.yml) 4 | [![codecov](https://codecov.io/gh/dgasmith/opt_einsum/branch/master/graph/badge.svg)](https://codecov.io/gh/dgasmith/opt_einsum) 5 | [![Anaconda-Server Badge](https://anaconda.org/conda-forge/opt_einsum/badges/version.svg)](https://anaconda.org/conda-forge/opt_einsum) 6 | [![PyPI](https://img.shields.io/pypi/v/opt_einsum.svg)](https://pypi.org/project/opt-einsum/#description) 7 | [![PyPIStats](https://img.shields.io/pypi/dm/opt_einsum)](https://pypistats.org/packages/opt-einsum) 8 | [![Documentation Status](https://github.com/dgasmith/opt_einsum/actions/workflows/Docs.yaml/badge.svg)](https://dgasmith.github.io/opt_einsum/) 9 | [![DOI](https://joss.theoj.org/papers/10.21105/joss.00753/status.svg)](https://doi.org/10.21105/joss.00753) 10 | 11 | ## Optimized Einsum: A tensor contraction order optimizer 12 | 13 | Optimized einsum can significantly reduce the overall execution time of einsum-like expressions (e.g., 14 | [`np.einsum`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html), 15 | [`dask.array.einsum`](https://docs.dask.org/en/latest/array-api.html#dask.array.einsum), 16 | [`pytorch.einsum`](https://pytorch.org/docs/stable/torch.html#torch.einsum), 17 | [`tensorflow.einsum`](https://www.tensorflow.org/api_docs/python/tf/einsum), 18 | ) 19 | by optimizing the expression's contraction order and dispatching many 20 | operations to canonical BLAS, cuBLAS, or other specialized routines. 21 | 22 | Optimized 23 | einsum is agnostic to the backend and can handle NumPy, Dask, PyTorch, 24 | Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as potentially 25 | any library which conforms to a standard API. See the 26 | [**documentation**](https://dgasmith.github.io/opt_einsum/) for more 27 | information. 28 | 29 | ## Example usage 30 | 31 | The [`opt_einsum.contract`](https://dgasmith.github.io/opt_einsum/api_reference#opt_einsumcontract) 32 | function can often act as a drop-in replacement for `einsum` 33 | functions without further changes to the code while providing superior performance. 34 | Here, a tensor contraction is performed with and without optimization: 35 | 36 | ```python 37 | import numpy as np 38 | from opt_einsum import contract 39 | 40 | N = 10 41 | C = np.random.rand(N, N) 42 | I = np.random.rand(N, N, N, N) 43 | 44 | %timeit np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 45 | 1 loops, best of 3: 934 ms per loop 46 | 47 | %timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 48 | 1000 loops, best of 3: 324 us per loop 49 | ``` 50 | 51 | In this particular example, we see a ~3000x performance improvement which is 52 | not uncommon when compared against unoptimized contractions. See the [backend 53 | examples](https://dgasmith.github.io/opt_einsum/getting_started/backends) 54 | for more information on using other backends. 55 | 56 | ## Features 57 | 58 | The algorithms found in this repository often power the `einsum` optimizations 59 | in many of the above projects. For example, the optimization of `np.einsum` 60 | has been passed upstream and most of the same features that can be found in 61 | this repository can be enabled with `np.einsum(..., optimize=True)`. However, 62 | this repository often has more up to date algorithms for complex contractions. 63 | 64 | The following capabilities are enabled by `opt_einsum`: 65 | 66 | * Inspect [detailed information](https://dgasmith.github.io/opt_einsum/paths/introduction) about the path chosen. 67 | * Perform contractions with [numerous backends](https://dgasmith.github.io/opt_einsum/getting_started/backends), including on the GPU and with libraries such as [TensorFlow](https://www.tensorflow.org) and [PyTorch](https://pytorch.org). 68 | * Generate [reusable expressions](https://dgasmith.github.io/opt_einsum/getting_started/reusing_paths), potentially with [constant tensors](https://dgasmith.github.io/opt_einsum/getting_started/reusing_paths#specifying-constants), that can be compiled for greater performance. 69 | * Use an arbitrary number of indices to find contractions for [hundreds or even thousands of tensors](https://dgasmith.github.io/opt_einsum/examples/large_expr_with_greedy). 70 | * Share [intermediate computations](https://dgasmith.github.io/opt_einsum/getting_started/sharing_intermediates) among multiple contractions. 71 | * Compute gradients of tensor contractions using [autograd](https://github.com/HIPS/autograd) or [jax](https://github.com/google/jax) 72 | 73 | Please see the [documentation](https://dgasmith.github.io/opt_einsum/index) for more features! 74 | 75 | ## Installation 76 | 77 | `opt_einsum` can either be installed via `pip install opt_einsum` or from conda `conda install opt_einsum -c conda-forge`. 78 | See the installation [documentation](https://dgasmith.github.io/opt_einsum/getting_started/install) for further methods. 79 | 80 | ## Citation 81 | 82 | If this code has benefited your research, please support us by citing: 83 | 84 | Daniel G. A. Smith and Johnnie Gray, opt_einsum - A Python package for optimizing contraction order for einsum-like expressions. *Journal of Open Source Software*, **2018**, 3(26), 753 85 | 86 | DOI: 87 | 88 | ## Contributing 89 | 90 | All contributions, bug reports, bug fixes, documentation improvements, enhancements, and ideas are welcome. 91 | 92 | A detailed overview on how to contribute can be found in the [contributing guide](https://github.com/dgasmith/opt_einsum/blob/master/.github/CONTRIBUTING.md). 93 | -------------------------------------------------------------------------------- /devtools/RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Checklist 2 | 3 | ## Lint Static Scan 4 | 5 | Check for flake8 issues and spelling. 6 | 7 | ```shell 8 | pip install flake8-spellcheck 9 | flake8 --whitelist ./devtools/allowlist.txt 10 | ``` 11 | 12 | ## PyPI Source and Wheel 13 | 14 | ```shell 15 | conda update setuptools wheel 16 | 17 | python setup.py sdist bdist_wheel 18 | twine upload --repository-url https://test.pypi.org/legacy/ dist/ 19 | ``` 20 | 21 | ## Update conda-forge 22 | 23 | ```plaintext 24 | - Version 25 | - Zip Hash 26 | ``` 27 | -------------------------------------------------------------------------------- /devtools/allowlist.txt: -------------------------------------------------------------------------------- 1 | 0b100101 2 | 0rc1 3 | 10pt 4 | 11pt 5 | 12pt 6 | 16x16 7 | 32x32 8 | a4paper 9 | abap 10 | allclose 11 | astype 12 | autodoc 13 | autogenerated 14 | autograd 15 | Backends 16 | backendseq 17 | bbfreeze 18 | bdist 19 | blas 20 | borland 21 | bw 22 | C0301 23 | caes 24 | cba 25 | cfg 26 | cmd 27 | cmdclass 28 | cmin 29 | conda 30 | const 31 | consts 32 | crossref 33 | cupy 34 | cx 35 | datestamp 36 | DDOT 37 | Deduplicate 38 | dep 39 | dereference 40 | detailmenu 41 | df 42 | dirname 43 | documentclass 44 | dtype 45 | ein 46 | einsum 47 | eq 48 | execfile 49 | expr 50 | favicon 51 | fi 52 | FILEVERSION 53 | fmt 54 | fn 55 | fs 56 | func 57 | GEMM 58 | GEMV 59 | gh 60 | gitattributes 61 | githubs 62 | hadamard 63 | Hadamard 64 | hardlink 65 | hashable 66 | hashtable 67 | howto 68 | htaccess 69 | htbp 70 | https 71 | ico 72 | idx 73 | iij 74 | ij 75 | ik 76 | inds 77 | isort 78 | ja 79 | ja 80 | jax 81 | jieba 82 | jk 83 | jkk 84 | js 85 | letterpaper 86 | lgtm 87 | lhs 88 | libs 89 | loc 90 | lru 91 | manni 92 | mem 93 | method1 94 | method2 95 | modindex 96 | moduleauthor 97 | monokai 98 | nczeczulin 99 | ndim 100 | nl 101 | no 102 | NUM 103 | numpy 104 | numpytensordot 105 | opensearch 106 | outputless 107 | pagerefs 108 | papersize 109 | paraiso 110 | parentdir 111 | pep440 112 | perldoc 113 | pointsize 114 | prepended 115 | prodcuts 116 | PRODUCTVERSION 117 | py2exe 118 | Pygments 119 | pylint 120 | quickstart 121 | recurse 122 | ret 123 | refnames 124 | rhs 125 | ro 126 | rrt 127 | rsrc 128 | rst 129 | runtime 130 | s1 131 | s2 132 | sdist 133 | sectionauthor 134 | Ses 135 | setrlimit 136 | sig 137 | skipif 138 | sourcedist 139 | sourcelink 140 | sparsify 141 | sphinxstrong 142 | sphinxtitleref 143 | subdependencies 144 | subgraph 145 | subgraphs 146 | subst 147 | sv 148 | TDOT 149 | tempdir 150 | tensordot 151 | Tensordot 152 | test5 153 | texinfo 154 | Texinfo 155 | tf 156 | TF 157 | theano 158 | timeit 159 | titleref 160 | tmp 161 | toctree 162 | toplevel 163 | trac 164 | transpose 165 | uncomparable 166 | undoc 167 | unparsable 168 | v0 169 | VCS 170 | Vectordot 171 | versioneer 172 | versionfile 173 | x3 174 | xcode 175 | xhtml 176 | zh 177 | zh 178 | zipball 179 | -------------------------------------------------------------------------------- /devtools/ci_scripts/check_no_numpy.py: -------------------------------------------------------------------------------- 1 | try: 2 | import numpy 3 | exit(1) 4 | except ModuleNotFoundError: 5 | exit(0) -------------------------------------------------------------------------------- /devtools/conda-envs/full-environment.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Base depends 6 | - python >=3.8 7 | - numpy >=1.23 8 | - nomkl 9 | 10 | # Backends 11 | - tensorflow-cpu 12 | - dask 13 | - sparse 14 | - pytorch-cpu 15 | - jax 16 | 17 | # Testing 18 | - codecov 19 | - mypy ==1.11* 20 | - pytest 21 | - pytest-cov 22 | - ruff ==0.5.* 23 | -------------------------------------------------------------------------------- /devtools/conda-envs/min-deps-environment.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Base depends 6 | - python >=3.8 7 | 8 | # Testing 9 | - codecov 10 | - mypy ==1.11* 11 | - pytest 12 | - pytest-cov 13 | - ruff ==0.6.* 14 | -------------------------------------------------------------------------------- /devtools/conda-envs/min-ver-environment.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Base depends 6 | - python >=3.8 7 | - numpy >=1.23 8 | - nomkl 9 | 10 | # Backends 11 | - tensorflow-cpu ==2.10.* 12 | - dask ==2021.* 13 | 14 | # Testing 15 | - codecov 16 | - mypy ==1.11* 17 | - pytest 18 | - pytest-cov 19 | - ruff ==0.5.* 20 | -------------------------------------------------------------------------------- /devtools/conda-envs/torch-only-environment.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | # Base depends 7 | - python >=3.8 8 | - pytorch::pytorch >=2.0,<3.0.0a 9 | - pytorch::cpuonly 10 | - mkl 11 | 12 | # Testing 13 | - codecov 14 | - mypy ==1.11* 15 | - pytest 16 | - pytest-cov 17 | - ruff ==0.5.* 18 | -------------------------------------------------------------------------------- /devtools/conda-recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "opt_einsum" %} 2 | {% set version = "2.0.0" %} 3 | 4 | package: 5 | name: {{ name|lower }} 6 | version: {{ version }} 7 | 8 | source: 9 | path: ../.. 10 | 11 | build: 12 | number: 0 13 | script: python setup.py install --single-version-externally-managed --record record.txt 14 | 15 | requirements: 16 | build: 17 | - python 18 | - setuptools 19 | run: 20 | - python 21 | - numpy 22 | 23 | test: 24 | requires: 25 | - python 26 | - pytest 27 | commands: 28 | - py.test --pyargs opt_einsum 29 | 30 | about: 31 | home: http://github.com/dgasmith/opt_einsum 32 | license: MIT 33 | license_family: MIT 34 | license_file: LICENSE 35 | summary: 'A contraction optimizer for the NumPy Einsum function.' 36 | 37 | description: > 38 | Einsum is a very powerful function for contracting tensors of arbitrary dimension and index. 39 | However, it is only optimized to contract two terms at a time resulting in non-optimal scaling. 40 | 41 | For example, let us examine the following index transformation: 42 | `M_{pqrs} = C_{pi} C_{qj} I_{ijkl} C_{rk} C_{sl}` 43 | 44 | We can then develop two separate implementations that produce the same result: 45 | ```python 46 | N = 10 47 | C = np.random.rand(N, N) 48 | I = np.random.rand(N, N, N, N) 49 | 50 | def naive(I, C): 51 | # N^8 scaling 52 | return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 53 | 54 | def optimized(I, C): 55 | # N^5 scaling 56 | K = np.einsum('pi,ijkl->pjkl', C, I) 57 | K = np.einsum('qj,pjkl->pqkl', C, K) 58 | K = np.einsum('rk,pqkl->pqrl', C, K) 59 | K = np.einsum('sl,pqrl->pqrs', C, K) 60 | return K 61 | ``` 62 | 63 | The einsum function does not consider building intermediate arrays; therefore, helping einsum out by building these intermediate arrays can result in a considerable cost savings even for small N (N=10): 64 | 65 | ```python 66 | np.allclose(naive(I, C), optimized(I, C)) 67 | True 68 | 69 | %timeit naive(I, C) 70 | 1 loops, best of 3: 934 ms per loop 71 | 72 | %timeit optimized(I, C) 73 | 1000 loops, best of 3: 527 µs per loop 74 | ``` 75 | 76 | A 2000 fold speed up for 4 extra lines of code! 77 | This contraction can be further complicated by considering that the shape of the C matrices need not be the same, in this case the ordering in which the indices are transformed matters greatly. 78 | Logic can be built that optimizes the ordering; however, this is a lot of time and effort for a single expression. 79 | The opt_einsum package is a drop in replacement for the np.einsum function and can handle all of this logic for you: 80 | 81 | ```python 82 | from opt_einsum import contract 83 | 84 | %timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 85 | 1000 loops, best of 3: 324 µs per loop 86 | ``` 87 | 88 | The above will automatically find the optimal contraction order, in this case identical to that of the optimized function above, and compute the products for you. In this case, it even uses `np.dot` under the hood to exploit any vendor BLAS functionality that your NumPy build has! 89 | dev_url: https://github.com/dgasmith/opt_einsum 90 | 91 | extra: 92 | recipe-maintainers: 93 | - dgasmith 94 | - loriab 95 | -------------------------------------------------------------------------------- /docs/api_reference.md: -------------------------------------------------------------------------------- 1 | --- 2 | toc_depth: 1 3 | --- 4 | 5 | # API Documentation 6 | 7 | ### `opt_einsum.contract` 8 | 9 | ::: opt_einsum.contract.contract 10 | 11 | 12 | ### `opt_einsum.contract_path` 13 | 14 | ::: opt_einsum.contract.contract_path 15 | 16 | 17 | ### `opt_einsum.contract_expression` 18 | 19 | ::: opt_einsum.contract.contract_expression 20 | 22 | 23 | ### `opt_einsum.contract.ContractExpression` 24 | 25 | ::: opt_einsum.contract.ContractExpression 26 | 28 | 29 | ### `opt_einsum.contract.PathInfo` 30 | 31 | ::: opt_einsum.contract.PathInfo 32 | 33 | 34 | ### `opt_einsum.get_symbol` 35 | 36 | ::: opt_einsum.parser.get_symbol 37 | 38 | 39 | ### `opt_einsum.shared_intermediates` 40 | 41 | ::: opt_einsum.sharing.shared_intermediates 42 | 43 | 44 | ### `opt_einsum.paths.optimal` 45 | 46 | ::: opt_einsum.paths.optimal 47 | 48 | 49 | ### `opt_einsum.paths.greedy` 50 | 51 | ::: opt_einsum.paths.greedy 52 | 53 | 54 | ### `opt_einsum.paths.branch` 55 | 56 | ::: opt_einsum.paths.branch 57 | 58 | 59 | ### `opt_einsum.paths.PathOptimizer` 60 | 61 | ::: opt_einsum.paths.PathOptimizer 62 | 64 | 65 | ### `opt_einsum.paths.BranchBound` 66 | 67 | ::: opt_einsum.paths.BranchBound 68 | 70 | 71 | ### `opt_einsum.path_random.RandomOptimizer` 72 | 73 | ::: opt_einsum.path_random.RandomOptimizer 74 | 76 | 77 | ### `opt_einsum.path_random.RandomGreedy` 78 | 79 | ::: opt_einsum.path_random.RandomGreedy 80 | 82 | 83 | ### `opt_einsum.paths.DynamicProgramming` 84 | 85 | ::: opt_einsum.paths.DynamicProgramming 86 | 88 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | ## 3.4.0 / 2024-09-26 5 | 6 | NumPy has been removed from `opt_einsum` as a dependency allowing for more flexible installs. 7 | 8 | **New Features** 9 | 10 | - [\#160](https://github.com/dgasmith/opt_einsum/pull/160) Migrates docs to MkDocs Material and GitHub pages hosting. 11 | - [\#161](https://github.com/dgasmith/opt_einsum/pull/161) Adds Python type annotations to the code base. 12 | - [\#204](https://github.com/dgasmith/opt_einsum/pull/204) Removes NumPy as a hard dependency. 13 | 14 | **Enhancements** 15 | 16 | - [\#154](https://github.com/dgasmith/opt_einsum/pull/154) Prevents an infinite recursion error when the `memory_limit` was set very low for the `dp` algorithm. 17 | - [\#155](https://github.com/dgasmith/opt_einsum/pull/155) Adds flake8 spell check to the doc strings 18 | - [\#159](https://github.com/dgasmith/opt_einsum/pull/159) Migrates to GitHub actions for CI. 19 | - [\#174](https://github.com/dgasmith/opt_einsum/pull/174) Prevents double contracts of floats in dynamic paths. 20 | - [\#196](https://github.com/dgasmith/opt_einsum/pull/196) Allows `backend=None` which is equivalent to `backend='auto'` 21 | - [\#208](https://github.com/dgasmith/opt_einsum/pull/208) Switches to `ConfigParser` insetad of `SafeConfigParser` for Python 3.12 compatability. 22 | - [\#228](https://github.com/dgasmith/opt_einsum/pull/228) `backend='jaxlib'` is now an alias for the `jax` library 23 | - [\#237](https://github.com/dgasmith/opt_einsum/pull/237) Switches to `ruff` for formatting and linting. 24 | - [\#238](https://github.com/dgasmith/opt_einsum/pull/238) Removes `numpy`-specific keyword args from being explicitly defined in `contract` and uses `**kwargs` instead. 25 | 26 | **Bug Fixes** 27 | 28 | - [\#195](https://github.com/dgasmith/opt_einsum/pull/195) Fixes a bug where `dp` would not work for scalar-only contractions. 29 | - [\#200](https://github.com/dgasmith/opt_einsum/pull/200) Fixes a bug where `parse_einsum_input` would not correctly respect shape-only contractions. 30 | - [\#222](https://github.com/dgasmith/opt_einsum/pull/222) Fixes an erorr in `parse_einsum_input` where an output subscript specified multiple times was not correctly caught. 31 | - [\#229](https://github.com/dgasmith/opt_einsum/pull/229) Fixes a bug where empty contraction lists in `PathInfo` would cause an error. 32 | 33 | ## 3.3.0 / 2020-07-19 34 | 35 | Adds a `object` backend for optimized contractions on arbitrary Python objects. 36 | 37 | **New Features** 38 | 39 | - [\#145](https://github.com/dgasmith/opt_einsum/pull/145) Adds a `object` based backend so that `contract(backend='object')` can be used on arbitrary objects such as SymPy symbols. 40 | 41 | **Enhancements** 42 | 43 | - [\#140](https://github.com/dgasmith/opt_einsum/pull/140) Better error messages when the requested `contract` backend cannot be found. 44 | - [\#141](https://github.com/dgasmith/opt_einsum/pull/141) Adds a check with RandomOptimizers to ensure the objects are not accidentally reused for different contractions. 45 | - [\#149](https://github.com/dgasmith/opt_einsum/pull/149) Limits the `remaining` category for the `contract_path` output to only show up to 20 tensors to prevent issues with the quadratically scaling memory requirements and the number of print lines for large contractions. 46 | 47 | ## 3.2.0 / 2020-03-01 48 | 49 | Small fixes for the `dp` path and support for a new mars backend. 50 | 51 | **New Features** 52 | 53 | - [\#109](https://github.com/dgasmith/opt_einsum/pull/109) Adds mars backend support. 54 | 55 | **Enhancements** 56 | 57 | - [\#110](https://github.com/dgasmith/opt_einsum/pull/110) New `auto-hq` and `'random-greedy-128'` paths. 58 | - [\#119](https://github.com/dgasmith/opt_einsum/pull/119) Fixes several edge cases in the `dp` path. 59 | 60 | **Bug fixes** 61 | 62 | - [\#127](https://github.com/dgasmith/opt_einsum/pull/127) Fixes an issue where Python 3.6 features are required while Python 3.5 is `opt_einsum`'s stated minimum version. 63 | 64 | ## 3.1.0 / 2019-09-30 65 | 66 | Adds a new dynamic programming algorithm to the suite of paths. 67 | 68 | **New Features** 69 | 70 | - [\#102](https://github.com/dgasmith/opt_einsum/pull/102) Adds new `dp` path. 71 | 72 | ## 3.0.0 / 2019-08-10 73 | 74 | This release moves `opt_einsum` to be backend agnostic while adding support 75 | additional backends such as Jax and Autograd. Support for Python 2.7 has been dropped and Python 3.5 will become the new minimum version, a Python deprecation policy equivalent to NumPy's has been adopted. 76 | 77 | 78 | **New Features** 79 | 80 | - [\#78](https://github.com/dgasmith/opt_einsum/pull/78) A new random-optimizer has been implemented which uses Boltzmann weighting to explore alternative near-minimum paths using greedy-like schemes. This provides a fairly large path performance enhancements with a linear path time overhead. 81 | - [\#78](https://github.com/dgasmith/opt_einsum/pull/78) A new PathOptimizer class has been implemented to provide a framework for building new optimizers. An example is that now custom cost functions can now be provided in the greedy formalism for building custom optimizers without a large amount of additional code. 82 | - [\#81](https://github.com/dgasmith/opt_einsum/pull/81) The `backend="auto"` keyword has been implemented for `contract` allowing automatic detection of the correct backend to use based off provided tensors in the contraction. 83 | - [\#88](https://github.com/dgasmith/opt_einsum/pull/88) Autograd and Jax support have been implemented. 84 | - [\#96](https://github.com/dgasmith/opt_einsum/pull/96) Deprecates Python 2 functionality and devops improvements. 85 | 86 | **Enhancements** 87 | 88 | - [\#84](https://github.com/dgasmith/opt_einsum/pull/84) The `contract_path` function can now accept shape tuples rather than full tensors. 89 | - [\#84](https://github.com/dgasmith/opt_einsum/pull/84) The `contract_path` automated path algorithm decision technology has been refactored to a standalone function. 90 | 91 | 92 | ## 2.3.0 / 2018-12-01 93 | 94 | This release primarily focuses on expanding the suite of available path 95 | technologies to provide better optimization characistics for 4-20 tensors while 96 | decreasing the time to find paths for 50-200+ tensors. See `Path Overview `_ for more information. 97 | 98 | **New Features** 99 | 100 | - [\#60](https://github.com/dgasmith/opt_einsum/pull/60) A new `greedy` implementation has been added which is up to two orders of magnitude faster for 200 tensors. 101 | - [\#73](https://github.com/dgasmith/opt_einsum/pull/73) Adds a new `branch` path that uses `greedy` ideas to prune the `optimal` exploration space to provide a better path than `greedy` at sub `optimal` cost. 102 | - [\#73](https://github.com/dgasmith/opt_einsum/pull/73) Adds a new `auto` keyword to the `opt_einsum.contract` `path` option. This keyword automatically chooses the best path technology that takes under 1ms to execute. 103 | 104 | **Enhancements** 105 | 106 | - [\#61](https://github.com/dgasmith/opt_einsum/pull/61) The `opt_einsum.contract` `path` keyword has been changed to `optimize` to more closely match NumPy. `path` will be deprecated in the future. 107 | - [\#61](https://github.com/dgasmith/opt_einsum/pull/61) The `opt_einsum.contract_path` now returns a `opt_einsum.contract.PathInfo` object that can be queried for the scaling, flops, and intermediates of the path. The print representation of this object is identical to before. 108 | - [\#61](https://github.com/dgasmith/opt_einsum/pull/61) The default `memory_limit` is now unlimited by default based on community feedback. 109 | - [\#66](https://github.com/dgasmith/opt_einsum/pull/66) The Torch backend will now use `tensordot` when using a version of Torch which includes this functionality. 110 | - [\#68](https://github.com/dgasmith/opt_einsum/pull/68) Indices can now be any hashable object when provided in the `"Interleaved Input" `_ syntax. 111 | - [\#74](https://github.com/dgasmith/opt_einsum/pull/74) Allows the default `transpose` operation to be overridden to take advantage of more advanced tensor transpose libraries. 112 | - [\#73](https://github.com/dgasmith/opt_einsum/pull/73) The `optimal` path is now significantly faster. 113 | - [\#81](https://github.com/dgasmith/opt_einsum/pull/81) A documentation pass for v3.0. 114 | 115 | **Bug fixes** 116 | 117 | - [\#72](https://github.com/dgasmith/opt_einsum/pull/72) Fixes the `"Interleaved Input" `_ syntax and adds documentation. 118 | 119 | ## 2.2.0 / 2018-07-29 120 | 121 | **New Features** 122 | 123 | - [\#48](https://github.com/dgasmith/opt_einsum/pull/48) Intermediates can now be shared between contractions, see here for more details. 124 | - [\#53](https://github.com/dgasmith/opt_einsum/pull/53) Intermediate caching is thread safe. 125 | 126 | **Enhancements** 127 | 128 | - [\#48](https://github.com/dgasmith/opt_einsum/pull/48) Expressions are now mapped to non-unicode index set so that unicode input is support for all backends. 129 | - [\#54](https://github.com/dgasmith/opt_einsum/pull/54) General documentation update. 130 | 131 | **Bug fixes** 132 | 133 | - [\#41](https://github.com/dgasmith/opt_einsum/pull/41) PyTorch indices are mapped back to a small a-z subset valid for PyTorch's einsum implementation. 134 | 135 | ## 2.1.3 / 2018-8-23 136 | 137 | **Bug fixes** 138 | 139 | - Fixes unicode issue for large numbers of tensors in Python 2.7. 140 | - Fixes unicode install bug in README.md. 141 | 142 | ## 2.1.2 / 2018-8-16 143 | 144 | **Bug fixes** 145 | 146 | - Ensures `versioneer.py` is in MANIFEST.in for a clean pip install. 147 | 148 | 149 | ## 2.1.1 / 2018-8-15 150 | 151 | **Bug fixes** 152 | 153 | - Corrected Markdown display on PyPi. 154 | 155 | ## 2.1.0 / 2018-8-15 156 | 157 | `opt_einsum` continues to improve its support for additional backends beyond NumPy with PyTorch. 158 | 159 | We have also published the opt_einsum package in the Journal of Open Source Software. If you use this package in your work, please consider citing us! 160 | 161 | **New features** 162 | 163 | - PyTorch backend support 164 | - Tensorflow eager-mode execution backend support 165 | 166 | **Enhancements** 167 | 168 | - Intermediate tensordot-like expressions are now ordered to avoid transposes. 169 | - CI now uses conda backend to better support GPU and tensor libraries. 170 | - Now accepts arbitrary unicode indices rather than a subset. 171 | - New auto path option which switches between optimal and greedy at four tensors. 172 | 173 | **Bug fixes** 174 | 175 | - Fixed issue where broadcast indices were incorrectly locked out of tensordot-like evaluations even after their dimension was broadcast. 176 | 177 | ## 2.0.1 / 2018-6-28 178 | 179 | **New Features** 180 | 181 | - Allows unlimited Unicode indices. 182 | - Adds a Journal of Open-Source Software paper. 183 | - Minor documentation improvements. 184 | 185 | 186 | ## 2.0.0 / 2018-5-17 187 | 188 | `opt_einsum` is a powerful tensor contraction order optimizer for NumPy and related ecosystems. 189 | 190 | **New Features** 191 | 192 | - Expressions can be precompiled so that the expression optimization need not happen multiple times. 193 | - The greedy order optimization algorithm has been tuned to be able to handle hundreds of tensors in several seconds. 194 | - Input indices can now be unicode so that expressions can have many thousands of indices. 195 | - GPU and distributed computing backends have been added such as Dask, TensorFlow, CUPy, Theano, and Sparse. 196 | 197 | **Bug Fixes** 198 | 199 | - An error affecting cases where opt_einsum mistook broadcasting operations for matrix multiply has been fixed. 200 | - Most error messages are now more expressive. 201 | 202 | 203 | ## 1.0.0 / 2016-10-14 204 | 205 | Einsum is a very powerful function for contracting tensors of arbitrary 206 | dimension and index. However, it is only optimized to contract two terms at a 207 | time resulting in non-optimal scaling for contractions with many terms. 208 | Opt_einsum aims to fix this by optimizing the contraction order which can lead 209 | to arbitrarily large speed ups at the cost of additional intermediate tensors. 210 | 211 | Opt_einsum is also implemented into the np.einsum function as of NumPy v1.12. 212 | 213 | **New Features** 214 | 215 | - Tensor contraction order optimizer. 216 | - `opt_einsum.contract` as a drop-in replacement for `numpy.einsum`. 217 | -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | div.autodoc-docstring { 2 | padding-left: 20px; 3 | margin-bottom: 30px; 4 | border-left: 5px solid rgba(230, 230, 230); 5 | } 6 | 7 | div.autodoc-members { 8 | padding-left: 20px; 9 | margin-bottom: 15px; 10 | } 11 | -------------------------------------------------------------------------------- /docs/examples/dask_reusing_intermediaries.md: -------------------------------------------------------------------------------- 1 | # Reusing Intermediaries with Dask 2 | 3 | [Dask](https://dask.pydata.org/) provides a computational framework where arrays and the computations on them are built up into a 'task graph' before computation. 4 | Since :mod:`opt_einsum` is compatible with `dask` arrays this means that multiple contractions can be built into the same task graph, which then automatically reuses any shared arrays and contractions. 5 | 6 | For example, imagine the two expressions: 7 | 8 | ```python 9 | contraction1 = 'ab,dca,eb,cde' 10 | contraction2 = 'ab,cda,eb,cde' 11 | sizes = {l: 10 for l in 'abcde'} 12 | ``` 13 | 14 | The contraction `'ab,eb'` is shared between them and could only be done once. 15 | First, let's set up some `numpy` arrays: 16 | 17 | ```python 18 | terms1, terms2 = contraction1.split(','), contraction2.split(',') 19 | terms = set((*terms1, *terms2)) 20 | terms 21 | #> {'ab', 'cda', 'cde', 'dca', 'eb'} 22 | 23 | import numpy as np 24 | np_arrays = {s: np.random.randn(*(sizes[c] for c in s)) for s in terms} 25 | # filter the arrays needed for each expression 26 | np_ops1 = [np_arrays[s] for s in terms1] 27 | np_ops2 = [np_arrays[s] for s in terms2] 28 | ``` 29 | 30 | Typically we would compute these expressions separately: 31 | 32 | ```python 33 | oe.contract(contraction1, *np_ops1) 34 | #> array(114.78314052) 35 | 36 | oe.contract(contraction2, *np_ops2) 37 | #> array(-75.55902751) 38 | ``` 39 | 40 | 41 | However, if we use dask arrays we can combine the two operations, so let's set those up: 42 | 43 | ```python 44 | import dask.array as da 45 | da_arrays = {s: da.from_array(np_arrays[s], chunks=1000, name=s) for s in inputs} 46 | da_arrays 47 | #> {'ab': dask.array, 48 | #> 'cda': dask.array, 49 | #> 'cde': dask.array, 50 | #> 'dca': dask.array, 51 | #> 'eb': dask.array} 52 | 53 | da_ops1 = [da_arrays[s] for s in terms1] 54 | da_ops2 = [da_arrays[s] for s in terms2] 55 | ``` 56 | 57 | Note `chunks` is a required argument relating to how the arrays are stored (see [array-creation](http://dask.pydata.org/en/latest/array-creation.html)). 58 | Now we can perform the contraction: 59 | 60 | ```python 61 | # these won't be immediately evaluated 62 | dy1 = oe.contract(contraction1, *da_ops1, backend='dask') 63 | dy2 = oe.contract(contraction2, *da_ops2, backend='dask') 64 | 65 | # wrap them in delayed to combine them into the same computation 66 | from dask import delayed 67 | dy = delayed([dy1, dy2]) 68 | dy 69 | #> Delayed('list-3af82335-b75e-47d6-b800-68490fc865fd') 70 | ``` 71 | 72 | As suggested by the name `Delayed`, we have a placeholder for the result 73 | so far. When we want to *perform* the computation we can call: 74 | 75 | ```python 76 | dy.compute() 77 | #> [114.78314052155015, -75.55902750513113] 78 | ``` 79 | 80 | The above matches the canonical numpy result. The computation can even be handled by various 81 | schedulers - see [scheduling](http://dask.pydata.org/en/latest/scheduling.html). 82 | Finally, to check we are reusing intermediaries, we can view the task graph generated for the computation: 83 | 84 | ```python 85 | dy.visualize(optimize_graph=True) 86 | ``` 87 | 88 | ![Dask Reuse Graph](../img/ex_dask_reuse_graph.png) 89 | 90 | !!! note 91 | For sharing intermediates with other backends see [Sharing Intermediates](../getting_started/sharing_intermediates.md). Dask graphs are particularly useful for reusing intermediates beyond just contractions and can allow additional parallelization. 92 | -------------------------------------------------------------------------------- /docs/examples/large_expr_with_greedy.md: -------------------------------------------------------------------------------- 1 | # Large Expressions with Greedy 2 | 3 | Using the greedy method allows the contraction of hundreds of tensors. Here's 4 | an example from quantum of computing the inner product between two ['Matrix 5 | Product States'](https://en.wikipedia.org/wiki/Matrix_product_state). 6 | Graphically, if we represent each tensor as an `O`, give it 7 | the same number of 'legs' as it has indices, and join those legs when that 8 | index is summed with another tensor, we get an expression for `n` particles 9 | that looks like: 10 | 11 | ```console 12 | O-O-O-O-O-O- -O-O-O-O-O-O 13 | | | | | | | ... | | | | | | 14 | O-O-O-O-O-O- -O-O-O-O-O-O 15 | 16 | 0 1 2 3 4 5 ........... n-2 n-1 17 | ``` 18 | 19 | The meaning of this is not that important other than its a large, useful 20 | contraction. For `n=100` it involves 200 different tensors and about 300 21 | unique indices. With this many indices it can be useful to generate them with 22 | the function `opt_einsum.parser.get_symbol`. 23 | 24 | ### Setup the string 25 | 26 | ```python 27 | import numpy as np 28 | import opt_einsum as oe 29 | 30 | n = 100 31 | phys_dim = 3 32 | bond_dim = 10 33 | 34 | # start with first site 35 | # O-- 36 | # | 37 | # O-- 38 | einsum_str = "ab,ac," 39 | 40 | for i in range(1, n - 1): 41 | # set the upper left/right, middle and lower left/right indices 42 | # --O-- 43 | # | 44 | # --O-- 45 | j = 3 * i 46 | ul, ur, m, ll, lr = (oe.get_symbol(i) 47 | for i in (j - 1, j + 2, j, j - 2, j + 1)) 48 | einsum_str += "{}{}{},{}{}{},".format(m, ul, ur, m, ll, lr) 49 | 50 | # finish with last site 51 | # --O 52 | # | 53 | # --O 54 | i = n - 1 55 | j = 3 * i 56 | ul, m, ll, = (oe.get_symbol(i) for i in (j - 1, j, j - 2)) 57 | einsum_str += "{}{},{}{}".format(m, ul, m, ll) 58 | ``` 59 | 60 | ### Generate the shapes 61 | 62 | ```python 63 | def gen_shapes(): 64 | yield (phys_dim, bond_dim) 65 | yield (phys_dim, bond_dim) 66 | for i in range(1, n - 1): 67 | yield(phys_dim, bond_dim, bond_dim) 68 | yield(phys_dim, bond_dim, bond_dim) 69 | yield (phys_dim, bond_dim) 70 | yield (phys_dim, bond_dim) 71 | 72 | shapes = tuple(gen_shapes()) 73 | ``` 74 | 75 | 76 | Let's time how long it takes to generate the expression (`'greedy'` is used by default, and we turn off the `memory_limit`): 77 | 78 | ```python 79 | %timeit expr = oe.contract_expression(einsum_str, *shapes, memory_limit=-1) 80 | #> 76.2 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 81 | ``` 82 | 83 | This is pretty manageable, though we might want to think about splitting the 84 | expression up if we go a lot bigger. 85 | Importantly, we can then use this repeatedly with any set of matching arrays: 86 | 87 | ```python 88 | arrays = [np.random.randn(*shp) / 4 for shp in shapes] 89 | expr(*arrays) 90 | #> array(23.23628116) 91 | 92 | arrays = [np.random.randn(*shp) / 4 for shp in shapes] 93 | expr(*arrays) 94 | #> array(-12.21091879) 95 | ``` 96 | 97 | ### Full path 98 | 99 | And if we **really** want we can generate the full contraction path info: 100 | 101 | ```python 102 | print(oe.contract_path(einsum_str, *arrays, memory_limit=-1)[1]) 103 | #> Complete contraction: ab,ac,dcf,dbe,gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ-> 104 | #> Naive scaling: 298 105 | #> Optimized scaling: 5 106 | #> Naive FLOP count: 1.031e+248 107 | #> Optimized FLOP count: 1.168e+06 108 | #> Theoretical speedup: 88264689284468460017580864156865782413140936705854966013600065426858041248009637246968036807489558012989638169986640870276510490846199301907401763236976204166215471281505344088317454144870323271826022036197984172898402324699098341524952317952.000 109 | #> Largest intermediate: 3.000e+02 elements 110 | #> -------------------------------------------------------------------------------- 111 | #> scaling BLAS current remaining 112 | #> -------------------------------------------------------------------------------- 113 | #> 4 TDOT dbe,ab->ade ac,dcf,gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ,ade-> 114 | #> 4 TDOT dcf,ac->adf gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ,ade,adf-> 115 | #> 4 GEMM ƶƵ,ƳƲƵ->ƳƶƲ gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƱƴ,ƶƴ,ade,adf,ƳƶƲ-> 116 | #> 4 GEMM ƶƴ,ƳƱƴ->ƳƶƱ gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ade,adf,ƳƶƲ,ƳƶƱ-> 117 | #> 5 TDOT ade,geh->adgh gfi,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,adf,ƳƶƲ,ƳƶƱ,adgh-> 118 | #> 119 | #> ... 120 | #> 121 | #> 4 TDOT Ğğ,ĠğĢ->ĠĞĢ ĠĞġ,ģĢĥ,ģġĤ,Ĥĥ,ĠĞĢ-> 122 | #> 4 GEMM ĠĞĢ,ĠĞġ->ġĢ ģĢĥ,ģġĤ,Ĥĥ,ġĢ-> 123 | #> 4 GEMM Ĥĥ,ģĢĥ->ģĢĤ ģġĤ,ġĢ,ģĢĤ-> 124 | #> 4 TDOT ģĢĤ,ģġĤ->ġĢ ġĢ,ġĢ-> 125 | #> 2 DOT ġĢ,ġĢ-> -> 126 | ``` 127 | 128 | Where we can see the speedup over a naive einsum is about `10^241`, not bad! 129 | -------------------------------------------------------------------------------- /docs/getting_started/backends.md: -------------------------------------------------------------------------------- 1 | # Backends & GPU Support 2 | 3 | `opt_einsum` is largely agnostic to the type of n-dimensional arrays (tensors) 4 | it uses, since finding the contraction path only relies on getting the shape 5 | attribute of each array supplied. 6 | It can perform the underlying tensor contractions with various 7 | libraries. In fact, any library that provides a `numpy.tensordot` and 8 | `numpy.transpose` implementation can perform most normal contractions. 9 | However, certain special functionalities such as axes reduction are reliant on a 10 | `numpy.einsum` implementation. 11 | The following is a brief overview of libraries which have been tested with 12 | `opt_einsum`: 13 | 14 | - [tensorflow](https://www.tensorflow.org/): compiled tensor expressions 15 | that can run on GPU. 16 | - [theano](http://deeplearning.net/software/theano/): compiled tensor 17 | expressions that can run on GPU. 18 | - [cupy](https://cupy.chainer.org/): numpy-like api for GPU tensors. 19 | - [dask](https://dask.pydata.org/): larger-than-memory tensor 20 | computations, distributed scheduling, and potential reuse of 21 | intermediaries. 22 | - [sparse](https://sparse.pydata.org/): sparse tensors. 23 | - [pytorch](https://pytorch.org): numpy-like api for GPU tensors. 24 | - [autograd](https://github.com/HIPS/autograd): automatic derivative 25 | computation for tensor expressions 26 | - [jax](https://github.com/google/jax): compiled GPU tensor expressions 27 | including `autograd`-like functionality 28 | 29 | !!! note 30 | For a contraction to be possible without using a backend einsum, it must 31 | satisfy the following rule: in the full expression (*including* output 32 | indices) each index must appear twice. In other words, each dimension 33 | must be either contracted with one other dimension or left alone. 34 | 35 | 36 | ## Backend agnostic contractions 37 | 38 | The automatic backend detection will be detected based on the first supplied 39 | array (default), this can be overridden by specifying the correct `backend` 40 | argument for the type of arrays supplied when calling 41 | [`opt_einsum.contract`](../api_reference.md#opt_einsum.contract.contract). For example, if you had a library installed 42 | called `'foo'` which provided an `numpy.ndarray` like object with a 43 | `.shape` attribute as well as `foo.tensordot` and `foo.transpose` then 44 | you could contract them with something like: 45 | 46 | ```python 47 | contract(einsum_str, *foo_arrays, backend='foo') 48 | ``` 49 | 50 | Behind the scenes `opt_einsum` will find the contraction path, perform 51 | pairwise contractions using e.g. `foo.tensordot` and finally return the canonical 52 | type those functions return. 53 | 54 | ### Dask 55 | 56 | [dask](https://dask.pydata.org/) is an example of a library which satisfies 57 | these requirements. For example: 58 | 59 | ```python 60 | import opt_einsum as oe 61 | import dask.array as da 62 | shapes = (3, 200), (200, 300), (300, 4) 63 | dxs = [da.random.normal(0, 1, shp, chunks=(100, 100)) for shp in shapes] 64 | dxs 65 | #> [dask.array, 66 | #> dask.array, 67 | #> dask.array] 68 | 69 | 70 | dy = oe.contract("ab,bc,cd", *dxs) # will infer backend='dask' 71 | dy 72 | #> dask.array 73 | 74 | dy.compute() 75 | #> array([[ 470.71404665, 2.44931372, -28.47577265, 424.37716615], 76 | #> [ 64.38328345, -287.40753131, 144.46515642, 324.88169821], 77 | #> [-142.07153553, -180.41739259, 125.0973783 , -239.16754541]]) 78 | ``` 79 | 80 | 81 | In this case, dask arrays in = dask array out, since dask arrays have a shape 82 | attribute, and `opt_einsum` can find `dask.array.tensordot` and 83 | `dask.array.transpose`. 84 | 85 | 86 | ### Sparse 87 | 88 | The [sparse](https://sparse.pydata.org/) library also fits the requirements and is 89 | supported. An example: 90 | 91 | ```python 92 | import sparse as sp 93 | shapes = (3, 200), (200, 300), (300, 4) 94 | sxs = [sp.random(shp) for shp in shapes] 95 | sxs 96 | #> [, 97 | #> , 98 | #> ] 99 | 100 | oe.contract("ab,bc,cd", *sxs) 101 | #> 102 | ``` 103 | 104 | 105 | ### Autograd 106 | 107 | The [autograd](https://github.com/HIPS/autograd) library is a drop-in for 108 | `numpy` that can automatically compute the gradients of array expressions. 109 | `opt_einsum` automatically dispatches the `autograd` arrays correctly, 110 | enabling a simple way to compute gradients of tensor contractions: 111 | 112 | ```python 113 | import numpy as np 114 | import autograd 115 | shapes = [(2, 3), (3, 4), (4, 2)] 116 | x, y, z = [np.random.rand(*s) for s in shapes] 117 | 118 | # make single arg function as autograd takes derivative of first arg 119 | def foo(xyz): 120 | return oe.contract('ij,jk,ki->', *xyz) 121 | 122 | foo([x, y, z]) 123 | #> array(4.90422159) 124 | 125 | # wrap foo with autograd to compute gradients instead 126 | dfoo = autograd.grad(foo) 127 | dx, dy, dz = dfoo(arrays) 128 | dx, dy, dz 129 | #> (array([[1.10056194, 1.25078356, 1.48211494], 130 | #> [1.38945961, 1.5572077 , 1.65234003]]), 131 | #> array([[0.41710717, 0.63202881, 0.84573502, 0.95069975], 132 | #> [0.42706777, 0.73630994, 0.99328938, 0.77415267], 133 | #> [0.40773334, 0.61693475, 0.82545726, 0.93132302]]), 134 | #> array([[0.78747828, 1.28979012], 135 | #> [1.26051133, 1.48835538], 136 | #> [0.46896666, 0.55003072], 137 | #> [1.10840828, 1.16722494]])) 138 | ``` 139 | 140 | ### Jax 141 | 142 | [jax](https://github.com/google/jax) is itself a drop-in for `autograd`, 143 | that additionally uses [XLA](https://www.tensorflow.org/xla) to compile the 144 | expressions, particularly for the GPU. Using it with `opt_einsum` is very 145 | simple: 146 | 147 | ```python 148 | import jax 149 | # generate a compiled version of the above function 150 | jit_foo = jax.jit(foo) 151 | jit_foo([x, y, z]) 152 | #> DeviceArray(4.9042215, dtype=float32) 153 | 154 | # generate a compiled version of the gradient function 155 | jit_dfoo = jax.jit(jax.grad(foo)) 156 | jit_dfoo([x, y, z]) 157 | #> [DeviceArray([[1.10056198, 1.25078356, 1.48211491], 158 | #> [1.38945973, 1.5572077, 1.65234005]], dtype=float32), 159 | #> DeviceArray([[0.41710716, 0.63202882, 0.84573501, 0.95069975], 160 | #> [0.42706776, 0.73630995, 0.99328935, 0.7741527 ], 161 | #> [0.40773335, 0.61693472, 0.82545722, 0.93132305]], 162 | #> dtype=float32), 163 | #> DeviceArray([[0.78747827, 1.28979015], 164 | #> [1.2605114 , 1.4883554 ], 165 | #> [0.46896666, 0.55003077], 166 | #> [1.10840821, 1.16722488]], dtype=float32)] 167 | ``` 168 | 169 | !!! note 170 | `jax` defaults to converting all arrays to single precision. This 171 | behaviour can be changed by running 172 | `from jax.config import config; config.update("jax_enable_x64", True)` 173 | **before** it has been imported and used at all. 174 | 175 | 176 | 177 | ## Special (GPU) backends for numpy arrays 178 | 179 | A particular case is if numpy arrays are required for the input and output, 180 | however, a more performant backend is required such as performing the contraction on a GPU. 181 | Unless the specified backend works on numpy arrays, this requires converting to and from the backend array type. 182 | Currently `opt_einsum` can handle this automatically for: 183 | 184 | - [tensorflow](https://www.tensorflow.org/) 185 | - [theano](http://deeplearning.net/software/theano/) 186 | - [cupy](https://cupy.chainer.org/) 187 | - [pytorch](https://pytorch.org) 188 | - [jax](https://github.com/google/jax) 189 | 190 | all of which offer GPU support. Since `tensorflow` and `theano` both require 191 | compiling the expression, this functionality is encapsulated in generating a 192 | [`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression) using 193 | [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy 194 | arrays whilst specifying `backend='tensorflow'` etc. 195 | Additionally, if arrays are marked as `constant` 196 | (see [`constants-section`](./reusing_paths.md#specifying-constants)), then these arrays will be kept on the device 197 | for optimal performance. 198 | 199 | 200 | ### Theano 201 | 202 | If `theano` is installed, using it as backend is as simple as specifying 203 | `backend='theano'`: 204 | 205 | ```python 206 | shapes = (3, 200), (200, 300), (300, 4) 207 | expr = oe.contract_expression("ab,bc,cd", *shapes) 208 | expr 209 | #> 210 | 211 | import numpy as np 212 | # GPU advantage mainly for low precision numbers 213 | xs = [np.random.randn(*shp).astype(np.float32) for shp in shapes] 214 | expr(*xs, backend='theano') # might see some fluff on first run 215 | #> array([[ 129.28352 , -128.00702 , -164.62917 , -335.11682 ], 216 | #> [-462.52344 , -121.12657 , -67.847626 , 624.5457 ], 217 | #> [ 5.2838974, 36.441578 , 81.62851 , 703.1576 ]], 218 | #> dtype=float32) 219 | ``` 220 | 221 | Note that you can still supply `theano.tensor.TensorType` directly to 222 | `opt_einsum` (with `backend='theano'`), and it will return the 223 | relevant `theano` type. 224 | 225 | 226 | ### Tensorflow 227 | 228 | To run the expression with **tensorflow**, you need to register a default 229 | session: 230 | 231 | ```python 232 | import tensorflow as tf 233 | sess = tf.Session() 234 | 235 | with sess.as_default(): 236 | out = expr(*xs, backend='tensorflow') 237 | 238 | out 239 | #> array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ], 240 | #> [-462.52362 , -121.12659 , -67.84769 , 624.5455 ], 241 | #> [ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]], 242 | #> dtype=float32) 243 | ``` 244 | 245 | Note that you can still supply this expression with, for example, a 246 | `tensorflow.placeholder` using `backend='tensorflow'`, and then no 247 | conversion would take place, instead you'd get a `tensorflow.Tensor` back. 248 | 249 | Version 1.9 of tensorflow also added support for eager execution of 250 | computations. If compilation of the contraction expression tensorflow graph is 251 | taking a substantial amount of time up then it can be advantageous to use this, 252 | especially since tensor contractions are quite compute-bound. This is achieved 253 | by running the following snippet: 254 | 255 | ```python 256 | import tensorflow as tf 257 | tf.enable_eager_execution() 258 | ``` 259 | 260 | After which `opt_einsum` will automatically detect eager mode if 261 | `backend='tensorflow'` is supplied to a 262 | [`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression). 263 | 264 | 265 | ### Pytorch & Cupy 266 | 267 | Both [pytorch](https://pytorch.org) and [cupy](https://cupy.chainer.org/) 268 | offer numpy-like, GPU-enabled arrays which execute eagerly rather than 269 | requiring any compilation. If they are installed, no steps are required to 270 | utilize them other than specifying the `backend` keyword: 271 | 272 | ```python 273 | expr(*xs, backend='torch') 274 | #> array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ], 275 | #> [-462.52362 , -121.12659 , -67.84769 , 624.5455 ], 276 | #> [ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]], 277 | #> dtype=float32) 278 | 279 | expr(*xs, backend='cupy') 280 | #> array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ], 281 | #> [-462.52362 , -121.12659 , -67.84769 , 624.5455 ], 282 | #> [ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]], 283 | #> dtype=float32) 284 | ``` 285 | 286 | And as with the other GPU backends, if raw `cupy` or `pytorch` arrays are 287 | supplied the returned array will be of the same type, with no conversion 288 | to or from `numpy` arrays. 289 | 290 | ### Jax 291 | 292 | [jax](https://github.com/google/jax), as introduced above, can compile tensor 293 | functions, in doing so often achieving better performance. 294 | `opt_einsum` expressions can handle this behind the scenes, 295 | so again just the `backend` keyword needs to be supplied: 296 | 297 | ```python 298 | expr(*xs, backend='jax') 299 | #> array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ], 300 | #> [-462.52362 , -121.12659 , -67.84769 , 624.5455 ], 301 | #> [ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]], 302 | #> dtype=float32) 303 | ``` 304 | 305 | 306 | ## Contracting arbitrary objects 307 | 308 | There is one more explicit backend that can handle arbitrary arrays of objects, 309 | so long the *objects themselves* just support multiplication and addition ( 310 | `__mul__` and `__add__` dunder methods respectively). 311 | Use it by supplying `backend='object'`. 312 | 313 | For example, imagine we want to perform a contraction of arrays made up of [sympy](https://www.sympy.org) symbols: 314 | 315 | ```python 316 | import opt_einsum as oe 317 | import numpy as np 318 | import sympy 319 | 320 | # define the symbols 321 | a, b, c, d, e, f, g, h, i, j, k, l = [sympy.symbols(oe.get_symbol(i)) for i in range(12)] 322 | a * b + c * d 323 | 𝑑 324 | 325 | # define the tensors (you might explicitly specify `dtype=object`) 326 | X = np.array([[a, b], [c, d]]) 327 | Y = np.array([[e, f], [g, h]]) 328 | Z = np.array([[i, j], [k, l]]) 329 | 330 | # contract the tensors! 331 | oe.contract('uv,vw,wu->u', X, Y, Z, backend='object') 332 | # array([i*(a*e + b*g) + k*(a*f + b*h), j*(c*e + d*g) + l*(c*f + d*h)], 333 | # dtype=object) 334 | ``` 335 | 336 | There are a few things to note here: 337 | 338 | - The returned array is a `numpy.ndarray` but since it has `dtype=object` 339 | it can really hold *any* python objects 340 | - We had to explicitly use `backend='object'`, since `numpy.einsum` 341 | would have otherwise been dispatched to, which can't handle `dtype=object` 342 | (though `numpy.tensordot` in fact can) 343 | - Although an optimized pairwise contraction order is used, the looping in each 344 | single contraction is **performed in python so performance will be 345 | drastically lower than for numeric dtypes!** 346 | -------------------------------------------------------------------------------- /docs/getting_started/input_format.md: -------------------------------------------------------------------------------- 1 | # Input Format 2 | 3 | The `opt_einsum` package was originally designed as a drop-in replacement for the `np.einsum` 4 | function and supports all input formats that `np.einsum` supports. There are 5 | two styles of input accepted, a basic introduction to which can be found in the 6 | documentation for `numpy.einsum`. In addition to this, `opt_einsum` 7 | extends the allowed index labels to unicode or arbitrary hashable, comparable 8 | objects in order to handle large contractions with many indices. 9 | 10 | 11 | ## 'Equation' Input 12 | 13 | As with `numpy.einsum`, here you specify an equation as a string, 14 | followed by the array arguments: 15 | 16 | ```python 17 | import opt_einsum as oe 18 | eq = 'ijk,jkl->li' 19 | x, y = np.random.rand(2, 3, 4), np.random.rand(3, 4, 5) 20 | z = oe.contract(eq, x, y) 21 | z.shape 22 | #> (5, 2) 23 | ``` 24 | 25 | However, in addition to the standard alphabet, `opt_einsum` also supports 26 | unicode characters: 27 | 28 | ```python 29 | eq = "αβγ,βγδ->δα" 30 | oe.contract(eq, x, y).shape 31 | #> (5, 2) 32 | ``` 33 | 34 | This enables access to thousands of possible index labels. One way to access 35 | these programmatically is through the function [`get_symbols`](../api_reference.md#opt_einsumget_symbol): 36 | 37 | ```python 38 | oe.get_symbol(805) 39 | #> 'α' 40 | ``` 41 | 42 | which maps an `int` to a unicode characater. Note that as with 43 | `numpy.einsum` if the output is not specified with `->` it will default 44 | to the sorted order of all indices appearing once: 45 | 46 | ```python 47 | eq = "αβγ,βγδ" # "->αδ" is implicit 48 | oe.contract(eq, x, y).shape 49 | #> (2, 5) 50 | ``` 51 | 52 | 53 | ## 'Interleaved' Input 54 | 55 | The other input format is to 'interleave' the array arguments with their index 56 | labels ('subscripts') in pairs, optionally specifying the output indices as a 57 | final argument. As with `numpy.einsum`, integers are allowed as these 58 | index labels: 59 | 60 | ```python 61 | oe.contract(x, [1, 2, 3], y, [2, 3, 4], [4, 1]).shape 62 | #> (5, 2) 63 | ``` 64 | 65 | with the default output order again specified by the sorted order of indices 66 | appearing once. However, unlike `numpy.einsum`, in `opt_einsum` you can 67 | also put *anything* hashable and comparable such as `str` in the subscript list. 68 | A simple example of this syntax is: 69 | 70 | ```python 71 | x, y, z = np.ones((1, 2)), np.ones((2, 2)), np.ones((2, 1)) 72 | oe.contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right'), ('left', 'right')) 73 | #> array([[4.]]) 74 | ``` 75 | 76 | The subscripts need to be hashable so that `opt_einsum` can efficiently process them, and 77 | they should also be comparable so as to allow a default sorted output. For example: 78 | 79 | ```python 80 | x = np.array([[0, 1], [2, 0]]) 81 | 82 | # original matrix 83 | oe.contract(x, (0, 1)) 84 | #> array([[0, 1], 85 | #> [2, 0]]) 86 | 87 | # the transpose 88 | oe.contract(x, (1, 0)) 89 | #> array([[0, 2], 90 | #> [1, 0]]) 91 | 92 | # original matrix, consistent behavior 93 | oe.contract(x, ('a', 'b')) 94 | #> array([[0, 1], 95 | #> [2, 0]]) 96 | 97 | # the transpose, consistent behavior 98 | >>> oe.contract(x, ('b', 'a')) 99 | #> array([[0, 2], 100 | #> [1, 0]]) 101 | 102 | # relative sequence undefined, can't determine output 103 | >>> oe.contract(x, (0, 'a')) 104 | #> TypeError: For this input type lists must contain either Ellipsis 105 | #> or hashable and comparable object (e.g. int, str) 106 | ``` 107 | 108 | -------------------------------------------------------------------------------- /docs/getting_started/install.md: -------------------------------------------------------------------------------- 1 | # Install opt_einsum 2 | 3 | You can install `opt_einsum` with `conda`, with `pip`, or by installing from source. 4 | 5 | ## Conda 6 | 7 | You can update `opt_einsum` using [`conda`](https://www.anaconda.com/download/): 8 | 9 | ```bash 10 | conda install opt_einsum -c conda-forge 11 | ``` 12 | 13 | This installs `opt_einsum` and the NumPy dependency. 14 | 15 | The `opt_einsum` package is maintained on the [conda-forge channel](https://conda-forge.github.io/). 16 | 17 | 18 | ## Pip 19 | 20 | To install `opt_einsum` with `pip` there are a few options, depending on which 21 | dependencies you would like to keep up to date: 22 | 23 | * `pip install opt_einsum` 24 | 25 | ## Install from Source 26 | 27 | To install opt_einsum from source, clone the repository from [github](https://github.com/dgasmith/opt_einsum): 28 | 29 | ```bash 30 | git clone https://github.com/dgasmith/opt_einsum.git 31 | cd opt_einsum 32 | python setup.py install 33 | ``` 34 | 35 | or use `pip` locally if you want to install all dependencies as well:: 36 | 37 | ```bash 38 | pip install -e . 39 | ``` 40 | 41 | 42 | ## Test 43 | 44 | Test `opt_einsum` with `py.test`: 45 | 46 | ```bash 47 | cd opt_einsum 48 | pytest 49 | ``` 50 | -------------------------------------------------------------------------------- /docs/getting_started/reusing_paths.md: -------------------------------------------------------------------------------- 1 | # Reusing Paths 2 | 3 | If you expect to use a particular contraction repeatedly, it can make things simpler and more efficient not to compute the path each time. 4 | Instead, supplying [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression) with the contraction string and the shapes of the tensors generates a [`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) which can then be repeatedly called with any matching set of arrays. 5 | For example: 6 | 7 | ```python 8 | my_expr = oe.contract_expression("abc,cd,dbe->ea", (2, 3, 4), (4, 5), (5, 3, 6)) 9 | print(my_expr) 10 | #> ea')> 11 | #> 1. 'dbe,cd->bce' [GEMM] 12 | #> 2. 'bce,abc->ea' [GEMM] 13 | ``` 14 | 15 | The `ContractExpression` can be called with 3 arrays that match the original shapes without having to recompute the path: 16 | 17 | ```python 18 | x, y, z = (np.random.rand(*s) for s in [(2, 3, 4), (4, 5), (5, 3, 6)]) 19 | my_expr(x, y, z) 20 | #> array([[ 3.08331541, 4.13708916], 21 | #> [ 2.92793729, 4.57945185], 22 | #> [ 3.55679457, 5.56304115], 23 | #> [ 2.6208398 , 4.39024187], 24 | #> [ 3.66736543, 5.41450334], 25 | #> [ 3.67772272, 5.46727192]]) 26 | ``` 27 | 28 | Note that few checks are performed when calling the expression, and while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal. 29 | 30 | 31 | ## Specifying Constants 32 | 33 | Often one generates contraction expressions where some of the tensor arguments 34 | will remain *constant* across many calls. 35 | [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression) allows you to specify the indices of 36 | these constant arguments, allowing `opt_einsum` to build and then reuse as 37 | many constant contractions as possible. 38 | 39 | Take for example the equation: 40 | 41 | ```python 42 | eq = "ij,jk,kl,lm,mn->ni" 43 | ``` 44 | 45 | where we know that *only* the first and last tensors will vary between calls. 46 | We can specify this by marking the middle three as constant - we then need to 47 | supply the actual arrays rather than just the shapes to 48 | [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression): 49 | 50 | ```python 51 | # A B C D E 52 | shapes = [(9, 5), (5, 5), (5, 5), (5, 5), (5, 8)] 53 | 54 | # mark the middle three arrays as constant 55 | constants = [1, 2, 3] 56 | 57 | # generate the constant arrays 58 | B, C, D = [np.random.randn(*shapes[i]) for i in constants] 59 | 60 | # supplied ops are now mix of shapes and arrays 61 | ops = (9, 5), B, C, D, (5, 8) 62 | 63 | expr = oe.contract_expression(eq, *ops, constants=constants) 64 | expr 65 | #> ni', constants=[1, 2, 3])> 66 | ``` 67 | 68 | The expression now only takes the remaining two arrays as arguments (the 69 | tensors with `'ij'` and `'mn'` indices), and will store as many reusable 70 | constant contractions as possible. 71 | 72 | .. code:: python 73 | 74 | ```python 75 | A1, E1 = np.random.rand(*shapes[0]), np.random.rand(*shapes[-1]) 76 | out1 = expr(A1, E1) 77 | out1.shap 78 | #> (8, 9) 79 | 80 | A2, E2 = np.random.rand(*shapes[0]), np.random.rand(*shapes[-1]) 81 | out2 = expr(A2, E2) 82 | out2.shape 83 | #> (8, 9) 84 | 85 | np.allclose(out1, out2) 86 | #> False 87 | 88 | print(expr) 89 | #> ni', constants=[1, 2, 3])> 90 | #> 1. 'jm,mn->jn' [GEMM] 91 | #> 2. 'jn,ij->ni' [GEMM] 92 | ``` 93 | 94 | Where we can see that the expression now only has to perform 95 | two contractions to compute the output. 96 | 97 | !!! note 98 | The constant part of an expression is lazily generated upon the first call 99 | (specific to each backend), though it can also be explicitly built by calling 100 | [`opt_einsum.contract.ContractExpression.evaluate_constants`](../api_reference.md#opt_einsumcontractcontractexpression). 101 | 102 | We can confirm the advantage of using expressions and constants by timing the 103 | following scenarios, first setting 104 | `A = np.random.rand(*shapes[0])` and `E = np.random.rand(*shapes[-1])`. 105 | 106 | ### Contract from scratch 107 | 108 | ```python 109 | %timeit oe.contract(eq, A, B, C, D, E) 110 | #> 239 µs ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 111 | ``` 112 | 113 | ### Contraction with an expression but no constants 114 | 115 | ```python 116 | expr_no_consts = oe.contract_expression(eq, *shapes) 117 | %timeit expr_no_consts(A, B, C, D, E) 118 | #> 76.7 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 119 | ``` 120 | 121 | ### Contraction with an expression and constants marked 122 | 123 | ```python 124 | %timeit expr(A, E) 125 | #> 40.8 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 126 | ``` 127 | 128 | Although this gives us a rough idea, of course the efficiency savings are 129 | hugely dependent on the size of the contraction and number of possible constant 130 | contractions. 131 | 132 | We also note that even if there are *no* constant contractions to perform, it 133 | can be very advantageous to specify constant tensors for particular backends. 134 | For instance, if a GPU backend is used, the constant tensors will be kept on 135 | the device rather than being transferred each time. 136 | -------------------------------------------------------------------------------- /docs/getting_started/sharing_intermediates.md: -------------------------------------------------------------------------------- 1 | # Sharing Intermediates 2 | 3 | If you want to compute multiple similar contractions with common terms, you can embed them in a [`opt_einsum.shared_intermediates`](../api_reference.md#opt_einsumshared_intermediates) context. Computations of subexpressions in this context will be memoized, and will be garbage collected when the contexts exits. 4 | 5 | For example, suppose we want to compute marginals at each point in a factor chain: 6 | 7 | ```python 8 | inputs = 'ab,bc,cd,de,ef' 9 | factors = [np.random.rand(1000, 1000) for _ in range(5)] 10 | 11 | %%timeit 12 | marginals = {output: contract('{}->{}'.format(inputs, output), *factors) 13 | for output in 'abcdef'} 14 | #> 1 loop, best of 3: 5.82 s per loop 15 | ``` 16 | 17 | To share this computation, we can perform all contractions in a shared context: 18 | 19 | ```python 20 | %%timeit 21 | with shared_intermediates(): 22 | marginals = {output: contract('{}->{}'.format(inputs, output), *factors) 23 | for output in 'abcdef'} 24 | #> 1 loop, best of 3: 1.55 s per loop 25 | ``` 26 | 27 | If it is difficult to fit your code into a context, you can instead save the sharing cache for later reuse. 28 | 29 | ```python 30 | with shared_intermediates() as cache: # create a cache 31 | pass 32 | marginals = {} 33 | for output in 'abcdef': 34 | with shared_intermediates(cache): # reuse a common cache 35 | marginals[output] = contract('{}->{}'.format(inputs, output), *factors) 36 | del cache # garbage collect intermediates 37 | ``` 38 | 39 | Note that sharing contexts can be nested, so it is safe to to use [`opt_einsum.shared_intermediates`](../api_reference.md#opt_einsumshared_intermediates) in library code without leaking intermediates into user caches. 40 | 41 | !!! note 42 | By default a cache is thread safe, to share intermediates between threads explicitly pass the same cache to each thread. 43 | -------------------------------------------------------------------------------- /docs/img/ex_dask_reuse_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgasmith/opt_einsum/f973f1e3265f248680f502807f2fdca13563cf1a/docs/img/ex_dask_reuse_graph.png -------------------------------------------------------------------------------- /docs/img/path_finding_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgasmith/opt_einsum/f973f1e3265f248680f502807f2fdca13563cf1a/docs/img/path_finding_time.png -------------------------------------------------------------------------------- /docs/img/path_found_flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgasmith/opt_einsum/f973f1e3265f248680f502807f2fdca13563cf1a/docs/img/path_found_flops.png -------------------------------------------------------------------------------- /docs/img/path_found_flops_random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgasmith/opt_einsum/f973f1e3265f248680f502807f2fdca13563cf1a/docs/img/path_found_flops_random.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | Optimized einsum can significantly reduce the overall execution time of einsum-like 4 | expressions by optimizing the expression's contraction order and dispatching 5 | many operations to canonical BLAS, cuBLAS, or other specialized routines. 6 | Optimized einsum is agnostic to the backend and can handle NumPy, Dask, 7 | PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as 8 | potentially any library which conforms to a standard API. 9 | 10 | ## Features 11 | 12 | The algorithms found in this repository often power the `einsum` optimizations 13 | in many of the above projects. For example, the optimization of `np.einsum` 14 | has been passed upstream and most of the same features that can be found in 15 | this repository can be enabled with `numpy.einsum(..., optimize=True)`. However, 16 | this repository often has more up to date algorithms for complex contractions. 17 | Several advanced features are as follows: 18 | 19 | * Inspect [detailed information](paths/introduction.md) about the path chosen. 20 | * Perform contractions with [numerous backends](getting_started/backends.md), including on the GPU and with libraries such as [TensorFlow](https://www.tensorflow.org) and [PyTorch](https://pytorch.org). 21 | * Generate [reusable expressions](getting_started/reusing_paths.md), potentially with constant tensors, that can be compiled for greater performance. 22 | * Use an arbitrary number of indices to find contractions for [hundreds or even thousands of tensors](examples/large_expr_with_greedy.md). 23 | * Share [intermediate computations](getting_started/sharing_intermediates.md) among multiple contractions. 24 | * Compute gradients of tensor contractions using [Autograd](https://github.com/HIPS/autograd) or [JAX](https://github.com/google/jax). 25 | 26 | ## Example 27 | 28 | Take the following einsum-like expression: 29 | 30 | $$ 31 | M_{pqrs} = C_{pi} C_{qj} I_{ijkl} C_{rk} C_{sl} 32 | $$ 33 | 34 | and consider two different algorithms: 35 | 36 | ```python 37 | import numpy as np 38 | 39 | dim = 10 40 | I = np.random.rand(dim, dim, dim, dim) 41 | C = np.random.rand(dim, dim) 42 | 43 | def naive(I, C): 44 | # N^8 scaling 45 | return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 46 | 47 | def optimized(I, C): 48 | # N^5 scaling 49 | K = np.einsum('pi,ijkl->pjkl', C, I) 50 | K = np.einsum('qj,pjkl->pqkl', C, K) 51 | K = np.einsum('rk,pqkl->pqrl', C, K) 52 | K = np.einsum('sl,pqrl->pqrs', C, K) 53 | return K 54 | ``` 55 | 56 | ```python 57 | >>> np.allclose(naive(I, C), optimized(I, C)) 58 | True 59 | ``` 60 | 61 | Most einsum functions do not consider building intermediate arrays; 62 | therefore, helping einsum functions by creating these intermediate arrays can result 63 | in considerable cost savings even for small N (N=10): 64 | 65 | ```python 66 | %timeit naive(I, C) 67 | 1 loops, best of 3: 829 ms per loop 68 | 69 | %timeit optimized(I, C) 70 | 1000 loops, best of 3: 445 µs per loop 71 | ``` 72 | 73 | The index transformation is a well-known contraction that leads to 74 | straightforward intermediates. This contraction can be further 75 | complicated by considering that the shape of the C matrices need not be 76 | the same, in this case, the ordering in which the indices are transformed 77 | matters significantly. Logic can be built that optimizes the order; 78 | however, this is a lot of time and effort for a single expression. 79 | 80 | The `opt_einsum` package is a typically a drop-in replacement for `einsum` 81 | functions and can handle this logic and path finding for you: 82 | 83 | ```python 84 | from opt_einsum import contract 85 | 86 | dim = 30 87 | I = np.random.rand(dim, dim, dim, dim) 88 | C = np.random.rand(dim, dim) 89 | 90 | %timeit optimized(I, C) 91 | 10 loops, best of 3: 65.8 ms per loop 92 | 93 | %timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 94 | 100 loops, best of 3: 16.2 ms per loop 95 | ``` 96 | 97 | The above will automatically find the optimal contraction order, in this case, 98 | identical to that of the optimized function above, and compute the products 99 | for you. Additionally, `contract` can use vendor BLAS with the `numpy.dot` 100 | function under the hood to exploit additional parallelism and performance. 101 | 102 | Details about the optimized contraction order can be explored: 103 | 104 | ```python 105 | >>> import opt_einsum as oe 106 | 107 | >>> path_info = oe.contract_path('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 108 | 109 | >>> print(path_info[0]) 110 | [(0, 2), (0, 3), (0, 2), (0, 1)] 111 | 112 | >>> print(path_info[1]) 113 | Complete contraction: pi,qj,ijkl,rk,sl->pqrs 114 | Naive scaling: 8 115 | Optimized scaling: 5 116 | Naive FLOP count: 8.000e+08 117 | Optimized FLOP count: 8.000e+05 118 | Theoretical speedup: 1000.000 119 | Largest intermediate: 1.000e+04 elements 120 | -------------------------------------------------------------------------------- 121 | scaling BLAS current remaining 122 | -------------------------------------------------------------------------------- 123 | 5 GEMM ijkl,pi->jklp qj,rk,sl,jklp->pqrs 124 | 5 GEMM jklp,qj->klpq rk,sl,klpq->pqrs 125 | 5 GEMM klpq,rk->lpqr sl,lpqr->pqrs 126 | 5 GEMM lpqr,sl->pqrs pqrs->pqrs 127 | ``` 128 | 129 | 130 | 131 | ## Citation 132 | 133 | If this code has benefited your research, please support us by citing: 134 | 135 | Daniel G. A. Smith and Johnnie Gray, opt_einsum - A Python package for optimizing contraction order for einsum-like expressions. **Journal of Open Source Software**, *2018*, 3(26), 753 136 | 137 | DOI: https://doi.org/10.21105/joss.00753 138 | -------------------------------------------------------------------------------- /docs/javascript/config.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/paths/branching_path.md: -------------------------------------------------------------------------------- 1 | # The Branching Path 2 | 3 | While the `optimal` path is guaranteed to find the smallest estimate FLOP 4 | cost, it spends a lot of time exploring paths which are not likely to result in 5 | an optimal path. For instance, outer products are usually not advantageous 6 | unless absolutely necessary. Additionally, by trying a 'good' path first, it 7 | should be possible to quickly establish a threshold FLOP cost which can then be 8 | used to prune many bad paths. 9 | 10 | The **branching** strategy (provided by [`opt_einsum.paths.branch`](../api_reference.md#opt_einsumpathsbranch)) does 11 | this by taking the recursive, depth-first approach of 12 | [`opt_einsum.paths.optimal`](../api_reference.md#opt_einsumpathsoptimal), whilst also sorting potential contractions 13 | based on a heuristic cost, as in [`opt_einsum.paths.greedy`](../api_reference.md#opt_einsumpathsgreedy). 14 | 15 | There are two main flavours: 16 | 17 | - `optimize='branch-all'`: explore **all** inner products, starting with 18 | those that look best according to the cost heuristic. 19 | - `optimize='branch-2'`: similar, but at each step only explore the 20 | estimated best **two** possible contractions, leading to a maximum of 21 | 2^N paths assessed. 22 | 23 | In both cases, [`opt_einsum.paths.branch`](../api_reference.md#opt_einsumpathsbranch) takes an active approach to 24 | pruning paths well before they hit the best *total* FLOP count, by comparing 25 | them to the FLOP count (times some factor) achieved by the best path at the 26 | same point in the contraction. 27 | 28 | There is also `'branch-1'`, which, since it only explores a single path at 29 | each step does not really 'branch' - this is essentially the approach of 30 | `'greedy'`. 31 | In comparison, `'branch-1'` will be slower for large expressions, but for 32 | small to medium expressions it might find slightly higher quality contractions 33 | due to considering individual flop costs at each step. 34 | 35 | The default `optimize='auto'` mode of `opt_einsum` will use 36 | `'branch-all'` for 5 or 6 tensors, though it should be able to handle 37 | 12-13 tensors in a matter or seconds. Likewise, `'branch-2'` will be used for 38 | 7 or 8 tensors, though it should be able to handle 20-22 tensors in a matter of 39 | seconds. Finally, `'branch-1'` will be used by `'auto'` for expressions of 40 | up to 14 tensors. 41 | 42 | 43 | Customizing the Branching Path 44 | ------------------------------ 45 | 46 | The 'branch and bound' path can be customized by creating a custom 47 | [`opt_einsum.paths.BranchBound`](../api_reference.md#opt_einsumpathsbranchbound) instance. For example: 48 | 49 | ```python 50 | optimizer = oe.BranchBound(nbranch=3, minimize='size', cutoff_flops_factor=None) 51 | path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer) 52 | ``` 53 | 54 | You could then tweak the settings (e.g. `optimizer.nbranch = 4`) and the best 55 | bound found so far will persist and be used to prune paths on the next call: 56 | 57 | ```python 58 | optimizer.nbranch = 4 59 | path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer) 60 | ``` 61 | -------------------------------------------------------------------------------- /docs/paths/custom_paths.md: -------------------------------------------------------------------------------- 1 | # Custom Path Optimizers 2 | 3 | If you want to implement or just experiment with custom contaction paths then 4 | you can easily by subclassing the [`opt_einsum.paths.PathOptimizer`](../api_reference.md#opt_einsum.paths.PathOptimizer) 5 | object. For example, imagine we want to test the path that just blindly 6 | contracts the first pair of tensors again and again. We would implement this 7 | as: 8 | 9 | ```python 10 | import opt_einsum as oe 11 | 12 | class MyOptimizer(oe.paths.PathOptimizer): 13 | 14 | def __call__(self, inputs, output, size_dict, memory_limit=None): 15 | return [(0, 1)] * (len(inputs) - 1) 16 | ``` 17 | 18 | Once defined we can use this as: 19 | 20 | ```python 21 | import numpy as np 22 | 23 | # set-up a random contraction 24 | eq, shapes = oe.helpers.rand_equation(10, 3, seed=42) 25 | arrays = list(map(np.ones, shapes)) 26 | 27 | # set-up our optimizer and use it 28 | optimizer = MyOptimizer() 29 | path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer) 30 | 31 | print(path) 32 | #> [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] 33 | 34 | print(path_info.speedup) 35 | #> 133.21363671496357 36 | ``` 37 | 38 | Note that though we still get a considerable speedup over `einsum` this is 39 | of course not a good strategy to take in general. 40 | 41 | 42 | ## Custom Random Optimizers 43 | 44 | If your custom path optimizer is inherently random, then you can reuse all the 45 | machinery of the random-greedy approach. Namely: 46 | 47 | - A **max-repeats** or **max-time** approach 48 | - Minimization with respect to total flops or largest intermediate size 49 | - Parallelization using a pool-executor 50 | 51 | This is done by subclassing the 52 | [`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsum.path_random.RandomOptimizer) 53 | object and implementing a 54 | `setup` method. Here's an example where we just randomly select any path 55 | (again, although we get a considerable speedup over `einsum` this is 56 | not a good strategy to take in general): 57 | 58 | ```python 59 | from opt_einsum.path_random import ssa_path_compute_cost 60 | 61 | class MyRandomOptimizer(oe.path_random.RandomOptimizer): 62 | 63 | @staticmethod 64 | def random_path(r, n, inputs, output, size_dict): 65 | """Picks a completely random contraction order. 66 | """ 67 | np.random.seed(r) 68 | ssa_path = [] 69 | remaining = set(range(n)) 70 | while len(remaining) > 1: 71 | i, j = np.random.choice(list(remaining), size=2, replace=False) 72 | remaining.add(n + len(ssa_path)) 73 | remaining.remove(i) 74 | remaining.remove(j) 75 | ssa_path.append((i, j)) 76 | cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict) 77 | return ssa_path, cost, size 78 | 79 | def setup(self, inputs, output, size_dict): 80 | """Prepares the function and arguments to repeatedly call. 81 | """ 82 | n = len(inputs) 83 | trial_fn = self.random_path 84 | trial_args = (n, inputs, output, size_dict) 85 | return trial_fn, trial_args 86 | ``` 87 | 88 | Which we can now instantiate using various other options: 89 | 90 | ```python 91 | optimizer = MyRandomOptimizer(max_repeats=1000, max_time=10, 92 | parallel=True, minimize='size') 93 | path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer) 94 | 95 | print(path) 96 | #> [(3, 4), (1, 3), (0, 3), (3, 5), (3, 4), (3, 4), (1, 0), (0, 1), (0, 1)] 97 | 98 | print(path_info.speedup) 99 | #> 712829.9451056132 100 | ``` 101 | 102 | There are a few things to note here: 103 | 104 | 1. The core function (`MyRandomOptimizer.random_path` here), should take a 105 | trial number `r` as it first argument 106 | 2. It should return a *ssa_path* (see `opt_einsum.paths.ssa_to_linear` and 107 | `opt_einsum.paths.linear_to_ssa`) as well as a flops-cost and max-size. 108 | 3. The `setup` method prepares this function, as well as any input to it, 109 | so that the trials will look roughly like 110 | `[trial_fn(r, *trial_args) for r in range(max_repeats)]`. If you need to 111 | parse the standard arguments (into a network for example), it thus only 112 | needs to be done once per optimization 113 | 114 | More details about 115 | [`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsumpath_randomrandomoptimizer) 116 | options can 117 | be found in [`RandomGreedyPathPage`](./random_greedy_path.md) section. 118 | -------------------------------------------------------------------------------- /docs/paths/dp_path.md: -------------------------------------------------------------------------------- 1 | # The Dynamic Programming Path 2 | 3 | The dynamic programming (DP) approach described in reference [1] provides an efficient 4 | way to find an asymptotically optimal contraction path by running the following steps: 5 | 6 | 1. Compute all traces, i.e. summations over indices occurring exactly in one 7 | input. 8 | 2. Decompose the contraction graph of inputs into disconnected subgraphs. Two 9 | inputs are connected if they share at least one summation index. 10 | 3. Find the contraction path for each of the disconnected subgraphs using a 11 | DP approach: The optimal contraction path for all sets of `n` (ranging 12 | from 1 to the number of inputs) connected tensors is found by combining 13 | sets of `m` and `n-m` tensors. 14 | 15 | Note that computing all the traces in the very beginning can never lead to a 16 | non-optimal contraction path. 17 | 18 | Contractions of disconnected subgraphs can be optimized independently, which 19 | still results in an optimal contraction path. However, the computational 20 | complexity of finding the contraction path is drastically reduced: If the 21 | subgraphs consist of `n1`, `n2`, ... inputs, the computational complexity 22 | is reduced from `O(exp(n1 + n2 + ...))` to `O(exp(n1) + exp(n2) + ...)`. 23 | 24 | The DP approach will only perform pair contractions and by default will never 25 | compute intermediate outer products as in reference [1] it is shown that this 26 | always results in an asymptotically optimal contraction path. 27 | 28 | A major optimization for DP is the cost capping strategy: The DP optimization 29 | only memorizes contractions for a subset of inputs, if the total cost for this 30 | contraction is smaller than the cost cap. The cost cap is initialized with 31 | the minimal possible cost, i.e. the product of all output dimensions, and is 32 | iteratively increased by multiplying it with the smallest dimension 33 | until a contraction path including all inputs is found. 34 | 35 | Note that the worst case scaling of DP is exponential in the number 36 | of inputs. Nevertheless, if the contraction graph is not completely random, 37 | but exhibits a certain kind of structure, it can be used for large 38 | contraction graphs and is guaranteed to find an asymptotically optimal 39 | contraction path. For this reason it is the most frequently used contraction 40 | path optimizer in the field of tensor network states. 41 | 42 | More specifically, the search is performed over connected subgraphs, which, for 43 | example, planar and tree-like graphs have far fewer of. As a rough guide, if 44 | the graph is planar, expressions with many tens of tensors are tractable, 45 | whereas if the graph is tree-like, expressions with many hundreds of tensors 46 | are tractable. 47 | 48 | 49 | [1] Robert N. C. Pfeifer, Jutho Haegeman, and Frank Verstraete Phys. Rev. E 90, 033315 (2014). https://arxiv.org/abs/1304.6112 50 | 51 | 52 | Customizing the Dynamic Programming Path 53 | ---------------------------------------- 54 | 55 | The default `optimize='dp'` approach has sensible defaults but can be 56 | customized with the [`opt_einsum.paths.DynamicProgramming`](../api_reference.md#opt_einsumpathsdynamicprogramming) object. 57 | 58 | ```python 59 | import opt_einsum as oe 60 | 61 | optimizer = oe.DynamicProgramming( 62 | minimize='size', # optimize for largest intermediate tensor size 63 | search_outer=True, # search through outer products as well 64 | cost_cap=False, # don't use cost-capping strategy 65 | ) 66 | 67 | oe.contract(eq, *arrays, optimize=optimizer) 68 | ``` 69 | 70 | !!! warning 71 | Note that searching outer products will most likely drastically slow down 72 | the optimizer on all but the smallest examples. 73 | 74 | 75 | The values that `minimize` can take are: 76 | 77 | - `'flops'`: minimize the total number of scalar operations. 78 | - `'size'`: minimize the size of the largest intermediate. 79 | - `'write'`: minimize the combined size of all intermediate tensors - 80 | approximately speaking the amount of memory that will be written. This is 81 | relevant if you were to automatically differentiate through the 82 | contraction, which naively would require storing all intermediates. 83 | - `'combo'` - minimize `flops + alpha * write` summed over intermediates, a 84 | default ratio of `alpha=64` is used, or it can be customized with 85 | `f'combo-{alpha}'`. 86 | - `'limit'` - minimize `max(flops, alpha * write)` summed over intermediates, a 87 | default ratio of `alpha=64` is used, or it can be customized with `f'limit-{alpha}'`. 88 | 89 | The last two take into account the fact that real contraction performance can 90 | be bound by memory speed, and so favor paths with higher arithmetic 91 | intensity. The default value of `alpha=64` is reasonable for both typical 92 | CPUs and GPUs. 93 | -------------------------------------------------------------------------------- /docs/paths/greedy_path.md: -------------------------------------------------------------------------------- 1 | # The Greedy Path 2 | 3 | The `'greedy'` approach provides a very efficient strategy for finding 4 | contraction paths for expressions with large numbers of tensors. 5 | It does this by eagerly choosing contractions in three stages: 6 | 7 | 1. Eagerly compute any **Hadamard** products (in arbitrary order -- this is 8 | commutative). 9 | 2. Greedily contract pairs of remaining tensors, at each step choosing the 10 | pair that maximizes `reduced_size` -- these are generally **inner** 11 | products. 12 | 3. Greedily compute any pairwise **outer** products, at each step choosing 13 | the pair that minimizes `sum(input_sizes)`. 14 | 15 | The cost heuristic `reduced_size` is simply the size of the pair of potential 16 | tensors to be contracted, minus the size of the resulting tensor. 17 | 18 | The `greedy` algorithm has space and time complexity `O(n * k)` where `n` 19 | is the number of input tensors and `k` is the maximum number of tensors that 20 | share any dimension (excluding dimensions that occur in the output or in every 21 | tensor). As such, the algorithm scales well to very large sparse contractions 22 | of low-rank tensors, and indeed, often finds the optimal, or close to optimal 23 | path in such cases. 24 | 25 | The `greedy` functionality is provided by [`opt_einsum.paths.greedy`](../api_reference.md#opt_einsumpathsgreedy), 26 | and is selected by the default `optimize='auto'` mode of `opt_einsum` for 27 | expressions with many inputs. Expressions of up to a thousand tensors 28 | should still take well less than a second to find paths for. 29 | 30 | 31 | Optimal Scaling Misses 32 | ---------------------- 33 | 34 | The greedy algorithm, while inexpensive, can occasionally miss optimal scaling in some circumstances as seen below. The `greedy` algorithm prioritizes expressions which remove the largest indices first, in this particular case this is the incorrect choice and it is difficult for any heuristic algorithm to "see ahead" as would be needed here. 35 | 36 | It should be stressed these cases are quite rare and by default `contract` uses the `optimal` path for four and fewer inputs as the cost of evaluating the `optimal` path is similar to that of the `greedy` path. Similarly, for 5-8 inputs, `contract` uses one of the 37 | branching strategies which can find higher quality paths. 38 | 39 | ```python 40 | M = np.random.rand(35, 37, 59) 41 | A = np.random.rand(35, 51, 59) 42 | B = np.random.rand(37, 51, 51, 59) 43 | C = np.random.rand(59, 27) 44 | 45 | path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy") 46 | print(desc) 47 | #> Complete contraction: xyf,xtf,ytpf,fr->tpr 48 | #> Naive scaling: 6 49 | #> Optimized scaling: 5 50 | #> Naive FLOP count: 2.146e+10 51 | #> Optimized FLOP count: 4.165e+08 52 | #> Theoretical speedup: 51.533 53 | #> Largest intermediate: 5.371e+06 elements 54 | #> -------------------------------------------------------------------------------- 55 | #> scaling BLAS current remaining 56 | #> -------------------------------------------------------------------------------- 57 | #> 5 False ytpf,xyf->tpfx xtf,fr,tpfx->tpr 58 | #> 4 False tpfx,xtf->tpf fr,tpf->tpr 59 | #> 4 GEMM tpf,fr->tpr tpr->tpr 60 | 61 | path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal") 62 | print(desc) 63 | #> Complete contraction: xyf,xtf,ytpf,fr->tpr 64 | #> Naive scaling: 6 65 | #> Optimized scaling: 4 66 | #> Naive FLOP count: 2.146e+10 67 | #> Optimized FLOP count: 2.744e+07 68 | #> Theoretical speedup: 782.283 69 | #> Largest intermediate: 1.535e+05 elements 70 | #> -------------------------------------------------------------------------------- 71 | #> scaling BLAS current remaining 72 | #> -------------------------------------------------------------------------------- 73 | #> 4 False xtf,xyf->tfy ytpf,fr,tfy->tpr 74 | #> 4 False tfy,ytpf->tfp fr,tfp->tpr 75 | #> 4 TDOT tfp,fr->tpr tpr->tpr 76 | ``` 77 | 78 | 79 | So we can see that the `greedy` algorithm finds a path which is about 16 80 | times slower than the `optimal` one. In such cases, it might be worth using 81 | one of the more exhaustive optimization strategies: `'optimal'`, 82 | `'branch-all'` or `branch-2` (all of which will find the optimal path in 83 | this example). 84 | 85 | 86 | Customizing the Greedy Path 87 | --------------------------- 88 | 89 | The greedy path is a local optimizer in that it only ever assesses pairs of 90 | tensors to contract, assigning each a heuristic 'cost' and then choosing the 91 | 'best' of these. Custom greedy approaches can be implemented by supplying 92 | callables to the `cost_fn` and `choose_fn` arguments of 93 | [`opt_einsum.paths.greedy`](../api_reference.md#opt_einsumpathsgreedy). 94 | -------------------------------------------------------------------------------- /docs/paths/introduction.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Performing an optimized tensor contraction to speed up `einsum` involves two 4 | key stages: 5 | 6 | 1. Finding a pairwise contraction order, or **'path'**. 7 | 2. Performing the sequence of contractions given this path. 8 | 9 | The better the quality of path found in the first step, the quicker the actual 10 | contraction in the second step can be -- often dramatically. However, finding 11 | the *optimal* path is an NP-hard problem that can quickly become intractable, 12 | meaning that a balance must be struck between the time spent finding a path, 13 | and its quality. `opt_einsum` handles this by using several path finding 14 | algorithms, which can be manually specified using the `optimize` keyword. 15 | These are: 16 | 17 | - The `'optimal'` strategy - an exhaustive search of all possible paths 18 | - The `'dynamic-programming'` strategy - a near-optimal search based off dynamic-programming 19 | - The `'branch'` strategy - a more restricted search of many likely paths 20 | - The `'greedy'` strategy - finds a path one step at a time using a cost 21 | heuristic 22 | 23 | By default (`optimize='auto'`), [`opt_einsum.contract`](../api_reference.md#opt_einsumcontract) will select the 24 | best of these it can while aiming to keep path finding times below around 1ms. 25 | An analysis of each of these approaches' performance can be found at the bottom of this page. 26 | 27 | For large and complex contractions, there is the `'random-greedy'` approach, 28 | which samples many (by default 32) greedy paths and can be customized to 29 | explicitly spend a maximum amount of time searching. Another preset, 30 | `'random-greedy-128'`, uses 128 paths for a more exhaustive search. 31 | See [`RandomGreedyPath`](./random_greedy_path.md) page for more details on configuring these. 32 | 33 | Finally, there is the `'auto-hq'` preset which targets a much larger search 34 | time (~1sec) in return for finding very high quality paths, dispatching to the 35 | `'optimal'`, `'dynamic-programming'` and then `'random-greedy-128'` paths 36 | depending on contraction size. 37 | 38 | If you want to find the path separately to performing the 39 | contraction, or just inspect information about the path found, you can use the 40 | function [`opt_einsum.contract_path`](../api_reference.md#opt_einsumcontract_path). 41 | 42 | 43 | ## Examining the Path 44 | 45 | As an example, consider the following expression found in a perturbation theory (one of ~5,000 such expressions): 46 | 47 | ```python 48 | 'bdik,acaj,ikab,ajac,ikbd' 49 | ``` 50 | 51 | At first, it would appear that this scales like N^7 as there are 7 unique indices; however, we can define a intermediate to reduce this scaling. 52 | 53 | ```python 54 | # (N^5 scaling) 55 | a = 'bdik,ikab,ikbd' 56 | 57 | # (N^4 scaling) 58 | result = 'acaj,ajac,a' 59 | ``` 60 | 61 | This is a single possible path to the final answer (and notably, not the most optimal) out of many possible paths. Now, let opt_einsum compute the optimal path: 62 | 63 | ```python 64 | import opt_einsum as oe 65 | 66 | # Take a complex string 67 | einsum_string = 'bdik,acaj,ikab,ajac,ikbd->' 68 | 69 | # Build random views to represent this contraction 70 | unique_inds = set(einsum_string) - {',', '-', '>'} 71 | index_size = [10, 17, 9, 10, 13, 16, 15, 14, 12] 72 | sizes_dict = dict(zip(unique_inds, index_size)) 73 | views = oe.helpers.build_views(einsum_string, sizes_dict) 74 | 75 | path, path_info = oe.contract_path(einsum_string, *views) 76 | 77 | print(path) 78 | #> [(0, 4), (1, 3), (0, 1), (0, 1)] 79 | 80 | print(path_info) 81 | #> Complete contraction: bdik,acaj,ikab,ajac,ikbd-> 82 | #> Naive scaling: 7 83 | #> Optimized scaling: 4 84 | #> Naive FLOP count: 2.387e+8 85 | #> Optimized FLOP count: 8.068e+4 86 | #> Theoretical speedup: 2958.354 87 | #> Largest intermediate: 1.530e+3 elements 88 | #> -------------------------------------------------------------------------------- 89 | #> scaling BLAS current remaining 90 | #> -------------------------------------------------------------------------------- 91 | #> 4 0 ikbd,bdik->ikb acaj,ikab,ajac,ikb-> 92 | #> 4 GEMV/EINSUM ikb,ikab->a acaj,ajac,a-> 93 | #> 3 0 ajac,acaj->a a,a-> 94 | #> 1 DOT a,a-> -> 95 | ``` 96 | 97 | 98 | We can then check that actually performing the contraction produces the expected result: 99 | 100 | ```python 101 | import numpy as np 102 | 103 | einsum_result = np.einsum("bdik,acaj,ikab,ajac,ikbd->", *views) 104 | contract_result = oe.contract("bdik,acaj,ikab,ajac,ikbd->", *views) 105 | 106 | np.allclose(einsum_result, contract_result) 107 | #> True 108 | ``` 109 | 110 | By contracting terms in the correct order we can see that this expression can be computed with N^4 scaling. Even with the overhead of finding the best order or 'path' and small dimensions, 111 | `opt_einsum` is roughly 3000 times faster than pure einsum for this expression. 112 | 113 | 114 | ## Format of the Path 115 | 116 | Let us look at the structure of a canonical `einsum` path found in NumPy and its optimized variant: 117 | 118 | ```python 119 | einsum_path = [(0, 1, 2, 3, 4)] 120 | opt_path = [(1, 3), (0, 2), (0, 2), (0, 1)] 121 | ``` 122 | 123 | In opt_einsum each element of the list represents a single contraction. 124 | In the above example the einsum_path would effectively compute the result as a single contraction identical to that of `einsum`, while the 125 | opt_path would perform four contractions in order to reduce the overall scaling. 126 | The first tuple in the opt_path, `(1,3)`, pops the second and fourth terms, then contracts them together to produce a new term which is then appended to the list of terms, this is continued until all terms are contracted. 127 | An example should illuminate this: 128 | 129 | ```console 130 | --------------------------------------------------------------------------------- 131 | scaling GEMM current remaining 132 | --------------------------------------------------------------------------------- 133 | terms = ['bdik', 'acaj', 'ikab', 'ajac', 'ikbd'] contraction = (1, 3) 134 | 3 False ajac,acaj->a bdik,ikab,ikbd,a-> 135 | terms = ['bdik', 'ikab', 'ikbd', 'a'] contraction = (0, 2) 136 | 4 False ikbd,bdik->bik ikab,a,bik-> 137 | terms = ['ikab', 'a', 'bik'] contraction = (0, 2) 138 | 4 False bik,ikab->a a,a-> 139 | terms = ['a', 'a'] contraction = (0, 1) 140 | 1 DOT a,a-> -> 141 | ``` 142 | 143 | 144 | A path specified in this format can explicitly be supplied directly to 145 | [`opt_einsum.contract`](../api_reference.md#opt_einsumcontract) using the `optimize` keyword: 146 | 147 | ```python 148 | contract_result = oe.contract("bdik,acaj,ikab,ajac,ikbd->", *views, optimize=opt_path) 149 | 150 | np.allclose(einsum_result, contract_result) 151 | #> True 152 | ``` 153 | 154 | 155 | ## Performance Comparison 156 | 157 | The following graphs should give some indication of the tradeoffs between path 158 | finding time and path quality. They are generated by finding paths with each 159 | possible algorithm for many randomly generated networks of `n` tensors with 160 | varying connectivity. 161 | 162 | First we have the time to find each path as a function of the number of terms 163 | in the expression: 164 | 165 | ![Path Finding](../img/path_finding_time.png) 166 | 167 | Clearly the exhaustive (`'optimal'`, `'branch-all'`) and exponential 168 | (`'branch-2'`) searches eventually scale badly, but for modest amounts of 169 | terms they incur only a small overhead. The `'random-greedy'` approach is not 170 | shown here as it is simply `max_repeats` times slower than the `'greedy'` 171 | approach - at least if not parallelized. 172 | 173 | Next we can look at the average FLOP speedup (as compared to the easiest path 174 | to find, `'greedy'`): 175 | 176 | ![Path Finding](../img/path_found_flops.png) 177 | 178 | One can see that the hierarchy of path qualities is: 179 | 180 | 1. `'optimal'` (used by auto for `n <= 4`) 181 | 2. `'branch-all'` (used by auto for `n <= 6`) 182 | 3. `'branch-2'` (used by auto for `n <= 8`) 183 | 4. `'branch-1'` (used by auto for `n <= 14`) 184 | 5. `'greedy'` (used by auto for anything larger) 185 | 186 | !!! note 187 | The performance of the `'random=greedy'` approach (which is never used 188 | automatically) can be found separately in [`RandomGreedyPath`](./random_greedy_path.md) section. 189 | 190 | There are a few important caveats to note with this graph. Firstly, the 191 | benefits of more advanced path finding are very dependent on the complexity of 192 | the expression. For 'simple' contractions, all the different approaches will 193 | *mostly* find the same path (as here). However, for 'tricky' contractions, there 194 | will be certain cases where the more advanced algorithms will find much better 195 | paths. As such, while this graph gives a good idea of the *relative* performance 196 | of each algorithm, the 'average speedup' is not a perfect indicator since 197 | worst-case performance might be more critical. 198 | 199 | Note that the speedups for any of the methods as compared to a standard 200 | `einsum` or a naively chosen path (such as `path=[(0, 1), (0, 1), ...]`) 201 | are all exponentially large and not shown. 202 | -------------------------------------------------------------------------------- /docs/paths/optimal_path.md: -------------------------------------------------------------------------------- 1 | # The Optimal Path 2 | 3 | The most optimal path can be found by searching through every possible way to contract the tensors together, this includes all combinations with the new intermediate tensors as well. 4 | While this algorithm scales like N!, and can often become more costly to compute than the unoptimized contraction itself, it provides an excellent benchmark. 5 | The function that computes this path in opt_einsum is called [`opt_einsum.paths.optimal`](../api_reference.md#opt_einsumpathsoptimal) and works by performing a recursive, depth-first search. By keeping track of the 6 | best path found so far, in terms of total estimated FLOP count, the search can 7 | then quickly prune many paths as soon as as they exceed this best. 8 | This optimal strategy is used by default with the `optimize='auto'` mode of 9 | `opt_einsum` for 4 tensors or less, though it can handle expressions of up to 10 | 9-10 tensors in a matter of seconds. 11 | 12 | 13 | Let us look at an example: 14 | 15 | ```python 16 | Contraction: abc,dc,ac->bd 17 | ``` 18 | 19 | Build a list with tuples that have the following form: 20 | 21 | 22 | ```python 23 | #> iteration 0: 24 | #> "(cost, path, list of input sets remaining)" 25 | #> [ (0, [], [set(['a', 'c', 'b']), set(['d', 'c']), set(['a', 'c'])] ] 26 | ``` 27 | 28 | Since this is iteration zero, we have the initial list of input sets. 29 | We can consider three possible combinations where we contract list positions (0, 1), (0, 2), or (1, 2) together: 30 | 31 | ```python 32 | #> iteration 1: 33 | #> [ (9504, [(0, 1)], [set(['a', 'c']), set(['a', 'c', 'b', 'd']) ]), 34 | #> (1584, [(0, 2)], [set(['c', 'd']), set(['c', 'b']) ]), 35 | #> (864, [(1, 2)], [set(['a', 'c', 'b']), set(['a', 'c', 'd']) ])] 36 | ``` 37 | 38 | We have now run through the three possible combinations, computed the cost of the contraction up to this point, and appended the resulting indices from the contraction to the list. 39 | As all contractions only have two remaining input sets the only possible contraction is (0, 1): 40 | 41 | ```python 42 | #> iteration 2: 43 | #> [ (28512, [(0, 1), (0, 1)], [set(['b', 'd']) ]), 44 | #> (3168, [(0, 2), (0, 1)], [set(['b', 'd']) ]), 45 | #> (19872, [(1, 2), (0, 1)], [set(['b', 'd']) ])] 46 | ``` 47 | 48 | The final contraction cost is computed, and we choose the second path from the list as the overall cost is the lowest. 49 | -------------------------------------------------------------------------------- /docs/paths/random_greedy_path.md: -------------------------------------------------------------------------------- 1 | # The Random-Greedy Path 2 | 3 | For large *and* complex contractions the exhaustive approaches will be too slow 4 | while the greedy path might be very far from optimal. In this case you might 5 | want to consider the `'random-greedy'` path optimizer. This samples many 6 | greedy paths and selects the best one found, which can often be exponentially 7 | better than the average. 8 | 9 | ```python 10 | import opt_einsum as oe 11 | import numpy as np 12 | import math 13 | 14 | eq, shapes = oe.helpers.rand_equation(40, 5, seed=1, d_max=2) 15 | arrays = list(map(np.ones, shapes)) 16 | 17 | path_greedy = oe.contract_path(eq, *arrays, optimize='greedy')[1] 18 | print(math.log2(path_greedy.opt_cost)) 19 | #> 36.04683022558587 20 | 21 | path_rand_greedy = oe.contract_path(eq, *arrays, optimize='random-greedy')[1] 22 | print(math.log2(path_rand_greedy.opt_cost)) 23 | #> 32.203616699170865 24 | ``` 25 | 26 | So here the random-greedy approach has found a path about 27 | 16 times quicker (`= 2^(36 - 32)`). 28 | 29 | This approach works by randomly choosing from the best `n` contractions at 30 | each step, weighted by a 31 | [Boltzmann factor](https://en.wikipedia.org/wiki/Boltzmann_distribution) with 32 | respect to the contraction with the 'best' cost. As such, contractions with 33 | very similar costs will be explored with equal probability, whereas those with 34 | higher costs will be less likely, but still possible. In this way, the 35 | optimizer can randomly explore the huge space of possible paths, but in a 36 | guided manner. 37 | 38 | The following graph roughly demonstrates the potential benefits of the 39 | `'random-greedy'` algorithm, here for large randomly generated contractions, 40 | with either 8, 32 (the default), or 128 repeats: 41 | 42 | ![Path Finding](../img/path_found_flops_random.png) 43 | 44 | !!! note 45 | Bear in mind that such speed-ups are not guaranteed - it very much depends 46 | on how structured or complex your contractions are. 47 | 48 | 49 | ## Customizing the Random-Greedy Path 50 | 51 | The random-greedy optimizer can be customized by instantiating your own 52 | [`opt_einsum.paths.RandomGreedy`](../api_reference.md#opt_einsumpath_randomrandomgreedy) 53 | object. Here you can control: 54 | 55 | - `temperature` - how far to stray from the locally 'best' contractions 56 | - `rel_temperature` - whether to normalize the temperature 57 | - `nbranch` - how many contractions (branches) to consider at each step 58 | - `cost_fn` - how to cost potential contractions 59 | 60 | There are also the main 61 | [`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsumpath_randomrandomoptimizer) 62 | options: 63 | 64 | - `max_repeats` - the maximum number of repeats 65 | - `max_time` - the maximum amount of time to run for (in seconds) 66 | - `minimize` - whether to minimize for total `'flops'` or `'size'` of the 67 | largest intermediate 68 | 69 | For example, here we'll create an optimizer, then change its temperature 70 | whilst reusing it. We'll also set a high `max_repeats` and instead use a 71 | maximum time to terminate the search: 72 | 73 | ```python 74 | optimizer = oe.RandomGreedy(max_time=2, max_repeats=1_000_000) 75 | 76 | for T in [1000, 100, 10, 1, 0.1]: 77 | optimizer.temperature = T 78 | path_rand_greedy = oe.contract_path(eq, *arrays, optimize=optimizer)[1] 79 | print(math.log2(optimizer.best['flops'])) 80 | 81 | #> 32.81709395639357 82 | #> 32.67625007170783 83 | #> 31.719756871539033 84 | #> 31.62043317835677 85 | #> 31.253305891247 86 | 87 | # the total number of trials so far 88 | print(len(optimizer.costs)) 89 | #> 2555 90 | ``` 91 | 92 | So we have improved a bit on the standard `'random-greedy'` (which performs 32 repeats by default). 93 | The `optimizer` object now stores both the best path 94 | found so far - `optimizer.path` - as well as the list of flop-costs and 95 | maximum sizes found for each trial - `optimizer.costs` and 96 | `optimizer.sizes` respectively. 97 | 98 | 99 | ## Parallelizing the Random-Greedy Search 100 | 101 | Since each greedy attempt is independent, the random-greedy approach is 102 | naturally suited to parallelization. This can be automatically handled by 103 | specifying the `parallel` keyword like so: 104 | 105 | ```python 106 | # use same number of processes as cores 107 | optimizer = oe.RandomGreedy(parallel=True) 108 | 109 | # or use specific number of processes 110 | optimizer = oe.RandomGreedy(parallel=4) 111 | ``` 112 | 113 | !!! warning 114 | 115 | The pool-executor used to perform this parallelization is the 116 | `ProcessPoolExecutor` from the [`concurrent.futures` 117 | ](https://docs.python.org/3/library/concurrent.futures.html) module. 118 | 119 | For full control over the parallelization you can supply any 120 | pool-executor like object, which should have an API matching the Python 3 121 | [concurrent.futures](https://docs.python.org/3/library/concurrent.futures.html>) 122 | module: 123 | 124 | ```python 125 | from concurrent.futures import ProcessPoolExecutor 126 | 127 | pool = ProcessPoolExecutor() 128 | optimizer = oe.RandomGreedy(parallel=pool, max_repeats=128) 129 | path_rand_greedy = oe.contract_path(eq, *arrays, optimize=optimizer)[1] 130 | 131 | print(math.log2(optimizer.best['flops'])) 132 | #> 31.64992600300931 133 | ``` 134 | 135 | Other examples of such pools include: 136 | 137 | - [loky](https://loky.readthedocs.io/en/latest/) 138 | - [dask.distributed](http://distributed.dask.org/en/latest/) 139 | - [mpi4py](https://mpi4py.readthedocs.io/en/latest/) 140 | -------------------------------------------------------------------------------- /docs/requirements.yml: -------------------------------------------------------------------------------- 1 | ansi2html==1.* 2 | black 3 | devtools 4 | markdown==3.* 5 | markdown-include==0.* 6 | mkdocstrings[python]==0.29.* 7 | mkdocs==1.* 8 | mkdocs-awesome-pages-plugin==2.* 9 | mkdocs-exclude==1.* 10 | mkdocs-material==9.* 11 | pygments==2.* 12 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Optimized Einsum 2 | repo_url: https://github.com/dgasmith/opt_einsum 3 | repo_name: dgasmith/opt_einsum 4 | theme: 5 | name: material 6 | features: 7 | - navigation.instant 8 | palette: 9 | 10 | # Palette toggle for automatic mode 11 | - media: "(prefers-color-scheme)" 12 | toggle: 13 | icon: material/brightness-auto 14 | name: Switch to light mode 15 | 16 | # Palette toggle for light mode 17 | - media: "(prefers-color-scheme: light)" 18 | scheme: default 19 | toggle: 20 | icon: material/brightness-7 21 | name: Switch to dark mode 22 | 23 | # Palette toggle for dark mode 24 | - media: "(prefers-color-scheme: dark)" 25 | scheme: slate 26 | toggle: 27 | icon: material/brightness-4 28 | name: Switch to system preference 29 | 30 | 31 | plugins: 32 | - search 33 | - awesome-pages 34 | - mkdocstrings: 35 | default_handler: python 36 | handlers: 37 | python: 38 | # paths: [opt_einsum] 39 | options: 40 | docstring_style: google 41 | docstring_options: 42 | ignore_init_summary: true 43 | docstring_section_style: list 44 | filters: ["!^_"] 45 | heading_level: 1 46 | inherited_members: true 47 | merge_init_into_class: true 48 | parameter_headings: true 49 | preload_modules: [mkdocstrings] 50 | separate_signature: true 51 | show_root_heading: true 52 | show_root_full_path: false 53 | show_signature_annotations: true 54 | show_source: false 55 | show_symbol_type_heading: true 56 | show_symbol_type_toc: true 57 | signature_crossrefs: true 58 | summary: true 59 | unwrap_annotated: true 60 | 61 | extra_javascript: 62 | - javascript/config.js 63 | - https://polyfill.io/v3/polyfill.min.js?features=es6 64 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 65 | 66 | extra_css: 67 | - css/custom.css 68 | 69 | markdown_extensions: 70 | - markdown.extensions.codehilite: 71 | guess_lang: false 72 | - markdown_include.include: 73 | base_path: docs 74 | - pymdownx.arithmatex: 75 | generic: true 76 | - admonition 77 | - codehilite 78 | - extra 79 | - pymdownx.extra 80 | - pymdownx.arithmatex: 81 | generic: true 82 | - toc: 83 | toc_depth: 2 84 | 85 | nav: 86 | - Overview: index.md 87 | - Getting Started: 88 | - Installing: getting_started/install.md 89 | - Input Format: getting_started/input_format.md 90 | - "Backends & GPU Support": getting_started/backends.md 91 | - Reusing Paths: getting_started/reusing_paths.md 92 | - Sharing Intermediates: getting_started/sharing_intermediates.md 93 | - Path Information: 94 | - Introduction: paths/introduction.md 95 | - Optimal Path: paths/optimal_path.md 96 | - Branching Path: paths/branching_path.md 97 | - Greedy Path: paths/greedy_path.md 98 | - Random-Greedy Path: paths/random_greedy_path.md 99 | - Dynamic Programming Path: paths/dp_path.md 100 | - Custom Path Optimizers: paths/custom_paths.md 101 | - Examples: 102 | - Reusing Intermediaries with Dask: examples/dask_reusing_intermediaries.md 103 | - Large Expressions with Greedy: examples/large_expr_with_greedy.md 104 | - API Reference: api_reference.md 105 | - Changelog: changelog.md 106 | -------------------------------------------------------------------------------- /opt_einsum/__init__.py: -------------------------------------------------------------------------------- 1 | """Main init function for opt_einsum.""" 2 | 3 | from opt_einsum import blas, helpers, path_random, paths 4 | from opt_einsum._version import __version__ 5 | from opt_einsum.contract import contract, contract_expression, contract_path 6 | from opt_einsum.parser import get_symbol 7 | from opt_einsum.path_random import RandomGreedy 8 | from opt_einsum.paths import BranchBound, DynamicProgramming 9 | from opt_einsum.sharing import shared_intermediates 10 | 11 | __all__ = [ 12 | "__version__", 13 | "blas", 14 | "helpers", 15 | "path_random", 16 | "paths", 17 | "contract", 18 | "contract_expression", 19 | "contract_path", 20 | "get_symbol", 21 | "RandomGreedy", 22 | "BranchBound", 23 | "DynamicProgramming", 24 | "shared_intermediates", 25 | ] 26 | 27 | 28 | paths.register_path_fn("random-greedy", path_random.random_greedy) 29 | paths.register_path_fn("random-greedy-128", path_random.random_greedy_128) 30 | -------------------------------------------------------------------------------- /opt_einsum/_version.py: -------------------------------------------------------------------------------- 1 | # file generated by setuptools_scm 2 | # don't change, don't track in version control 3 | TYPE_CHECKING = False 4 | if TYPE_CHECKING: 5 | from typing import Tuple, Union 6 | VERSION_TUPLE = Tuple[Union[int, str], ...] 7 | else: 8 | VERSION_TUPLE = object 9 | 10 | version: str 11 | __version__: str 12 | __version_tuple__: VERSION_TUPLE 13 | version_tuple: VERSION_TUPLE 14 | 15 | __version__ = version = '0.0.0.dev' 16 | __version_tuple__ = version_tuple = (0, 0, 0, 'dev', '') 17 | -------------------------------------------------------------------------------- /opt_einsum/backends/__init__.py: -------------------------------------------------------------------------------- 1 | """Compute backends for opt_einsum.""" 2 | 3 | # Backends 4 | from opt_einsum.backends.cupy import to_cupy 5 | from opt_einsum.backends.dispatch import ( 6 | build_expression, 7 | evaluate_constants, 8 | get_func, 9 | has_backend, 10 | has_einsum, 11 | has_tensordot, 12 | ) 13 | from opt_einsum.backends.tensorflow import to_tensorflow 14 | from opt_einsum.backends.theano import to_theano 15 | from opt_einsum.backends.torch import to_torch 16 | 17 | __all__ = [ 18 | "get_func", 19 | "has_einsum", 20 | "has_tensordot", 21 | "build_expression", 22 | "evaluate_constants", 23 | "has_backend", 24 | "to_tensorflow", 25 | "to_theano", 26 | "to_cupy", 27 | "to_torch", 28 | ] 29 | -------------------------------------------------------------------------------- /opt_einsum/backends/cupy.py: -------------------------------------------------------------------------------- 1 | """Required functions for optimized contractions of numpy arrays using cupy.""" 2 | 3 | from opt_einsum.helpers import has_array_interface 4 | from opt_einsum.sharing import to_backend_cache_wrap 5 | 6 | __all__ = ["to_cupy", "build_expression", "evaluate_constants"] 7 | 8 | 9 | @to_backend_cache_wrap 10 | def to_cupy(array): # pragma: no cover 11 | import cupy 12 | 13 | if has_array_interface(array): 14 | return cupy.asarray(array) 15 | 16 | return array 17 | 18 | 19 | def build_expression(_, expr): # pragma: no cover 20 | """Build a cupy function based on ``arrays`` and ``expr``.""" 21 | 22 | def cupy_contract(*arrays): 23 | return expr._contract([to_cupy(x) for x in arrays], backend="cupy").get() 24 | 25 | return cupy_contract 26 | 27 | 28 | def evaluate_constants(const_arrays, expr): # pragma: no cover 29 | """Convert constant arguments to cupy arrays, and perform any possible 30 | constant contractions. 31 | """ 32 | return expr(*[to_cupy(x) for x in const_arrays], backend="cupy", evaluate_constants=True) 33 | -------------------------------------------------------------------------------- /opt_einsum/backends/dispatch.py: -------------------------------------------------------------------------------- 1 | """Handles dispatching array operations to the correct backend library, as well 2 | as converting arrays to backend formats and then potentially storing them as 3 | constants. 4 | """ 5 | 6 | import importlib 7 | from typing import Any, Dict, Tuple 8 | 9 | from opt_einsum.backends import cupy as _cupy 10 | from opt_einsum.backends import jax as _jax 11 | from opt_einsum.backends import object_arrays 12 | from opt_einsum.backends import tensorflow as _tensorflow 13 | from opt_einsum.backends import theano as _theano 14 | from opt_einsum.backends import torch as _torch 15 | 16 | __all__ = [ 17 | "get_func", 18 | "has_einsum", 19 | "has_tensordot", 20 | "build_expression", 21 | "evaluate_constants", 22 | "has_backend", 23 | ] 24 | 25 | # known non top-level imports 26 | _aliases = { 27 | "dask": "dask.array", 28 | "theano": "theano.tensor", 29 | "torch": "opt_einsum.backends.torch", 30 | "jax": "jax.numpy", 31 | "jaxlib": "jax.numpy", 32 | "autograd": "autograd.numpy", 33 | "mars": "mars.tensor", 34 | } 35 | 36 | 37 | def _import_func(func: str, backend: str, default: Any = None) -> Any: 38 | """Try and import ``{backend}.{func}``. 39 | If library is installed and func is found, return the func; 40 | otherwise if default is provided, return default; 41 | otherwise raise an error. 42 | """ 43 | try: 44 | lib = importlib.import_module(_aliases.get(backend, backend)) 45 | return getattr(lib, func) if default is None else getattr(lib, func, default) 46 | except AttributeError: 47 | error_msg = ( 48 | "{} doesn't seem to provide the function {} - see " 49 | "https://optimized-einsum.readthedocs.io/en/latest/backends.html " 50 | "for details on which functions are required for which contractions." 51 | ) 52 | raise AttributeError(error_msg.format(backend, func)) 53 | 54 | 55 | # manually cache functions as python2 doesn't support functools.lru_cache 56 | # other libs will be added to this if needed, but pre-populate with numpy 57 | _cached_funcs: Dict[Tuple[str, str], Any] = { 58 | ("einsum", "object"): object_arrays.object_einsum, 59 | } 60 | 61 | try: 62 | import numpy as np # type: ignore 63 | 64 | _cached_funcs[("tensordot", "numpy")] = np.tensordot 65 | _cached_funcs[("transpose", "numpy")] = np.transpose 66 | _cached_funcs[("einsum", "numpy")] = np.einsum 67 | # also pre-populate with the arbitrary object backend 68 | _cached_funcs[("tensordot", "object")] = np.tensordot 69 | _cached_funcs[("transpose", "object")] = np.transpose 70 | except ModuleNotFoundError: 71 | pass 72 | 73 | 74 | def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any: 75 | """Return ``{backend}.{func}``, e.g. ``numpy.einsum``, 76 | or a default func if provided. Cache result. 77 | """ 78 | try: 79 | return _cached_funcs[func, backend] 80 | except KeyError: 81 | fn = _import_func(func, backend, default) 82 | _cached_funcs[func, backend] = fn 83 | return fn 84 | 85 | 86 | # mark libs with einsum, else try to use tensordot/transpose as much as possible 87 | _has_einsum: Dict[str, bool] = {} 88 | 89 | 90 | def has_einsum(backend: str) -> bool: 91 | """Check if ``{backend}.einsum`` exists, cache result for performance.""" 92 | try: 93 | return _has_einsum[backend] 94 | except KeyError: 95 | try: 96 | get_func("einsum", backend) 97 | _has_einsum[backend] = True 98 | except AttributeError: 99 | _has_einsum[backend] = False 100 | 101 | return _has_einsum[backend] 102 | 103 | 104 | _has_tensordot: Dict[str, bool] = {} 105 | 106 | 107 | def has_tensordot(backend: str) -> bool: 108 | """Check if ``{backend}.tensordot`` exists, cache result for performance.""" 109 | try: 110 | return _has_tensordot[backend] 111 | except KeyError: 112 | try: 113 | get_func("tensordot", backend) 114 | _has_tensordot[backend] = True 115 | except AttributeError: 116 | _has_tensordot[backend] = False 117 | 118 | return _has_tensordot[backend] 119 | 120 | 121 | # Dispatch to correct expression backend 122 | # these are the backends which support explicit to-and-from numpy conversion 123 | CONVERT_BACKENDS = { 124 | "tensorflow": _tensorflow.build_expression, 125 | "theano": _theano.build_expression, 126 | "cupy": _cupy.build_expression, 127 | "torch": _torch.build_expression, 128 | "jax": _jax.build_expression, 129 | } 130 | 131 | EVAL_CONSTS_BACKENDS = { 132 | "tensorflow": _tensorflow.evaluate_constants, 133 | "theano": _theano.evaluate_constants, 134 | "cupy": _cupy.evaluate_constants, 135 | "torch": _torch.evaluate_constants, 136 | "jax": _jax.evaluate_constants, 137 | } 138 | 139 | 140 | def build_expression(backend, arrays, expr): 141 | """Build an expression, based on ``expr`` and initial arrays ``arrays``, 142 | that evaluates using backend ``backend``. 143 | """ 144 | return CONVERT_BACKENDS[backend](arrays, expr) 145 | 146 | 147 | def evaluate_constants(backend, arrays, expr): 148 | """Convert constant arrays to the correct backend, and perform as much of 149 | the contraction of ``expr`` with these as possible. 150 | """ 151 | return EVAL_CONSTS_BACKENDS[backend](arrays, expr) 152 | 153 | 154 | def has_backend(backend: str) -> bool: 155 | """Checks if the backend is known.""" 156 | return backend.lower() in CONVERT_BACKENDS 157 | -------------------------------------------------------------------------------- /opt_einsum/backends/jax.py: -------------------------------------------------------------------------------- 1 | """Required functions for optimized contractions of numpy arrays using jax.""" 2 | 3 | from opt_einsum.sharing import to_backend_cache_wrap 4 | 5 | __all__ = ["build_expression", "evaluate_constants"] 6 | 7 | _JAX = None 8 | 9 | 10 | def _get_jax_and_to_jax(): 11 | global _JAX 12 | if _JAX is None: 13 | import jax # type: ignore 14 | 15 | @to_backend_cache_wrap 16 | @jax.jit 17 | def to_jax(x): 18 | return x 19 | 20 | _JAX = jax, to_jax 21 | 22 | return _JAX 23 | 24 | 25 | def build_expression(_, expr): # pragma: no cover 26 | """Build a jax function based on ``arrays`` and ``expr``.""" 27 | jax, _ = _get_jax_and_to_jax() 28 | 29 | jax_expr = jax.jit(expr._contract) 30 | 31 | def jax_contract(*arrays): 32 | import numpy as np # type: ignore 33 | 34 | return np.asarray(jax_expr(arrays)) 35 | 36 | return jax_contract 37 | 38 | 39 | def evaluate_constants(const_arrays, expr): # pragma: no cover 40 | """Convert constant arguments to jax arrays, and perform any possible 41 | constant contractions. 42 | """ 43 | jax, to_jax = _get_jax_and_to_jax() 44 | 45 | return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True) 46 | -------------------------------------------------------------------------------- /opt_einsum/backends/object_arrays.py: -------------------------------------------------------------------------------- 1 | """Functions for performing contractions with array elements which are objects.""" 2 | 3 | import functools 4 | import operator 5 | 6 | from opt_einsum.typing import ArrayType 7 | 8 | 9 | def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: 10 | """A ``einsum`` implementation for ``numpy`` arrays with object dtype. 11 | The loop is performed in python, meaning the objects themselves need 12 | only to implement ``__mul__`` and ``__add__`` for the contraction to be 13 | computed. This may be useful when, for example, computing expressions of 14 | tensors with symbolic elements, but note it will be very slow when compared 15 | to ``numpy.einsum`` and numeric data types! 16 | 17 | Parameters 18 | ---------- 19 | eq : str 20 | The contraction string, should specify output. 21 | arrays : sequence of arrays 22 | These can be any indexable arrays as long as addition and 23 | multiplication is defined on the elements. 24 | 25 | Returns: 26 | ------- 27 | out : numpy.ndarray 28 | The output tensor, with ``dtype=object``. 29 | """ 30 | import numpy as np # type: ignore 31 | 32 | # when called by ``opt_einsum`` we will always be given a full eq 33 | lhs, output = eq.split("->") 34 | inputs = lhs.split(",") 35 | 36 | sizes = {} 37 | for term, array in zip(inputs, arrays): 38 | for k, d in zip(term, array.shape): 39 | sizes[k] = d 40 | 41 | out_size = tuple(sizes[k] for k in output) 42 | out = np.empty(out_size, dtype=object) 43 | 44 | inner = tuple(k for k in sizes if k not in output) 45 | inner_size = tuple(sizes[k] for k in inner) 46 | 47 | for coo_o in np.ndindex(*out_size): 48 | coord = dict(zip(output, coo_o)) 49 | 50 | def gen_inner_sum(): 51 | for coo_i in np.ndindex(*inner_size): 52 | coord.update(dict(zip(inner, coo_i))) 53 | locs = (tuple(coord[k] for k in term) for term in inputs) 54 | elements = (array[loc] for array, loc in zip(arrays, locs)) 55 | yield functools.reduce(operator.mul, elements) 56 | 57 | out[coo_o] = functools.reduce(operator.add, gen_inner_sum()) 58 | 59 | return out 60 | -------------------------------------------------------------------------------- /opt_einsum/backends/tensorflow.py: -------------------------------------------------------------------------------- 1 | """Required functions for optimized contractions of numpy arrays using tensorflow.""" 2 | 3 | from opt_einsum.helpers import has_array_interface 4 | from opt_einsum.sharing import to_backend_cache_wrap 5 | 6 | __all__ = ["to_tensorflow", "build_expression", "evaluate_constants"] 7 | 8 | _CACHED_TF_DEVICE = None 9 | 10 | 11 | def _get_tensorflow_and_device(): 12 | global _CACHED_TF_DEVICE 13 | 14 | if _CACHED_TF_DEVICE is None: 15 | import tensorflow as tf # type: ignore 16 | 17 | try: 18 | eager = tf.executing_eagerly() 19 | except AttributeError: 20 | try: 21 | eager = tf.contrib.eager.in_eager_mode() 22 | except AttributeError: 23 | eager = False 24 | 25 | device = tf.test.gpu_device_name() 26 | if not device: 27 | device = "cpu" 28 | 29 | _CACHED_TF_DEVICE = tf, device, eager 30 | 31 | return _CACHED_TF_DEVICE 32 | 33 | 34 | @to_backend_cache_wrap(constants=True) 35 | def to_tensorflow(array, constant=False): 36 | """Convert a numpy array to a ``tensorflow.placeholder`` instance.""" 37 | tf, device, eager = _get_tensorflow_and_device() 38 | 39 | if eager: 40 | if has_array_interface(array): 41 | with tf.device(device): 42 | return tf.convert_to_tensor(array) 43 | 44 | return array 45 | 46 | if has_array_interface(array): 47 | if constant: 48 | return tf.convert_to_tensor(array) 49 | 50 | return tf.placeholder(array.dtype, array.shape) 51 | 52 | return array 53 | 54 | 55 | # Standard graph mode 56 | 57 | 58 | def build_expression_graph(arrays, expr): 59 | """Build a tensorflow function based on ``arrays`` and ``expr``.""" 60 | tf, _, _ = _get_tensorflow_and_device() 61 | 62 | placeholders = [to_tensorflow(array) for array in arrays] 63 | graph = expr._contract(placeholders, backend="tensorflow") 64 | 65 | def tensorflow_contract(*arrays): 66 | session = tf.get_default_session() 67 | # only want to feed placeholders - constant tensors already have values 68 | feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"} 69 | return session.run(graph, feed_dict=feed_dict) 70 | 71 | return tensorflow_contract 72 | 73 | 74 | def evaluate_constants_graph(const_arrays, expr): 75 | """Convert constant arguments to tensorflow constants, and perform any 76 | possible constant contractions. Requires evaluating a tensorflow graph. 77 | """ 78 | tf, _, _ = _get_tensorflow_and_device() 79 | 80 | # compute the partial graph of new inputs 81 | const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays] 82 | new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True) 83 | 84 | # evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts 85 | session = tf.get_default_session() 86 | new_consts = iter(session.run([x for x in new_ops if x is not None])) 87 | new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops] 88 | 89 | return new_ops, new_contraction_list 90 | 91 | 92 | # Eager execution mode 93 | 94 | 95 | def build_expression_eager(_, expr): 96 | """Build a eager tensorflow function based on ``arrays`` and ``expr``.""" 97 | 98 | def tensorflow_eager_contract(*arrays): 99 | return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy() 100 | 101 | return tensorflow_eager_contract 102 | 103 | 104 | def evaluate_constants_eager(const_arrays, expr): 105 | """Convert constant arguments to tensorflow_eager arrays, and perform any 106 | possible constant contractions. 107 | """ 108 | return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True) 109 | 110 | 111 | # Dispatch to eager or graph mode 112 | 113 | 114 | def build_expression(arrays, expr): 115 | _, _, eager = _get_tensorflow_and_device() 116 | fn = build_expression_eager if eager else build_expression_graph 117 | return fn(arrays, expr) 118 | 119 | 120 | def evaluate_constants(const_arrays, expr): 121 | _, _, eager = _get_tensorflow_and_device() 122 | fn = evaluate_constants_eager if eager else evaluate_constants_graph 123 | return fn(const_arrays, expr) 124 | -------------------------------------------------------------------------------- /opt_einsum/backends/theano.py: -------------------------------------------------------------------------------- 1 | """Required functions for optimized contractions of numpy arrays using theano.""" 2 | 3 | from opt_einsum.helpers import has_array_interface 4 | from opt_einsum.sharing import to_backend_cache_wrap 5 | 6 | __all__ = ["to_theano", "build_expression", "evaluate_constants"] 7 | 8 | 9 | @to_backend_cache_wrap(constants=True) 10 | def to_theano(array, constant=False): 11 | """Convert a numpy array to ``theano.tensor.TensorType`` instance.""" 12 | import theano # type: ignore 13 | 14 | if has_array_interface(array): 15 | if constant: 16 | return theano.tensor.constant(array) 17 | 18 | return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))() 19 | 20 | return array 21 | 22 | 23 | def build_expression(arrays, expr): 24 | """Build a theano function based on ``arrays`` and ``expr``.""" 25 | import theano 26 | 27 | in_vars = [to_theano(array) for array in arrays] 28 | out_var = expr._contract(in_vars, backend="theano") 29 | 30 | # don't supply constants to graph 31 | graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)] 32 | graph = theano.function(graph_ins, out_var) 33 | 34 | def theano_contract(*arrays): 35 | return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)]) 36 | 37 | return theano_contract 38 | 39 | 40 | def evaluate_constants(const_arrays, expr): 41 | # compute the partial graph of new inputs 42 | const_arrays = [to_theano(x, constant=True) for x in const_arrays] 43 | new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True) 44 | 45 | # evaluate the new inputs and convert to theano shared tensors 46 | new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops] 47 | 48 | return new_ops, new_contraction_list 49 | -------------------------------------------------------------------------------- /opt_einsum/backends/torch.py: -------------------------------------------------------------------------------- 1 | """Required functions for optimized contractions of numpy arrays using pytorch.""" 2 | 3 | from opt_einsum.helpers import has_array_interface 4 | from opt_einsum.parser import convert_to_valid_einsum_chars 5 | from opt_einsum.sharing import to_backend_cache_wrap 6 | 7 | __all__ = [ 8 | "transpose", 9 | "einsum", 10 | "tensordot", 11 | "to_torch", 12 | "build_expression", 13 | "evaluate_constants", 14 | ] 15 | 16 | _TORCH_DEVICE = None 17 | _TORCH_HAS_TENSORDOT = None 18 | 19 | _torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 20 | 21 | 22 | def _get_torch_and_device(): 23 | global _TORCH_DEVICE 24 | global _TORCH_HAS_TENSORDOT 25 | 26 | if _TORCH_DEVICE is None: 27 | import torch # type: ignore 28 | 29 | device = "cuda" if torch.cuda.is_available() else "cpu" 30 | _TORCH_DEVICE = torch, device 31 | _TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot") 32 | 33 | return _TORCH_DEVICE 34 | 35 | 36 | def transpose(a, axes): 37 | """Normal torch transpose is only valid for 2D matrices.""" 38 | return a.permute(*axes) 39 | 40 | 41 | def einsum(equation, *operands, **kwargs): 42 | """Variadic version of torch.einsum to match numpy api.""" 43 | # rename symbols to support PyTorch 0.4.1 and earlier, 44 | # which allow only symbols a-z. 45 | equation = convert_to_valid_einsum_chars(equation) 46 | 47 | torch, _ = _get_torch_and_device() 48 | return torch.einsum(equation, operands) 49 | 50 | 51 | def tensordot(x, y, axes=2): 52 | """Simple translation of tensordot syntax to einsum.""" 53 | torch, _ = _get_torch_and_device() 54 | 55 | if _TORCH_HAS_TENSORDOT: 56 | return torch.tensordot(x, y, dims=axes) 57 | 58 | xnd = x.ndimension() 59 | ynd = y.ndimension() 60 | 61 | # convert int argument to (list[int], list[int]) 62 | if isinstance(axes, int): 63 | axes = range(xnd - axes, xnd), range(axes) 64 | 65 | # convert (int, int) to (list[int], list[int]) 66 | if isinstance(axes[0], int): 67 | axes = (axes[0],), axes[1] 68 | if isinstance(axes[1], int): 69 | axes = axes[0], (axes[1],) 70 | 71 | # initialize empty indices 72 | x_ix = [None] * xnd 73 | y_ix = [None] * ynd 74 | out_ix = [] 75 | 76 | # fill in repeated indices 77 | available_ix = iter(_torch_symbols_base) 78 | for ax1, ax2 in zip(*axes): 79 | repeat = next(available_ix) 80 | x_ix[ax1] = repeat 81 | y_ix[ax2] = repeat 82 | 83 | # fill in the rest, and maintain output order 84 | for i in range(xnd): 85 | if x_ix[i] is None: 86 | leave = next(available_ix) 87 | x_ix[i] = leave 88 | out_ix.append(leave) 89 | for i in range(ynd): 90 | if y_ix[i] is None: 91 | leave = next(available_ix) 92 | y_ix[i] = leave 93 | out_ix.append(leave) 94 | 95 | # form full string and contract! 96 | einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix))) 97 | return einsum(einsum_str, x, y) 98 | 99 | 100 | @to_backend_cache_wrap 101 | def to_torch(array): 102 | torch, device = _get_torch_and_device() 103 | 104 | if has_array_interface(array): 105 | return torch.from_numpy(array).to(device) 106 | 107 | return array 108 | 109 | 110 | def build_expression(_, expr): # pragma: no cover 111 | """Build a torch function based on ``arrays`` and ``expr``.""" 112 | 113 | def torch_contract(*arrays): 114 | torch_arrays = [to_torch(x) for x in arrays] 115 | torch_out = expr._contract(torch_arrays, backend="torch") 116 | 117 | if torch_out.device.type == "cpu": 118 | return torch_out.numpy() 119 | 120 | return torch_out.cpu().numpy() 121 | 122 | return torch_contract 123 | 124 | 125 | def evaluate_constants(const_arrays, expr): 126 | """Convert constant arguments to torch, and perform any possible constant 127 | contractions. 128 | """ 129 | const_arrays = [to_torch(x) for x in const_arrays] 130 | return expr(*const_arrays, backend="torch", evaluate_constants=True) 131 | -------------------------------------------------------------------------------- /opt_einsum/blas.py: -------------------------------------------------------------------------------- 1 | """Determines if a contraction can use BLAS or not.""" 2 | 3 | from typing import List, Sequence, Tuple, Union 4 | 5 | from opt_einsum.typing import ArrayIndexType 6 | 7 | __all__ = ["can_blas"] 8 | 9 | 10 | def can_blas( 11 | inputs: List[str], 12 | result: str, 13 | idx_removed: ArrayIndexType, 14 | shapes: Union[Sequence[Tuple[int]], None] = None, 15 | ) -> Union[str, bool]: 16 | """Checks if we can use a BLAS call. 17 | 18 | Parameters 19 | ---------- 20 | inputs : list of str 21 | Specifies the subscripts for summation. 22 | result : str 23 | Resulting summation. 24 | idx_removed : set 25 | Indices that are removed in the summation 26 | shapes : sequence of tuple[int], optional 27 | If given, check also that none of the indices are broadcast dimensions. 28 | 29 | Returns: 30 | ------- 31 | type : str or bool 32 | The type of BLAS call to be used or False if none. 33 | 34 | Notes: 35 | ----- 36 | We assume several operations are not efficient such as a transposed 37 | DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas 38 | type appended with "/EINSUM" to differentiate when they can still be done 39 | with tensordot if required, e.g. when a backend has no einsum. 40 | 41 | Examples: 42 | -------- 43 | >>> can_blas(['ij', 'jk'], 'ik', set('j')) 44 | 'GEMM' 45 | 46 | >>> can_blas(['ijj', 'jk'], 'ik', set('j')) 47 | False 48 | 49 | >>> can_blas(['ab', 'cd'], 'abcd', set()) 50 | 'OUTER/EINSUM' 51 | 52 | >>> # looks like GEMM but actually 'j' is broadcast: 53 | >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)]) 54 | False 55 | """ 56 | # Can only do two 57 | if len(inputs) != 2: 58 | return False 59 | 60 | input_left, input_right = inputs 61 | 62 | for c in set(input_left + input_right): 63 | # can't deal with repeated indices on same input or more than 2 total 64 | nl, nr = input_left.count(c), input_right.count(c) 65 | if (nl > 1) or (nr > 1) or (nl + nr > 2): 66 | return False 67 | 68 | # can't do implicit summation or dimension collapse e.g. 69 | # "ab,bc->c" (implicitly sum over 'a') 70 | # "ab,ca->ca" (take diagonal of 'a') 71 | if nl + nr - 1 == int(c in result): 72 | return False 73 | 74 | # check for broadcast indices e.g: 75 | # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up) 76 | if shapes is not None: 77 | for c in idx_removed: 78 | if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]: 79 | return False 80 | 81 | # Prefer einsum if not removing indices 82 | # (N.B. tensordot outer faster for large arrays?) 83 | if len(idx_removed) == 0: 84 | return "OUTER/EINSUM" 85 | 86 | # Build a few temporaries 87 | sets = [set(x) for x in inputs] 88 | keep_left = sets[0] - idx_removed 89 | keep_right = sets[1] - idx_removed 90 | rs = len(idx_removed) 91 | 92 | # DDOT 93 | if inputs[0] == inputs[1]: 94 | return "DOT" 95 | 96 | # DDOT does not make sense if you have to transpose - prefer einsum 97 | elif sets[0] == sets[1]: 98 | return "DOT/EINSUM" 99 | 100 | # GEMM no transpose 101 | if input_left[-rs:] == input_right[:rs]: 102 | return "GEMM" 103 | 104 | # GEMM transpose both 105 | elif input_left[:rs] == input_right[-rs:]: 106 | return "GEMM" 107 | 108 | # GEMM transpose right 109 | elif input_left[-rs:] == input_right[-rs:]: 110 | return "GEMM" 111 | 112 | # GEMM transpose left 113 | elif input_left[:rs] == input_right[:rs]: 114 | return "GEMM" 115 | 116 | # Einsum is faster than vectordot if we have to copy 117 | elif (len(keep_left) == 0) or (len(keep_right) == 0): 118 | return "GEMV/EINSUM" 119 | 120 | # Conventional tensordot 121 | else: 122 | return "TDOT" 123 | -------------------------------------------------------------------------------- /opt_einsum/helpers.py: -------------------------------------------------------------------------------- 1 | """Contains helper functions for opt_einsum testing scripts.""" 2 | 3 | from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload 4 | 5 | from opt_einsum.typing import ArrayIndexType, ArrayType 6 | 7 | __all__ = ["compute_size_by_dict", "find_contraction", "flop_count"] 8 | 9 | _valid_chars = "abcdefghijklmopqABC" 10 | _sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] 11 | _default_dim_dict = dict(zip(_valid_chars, _sizes)) 12 | 13 | 14 | @overload 15 | def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: ... 16 | 17 | 18 | @overload 19 | def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int: ... 20 | 21 | 22 | def compute_size_by_dict(indices: Any, idx_dict: Any) -> int: 23 | """Computes the product of the elements in indices based on the dictionary 24 | idx_dict. 25 | 26 | Parameters 27 | ---------- 28 | indices : iterable 29 | Indices to base the product on. 30 | idx_dict : dictionary 31 | Dictionary of index _sizes 32 | 33 | Returns: 34 | ------- 35 | ret : int 36 | The resulting product. 37 | 38 | Examples: 39 | -------- 40 | >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 41 | 90 42 | 43 | """ 44 | ret = 1 45 | for i in indices: # lgtm [py/iteration-string-and-sequence] 46 | ret *= idx_dict[i] 47 | return ret 48 | 49 | 50 | def find_contraction( 51 | positions: Collection[int], 52 | input_sets: List[ArrayIndexType], 53 | output_set: ArrayIndexType, 54 | ) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]: 55 | """Finds the contraction for a given set of input and output sets. 56 | 57 | Parameters 58 | ---------- 59 | positions : iterable 60 | Integer positions of terms used in the contraction. 61 | input_sets : list 62 | List of sets that represent the lhs side of the einsum subscript 63 | output_set : set 64 | Set that represents the rhs side of the overall einsum subscript 65 | 66 | Returns: 67 | ------- 68 | new_result : set 69 | The indices of the resulting contraction 70 | remaining : list 71 | List of sets that have not been contracted, the new set is appended to 72 | the end of this list 73 | idx_removed : set 74 | Indices removed from the entire contraction 75 | idx_contraction : set 76 | The indices used in the current contraction 77 | 78 | Examples: 79 | -------- 80 | # A simple dot product test case 81 | >>> pos = (0, 1) 82 | >>> isets = [set('ab'), set('bc')] 83 | >>> oset = set('ac') 84 | >>> find_contraction(pos, isets, oset) 85 | ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 86 | 87 | # A more complex case with additional terms in the contraction 88 | >>> pos = (0, 2) 89 | >>> isets = [set('abd'), set('ac'), set('bdc')] 90 | >>> oset = set('ac') 91 | >>> find_contraction(pos, isets, oset) 92 | ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 93 | """ 94 | remaining = list(input_sets) 95 | inputs = (remaining.pop(i) for i in sorted(positions, reverse=True)) 96 | idx_contract = frozenset.union(*inputs) 97 | idx_remain = output_set.union(*remaining) 98 | 99 | new_result = idx_remain & idx_contract 100 | idx_removed = idx_contract - new_result 101 | remaining.append(new_result) 102 | 103 | return new_result, remaining, idx_removed, idx_contract 104 | 105 | 106 | def flop_count( 107 | idx_contraction: Collection[str], 108 | inner: bool, 109 | num_terms: int, 110 | size_dictionary: Dict[str, int], 111 | ) -> int: 112 | """Computes the number of FLOPS in the contraction. 113 | 114 | Parameters 115 | ---------- 116 | idx_contraction : iterable 117 | The indices involved in the contraction 118 | inner : bool 119 | Does this contraction require an inner product? 120 | num_terms : int 121 | The number of terms in a contraction 122 | size_dictionary : dict 123 | The size of each of the indices in idx_contraction 124 | 125 | Returns: 126 | ------- 127 | flop_count : int 128 | The total number of FLOPS required for the contraction. 129 | 130 | Examples: 131 | -------- 132 | >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 133 | 30 134 | 135 | >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 136 | 60 137 | 138 | """ 139 | overall_size = compute_size_by_dict(idx_contraction, size_dictionary) 140 | op_factor = max(1, num_terms - 1) 141 | if inner: 142 | op_factor += 1 143 | 144 | return overall_size * op_factor 145 | 146 | 147 | def has_array_interface(array: ArrayType) -> ArrayType: 148 | if hasattr(array, "__array_interface__"): 149 | return True 150 | else: 151 | return False 152 | -------------------------------------------------------------------------------- /opt_einsum/sharing.py: -------------------------------------------------------------------------------- 1 | """A module for sharing intermediates between contractions. 2 | 3 | Copyright (c) 2018 Uber Technologies 4 | """ 5 | 6 | import contextlib 7 | import functools 8 | import numbers 9 | import threading 10 | from collections import Counter, defaultdict 11 | from typing import Any, Dict, Generator, List, Optional, Tuple, Union 12 | from typing import Counter as CounterType 13 | 14 | from opt_einsum.parser import alpha_canonicalize, parse_einsum_input 15 | from opt_einsum.typing import ArrayType 16 | 17 | CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]] 18 | CacheType = Dict[CacheKeyType, ArrayType] 19 | 20 | __all__ = [ 21 | "currently_sharing", 22 | "get_sharing_cache", 23 | "shared_intermediates", 24 | "count_cached_ops", 25 | "transpose_cache_wrap", 26 | "einsum_cache_wrap", 27 | "to_backend_cache_wrap", 28 | ] 29 | 30 | _SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list) 31 | 32 | 33 | def currently_sharing() -> bool: 34 | """Check if we are currently sharing a cache -- thread specific.""" 35 | return threading.get_ident() in _SHARING_STACK 36 | 37 | 38 | def get_sharing_cache() -> CacheType: 39 | """Return the most recent sharing cache -- thread specific.""" 40 | return _SHARING_STACK[threading.get_ident()][-1] 41 | 42 | 43 | def _add_sharing_cache(cache: CacheType) -> Any: 44 | _SHARING_STACK[threading.get_ident()].append(cache) 45 | 46 | 47 | def _remove_sharing_cache() -> None: 48 | tid = threading.get_ident() 49 | _SHARING_STACK[tid].pop() 50 | if not _SHARING_STACK[tid]: 51 | del _SHARING_STACK[tid] 52 | 53 | 54 | @contextlib.contextmanager 55 | def shared_intermediates( 56 | cache: Optional[CacheType] = None, 57 | ) -> Generator[CacheType, None, None]: 58 | """Context in which contract intermediate results are shared. 59 | 60 | Note that intermediate computations will not be garbage collected until 61 | 1. this context exits, and 62 | 2. the yielded cache is garbage collected (if it was captured). 63 | 64 | **Parameters:** 65 | 66 | - **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts. 67 | 68 | **Returns:** 69 | 70 | - **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored, 71 | sharing results will be garbage collected when this context is 72 | exited. This dict can be passed to another context to resume 73 | sharing. 74 | """ 75 | if cache is None: 76 | cache = {} 77 | _add_sharing_cache(cache) 78 | try: 79 | yield cache 80 | finally: 81 | _remove_sharing_cache() 82 | 83 | 84 | def count_cached_ops(cache: CacheType) -> CounterType[str]: 85 | """Returns a counter of the types of each op in the cache. 86 | This is useful for profiling to increase sharing. 87 | """ 88 | return Counter(key[0] for key in cache.keys()) 89 | 90 | 91 | def _save_tensors(*tensors: ArrayType) -> None: 92 | """Save tensors in the cache to prevent their ids from being recycled. 93 | This is needed to prevent false cache lookups. 94 | """ 95 | cache = get_sharing_cache() 96 | for tensor in tensors: 97 | cache["tensor", id(tensor)] = tensor 98 | 99 | 100 | def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType: 101 | """Memoize ``fn(*args, **kwargs)`` using the given ``key``. 102 | Results will be stored in the innermost ``cache`` yielded by 103 | :func:`shared_intermediates`. 104 | """ 105 | cache = get_sharing_cache() 106 | if key in cache: 107 | return cache[key] 108 | result = fn(*args, **kwargs) 109 | cache[key] = result 110 | return result 111 | 112 | 113 | def transpose_cache_wrap(transpose: Any) -> Any: 114 | """Decorates a ``transpose()`` implementation to be memoized inside a 115 | :func:`shared_intermediates` context. 116 | """ 117 | 118 | @functools.wraps(transpose) 119 | def cached_transpose(a, axes, backend="numpy"): 120 | if not currently_sharing(): 121 | return transpose(a, axes, backend=backend) 122 | 123 | # hash by axes 124 | _save_tensors(a) 125 | axes = tuple(axes) 126 | key = "transpose", backend, id(a), axes 127 | return _memoize(key, transpose, a, axes, backend=backend) 128 | 129 | return cached_transpose 130 | 131 | 132 | def tensordot_cache_wrap(tensordot: Any) -> Any: 133 | """Decorates a ``tensordot()`` implementation to be memoized inside a 134 | :func:`shared_intermediates` context. 135 | """ 136 | 137 | @functools.wraps(tensordot) 138 | def cached_tensordot(x, y, axes=2, backend="numpy"): 139 | if not currently_sharing(): 140 | return tensordot(x, y, axes, backend=backend) 141 | 142 | # hash based on the (axes_x,axes_y) form of axes 143 | _save_tensors(x, y) 144 | if isinstance(axes, numbers.Number): 145 | axes = ( 146 | list(range(len(x.shape)))[len(x.shape) - axes :], 147 | list(range(len(y.shape)))[:axes], 148 | ) 149 | axes = tuple(axes[0]), tuple(axes[1]) 150 | key = "tensordot", backend, id(x), id(y), axes 151 | return _memoize(key, tensordot, x, y, axes, backend=backend) 152 | 153 | return cached_tensordot 154 | 155 | 156 | def einsum_cache_wrap(einsum: Any) -> Any: 157 | """Decorates an ``einsum()`` implementation to be memoized inside a 158 | :func:`shared_intermediates` context. 159 | """ 160 | 161 | @functools.wraps(einsum) 162 | def cached_einsum(*args, **kwargs): 163 | if not currently_sharing(): 164 | return einsum(*args, **kwargs) 165 | 166 | # hash modulo commutativity by computing a canonical ordering and names 167 | backend = kwargs.pop("backend", "numpy") 168 | equation = args[0] 169 | inputs, output, operands = parse_einsum_input(args) 170 | inputs = inputs.split(",") 171 | 172 | _save_tensors(*operands) 173 | 174 | # Build canonical key 175 | canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1]) 176 | canonical_ids = tuple(id_ for _, id_ in canonical) 177 | canonical_inputs = ",".join(input_ for input_, _ in canonical) 178 | canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output) 179 | 180 | key = "einsum", backend, canonical_equation, canonical_ids 181 | return _memoize(key, einsum, equation, *operands, backend=backend) 182 | 183 | return cached_einsum 184 | 185 | 186 | def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any: 187 | """Decorates an ``to_backend()`` implementation to be memoized inside a 188 | :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``). 189 | """ 190 | # manage the case that decorator is called with args 191 | if to_backend is None: 192 | return functools.partial(to_backend_cache_wrap, constants=constants) 193 | 194 | if constants: 195 | 196 | @functools.wraps(to_backend) 197 | def cached_to_backend(array, constant=False): 198 | if not currently_sharing(): 199 | return to_backend(array, constant=constant) 200 | 201 | # hash by id 202 | key = to_backend.__name__, id(array), constant 203 | return _memoize(key, to_backend, array, constant=constant) 204 | 205 | else: 206 | 207 | @functools.wraps(to_backend) 208 | def cached_to_backend(array): 209 | if not currently_sharing(): 210 | return to_backend(array) 211 | 212 | # hash by id 213 | key = to_backend.__name__, id(array) 214 | return _memoize(key, to_backend, array) 215 | 216 | return cached_to_backend 217 | -------------------------------------------------------------------------------- /opt_einsum/testing.py: -------------------------------------------------------------------------------- 1 | """Testing routines for opt_einsum.""" 2 | 3 | import random 4 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload 5 | 6 | import pytest 7 | 8 | from opt_einsum.parser import get_symbol 9 | from opt_einsum.typing import ArrayType, PathType, TensorShapeType 10 | 11 | _no_collision_chars = "".join(chr(i) for i in range(7000, 7007)) 12 | _valid_chars = "abcdefghijklmnopqABC" + _no_collision_chars 13 | _sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4, 9, 10, 2, 4, 5, 3, 2, 6] 14 | _default_dim_dict = dict(zip(_valid_chars, _sizes)) 15 | assert len(_valid_chars) == len( 16 | _sizes 17 | ), f"Valid characters and sizes must be the same length: {len(_valid_chars)} != {len(_sizes)}" 18 | 19 | 20 | def build_shapes( 21 | string: str, dimension_dict: Optional[Dict[str, int]] = None, replace_ellipsis: bool = False 22 | ) -> Tuple[TensorShapeType, ...]: 23 | """Builds random tensor shapes for testing. 24 | 25 | Parameters: 26 | string: List of tensor strings to build 27 | dimension_dict: Dictionary of index sizes, defaults to indices size of 2-7 28 | replace_ellipsis: Replace ellipsis with a string of no collision characters 29 | 30 | Returns: 31 | The resulting shapes. 32 | 33 | Examples: 34 | ```python 35 | >>> shapes = build_shapes('abbc', {'a': 2, 'b':3, 'c':5}) 36 | >>> shapes 37 | [(2, 3), (3, 3, 5), (5,)] 38 | ``` 39 | 40 | """ 41 | if dimension_dict is None: 42 | dimension_dict = _default_dim_dict 43 | 44 | if "..." in string: 45 | if replace_ellipsis is False: 46 | raise ValueError( 47 | "Ellipsis found in string but `replace_ellipsis=False`, use `replace_ellipsis=True` if this behavior is desired." 48 | ) 49 | ellipse_replace = _no_collision_chars[:3] 50 | string = string.replace("...", ellipse_replace) 51 | 52 | shapes = [] 53 | terms = string.split("->")[0].split(",") 54 | for term in terms: 55 | dims = [dimension_dict[x] for x in term] 56 | shapes.append(tuple(dims)) 57 | return tuple(shapes) 58 | 59 | 60 | def build_views( 61 | string: str, 62 | dimension_dict: Optional[Dict[str, int]] = None, 63 | array_function: Optional[Any] = None, 64 | replace_ellipsis: bool = False, 65 | ) -> Tuple[ArrayType, ...]: 66 | """Builds random numpy arrays for testing. 67 | 68 | Parameters: 69 | string: List of tensor strings to build 70 | dimension_dict: Dictionary of index _sizes 71 | array_function: Function to build the arrays, defaults to np.random.rand 72 | replace_ellipsis: Replace ellipsis with a string of no collision characters 73 | 74 | Returns: 75 | The resulting views. 76 | 77 | Examples: 78 | ```python 79 | >>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5}) 80 | >>> view[0].shape 81 | (2, 3, 3, 5) 82 | ``` 83 | 84 | """ 85 | if array_function is None: 86 | np = pytest.importorskip("numpy") 87 | array_function = np.random.rand 88 | 89 | views = [] 90 | for shape in build_shapes(string, dimension_dict=dimension_dict, replace_ellipsis=replace_ellipsis): 91 | if shape: 92 | views.append(array_function(*shape)) 93 | else: 94 | views.append(random.random()) 95 | return tuple(views) 96 | 97 | 98 | @overload 99 | def rand_equation( 100 | n: int, 101 | regularity: int, 102 | n_out: int = ..., 103 | d_min: int = ..., 104 | d_max: int = ..., 105 | seed: Optional[int] = ..., 106 | global_dim: bool = ..., 107 | *, 108 | return_size_dict: Literal[True], 109 | ) -> Tuple[str, PathType, Dict[str, int]]: ... 110 | 111 | 112 | @overload 113 | def rand_equation( 114 | n: int, 115 | regularity: int, 116 | n_out: int = ..., 117 | d_min: int = ..., 118 | d_max: int = ..., 119 | seed: Optional[int] = ..., 120 | global_dim: bool = ..., 121 | return_size_dict: Literal[False] = ..., 122 | ) -> Tuple[str, PathType]: ... 123 | 124 | 125 | def rand_equation( 126 | n: int, 127 | regularity: int, 128 | n_out: int = 0, 129 | d_min: int = 2, 130 | d_max: int = 9, 131 | seed: Optional[int] = None, 132 | global_dim: bool = False, 133 | return_size_dict: bool = False, 134 | ) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]: 135 | """Generate a random contraction and shapes. 136 | 137 | Parameters: 138 | n: Number of array arguments. 139 | regularity: 'Regularity' of the contraction graph. This essentially determines how 140 | many indices each tensor shares with others on average. 141 | n_out: Number of output indices (i.e. the number of non-contracted indices). 142 | Defaults to 0, i.e., a contraction resulting in a scalar. 143 | d_min: Minimum dimension size. 144 | d_max: Maximum dimension size. 145 | seed: If not None, seed numpy's random generator with this. 146 | global_dim: Add a global, 'broadcast', dimension to every operand. 147 | return_size_dict: Return the mapping of indices to sizes. 148 | 149 | Returns: 150 | eq: The equation string. 151 | shapes: The array shapes. 152 | size_dict: The dict of index sizes, only returned if ``return_size_dict=True``. 153 | 154 | Examples: 155 | ```python 156 | >>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42) 157 | >>> eq 158 | 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' 159 | 160 | >>> shapes 161 | [(9, 5, 4, 5, 4), 162 | (4, 4, 8, 5), 163 | (9, 4, 6, 9), 164 | (6, 6), 165 | (6, 9, 7, 8), 166 | (4,), 167 | (9, 3, 9, 4, 9), 168 | (6, 8, 4, 6, 8, 6, 3), 169 | (4, 7, 8, 8, 6, 9, 6), 170 | (9, 5, 3, 3, 9, 5)] 171 | ``` 172 | """ 173 | np = pytest.importorskip("numpy") 174 | if seed is not None: 175 | np.random.seed(seed) 176 | 177 | # total number of indices 178 | num_inds = n * regularity // 2 + n_out 179 | inputs = ["" for _ in range(n)] 180 | output = [] 181 | 182 | size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)} 183 | 184 | # generate a list of indices to place either once or twice 185 | def gen(): 186 | for i, ix in enumerate(size_dict): 187 | # generate an outer index 188 | if i < n_out: 189 | output.append(ix) 190 | yield ix 191 | # generate a bond 192 | else: 193 | yield ix 194 | yield ix 195 | 196 | # add the indices randomly to the inputs 197 | for i, ix in enumerate(np.random.permutation(list(gen()))): 198 | # make sure all inputs have at least one index 199 | if i < n: 200 | inputs[i] += ix 201 | else: 202 | # don't add any traces on same op 203 | where = np.random.randint(0, n) 204 | while ix in inputs[where]: 205 | where = np.random.randint(0, n) 206 | 207 | inputs[where] += ix 208 | 209 | # possibly add the same global dim to every arg 210 | if global_dim: 211 | gdim = get_symbol(num_inds) 212 | size_dict[gdim] = np.random.randint(d_min, d_max + 1) 213 | for i in range(n): 214 | inputs[i] += gdim 215 | output += gdim 216 | 217 | # randomly transpose the output indices and form equation 218 | output = "".join(np.random.permutation(output)) # type: ignore 219 | eq = "{}->{}".format(",".join(inputs), output) 220 | 221 | # make the shapes 222 | shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] 223 | 224 | ret = (eq, shapes) 225 | 226 | if return_size_dict: 227 | return ret + (size_dict,) 228 | else: 229 | return ret 230 | 231 | 232 | def build_arrays_from_tuples(path: PathType) -> List[Any]: 233 | """Build random numpy arrays from a path. 234 | 235 | Parameters: 236 | path: The path to build arrays from. 237 | 238 | Returns: 239 | The resulting arrays. 240 | """ 241 | np = pytest.importorskip("numpy") 242 | 243 | return [np.random.rand(*x) for x in path] 244 | -------------------------------------------------------------------------------- /opt_einsum/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgasmith/opt_einsum/f973f1e3265f248680f502807f2fdca13563cf1a/opt_einsum/tests/__init__.py -------------------------------------------------------------------------------- /opt_einsum/tests/test_blas.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the BLAS capability for the opt_einsum module. 3 | """ 4 | 5 | from typing import Any 6 | 7 | import pytest 8 | 9 | from opt_einsum import blas, contract 10 | 11 | blas_tests = [ 12 | # DOT 13 | ((["k", "k"], "", set("k")), "DOT"), # DDOT 14 | ((["ijk", "ijk"], "", set("ijk")), "DOT"), # DDOT 15 | # GEMV? 16 | # GEMM 17 | ((["ij", "jk"], "ik", set("j")), "GEMM"), # GEMM N N 18 | ((["ijl", "jlk"], "ik", set("jl")), "GEMM"), # GEMM N N Tensor 19 | ((["ij", "kj"], "ik", set("j")), "GEMM"), # GEMM N T 20 | ((["ijl", "kjl"], "ik", set("jl")), "GEMM"), # GEMM N T Tensor 21 | ((["ji", "jk"], "ik", set("j")), "GEMM"), # GEMM T N 22 | ((["jli", "jlk"], "ik", set("jl")), "GEMM"), # GEMM T N Tensor 23 | ((["ji", "kj"], "ik", set("j")), "GEMM"), # GEMM T T 24 | ((["jli", "kjl"], "ik", set("jl")), "GEMM"), # GEMM T T Tensor 25 | # GEMM with final transpose 26 | ((["ij", "jk"], "ki", set("j")), "GEMM"), # GEMM N N 27 | ((["ijl", "jlk"], "ki", set("jl")), "GEMM"), # GEMM N N Tensor 28 | ((["ij", "kj"], "ki", set("j")), "GEMM"), # GEMM N T 29 | ((["ijl", "kjl"], "ki", set("jl")), "GEMM"), # GEMM N T Tensor 30 | ((["ji", "jk"], "ki", set("j")), "GEMM"), # GEMM T N 31 | ((["jli", "jlk"], "ki", set("jl")), "GEMM"), # GEMM T N Tensor 32 | ((["ji", "kj"], "ki", set("j")), "GEMM"), # GEMM T T 33 | ((["jli", "kjl"], "ki", set("jl")), "GEMM"), # GEMM T T Tensor 34 | # Tensor Dot (requires copy), lets not deal with this for now 35 | ((["ilj", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM N N Tensor 36 | ((["ijl", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM N N Tensor 37 | ((["ilj", "kjl"], "ik", set("jl")), "TDOT"), # FT GEMM N T Tensor 38 | ((["ijl", "klj"], "ik", set("jl")), "TDOT"), # ST GEMM N T Tensor 39 | ((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor 40 | ((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor 41 | ((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor 42 | ((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor 43 | # Tensor Dot (requires copy), lets not deal with this for now with transpose 44 | ((["ilj", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM N N Tensor 45 | ((["ijl", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM N N Tensor 46 | ((["ilj", "kjl"], "ik", set("lj")), "TDOT"), # FT GEMM N T Tensor 47 | ((["ijl", "klj"], "ik", set("lj")), "TDOT"), # ST GEMM N T Tensor 48 | ((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor 49 | ((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor 50 | ((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor 51 | ((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor 52 | # Other 53 | ((["ijk", "ikj"], "", set("ijk")), "DOT/EINSUM"), # Transpose DOT 54 | ((["i", "j"], "ij", set()), "OUTER/EINSUM"), # Outer 55 | ((["ijk", "ik"], "j", set("ik")), "GEMV/EINSUM"), # Matrix-vector 56 | ((["ijj", "jk"], "ik", set("j")), False), # Double index 57 | ((["ijk", "j"], "ij", set()), False), # Index sum 1 58 | ((["ij", "ij"], "ij", set()), False), # Index sum 2 59 | ] 60 | 61 | 62 | @pytest.mark.parametrize("inp,benchmark", blas_tests) 63 | def test_can_blas(inp: Any, benchmark: bool) -> None: 64 | result = blas.can_blas(*inp) 65 | assert result == benchmark 66 | 67 | 68 | def test_blas_out() -> None: 69 | np = pytest.importorskip("numpy") 70 | 71 | a = np.random.rand(4, 4) 72 | b = np.random.rand(4, 4) 73 | c = np.random.rand(4, 4) 74 | d = np.empty((4, 4)) 75 | 76 | contract("ij,jk->ik", a, b, out=d) 77 | np.testing.assert_allclose(d, np.dot(a, b)) 78 | assert np.allclose(d, np.dot(a, b)) 79 | 80 | contract("ij,jk,kl->il", a, b, c, out=d) 81 | np.testing.assert_allclose(d, np.dot(a, b).dot(c)) 82 | -------------------------------------------------------------------------------- /opt_einsum/tests/test_contract.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths 3 | """ 4 | 5 | from typing import Any, List 6 | 7 | import pytest 8 | 9 | from opt_einsum import contract, contract_expression, contract_path 10 | from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear 11 | from opt_einsum.testing import build_views, rand_equation 12 | from opt_einsum.typing import OptimizeKind 13 | 14 | # NumPy is required for the majority of this file 15 | np = pytest.importorskip("numpy") 16 | 17 | 18 | tests = [ 19 | # Test scalar-like operations 20 | "a,->a", 21 | "ab,->ab", 22 | ",ab,->ab", 23 | ",,->", 24 | # Test hadamard-like products 25 | "a,ab,abc->abc", 26 | "a,b,ab->ab", 27 | # Test index-transformations 28 | "ea,fb,gc,hd,abcd->efgh", 29 | "ea,fb,abcd,gc,hd->efgh", 30 | "abcd,ea,fb,gc,hd->efgh", 31 | # Test complex contractions 32 | "acdf,jbje,gihb,hfac,gfac,gifabc,hfac", 33 | "acdf,jbje,gihb,hfac,gfac,gifabc,hfac", 34 | "cd,bdhe,aidb,hgca,gc,hgibcd,hgac", 35 | "abhe,hidj,jgba,hiab,gab", 36 | "bde,cdh,agdb,hica,ibd,hgicd,hiac", 37 | "chd,bde,agbc,hiad,hgc,hgi,hiad", 38 | "chd,bde,agbc,hiad,bdi,cgh,agdb", 39 | "bdhe,acad,hiab,agac,hibd", 40 | # Test collapse 41 | "ab,ab,c->", 42 | "ab,ab,c->c", 43 | "ab,ab,cd,cd->", 44 | "ab,ab,cd,cd->ac", 45 | "ab,ab,cd,cd->cd", 46 | "ab,ab,cd,cd,ef,ef->", 47 | # Test outer prodcuts 48 | "ab,cd,ef->abcdef", 49 | "ab,cd,ef->acdf", 50 | "ab,cd,de->abcde", 51 | "ab,cd,de->be", 52 | "ab,bcd,cd->abcd", 53 | "ab,bcd,cd->abd", 54 | # Random test cases that have previously failed 55 | "eb,cb,fb->cef", 56 | "dd,fb,be,cdb->cef", 57 | "bca,cdb,dbf,afc->", 58 | "dcc,fce,ea,dbf->ab", 59 | "fdf,cdd,ccd,afe->ae", 60 | "abcd,ad", 61 | "ed,fcd,ff,bcf->be", 62 | "baa,dcf,af,cde->be", 63 | "bd,db,eac->ace", 64 | "fff,fae,bef,def->abd", 65 | "efc,dbc,acf,fd->abe", 66 | # Inner products 67 | "ab,ab", 68 | "ab,ba", 69 | "abc,abc", 70 | "abc,bac", 71 | "abc,cba", 72 | # GEMM test cases 73 | "ab,bc", 74 | "ab,cb", 75 | "ba,bc", 76 | "ba,cb", 77 | "abcd,cd", 78 | "abcd,ab", 79 | "abcd,cdef", 80 | "abcd,cdef->feba", 81 | "abcd,efdc", 82 | # Inner than dot 83 | "aab,bc->ac", 84 | "ab,bcc->ac", 85 | "aab,bcc->ac", 86 | "baa,bcc->ac", 87 | "aab,ccb->ac", 88 | # Randomly build test caes 89 | "aab,fa,df,ecc->bde", 90 | "ecb,fef,bad,ed->ac", 91 | "bcf,bbb,fbf,fc->", 92 | "bb,ff,be->e", 93 | "bcb,bb,fc,fff->", 94 | "fbb,dfd,fc,fc->", 95 | "afd,ba,cc,dc->bf", 96 | "adb,bc,fa,cfc->d", 97 | "bbd,bda,fc,db->acf", 98 | "dba,ead,cad->bce", 99 | "aef,fbc,dca->bde", 100 | ] 101 | 102 | 103 | @pytest.mark.parametrize("optimize", (True, False, None)) 104 | def test_contract_plain_types(optimize: OptimizeKind) -> None: 105 | expr = "ij,jk,kl->il" 106 | ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)] 107 | 108 | path = contract_path(expr, *ops, optimize=optimize) 109 | assert len(path) == 2 110 | 111 | result = contract(expr, *ops, optimize=optimize) 112 | assert result.shape == (2, 2) 113 | 114 | 115 | @pytest.mark.parametrize("string", tests) 116 | @pytest.mark.parametrize("optimize", _PATH_OPTIONS) 117 | def test_compare(optimize: OptimizeKind, string: str) -> None: 118 | views = build_views(string) 119 | 120 | ein = contract(string, *views, optimize=False, use_blas=False) 121 | opt = contract(string, *views, optimize=optimize, use_blas=False) 122 | assert np.allclose(ein, opt) 123 | 124 | 125 | @pytest.mark.parametrize("string", tests) 126 | def test_drop_in_replacement(string: str) -> None: 127 | views = build_views(string) 128 | opt = contract(string, *views) 129 | assert np.allclose(opt, np.einsum(string, *views)) 130 | 131 | 132 | @pytest.mark.parametrize("string", tests) 133 | @pytest.mark.parametrize("optimize", _PATH_OPTIONS) 134 | def test_compare_greek(optimize: OptimizeKind, string: str) -> None: 135 | views = build_views(string) 136 | 137 | ein = contract(string, *views, optimize=False, use_blas=False) 138 | 139 | # convert to greek 140 | string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string) 141 | 142 | opt = contract(string, *views, optimize=optimize, use_blas=False) 143 | assert np.allclose(ein, opt) 144 | 145 | 146 | @pytest.mark.parametrize("string", tests) 147 | @pytest.mark.parametrize("optimize", _PATH_OPTIONS) 148 | def test_compare_blas(optimize: OptimizeKind, string: str) -> None: 149 | views = build_views(string) 150 | 151 | ein = contract(string, *views, optimize=False) 152 | opt = contract(string, *views, optimize=optimize) 153 | assert np.allclose(ein, opt) 154 | 155 | 156 | @pytest.mark.parametrize("string", tests) 157 | @pytest.mark.parametrize("optimize", _PATH_OPTIONS) 158 | def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None: 159 | views = build_views(string) 160 | 161 | ein = contract(string, *views, optimize=False) 162 | 163 | # convert to greek 164 | string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string) 165 | 166 | opt = contract(string, *views, optimize=optimize) 167 | assert np.allclose(ein, opt) 168 | 169 | 170 | def test_some_non_alphabet_maintains_order() -> None: 171 | # 'c beta a' should automatically go to -> 'a c beta' 172 | string = "c" + chr(ord("b") + 848) + "a" 173 | # but beta will be temporarily replaced with 'b' for which 'cba->abc' 174 | # so check manual output kicks in: 175 | x = np.random.rand(2, 3, 4) 176 | assert np.allclose(contract(string, x), contract("cxa", x)) 177 | 178 | 179 | def test_printing(): 180 | string = "bbd,bda,fc,db->acf" 181 | views = build_views(string) 182 | 183 | ein = contract_path(string, *views) 184 | assert len(str(ein[1])) == 728 185 | 186 | 187 | @pytest.mark.parametrize("string", tests) 188 | @pytest.mark.parametrize("optimize", _PATH_OPTIONS) 189 | @pytest.mark.parametrize("use_blas", [False, True]) 190 | @pytest.mark.parametrize("out_spec", [False, True]) 191 | def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None: 192 | views = build_views(string) 193 | shapes = [view.shape if hasattr(view, "shape") else () for view in views] 194 | expected = contract(string, *views, optimize=False, use_blas=False) 195 | 196 | expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas) 197 | 198 | if out_spec and ("->" in string) and (string[-2:] != "->"): 199 | (out,) = build_views(string.split("->")[1]) 200 | expr(*views, out=out) 201 | else: 202 | out = expr(*views) 203 | 204 | assert np.allclose(out, expected) 205 | 206 | # check representations 207 | assert string in expr.__repr__() 208 | assert string in expr.__str__() 209 | 210 | 211 | def test_contract_expression_interleaved_input() -> None: 212 | x, y, z = (np.random.randn(2, 2) for _ in "xyz") 213 | expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) 214 | xshp, yshp, zshp = ((2, 2) for _ in "xyz") 215 | expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0]) 216 | out = expr(x, y, z) 217 | assert np.allclose(out, expected) 218 | 219 | 220 | @pytest.mark.parametrize( 221 | "string,constants", 222 | [ 223 | ("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]), 224 | ("bdef,cdkj,ji,ikeh,hbc,lfo", [0, 1, 2, 3]), 225 | ("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]), 226 | ("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]), 227 | ("ijab,acd,bce,df,ef->ji", [1, 2, 3, 4]), 228 | ("ab,cd,ad,cb", [1, 3]), 229 | ("ab,bc,cd", [0, 1]), 230 | ], 231 | ) 232 | def test_contract_expression_with_constants(string: str, constants: List[int]) -> None: 233 | views = build_views(string) 234 | expected = contract(string, *views, optimize=False, use_blas=False) 235 | 236 | shapes = [view.shape if hasattr(view, "shape") else () for view in views] 237 | 238 | expr_args: List[Any] = [] 239 | ctrc_args = [] 240 | for i, (shape, view) in enumerate(zip(shapes, views)): 241 | if i in constants: 242 | expr_args.append(view) 243 | else: 244 | expr_args.append(shape) 245 | ctrc_args.append(view) 246 | 247 | expr = contract_expression(string, *expr_args, constants=constants) 248 | out = expr(*ctrc_args) 249 | assert np.allclose(expected, out) 250 | 251 | 252 | @pytest.mark.parametrize("optimize", ["greedy", "optimal"]) 253 | @pytest.mark.parametrize("n", [4, 5]) 254 | @pytest.mark.parametrize("reg", [2, 3]) 255 | @pytest.mark.parametrize("n_out", [0, 2, 4]) 256 | @pytest.mark.parametrize("global_dim", [False, True]) 257 | def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None: 258 | eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True) 259 | views = build_views(eq, size_dict) 260 | 261 | expected = contract(eq, *views, optimize=False) 262 | actual = contract(eq, *views, optimize=optimize) 263 | 264 | assert np.allclose(expected, actual) 265 | 266 | 267 | @pytest.mark.parametrize("equation", tests) 268 | def test_linear_vs_ssa(equation: str) -> None: 269 | views = build_views(equation) 270 | linear_path, _ = contract_path(equation, *views) 271 | ssa_path = linear_to_ssa(linear_path) 272 | linear_path2 = ssa_to_linear(ssa_path) 273 | assert linear_path2 == linear_path 274 | 275 | 276 | def test_contract_path_supply_shapes() -> None: 277 | eq = "ab,bc,cd" 278 | shps = [(2, 3), (3, 4), (4, 5)] 279 | contract_path(eq, *shps, shapes=True) 280 | -------------------------------------------------------------------------------- /opt_einsum/tests/test_edge_cases.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths 3 | """ 4 | 5 | from typing import Any, Tuple 6 | 7 | import pytest 8 | 9 | from opt_einsum import contract, contract_expression, contract_path 10 | from opt_einsum.typing import PathType 11 | 12 | # NumPy is required for the majority of this file 13 | np = pytest.importorskip("numpy") 14 | 15 | 16 | def test_contract_expression_checks() -> None: 17 | # check optimize needed 18 | with pytest.raises(ValueError): 19 | contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False) 20 | 21 | # check sizes are still checked 22 | with pytest.raises(ValueError): 23 | contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42)) 24 | 25 | # check if out given 26 | out = np.empty((2, 4)) 27 | with pytest.raises(ValueError): 28 | contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out) 29 | 30 | # check still get errors when wrong ranks supplied to expression 31 | expr = contract_expression("ab,bc->ac", (2, 3), (3, 4)) 32 | 33 | # too few arguments 34 | with pytest.raises(ValueError) as err: 35 | expr(np.random.rand(2, 3)) 36 | assert "`ContractExpression` takes exactly 2" in str(err.value) 37 | 38 | # too many arguments 39 | with pytest.raises(ValueError) as err: 40 | expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3)) 41 | assert "`ContractExpression` takes exactly 2" in str(err.value) 42 | 43 | # wrong shapes 44 | with pytest.raises(ValueError) as err: 45 | expr(np.random.rand(2, 3, 4), np.random.rand(3, 4)) 46 | assert "Internal error while evaluating `ContractExpression`" in str(err.value) 47 | with pytest.raises(ValueError) as err: 48 | expr(np.random.rand(2, 4), np.random.rand(3, 4, 5)) 49 | assert "Internal error while evaluating `ContractExpression`" in str(err.value) 50 | with pytest.raises(ValueError) as err: 51 | expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6)) 52 | assert "Internal error while evaluating `ContractExpression`" in str(err.value) 53 | 54 | # should only be able to specify out 55 | with pytest.raises(TypeError) as err_type: 56 | expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore 57 | assert "got an unexpected keyword" in str(err_type.value) 58 | 59 | 60 | def test_broadcasting_contraction() -> None: 61 | a = np.random.rand(1, 5, 4) 62 | b = np.random.rand(4, 6) 63 | c = np.random.rand(5, 6) 64 | d = np.random.rand(10) 65 | 66 | ein_scalar = contract("ijk,kl,jl", a, b, c, optimize=False) 67 | opt_scalar = contract("ijk,kl,jl", a, b, c, optimize=True) 68 | assert np.allclose(ein_scalar, opt_scalar) 69 | 70 | result = ein_scalar * d 71 | 72 | ein = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=False) 73 | opt = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=True) 74 | 75 | assert np.allclose(ein, result) 76 | assert np.allclose(opt, result) 77 | 78 | 79 | def test_broadcasting_contraction2() -> None: 80 | a = np.random.rand(1, 1, 5, 4) 81 | b = np.random.rand(4, 6) 82 | c = np.random.rand(5, 6) 83 | d = np.random.rand(7, 7) 84 | 85 | ein_scalar = contract("abjk,kl,jl", a, b, c, optimize=False) 86 | opt_scalar = contract("abjk,kl,jl", a, b, c, optimize=True) 87 | assert np.allclose(ein_scalar, opt_scalar) 88 | 89 | result = ein_scalar * d 90 | 91 | ein = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=False) 92 | opt = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=True) 93 | 94 | assert np.allclose(ein, result) 95 | assert np.allclose(opt, result) 96 | 97 | 98 | def test_broadcasting_contraction3() -> None: 99 | a = np.random.rand(1, 5, 4) 100 | b = np.random.rand(4, 1, 6) 101 | c = np.random.rand(5, 6) 102 | d = np.random.rand(7, 7) 103 | 104 | ein = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=False) 105 | opt = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=True) 106 | 107 | assert np.allclose(ein, opt) 108 | 109 | 110 | def test_broadcasting_contraction4() -> None: 111 | a = np.arange(64).reshape(2, 4, 8) 112 | ein = contract("obk,ijk->ioj", a, a, optimize=False) 113 | opt = contract("obk,ijk->ioj", a, a, optimize=True) 114 | 115 | assert np.allclose(ein, opt) 116 | 117 | 118 | def test_can_blas_on_healed_broadcast_dimensions() -> None: 119 | expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20)) 120 | # first contraction involves broadcasting 121 | assert expr.contraction_list[0][2] == "bc,ab->bca" 122 | assert expr.contraction_list[0][-1] is False 123 | # but then is healed GEMM is usable 124 | assert expr.contraction_list[1][2] == "bca,bd->acd" 125 | assert expr.contraction_list[1][-1] == "GEMM" 126 | 127 | 128 | def test_pathinfo_for_empty_contraction() -> None: 129 | eq = "->" 130 | arrays = (1.0,) 131 | path: PathType = [] 132 | _, info = contract_path(eq, *arrays, optimize=path) 133 | # some info is built lazily, so check repr 134 | assert repr(info) 135 | assert info.largest_intermediate == 1 136 | 137 | 138 | @pytest.mark.parametrize( 139 | "expression, operands", 140 | [ 141 | [",,->", (5, 5.0, 2.0j)], 142 | ["ab,->", ([[5, 5], [2.0, 1]], 2.0j)], 143 | ["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])], 144 | ["ab,->", ([[5, 5], [2.0, 1]], True)], 145 | ], 146 | ) 147 | def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None: 148 | """Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes.""" 149 | 150 | benchmark = np.einsum(expression, *operands) 151 | result = contract(expression, *operands, optimize=True) 152 | assert np.allclose(benchmark, result) 153 | -------------------------------------------------------------------------------- /opt_einsum/tests/test_input.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests. 3 | """ 4 | 5 | from typing import Any 6 | 7 | import pytest 8 | 9 | from opt_einsum import contract, contract_path 10 | from opt_einsum.testing import build_views 11 | 12 | np = pytest.importorskip("numpy") 13 | 14 | 15 | def test_type_errors() -> None: 16 | # subscripts must be a string 17 | with pytest.raises(TypeError): 18 | contract(0, 0) 19 | 20 | # out parameter must be an array 21 | with pytest.raises(TypeError): 22 | contract("", 0, out="test") 23 | 24 | # order parameter must be a valid order 25 | # changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c 26 | with pytest.raises((TypeError, ValueError)): 27 | contract("", 0, order="W") # type: ignore 28 | 29 | # casting parameter must be a valid casting 30 | with pytest.raises(ValueError): 31 | contract("", 0, casting="blah") # type: ignore 32 | 33 | # dtype parameter must be a valid dtype 34 | with pytest.raises(TypeError): 35 | contract("", 0, dtype="bad_data_type") 36 | 37 | # other keyword arguments are rejected 38 | with pytest.raises(TypeError): 39 | contract("", 0, bad_arg=0) 40 | 41 | # issue 4528 revealed a segfault with this call 42 | with pytest.raises(TypeError): 43 | contract(*(None,) * 63) 44 | 45 | # Cannot have two -> 46 | with pytest.raises(ValueError): 47 | contract("->,->", 0, 5) 48 | 49 | # Undefined symbol lhs 50 | with pytest.raises(ValueError): 51 | contract("&,a->", 0, 5) 52 | 53 | # Undefined symbol rhs 54 | with pytest.raises(ValueError): 55 | contract("a,a->&", 0, 5) 56 | 57 | with pytest.raises(ValueError): 58 | contract("a,a->&", 0, 5) 59 | 60 | # Catch ellipsis errors 61 | string = "...a->...a" 62 | views = build_views(string, replace_ellipsis=True) 63 | 64 | # Subscript list must contain Ellipsis or (hashable && comparable) object 65 | with pytest.raises(TypeError): 66 | contract(views[0], [Ellipsis, 0], [Ellipsis, ["a"]]) 67 | 68 | with pytest.raises(TypeError): 69 | contract(views[0], [Ellipsis, {}], [Ellipsis, "a"]) 70 | 71 | 72 | @pytest.mark.parametrize("contract_fn", [contract, contract_path]) 73 | def test_value_errors(contract_fn: Any) -> None: 74 | with pytest.raises(ValueError): 75 | contract_fn("") 76 | 77 | # subscripts must be a string 78 | with pytest.raises(TypeError): 79 | contract_fn(0, 0) 80 | 81 | # invalid subscript character 82 | with pytest.raises(ValueError): 83 | contract_fn("i%...", [0, 0]) 84 | with pytest.raises(ValueError): 85 | contract_fn("...j$", [0, 0]) 86 | with pytest.raises(ValueError): 87 | contract_fn("i->&", [0, 0]) 88 | 89 | with pytest.raises(ValueError): 90 | contract_fn("") 91 | # number of operands must match count in subscripts string 92 | with pytest.raises(ValueError): 93 | contract_fn("", 0, 0) 94 | with pytest.raises(ValueError): 95 | contract_fn(",", 0, [0], [0]) 96 | with pytest.raises(ValueError): 97 | contract_fn(",", [0]) 98 | 99 | # can't have more subscripts than dimensions in the operand 100 | with pytest.raises(ValueError): 101 | contract_fn("i", 0) 102 | with pytest.raises(ValueError): 103 | contract_fn("ij", [0, 0]) 104 | with pytest.raises(ValueError): 105 | contract_fn("...i", 0) 106 | with pytest.raises(ValueError): 107 | contract_fn("i...j", [0, 0]) 108 | with pytest.raises(ValueError): 109 | contract_fn("i...", 0) 110 | with pytest.raises(ValueError): 111 | contract_fn("ij...", [0, 0]) 112 | 113 | # invalid ellipsis 114 | with pytest.raises(ValueError): 115 | contract_fn("i..", [0, 0]) 116 | with pytest.raises(ValueError): 117 | contract_fn(".i...", [0, 0]) 118 | with pytest.raises(ValueError): 119 | contract_fn("j->..j", [0, 0]) 120 | with pytest.raises(ValueError): 121 | contract_fn("j->.j...", [0, 0]) 122 | 123 | # invalid subscript character 124 | with pytest.raises(ValueError): 125 | contract_fn("i%...", [0, 0]) 126 | with pytest.raises(ValueError): 127 | contract_fn("...j$", [0, 0]) 128 | with pytest.raises(ValueError): 129 | contract_fn("i->&", [0, 0]) 130 | 131 | # output subscripts must appear in input 132 | with pytest.raises(ValueError): 133 | contract_fn("i->ij", [0, 0]) 134 | 135 | # output subscripts may only be specified once 136 | with pytest.raises(ValueError): 137 | contract_fn("ij->jij", [[0, 0], [0, 0]]) 138 | 139 | # dimensions much match when being collapsed 140 | with pytest.raises(ValueError): 141 | contract_fn("ii", np.arange(6).reshape(2, 3)) 142 | with pytest.raises(ValueError): 143 | contract_fn("ii->i", np.arange(6).reshape(2, 3)) 144 | 145 | # broadcasting to new dimensions must be enabled explicitly 146 | with pytest.raises(ValueError): 147 | contract_fn("i", np.arange(6).reshape(2, 3)) 148 | 149 | with pytest.raises(TypeError): 150 | contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True) 151 | 152 | 153 | def test_input_formats_shapes(): 154 | """ 155 | Test that the shapes are the same for the bench and interleved input formats 156 | """ 157 | shape1 = (2, 3, 4) 158 | shape2 = (3, 4, 5) 159 | 160 | bench = contract_path("abc,bcd->da", shape1, shape2, shapes=True) 161 | interleved = contract_path(shape1, [1, 2, 3], shape2, [2, 3, 4], [4, 1], shapes=True) 162 | assert bench[0] == interleved[0] 163 | 164 | 165 | @pytest.mark.parametrize( 166 | "string", 167 | [ 168 | # Ellipse 169 | "...a->...", 170 | "a...->...", 171 | "a...a->...a", 172 | "...,...", 173 | "a,b", 174 | "...a,...b", 175 | ], 176 | ) 177 | def test_compare(string: str) -> None: 178 | views = build_views(string, replace_ellipsis=True) 179 | 180 | ein = contract(string, *views, optimize=False) 181 | opt = contract(string, *views) 182 | assert np.allclose(ein, opt) 183 | 184 | opt = contract(string, *views, optimize="optimal") 185 | assert np.allclose(ein, opt) 186 | 187 | 188 | def test_ellipse_input1() -> None: 189 | string = "...a->..." 190 | views = build_views(string, replace_ellipsis=True) 191 | 192 | ein = contract(string, *views, optimize=False) 193 | opt = contract(views[0], [Ellipsis, 0], [Ellipsis]) 194 | assert np.allclose(ein, opt) 195 | 196 | 197 | def test_ellipse_input2() -> None: 198 | string = "...a" 199 | views = build_views(string, replace_ellipsis=True) 200 | 201 | ein = contract(string, *views, optimize=False) 202 | opt = contract(views[0], [Ellipsis, 0]) 203 | assert np.allclose(ein, opt) 204 | 205 | 206 | def test_ellipse_input3() -> None: 207 | string = "...a->...a" 208 | views = build_views(string, replace_ellipsis=True) 209 | 210 | ein = contract(string, *views, optimize=False) 211 | opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0]) 212 | assert np.allclose(ein, opt) 213 | 214 | 215 | def test_ellipse_input4() -> None: 216 | string = "...b,...a->..." 217 | views = build_views(string, replace_ellipsis=True) 218 | 219 | ein = contract(string, *views, optimize=False) 220 | opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis]) 221 | assert np.allclose(ein, opt) 222 | 223 | 224 | def test_singleton_dimension_broadcast() -> None: 225 | # singleton dimensions broadcast (gh-10343) 226 | p = np.ones((10, 2)) 227 | q = np.ones((1, 2)) 228 | 229 | ein = contract("ij,ij->j", p, q, optimize=False) 230 | opt = contract("ij,ij->j", p, q, optimize=True) 231 | assert np.allclose(ein, opt) 232 | assert np.allclose(opt, [10.0, 10.0]) 233 | 234 | p = np.ones((1, 5)) 235 | q = np.ones((5, 5)) 236 | 237 | for optimize in (True, False): 238 | res1 = (contract("...ij,...jk->...ik", p, p, optimize=optimize),) 239 | res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize) 240 | assert np.allclose(res1, res2) 241 | assert np.allclose(res2, np.full((1, 5), 5)) 242 | 243 | 244 | def test_large_int_input_format() -> None: 245 | string = "ab,bc,cd" 246 | x, y, z = build_views(string) 247 | string_output = contract(string, x, y, z) 248 | int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003)) 249 | assert np.allclose(string_output, int_output) 250 | for i in range(10): 251 | transpose_output = contract(x, (i + 1, i)) 252 | assert np.allclose(transpose_output, x.T) 253 | 254 | 255 | def test_hashable_object_input_format() -> None: 256 | string = "ab,bc,cd" 257 | x, y, z = build_views(string) 258 | string_output = contract(string, x, y, z) 259 | hash_output1 = contract(x, ("left", "bond1"), y, ("bond1", "bond2"), z, ("bond2", "right")) 260 | hash_output2 = contract( 261 | x, 262 | ("left", "bond1"), 263 | y, 264 | ("bond1", "bond2"), 265 | z, 266 | ("bond2", "right"), 267 | ("left", "right"), 268 | ) 269 | assert np.allclose(string_output, hash_output1) 270 | assert np.allclose(hash_output1, hash_output2) 271 | for i in range(1, 10): 272 | transpose_output = contract(x, ("b" * i, "a" * i)) 273 | assert np.allclose(transpose_output, x.T) 274 | -------------------------------------------------------------------------------- /opt_einsum/tests/test_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly tests various parser utility functions. 3 | """ 4 | 5 | from typing import Any, Tuple 6 | 7 | import pytest 8 | 9 | from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input 10 | from opt_einsum.testing import build_arrays_from_tuples 11 | 12 | 13 | def test_get_symbol() -> None: 14 | assert get_symbol(2) == "c" 15 | assert get_symbol(200000) == "\U00031540" 16 | # Ensure we skip surrogates '[\uD800-\uDFFF]' 17 | assert get_symbol(55295) == "\ud88b" 18 | assert get_symbol(55296) == "\ue000" 19 | assert get_symbol(57343) == "\ue7ff" 20 | 21 | 22 | def test_parse_einsum_input() -> None: 23 | eq = "ab,bc,cd" 24 | ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)]) 25 | input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops]) 26 | assert input_subscripts == eq 27 | assert output_subscript == "ad" 28 | assert operands == ops 29 | 30 | 31 | def test_parse_einsum_input_shapes_error() -> None: 32 | eq = "ab,bc,cd" 33 | ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)]) 34 | 35 | with pytest.raises(ValueError): 36 | _ = parse_einsum_input([eq, *ops], shapes=True) 37 | 38 | 39 | def test_parse_einsum_input_shapes() -> None: 40 | eq = "ab,bc,cd" 41 | shapes = [(2, 3), (3, 4), (4, 5)] 42 | input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) 43 | assert input_subscripts == eq 44 | assert output_subscript == "ad" 45 | assert shapes == operands 46 | 47 | 48 | def test_parse_with_ellisis() -> None: 49 | eq = "...a,ab" 50 | shapes = [(2, 3), (3, 4)] 51 | input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) 52 | assert input_subscripts == "da,ab" 53 | assert output_subscript == "db" 54 | assert shapes == operands 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "array, shape", 59 | [ 60 | [[5], (1,)], 61 | [[5, 5], (2,)], 62 | [(5, 5), (2,)], 63 | [[[[[[5, 2]]]]], (1, 1, 1, 1, 2)], 64 | [[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)], 65 | ["A", ()], 66 | [b"A", ()], 67 | [True, ()], 68 | [5, ()], 69 | [5.0, ()], 70 | [5.0 + 0j, ()], 71 | ], 72 | ) 73 | def test_get_shapes(array: Any, shape: Tuple[int]) -> None: 74 | assert get_shape(array) == shape 75 | -------------------------------------------------------------------------------- /opt_einsum/typing.py: -------------------------------------------------------------------------------- 1 | """Types used in the opt_einsum package.""" 2 | 3 | from collections import namedtuple 4 | from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union 5 | 6 | TensorShapeType = Tuple[int, ...] 7 | PathType = Collection[TensorShapeType] 8 | 9 | ArrayType = Any 10 | 11 | ArrayIndexType = FrozenSet[str] 12 | ArrayShaped = namedtuple("ArrayShaped", ["shape"]) 13 | 14 | ContractionListType = List[Tuple[Any, ArrayIndexType, str, Optional[Tuple[str, ...]], Union[str, bool]]] 15 | PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType] 16 | 17 | # Contract kwargs 18 | OptimizeKind = Union[ 19 | None, 20 | bool, 21 | Literal[ 22 | "optimal", "dp", "greedy", "random-greedy", "random-greedy-128", "branch-all", "branch-2", "auto", "auto-hq" 23 | ], 24 | PathType, 25 | PathSearchFunctionType, 26 | ] 27 | BackendType = Literal["auto", "object", "autograd", "cupy", "dask", "jax", "theano", "tensorflow", "torch", "libjax"] 28 | -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | @article{NumPy, 2 | author={S. van der Walt and S. C. Colbert and G. Varoquaux}, 3 | journal={Comput. Sci. Eng.}, 4 | title={The NumPy Array: A Structure for Efficient Numerical Computation}, 5 | year={2011}, 6 | volume={13}, 7 | number={2}, 8 | pages={22-30}, 9 | doi={10.1109/MCSE.2011.37}, 10 | ISSN={1521-9615}, 11 | month={March},} 12 | 13 | @Misc{Dask, 14 | title = {Dask: Library for dynamic task scheduling}, 15 | author = {{Dask Development Team}}, 16 | year = {2016}, 17 | url = {http://dask.pydata.org}, 18 | note = {(accessed date May 9th, 2018)} 19 | } 20 | 21 | @article{Tensorflow, 22 | author = {Mart{\'{\i}}n Abadi and 23 | Ashish Agarwal and 24 | Paul Barham and 25 | Eugene Brevdo and 26 | Zhifeng Chen and 27 | Craig Citro and 28 | Gregory S. Corrado and 29 | Andy Davis and 30 | Jeffrey Dean and 31 | Matthieu Devin and 32 | Sanjay Ghemawat and 33 | Ian J. Goodfellow and 34 | Andrew Harp and 35 | Geoffrey Irving and 36 | Michael Isard and 37 | Yangqing Jia and 38 | Rafal J{\'{o}}zefowicz and 39 | Lukasz Kaiser and 40 | Manjunath Kudlur and 41 | Josh Levenberg and 42 | Dan Man{\'{e}} and 43 | Rajat Monga and 44 | Sherry Moore and 45 | Derek Gordon Murray and 46 | Chris Olah and 47 | Mike Schuster and 48 | Jonathon Shlens and 49 | Benoit Steiner and 50 | Ilya Sutskever and 51 | Kunal Talwar and 52 | Paul A. Tucker and 53 | Vincent Vanhoucke and 54 | Vijay Vasudevan and 55 | Fernanda B. Vi{\'{e}}gas and 56 | Oriol Vinyals and 57 | Pete Warden and 58 | Martin Wattenberg and 59 | Martin Wicke and 60 | Yuan Yu and 61 | Xiaoqiang Zheng}, 62 | title = {TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed 63 | Systems}, 64 | journal = {CoRR}, 65 | volume = {abs/1603.04467}, 66 | year = {2016}, 67 | url = {http://arxiv.org/abs/1603.04467}, 68 | archivePrefix = {arXiv}, 69 | eprint = {1603.04467}, 70 | timestamp = {Wed, 07 Jun 2017 14:40:20 +0200}, 71 | biburl = {https://dblp.org/rec/bib/journals/corr/AbadiABBCCCDDDG16}, 72 | bibsource = {dblp computer science bibliography, https://dblp.org} 73 | } 74 | 75 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: opt\_einsum - A Python package for optimizing contraction order for einsum-like expressions 3 | tags: 4 | - array 5 | - tensors 6 | - optimization 7 | - phylogenetics 8 | - natural selection 9 | - molecular evolution 10 | authors: 11 | - name: Daniel G. A. Smith 12 | orcid: 0000-0001-8626-0900 13 | affiliation: "1" 14 | - name: Johnnie Gray 15 | orcid: 0000-0001-9461-3024 16 | affiliation: "2" 17 | 18 | affiliations: 19 | - name: The Molecular Science Software Institute, Blacksburg, VA 24060 20 | index: 1 21 | - name: University College London, London, UK 22 | index: 2 23 | date: 14 May 2018 24 | bibliography: paper.bib 25 | --- 26 | 27 | # Summary 28 | 29 | ``einsum`` is a powerful Swiss army knife for arbitrary tensor contractions and 30 | general linear algebra found in the popular ``numpy`` [@NumPy] package. While 31 | these expressions can be used to form most mathematical operations found in 32 | NumPy, the optimization of these expressions becomes increasingly important as 33 | naive implementations increase the overall scaling of these expressions 34 | resulting in a dramatic increase in overall execution time. Expressions with 35 | many tensors are particularly prevalent in many-body theories such as quantum 36 | chemistry, particle physics, and nuclear physics in addition to other fields 37 | such as machine learning. At the extreme case, matrix product state theory can 38 | have thousands of tensors meaning that the computation cannot proceed in a 39 | naive fashion. 40 | 41 | The canonical NumPy ``einsum`` function considers expressions as a single unit 42 | and is not able to factor these expressions into multiple smaller pieces. For 43 | example, consider the following index transformation: ``M_{pqrs} = C_{pi} C_{qj} 44 | I_{ijkl} C_{rk} C_{sl}`` with two different algorithms: 45 | 46 | ```python 47 | import numpy as np 48 | 49 | dim = 10 50 | I = np.random.rand(dim, dim, dim, dim) 51 | C = np.random.rand(dim, dim) 52 | 53 | def naive(I, C): 54 | # N^8 scaling 55 | return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 56 | 57 | def optimized(I, C): 58 | # N^5 scaling 59 | K = np.einsum('pi,ijkl->pjkl', C, I) 60 | K = np.einsum('qj,pjkl->pqkl', C, K) 61 | K = np.einsum('rk,pqkl->pqrl', C, K) 62 | K = np.einsum('sl,pqrl->pqrs', C, K) 63 | return K 64 | ``` 65 | 66 | By building intermediate arrays the overall scaling of the contraction is 67 | reduced and considerable cost savings even for small ``N`` (``N=10``) can be seen: 68 | 69 | ```python 70 | >> np.allclose(naive(I, C), optimized(I, C)) 71 | True 72 | 73 | %timeit naive(I, C) 74 | 1 loops, best of 3: 829 ms per loop 75 | 76 | %timeit optimized(I, C) 77 | 1000 loops, best of 3: 445 µs per loop 78 | ``` 79 | 80 | This index transformation is a well known contraction that leads to 81 | straightforward intermediates. This contraction can be further complicated by 82 | considering that the shape of the C matrices need not be the same, in this case 83 | the ordering in which the indices are transformed matters greatly. The 84 | opt_einsum package handles this logic automatically and is a drop in 85 | replacement for the ``np.einsum`` function: 86 | 87 | ```python 88 | from opt_einsum import contract 89 | 90 | dim = 30 91 | I = np.random.rand(dim, dim, dim, dim) 92 | C = np.random.rand(dim, dim) 93 | 94 | %timeit optimized(I, C) 95 | 10 loops, best of 3: 65.8 ms per loop 96 | 97 | %timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C) 98 | 100 loops, best of 3: 16.2 ms per loop 99 | ``` 100 | 101 | The above automatically will find the optimal contraction order, in this case 102 | identical to that of the optimized function above, and computes the products. 103 | In this case, it uses ``np.dot`` internally to exploit any vendor BLAS 104 | functionality that the NumPy build may have. 105 | 106 | In addition, backends other than NumPy can be used to either exploit GPU 107 | computation via Tensorflow [@Tensorflow] or distributed compute capabilities 108 | via Dask [@Dask]. The core components of ``opt_einsum`` have been contributed 109 | back to the ``numpy`` library and can be found in all ``numpy.einsum`` function 110 | calls in version 1.12 or later using the ``optimize`` keyword 111 | (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.einsum.html). 112 | 113 | The software is on GitHub (https://github.com/dgasmith/opt_einsum/tree/v2.0.0) 114 | and can be downloaded via pip or conda-forge. Further discussion of features 115 | and uses can be found at the documentation 116 | (http://optimized-einsum.readthedocs.io/en/latest/). 117 | 118 | # Acknowledgements 119 | 120 | We acknowledge additional contributions from Fabian-Robert Stöter, Robert T. 121 | McGibbon, and Nils Werner to this project. 122 | 123 | # References 124 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ['hatchling', 'hatch-fancy-pypi-readme>=22.5.0', 'hatch-vcs'] 3 | build-backend = 'hatchling.build' 4 | 5 | [project] 6 | name = 'opt_einsum' 7 | description = 'Path optimization of einsum functions.' 8 | authors = [ 9 | {name = 'Daniel Smith', email = 'dgasmith@icloud.com'}, 10 | ] 11 | license = 'MIT' 12 | classifiers = [ 13 | 'Development Status :: 5 - Production/Stable', 14 | 'Programming Language :: Python', 15 | 'Programming Language :: Python :: Implementation :: CPython', 16 | 'Programming Language :: Python :: Implementation :: PyPy', 17 | 'Programming Language :: Python :: 3', 18 | 'Programming Language :: Python :: 3 :: Only', 19 | 'Programming Language :: Python :: 3.9', 20 | 'Programming Language :: Python :: 3.10', 21 | 'Programming Language :: Python :: 3.11', 22 | 'Programming Language :: Python :: 3.12', 23 | 'Programming Language :: Python :: 3.13', 24 | 'Intended Audience :: Developers', 25 | 'Intended Audience :: Science/Research', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Topic :: Software Development :: Libraries :: Python Modules', 28 | 29 | ] 30 | requires-python = '>=3.8' 31 | dependencies = [ 32 | ] 33 | dynamic = ['version', 'readme'] 34 | 35 | [tool.hatch.version] 36 | source = "vcs" 37 | path = 'opt_einsum/_version.py' 38 | 39 | [tool.hatch.metadata] 40 | allow-direct-references = true 41 | 42 | [tool.hatch.build.hooks.vcs] 43 | version-file = "opt_einsum/_version.py" 44 | 45 | [tool.hatch.metadata.hooks.fancy-pypi-readme] 46 | content-type = 'text/markdown' 47 | # construct the PyPI readme from README.md and HISTORY.md 48 | fragments = [ 49 | {path = "README.md"}, 50 | ] 51 | 52 | [tool.hatch.build.targets.sdist] 53 | exclude = [ 54 | "/.github", 55 | "/devtools", 56 | "/docs", 57 | "/paper", 58 | "/scripts" 59 | ] 60 | 61 | [tool.hatch.build.targets.wheel] 62 | packages = ["opt_einsum"] 63 | 64 | [tool.pytest.ini_options] 65 | filterwarnings = [ 66 | 'ignore::DeprecationWarning:tensorflow', 67 | 'ignore::DeprecationWarning:tensorboard', 68 | ] 69 | 70 | [tool.ruff] 71 | line-length = 120 72 | target-version = 'py38' 73 | 74 | [tool.ruff.lint] 75 | extend-select = ['RUF100', 'UP', 'C', 'D', 'I', 'N', 'NPY', 'Q', 'T', 'W'] 76 | extend-ignore = ['C901', 'D101', 'D102', 'D103', 'D105', 'D107', 'D205', 'D415'] 77 | isort = { known-first-party = ['opt_einsum'] } 78 | mccabe = { max-complexity = 14 } 79 | pydocstyle = { convention = 'google' } 80 | 81 | [tool.ruff.lint.per-file-ignores] 82 | 'opt_einsum/tests/*' = ['D', 'T201', 'NPY002', 'ANN001', 'ANN202'] 83 | 84 | [tool.coverage.run] 85 | source = ['opt_einsum'] 86 | omit = ['*/tests/*', 'opt_einsum/_version.py'] 87 | branch = true 88 | relative_files = true 89 | 90 | [[tool.mypy.overrides]] 91 | module = "cupy.*, jax.*, numpy.*, theano.*, tensorflow.*, torch.*" 92 | ignore_missing_imports = true -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | Compare 2 | ======= 3 | 4 | This is a scratch folder to compare large numbers of contractions in different ways. 5 | -------------------------------------------------------------------------------- /scripts/compare_random_paths.py: -------------------------------------------------------------------------------- 1 | import resource 2 | import timeit 3 | from typing import Literal 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import opt_einsum as oe 9 | 10 | rsrc = resource.RLIMIT_DATA 11 | limit = int(1e9) 12 | resource.setrlimit(rsrc, (limit, limit)) 13 | 14 | pd.set_option("display.width", 200) 15 | 16 | opt_path: Literal["optimal"] = "optimal" 17 | 18 | # Number of dimensions 19 | max_dims = 4 20 | min_dims = 2 21 | 22 | # Size of each dimension 23 | min_size = 10 24 | max_size = 20 25 | 26 | # Number of terms 27 | min_terms = 3 28 | max_terms = 5 29 | 30 | # Additional parameters 31 | max_indices = 6 32 | max_doubles = 1e7 33 | 34 | alpha = list("abcdefghijklmnopqrstuvwyxz") 35 | alpha_dict = {num: x for num, x in enumerate(alpha)} 36 | 37 | print("Maximum term size is %d" % (max_size**max_dims)) 38 | 39 | 40 | def make_term(): 41 | num_dims = np.random.randint(min_dims, max_dims + 1) 42 | term = np.random.randint(0, max_indices, num_dims) 43 | return term 44 | 45 | 46 | def get_string(term): 47 | return "".join([alpha_dict[x] for x in term]) 48 | 49 | 50 | def random_contraction(): 51 | 52 | # Compute number of terms 53 | num_terms = np.random.randint(min_terms, max_terms) 54 | 55 | # Compute size of each index 56 | index_size = np.random.randint(min_size, max_size, max_indices) 57 | 58 | # Build random terms and views 59 | int_terms = [make_term() for x in range(num_terms)] 60 | views = [np.random.rand(*index_size[s]) for s in int_terms] 61 | 62 | # Compute einsum string and return string 63 | sum_string = ",".join([get_string(s) for s in int_terms]) 64 | out_string = sum_string.replace(",", "") 65 | out_string = [x for x in alpha if out_string.count(x) == 1] 66 | 67 | # sum_string += '->' 68 | sum_string += "->" + "".join(out_string) 69 | return (sum_string, views, index_size) 70 | 71 | 72 | out = [] 73 | for x in range(200): 74 | sum_string, views, index_size = random_contraction() 75 | 76 | try: 77 | ein = np.einsum(sum_string, *views) 78 | except Exception: 79 | out.append(["Einsum failed", sum_string, index_size, 0, 0]) 80 | continue 81 | 82 | try: 83 | opt = oe.contract(sum_string, *views, path=opt_path) 84 | except Exception: 85 | out.append(["Opt_einsum failed", sum_string, index_size, 0, 0]) 86 | continue 87 | 88 | current_opt_path = oe.contract_path(sum_string, *views, optimize=opt_path)[0] 89 | if not np.allclose(ein, opt): 90 | out.append(["Comparison failed", sum_string, index_size, 0, 0]) 91 | continue 92 | 93 | setup = "import numpy as np; import opt_einsum as oe; \ 94 | from __main__ import sum_string, views, current_opt_path" 95 | 96 | einsum_string = "np.einsum(sum_string, *views)" 97 | contract_string = "oe.contract(sum_string, *views, path=current_opt_path)" 98 | 99 | e_n = 1 100 | o_n = 1 101 | einsum_time = timeit.timeit(einsum_string, setup=setup, number=e_n) / e_n 102 | contract_time = timeit.timeit(contract_string, setup=setup, number=o_n) / o_n 103 | 104 | out.append([True, sum_string, current_opt_path, einsum_time, contract_time]) 105 | 106 | df = pd.DataFrame(out) 107 | df.columns = ["Flag", "String", "Path", "Einsum time", "Opt_einsum time"] 108 | df["Ratio"] = df["Einsum time"] / df["Opt_einsum time"] 109 | 110 | diff_flags = df["Flag"] is not True 111 | print("\nNumber of contract different than einsum: %d." % np.sum(diff_flags)) 112 | if diff_flags > 0: 113 | print("Terms different than einsum") 114 | print(df[df["Flag"] is not True]) 115 | 116 | print("\nDescription of speedup in relative terms:") 117 | print(df["Ratio"].describe()) 118 | 119 | print("\nNumber of contract slower than einsum: %d." % np.sum(df["Ratio"] < 0.90)) 120 | tmp = df.loc[df["Ratio"] < 0.90].copy() 121 | tmp["Diff (us)"] = np.abs(tmp["Einsum time"] - tmp["Opt_einsum time"]) * 1e6 122 | tmp = tmp.sort_values("Diff (us)", ascending=False) 123 | print(tmp) 124 | 125 | # diff_us = np.abs(tmp['Einsum time'] - tmp['Opt_einsum time'])*1e6 126 | print("\nDescription of slowdown:") 127 | print(tmp.describe()) 128 | --------------------------------------------------------------------------------