├── .github └── workflows │ └── rust.yml ├── .gitignore ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── assets ├── benchmarks_bar_plot_13_29.png ├── benchmarks_bar_plot_4_12.png ├── py_benchmarks_bar_plot_0_8.png └── py_benchmarks_bar_plot_9_28.png ├── benches ├── Makefile ├── README.md ├── bench.rs ├── benchmark.sh ├── benchmark_plots.py ├── main.c ├── py_benchmarks.py ├── requirements.txt ├── run_benches.py └── utils.py ├── codecov.yml ├── examples ├── benchmark.rs ├── fftwrb.rs ├── profile.rs └── rustfft.rs ├── hooks └── pre-commit ├── profile.sh ├── pyphastft ├── .github │ └── workflows │ │ └── CI.yml ├── .gitignore ├── Cargo.toml ├── example.py ├── pyproject.toml ├── src │ └── lib.rs └── vis_qt.py ├── rust-toolchain.toml ├── scripts └── twiddle_generator.py ├── src ├── cobra.rs ├── kernels.rs ├── lib.rs ├── options.rs ├── planner.rs ├── twiddles.rs └── utils.rs └── utilities ├── Cargo.toml └── src └── lib.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | on: [ push, pull_request ] 2 | 3 | name: Build 4 | 5 | jobs: 6 | clippy: 7 | name: Clippy 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout sources 11 | uses: actions/checkout@v2 12 | 13 | - name: Install nightly toolchain with clippy available 14 | uses: actions-rs/toolchain@v1 15 | with: 16 | profile: minimal 17 | toolchain: nightly 18 | override: true 19 | components: clippy 20 | 21 | - name: Run cargo clippy 22 | uses: actions-rs/cargo@v1 23 | with: 24 | command: clippy 25 | args: -- -D warnings 26 | 27 | rustfmt: 28 | name: Format 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: Checkout sources 32 | uses: actions/checkout@v2 33 | 34 | - name: Install nightly toolchain with rustfmt available 35 | uses: actions-rs/toolchain@v1 36 | with: 37 | profile: minimal 38 | toolchain: nightly 39 | override: true 40 | components: rustfmt 41 | 42 | - name: Run cargo fmt 43 | uses: actions-rs/cargo@v1 44 | with: 45 | command: fmt 46 | args: --all -- --check 47 | 48 | combo: 49 | name: Test 50 | runs-on: ubuntu-latest 51 | steps: 52 | - name: Checkout sources 53 | uses: actions/checkout@v2 54 | 55 | - name: Install nightly toolchain 56 | uses: actions-rs/toolchain@v1 57 | with: 58 | profile: minimal 59 | toolchain: nightly 60 | override: true 61 | 62 | - name: Run cargo test 63 | uses: actions-rs/cargo@v1 64 | with: 65 | command: test 66 | args: --all-features 67 | 68 | coverage: 69 | runs-on: ubuntu-latest 70 | env: 71 | CARGO_TERM_COLOR: always 72 | steps: 73 | - name: Checkout sources 74 | uses: actions/checkout@v2 75 | 76 | - name: Install nightly toolchain 77 | uses: actions-rs/toolchain@v1 78 | with: 79 | profile: minimal 80 | toolchain: nightly 81 | override: true 82 | - name: Install cargo-llvm-cov 83 | uses: taiki-e/install-action@cargo-llvm-cov 84 | - name: Generate code coverage 85 | run: cargo llvm-cov --workspace --lcov --output-path lcov.info 86 | - name: Upload coverage to Codecov 87 | uses: codecov/codecov-action@v4 88 | env: 89 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 90 | with: 91 | files: lcov.info 92 | fail_ci_if_error: true 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | cover/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | db.sqlite3-journal 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | .pybuilder/ 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | # For a library or package, you might want to ignore these files since the code is 103 | # intended to run in multiple environments; otherwise, check them in: 104 | # .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # poetry 114 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 115 | # This is especially recommended for binary packages to ensure reproducibility, and is more 116 | # commonly ignored for libraries. 117 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 118 | #poetry.lock 119 | 120 | # pdm 121 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 122 | #pdm.lock 123 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 124 | # in version control. 125 | # https://pdm.fming.dev/#use-with-ide 126 | .pdm.toml 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | .idea/ 177 | 178 | # macOS 179 | .DS_Store 180 | 181 | # Benchmarks outputs 182 | benches/benchmark-data.* 183 | benches/*benchmarks_bar_plot*.png 184 | benches/__pycache__ 185 | benches/elapsed_times.csv 186 | benches/bench_fftw 187 | 188 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CONTRIBUTING 2 | 3 | Contributions are welcome, and they are greatly appreciated! 4 | 5 | 1. Fork the `phastft` repository 6 | 2. Clone your fork to your dev machine/environment: 7 | ```bash 8 | git clone git@github.com:/phastft.git 9 | ``` 10 | 3. [Install Rust](https://www.rust-lang.org/tools/install) and setup [nightly](https://rust-lang.github.io/rustup/concepts/channels.html) Rust 11 | 12 | 4. Setup the git hooks by in your local repo: 13 | ```bash 14 | cd PhastFT 15 | git config core.hooksPath ./hooks 16 | ``` 17 | 18 | 5. When you're done with your changes, ensure the tests pass with: 19 | ```bash 20 | cargo test --all-features 21 | ``` 22 | 23 | 7. Commit your changes and push them to GitHub 24 | 25 | 8. Submit a pull request (PR) through the [GitHub website](https://github.com/QuState/phastft/pulls). 26 | 27 | ## Pull Request Guidelines 28 | 29 | Before you submit a pull request, please check the following: 30 | - The pull request should include tests if it adds and/or changes functionalities. 31 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "phastft" 3 | version = "0.2.1" 4 | edition = "2021" 5 | authors = ["Saveliy Yusufov", "Shnatsel"] 6 | license = "MIT OR Apache-2.0" 7 | description = "A high-performance, quantum-inspired, implementation of FFT in pure Rust" 8 | repository = "https://github.com/QuState/PhastFT" 9 | keywords = ["quantum", "fft", "discrete", "fourier", "transform"] 10 | categories = ["algorithms", "compression", "science"] 11 | exclude = ["assets", "scripts", "benches"] 12 | 13 | [dependencies] 14 | num-traits = "0.2.18" 15 | multiversion = "0.7" 16 | num-complex = { version = "0.4.6", features = ["bytemuck"], optional = true } 17 | bytemuck = { version = "1.16.0", optional = true } 18 | 19 | [features] 20 | default = [] 21 | complex-nums = ["dep:num-complex", "dep:bytemuck"] 22 | 23 | [dev-dependencies] 24 | criterion = "0.5.1" 25 | fftw = "0.8.0" 26 | rand = "0.8.5" 27 | utilities = { path = "utilities" } 28 | 29 | [[bench]] 30 | name = "bench" 31 | harness = false 32 | 33 | [profile.release] 34 | codegen-units = 1 35 | lto = true 36 | panic = "abort" 37 | 38 | [profile.profiling] 39 | inherits = "release" 40 | debug = true 41 | 42 | [package.metadata.docs.rs] 43 | all-features = true 44 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 The PhastFT Developers 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build](https://github.com/QuState/PhastFT/actions/workflows/rust.yml/badge.svg)](https://github.com/QuState/PhastFT/actions/workflows/rust.yml) 2 | [![codecov](https://codecov.io/gh/QuState/PhastFT/graph/badge.svg?token=IM86XMURHN)](https://codecov.io/gh/QuState/PhastFT) 3 | [![unsafe forbidden](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/) 4 | [![](https://img.shields.io/crates/v/phastft)](https://crates.io/crates/phastft) 5 | [![](https://docs.rs/phastft/badge.svg)](https://docs.rs/phastft/) 6 | 7 | # PhastFT 8 | 9 | PhastFT is a high-performance, "quantum-inspired" Fast Fourier 10 | Transform (FFT) library written in pure Rust. 11 | 12 | ## Features 13 | 14 | - Simple implementation using the Cooley-Tukey FFT algorithm 15 | - Performance on par with other Rust FFT implementations 16 | - Zero `unsafe` code 17 | - Takes advantage of latest CPU features up to and including `AVX-512`, but performs well even without them 18 | - Selects the fastest implementation at runtime. No need for `-C target-cpu=native`! 19 | - Optional parallelization of some steps to 2 threads (with even more planned) 20 | - 2x lower memory usage than [RustFFT](https://crates.io/crates/rustfft/) 21 | - Python bindings (via [PyO3](https://github.com/PyO3/pyo3)) 22 | 23 | ## Limitations 24 | 25 | - Only supports input with a length of `2^n` (i.e., a power of 2) -- input should be padded with zeros to the next power 26 | of 2 27 | - Requires nightly Rust compiler due to use of portable SIMD 28 | 29 | ## Planned features 30 | 31 | - Bluestein's algorithm (to handle arbitrary sized FFTs) 32 | - More multi-threading 33 | - More work on cache-optimal FFT 34 | 35 | ## How is it so fast? 36 | 37 | PhastFT is designed around the capabilities and limitations of modern hardware (that is, anything made in the last 10 38 | years or so). 39 | 40 | The two major bottlenecks in FFT are the **CPU cycles** and **memory accesses**. 41 | 42 | We picked an efficient, general-purpose FFT algorithm. Our implementation can make use of latest CPU features such as 43 | `AVX-512`, but performs well even without them. 44 | 45 | Our key insight for speeding up memory accesses is that FFT is equivalent to applying gates to all qubits in `[0, n)`. 46 | This creates the opportunity to leverage the same memory access patterns as 47 | a [high-performance quantum state simulator](https://github.com/QuState/spinoza). 48 | 49 | We also use the Cache-Optimal Bit Reversal 50 | Algorithm ([COBRA](https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf)) 51 | on large datasets and optionally run it on 2 parallel threads, accelerating it even further. 52 | 53 | All of this combined results in a fast and efficient FFT implementation competitive with 54 | the performance of existing Rust FFT crates, 55 | including [RustFFT](https://crates.io/crates/rustfft/), while using significantly less memory. 56 | 57 | ## Quickstart 58 | 59 | ### Rust 60 | 61 | ```rust 62 | use phastft::{ 63 | planner::Direction, 64 | fft_64 65 | }; 66 | 67 | let big_n = 1 << 10; 68 | let mut reals: Vec = (1..=big_n).map(|i| i as f64).collect(); 69 | let mut imags: Vec = (1..=big_n).map(|i| i as f64).collect(); 70 | fft_64(&mut reals, &mut imags, Direction::Forward); 71 | ``` 72 | 73 | ### Python 74 | 75 | Follow the instructions at https://rustup.rs/ to install Rust, then switch to the nightly channel with 76 | 77 | ```bash 78 | rustup default nightly 79 | ``` 80 | 81 | Then you can install PhastFT itself: 82 | 83 | ```bash 84 | pip install numpy 85 | pip install git+https://github.com/QuState/PhastFT#subdirectory=pyphastft 86 | ``` 87 | 88 | ```python 89 | import numpy as np 90 | from pyphastft import fft 91 | 92 | sig_re = np.asarray(sig_re, dtype=np.float64) 93 | sig_im = np.asarray(sig_im, dtype=np.float64) 94 | 95 | fft(a_re, a_im) 96 | ``` 97 | 98 | ### Normalization 99 | 100 | `phastft` does not normalize outputs. Users can normalize outputs after running a forward FFT followed by an inverse 101 | FFT by scaling each element by `1/N`, where `N` is the number of data points. 102 | 103 | ### Output Order 104 | 105 | `phastft` always finishes processing input data by running 106 | a [bit-reversal permutation](https://en.wikipedia.org/wiki/Bit-reversal_permutation) on the processed data. 107 | 108 | ## Benchmarks 109 | 110 | PhastFT is benchmarked against several other FFT libraries. Scripts and instructions to reproduce benchmark results and 111 | plots are available [here](https://github.com/QuState/PhastFT/tree/main/benches#readme). 112 | 113 |

114 | PhastFT vs. RustFFT vs. FFTW3 115 | PhastFT vs. RustFFT vs. FFTW3 116 |

117 | 118 |

119 | PhastFT vs. NumPy FFT vs. pyFFTW 120 | PhastFT vs. NumPy FFT vs. pyFFTW 121 |

122 | 123 | ## Contributing 124 | 125 | Contributions to PhastFT are welcome! If you find any issues or have improvements to suggest, please open an issue or 126 | submit a pull request. Follow the contribution guidelines outlined in the CONTRIBUTING.md file. 127 | 128 | ## License 129 | 130 | PhastFT is licensed under MIT or Apache 2.0 license, at your option. 131 | 132 | ## PhastFT vs. RustFFT 133 | 134 | [RustFFT](https://crates.io/crates/rustfft/) is another excellent FFT implementation in pure Rust. 135 | RustFFT and PhastFT make different trade-offs. 136 | 137 | RustFFT made the choice to work on stable Rust compiler at the cost of `unsafe` code, 138 | while PhastFT contains no `unsafe` blocks but requires a nightly build of Rust compiler 139 | to access the Portable SIMD API. 140 | 141 | RustFFT implements multiple FFT algorithms and tries to pick the best one depending on the workload, 142 | while PhastFT has a single FFT implementation and still achieves competitive performance. 143 | 144 | PhastFT uses 2x less memory than RustFFT, which is important for processing large datasets. 145 | 146 | ## What's with the name? 147 | 148 | The name, **PhastFT**, is derived from the implementation of the 149 | [Quantum Fourier Transform](https://en.wikipedia.org/wiki/Quantum_Fourier_transform) (QFT). Namely, the 150 | [quantum circuit implementation of QFT](https://en.wikipedia.org/wiki/Quantum_Fourier_transform#Circuit_implementation) 151 | consists of the **P**hase gates and **H**adamard gates. Hence, **Ph**astFT. 152 | -------------------------------------------------------------------------------- /assets/benchmarks_bar_plot_13_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuState/PhastFT/7048416f4ec5c8ccff16c40faf3eecf6fc54c02d/assets/benchmarks_bar_plot_13_29.png -------------------------------------------------------------------------------- /assets/benchmarks_bar_plot_4_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuState/PhastFT/7048416f4ec5c8ccff16c40faf3eecf6fc54c02d/assets/benchmarks_bar_plot_4_12.png -------------------------------------------------------------------------------- /assets/py_benchmarks_bar_plot_0_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuState/PhastFT/7048416f4ec5c8ccff16c40faf3eecf6fc54c02d/assets/py_benchmarks_bar_plot_0_8.png -------------------------------------------------------------------------------- /assets/py_benchmarks_bar_plot_9_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuState/PhastFT/7048416f4ec5c8ccff16c40faf3eecf6fc54c02d/assets/py_benchmarks_bar_plot_9_28.png -------------------------------------------------------------------------------- /benches/Makefile: -------------------------------------------------------------------------------- 1 | CC = gcc 2 | CFLAGS = -Wall -Wextra -Werror -O3 3 | LIBS = -lfftw3 -lm 4 | 5 | bench_fftw: main.c 6 | $(CC) $(CFLAGS) -o bench_fftw main.c $(LIBS) 7 | 8 | clean: 9 | rm -f bench_fftw 10 | 11 | -------------------------------------------------------------------------------- /benches/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks and Profiling 2 | 3 | ## Run benchmarks 4 | 5 | ### Setup Environment 6 | 7 | 1. Clone the `PhastFT` git repository [^2]. 8 | 9 | 2. Create virtual env 10 | 11 | ```bash 12 | cd ~/PhastFT/benches && python3 -m venv .env && source .env/bin/activate 13 | ``` 14 | 15 | 3. Install python dependencies[^1] 16 | 17 | ```bash 18 | pip install -r requirements.txt 19 | cd ~/PhastFT/pyphastft 20 | pip install . 21 | ``` 22 | 23 | 5. Run the `FFTW3-RB` vs. `RustFFT` vs. `PhastFT` benchmarks` 24 | 25 | ```bash 26 | python run_benches.py 27 | ``` 28 | 29 | 6. Plot the results 30 | 31 | ```bash 32 | python benchmark_plots.py 33 | ``` 34 | 35 | The generated images will be saved in your working directory. 36 | 37 | 7. Run the python benchmarks and plot the results 38 | 39 | ```bash 40 | python py_benchmarks.py 41 | ``` 42 | 43 | The generated images will be saved in your working directory. 44 | 45 | ## Benchmark Configuration 46 | 47 | ### Libraries and Packages 48 | 49 | | Library/Package | Version | Language | Benchmark Compilation Flags | 50 | |-----------------|----------------|-----------|---------------------------------------------------------------------------------| 51 | | `FFTW3` | 3.3.10-1 amd64 | C, OCaml | `-O3` | 52 | | `RustFFT` | 6.2.0 | Rust | `-C opt-level=3 --edition=2021; codegen-units = 1; lto = true; panic = "abort"` | 53 | | `PhastFT` | 0.1.0 | Rust | `-C opt-level=3 --edition=2021; codegen-units = 1; lto = true; panic = "abort"` | 54 | | `NumPy` | 1.26.4 | Python, C | `N/A` | 55 | | `pyFFTW` | 0.13.1 | Python, C | `N/A` | 56 | 57 | ### Benchmark System Configuration 58 | 59 | | | | 60 | |---------------------------|-------------------------------------------------------------------------------------------------| 61 | | **CPU** | AMD Ryzen 9 7950X (SMT off) | 62 | | L1d Cache | 512 KiB (16 instances) | 63 | | L1i Cache | 512 KiB (16 instances) | 64 | | L2 Cache | 16 MiB (16 instances) | 65 | | L3 Cache | 64 MiB (2 instances) | 66 | | | | 67 | | **Memory** | | 68 | | /0/f/0 | 64GiB System Memory | 69 | | /0/f/1 | 32GiB DIMM Synchronous Unbuffered (Unregistered) 6000 MHz (0.2 ns) | 70 | | /0/f/3 | 32GiB DIMM Synchronous Unbuffered (Unregistered) 6000 MHz (0.2 ns) | 71 | | | | 72 | | **OS** | Linux 7950x 6.1.0-17-amd64 #1 SMP PREEMPT_DYNAMIC Debian 6.1.69-1 (2023-12-30) x86_64 GNU/Linux | 73 | | CPU Freq Scaling Governor | Performance | 74 | | | | 75 | | **Rust** | | 76 | | Installed Toolchains | stable-x86_64-unknown-linux-gnu | 77 | | | nightly-x86_64-unknown-linux-gnu (default) | 78 | | Active Toolchain | nightly-x86_64-unknown-linux-gnu (default) | 79 | | Rustc Version | rustc 1.79.0-nightly (7f2fc33da 2024-04-22) | 80 | 81 | ## Profiling 82 | 83 | Navigate to the cloned repo: 84 | 85 | ```bash 86 | cd PhastFt 87 | ``` 88 | 89 | On linux, open access to performance monitoring, and observability operations for processes: 90 | 91 | ```bash 92 | echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid 93 | ``` 94 | 95 | Finally, run: 96 | 97 | ```bash 98 | ./profile.sh 99 | ``` 100 | 101 | [^1]: Those with macOS on Apple Silicon should 102 | consult [pyFFTW Issue #352](https://github.com/pyFFTW/pyFFTW/issues/352#issuecomment-1945444558) 103 | 104 | [^2]: This tutorial assumes you will clone `PhastFT` to `$HOME` 105 | -------------------------------------------------------------------------------- /benches/bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; 2 | use num_traits::Float; 3 | use phastft::{ 4 | fft_32_with_opts_and_plan, fft_64_with_opts_and_plan, 5 | options::Options, 6 | planner::{Direction, Planner32, Planner64}, 7 | }; 8 | use rand::{ 9 | distributions::{Distribution, Standard}, 10 | thread_rng, Rng, 11 | }; 12 | use utilities::rustfft::num_complex::Complex; 13 | use utilities::rustfft::FftPlanner; 14 | 15 | const LENGTHS: &[usize] = &[ 16 | 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 17 | ]; 18 | 19 | fn generate_numbers(n: usize) -> (Vec, Vec) 20 | where 21 | Standard: Distribution, 22 | { 23 | let mut rng = thread_rng(); 24 | 25 | let samples: Vec = (&mut rng).sample_iter(Standard).take(2 * n).collect(); 26 | 27 | let mut reals = vec![T::zero(); n]; 28 | let mut imags = vec![T::zero(); n]; 29 | 30 | for ((z_re, z_im), rand_chunk) in reals 31 | .iter_mut() 32 | .zip(imags.iter_mut()) 33 | .zip(samples.chunks_exact(2)) 34 | { 35 | *z_re = rand_chunk[0]; 36 | *z_im = rand_chunk[1]; 37 | } 38 | 39 | (reals, imags) 40 | } 41 | 42 | fn generate_complex_numbers(n: usize) -> Vec> 43 | where 44 | Standard: Distribution, 45 | { 46 | let mut rng = thread_rng(); 47 | 48 | let samples: Vec = (&mut rng).sample_iter(Standard).take(2 * n).collect(); 49 | 50 | let mut signal = vec![Complex::default(); n]; 51 | 52 | for (z, rand_chunk) in signal.iter_mut().zip(samples.chunks_exact(2)) { 53 | z.re = rand_chunk[0]; 54 | z.im = rand_chunk[1]; 55 | } 56 | 57 | signal 58 | } 59 | 60 | fn benchmark_forward_f32(c: &mut Criterion) { 61 | let mut group = c.benchmark_group("Forward f32"); 62 | 63 | for n in LENGTHS.iter() { 64 | let len = 1 << n; 65 | group.throughput(Throughput::Elements(len as u64)); 66 | 67 | let id = "PhastFT FFT Forward"; 68 | let options = Options::guess_options(len); 69 | let planner = Planner32::new(len, Direction::Forward); 70 | let (mut reals, mut imags) = generate_numbers(len); 71 | 72 | group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { 73 | b.iter(|| { 74 | fft_32_with_opts_and_plan( 75 | black_box(&mut reals), 76 | black_box(&mut imags), 77 | black_box(&options), 78 | black_box(&planner), 79 | ); 80 | }); 81 | }); 82 | 83 | let id = "RustFFT FFT Forward"; 84 | let mut planner = FftPlanner::::new(); 85 | let fft = planner.plan_fft_forward(len); 86 | let mut signal = generate_complex_numbers(len); 87 | 88 | group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { 89 | b.iter(|| fft.process(black_box(&mut signal))); 90 | }); 91 | } 92 | group.finish(); 93 | } 94 | 95 | fn benchmark_inverse_f32(c: &mut Criterion) { 96 | let options = Options::default(); 97 | 98 | for n in LENGTHS.iter() { 99 | let len = 1 << n; 100 | let id = format!("FFT Inverse f32 {} elements", len); 101 | let planner = Planner32::new(len, Direction::Reverse); 102 | 103 | c.bench_function(&id, |b| { 104 | let (mut reals, mut imags) = generate_numbers(len); 105 | b.iter(|| { 106 | fft_32_with_opts_and_plan( 107 | black_box(&mut reals), 108 | black_box(&mut imags), 109 | black_box(&options), 110 | black_box(&planner), 111 | ); 112 | }); 113 | }); 114 | } 115 | } 116 | 117 | fn benchmark_forward_f64(c: &mut Criterion) { 118 | let mut group = c.benchmark_group("Forward f64"); 119 | 120 | for n in LENGTHS.iter() { 121 | let len = 1 << n; 122 | let id = "PhastFT FFT Forward"; 123 | let options = Options::guess_options(len); 124 | let planner = Planner64::new(len, Direction::Forward); 125 | let (mut reals, mut imags) = generate_numbers(len); 126 | group.throughput(Throughput::Elements(len as u64)); 127 | 128 | group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { 129 | b.iter(|| { 130 | fft_64_with_opts_and_plan( 131 | black_box(&mut reals), 132 | black_box(&mut imags), 133 | black_box(&options), 134 | black_box(&planner), 135 | ); 136 | }); 137 | }); 138 | 139 | let id = "RustFFT FFT Forward"; 140 | let mut planner = FftPlanner::::new(); 141 | let fft = planner.plan_fft_forward(len); 142 | let mut signal = generate_complex_numbers(len); 143 | 144 | group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { 145 | b.iter(|| fft.process(black_box(&mut signal))); 146 | }); 147 | } 148 | group.finish(); 149 | } 150 | 151 | fn benchmark_inverse_f64(c: &mut Criterion) { 152 | let options = Options::default(); 153 | 154 | for n in LENGTHS.iter() { 155 | let len = 1 << n; 156 | let id = format!("FFT Inverse f64 {} elements", len); 157 | let planner = Planner64::new(len, Direction::Reverse); 158 | 159 | c.bench_function(&id, |b| { 160 | let (mut reals, mut imags) = generate_numbers(len); 161 | b.iter(|| { 162 | fft_64_with_opts_and_plan( 163 | black_box(&mut reals), 164 | black_box(&mut imags), 165 | black_box(&options), 166 | black_box(&planner), 167 | ); 168 | }); 169 | }); 170 | } 171 | } 172 | 173 | criterion_group!( 174 | benches, 175 | benchmark_forward_f32, 176 | benchmark_inverse_f32, 177 | benchmark_forward_f64, 178 | benchmark_inverse_f64 179 | ); 180 | criterion_main!(benches); 181 | -------------------------------------------------------------------------------- /benches/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -Eeuo pipefail 4 | 5 | if [[ "$#" -ne 2 ]] 6 | then 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | start=$1 12 | end=$2 13 | max_iters=1000 # Set your desired maximum number of iterations 14 | 15 | OUTPUT_DIR=benchmark-data.$(date +"%Y.%m.%d.%H-%M-%S") 16 | mkdir -p "$OUTPUT_DIR"/fftw3 && mkdir "$OUTPUT_DIR"/rustfft && mkdir "$OUTPUT_DIR"/phastft && mkdir "$OUTPUT_DIR"/fftwrb 17 | 18 | benchmark_fftw3() { 19 | make clean && make 20 | 21 | for n in $(seq "$start" "$end"); do 22 | iters=$((2**($end - $n))) 23 | iters=$((iters > max_iters ? max_iters : iters)) # clamp to `max_iters` 24 | echo "Running FFTW3 benchmark for N = 2^${n} for ${iters} iterations..." 25 | 26 | for _ in $(seq 1 "$iters"); do 27 | ./bench_fftw "${n}" >> "${OUTPUT_DIR}/fftw3/size_${n}" 28 | done 29 | done 30 | } 31 | 32 | benchmark_phastft() { 33 | cargo clean && cargo build --release --examples 34 | 35 | for n in $(seq "$start" "$end"); do 36 | iters=$((2**($end - $n))) 37 | iters=$((iters > max_iters ? max_iters : iters)) 38 | echo "Running PhastFT benchmark for N = 2^${n}..." 39 | 40 | for _ in $(seq 1 "$iters"); do 41 | ../target/release/examples/benchmark "${n}" >> "${OUTPUT_DIR}"/phastft/size_"${n}" 42 | done 43 | done 44 | } 45 | 46 | benchmark_rustfft() { 47 | cargo clean && cargo build --release --examples 48 | 49 | for n in $(seq "$start" "$end"); do 50 | iters=$((2**($end - $n))) 51 | iters=$((iters > max_iters ? max_iters : iters)) 52 | echo "Running RustFFT benchmark for N = 2^${n}..." 53 | 54 | for _ in $(seq 1 "$iters"); do 55 | ../target/release/examples/rustfft "${n}" >> "${OUTPUT_DIR}"/rustfft/size_"${n}" 56 | done 57 | done 58 | } 59 | 60 | benchmark_rs_fftw3() { 61 | cargo clean && cargo build --release --examples 62 | 63 | for n in $(seq "$start" "$end"); do 64 | iters=$((2**($end - $n))) 65 | iters=$((iters > max_iters ? max_iters : iters)) 66 | echo "Running FFTW3 Rust bindings benchmark for N = 2^${n}..." 67 | 68 | for _ in $(seq 1 "$iters"); do 69 | ../target/release/examples/fftwrb "${n}" >> "${OUTPUT_DIR}"/fftwrb/size_"${n}" 70 | done 71 | done 72 | } 73 | 74 | benchmark_rs_fftw3 75 | benchmark_fftw3 76 | benchmark_phastft 77 | benchmark_rustfft 78 | -------------------------------------------------------------------------------- /benches/benchmark_plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot benchmark results for FFTW3, RustFFT, and PhastFT 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from utils import bytes2human, find_directory 10 | 11 | plt.style.use("fivethirtyeight") 12 | 13 | 14 | def read_file(filepath: str) -> list[float]: 15 | y = [] 16 | 17 | with open(filepath) as f: 18 | for line in f: 19 | line = line.strip() 20 | y.append(float(line)) 21 | 22 | return y 23 | 24 | 25 | def get_figure_of_interest(vals: list[float]) -> float: 26 | return np.min(vals) 27 | 28 | 29 | def build_and_clean_data( 30 | root_benchmark_dir: str, n_range: range, lib_name: str 31 | ) -> list[float]: 32 | data = [] 33 | 34 | for n in n_range: 35 | y = read_file(f"{root_benchmark_dir}/{lib_name}/size_{n}") 36 | y_k = get_figure_of_interest(y) 37 | data.append(y_k) 38 | 39 | return data 40 | 41 | 42 | def plot(data: dict[str, list], n_range: range) -> None: 43 | index = [bytes2human(2**n * (128 / 8)) for n in n_range] 44 | 45 | y0 = np.asarray(data["fftw3"]) 46 | y1 = np.asarray(data["phastft"]) 47 | y2 = np.asarray(data["rustfft"]) 48 | y3 = np.asarray(data["fftwrb"]) 49 | 50 | y0 /= y2 51 | y1 /= y2 52 | y3 /= y2 53 | 54 | df = pd.DataFrame( 55 | { 56 | "RustFFT": np.ones(len(index)), 57 | "FFTW3": y0, 58 | "PhastFT": y1, 59 | "FFTW3-RB": y3, 60 | }, 61 | index=index, 62 | ) 63 | 64 | df.plot(kind="bar", linewidth=2, rot=0) 65 | plt.title("PhastFT vs. FFTW3 vs. FFTW3 RB vs. RustFFT", fontsize=14) 66 | plt.xticks(fontsize=8, rotation=-45) 67 | plt.xlabel("size of input", fontsize=10) 68 | plt.ylabel("Execution Time Ratio\n(relative to RustFFT)", fontsize=11) 69 | plt.legend(loc="best") 70 | plt.tight_layout() 71 | plt.savefig(f"benchmarks_bar_plot_{n_range.start}_{n_range.stop -1}.png", dpi=600) 72 | plt.show() 73 | 74 | 75 | def main(): 76 | """Entry point... yay""" 77 | lib_names = ("rustfft", "phastft", "fftw3", "fftwrb") 78 | ranges = (range(4, 13), range(13, 30)) 79 | 80 | for n_range in ranges: 81 | all_data = {} 82 | 83 | for lib in lib_names: 84 | root_folder = find_directory() 85 | if root_folder is None: 86 | raise FileNotFoundError("unable to find the benchmark data directory") 87 | 88 | data = build_and_clean_data(root_folder, n_range, lib) 89 | all_data[lib] = data 90 | 91 | assert ( 92 | len(all_data["rustfft"]) 93 | == len(all_data["fftw3"]) 94 | == len(all_data["phastft"]) 95 | == len(all_data["fftwrb"]) 96 | ) 97 | plot(all_data, n_range) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /benches/main.c: -------------------------------------------------------------------------------- 1 | #define _GNU_SOURCE 2 | #include 3 | #include 4 | #include 5 | #include // uint64 6 | #include // clock_gettime 7 | #include 8 | 9 | #define BILLION 1000000000L 10 | 11 | // Function to generate a random, complex signal 12 | void gen_random_signal(double* reals, double* imags, int num_amps) { 13 | // Check for invalid input 14 | if (num_amps <= 0 || reals == NULL || imags == NULL) { 15 | fprintf(stderr, "Invalid input\n"); 16 | exit(EXIT_FAILURE); 17 | } 18 | 19 | // Seed the random number generator 20 | srand((unsigned int)time(NULL)); 21 | 22 | // Generate random values for probabilities 23 | double* probs = (double*)malloc(num_amps * sizeof(double)); 24 | double total = 0.0; 25 | 26 | for (int i = 0; i < num_amps; ++i) { 27 | probs[i] = (double)rand() / RAND_MAX; 28 | total += probs[i]; 29 | } 30 | 31 | // Normalize probabilities 32 | double total_recip = 1.0 / total; 33 | 34 | for (int i = 0; i < num_amps; ++i) { 35 | probs[i] *= total_recip; 36 | } 37 | 38 | // Generate random angles 39 | double* angles = (double*)malloc(num_amps * sizeof(double)); 40 | 41 | for (int i = 0; i < num_amps; ++i) { 42 | angles[i] = 2.0 * M_PI * ((double)rand() / RAND_MAX); 43 | } 44 | 45 | // Generate complex values and fill the buffers 46 | for (int i = 0; i < num_amps; ++i) { 47 | double p_sqrt = sqrt(probs[i]); 48 | double sin_a, cos_a; 49 | 50 | double theta = angles[i]; 51 | sin_a = sin(theta); 52 | cos_a = cos(theta); 53 | 54 | double re = p_sqrt * cos_a; 55 | double im = p_sqrt * sin_a; 56 | 57 | reals[i] = re; 58 | imags[i] = im; 59 | } 60 | 61 | // Free allocated memory 62 | free(probs); 63 | free(angles); 64 | } 65 | 66 | 67 | int main(int argc, char** argv) { 68 | if (argc != 2) { 69 | fprintf(stderr, "Usage: %s \n", argv[0]); 70 | return EXIT_FAILURE; 71 | } 72 | 73 | long n = strtol(argv[1], NULL, 0); 74 | 75 | int N = 1 << n; 76 | 77 | // We don't count input mem allocation for RustFFT or PhastFT, so we omit 78 | // it from the timer here. 79 | fftw_complex* in = fftw_alloc_complex(N); 80 | 81 | uint64_t diff1; 82 | struct timespec start, end; 83 | clock_gettime(CLOCK_MONOTONIC, &start); 84 | fftw_plan p = fftw_plan_dft_1d(N, in, in, FFTW_FORWARD, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); 85 | clock_gettime(CLOCK_MONOTONIC, &end); 86 | diff1 = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec; 87 | 88 | // Generate random complex signal using the provided function 89 | double* reals = (double*)malloc(N * sizeof(double)); 90 | double* imags = (double*)malloc(N * sizeof(double)); 91 | gen_random_signal(reals, imags, N); 92 | 93 | // Fill the FFT input array 94 | for (int i = 0; i < N; i++) { 95 | in[i][0] = reals[i]; 96 | in[i][1] = imags[i]; 97 | } 98 | free(reals); 99 | free(imags); 100 | 101 | uint64_t diff2; 102 | struct timespec start1, end1; 103 | clock_gettime(CLOCK_MONOTONIC, &start1); 104 | fftw_execute(p); 105 | clock_gettime(CLOCK_MONOTONIC, &end1); /* mark the end1 time */ 106 | diff2 = BILLION * (end1.tv_sec - start1.tv_sec) + end1.tv_nsec - start1.tv_nsec; 107 | 108 | uint64_t diff = (diff1 / 1000) + (diff2 / 1000); 109 | printf("%llu\n", (long long unsigned int) diff); 110 | 111 | fftw_free(in); 112 | fftw_destroy_plan(p); 113 | fftw_cleanup(); 114 | 115 | return EXIT_SUCCESS; 116 | } 117 | 118 | -------------------------------------------------------------------------------- /benches/py_benchmarks.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import time 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import pyfftw 8 | 9 | pyfftw.interfaces.cache.enable() 10 | 11 | from pyphastft import fft 12 | 13 | from utils import bytes2human 14 | 15 | plt.style.use("fivethirtyeight") 16 | 17 | 18 | def gen_random_signal(dim: int) -> np.ndarray: 19 | """Generate a random, complex 1D signal""" 20 | return np.ascontiguousarray( 21 | np.random.randn(dim) + 1j * np.random.randn(dim), 22 | dtype="complex128", 23 | ) 24 | 25 | 26 | def main() -> None: 27 | with open("elapsed_times.csv", "w", newline="", encoding="utf-8") as csvfile: 28 | fieldnames = ["n", "phastft_time", "numpy_fft_time", "pyfftw_fft_time"] 29 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 30 | 31 | writer.writeheader() 32 | 33 | for n in range(4, 29): 34 | print(f"n = {n}") 35 | big_n = 1 << n 36 | s = gen_random_signal(big_n) 37 | 38 | a_re = np.ascontiguousarray(s.real, dtype=np.float64) 39 | a_im = np.ascontiguousarray(s.imag, dtype=np.float64) 40 | 41 | start = time.time() 42 | fft(a_re, a_im, "f") 43 | phastft_elapsed = round((time.time() - start) * 10**6) 44 | print(f"PhastFT completed in {phastft_elapsed} us") 45 | 46 | a = s.copy() 47 | 48 | start = time.time() 49 | expected = np.fft.fft(a) 50 | numpy_elapsed = round((time.time() - start) * 10**6) 51 | print(f"NumPy fft completed in {numpy_elapsed} us") 52 | 53 | actual = np.asarray( 54 | [ 55 | complex(z_re, z_im) 56 | for (z_re, z_im) in zip( 57 | a_re, 58 | a_im, 59 | ) 60 | ] 61 | ) 62 | np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=0) 63 | 64 | arr = s.copy() 65 | a = pyfftw.empty_aligned(big_n, dtype="complex128") 66 | a[:] = arr 67 | start = time.time() 68 | a = pyfftw.interfaces.numpy_fft.fft(a) 69 | pyfftw_elapsed = round((time.time() - start) * 10**6) 70 | print(f"pyFFTW completed in {pyfftw_elapsed} us") 71 | 72 | np.testing.assert_allclose(a, actual, rtol=1e-3, atol=0) 73 | 74 | writer.writerow( 75 | { 76 | "n": n, 77 | "phastft_time": phastft_elapsed, 78 | "numpy_fft_time": numpy_elapsed, 79 | "pyfftw_fft_time": pyfftw_elapsed, 80 | } 81 | ) 82 | 83 | file_path = "elapsed_times.csv" 84 | loaded_data = read_csv_to_dict(file_path) 85 | grouped_bar_plot(loaded_data, start=0, end=9) 86 | grouped_bar_plot(loaded_data, start=9, end=29) 87 | 88 | 89 | def read_csv_to_dict(file_path: str) -> dict: 90 | """Read the benchmark results from the csv file and convert it to a dict""" 91 | data: dict[str, list] = { 92 | "n": [], 93 | "phastft_time": [], 94 | "numpy_fft_time": [], 95 | "pyfftw_fft_time": [], 96 | } 97 | with open(file_path, newline="", encoding="utf-8") as csvfile: 98 | reader = csv.DictReader(csvfile) 99 | for row in reader: 100 | data["n"].append(int(row["n"])) 101 | data["phastft_time"].append( 102 | int(row["phastft_time"]) if row["phastft_time"] else None 103 | ) 104 | data["numpy_fft_time"].append( 105 | int(row["numpy_fft_time"]) if row["numpy_fft_time"] else None 106 | ) 107 | data["pyfftw_fft_time"].append( 108 | int(row["pyfftw_fft_time"]) if row["pyfftw_fft_time"] else None 109 | ) 110 | return data 111 | 112 | 113 | def plot_elapsed_times(data: dict) -> None: 114 | """Plot the timings for all libs using line plots""" 115 | index = [bytes2human(2**n * (128 / 8)) for n in data["n"]] 116 | np_fft_timings = np.asarray(data["numpy_fft_time"]) 117 | pyfftw_timings = np.asarray(data["pyfftw_fft_time"]) 118 | phastft_timings = np.asarray(data["phastft_time"]) 119 | 120 | plt.plot(index, np_fft_timings, label="NumPy FFT", lw=0.8) 121 | plt.plot(index, pyfftw_timings, label="pyFFTW", lw=0.8) 122 | plt.plot(index, phastft_timings, label="pyPhastFT", lw=0.8) 123 | 124 | plt.title("pyPhastFT vs. pyFFTW vs. NumPy FFT") 125 | plt.xticks(fontsize=8, rotation=-45) 126 | plt.xlabel("size of input") 127 | plt.ylabel("time (us)") 128 | plt.yscale("log") 129 | plt.legend(loc="best") 130 | plt.tight_layout() 131 | plt.savefig("py_benchmarks.png", dpi=600) 132 | 133 | 134 | def grouped_bar_plot(data: dict, start=0, end=1): 135 | """Plot the timings for all libs using a grouped bar chart""" 136 | index = data["n"] 137 | index = [bytes2human(2**n * (128 / 8)) for n in index] 138 | np_fft_timings = np.asarray(data["numpy_fft_time"]) 139 | pyfftw_timings = np.asarray(data["pyfftw_fft_time"]) # / np_fft_timings 140 | phastft_timings = np.asarray(data["phastft_time"]) # / np_fft_timings 141 | 142 | df = pd.DataFrame( 143 | { 144 | "NumPy fft": np.ones(len(index)), 145 | "pyFFTW": pyfftw_timings / np_fft_timings, 146 | "pyPhastFT": phastft_timings / np_fft_timings, 147 | }, 148 | index=index, 149 | ) 150 | 151 | title = "pyPhastFT vs. pyFFTW vs. NumPy FFT" 152 | df[start:end].plot(kind="bar", linewidth=2, rot=0, title=title) 153 | plt.xticks(fontsize=8, rotation=-45) 154 | plt.xlabel("size of input") 155 | plt.ylabel("Execution Time Ratio\n(relative to NumPy FFT)") 156 | plt.legend(loc="best") 157 | plt.tight_layout() 158 | plt.savefig(f"py_benchmarks_bar_plot_{start}_{end-1}.png", dpi=600) 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /benches/requirements.txt: -------------------------------------------------------------------------------- 1 | black==24.3.0 2 | isort==5.13.2 3 | matplotlib==3.8.2 4 | mypy==1.8.0 5 | numpy==1.26.4 6 | pandas==2.2.0 7 | pyFFTW==0.13.1 8 | pylint==3.0.3 9 | -------------------------------------------------------------------------------- /benches/run_benches.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | import matplotlib.pyplot as plt 6 | import shutil 7 | from datetime import datetime 8 | import logging 9 | import numpy as np 10 | 11 | # Configuration 12 | OUTPUT_DIR = "benchmark_output" 13 | HISTORY_DIR = "benchmark_history" 14 | LOG_DIR = "benchmark_logs" 15 | MAX_ITERS = 1 << 10 16 | START = 6 17 | END = 20 18 | STD_THRESHOLD = 0.05 # 5% standard deviation threshold 19 | 20 | # Ensure log directory exists 21 | Path(LOG_DIR).mkdir(parents=True, exist_ok=True) 22 | 23 | # Setup logging 24 | logging.basicConfig( 25 | filename=Path(LOG_DIR) / "benchmark.log", 26 | level=logging.INFO, 27 | format="%(asctime)s - %(message)s", 28 | ) 29 | console = logging.StreamHandler() 30 | console.setLevel(logging.INFO) 31 | formatter = logging.Formatter("%(asctime)s - %(message)s") 32 | console.setFormatter(formatter) 33 | logging.getLogger().addHandler(console) 34 | 35 | 36 | def run_command(command, cwd=None): 37 | result = subprocess.run( 38 | command, shell=True, text=True, capture_output=True, cwd=cwd 39 | ) 40 | if result.returncode != 0: 41 | logging.error(f"Error running command: {command}\n{result.stderr}") 42 | sys.exit(result.returncode) 43 | return result.stdout.strip() 44 | 45 | 46 | def clean_build_rust(): 47 | logging.info("Cleaning and building Rust project...") 48 | run_command("cargo clean") 49 | run_command("cargo build --release --examples") 50 | 51 | 52 | def benchmark_with_stabilization(executable_name, n, max_iters, std_threshold): 53 | times = [] 54 | for i in range(max_iters): 55 | result = run_command(f"../target/release/examples/{executable_name} {n}") 56 | times.append(int(result)) 57 | if len(times) > 10: # Start evaluating after a minimum number of runs 58 | current_std = np.std(times) / np.mean(times) 59 | if current_std < std_threshold: 60 | break 61 | return times 62 | 63 | 64 | def benchmark( 65 | benchmark_name, output_subdir, start, end, max_iters, std_threshold, executable_name 66 | ): 67 | output_dir_path = Path(OUTPUT_DIR) / output_subdir 68 | output_dir_path.mkdir(parents=True, exist_ok=True) 69 | 70 | for n in range(start, end + 1): 71 | logging.info( 72 | f"Running {benchmark_name} benchmark for N = 2^{n} with a standard deviation threshold of {std_threshold * 100}%..." 73 | ) 74 | times = benchmark_with_stabilization( 75 | executable_name, n, max_iters, std_threshold 76 | ) 77 | output_file = output_dir_path / f"size_{n}" 78 | with open(output_file, "w") as f: 79 | for time in times: 80 | f.write(f"{time}\n") 81 | logging.info( 82 | f"Completed N = 2^{n} in {len(times)} iterations with a final standard deviation of {np.std(times) / np.mean(times):.2%}" 83 | ) 84 | 85 | 86 | def read_benchmark_results(output_dir, start, end): 87 | sizes = [] 88 | times = [] 89 | 90 | for n in range(start, end + 1): 91 | size_file = Path(output_dir) / f"size_{n}" 92 | if size_file.exists(): 93 | with open(size_file, "r") as f: 94 | data = f.readlines() 95 | data = [int(line.strip()) for line in data] 96 | if data: 97 | min_time_ns = min(data) 98 | sizes.append(2**n) 99 | times.append(min_time_ns) 100 | else: 101 | logging.warning(f"No data found in file: {size_file}") 102 | else: 103 | logging.warning(f"File does not exist: {size_file}") 104 | 105 | return sizes, times 106 | 107 | 108 | def plot_benchmark_results(output_subdirs, start, end, history_dirs=[]): 109 | plt.figure(figsize=(10, 6)) 110 | has_data = False 111 | 112 | # Plot current results 113 | for subdir in output_subdirs: 114 | sizes, times = read_benchmark_results(Path(OUTPUT_DIR) / subdir, start, end) 115 | if sizes and times: 116 | has_data = True 117 | plt.plot(sizes, times, marker="o", label=f"current {subdir}") 118 | 119 | # Plot previous results from history for PhastFT 120 | for history_dir in history_dirs: 121 | sizes, times = read_benchmark_results( 122 | Path(history_dir) / "benchmark_output" / "phastft", start, end 123 | ) 124 | if sizes and times: 125 | has_data = True 126 | timestamp = Path(history_dir).stem 127 | plt.plot( 128 | sizes, times, marker="x", linestyle="--", label=f"{timestamp} phastft" 129 | ) 130 | 131 | if has_data: 132 | plt.title("Benchmark Results") 133 | plt.xlabel("FFT Size (N)") 134 | plt.ylabel("Minimum Time (ns)") 135 | plt.xscale("log") 136 | plt.yscale("log") 137 | plt.grid(True, which="both", ls="--") 138 | plt.legend() 139 | plt.savefig(f"{OUTPUT_DIR}/benchmark_results.png", dpi=600) 140 | # plt.show() 141 | else: 142 | logging.warning("No data available to plot.") 143 | 144 | 145 | def compare_results(current_dir, previous_dir, start, end): 146 | changes = {} 147 | for n in range(start, end + 1): 148 | current_file = Path(current_dir) / f"size_{n}" 149 | previous_file = ( 150 | Path(previous_dir) / "benchmark_output" / "phastft" / f"size_{n}" 151 | ) 152 | 153 | if current_file.exists() and previous_file.exists(): 154 | with open(current_file, "r") as cf, open(previous_file, "r") as pf: 155 | current_data = [int(line.strip()) for line in cf.readlines()] 156 | previous_data = [int(line.strip()) for line in pf.readlines()] 157 | 158 | if current_data and previous_data: 159 | current_min = min(current_data) 160 | previous_min = min(previous_data) 161 | 162 | if current_min != previous_min: 163 | change = ((current_min - previous_min) / previous_min) * 100 164 | changes[n] = change 165 | else: 166 | logging.warning( 167 | f"Data missing in files for size 2^{n}: Current data length: {len(current_data)}, Previous data length: {len(previous_data)}" 168 | ) 169 | else: 170 | logging.warning( 171 | f"Missing files for size 2^{n}: Current file exists: {current_file.exists()}, Previous file exists: {previous_file.exists()}" 172 | ) 173 | 174 | return changes 175 | 176 | 177 | def archive_current_results(): 178 | if Path(OUTPUT_DIR).exists(): 179 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 180 | history_dir = Path(HISTORY_DIR) / timestamp 181 | history_dir.mkdir(parents=True, exist_ok=True) 182 | shutil.move(OUTPUT_DIR, history_dir) 183 | logging.info(f"Archived current results to: {history_dir}") 184 | else: 185 | logging.warning( 186 | f"Output directory '{OUTPUT_DIR}' does not exist and cannot be archived." 187 | ) 188 | 189 | 190 | def main(): 191 | clean_build_rust() 192 | 193 | # Check if there are previous results for comparison 194 | history_dirs = ( 195 | sorted(Path(HISTORY_DIR).iterdir(), key=os.path.getmtime) 196 | if Path(HISTORY_DIR).exists() 197 | else [] 198 | ) 199 | latest_previous_dir = history_dirs[-1] if history_dirs else None 200 | 201 | # Run new benchmarks for PhastFT, RustFFT, and FFTW3 202 | benchmark("PhastFT", "phastft", START, END, MAX_ITERS, STD_THRESHOLD, "benchmark") 203 | benchmark("RustFFT", "rustfft", START, END, MAX_ITERS, STD_THRESHOLD, "rustfft") 204 | benchmark( 205 | "FFTW3 Rust bindings", "fftwrb", START, END, MAX_ITERS, STD_THRESHOLD, "fftwrb" 206 | ) 207 | 208 | # Compare new PhastFT benchmarks against previous results 209 | if latest_previous_dir: 210 | logging.info(f"Comparing with previous results from: {latest_previous_dir}") 211 | changes = compare_results( 212 | Path(OUTPUT_DIR) / "phastft", latest_previous_dir, START, END 213 | ) 214 | for n, change in changes.items(): 215 | status = "improvement" if change < 0 else "regression" 216 | logging.info(f"N = 2^{n}: {abs(change):.2f}% {status}") 217 | else: 218 | logging.info("No previous results found for comparison.") 219 | 220 | # Plot benchmark results 221 | plot_benchmark_results(["phastft", "rustfft", "fftwrb"], START, END, history_dirs) 222 | 223 | # Archive current results 224 | archive_current_results() 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /benches/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for plotting that are common to all scripts 3 | """ 4 | 5 | import os 6 | import re 7 | from datetime import datetime 8 | 9 | SYMBOLS = { 10 | "customary": ("B", "K", "M", "G", "T", "P", "E", "Z", "Y"), 11 | "customary_ext": ( 12 | "byte", 13 | "kilo", 14 | "mega", 15 | "giga", 16 | "tera", 17 | "peta", 18 | "exa", 19 | "zetta", 20 | "iotta", 21 | ), 22 | "iec": ("Bi", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"), 23 | "iec_ext": ("byte", "kibi", "mebi", "gibi", "tebi", "pebi", "exbi", "zebi", "yobi"), 24 | } 25 | 26 | 27 | def bytes2human(n, format="%(value).0f %(symbol)s", symbols="customary"): 28 | """ 29 | Convert n bytes into a human-readable string based on format. 30 | symbols can be either "customary", "customary_ext", "iec" or "iec_ext", 31 | see: https://goo.gl/kTQMs 32 | 33 | Source: https://stackoverflow.com/a/1094933 34 | """ 35 | n = int(n) 36 | if n < 0: 37 | raise ValueError("n < 0") 38 | symbols = SYMBOLS[symbols] 39 | prefix = {} 40 | for i, s in enumerate(symbols[1:]): 41 | prefix[s] = 1 << (i + 1) * 10 42 | for symbol in reversed(symbols[1:]): 43 | if n >= prefix[symbol]: 44 | value = float(n) / prefix[symbol] 45 | return format % locals() 46 | return format % dict(symbol=symbols[0], value=n) 47 | 48 | 49 | def find_directory(pattern="benchmark-data"): 50 | current_dir = os.getcwd() 51 | 52 | # List all directories in the current directory 53 | all_dirs = [ 54 | d 55 | for d in os.listdir(current_dir) 56 | if os.path.isdir(os.path.join(current_dir, d)) 57 | ] 58 | 59 | # Define the regex pattern for matching 60 | date_pattern = re.compile(r"\d{4}\.\d{2}\.\d{2}\.\d{2}-\d{2}-\d{2}") 61 | 62 | # Iterate through directories and check if they match the pattern 63 | matching_dirs = [d for d in all_dirs if pattern in d and date_pattern.search(d)] 64 | 65 | if matching_dirs: 66 | # Sort directories based on the date in the directory name 67 | matching_dirs.sort( 68 | key=lambda x: datetime.strptime( 69 | date_pattern.search(x).group(), "%Y.%m.%d.%H-%M-%S" 70 | ) 71 | ) 72 | return os.path.join( 73 | current_dir, matching_dirs[-1] 74 | ) # Return the latest matching directory 75 | else: 76 | return None # No matching directory found 77 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 90% 6 | 7 | patch: false 8 | changes: false -------------------------------------------------------------------------------- /examples/benchmark.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::str::FromStr; 3 | 4 | use utilities::gen_random_signal; 5 | 6 | use phastft::fft_64_with_opts_and_plan; 7 | use phastft::options::Options; 8 | use phastft::planner::{Direction, Planner64}; 9 | 10 | fn benchmark_fft_64(n: usize) { 11 | let big_n = 1 << n; 12 | let mut reals = vec![0.0; big_n]; 13 | let mut imags = vec![0.0; big_n]; 14 | gen_random_signal(&mut reals, &mut imags); 15 | 16 | let planner = Planner64::new(reals.len(), Direction::Forward); 17 | let opts = Options::guess_options(reals.len()); 18 | 19 | let now = std::time::Instant::now(); 20 | fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); 21 | let elapsed = now.elapsed().as_nanos(); 22 | println!("{elapsed}"); 23 | } 24 | 25 | fn main() { 26 | let args: Vec = env::args().collect(); 27 | assert_eq!(args.len(), 2, "Usage {} ", args[0]); 28 | 29 | let n = usize::from_str(&args[1]).unwrap(); 30 | 31 | benchmark_fft_64(n); 32 | } 33 | -------------------------------------------------------------------------------- /examples/fftwrb.rs: -------------------------------------------------------------------------------- 1 | use std::{env, ptr::slice_from_raw_parts_mut, str::FromStr}; 2 | 3 | use fftw::{ 4 | array::AlignedVec, 5 | plan::{C2CPlan, C2CPlan64}, 6 | types::{Flag, Sign}, 7 | }; 8 | use utilities::{gen_random_signal, rustfft::num_complex::Complex}; 9 | 10 | fn benchmark_fftw(n: usize) { 11 | let big_n = 1 << n; 12 | 13 | let mut reals = vec![0.0; big_n]; 14 | let mut imags = vec![0.0; big_n]; 15 | 16 | gen_random_signal(&mut reals, &mut imags); 17 | let mut nums = AlignedVec::new(big_n); 18 | reals 19 | .drain(..) 20 | .zip(imags.drain(..)) 21 | .zip(nums.iter_mut()) 22 | .for_each(|((re, im), z)| *z = Complex::new(re, im)); 23 | 24 | let now = std::time::Instant::now(); 25 | C2CPlan64::aligned( 26 | &[big_n], 27 | Sign::Backward, 28 | Flag::DESTROYINPUT | Flag::ESTIMATE, 29 | ) 30 | .unwrap() 31 | .c2c( 32 | // SAFETY: See above comment. 33 | unsafe { &mut *slice_from_raw_parts_mut(nums.as_mut_ptr(), big_n) }, 34 | &mut nums, 35 | ) 36 | .unwrap(); 37 | let elapsed = now.elapsed().as_nanos(); 38 | println!("{elapsed}"); 39 | } 40 | 41 | fn main() { 42 | let args: Vec = env::args().collect(); 43 | assert_eq!(args.len(), 2, "Usage {} ", args[0]); 44 | 45 | let n = usize::from_str(&args[1]).unwrap(); 46 | benchmark_fftw(n); 47 | } 48 | -------------------------------------------------------------------------------- /examples/profile.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::str::FromStr; 3 | 4 | use utilities::gen_random_signal; 5 | 6 | use phastft::fft_64_with_opts_and_plan; 7 | use phastft::options::Options; 8 | use phastft::planner::{Direction, Planner64}; 9 | 10 | fn benchmark_fft_64(n: usize) { 11 | let big_n = 1 << n; 12 | let mut reals = vec![0.0; big_n]; 13 | let mut imags = vec![0.0; big_n]; 14 | gen_random_signal(&mut reals, &mut imags); 15 | 16 | let planner = Planner64::new(reals.len(), Direction::Forward); 17 | let opts = Options::guess_options(reals.len()); 18 | 19 | fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); 20 | } 21 | 22 | fn main() { 23 | let args: Vec = env::args().collect(); 24 | assert_eq!(args.len(), 2, "Usage {} ", args[0]); 25 | 26 | let n = usize::from_str(&args[1]).unwrap(); 27 | 28 | benchmark_fft_64(n); 29 | } 30 | -------------------------------------------------------------------------------- /examples/rustfft.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::str::FromStr; 3 | 4 | use utilities::{ 5 | gen_random_signal, 6 | rustfft::{num_complex::Complex64, FftPlanner}, 7 | }; 8 | 9 | fn benchmark_rustfft(n: usize) { 10 | let big_n = 1 << n; 11 | 12 | let mut reals = vec![0.0f64; big_n]; 13 | let mut imags = vec![0.0f64; big_n]; 14 | 15 | gen_random_signal(&mut reals, &mut imags); 16 | let mut signal = vec![Complex64::default(); big_n]; 17 | reals 18 | .drain(..) 19 | .zip(imags.drain(..)) 20 | .zip(signal.iter_mut()) 21 | .for_each(|((re, im), z)| { 22 | z.re = re; 23 | z.im = im; 24 | }); 25 | 26 | let mut planner = FftPlanner::new(); 27 | let fft = planner.plan_fft_forward(signal.len()); 28 | 29 | let now = std::time::Instant::now(); 30 | fft.process(&mut signal); 31 | let elapsed = now.elapsed().as_nanos(); 32 | println!("{elapsed}"); 33 | } 34 | 35 | fn main() { 36 | let args: Vec = env::args().collect(); 37 | assert_eq!(args.len(), 2, "Usage {} ", args[0]); 38 | 39 | let n = usize::from_str(&args[1]).unwrap(); 40 | benchmark_rustfft(n); 41 | } 42 | -------------------------------------------------------------------------------- /hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | if ! cargo fmt --all -- --check 6 | then 7 | echo "There are some code style issues." 8 | echo "Run cargo fmt first." 9 | exit 1 10 | fi 11 | 12 | if ! cargo clippy --all-targets --all-features --tests -- -D warnings 13 | then 14 | echo "There are some clippy issues." 15 | exit 1 16 | fi 17 | 18 | if ! cargo test --all-features 19 | then 20 | echo "There are some test issues." 21 | exit 1 22 | fi 23 | 24 | exit 0 25 | -------------------------------------------------------------------------------- /profile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -Eeuo pipefail 4 | 5 | if [[ "$#" -ne 1 ]] 6 | then 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | cargo +nightly build --profile profiling --example profile 12 | 13 | sudo perf record --call-graph=dwarf ./target/profiling/examples/profile $1 && sudo perf script -f -F +pid > processed_result.perf 14 | 15 | echo "done! results in process_result.perf" 16 | -------------------------------------------------------------------------------- /pyphastft/.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by maturin v1.4.0 2 | # To update, run 3 | # 4 | # maturin generate-ci github 5 | # 6 | name: CI 7 | 8 | on: 9 | push: 10 | branches: 11 | - main 12 | - master 13 | tags: 14 | - '*' 15 | pull_request: 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | linux: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | matrix: 26 | target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] 27 | steps: 28 | - uses: actions/checkout@v3 29 | - uses: actions/setup-python@v4 30 | with: 31 | python-version: '3.10' 32 | - name: Build wheels 33 | uses: PyO3/maturin-action@v1 34 | with: 35 | target: ${{ matrix.target }} 36 | args: --release --out dist --find-interpreter 37 | sccache: 'true' 38 | manylinux: auto 39 | - name: Upload wheels 40 | uses: actions/upload-artifact@v3 41 | with: 42 | name: wheels 43 | path: dist 44 | 45 | windows: 46 | runs-on: windows-latest 47 | strategy: 48 | matrix: 49 | target: [x64, x86] 50 | steps: 51 | - uses: actions/checkout@v3 52 | - uses: actions/setup-python@v4 53 | with: 54 | python-version: '3.10' 55 | architecture: ${{ matrix.target }} 56 | - name: Build wheels 57 | uses: PyO3/maturin-action@v1 58 | with: 59 | target: ${{ matrix.target }} 60 | args: --release --out dist --find-interpreter 61 | sccache: 'true' 62 | - name: Upload wheels 63 | uses: actions/upload-artifact@v3 64 | with: 65 | name: wheels 66 | path: dist 67 | 68 | macos: 69 | runs-on: macos-latest 70 | strategy: 71 | matrix: 72 | target: [x86_64, aarch64] 73 | steps: 74 | - uses: actions/checkout@v3 75 | - uses: actions/setup-python@v4 76 | with: 77 | python-version: '3.10' 78 | - name: Build wheels 79 | uses: PyO3/maturin-action@v1 80 | with: 81 | target: ${{ matrix.target }} 82 | args: --release --out dist --find-interpreter 83 | sccache: 'true' 84 | - name: Upload wheels 85 | uses: actions/upload-artifact@v3 86 | with: 87 | name: wheels 88 | path: dist 89 | 90 | sdist: 91 | runs-on: ubuntu-latest 92 | steps: 93 | - uses: actions/checkout@v3 94 | - name: Build sdist 95 | uses: PyO3/maturin-action@v1 96 | with: 97 | command: sdist 98 | args: --out dist 99 | - name: Upload sdist 100 | uses: actions/upload-artifact@v3 101 | with: 102 | name: wheels 103 | path: dist 104 | 105 | release: 106 | name: Release 107 | runs-on: ubuntu-latest 108 | if: "startsWith(github.ref, 'refs/tags/')" 109 | needs: [linux, windows, macos, sdist] 110 | steps: 111 | - uses: actions/download-artifact@v3 112 | with: 113 | name: wheels 114 | - name: Publish to PyPI 115 | uses: PyO3/maturin-action@v1 116 | env: 117 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 118 | with: 119 | command: upload 120 | args: --non-interactive --skip-existing * 121 | -------------------------------------------------------------------------------- /pyphastft/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version 73 | -------------------------------------------------------------------------------- /pyphastft/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pyphastft" 3 | version = "0.2.1" 4 | edition = "2021" 5 | authors = ["Saveliy Yusufov", "Shnatsel"] 6 | license = "MIT OR Apache-2.0" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | [lib] 10 | name = "pyphastft" 11 | crate-type = ["cdylib"] 12 | 13 | [dependencies] 14 | pyo3 = { version = "0.21.2" } 15 | numpy = "0.21.0" 16 | phastft = { path = ".." } 17 | -------------------------------------------------------------------------------- /pyphastft/example.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from numpy.fft import fft 4 | 5 | # from pyphastft import fft 6 | 7 | 8 | def main(): 9 | fs = 100 # Sampling frequency (100 samples/second for this synthetic example) 10 | t_max = 6 # maximum time in "seconds" 11 | 12 | # Find the next lower power of 2 for the number of samples 13 | n_samples = 2 ** int(np.log2(t_max * fs)) 14 | 15 | t = np.linspace( 16 | 0, n_samples / fs, n_samples, endpoint=False 17 | ) # Adjusted time vector 18 | 19 | # Generate the signal 20 | s_re = 2 * np.sin(2 * np.pi * t + 1) + np.sin(2 * np.pi * 10 * t + 1) 21 | s_im = np.ascontiguousarray([0.0] * len(s_re), dtype=np.float64) 22 | 23 | # Plot the original signal 24 | plt.figure(figsize=(10, 7)) 25 | 26 | plt.subplot(2, 1, 1) 27 | plt.plot(t, s_re, label="f(x) = 2sin(x) + sin(10x)") 28 | plt.title("signal: f(x) = 2sin(x) + sin(10x)") 29 | plt.xlabel("time [seconds]") 30 | plt.ylabel("f(x)") 31 | plt.legend() 32 | 33 | # Perform FFT 34 | s_re = fft(s_re) 35 | 36 | # Plot the magnitude spectrum of the FFT result 37 | plt.subplot(2, 1, 2) 38 | plt.plot( 39 | np.abs(s_re), 40 | label="frequency spectrum", 41 | ) 42 | plt.title("Signal after FFT") 43 | plt.xlabel("frequency (in Hz)") 44 | plt.ylabel("|FFT(f(x))|") 45 | 46 | # only show up to 11 Hz as in the wiki example 47 | plt.xlim(0, 11) 48 | 49 | plt.legend() 50 | plt.tight_layout() 51 | plt.savefig("wiki_fft_example.png", dpi=600) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /pyphastft/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.4,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "pyphastft" 7 | requires-python = ">=3.8" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | ] 13 | dynamic = ["version"] 14 | 15 | [tool.maturin] 16 | features = ["pyo3/extension-module"] 17 | -------------------------------------------------------------------------------- /pyphastft/src/lib.rs: -------------------------------------------------------------------------------- 1 | use numpy::PyReadwriteArray1; 2 | use phastft::{fft_64 as fft_64_rs, planner::Direction}; 3 | use pyo3::prelude::*; 4 | 5 | #[pyfunction] 6 | fn fft(mut reals: PyReadwriteArray1, mut imags: PyReadwriteArray1, direction: char) { 7 | assert!(direction == 'f' || direction == 'r'); 8 | let dir = if direction == 'f' { 9 | Direction::Forward 10 | } else { 11 | Direction::Reverse 12 | }; 13 | 14 | fft_64_rs( 15 | reals.as_slice_mut().unwrap(), 16 | imags.as_slice_mut().unwrap(), 17 | dir, 18 | ); 19 | } 20 | 21 | /// A Python module implemented in Rust. 22 | #[pymodule] 23 | fn pyphastft(_py: Python, m: &PyModule) -> PyResult<()> { 24 | m.add_function(wrap_pyfunction!(fft, m)?)?; 25 | Ok(()) 26 | } 27 | -------------------------------------------------------------------------------- /pyphastft/vis_qt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pyaudio 5 | import pyqtgraph as pg 6 | from pyphastft import fft 7 | from pyqtgraph.Qt import QtWidgets, QtCore 8 | 9 | 10 | class RealTimeAudioSpectrum(QtWidgets.QWidget): 11 | def __init__(self, parent=None): 12 | super(RealTimeAudioSpectrum, self).__init__(parent) 13 | self.n_fft_bins = 1024 # Increased FFT size for better frequency resolution 14 | self.n_display_bins = 32 # Maintain the same number of bars in the display 15 | self.sample_rate = 44100 16 | self.smoothing_factor = 0.1 # Smaller value for more smoothing 17 | self.ema_fft_data = np.zeros( 18 | self.n_display_bins 19 | ) # Adjusted to the number of display bins 20 | self.init_ui() 21 | self.init_audio_stream() 22 | 23 | def init_ui(self): 24 | self.layout = QtWidgets.QVBoxLayout(self) 25 | self.plot_widget = pg.PlotWidget() 26 | self.layout.addWidget(self.plot_widget) 27 | 28 | # Customize plot aesthetics 29 | self.plot_widget.setBackground("k") 30 | self.plot_item = self.plot_widget.getPlotItem() 31 | self.plot_item.setTitle( 32 | "Real-Time Audio Spectrum Visualizer powered by PhastFT", 33 | color="w", 34 | size="16pt", 35 | ) 36 | 37 | # Hide axis labels 38 | self.plot_item.getAxis("left").hide() 39 | self.plot_item.getAxis("bottom").hide() 40 | 41 | # Set fixed ranges for the x and y axes to prevent them from jumping 42 | self.plot_item.setXRange(0, self.sample_rate / 2, padding=0) 43 | self.plot_item.setYRange(0, 1, padding=0) 44 | 45 | self.bar_width = ( 46 | (self.sample_rate / 2) / self.n_display_bins * 0.90 47 | ) # Adjusted width for display bins 48 | 49 | # Calculate bar positions so they are centered with respect to their frequency values 50 | self.freqs = np.linspace( 51 | 0 + self.bar_width / 2, 52 | self.sample_rate / 2 - self.bar_width / 2, 53 | self.n_display_bins, 54 | ) 55 | 56 | self.bar_graph = pg.BarGraphItem( 57 | x=self.freqs, 58 | height=np.zeros(self.n_display_bins), 59 | width=self.bar_width, 60 | brush=pg.mkBrush("m"), 61 | ) 62 | self.plot_item.addItem(self.bar_graph) 63 | 64 | self.timer = QtCore.QTimer() 65 | self.timer.timeout.connect(self.update) 66 | self.timer.start(50) # Update interval in milliseconds 67 | 68 | def init_audio_stream(self): 69 | self.p = pyaudio.PyAudio() 70 | self.stream = self.p.open( 71 | format=pyaudio.paFloat32, 72 | channels=1, 73 | rate=self.sample_rate, 74 | input=True, 75 | frames_per_buffer=self.n_fft_bins, # This should match the FFT size 76 | stream_callback=self.audio_callback, 77 | ) 78 | self.stream.start_stream() 79 | 80 | def audio_callback(self, in_data, frame_count, time_info, status): 81 | audio_data = np.frombuffer(in_data, dtype=np.float32) 82 | reals = np.zeros(self.n_fft_bins) 83 | imags = np.zeros(self.n_fft_bins) 84 | reals[: len(audio_data)] = audio_data # Fill the reals array with audio data 85 | fft(reals, imags, direction="f") 86 | fft_magnitude = np.sqrt(reals**2 + imags**2)[: self.n_fft_bins // 2] 87 | 88 | # Aggregate or interpolate FFT data to fit into display bins 89 | new_fft_data = np.interp( 90 | np.linspace(0, len(fft_magnitude), self.n_display_bins), 91 | np.arange(len(fft_magnitude)), 92 | fft_magnitude, 93 | ) 94 | 95 | # Apply exponential moving average filter 96 | self.ema_fft_data = self.ema_fft_data * self.smoothing_factor + new_fft_data * ( 97 | 1.0 - self.smoothing_factor 98 | ) 99 | return in_data, pyaudio.paContinue 100 | 101 | def update(self): 102 | self.bar_graph.setOpts(height=self.ema_fft_data, width=self.bar_width) 103 | 104 | def closeEvent(self, event): 105 | self.stream.stop_stream() 106 | self.stream.close() 107 | self.p.terminate() 108 | event.accept() 109 | 110 | 111 | if __name__ == "__main__": 112 | app = QtWidgets.QApplication(sys.argv) 113 | window = RealTimeAudioSpectrum() 114 | window.show() 115 | sys.exit(app.exec_()) 116 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly" 3 | -------------------------------------------------------------------------------- /scripts/twiddle_generator.py: -------------------------------------------------------------------------------- 1 | from matplotlib import cm 2 | 3 | 4 | def gen_twiddles(dist: int) -> list[complex]: 5 | theta = -np.pi / dist 6 | g = complex(np.cos(theta), np.sin(theta)) 7 | 8 | w = complex(1.0, 0.0) 9 | print(w) 10 | 11 | for k in range(1, dist): 12 | w *= g 13 | print(w) 14 | 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | 19 | 20 | def cycle_colors(num_colors, colormap="viridis"): 21 | """ 22 | Generate a list of colors by cycling over a specified colormap. 23 | 24 | Parameters: 25 | - num_colors (int): Number of colors to generate. 26 | - colormap (str): Name of the Matplotlib colormap to use. 27 | 28 | Returns: 29 | - colors (list): List of color values in hexadecimal format. 30 | """ 31 | # Get the specified colormap 32 | cmap = cm.get_cmap(colormap) 33 | 34 | # Generate equally spaced values from 0 to 1 35 | values = np.linspace(0, 1, num_colors) 36 | 37 | # Map the values to colors using the colormap 38 | colors = [cmap(value) for value in values] 39 | 40 | # Convert colors to hexadecimal format 41 | hex_colors = [ 42 | f"#{int(r * 255):02X}{int(g * 255):02X}{int(b * 255):02X}" 43 | for r, g, b, _ in colors 44 | ] 45 | 46 | return hex_colors 47 | 48 | 49 | def plot_roots_of_unity(dist: int): 50 | # Calculate the nth roots of unity 51 | theta = -np.pi / dist 52 | roots = [] 53 | g = complex(np.cos(theta), np.sin(theta)) 54 | w = complex(1.0, 0.0) 55 | roots = [w] 56 | 57 | for k in range(1, dist): 58 | w *= g 59 | roots.append(w) 60 | 61 | roots = np.asarray(roots) 62 | for r in roots: 63 | print(f"{np.round(r.real, 2)} {np.round(r.imag, 2)}") 64 | 65 | temp = cycle_colors(dist // 2) 66 | temp.reverse() 67 | all_colors = cycle_colors(dist // 2) + temp 68 | all_colors[0] = "r" 69 | 70 | # Plot the roots 71 | plt.figure(figsize=(6, 6)) 72 | plt.scatter(roots.real, roots.imag, color=all_colors, marker="o") 73 | plt.title(f"Roots of Unity (n={n})") 74 | plt.xlabel("Real Part") 75 | plt.ylabel("Imaginary Part") 76 | plt.axhline(0, color="black", linewidth=0.5) 77 | plt.axvline(0, color="black", linewidth=0.5) 78 | plt.grid(color="gray", linestyle="--", linewidth=0.5) 79 | plt.axis("equal") 80 | 81 | # Limit the axes to -1.0 to 1.0 82 | plt.xlim(-1.0, 1.0) 83 | plt.ylim(-1.0, 1.0) 84 | 85 | plt.show() 86 | 87 | 88 | # # Specify the number of roots (n) 89 | # n = 8 # You can change this to any positive integer 90 | # 91 | # # Plot the roots of unity for the specified value of n 92 | # plot_roots_of_unity(n) 93 | 94 | 95 | def main(): 96 | # gen_twiddles(4) 97 | print("===================\n") 98 | gen_twiddles(16) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | 104 | # (1+0j) 105 | # (0.9238795325112867-0.3826834323650898j) 106 | # (0.7071067811865475-0.7071067811865476j) 107 | # (0.38268343236508967-0.9238795325112867j) 108 | # (-1.2420623018332135e-16-0.9999999999999999j) 109 | # (-0.38268343236508984-0.9238795325112866j) 110 | # (-0.7071067811865476-0.7071067811865474j) 111 | # (-0.9238795325112867-0.38268343236508956j) 112 | -------------------------------------------------------------------------------- /src/cobra.rs: -------------------------------------------------------------------------------- 1 | //! This module provides several implementations of the bit reverse permutation, which is 2 | //! essential for algorithms like FFT. 3 | //! 4 | //! In practice, most FFT implementations avoid bit reversals; however this comes at a computational 5 | //! cost as well. For example, Bailey's 4 step FFT algorithm is O(N * lg(N) * lg(lg(N))). 6 | //! The original Cooley-Tukey implementation is O(N * lg(N)). The extra term in the 4-step algorithm 7 | //! comes from incorporating the bit reversals into each level of the recursion. By utilizing a 8 | //! cache-optimal bit reversal, we are able to avoid this extra cost [1]. 9 | //! 10 | //! # References 11 | //! 12 | //! [1] L. Carter and K. S. Gatlin, "Towards an optimal bit-reversal permutation program," Proceedings 39th Annual 13 | //! Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi: 14 | //! 10.1109/SFCS.1998.743505. 15 | //! keywords: {Read-write memory;Costs;Computer science;Drives;Random access memory;Argon;Registers;Read only memory;Computational modeling;Libraries} 16 | 17 | use num_traits::Float; 18 | 19 | const BLOCK_WIDTH: usize = 128; 20 | // size of the cacheline 21 | const LOG_BLOCK_WIDTH: usize = 7; // log2 of cacheline 22 | 23 | /// In-place bit reversal on a single buffer. Also referred to as "Jennifer's method" [1]. 24 | /// 25 | /// ## References 26 | /// [1] 27 | #[inline] 28 | pub(crate) fn bit_rev(buf: &mut [T], log_n: usize) { 29 | let mut nodd: usize; 30 | let mut noddrev; // to hold bitwise negated or odd values 31 | 32 | let big_n = 1 << log_n; 33 | let halfn = big_n >> 1; // frequently used 'constants' 34 | let quartn = big_n >> 2; 35 | let nmin1 = big_n - 1; 36 | 37 | let mut forward = halfn; // variable initialisations 38 | let mut rev = 1; 39 | 40 | let mut i = quartn; 41 | while i > 0 { 42 | // start of bit reversed permutation loop, N/4 iterations 43 | 44 | // Gray code generator for even values: 45 | 46 | nodd = !i; // counting ones is easier 47 | 48 | let mut zeros = 0; 49 | while (nodd & 1) == 1 { 50 | nodd >>= 1; 51 | zeros += 1; 52 | } 53 | 54 | forward ^= 2 << zeros; // toggle one bit of forward 55 | rev ^= quartn >> zeros; // toggle one bit of rev 56 | 57 | // swap even and ~even conditionally 58 | if forward < rev { 59 | buf.swap(forward, rev); 60 | nodd = nmin1 ^ forward; // compute the bitwise negations 61 | noddrev = nmin1 ^ rev; 62 | buf.swap(nodd, noddrev); // swap bitwise-negated pairs 63 | } 64 | 65 | nodd = forward ^ 1; // compute the odd values from the even 66 | noddrev = rev ^ halfn; 67 | 68 | // swap odd unconditionally 69 | buf.swap(nodd, noddrev); 70 | i -= 1; 71 | } 72 | } 73 | 74 | /// Run in-place bit reversal on the entire state (i.e., the reals and imags buffers) 75 | /// 76 | /// ## References 77 | /// [1] 78 | #[allow(unused)] 79 | #[deprecated( 80 | since = "0.1.0", 81 | note = "Please use COBRA for a cache-optimal bit reverse permutation." 82 | )] 83 | fn complex_bit_rev(reals: &mut [T], imags: &mut [T], log_n: usize) { 84 | let mut nodd: usize; 85 | let mut noddrev; // to hold bitwise negated or odd values 86 | 87 | let big_n = 1 << log_n; 88 | let halfn = big_n >> 1; // frequently used 'constants' 89 | let quartn = big_n >> 2; 90 | let nmin1 = big_n - 1; 91 | 92 | let mut forward = halfn; // variable initialisations 93 | let mut rev = 1; 94 | 95 | let mut i = quartn; 96 | while i > 0 { 97 | // start of bit-reversed permutation loop, N/4 iterations 98 | 99 | // Gray code generator for even values: 100 | 101 | nodd = !i; // counting ones is easier 102 | 103 | let mut zeros = 0; 104 | while (nodd & 1) == 1 { 105 | nodd >>= 1; 106 | zeros += 1; 107 | } 108 | 109 | forward ^= 2 << zeros; // toggle one bit of forward 110 | rev ^= quartn >> zeros; // toggle one bit of rev 111 | 112 | // swap even and ~even conditionally 113 | if forward < rev { 114 | reals.swap(forward, rev); 115 | imags.swap(forward, rev); 116 | nodd = nmin1 ^ forward; // compute the bitwise negations 117 | noddrev = nmin1 ^ rev; 118 | 119 | // swap bitwise-negated pairs 120 | reals.swap(nodd, noddrev); 121 | imags.swap(nodd, noddrev); 122 | } 123 | 124 | nodd = forward ^ 1; // compute the odd values from the even 125 | noddrev = rev ^ halfn; 126 | 127 | // swap odd unconditionally 128 | reals.swap(nodd, noddrev); 129 | imags.swap(nodd, noddrev); 130 | i -= 1; 131 | } 132 | } 133 | 134 | #[allow(dead_code)] 135 | #[deprecated( 136 | since = "0.1.0", 137 | note = "Naive bit reverse permutation is slow and not cache friendly. COBRA should be used instead." 138 | )] 139 | pub(crate) fn bit_reverse_permutation(buf: &mut [T]) { 140 | let n = buf.len(); 141 | let mut j = 0; 142 | 143 | for i in 1..n { 144 | let mut bit = n >> 1; 145 | 146 | while (j & bit) != 0 { 147 | j ^= bit; 148 | bit >>= 1; 149 | } 150 | j ^= bit; 151 | 152 | if i < j { 153 | buf.swap(i, j); 154 | } 155 | } 156 | } 157 | 158 | /// Pure Rust implementation of Cache Optimal Bit-Reverse Algorithm (COBRA). 159 | /// Rewritten from a C++ implementation [3]. 160 | /// 161 | /// ## References 162 | /// [1] L. Carter and K. S. Gatlin, "Towards an optimal bit-reversal permutation program," Proceedings 39th Annual 163 | /// Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi: 164 | /// 10.1109/SFCS.1998.743505. 165 | /// [2] Christian Knauth, Boran Adas, Daniel Whitfield, Xuesong Wang, Lydia Ickler, Tim Conrad, Oliver Serang: 166 | /// Practically efficient methods for performing bit-reversed permutation in C++11 on the x86-64 architecture 167 | /// [3] 168 | #[multiversion::multiversion(targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 169 | "x86_64+avx2+fma", // x86_64-v3 170 | "x86_64+sse4.2", // x86_64-v2 171 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 172 | "x86+avx2+fma", 173 | "x86+sse4.2", 174 | "x86+sse2", 175 | ))] 176 | pub fn cobra_apply(v: &mut [T], log_n: usize) { 177 | if log_n <= 2 * LOG_BLOCK_WIDTH { 178 | bit_rev(v, log_n); 179 | return; 180 | } 181 | let num_b_bits = log_n - 2 * LOG_BLOCK_WIDTH; 182 | let b_size: usize = 1 << num_b_bits; 183 | let block_width: usize = 1 << LOG_BLOCK_WIDTH; 184 | 185 | let mut buffer = [T::default(); BLOCK_WIDTH * BLOCK_WIDTH]; 186 | 187 | for b in 0..b_size { 188 | let b_rev = b.reverse_bits() >> ((b_size - 1).leading_zeros()); 189 | 190 | // Copy block to buffer 191 | for a in 0..block_width { 192 | let a_rev = a.reverse_bits() >> ((block_width - 1).leading_zeros()); 193 | for c in 0..BLOCK_WIDTH { 194 | buffer[(a_rev << LOG_BLOCK_WIDTH) | c] = 195 | v[(a << num_b_bits << LOG_BLOCK_WIDTH) | (b << LOG_BLOCK_WIDTH) | c]; 196 | } 197 | } 198 | 199 | for c in 0..BLOCK_WIDTH { 200 | // NOTE: Typo in original pseudocode by Carter and Gatlin at the following line: 201 | let c_rev = c.reverse_bits() >> ((block_width - 1).leading_zeros()); 202 | 203 | for a_rev in 0..BLOCK_WIDTH { 204 | let a = a_rev.reverse_bits() >> ((block_width - 1).leading_zeros()); 205 | 206 | // To guarantee each value is swapped only one time: 207 | // index < reversed_index <--> 208 | // a b c < c' b' a' <--> 209 | // a < c' || 210 | // a <= c' && b < b' || 211 | // a <= c' && b <= b' && a' < c 212 | let index_less_than_reverse = a < c_rev 213 | || (a == c_rev && b < b_rev) 214 | || (a == c_rev && b == b_rev && a_rev < c); 215 | 216 | if index_less_than_reverse { 217 | let v_idx = (c_rev << num_b_bits << LOG_BLOCK_WIDTH) 218 | | (b_rev << LOG_BLOCK_WIDTH) 219 | | a_rev; 220 | let b_idx = (a_rev << LOG_BLOCK_WIDTH) | c; 221 | std::mem::swap(&mut v[v_idx], &mut buffer[b_idx]); 222 | } 223 | } 224 | } 225 | 226 | // Copy changes that were swapped into buffer above: 227 | for a in 0..BLOCK_WIDTH { 228 | let a_rev = a.reverse_bits() >> ((block_width - 1).leading_zeros()); 229 | for c in 0..BLOCK_WIDTH { 230 | let c_rev = c.reverse_bits() >> ((block_width - 1).leading_zeros()); 231 | let index_less_than_reverse = a < c_rev 232 | || (a == c_rev && b < b_rev) 233 | || (a == c_rev && b == b_rev && a_rev < c); 234 | 235 | if index_less_than_reverse { 236 | let v_idx = (a << num_b_bits << LOG_BLOCK_WIDTH) | (b << LOG_BLOCK_WIDTH) | c; 237 | let b_idx = (a_rev << LOG_BLOCK_WIDTH) | c; 238 | std::mem::swap(&mut v[v_idx], &mut buffer[b_idx]); 239 | } 240 | } 241 | } 242 | } 243 | } 244 | 245 | #[cfg(test)] 246 | mod tests { 247 | use super::*; 248 | 249 | /// Top down bit reverse interleaving. This is a very simple and well known approach that we only use for testing 250 | /// COBRA and any other bit reverse algorithms. 251 | fn top_down_bit_reverse_permutation(x: &[T]) -> Vec { 252 | if x.len() == 1 { 253 | return x.to_vec(); 254 | } 255 | 256 | let mut y = Vec::with_capacity(x.len()); 257 | let mut evens = Vec::with_capacity(x.len() >> 1); 258 | let mut odds = Vec::with_capacity(x.len() >> 1); 259 | 260 | let mut i = 1; 261 | while i < x.len() { 262 | evens.push(x[i - 1]); 263 | odds.push(x[i]); 264 | i += 2; 265 | } 266 | 267 | y.extend_from_slice(&top_down_bit_reverse_permutation(&evens)); 268 | y.extend_from_slice(&top_down_bit_reverse_permutation(&odds)); 269 | y 270 | } 271 | 272 | #[test] 273 | fn cobra() { 274 | for n in 4..23 { 275 | let big_n = 1 << n; 276 | let mut v: Vec<_> = (0..big_n).collect(); 277 | cobra_apply(&mut v, n); 278 | 279 | let x: Vec<_> = (0..big_n).collect(); 280 | assert_eq!(v, top_down_bit_reverse_permutation(&x)); 281 | } 282 | } 283 | 284 | #[test] 285 | fn bit_reversal() { 286 | let n = 3; 287 | let big_n = 1 << n; 288 | let mut buf: Vec = (0..big_n).map(f64::from).collect(); 289 | bit_rev(&mut buf, n); 290 | println!("{buf:?}"); 291 | assert_eq!(buf, vec![0.0, 4.0, 2.0, 6.0, 1.0, 5.0, 3.0, 7.0]); 292 | 293 | let n = 4; 294 | let big_n = 1 << n; 295 | let mut buf: Vec = (0..big_n).map(f64::from).collect(); 296 | bit_rev(&mut buf, n); 297 | println!("{buf:?}"); 298 | assert_eq!( 299 | buf, 300 | vec![ 301 | 0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 302 | 15.0, 303 | ] 304 | ); 305 | } 306 | 307 | #[test] 308 | fn jennifer_method() { 309 | for n in 2..24 { 310 | let big_n = 1 << n; 311 | let mut actual_re: Vec = (0..big_n).map(f64::from).collect(); 312 | let mut actual_im: Vec = (0..big_n).map(f64::from).collect(); 313 | 314 | #[allow(deprecated)] 315 | complex_bit_rev(&mut actual_re, &mut actual_im, n); 316 | 317 | let input_re: Vec = (0..big_n).map(f64::from).collect(); 318 | let expected_re = top_down_bit_reverse_permutation(&input_re); 319 | assert_eq!(actual_re, expected_re); 320 | 321 | let input_im: Vec = (0..big_n).map(f64::from).collect(); 322 | let expected_im = top_down_bit_reverse_permutation(&input_im); 323 | assert_eq!(actual_im, expected_im); 324 | } 325 | } 326 | 327 | #[test] 328 | fn naive_bit_reverse_permutation() { 329 | for n in 2..24 { 330 | let big_n = 1 << n; 331 | let mut actual_re: Vec = (0..big_n).map(f64::from).collect(); 332 | let mut actual_im: Vec = (0..big_n).map(f64::from).collect(); 333 | 334 | #[allow(deprecated)] 335 | bit_reverse_permutation(&mut actual_re); 336 | 337 | #[allow(deprecated)] 338 | bit_reverse_permutation(&mut actual_im); 339 | 340 | let input_re: Vec = (0..big_n).map(f64::from).collect(); 341 | let expected_re = top_down_bit_reverse_permutation(&input_re); 342 | assert_eq!(actual_re, expected_re); 343 | 344 | let input_im: Vec = (0..big_n).map(f64::from).collect(); 345 | let expected_im = top_down_bit_reverse_permutation(&input_im); 346 | assert_eq!(actual_im, expected_im); 347 | } 348 | } 349 | } 350 | -------------------------------------------------------------------------------- /src/kernels.rs: -------------------------------------------------------------------------------- 1 | use std::simd::{f32x16, f64x8}; 2 | 3 | use num_traits::Float; 4 | 5 | macro_rules! fft_butterfly_n_simd { 6 | ($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => { 7 | #[multiversion::multiversion(targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 8 | "x86_64+avx2+fma", // x86_64-v3 9 | "x86_64+sse4.2", // x86_64-v2 10 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 11 | "x86+avx2+fma", 12 | "x86+sse4.2", 13 | "x86+sse2", 14 | ))] 15 | #[inline] 16 | pub fn $func_name( 17 | reals: &mut [$precision], 18 | imags: &mut [$precision], 19 | twiddles_re: &[$precision], 20 | twiddles_im: &[$precision], 21 | dist: usize, 22 | ) { 23 | let chunk_size = dist << 1; 24 | assert!(chunk_size >= $lanes * 2); 25 | reals 26 | .chunks_exact_mut(chunk_size) 27 | .zip(imags.chunks_exact_mut(chunk_size)) 28 | .for_each(|(reals_chunk, imags_chunk)| { 29 | let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); 30 | let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); 31 | 32 | reals_s0 33 | .chunks_exact_mut($lanes) 34 | .zip(reals_s1.chunks_exact_mut($lanes)) 35 | .zip(imags_s0.chunks_exact_mut($lanes)) 36 | .zip(imags_s1.chunks_exact_mut($lanes)) 37 | .zip(twiddles_re.chunks_exact($lanes)) 38 | .zip(twiddles_im.chunks_exact($lanes)) 39 | .for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| { 40 | let real_c0 = <$simd_vector>::from_slice(re_s0); 41 | let real_c1 = <$simd_vector>::from_slice(re_s1); 42 | let imag_c0 = <$simd_vector>::from_slice(im_s0); 43 | let imag_c1 = <$simd_vector>::from_slice(im_s1); 44 | 45 | let tw_re = <$simd_vector>::from_slice(w_re); 46 | let tw_im = <$simd_vector>::from_slice(w_im); 47 | 48 | re_s0.copy_from_slice((real_c0 + real_c1).as_array()); 49 | im_s0.copy_from_slice((imag_c0 + imag_c1).as_array()); 50 | let v_re = real_c0 - real_c1; 51 | let v_im = imag_c0 - imag_c1; 52 | re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array()); 53 | im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array()); 54 | }); 55 | }); 56 | } 57 | }; 58 | } 59 | 60 | fft_butterfly_n_simd!(fft_64_chunk_n_simd, f64, 8, f64x8); 61 | fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 16, f32x16); 62 | 63 | #[multiversion::multiversion(targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 64 | "x86_64+avx2+fma", // x86_64-v3 65 | "x86_64+sse4.2", // x86_64-v2 66 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 67 | "x86+avx2+fma", 68 | "x86+sse4.2", 69 | "x86+sse2", 70 | ))] 71 | #[inline] 72 | pub(crate) fn fft_chunk_n( 73 | reals: &mut [T], 74 | imags: &mut [T], 75 | twiddles_re: &[T], 76 | twiddles_im: &[T], 77 | dist: usize, 78 | ) { 79 | let chunk_size = dist << 1; 80 | 81 | reals 82 | .chunks_exact_mut(chunk_size) 83 | .zip(imags.chunks_exact_mut(chunk_size)) 84 | .for_each(|(reals_chunk, imags_chunk)| { 85 | let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); 86 | let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); 87 | 88 | reals_s0 89 | .iter_mut() 90 | .zip(reals_s1.iter_mut()) 91 | .zip(imags_s0.iter_mut()) 92 | .zip(imags_s1.iter_mut()) 93 | .zip(twiddles_re.iter()) 94 | .zip(twiddles_im.iter()) 95 | .for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| { 96 | let real_c0 = *re_s0; 97 | let real_c1 = *re_s1; 98 | let imag_c0 = *im_s0; 99 | let imag_c1 = *im_s1; 100 | 101 | *re_s0 = real_c0 + real_c1; 102 | *im_s0 = imag_c0 + imag_c1; 103 | let v_re = real_c0 - real_c1; 104 | let v_im = imag_c0 - imag_c1; 105 | *re_s1 = v_re * *w_re - v_im * *w_im; 106 | *im_s1 = v_re * *w_im + v_im * *w_re; 107 | }); 108 | }); 109 | } 110 | 111 | /// `chunk_size == 4`, so hard code twiddle factors 112 | #[multiversion::multiversion(targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 113 | "x86_64+avx2+fma", // x86_64-v3 114 | "x86_64+sse4.2", // x86_64-v2 115 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 116 | "x86+avx2+fma", 117 | "x86+sse4.2", 118 | "x86+sse2", 119 | ))] 120 | #[inline] 121 | pub(crate) fn fft_chunk_4(reals: &mut [T], imags: &mut [T]) { 122 | const DIST: usize = 2; 123 | const CHUNK_SIZE: usize = DIST << 1; 124 | 125 | reals 126 | .chunks_exact_mut(CHUNK_SIZE) 127 | .zip(imags.chunks_exact_mut(CHUNK_SIZE)) 128 | .for_each(|(reals_chunk, imags_chunk)| { 129 | let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST); 130 | let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST); 131 | 132 | let real_c0 = reals_s0[0]; 133 | let real_c1 = reals_s1[0]; 134 | let imag_c0 = imags_s0[0]; 135 | let imag_c1 = imags_s1[0]; 136 | 137 | reals_s0[0] = real_c0 + real_c1; 138 | imags_s0[0] = imag_c0 + imag_c1; 139 | reals_s1[0] = real_c0 - real_c1; 140 | imags_s1[0] = imag_c0 - imag_c1; 141 | 142 | let real_c0 = reals_s0[1]; 143 | let real_c1 = reals_s1[1]; 144 | let imag_c0 = imags_s0[1]; 145 | let imag_c1 = imags_s1[1]; 146 | 147 | reals_s0[1] = real_c0 + real_c1; 148 | imags_s0[1] = imag_c0 + imag_c1; 149 | reals_s1[1] = imag_c0 - imag_c1; 150 | imags_s1[1] = -(real_c0 - real_c1); 151 | }); 152 | } 153 | 154 | /// `chunk_size == 2`, so skip phase 155 | #[multiversion::multiversion(targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 156 | "x86_64+avx2+fma", // x86_64-v3 157 | "x86_64+sse4.2", // x86_64-v2 158 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 159 | "x86+avx2+fma", 160 | "x86+sse4.2", 161 | "x86+sse2", 162 | ))] 163 | #[inline] 164 | pub(crate) fn fft_chunk_2(reals: &mut [T], imags: &mut [T]) { 165 | reals 166 | .chunks_exact_mut(2) 167 | .zip(imags.chunks_exact_mut(2)) 168 | .for_each(|(reals_chunk, imags_chunk)| { 169 | let z0_re = reals_chunk[0]; 170 | let z0_im = imags_chunk[0]; 171 | let z1_re = reals_chunk[1]; 172 | let z1_im = imags_chunk[1]; 173 | 174 | reals_chunk[0] = z0_re + z1_re; 175 | imags_chunk[0] = z0_im + z1_im; 176 | reals_chunk[1] = z0_re - z1_re; 177 | imags_chunk[1] = z0_im - z1_im; 178 | }); 179 | } 180 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | #![warn( 3 | missing_docs, 4 | clippy::complexity, 5 | clippy::perf, 6 | clippy::style, 7 | clippy::correctness, 8 | clippy::suspicious 9 | )] 10 | #![forbid(unsafe_code)] 11 | #![feature(portable_simd, avx512_target_feature)] 12 | #![feature(doc_cfg)] 13 | 14 | #[cfg(feature = "complex-nums")] 15 | use crate::utils::{combine_re_im, deinterleave_complex32, deinterleave_complex64}; 16 | #[cfg(feature = "complex-nums")] 17 | use num_complex::Complex; 18 | 19 | use crate::cobra::cobra_apply; 20 | use crate::kernels::{ 21 | fft_32_chunk_n_simd, fft_64_chunk_n_simd, fft_chunk_2, fft_chunk_4, fft_chunk_n, 22 | }; 23 | use crate::options::Options; 24 | use crate::planner::{Direction, Planner32, Planner64}; 25 | use crate::twiddles::filter_twiddles; 26 | 27 | pub mod cobra; 28 | mod kernels; 29 | pub mod options; 30 | pub mod planner; 31 | mod twiddles; 32 | mod utils; 33 | 34 | macro_rules! impl_fft_for { 35 | ($func_name:ident, $precision:ty, $planner:ty, $opts_and_plan:ident) => { 36 | /// FFT -- Decimation in Frequency. This is just the decimation-in-time algorithm, reversed. 37 | /// This call to FFT is run, in-place. 38 | /// The input should be provided in normal order, and then the modified input is bit-reversed. 39 | /// 40 | /// # Panics 41 | /// 42 | /// Panics if `reals.len() != imags.len()`, or if the input length is _not_ a power of 2. 43 | /// 44 | /// ## References 45 | /// 46 | pub fn $func_name( 47 | reals: &mut [$precision], 48 | imags: &mut [$precision], 49 | direction: Direction, 50 | ) { 51 | assert_eq!( 52 | reals.len(), 53 | imags.len(), 54 | "real and imaginary inputs must be of equal size, but got: {} {}", 55 | reals.len(), 56 | imags.len() 57 | ); 58 | 59 | let planner = <$planner>::new(reals.len(), direction); 60 | assert!( 61 | planner.num_twiddles().is_power_of_two() 62 | && planner.num_twiddles() == reals.len() / 2 63 | ); 64 | 65 | let opts = Options::guess_options(reals.len()); 66 | 67 | $opts_and_plan(reals, imags, &opts, &planner); 68 | } 69 | }; 70 | } 71 | 72 | impl_fft_for!(fft_64, f64, Planner64, fft_64_with_opts_and_plan); 73 | impl_fft_for!(fft_32, f32, Planner32, fft_32_with_opts_and_plan); 74 | 75 | #[cfg(feature = "complex-nums")] 76 | macro_rules! impl_fft_interleaved_for { 77 | ($func_name:ident, $precision:ty, $fft_func:ident, $deinterleaving_func: ident) => { 78 | /// FFT Interleaved -- this is an alternative to [`fft_64`]/[`fft_32`] in the case where 79 | /// the input data is a array of [`Complex`]. 80 | /// 81 | /// The input should be provided in normal order, and then the modified input is 82 | /// bit-reversed. 83 | /// 84 | /// ## References 85 | /// 86 | pub fn $func_name(signal: &mut [Complex<$precision>], direction: Direction) { 87 | let (mut reals, mut imags) = $deinterleaving_func(signal); 88 | $fft_func(&mut reals, &mut imags, direction); 89 | signal.copy_from_slice(&combine_re_im(&reals, &imags)) 90 | } 91 | }; 92 | } 93 | 94 | #[doc(cfg(feature = "complex-nums"))] 95 | #[cfg(feature = "complex-nums")] 96 | impl_fft_interleaved_for!(fft_32_interleaved, f32, fft_32, deinterleave_complex32); 97 | #[doc(cfg(feature = "complex-nums"))] 98 | #[cfg(feature = "complex-nums")] 99 | impl_fft_interleaved_for!(fft_64_interleaved, f64, fft_64, deinterleave_complex64); 100 | 101 | macro_rules! impl_fft_with_opts_and_plan_for { 102 | ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident, $lanes:literal) => { 103 | /// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as 104 | /// a [`Planner`] in the case that this FFT will need to be run multiple times. 105 | /// 106 | /// `fft` automatically guesses the best strategy for a given input, 107 | /// so you only need to call this if you are tuning performance for a specific hardware platform. 108 | /// 109 | /// In addition, `fft` automatically creates a planner to be used. In the case that you plan 110 | /// on running an FFT many times on inputs of the same size, use this function with the pre-built 111 | /// [`Planner`]. 112 | /// 113 | /// # Panics 114 | /// 115 | /// Panics if `reals.len() != imags.len()`, or if the input length is _not_ a power of 2. 116 | #[multiversion::multiversion( 117 | targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 118 | "x86_64+avx2+fma", // x86_64-v3 119 | "x86_64+sse4.2", // x86_64-v2 120 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 121 | "x86+avx2+fma", 122 | "x86+sse4.2", 123 | "x86+sse2", 124 | ))] 125 | pub fn $func_name( 126 | reals: &mut [$precision], 127 | imags: &mut [$precision], 128 | opts: &Options, 129 | planner: &$planner, 130 | ) { 131 | assert!(reals.len() == imags.len() && reals.len().is_power_of_two()); 132 | let n: usize = reals.len().ilog2() as usize; 133 | 134 | // Use references to avoid unnecessary clones 135 | let twiddles_re = &planner.twiddles_re; 136 | let twiddles_im = &planner.twiddles_im; 137 | 138 | // We shouldn't be able to execute FFT if the # of twiddles isn't equal to the distance 139 | // between pairs 140 | assert!(twiddles_re.len() == reals.len() / 2 && twiddles_im.len() == imags.len() / 2); 141 | 142 | match planner.direction { 143 | Direction::Reverse => { 144 | for z_im in imags.iter_mut() { 145 | *z_im = -*z_im; 146 | } 147 | } 148 | _ => (), 149 | } 150 | 151 | // 0th stage is special due to no need to filter twiddle factor 152 | let dist = 1 << (n - 1); 153 | let chunk_size = dist << 1; 154 | 155 | if chunk_size > 4 { 156 | if chunk_size >= $lanes * 2 { 157 | $simd_butterfly_kernel(reals, imags, twiddles_re, twiddles_im, dist); 158 | } else { 159 | fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); 160 | } 161 | } 162 | else if chunk_size == 4 { 163 | fft_chunk_4(reals, imags); 164 | } 165 | else if chunk_size == 2 { 166 | fft_chunk_2(reals, imags); 167 | } 168 | 169 | let (mut filtered_twiddles_re, mut filtered_twiddles_im) = filter_twiddles(twiddles_re, twiddles_im); 170 | 171 | for t in (0..n - 1).rev() { 172 | let dist = 1 << t; 173 | let chunk_size = dist << 1; 174 | 175 | if chunk_size > 4 { 176 | if chunk_size >= $lanes * 2 { 177 | $simd_butterfly_kernel(reals, imags, &filtered_twiddles_re, &filtered_twiddles_im, dist); 178 | } else { 179 | fft_chunk_n(reals, imags, &filtered_twiddles_re, &filtered_twiddles_im, dist); 180 | } 181 | } 182 | else if chunk_size == 4 { 183 | fft_chunk_4(reals, imags); 184 | } 185 | else if chunk_size == 2 { 186 | fft_chunk_2(reals, imags); 187 | } 188 | (filtered_twiddles_re, filtered_twiddles_im) = filter_twiddles(&filtered_twiddles_re, &filtered_twiddles_im); 189 | } 190 | 191 | if opts.multithreaded_bit_reversal { 192 | std::thread::scope(|s| { 193 | s.spawn(|| cobra_apply(reals, n)); 194 | s.spawn(|| cobra_apply(imags, n)); 195 | }); 196 | } else { 197 | cobra_apply(reals, n); 198 | cobra_apply(imags, n); 199 | } 200 | 201 | match planner.direction { 202 | Direction::Reverse => { 203 | let scaling_factor = (reals.len() as $precision).recip(); 204 | for (z_re, z_im) in reals.iter_mut().zip(imags.iter_mut()) { 205 | *z_re *= scaling_factor; 206 | *z_im *= -scaling_factor; 207 | } 208 | } 209 | _ => (), 210 | } 211 | } 212 | }; 213 | } 214 | 215 | impl_fft_with_opts_and_plan_for!( 216 | fft_64_with_opts_and_plan, 217 | f64, 218 | Planner64, 219 | fft_64_chunk_n_simd, 220 | 8 221 | ); 222 | 223 | impl_fft_with_opts_and_plan_for!( 224 | fft_32_with_opts_and_plan, 225 | f32, 226 | Planner32, 227 | fft_32_chunk_n_simd, 228 | 16 229 | ); 230 | 231 | #[cfg(test)] 232 | mod tests { 233 | use std::ops::Range; 234 | 235 | use utilities::rustfft::num_complex::Complex; 236 | use utilities::rustfft::FftPlanner; 237 | use utilities::{assert_float_closeness, gen_random_signal, gen_random_signal_f32}; 238 | 239 | use super::*; 240 | 241 | macro_rules! non_power_of_2_planner { 242 | ($test_name:ident, $planner:ty) => { 243 | #[should_panic] 244 | #[test] 245 | fn $test_name() { 246 | let num_points = 5; 247 | 248 | // this test _should_ always fail at this stage 249 | let _ = <$planner>::new(num_points, Direction::Forward); 250 | } 251 | }; 252 | } 253 | 254 | non_power_of_2_planner!(non_power_of_2_planner_32, Planner32); 255 | non_power_of_2_planner!(non_power_of_2_planner_64, Planner64); 256 | 257 | macro_rules! wrong_num_points_in_planner { 258 | ($test_name:ident, $planner:ty, $fft_with_opts_and_plan:ident) => { 259 | // A regression test to make sure the `Planner` is compatible with fft execution. 260 | #[should_panic] 261 | #[test] 262 | fn $test_name() { 263 | let n = 16; 264 | let num_points = 1 << n; 265 | 266 | // We purposely set n = 16 and pass it to the planner. 267 | // n == 16 == 2^{4} is clearly a power of two, so the planner won't throw it out. 268 | // However, the call to `fft_with_opts_and_plan` should panic since it tests that the 269 | // size of the generated twiddle factors is half the size of the input. 270 | // In this case, we have an input of size 1024 (used for mp3), but we tell the planner the 271 | // input size is 16. 272 | let mut planner = <$planner>::new(n, Direction::Forward); 273 | 274 | let mut reals = vec![0.0; num_points]; 275 | let mut imags = vec![0.0; num_points]; 276 | let opts = Options::guess_options(reals.len()); 277 | 278 | // this call should panic 279 | $fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); 280 | } 281 | }; 282 | } 283 | 284 | wrong_num_points_in_planner!( 285 | wrong_num_points_in_planner_32, 286 | Planner32, 287 | fft_32_with_opts_and_plan 288 | ); 289 | wrong_num_points_in_planner!( 290 | wrong_num_points_in_planner_64, 291 | Planner64, 292 | fft_64_with_opts_and_plan 293 | ); 294 | 295 | macro_rules! test_fft_correctness { 296 | ($test_name:ident, $precision:ty, $fft_type:ident, $range_start:literal, $range_end:literal) => { 297 | #[test] 298 | fn $test_name() { 299 | let range = Range { 300 | start: $range_start, 301 | end: $range_end, 302 | }; 303 | 304 | for k in range { 305 | let n: usize = 1 << k; 306 | 307 | let mut reals: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); 308 | let mut imags: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); 309 | $fft_type(&mut reals, &mut imags, Direction::Forward); 310 | 311 | let mut buffer: Vec> = (1..=n) 312 | .map(|i| Complex::new(i as $precision, i as $precision)) 313 | .collect(); 314 | 315 | let mut planner = FftPlanner::new(); 316 | let fft = planner.plan_fft_forward(buffer.len()); 317 | fft.process(&mut buffer); 318 | 319 | reals 320 | .iter() 321 | .zip(imags.iter()) 322 | .enumerate() 323 | .for_each(|(i, (z_re, z_im))| { 324 | let expect_re = buffer[i].re; 325 | let expect_im = buffer[i].im; 326 | assert_float_closeness(*z_re, expect_re, 0.01); 327 | assert_float_closeness(*z_im, expect_im, 0.01); 328 | }); 329 | } 330 | } 331 | }; 332 | } 333 | 334 | test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9); 335 | test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17); 336 | 337 | #[cfg(feature = "complex-nums")] 338 | #[test] 339 | fn fft_interleaved_correctness() { 340 | let n = 10; 341 | let big_n = 1 << n; 342 | let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f64, 0.0)).collect(); 343 | let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f64).collect(); 344 | let mut expected_imags = vec![0.0; big_n]; 345 | 346 | fft_64_interleaved(&mut actual_signal, Direction::Forward); 347 | fft_64(&mut expected_reals, &mut expected_imags, Direction::Forward); 348 | 349 | actual_signal 350 | .iter() 351 | .zip(expected_reals) 352 | .zip(expected_imags) 353 | .for_each(|((z, z_re), z_im)| { 354 | assert_float_closeness(z.re, z_re, 1e-10); 355 | assert_float_closeness(z.im, z_im, 1e-10); 356 | }); 357 | 358 | let n = 10; 359 | let big_n = 1 << n; 360 | let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f32, 0.0)).collect(); 361 | let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f32).collect(); 362 | let mut expected_imags = vec![0.0; big_n]; 363 | 364 | fft_32_interleaved(&mut actual_signal, Direction::Forward); 365 | fft_32(&mut expected_reals, &mut expected_imags, Direction::Forward); 366 | 367 | actual_signal 368 | .iter() 369 | .zip(expected_reals) 370 | .zip(expected_imags) 371 | .for_each(|((z, z_re), z_im)| { 372 | assert_float_closeness(z.re, z_re, 1e-10); 373 | assert_float_closeness(z.im, z_im, 1e-10); 374 | }); 375 | } 376 | 377 | #[test] 378 | fn fft_round_trip() { 379 | for i in 4..23 { 380 | let big_n = 1 << i; 381 | let mut reals = vec![0.0; big_n]; 382 | let mut imags = vec![0.0; big_n]; 383 | 384 | gen_random_signal(&mut reals, &mut imags); 385 | 386 | let original_reals = reals.clone(); 387 | let original_imags = imags.clone(); 388 | 389 | // Forward FFT 390 | fft_64(&mut reals, &mut imags, Direction::Forward); 391 | 392 | // Inverse FFT 393 | fft_64(&mut reals, &mut imags, Direction::Reverse); 394 | 395 | // Ensure we get back the original signal within some tolerance 396 | for ((orig_re, orig_im), (res_re, res_im)) in original_reals 397 | .into_iter() 398 | .zip(original_imags.into_iter()) 399 | .zip(reals.into_iter().zip(imags.into_iter())) 400 | { 401 | assert_float_closeness(res_re, orig_re, 1e-6); 402 | assert_float_closeness(res_im, orig_im, 1e-6); 403 | } 404 | } 405 | } 406 | 407 | #[test] 408 | fn fft_64_with_opts_and_plan_vs_fft_64() { 409 | let num_points = 4096; 410 | 411 | let mut reals = vec![0.0; num_points]; 412 | let mut imags = vec![0.0; num_points]; 413 | gen_random_signal(&mut reals, &mut imags); 414 | 415 | let mut re = reals.clone(); 416 | let mut im = imags.clone(); 417 | 418 | let planner = Planner64::new(num_points, Direction::Forward); 419 | let opts = Options::guess_options(reals.len()); 420 | fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); 421 | 422 | fft_64(&mut re, &mut im, Direction::Forward); 423 | 424 | reals 425 | .iter() 426 | .zip(imags.iter()) 427 | .zip(re.iter()) 428 | .zip(im.iter()) 429 | .for_each(|(((r, i), z_re), z_im)| { 430 | assert_float_closeness(*r, *z_re, 1e-6); 431 | assert_float_closeness(*i, *z_im, 1e-6); 432 | }); 433 | } 434 | 435 | #[test] 436 | fn fft_32_with_opts_and_plan_vs_fft_64() { 437 | let dirs = [Direction::Forward, Direction::Reverse]; 438 | 439 | for direction in dirs { 440 | for n in 4..14 { 441 | let num_points = 1 << n; 442 | let mut reals = vec![0.0; num_points]; 443 | let mut imags = vec![0.0; num_points]; 444 | gen_random_signal_f32(&mut reals, &mut imags); 445 | 446 | let mut re = reals.clone(); 447 | let mut im = imags.clone(); 448 | 449 | let planner = Planner32::new(num_points, direction); 450 | let opts = Options::guess_options(reals.len()); 451 | fft_32_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); 452 | 453 | fft_32(&mut re, &mut im, direction); 454 | 455 | reals 456 | .iter() 457 | .zip(imags.iter()) 458 | .zip(re.iter()) 459 | .zip(im.iter()) 460 | .for_each(|(((r, i), z_re), z_im)| { 461 | assert_float_closeness(*r, *z_re, 1e-6); 462 | assert_float_closeness(*i, *z_im, 1e-6); 463 | }); 464 | } 465 | } 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /src/options.rs: -------------------------------------------------------------------------------- 1 | //! Options to tune to improve performance depending on the hardware and input size. 2 | 3 | /// Calling FFT routines without specifying options will automatically select reasonable defaults 4 | /// depending on the input size and other factors. 5 | /// 6 | /// You only need to tune these options if you are trying to squeeze maximum performance 7 | /// out of a known hardware platform that you can benchmark at varying input sizes. 8 | #[non_exhaustive] 9 | #[derive(Debug, Clone, Default)] 10 | pub struct Options { 11 | /// Whether to run the bit reversal step in 2 threads instead of one. 12 | /// This is beneficial only at large input sizes (i.e. gigabytes of data). 13 | /// The exact threshold where it starts being beneficial varies depending on the hardware. 14 | pub multithreaded_bit_reversal: bool, 15 | } 16 | 17 | impl Options { 18 | /// Attempt to guess the best settings to use for optimal FFT 19 | pub fn guess_options(input_size: usize) -> Options { 20 | let mut options = Options::default(); 21 | let n: usize = input_size.ilog2() as usize; 22 | options.multithreaded_bit_reversal = n >= 22; 23 | options 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/planner.rs: -------------------------------------------------------------------------------- 1 | //! The planner module provides a convenient interface for planning and executing 2 | //! a Fast Fourier Transform (FFT). Currently, the planner is responsible for 3 | //! pre-computing twiddle factors based on the input signal length, as well as the 4 | //! direction of the FFT. 5 | use crate::twiddles::{generate_twiddles, generate_twiddles_simd_32, generate_twiddles_simd_64}; 6 | 7 | /// Reverse is for running the Inverse Fast Fourier Transform (IFFT) 8 | /// Forward is for running the regular FFT 9 | #[derive(Copy, Clone)] 10 | pub enum Direction { 11 | /// Leave the exponent term in the twiddle factor alone 12 | Forward = 1, 13 | /// Multiply the exponent term in the twiddle factor by -1 14 | Reverse = -1, 15 | } 16 | 17 | macro_rules! impl_planner_for { 18 | ($struct_name:ident, $precision:ident, $generate_twiddles_simd_fn:ident) => { 19 | /// The planner is responsible for pre-computing and storing twiddle factors for all the 20 | /// `log_2(N)` stages of the FFT. 21 | /// The amount of twiddle factors should always be a power of 2. In addition, 22 | /// the amount of twiddle factors should always be `(1/2) * N` 23 | pub struct $struct_name { 24 | /// The real components of the twiddle factors 25 | pub twiddles_re: Vec<$precision>, 26 | /// The imaginary components of the twiddle factors 27 | pub twiddles_im: Vec<$precision>, 28 | /// The direction of the FFT associated with this `Planner` 29 | pub direction: Direction, 30 | } 31 | impl $struct_name { 32 | /// Create a `Planner` for an FFT of size `num_points`. 33 | /// The twiddle factors are pre-computed based on the provided [`Direction`]. 34 | /// For `Forward`, use [`Direction::Forward`]. 35 | /// For `Reverse`, use [`Direction::Reverse`]. 36 | /// 37 | /// # Panics 38 | /// 39 | /// Panics if `num_points < 1` or if `num_points` is __not__ a power of 2. 40 | pub fn new(num_points: usize, direction: Direction) -> Self { 41 | assert!(num_points > 0 && num_points.is_power_of_two()); 42 | let dir = match direction { 43 | Direction::Forward => Direction::Forward, 44 | Direction::Reverse => Direction::Reverse, 45 | }; 46 | 47 | if num_points <= 4 { 48 | return Self { 49 | twiddles_re: vec![], 50 | twiddles_im: vec![], 51 | direction: dir, 52 | }; 53 | } 54 | 55 | let dist = num_points >> 1; 56 | 57 | let (twiddles_re, twiddles_im) = if dist >= 8 * 2 { 58 | $generate_twiddles_simd_fn(dist, Direction::Forward) 59 | } else { 60 | generate_twiddles(dist, Direction::Forward) 61 | }; 62 | 63 | assert_eq!(twiddles_re.len(), twiddles_im.len()); 64 | 65 | Self { 66 | twiddles_re, 67 | twiddles_im, 68 | direction: dir, 69 | } 70 | } 71 | 72 | pub(crate) fn num_twiddles(&self) -> usize { 73 | assert_eq!(self.twiddles_re.len(), self.twiddles_im.len()); 74 | self.twiddles_re.len() 75 | } 76 | } 77 | }; 78 | } 79 | 80 | impl_planner_for!(Planner64, f64, generate_twiddles_simd_64); 81 | impl_planner_for!(Planner32, f32, generate_twiddles_simd_32); 82 | 83 | #[cfg(test)] 84 | mod tests { 85 | use super::*; 86 | 87 | macro_rules! test_no_twiddles { 88 | ($test_name:ident, $planner:ty) => { 89 | #[test] 90 | fn $test_name() { 91 | for num_points in [2, 4] { 92 | let planner = <$planner>::new(num_points, Direction::Forward); 93 | assert!(planner.twiddles_im.is_empty() && planner.twiddles_re.is_empty()); 94 | } 95 | } 96 | }; 97 | } 98 | 99 | test_no_twiddles!(no_twiddles_64, Planner64); 100 | test_no_twiddles!(no_twiddles_32, Planner32); 101 | } 102 | -------------------------------------------------------------------------------- /src/twiddles.rs: -------------------------------------------------------------------------------- 1 | use std::simd::{f32x8, f64x8}; 2 | 3 | use num_traits::{Float, FloatConst, One, Zero}; 4 | 5 | use crate::planner::Direction; 6 | 7 | /// This isn't really used except for testing. 8 | /// It may be better to use this in the case where the input size is very large, 9 | /// as to free up the cache. 10 | pub(crate) struct Twiddles { 11 | st: T, 12 | ct: T, 13 | w_re_prev: T, 14 | w_im_prev: T, 15 | } 16 | 17 | impl Twiddles { 18 | /// `cache_size` is the amount of roots of unity kept pre-built at any point in time. 19 | /// `num_roots` is the total number of roots of unity that will need to be computed. 20 | /// `cache_size` can be thought of as the length of a chunk of roots of unity from 21 | /// out of the total amount (i.e., `num_roots`) 22 | #[allow(dead_code)] 23 | pub fn new(num_roots: usize) -> Self { 24 | let theta = -T::PI() / (T::from(num_roots).unwrap()); 25 | let (st, ct) = theta.sin_cos(); 26 | Self { 27 | st, 28 | ct, 29 | w_re_prev: T::one(), 30 | w_im_prev: T::zero(), 31 | } 32 | } 33 | } 34 | 35 | // TODO: generate twiddles using the first quarter chunk of twiddle factors 36 | // 1st chunk: old fashioned multiplication of complex nums 37 | // 2nd chunk: reverse the 1st chunk, swap components, and negate both components 38 | // 3rd chunk: No reversal. Swap the components and negate the *new* imaginary components 39 | // 4th chunk: reverse the 1st chunk, and negate the real component 40 | impl Iterator for Twiddles { 41 | type Item = (T, T); 42 | 43 | fn next(&mut self) -> Option<(T, T)> { 44 | let w_re = self.w_re_prev; 45 | let w_im = self.w_im_prev; 46 | 47 | let temp = self.w_re_prev; 48 | self.w_re_prev = temp * self.ct - self.w_im_prev * self.st; 49 | self.w_im_prev = temp * self.st + self.w_im_prev * self.ct; 50 | 51 | Some((w_re, w_im)) 52 | } 53 | } 54 | 55 | pub fn generate_twiddles( 56 | dist: usize, 57 | direction: Direction, 58 | ) -> (Vec, Vec) { 59 | let mut twiddles_re = vec![T::zero(); dist]; 60 | let mut twiddles_im = vec![T::zero(); dist]; 61 | twiddles_re[0] = T::one(); 62 | 63 | let sign = match direction { 64 | Direction::Forward => T::one(), 65 | Direction::Reverse => -T::one(), 66 | }; 67 | 68 | let angle = sign * -T::PI() / T::from(dist).unwrap(); 69 | let (st, ct) = angle.sin_cos(); 70 | let (mut w_re, mut w_im) = (T::one(), T::zero()); 71 | 72 | let mut i = 1; 73 | while i < (dist / 2) + 1 { 74 | let temp = w_re; 75 | w_re = w_re * ct - w_im * st; 76 | w_im = temp * st + w_im * ct; 77 | twiddles_re[i] = w_re; 78 | twiddles_im[i] = w_im; 79 | i += 1; 80 | } 81 | 82 | while i < dist { 83 | twiddles_re[i] = -twiddles_re[dist - i]; 84 | twiddles_im[i] = twiddles_im[dist - i]; 85 | i += 1; 86 | } 87 | 88 | (twiddles_re, twiddles_im) 89 | } 90 | 91 | macro_rules! generate_twiddles_simd { 92 | ($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => { 93 | pub(crate) fn $func_name( 94 | dist: usize, 95 | direction: Direction, 96 | ) -> (Vec<$precision>, Vec<$precision>) { 97 | const CHUNK_SIZE: usize = 8; // TODO: make this a const generic? 98 | assert!(dist >= CHUNK_SIZE * 2); 99 | assert_eq!(dist % CHUNK_SIZE, 0); 100 | let mut twiddles_re = vec![0.0; dist]; 101 | let mut twiddles_im = vec![0.0; dist]; 102 | twiddles_re[0] = 1.0; 103 | 104 | let sign = match direction { 105 | Direction::Forward => 1.0, 106 | Direction::Reverse => -1.0, 107 | }; 108 | 109 | let angle = sign * -<$precision>::PI() / dist as $precision; 110 | let (st, ct) = angle.sin_cos(); 111 | let (mut w_re, mut w_im) = (<$precision>::one(), <$precision>::zero()); 112 | 113 | let mut next_twiddle = || { 114 | let temp = w_re; 115 | w_re = w_re * ct - w_im * st; 116 | w_im = temp * st + w_im * ct; 117 | (w_re, w_im) 118 | }; 119 | 120 | let apply_symmetry_re = |input: &[$precision], output: &mut [$precision]| { 121 | let first_re = <$simd_vector>::from_slice(input); 122 | let minus_one = <$simd_vector>::splat(-1.0); 123 | let negated = (first_re * minus_one).reverse(); 124 | output.copy_from_slice(negated.as_array()); 125 | }; 126 | 127 | let apply_symmetry_im = |input: &[$precision], output: &mut [$precision]| { 128 | let mut buf: [$precision; CHUNK_SIZE] = [0.0; CHUNK_SIZE]; 129 | buf.copy_from_slice(input); 130 | buf.reverse(); 131 | output.copy_from_slice(&buf); 132 | }; 133 | 134 | // Split the twiddles into two halves. There is a cheaper way to calculate the second half 135 | let (first_half_re, second_half_re) = twiddles_re[1..].split_at_mut(dist / 2); 136 | assert_eq!(first_half_re.len(), second_half_re.len() + 1); 137 | let (first_half_im, second_half_im) = twiddles_im[1..].split_at_mut(dist / 2); 138 | assert_eq!(first_half_im.len(), second_half_im.len() + 1); 139 | 140 | first_half_re 141 | .chunks_exact_mut(CHUNK_SIZE) 142 | .zip(first_half_im.chunks_exact_mut(CHUNK_SIZE)) 143 | .zip( 144 | second_half_re[CHUNK_SIZE - 1..] 145 | .chunks_exact_mut(CHUNK_SIZE) 146 | .rev(), 147 | ) 148 | .zip( 149 | second_half_im[CHUNK_SIZE - 1..] 150 | .chunks_exact_mut(CHUNK_SIZE) 151 | .rev(), 152 | ) 153 | .for_each( 154 | |(((first_ch_re, first_ch_im), second_ch_re), second_ch_im)| { 155 | // Calculate a chunk of the first half in a plain old scalar way 156 | first_ch_re 157 | .iter_mut() 158 | .zip(first_ch_im.iter_mut()) 159 | .for_each(|(re, im)| { 160 | (*re, *im) = next_twiddle(); 161 | }); 162 | // Calculate a chunk of the second half in a clever way by copying the first chunk 163 | // This avoids data dependencies of the regular calculation and gets vectorized. 164 | // We do it up front while the values we just calculated are still in the cache, 165 | // so we don't have to re-load them from memory later, which would be slow. 166 | apply_symmetry_re(first_ch_re, second_ch_re); 167 | apply_symmetry_im(first_ch_im, second_ch_im); 168 | }, 169 | ); 170 | 171 | // Fill in the middle that the SIMD loop didn't 172 | twiddles_re[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1] 173 | .iter_mut() 174 | .zip(twiddles_im[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1].iter_mut()) 175 | .for_each(|(re, im)| { 176 | (*re, *im) = next_twiddle(); 177 | }); 178 | 179 | (twiddles_re, twiddles_im) 180 | } 181 | }; 182 | } 183 | 184 | generate_twiddles_simd!(generate_twiddles_simd_64, f64, 8, f64x8); 185 | generate_twiddles_simd!(generate_twiddles_simd_32, f32, 8, f32x8); 186 | 187 | #[multiversion::multiversion( 188 | targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 189 | "x86_64+avx2+fma", // x86_64-v3 190 | "x86_64+sse4.2", // x86_64-v2 191 | "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", 192 | "x86+avx2+fma", 193 | "x86+sse4.2", 194 | "x86+sse2", 195 | ))] 196 | #[inline] 197 | pub(crate) fn filter_twiddles(twiddles_re: &[T], twiddles_im: &[T]) -> (Vec, Vec) { 198 | assert_eq!(twiddles_re.len(), twiddles_im.len()); 199 | let dist = twiddles_re.len(); 200 | 201 | let filtered_twiddles_re: Vec = twiddles_re.chunks_exact(2).map(|chunk| chunk[0]).collect(); 202 | let filtered_twiddles_im: Vec = twiddles_im.chunks_exact(2).map(|chunk| chunk[0]).collect(); 203 | 204 | assert!( 205 | filtered_twiddles_re.len() == filtered_twiddles_im.len() 206 | && filtered_twiddles_re.len() == dist / 2 207 | ); 208 | 209 | (filtered_twiddles_re, filtered_twiddles_im) 210 | } 211 | 212 | #[cfg(test)] 213 | mod tests { 214 | use std::f64::consts::FRAC_1_SQRT_2; 215 | 216 | use utilities::assert_float_closeness; 217 | 218 | use super::*; 219 | 220 | // TODO(saveliy): try to use only real twiddle factors since sin is just a phase shift of cos 221 | #[test] 222 | fn twiddles_cos_only() { 223 | let n = 4; 224 | let big_n = 1 << n; 225 | 226 | let dist = big_n >> 1; 227 | 228 | let (fwd_twiddles_re, fwd_twiddles_im) = if dist >= 8 * 2 { 229 | generate_twiddles_simd_64(dist, Direction::Forward) 230 | } else { 231 | generate_twiddles(dist, Direction::Forward) 232 | }; 233 | 234 | assert!(fwd_twiddles_re.len() == dist && fwd_twiddles_im.len() == dist); 235 | 236 | for i in 0..dist { 237 | let _w_re = fwd_twiddles_re[i]; 238 | let expected_w_im = fwd_twiddles_im[i]; 239 | 240 | let actual_w_im = -fwd_twiddles_re[(i + dist / 2) % dist]; 241 | //assert_float_closeness(actual_w_im, expected_w_im, 1e-6); 242 | println!("actual: {actual_w_im} expected: {expected_w_im}"); 243 | } 244 | println!("{:?}", fwd_twiddles_re); 245 | println!("{:?}", fwd_twiddles_im); 246 | } 247 | 248 | #[test] 249 | fn twiddles_4() { 250 | const N: usize = 4; 251 | let mut twiddle_iter = Twiddles::new(N); 252 | 253 | let (w_re, w_im) = twiddle_iter.next().unwrap(); 254 | println!("{w_re} {w_im}"); 255 | assert_float_closeness(w_re, 1.0, 1e-10); 256 | assert_float_closeness(w_im, 0.0, 1e-10); 257 | 258 | let (w_re, w_im) = twiddle_iter.next().unwrap(); 259 | println!("{w_re} {w_im}"); 260 | assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-10); 261 | assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); 262 | 263 | let (w_re, w_im) = twiddle_iter.next().unwrap(); 264 | println!("{w_re} {w_im}"); 265 | assert_float_closeness(w_re, 0.0, 1e-10); 266 | assert_float_closeness(w_im, -1.0, 1e-10); 267 | 268 | let (w_re, w_im) = twiddle_iter.next().unwrap(); 269 | println!("{w_re} {w_im}"); 270 | assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); 271 | assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); 272 | } 273 | 274 | macro_rules! test_twiddles_simd { 275 | ($test_name:ident, $generate_twiddles_simd:ident, $epsilon:literal) => { 276 | #[test] 277 | fn $test_name() { 278 | for n in 4..25 { 279 | let dist = 1 << n; 280 | 281 | let (twiddles_re_ref, twiddles_im_ref) = 282 | generate_twiddles(dist, Direction::Forward); 283 | let (twiddles_re, twiddles_im) = 284 | $generate_twiddles_simd(dist, Direction::Forward); 285 | 286 | twiddles_re 287 | .iter() 288 | .zip(twiddles_re_ref.iter()) 289 | .for_each(|(simd, reference)| { 290 | assert_float_closeness(*simd, *reference, $epsilon); 291 | }); 292 | 293 | twiddles_im 294 | .iter() 295 | .zip(twiddles_im_ref.iter()) 296 | .for_each(|(simd, reference)| { 297 | assert_float_closeness(*simd, *reference, $epsilon); 298 | }); 299 | } 300 | } 301 | }; 302 | } 303 | 304 | test_twiddles_simd!(twiddles_simd_32, generate_twiddles_simd_32, 1e-1); 305 | test_twiddles_simd!(twiddles_simd_64, generate_twiddles_simd_64, 1e-10); 306 | 307 | #[test] 308 | fn twiddles_filter() { 309 | // Assume n = 28 310 | let n = 28; 311 | 312 | // distance := 2^{n} / 2 == 2^{n-1} 313 | let dist = 1 << (n - 1); 314 | 315 | let mut twiddles_iter = Twiddles::new(dist); 316 | 317 | let (twiddles_re, twiddles_im) = generate_twiddles(dist, Direction::Forward); 318 | 319 | for i in 0..dist { 320 | let (w_re, w_im) = twiddles_iter.next().unwrap(); 321 | assert_float_closeness(twiddles_re[i], w_re, 1e-6); 322 | assert_float_closeness(twiddles_im[i], w_im, 1e-6); 323 | } 324 | 325 | let (mut tw_re, mut tw_im) = (twiddles_re.clone(), twiddles_im.clone()); 326 | 327 | for t in (0..n - 1).rev() { 328 | let dist = 1 << t; 329 | let mut twiddles_iter = Twiddles::new(dist); 330 | 331 | // Don't re-compute all the twiddles. 332 | // Just filter them out by taking every other twiddle factor 333 | (tw_re, tw_im) = filter_twiddles(&tw_re, &tw_im); 334 | 335 | assert!(tw_re.len() == dist && tw_im.len() == dist); 336 | 337 | for i in 0..dist { 338 | let (w_re, w_im) = twiddles_iter.next().unwrap(); 339 | assert_float_closeness(tw_re[i], w_re, 1e-6); 340 | assert_float_closeness(tw_im[i], w_im, 1e-6); 341 | } 342 | } 343 | } 344 | 345 | macro_rules! forward_mul_inverse_eq_identity { 346 | ($test_name:ident, $generate_twiddles_simd_fn:ident) => { 347 | #[test] 348 | fn $test_name() { 349 | for i in 3..25 { 350 | let num_points = 1 << i; 351 | let dist = num_points >> 1; 352 | 353 | let (fwd_twiddles_re, fwd_twiddles_im) = if dist >= 8 * 2 { 354 | $generate_twiddles_simd_fn(dist, Direction::Forward) 355 | } else { 356 | generate_twiddles(dist, Direction::Forward) 357 | }; 358 | 359 | assert_eq!(fwd_twiddles_re.len(), fwd_twiddles_im.len()); 360 | 361 | let (rev_twiddles_re, rev_twiddles_im) = if dist >= 8 * 2 { 362 | $generate_twiddles_simd_fn(dist, Direction::Reverse) 363 | } else { 364 | generate_twiddles(dist, Direction::Reverse) 365 | }; 366 | 367 | assert_eq!(rev_twiddles_re.len(), rev_twiddles_im.len()); 368 | 369 | // (a + ib) (c + id) = ac + iad + ibc - bd 370 | // = ac - bd + i(ad + bc) 371 | fwd_twiddles_re 372 | .iter() 373 | .zip(fwd_twiddles_im.iter()) 374 | .zip(rev_twiddles_re.iter()) 375 | .zip(rev_twiddles_im.iter()) 376 | .for_each(|(((a, b), c), d)| { 377 | let temp_re = a * c - b * d; 378 | let temp_im = a * d + b * c; 379 | assert_float_closeness(temp_re, 1.0, 1e-2); 380 | assert_float_closeness(temp_im, 0.0, 1e-2); 381 | }); 382 | } 383 | } 384 | }; 385 | } 386 | 387 | forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, generate_twiddles_simd_64); 388 | forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, generate_twiddles_simd_32); 389 | } 390 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Utility functions such as interleave/deinterleave 2 | 3 | #[cfg(feature = "complex-nums")] 4 | use num_complex::Complex; 5 | 6 | #[cfg(feature = "complex-nums")] 7 | use num_traits::Float; 8 | 9 | #[cfg(feature = "complex-nums")] 10 | use bytemuck::cast_slice; 11 | 12 | use std::simd::{prelude::Simd, simd_swizzle, SimdElement}; 13 | 14 | // We don't multiversion for AVX-512 here and keep the chunk size below AVX-512 15 | // because we haven't seen any gains from it in benchmarks. 16 | // This might be due to us running benchmarks on Zen4 which implements AVX-512 17 | // on top of 256-bit wide execution units. 18 | // 19 | // If benchmarks on "real" AVX-512 show improvement on AVX-512 20 | // without degrading AVX2 machines due to larger chunk size, 21 | // the AVX-512 specialization should be re-enabled. 22 | #[multiversion::multiversion( 23 | targets( 24 | "x86_64+avx2+fma", // x86_64-v3 25 | "x86_64+sse4.2", // x86_64-v2 26 | "x86+avx2+fma", 27 | "x86+sse4.2", 28 | "x86+sse2", 29 | ))] 30 | /// Separates data like `[1, 2, 3, 4]` into `([1, 3], [2, 4])` for any length 31 | pub(crate) fn deinterleave(input: &[T]) -> (Vec, Vec) { 32 | const CHUNK_SIZE: usize = 4; 33 | const DOUBLE_CHUNK: usize = CHUNK_SIZE * 2; 34 | 35 | let out_len = input.len() / 2; 36 | // We've benchmarked, and it turns out that this approach with zeroed memory 37 | // is faster than using uninit memory and bumping the length once in a while! 38 | let mut out_odd = vec![T::default(); out_len]; 39 | let mut out_even = vec![T::default(); out_len]; 40 | 41 | input 42 | .chunks_exact(DOUBLE_CHUNK) 43 | .zip(out_odd.chunks_exact_mut(CHUNK_SIZE)) 44 | .zip(out_even.chunks_exact_mut(CHUNK_SIZE)) 45 | .for_each(|((in_chunk, odds), evens)| { 46 | let in_simd: Simd = Simd::from_array(in_chunk.try_into().unwrap()); 47 | // This generates *slightly* faster code than just assigning values by index. 48 | // You'd think simd::deinterleave would be appropriate, but it does something different! 49 | let result = simd_swizzle!(in_simd, [0, 2, 4, 6, 1, 3, 5, 7]); 50 | let result_arr = result.to_array(); 51 | odds.copy_from_slice(&result_arr[..CHUNK_SIZE]); 52 | evens.copy_from_slice(&result_arr[CHUNK_SIZE..]); 53 | }); 54 | 55 | // Process the remainder, too small for the vectorized loop 56 | let input_rem = input.chunks_exact(DOUBLE_CHUNK).remainder(); 57 | let odds_rem = out_odd.chunks_exact_mut(CHUNK_SIZE).into_remainder(); 58 | let evens_rem = out_even.chunks_exact_mut(CHUNK_SIZE).into_remainder(); 59 | input_rem 60 | .chunks_exact(2) 61 | .zip(odds_rem.iter_mut()) 62 | .zip(evens_rem.iter_mut()) 63 | .for_each(|((inp, odd), even)| { 64 | *odd = inp[0]; 65 | *even = inp[1]; 66 | }); 67 | 68 | (out_odd, out_even) 69 | } 70 | 71 | /// Utility function to separate a slice of [`Complex64``] 72 | /// into a single vector of Complex Number Structs. 73 | /// 74 | /// # Panics 75 | /// 76 | /// Panics if `reals.len() != imags.len()`. 77 | #[cfg(feature = "complex-nums")] 78 | pub(crate) fn deinterleave_complex64(signal: &[Complex]) -> (Vec, Vec) { 79 | let complex_t: &[f64] = cast_slice(signal); 80 | deinterleave(complex_t) 81 | } 82 | 83 | /// Utility function to separate a slice of [`Complex32``] 84 | /// into a single vector of Complex Number Structs. 85 | /// 86 | /// # Panics 87 | /// 88 | /// Panics if `reals.len() != imags.len()`. 89 | #[cfg(feature = "complex-nums")] 90 | pub(crate) fn deinterleave_complex32(signal: &[Complex]) -> (Vec, Vec) { 91 | let complex_t: &[f32] = cast_slice(signal); 92 | deinterleave(complex_t) 93 | } 94 | 95 | /// Utility function to combine separate vectors of real and imaginary components 96 | /// into a single vector of Complex Number Structs. 97 | /// 98 | /// # Panics 99 | /// 100 | /// Panics if `reals.len() != imags.len()`. 101 | #[cfg(feature = "complex-nums")] 102 | pub(crate) fn combine_re_im(reals: &[T], imags: &[T]) -> Vec> { 103 | assert_eq!(reals.len(), imags.len()); 104 | 105 | reals 106 | .iter() 107 | .zip(imags.iter()) 108 | .map(|(z_re, z_im)| Complex::new(*z_re, *z_im)) 109 | .collect() 110 | } 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use super::*; 115 | 116 | fn gen_test_vec(len: usize) -> Vec { 117 | (0..len).collect() 118 | } 119 | 120 | /// Slow but obviously correct implementation of deinterleaving, 121 | /// to be used in tests 122 | fn deinterleave_naive(input: &[T]) -> (Vec, Vec) { 123 | input.chunks_exact(2).map(|c| (c[0], c[1])).unzip() 124 | } 125 | 126 | #[test] 127 | fn deinterleaving_correctness() { 128 | for len in [0, 1, 2, 3, 15, 16, 17, 127, 128, 129, 130, 135, 100500] { 129 | let input = gen_test_vec(len); 130 | let (naive_a, naive_b) = deinterleave_naive(&input); 131 | let (opt_a, opt_b) = deinterleave(&input); 132 | assert_eq!(naive_a, opt_a); 133 | assert_eq!(naive_b, opt_b); 134 | } 135 | } 136 | 137 | #[cfg(feature = "complex-nums")] 138 | #[test] 139 | fn test_separate_and_combine_re_im() { 140 | let complex_vec: Vec<_> = vec![ 141 | Complex::new(1.0, 2.0), 142 | Complex::new(3.0, 4.0), 143 | Complex::new(5.0, 6.0), 144 | Complex::new(7.0, 8.0), 145 | ]; 146 | 147 | let (reals, imags) = deinterleave_complex64(&complex_vec); 148 | 149 | let recombined_vec = combine_re_im(&reals, &imags); 150 | 151 | assert_eq!(complex_vec, recombined_vec); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /utilities/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "utilities" 3 | version = "0.2.1" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | rand = "0.8.5" 10 | rustfft = "6.2.0" -------------------------------------------------------------------------------- /utilities/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub extern crate rustfft; 2 | 3 | use std::f64::consts::PI; 4 | use std::fmt::Display; 5 | 6 | use rand::distributions::Uniform; 7 | use rand::prelude::*; 8 | use rustfft::num_traits::Float; 9 | 10 | /// Asserts that two floating-point numbers are approximately equal. 11 | /// 12 | /// # Panics 13 | /// 14 | /// Panics if `actual` and `expected` are too far from each other 15 | #[allow(dead_code)] 16 | #[track_caller] 17 | pub fn assert_float_closeness(actual: T, expected: T, epsilon: T) { 18 | if (actual - expected).abs() >= epsilon { 19 | panic!( 20 | "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})", 21 | ); 22 | } 23 | } 24 | 25 | pub fn gen_random_signal_f32(reals: &mut [f32], imags: &mut [f32]) { 26 | assert!(reals.len() == imags.len() && !reals.is_empty()); 27 | let mut rng = thread_rng(); 28 | let between = Uniform::from(0.0..1.0); 29 | let angle_dist = Uniform::from(0.0..2.0 * (PI as f32)); 30 | let num_amps = reals.len(); 31 | 32 | let mut probs: Vec<_> = (0..num_amps).map(|_| between.sample(&mut rng)).collect(); 33 | 34 | let total: f32 = probs.iter().sum(); 35 | let total_recip = total.recip(); 36 | 37 | probs.iter_mut().for_each(|p| *p *= total_recip); 38 | 39 | let angles = (0..num_amps).map(|_| angle_dist.sample(&mut rng)); 40 | 41 | probs 42 | .iter() 43 | .zip(angles) 44 | .enumerate() 45 | .for_each(|(i, (p, a))| { 46 | let p_sqrt = p.sqrt(); 47 | let (sin_a, cos_a) = a.sin_cos(); 48 | let re = p_sqrt * cos_a; 49 | let im = p_sqrt * sin_a; 50 | reals[i] = re; 51 | imags[i] = im; 52 | }); 53 | } 54 | 55 | /// Generate a random, complex, signal in the provided buffers 56 | /// 57 | /// # Panics 58 | /// 59 | /// Panics if `reals.len() != imags.len()` 60 | pub fn gen_random_signal(reals: &mut [f64], imags: &mut [f64]) { 61 | assert!(reals.len() == imags.len() && !reals.is_empty()); 62 | let mut rng = thread_rng(); 63 | let between = Uniform::from(0.0..1.0); 64 | let angle_dist = Uniform::from(0.0..2.0 * PI); 65 | let num_amps = reals.len(); 66 | 67 | let mut probs: Vec<_> = (0..num_amps).map(|_| between.sample(&mut rng)).collect(); 68 | 69 | let total: f64 = probs.iter().sum(); 70 | let total_recip = total.recip(); 71 | 72 | probs.iter_mut().for_each(|p| *p *= total_recip); 73 | 74 | let angles = (0..num_amps).map(|_| angle_dist.sample(&mut rng)); 75 | 76 | probs 77 | .iter() 78 | .zip(angles) 79 | .enumerate() 80 | .for_each(|(i, (p, a))| { 81 | let p_sqrt = p.sqrt(); 82 | let (sin_a, cos_a) = a.sin_cos(); 83 | let re = p_sqrt * cos_a; 84 | let im = p_sqrt * sin_a; 85 | reals[i] = re; 86 | imags[i] = im; 87 | }); 88 | } 89 | 90 | #[cfg(test)] 91 | mod tests { 92 | use super::*; 93 | 94 | #[test] 95 | fn generate_random_signal() { 96 | let big_n = 1 << 25; 97 | let mut reals = vec![0.0; big_n]; 98 | let mut imags = vec![0.0; big_n]; 99 | 100 | gen_random_signal(&mut reals, &mut imags); 101 | 102 | let sum: f64 = reals 103 | .iter() 104 | .zip(imags.iter()) 105 | .map(|(re, im)| re.powi(2) + im.powi(2)) 106 | .sum(); 107 | 108 | assert_f64_closeness(sum, 1.0, 1e-6); 109 | } 110 | } 111 | --------------------------------------------------------------------------------