├── .github ├── codecov.yml ├── dependabot.yml └── workflows │ ├── check.yml │ ├── scheduled.yml │ └── test.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── ROADMAP.md ├── build.rs ├── src ├── lib.rs ├── odr │ ├── mod.rs │ └── polynomial.rs ├── optimize │ ├── criteria.rs │ ├── least_square │ │ └── mod.rs │ ├── metric.rs │ ├── min_scalar │ │ ├── golden.rs │ │ └── mod.rs │ ├── mod.rs │ ├── root_scalar │ │ ├── bracket.rs │ │ ├── fixed_point.rs │ │ ├── halley.rs │ │ ├── mod.rs │ │ ├── newton.rs │ │ ├── polynomial.rs │ │ └── secant.rs │ └── util.rs ├── signal │ ├── band_filter.rs │ ├── convolution │ │ └── mod.rs │ ├── error.rs │ ├── filter_design │ │ ├── bessel.rs │ │ ├── butter.rs │ │ ├── butterord.rs │ │ ├── cheby1.rs │ │ ├── cheby2.rs │ │ ├── ellip.rs │ │ ├── error.rs │ │ └── mod.rs │ ├── fir_filter_design │ │ ├── firwin1.rs │ │ ├── mod.rs │ │ ├── pass_zero.rs │ │ ├── tools.rs │ │ └── windows.rs │ ├── mod.rs │ ├── output_type │ │ ├── ba.rs │ │ ├── mod.rs │ │ ├── sos.rs │ │ └── zpk.rs │ ├── signal_tools.rs │ └── tools │ │ └── mod.rs ├── special │ ├── kv.rs │ ├── mod.rs │ └── trig.rs └── tools │ ├── complex.rs │ └── mod.rs └── tests ├── optimize ├── least_square.rs ├── main.rs ├── min_scaler.rs ├── polynomial.rs └── root_scalar.rs ├── signal ├── bessel.rs ├── butter.rs ├── cheby1.rs ├── cheby2.rs ├── common.rs ├── fir_filter_design.rs ├── fir_filter_design_windows.rs ├── lp2bf_zpk.rs ├── main.rs └── signal_tools.rs └── special.rs /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | # ref: https://docs.codecov.com/docs/codecovyml-reference 2 | coverage: 3 | # Hold ourselves to a high bar 4 | range: 85..100 5 | round: down 6 | precision: 1 7 | status: 8 | # ref: https://docs.codecov.com/docs/commit-status 9 | project: 10 | default: 11 | # Avoid false negatives 12 | threshold: 1% 13 | 14 | # Test files aren't important for coverage 15 | ignore: 16 | - "tests" 17 | 18 | # Make comments less noisy 19 | comment: 20 | layout: "files" 21 | require_changes: yes 22 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: / 5 | schedule: 6 | interval: daily 7 | - package-ecosystem: cargo 8 | directory: / 9 | schedule: 10 | interval: daily 11 | ignore: 12 | - dependency-name: "*" 13 | # patch and minor updates don't matter for libraries 14 | # remove this ignore rule if your package has binaries 15 | update-types: 16 | - "version-update:semver-patch" 17 | - "version-update:semver-minor" 18 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | permissions: 2 | contents: read 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | # Spend CI time only on latest ref: https://github.com/jonhoo/rust-ci-conf/pull/5 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | name: check 12 | jobs: 13 | doc: 14 | name: Documentation 15 | runs-on: ubuntu-latest 16 | env: 17 | RUSTDOCFLAGS: -Dwarnings 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: dtolnay/rust-toolchain@nightly 21 | - uses: dtolnay/install@cargo-docs-rs 22 | - name: install fortran compiler 23 | run: sudo apt install -y gfortran 24 | - run: cargo docs-rs 25 | fmt: 26 | runs-on: ubuntu-latest 27 | name: stable / fmt 28 | steps: 29 | - uses: actions/checkout@v4 30 | with: 31 | submodules: true 32 | - name: Install stable 33 | uses: dtolnay/rust-toolchain@stable 34 | with: 35 | components: rustfmt 36 | - name: install fortran compiler 37 | run: sudo apt install -y gfortran 38 | - name: cargo fmt --check 39 | run: cargo fmt --check 40 | clippy: 41 | runs-on: ubuntu-latest 42 | name: ${{ matrix.toolchain }} / clippy 43 | permissions: 44 | contents: read 45 | checks: write 46 | strategy: 47 | fail-fast: false 48 | matrix: 49 | toolchain: [stable, beta] 50 | steps: 51 | - uses: actions/checkout@v4 52 | with: 53 | submodules: true 54 | - name: Install ${{ matrix.toolchain }} 55 | uses: dtolnay/rust-toolchain@master 56 | with: 57 | toolchain: ${{ matrix.toolchain }} 58 | components: clippy 59 | - name: install fortran compiler 60 | run: sudo apt install -y gfortran 61 | - name: cargo clippy 62 | uses: actions-rs/clippy-check@v1 63 | with: 64 | token: ${{ secrets.GITHUB_TOKEN }} 65 | hack: 66 | runs-on: ubuntu-latest 67 | name: ubuntu / stable / features 68 | steps: 69 | - uses: actions/checkout@v4 70 | with: 71 | submodules: true 72 | - name: Install stable 73 | uses: dtolnay/rust-toolchain@stable 74 | - name: install fortran compiler 75 | run: sudo apt install -y gfortran 76 | - name: cargo install cargo-hack 77 | uses: taiki-e/install-action@cargo-hack 78 | # intentionally no target specifier; see https://github.com/jonhoo/rust-ci-conf/pull/4 79 | - name: cargo hack 80 | run: cargo hack --feature-powerset check 81 | msrv: 82 | runs-on: ubuntu-latest 83 | # we use a matrix here just because env can't be used in job names 84 | # https://docs.github.com/en/actions/learn-github-actions/contexts#context-availability 85 | strategy: 86 | matrix: 87 | msrv: ["1.64.0"] # 2021 edition requires 1.56 88 | name: ubuntu / ${{ matrix.msrv }} 89 | steps: 90 | - uses: actions/checkout@v4 91 | with: 92 | submodules: true 93 | - name: Install ${{ matrix.msrv }} 94 | uses: dtolnay/rust-toolchain@master 95 | with: 96 | toolchain: ${{ matrix.msrv }} 97 | - name: install fortran compiler 98 | run: sudo apt install -y gfortran 99 | - name: cargo +${{ matrix.msrv }} check 100 | run: cargo check 101 | -------------------------------------------------------------------------------- /.github/workflows/scheduled.yml: -------------------------------------------------------------------------------- 1 | permissions: 2 | contents: read 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | schedule: 8 | - cron: '7 7 * * *' 9 | # Spend CI time only on latest ref: https://github.com/jonhoo/rust-ci-conf/pull/5 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 12 | cancel-in-progress: true 13 | name: rolling 14 | jobs: 15 | # https://twitter.com/mycoliza/status/1571295690063753218 16 | nightly: 17 | runs-on: ubuntu-latest 18 | name: ubuntu / nightly 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | submodules: true 23 | - name: Install nightly 24 | uses: dtolnay/rust-toolchain@nightly 25 | - name: install fortran compiler 26 | run: sudo apt install -y gfortran 27 | - name: Install python modules 28 | run: pip3 install scipy pandas numpy 29 | - name: cargo generate-lockfile 30 | if: hashFiles('Cargo.lock') == '' 31 | run: cargo generate-lockfile 32 | - name: cargo test --locked 33 | run: cargo test --locked --all-features --all-targets 34 | # https://twitter.com/alcuadrado/status/1571291687837732873 35 | update: 36 | runs-on: ubuntu-latest 37 | name: ubuntu / beta / updated 38 | # There's no point running this if no Cargo.lock was checked in in the 39 | # first place, since we'd just redo what happened in the regular test job. 40 | # Unfortunately, hashFiles only works in if on steps, so we reepeat it. 41 | # if: hashFiles('Cargo.lock') != '' 42 | steps: 43 | - uses: actions/checkout@v4 44 | with: 45 | submodules: true 46 | - name: Install beta 47 | if: hashFiles('Cargo.lock') != '' 48 | uses: dtolnay/rust-toolchain@beta 49 | - name: install fortran compiler 50 | run: sudo apt install -y gfortran 51 | - name: Install python modules 52 | run: pip3 install scipy pandas numpy 53 | - name: cargo update 54 | if: hashFiles('Cargo.lock') != '' 55 | run: cargo update 56 | - name: cargo test 57 | if: hashFiles('Cargo.lock') != '' 58 | run: cargo test --locked --all-features --all-targets 59 | env: 60 | RUSTFLAGS: -D deprecated 61 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | permissions: 2 | contents: read 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | # Spend CI time only on latest ref: https://github.com/jonhoo/rust-ci-conf/pull/5 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | name: test 12 | jobs: 13 | required: 14 | runs-on: ubuntu-latest 15 | name: ubuntu / ${{ matrix.toolchain }} 16 | strategy: 17 | matrix: 18 | toolchain: [stable, beta] 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | submodules: true 23 | - name: Install ${{ matrix.toolchain }} 24 | uses: dtolnay/rust-toolchain@master 25 | with: 26 | toolchain: ${{ matrix.toolchain }} 27 | - name: install fortran compiler 28 | run: sudo apt install -y gfortran 29 | - name: Install python modules 30 | run: pip3 install scipy pandas numpy 31 | - name: cargo generate-lockfile 32 | if: hashFiles('Cargo.lock') == '' 33 | run: cargo generate-lockfile 34 | # https://twitter.com/jonhoo/status/1571290371124260865 35 | - name: cargo test --locked 36 | run: cargo test --locked --all-features --all-targets 37 | # https://github.com/rust-lang/cargo/issues/6669 38 | - name: cargo test --doc 39 | run: cargo test --locked --all-features --doc 40 | coverage: 41 | runs-on: ubuntu-latest 42 | name: ubuntu / stable / coverage 43 | steps: 44 | - uses: actions/checkout@v4 45 | with: 46 | submodules: true 47 | - name: Install stable 48 | uses: dtolnay/rust-toolchain@stable 49 | with: 50 | components: llvm-tools-preview 51 | - name: cargo install cargo-llvm-cov 52 | uses: taiki-e/install-action@cargo-llvm-cov 53 | - name: cargo generate-lockfile 54 | if: hashFiles('Cargo.lock') == '' 55 | run: cargo generate-lockfile 56 | - name: install fortran compiler 57 | run: sudo apt install -y gfortran 58 | - name: Install python modules 59 | run: pip3 install scipy pandas numpy 60 | - name: cargo llvm-cov 61 | run: cargo llvm-cov --locked --all-features --lcov --output-path lcov.info 62 | - name: Upload to codecov.io 63 | uses: codecov/codecov-action@v3 64 | with: 65 | fail_ci_if_error: true 66 | env: 67 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | /complex_bessel 4 | 5 | **/.DS_Store 6 | README.tpl 7 | publish_version.sh 8 | 9 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sciport-rs" 3 | version = "0.0.3" 4 | edition = "2021" 5 | description = "Rust port of scipy" 6 | authors = ["Christian Belloni"] 7 | readme = "README.md" 8 | license = "MIT" 9 | keywords = ["math", "science", "filter", "dsp"] 10 | categories = ["mathematics", "science", "algorithms"] 11 | homepage = "https://github.com/ChristianBelloni/sciport-rs" 12 | repository = "https://github.com/ChristianBelloni/sciport-rs" 13 | 14 | [dependencies] 15 | num = "0.4" 16 | complex-bessel-rs = { version = "1.2.0" } 17 | ndarray = { version = "0.15.6", features = ["rayon"] } 18 | itertools = "0.13.0" 19 | nalgebra = "0.32.3" 20 | approx = { version = "0.5", features = ["num-complex"] } 21 | thiserror = { version = "1.0" } 22 | #blas-src = { version = "0.9.0", default-features = false, features = ["accelerate"], optional = true } 23 | 24 | [dev-dependencies] 25 | pyo3 = { version = "0.21.2", features = ["full", "auto-initialize"] } 26 | numpy = "0.21.0" 27 | ndarray = { version = "0.15.6", features = ["rayon", "approx-0_5"] } 28 | ndarray-rand = "0.14" 29 | rand = "0.8.5" 30 | lazy_static = "1.4.0" 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Christian Belloni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sciport-rs 2 | 3 | ## Sciport-rs 4 | 5 | Sciport is a collection of mathematical algorithms ported from the popular python package Scipy 6 | 7 | ## Api design 8 | 9 | The main philosophy behind sciport is to change the api surface of scipy to better utilize the 10 | rich rust typesystem, when deciding between keeping the original function signature and 11 | rewriting it to better represent the valid input space, more often than not we'll decide to 12 | change it.
13 | for example this is the scipy butter filter api: 14 | 15 | ```python 16 | scipy.signal.butter(N: int, Wn: array_like, btype: String, analog: bool, output: String, fs: 17 | float) 18 | ``` 19 | 20 | Wn represents a single or a pair of frequencies and btype is the type of filter, 21 | however, a single frequency makes sense only for a subset of btypes and so does a pair, 22 | in our implementation we rewrite this function like: 23 | 24 | ```rust 25 | fn filter(order: u32, band_filter: BandFilter, analog: Sampling) { } 26 | ``` 27 | 28 | where T represents the output representation of the filter (Zpk, Ba, Sos), band_filter 29 | encapsulates the original Wn and btype like this: 30 | 31 | ```rust 32 | 33 | pub enum BandFilter { 34 | Highpass(f64), 35 | Lowpass(f64), 36 | Bandpass { low: f64, high: f64 }, 37 | Bandstop { low: f64, high: f64 }, 38 | } 39 | ``` 40 | 41 | and Sampling encapsulates analog and fs (since a sampling rate makes sense only when talking 42 | about a digital filter) like this: 43 | 44 | ```rust 45 | pub enum Sampling { 46 | Analog, 47 | Digital { 48 | fs: f64 49 | } 50 | } 51 | ``` 52 | 53 | ## Modules 54 | 55 | ### Signal Processing 56 | 57 | The signal processing toolbox currently contains some filtering functions, a limited set of filter design tools, and a few B-spline interpolation algorithms for 1- and 2-D data. While the B-spline algorithms could technically be placed under the interpolation category, they are included here because they only work with equally-spaced data and make heavy use of filter-theory and transfer-function formalism to provide a fast B-spline transform. 58 | 59 | ### Special 60 | 61 | The main feature of this module is the definition of numerous special functions 62 | of mathematical physics. Available functions include airy, elliptic, bessel, gamma, beta, 63 | hypergeometric, parabolic cylinder, mathieu, spheroidal wave, struve, and kelvin. 64 | 65 | 66 | License: MIT 67 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Scipy features: 2 | 3 | ## scipy.signal 4 | 5 | ### Convolution 6 | - [ ] convolve 7 | - [ ] correlate 8 | - [ ] fftconvolve 9 | - [ ] oaconvolve 10 | - [ ] convolve2d 11 | - [ ] correlate2d 12 | - [ ] sepfir2d 13 | - [ ] choose_conv_method 14 | - [ ] correlation_lags 15 | 16 | ### B-splines 17 | - [ ] gauss_spline 18 | - [ ] cspline1d 19 | - [ ] qspline1d 20 | - [ ] cspline2d 21 | - [ ] qspline2d 22 | - [ ] cspline1d_eval 23 | - [ ] qspline1d_eval 24 | - [ ] spline_filter 25 | 26 | ### Filtering 27 | TODO! 28 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | pub fn main() { 2 | #[cfg(all(target_os = "macos", feature = "blas"))] 3 | println!("cargo:rustc-link-lib=framework=Accelerate"); 4 | #[cfg(target_os = "macos")] 5 | println!( 6 | "cargo:rustc-link-arg=-Wl,-rpath,/Library/Developer/CommandLineTools/Library/Frameworks" 7 | ); 8 | } 9 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![warn(clippy::all, clippy::nursery)] 2 | 3 | //! # Sciport-rs 4 | //! 5 | //! Sciport is a collection of mathematical algorithms ported from the popular python package Scipy 6 | //! 7 | //! # Api design 8 | //! 9 | //! The main philosophy behind sciport is to change the api surface of scipy to better utilize the 10 | //! rich rust typesystem, when deciding between keeping the original function signature and 11 | //! rewriting it to better represent the valid input space, more often than not we'll decide to 12 | //! change it.
13 | //! for example this is the scipy butter filter api: 14 | //! 15 | //! ```python 16 | //! scipy.signal.butter(N: int, Wn: array_like, btype: String, analog: bool, output: String, fs: 17 | //! float) 18 | //! ``` 19 | //! 20 | //! Wn represents a single or a pair of frequencies and btype is the type of filter, 21 | //! however, a single frequency makes sense only for a subset of btypes and so does a pair, 22 | //! in our implementation we rewrite this function like: 23 | //! 24 | //! ``` 25 | //! # use sciport_rs::signal::Sampling; 26 | //! # use sciport_rs::signal::band_filter::BandFilter; 27 | //! fn filter(order: u32, band_filter: BandFilter, analog: Sampling) { } 28 | //! ``` 29 | //! 30 | //! where T represents the output representation of the filter (Zpk, Ba, Sos), band_filter 31 | //! encapsulates the original Wn and btype like this: 32 | //! 33 | //! ``` 34 | //! 35 | //! pub enum BandFilter { 36 | //! Highpass(f64), 37 | //! Lowpass(f64), 38 | //! Bandpass { low: f64, high: f64 }, 39 | //! Bandstop { low: f64, high: f64 }, 40 | //! } 41 | //! ``` 42 | //! 43 | //! and Sampling encapsulates analog and fs (since a sampling rate makes sense only when talking 44 | //! about a digital filter) like this: 45 | //! 46 | //! ``` 47 | //! pub enum Sampling { 48 | //! Analog, 49 | //! Digital { 50 | //! fs: f64 51 | //! } 52 | //! } 53 | //! ``` 54 | //! 55 | //! # Modules 56 | //! 57 | //! ## Signal Processing 58 | //! 59 | //! The signal processing toolbox currently contains some filtering functions, a limited set of filter design tools, and a few B-spline interpolation algorithms for 1- and 2-D data. While the B-spline algorithms could technically be placed under the interpolation category, they are included here because they only work with equally-spaced data and make heavy use of filter-theory and transfer-function formalism to provide a fast B-spline transform. 60 | //! 61 | //! ## Special 62 | //! 63 | //! The main feature of this module is the definition of numerous special functions 64 | //! of mathematical physics. Available functions include airy, elliptic, bessel, gamma, beta, 65 | //! hypergeometric, parabolic cylinder, mathieu, spheroidal wave, struve, and kelvin. 66 | //! 67 | #[allow(unused)] 68 | pub mod odr; 69 | #[allow(unused)] 70 | pub mod optimize; 71 | pub mod signal; 72 | pub mod special; 73 | 74 | pub(crate) mod tools; 75 | -------------------------------------------------------------------------------- /src/odr/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod polynomial; 2 | -------------------------------------------------------------------------------- /src/odr/polynomial.rs: -------------------------------------------------------------------------------- 1 | use itertools::{EitherOrBoth, Itertools}; 2 | use nalgebra::{ComplexField, Scalar}; 3 | use num::complex::ComplexFloat; 4 | use num::Float; 5 | use std::fmt::{Debug, Display}; 6 | use std::ops::{Add, Div, Index, Mul, Sub}; 7 | use std::rc::Rc; 8 | 9 | use crate::optimize::least_square; 10 | use crate::optimize::root_scalar::polynomial::{polynomial_roots, IntoComplex}; 11 | use crate::optimize::util::Espilon; 12 | use crate::optimize::{IntoMetric, Metric}; 13 | 14 | /// # `PolynomialCoef` 15 | /// a common trait for polynomial coefficient, just for display purposes 16 | pub trait PolynomialCoef: ComplexFloat + Clone { 17 | fn coef_to_string(&self) -> String; 18 | } 19 | 20 | /// # `Polynomial` 21 | /// a polynomial struct that represented by a `Vec` of coefficient, 22 | /// which can be `f32`,`f64`,`Complex32`, `Complex64` 23 | /// 24 | /// where the i-th coefficient represent the i-th power's coefficient of the polynomial 25 | #[derive(Debug, Clone)] 26 | pub struct Polynomial 27 | where 28 | T: PolynomialCoef, 29 | { 30 | coef: Vec, 31 | } 32 | 33 | impl Polynomial 34 | where 35 | T: PolynomialCoef, 36 | { 37 | /// while the highest power of the polynomial is zero, 38 | /// pop that coefficient of the polynomial 39 | /// 40 | /// if the polynomial ended up to have no coefficient, 41 | /// push a zero to represent a zero constant 42 | pub fn saturate(mut self) -> Self { 43 | while let Some(&c) = self.coef.last() { 44 | if c == T::zero() { 45 | self.coef.pop(); 46 | } else { 47 | break; 48 | } 49 | } 50 | if self.degree() == 0 { 51 | self.coef.push(T::zero()); 52 | } 53 | self 54 | } 55 | /// return the degree of the polynomial, 56 | /// aka the highest power the polynomial consisted of 57 | pub fn degree(&self) -> usize { 58 | self.coef.len().clamp(1, usize::MAX) - 1 59 | } 60 | /// iterate its coefficient, from 0-th power. 61 | pub fn iter(&self) -> std::slice::Iter<'_, T> { 62 | self.coef.iter() 63 | } 64 | /// construct a `Polynomial` for a `Vec` 65 | pub fn from_vec(coef: Vec) -> Self { 66 | Self { coef }.saturate() 67 | } 68 | /// return new polynomial with only a zero constant 69 | pub fn zero() -> Self { 70 | Self::from_vec(vec![T::zero()]) 71 | } 72 | /// return new polynomial with only a one constant 73 | pub fn one() -> Self { 74 | Self::from_vec(vec![T::one()]) 75 | } 76 | /// evaluate the polynomial at `x` 77 | pub fn eval(&self, x: T) -> T { 78 | self.iter().rev().fold(T::zero(), |acc, &c| acc * x + c) 79 | } 80 | /// evaluate the polynomial at `x` for `x` in `xs` 81 | pub fn eval_iter(&self, xs: impl IntoIterator) -> Vec { 82 | xs.into_iter().map(|x| self.eval(x)).collect() 83 | } 84 | /// return the multiply of polynomial by `x^p` 85 | pub fn mul_power(&self, p: usize) -> Self { 86 | vec![T::zero(); p].iter().chain(self.iter()).collect() 87 | } 88 | /// return differentiated polynomial 89 | pub fn differentiate(&self) -> Self { 90 | self.iter() 91 | .enumerate() 92 | .filter_map(|(i, &c)| { 93 | if i == 0 { 94 | None 95 | } else { 96 | Some(T::from(i as i64).unwrap() * c) 97 | } 98 | }) 99 | .collect() 100 | } 101 | /// construct a new polynomial from roots and multiply constant `k` 102 | pub fn from_roots_k(roots: impl IntoIterator, k: T) -> Self { 103 | roots 104 | .into_iter() 105 | .map(|r| Self::from(vec![-r, T::one()])) 106 | .fold(Self::one(), |acc, p| acc * p) 107 | * k 108 | } 109 | /// return the deflated polynomial using horner's method 110 | /// 111 | /// it return the quotient polynomial and the remainder scalar 112 | /// 113 | /// 114 | pub fn deflate(&self, x: T) -> Option<(Self, T)> { 115 | let result = self 116 | .iter() 117 | .rev() 118 | .scan(T::zero(), |carry, &coef| { 119 | let new_coef = coef + *carry; 120 | *carry = new_coef * x; 121 | Some(new_coef) 122 | }) 123 | .collect::>(); 124 | let (remainder, quotient) = result.split_last()?; 125 | Some((quotient.iter().rev().collect(), remainder.to_owned())) 126 | } 127 | 128 | /// find all root of the polynomial, 129 | /// 130 | /// where all its root will be in complex number data structure 131 | /// i.e. `Complex32` or `Complexf64` 132 | #[must_use] 133 | pub fn roots(&self) -> Vec 134 | where 135 | T: PolynomialCoef + Espilon + IntoMetric + IntoComplex, 136 | C: PolynomialCoef + From + Espilon + IntoMetric + IntoComplex, 137 | M: Metric, 138 | { 139 | polynomial_roots(self) 140 | } 141 | /// calculate the polynomial least square curve fit on data `x` and `y` 142 | /// see `sciport_rs::optimize::least_square::poly_fit` 143 | pub fn poly_fit<'a, Q>( 144 | x: impl IntoIterator, 145 | y: impl IntoIterator, 146 | order: usize, 147 | ) -> Result 148 | where 149 | T: Debug + Display + ComplexField + PolynomialCoef + Espilon, 150 | Q: ComplexFloat + Scalar + Debug + Espilon, 151 | { 152 | least_square::poly_fit(x, y, order) 153 | } 154 | } 155 | 156 | impl Polynomial 157 | where 158 | T: PolynomialCoef, 159 | { 160 | /// take ownership and package the polynomial into `RcT>` 161 | #[must_use] 162 | pub fn as_rc(self) -> Rc T> { 163 | Rc::new(move |x| self.eval(x)) 164 | } 165 | 166 | pub fn as_fn(self) -> impl Fn(T) -> T { 167 | move |x| self.eval(x) 168 | } 169 | } 170 | 171 | impl Index for Polynomial 172 | where 173 | T: PolynomialCoef, 174 | { 175 | type Output = T; 176 | fn index(&self, index: usize) -> &Self::Output { 177 | &self.coef[index] 178 | } 179 | } 180 | 181 | impl Mul for Polynomial 182 | where 183 | T: PolynomialCoef, 184 | { 185 | type Output = Self; 186 | fn mul(self, rhs: T) -> Self::Output { 187 | self.iter().map(|&c| c * rhs).collect() 188 | } 189 | } 190 | 191 | #[allow(clippy::suspicious_arithmetic_impl)] 192 | impl Div for Polynomial 193 | where 194 | T: PolynomialCoef, 195 | { 196 | type Output = Self; 197 | fn div(self, rhs: T) -> Self::Output { 198 | self * rhs.recip() 199 | } 200 | } 201 | 202 | impl Add for Polynomial 203 | where 204 | T: PolynomialCoef, 205 | { 206 | type Output = Self; 207 | fn add(self, rhs: Self) -> Self::Output { 208 | self.iter() 209 | .zip_longest(rhs.iter()) 210 | .map(|pair| match pair { 211 | EitherOrBoth::Both(&a, &b) => a + b, 212 | EitherOrBoth::Left(&a) => a, 213 | EitherOrBoth::Right(&b) => b, 214 | }) 215 | .collect() 216 | } 217 | } 218 | 219 | impl Mul for Polynomial 220 | where 221 | T: PolynomialCoef, 222 | { 223 | type Output = Self; 224 | fn mul(self, rhs: Self) -> Self::Output { 225 | self.iter() 226 | .enumerate() 227 | .map(|(i, &c)| rhs.mul_power(i) * c) 228 | .fold(Self::zero(), |acc, p| acc + p) 229 | } 230 | } 231 | 232 | impl Sub for Polynomial 233 | where 234 | T: PolynomialCoef, 235 | { 236 | type Output = Self; 237 | fn sub(self, rhs: Self) -> Self::Output { 238 | self + rhs * (T::zero() - T::one()) 239 | } 240 | } 241 | 242 | impl Display for Polynomial 243 | where 244 | T: PolynomialCoef, 245 | { 246 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 247 | write!( 248 | f, 249 | "{}", 250 | self.iter() 251 | .enumerate() 252 | .map(|(i, &c)| { 253 | format!( 254 | "{}{}", 255 | c.coef_to_string(), 256 | if i == 0 { 257 | String::new() 258 | } else { 259 | format!(" * x^{i:<3}") 260 | } 261 | ) 262 | }) 263 | .collect::>() 264 | .join(" + ") 265 | ) 266 | } 267 | } 268 | 269 | impl FromIterator for Polynomial 270 | where 271 | T: PolynomialCoef, 272 | { 273 | fn from_iter>(iter: I) -> Self { 274 | Self { 275 | coef: Vec::from_iter(iter), 276 | } 277 | .saturate() 278 | } 279 | } 280 | 281 | impl<'a, T> FromIterator<&'a T> for Polynomial 282 | where 283 | T: PolynomialCoef + 'a, 284 | { 285 | fn from_iter>(iter: I) -> Self { 286 | iter.into_iter().copied().collect() 287 | } 288 | } 289 | 290 | impl From<&[T]> for Polynomial 291 | where 292 | T: PolynomialCoef, 293 | { 294 | fn from(value: &[T]) -> Self { 295 | value.iter().collect() 296 | } 297 | } 298 | 299 | impl From> for Polynomial 300 | where 301 | T: PolynomialCoef, 302 | { 303 | fn from(value: Vec) -> Self { 304 | Self::from_vec(value) 305 | } 306 | } 307 | 308 | impl IntoIterator for Polynomial 309 | where 310 | T: PolynomialCoef, 311 | { 312 | type Item = T; 313 | type IntoIter = std::vec::IntoIter; 314 | fn into_iter(self) -> Self::IntoIter { 315 | self.coef.into_iter() 316 | } 317 | } 318 | 319 | impl PolynomialCoef for T 320 | where 321 | T: ComplexFloat + Clone, 322 | { 323 | fn coef_to_string(&self) -> String { 324 | "coeff doesn't implement display".into() 325 | } 326 | } 327 | 328 | // impl PolynomialCoef for Complex32 { 329 | // fn coef_to_string(&self) -> String { 330 | // format!("({:9.2} + {:9.2}i)", self.re, self.im) 331 | // } 332 | // } 333 | // impl PolynomialCoef for Complex64 { 334 | // fn coef_to_string(&self) -> String { 335 | // format!("({:9.2} + {:9.2}i)", self.re, self.im) 336 | // } 337 | // } 338 | 339 | // impl PolynomialCoef for Complex 340 | // where 341 | // Complex: ComplexFloat + Clone + Debug + 'static, 342 | // T: std::fmt::Display, 343 | // { 344 | // fn coef_to_string(&self) -> String { 345 | // format!("({:9.2} + {:9.2}i)", self.re, self.im) 346 | // } 347 | // } 348 | -------------------------------------------------------------------------------- /src/optimize/criteria.rs: -------------------------------------------------------------------------------- 1 | use crate::optimize::*; 2 | 3 | /// Criteria 4 | #[derive(Debug, Clone)] 5 | pub struct OptimizeCriteria 6 | where 7 | X: IntoMetric, 8 | F: IntoMetric, 9 | M: Metric, 10 | { 11 | /// Satisfies xatol if `|x-x'| < xatol` 12 | pub xatol: Option, 13 | /// Satisfies xrtol if `|x-x'| < xrtol * x'` 14 | pub xrtol: Option, 15 | /// Satisfies fatol if `|f-f'| < fatol` 16 | pub fatol: Option, 17 | /// Satisfies frtol if `|f-f'| < frtol * f'` 18 | pub frtol: Option, 19 | /// Satisfies fltol if `|f-target_f| < fatol` 20 | pub fltol: Option, 21 | /// Fail if `iter > maxiter` 22 | pub maxiter: Option, 23 | /// specify the metric evaluation type for x 24 | pub x_metric_type: MetricType, 25 | /// specify the metric evaluation type for f 26 | pub f_metric_type: MetricType, 27 | } 28 | 29 | /// Default `xatol` 30 | const DEFAULT_XATOL: f64 = 1e-9; 31 | /// Default `xrtol` 32 | const DEFAULT_XRTOL: f64 = 1e-100; 33 | /// Default `fatol` 34 | const DEFAULT_FATOL: f64 = 1e-9; 35 | /// Default `frtol` 36 | const DEFAULT_FRTOL: f64 = 1e-100; 37 | /// Default `fltol` 38 | const DEFAULT_FLTOL: f64 = 1e-9; 39 | /// Default `maxiter` 40 | const DEFAULT_MAXITER: u64 = 1000; 41 | 42 | impl OptimizeCriteria 43 | where 44 | X: IntoMetric, 45 | F: IntoMetric, 46 | M: Metric, 47 | { 48 | /// Builder Pattern for setting `xatol` 49 | pub fn set_xatol(mut self, tol: Option) -> Self { 50 | self.xatol = tol; 51 | self 52 | } 53 | /// Builder Pattern for setting `xrtol` 54 | pub fn set_xrtol(mut self, tol: Option) -> Self { 55 | self.xrtol = tol; 56 | self 57 | } 58 | /// Builder Pattern for setting `fatol` 59 | pub fn set_fatol(mut self, tol: Option) -> Self { 60 | self.fatol = tol; 61 | self 62 | } 63 | /// Builder Pattern for setting `frtol` 64 | pub fn set_frtol(mut self, tol: Option) -> Self { 65 | self.frtol = tol; 66 | self 67 | } 68 | /// Builder Pattern for setting `fltol` 69 | pub fn set_fltol(mut self, tol: Option) -> Self { 70 | self.fltol = tol; 71 | self 72 | } 73 | /// Builder Pattern for setting `maxiter` 74 | pub fn set_maxiter(mut self, max: Option) -> Self { 75 | self.maxiter = max; 76 | self 77 | } 78 | /// Builder Pattern for setting `x_metric_type` 79 | pub fn set_x_metric_type(mut self, metric_type: MetricType) -> Self { 80 | self.x_metric_type = metric_type; 81 | self 82 | } 83 | /// Builder Pattern for setting `f_metric_type` 84 | pub fn set_f_metric_type(mut self, metric_type: MetricType) -> Self { 85 | self.f_metric_type = metric_type; 86 | self 87 | } 88 | 89 | /// Create a new criteria with no parameter, and default `Metric::L2Norm` for both x and f 90 | pub fn empty() -> Self { 91 | Self { 92 | xatol: None, 93 | xrtol: None, 94 | fatol: None, 95 | frtol: None, 96 | fltol: None, 97 | maxiter: None, 98 | x_metric_type: MetricType::L2Norm, 99 | f_metric_type: MetricType::L2Norm, 100 | } 101 | } 102 | } 103 | 104 | impl Default for OptimizeCriteria 105 | where 106 | X: IntoMetric, 107 | F: IntoMetric, 108 | M: Metric, 109 | { 110 | fn default() -> Self { 111 | OptimizeCriteria { 112 | xatol: M::from(DEFAULT_XATOL), 113 | xrtol: M::from(DEFAULT_XRTOL), 114 | fatol: M::from(DEFAULT_FATOL), 115 | frtol: M::from(DEFAULT_FRTOL), 116 | fltol: M::from(DEFAULT_FLTOL), 117 | maxiter: Some(DEFAULT_MAXITER), 118 | x_metric_type: MetricType::L2Norm, 119 | f_metric_type: MetricType::L2Norm, 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/optimize/least_square/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Display}; 2 | 3 | use crate::odr::polynomial::{Polynomial, PolynomialCoef}; 4 | use crate::optimize::util::Espilon; 5 | use nalgebra::{ComplexField, Scalar}; 6 | use num::complex::ComplexFloat; 7 | 8 | /// # poly_fit 9 | /// calculate the polynomial least square curve fit on data `x` and `y` 10 | /// 11 | /// `order` define the order of the returned polynomial. 12 | /// 13 | /// - if `order < x.len()`, the result will be least square curve fitting 14 | /// 15 | /// - if `order = x.len()`, the result will be exact polynomial solve 16 | /// 17 | /// - if `order > x.len()`, the result will be least sqaure of the coefficient of polynomial 18 | /// 19 | /// ## Example 20 | /// ``` 21 | /// # use sciport_rs::optimize::least_square::poly_fit; 22 | /// 23 | /// let x = vec![1.0,2.0,3.0]; 24 | /// let y = vec![2.0,1.0,2.0]; 25 | /// 26 | /// let poly = poly_fit(&x,&y,2).unwrap(); 27 | /// ``` 28 | /// 29 | /// ## Errors 30 | /// This function will return an error 31 | /// - if the lenght if `x` and `y` are not the same. 32 | /// - the svd solve fail. 33 | pub fn poly_fit<'a, T, Q>( 34 | x: impl IntoIterator, 35 | y: impl IntoIterator, 36 | order: usize, 37 | ) -> Result, String> 38 | where 39 | T: Debug + Display + ComplexField + PolynomialCoef + Espilon, 40 | Q: ComplexFloat + Scalar + Debug + Espilon, 41 | { 42 | let x = x.into_iter().collect::>(); 43 | let y = y.into_iter().collect::>(); 44 | 45 | if x.len() != y.len() { 46 | return Err(format!( 47 | "lsq_linear failed due to: len of x: {} and y: {} are not equal", 48 | x.len(), 49 | y.len() 50 | )); 51 | } 52 | 53 | let y = nalgebra::DVector::from_iterator(y.len(), y.into_iter().cloned()); 54 | 55 | let rows = (0..x.len()) 56 | .map(move |i| { 57 | nalgebra::DVector::from_iterator( 58 | order + 1, 59 | (0..(order + 1)).map(|a| ComplexFloat::powi(*x[i], a as i32)), 60 | ) 61 | .transpose() 62 | }) 63 | .collect::>(); 64 | 65 | nalgebra::DMatrix::from_rows(rows.as_slice()) 66 | .svd(true, true) 67 | .solve(&y, Q::epsilon()) 68 | .map(|s| s.into_iter().cloned().collect()) 69 | .map_err(|e| format!("lsq_linear failed due to: {}", e)) 70 | } 71 | -------------------------------------------------------------------------------- /src/optimize/metric.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Display}; 2 | use std::ops::{Add, Sub}; 3 | use std::rc::Rc; 4 | 5 | use num::complex::{Complex32, Complex64, ComplexFloat}; 6 | use num::traits::FloatConst; 7 | use num::{Complex, Float, NumCast, One, Zero}; 8 | 9 | /// # `Metric` 10 | /// `Metric` is trait for float, which all compare the optimizing solution to the allowed tolerance 11 | /// 12 | /// It is implemented for `f32` and `f64` 13 | pub trait Metric: Float + Sized + Clone + Debug {} 14 | impl Metric for f32 {} 15 | impl Metric for f64 {} 16 | 17 | /// # MetricType 18 | /// Different type of method to measure the norm of a certain type 19 | #[derive(Clone)] 20 | pub enum MetricType 21 | where 22 | T: IntoMetric, 23 | M: Metric, 24 | { 25 | /// powered sum of all element 26 | PowerSum(M), 27 | /// L1 norm 28 | L1Norm, 29 | /// L2 norm 30 | L2Norm, 31 | /// p norm 32 | PNorm(M), 33 | /// mean square 34 | MS, 35 | /// root mean square 36 | RMS, 37 | /// custom function 38 | Custom(Rc M>), 39 | } 40 | 41 | impl Debug for MetricType 42 | where 43 | T: IntoMetric, 44 | M: Metric + Debug, 45 | { 46 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 47 | f.write_fmt(format_args!("MetricType::{}", { 48 | match self { 49 | Self::PNorm(p) => format!("PNorm-{:?}", p), 50 | Self::L1Norm => "L1-Norm".to_string(), 51 | Self::L2Norm => "L2-Norm".to_string(), 52 | Self::MS => "MS".to_string(), 53 | Self::RMS => "RMS".to_string(), 54 | Self::PowerSum(p) => format!("PowerSum-{:?}", p), 55 | Self::Custom(_) => "Custom".to_string(), 56 | } 57 | })) 58 | } 59 | } 60 | 61 | /// # `IntoMetric` 62 | /// `IntoMetric` is a trait for evaluating optimization solution. 63 | /// 64 | /// Since optimization solution is not always comparable with the specified tolerance, 65 | /// e.g. the solution with type `Array` and tolerance metric with type `f64` 66 | /// 67 | /// this trait allow all optimizaition solution to be turn into a `Metric`. Implemented for: 68 | /// ```ignore 69 | /// f32, f64, Complex32, Complex64, Array1, Array1, Array1, Array1 70 | /// ``` 71 | /// 72 | /// Type with this trait must meet bound `Sub`, 73 | /// since its nesserary to compare new and old solution in iterative optimization. 74 | pub trait IntoMetric 75 | where 76 | Self: Sub + Add + Sized + Clone + Debug, 77 | M: Metric, 78 | { 79 | /// return the total number of element in the type for calculating mean. 80 | fn n(&self) -> M; 81 | /// return the sum of powered `p` of element as a metric. 82 | fn power_sum(&self, p: M) -> M; 83 | /// return the L1 norm as a metric. 84 | fn l1_norm(&self) -> M { 85 | self.p_norm(M::one()) 86 | } 87 | /// return the L2 norm as a metric. 88 | fn l2_norm(&self) -> M { 89 | self.p_norm(M::from(2).unwrap()) 90 | } 91 | /// return the p norm as a metric. 92 | fn p_norm(&self, p: M) -> M { 93 | self.power_sum(p).powf(p.recip()) 94 | } 95 | /// return the mean square as a metric. 96 | fn ms(&self) -> M { 97 | self.power_sum(M::from(2).unwrap()) / self.n() 98 | } 99 | /// return the root mean square as a metric. 100 | fn rms(&self) -> M { 101 | self.l2_norm() / self.n().powi(2) 102 | } 103 | /// return evaluate by specified metric type as a metric. 104 | fn eval(&self, m: &MetricType) -> M 105 | where 106 | Self: Sized, 107 | { 108 | match m { 109 | MetricType::PowerSum(p) => self.power_sum(*p), 110 | MetricType::L1Norm => self.l2_norm(), 111 | MetricType::L2Norm => self.l1_norm(), 112 | MetricType::PNorm(p) => self.p_norm(*p), 113 | MetricType::MS => self.ms(), 114 | MetricType::RMS => self.rms(), 115 | MetricType::Custom(f) => f(self), 116 | } 117 | } 118 | } 119 | 120 | impl IntoMetric for S 121 | where 122 | S: Float + Sized + Clone, 123 | S: Metric, 124 | { 125 | fn n(&self) -> S { 126 | S::one() 127 | } 128 | 129 | fn power_sum(&self, p: S) -> S { 130 | self.abs().powf(p) 131 | } 132 | } 133 | 134 | impl IntoMetric for Complex 135 | where 136 | S: Float + FloatConst, 137 | S: Metric, 138 | { 139 | fn n(&self) -> S { 140 | S::one() 141 | } 142 | 143 | fn power_sum(&self, p: S) -> S { 144 | S::from(self.abs().powf(p)).unwrap() 145 | } 146 | } 147 | 148 | impl IntoMetric for Array1 149 | where 150 | S: IntoMetric, 151 | M: Metric, 152 | { 153 | fn n(&self) -> M { 154 | M::from(self.len()).unwrap() 155 | } 156 | 157 | fn power_sum(&self, p: M) -> M { 158 | self.fold(M::zero(), |prd, x| prd + x.power_sum(p)) 159 | } 160 | } 161 | 162 | /// A macro for implementing `IntoMetric` for `f32`,`f64`, `Complex32` and `Complexf64`, into metric `f32` and `f64` respectively. 163 | macro_rules! impl_metric_complexfloat { 164 | ($trait_name:ident,$type_name:ident, $t:ty, $m:ty) => { 165 | impl IntoMetric<$m> for $t { 166 | fn n(&self) -> $m { 167 | <$m>::one() 168 | } 169 | fn power_sum(&self, p: $m) -> $m { 170 | (self).abs().powf(p) 171 | } 172 | } 173 | pub type $type_name = MetricType<$t, $m>; 174 | pub trait $trait_name: IntoMetric<$m> {} 175 | }; 176 | } 177 | 178 | // impl_metric_complexfloat!(R32IntoMetric, R32MetricType, f32, f32); 179 | // impl_metric_complexfloat!(R64IntoMetric, R64MetricType, f64, f64); 180 | // impl_metric_complexfloat!(Z32IntoMetric, Z32MetricType, Complex32, f32); 181 | // impl_metric_complexfloat!(Z64IntoMetric, Z64MetricType, Complex64, f64); 182 | 183 | use ndarray::Array1; 184 | 185 | /// A macro for implementing `IntoMetric` for `Array1` with type `f32`,`f64` 186 | /// , `Complex32` and `Complexf64`, into metric `f32` and `f64` respectively. 187 | macro_rules! impl_metric_array1 { 188 | ($trait_name:ident,$type_name:ident, $t:ty, $m:tt) => { 189 | impl IntoMetric<$m> for $t { 190 | fn n(&self) -> $m { 191 | $m::from(self.len() as u8) 192 | } 193 | fn power_sum(&self, p: $m) -> $m { 194 | self.fold($m::zero(), |prd, x| prd + x.abs().powf(p)) 195 | } 196 | } 197 | pub type $type_name = MetricType<$t, $m>; 198 | pub trait $trait_name: IntoMetric<$m> {} 199 | }; 200 | } 201 | 202 | // impl_metric_array1!(ArrayR32IntoMetric, Array1R32MetricType, Array1, f32); 203 | // impl_metric_array1!(ArrayR64IntoMetric, Array1R64MetricType, Array1, f64); 204 | // impl_metric_array1!( 205 | // ArrayZ32IntoMetric, 206 | // Array1Z32MetricType, 207 | // Array1, 208 | // f32 209 | // ); 210 | // impl_metric_array1!( 211 | // ArrayZ64IntoMetric, 212 | // Array1Z64MetricType, 213 | // Array1, 214 | // f64 215 | // ); 216 | -------------------------------------------------------------------------------- /src/optimize/min_scalar/golden.rs: -------------------------------------------------------------------------------- 1 | use std::io::Read; 2 | 3 | use crate::optimize::min_scalar::*; 4 | use crate::optimize::util::*; 5 | 6 | const GOLDEN_RATIO: f64 = 0.618_033_988_7; 7 | 8 | pub fn golden_method( 9 | fun: F, 10 | bracket: (R, R), 11 | criteria: Option>, 12 | ) -> OptimizeResult 13 | where 14 | R: IntoMetric + Float + Espilon, 15 | M: Metric, 16 | F: Fn(R) -> R, 17 | { 18 | let evaluator = MinScalarEvaluator::new(criteria); 19 | let evaluator = Rc::new(RefCell::new(evaluator)); 20 | 21 | let fun = { 22 | let evaluator = evaluator.clone(); 23 | move |x| { 24 | evaluator.borrow_mut().res.fev(); 25 | fun(x) 26 | } 27 | }; 28 | 29 | let fun = Box::new(fun); 30 | 31 | let bracket_x = bracket; 32 | let bracket_f = (fun(bracket_x.0), fun(bracket_x.1)); 33 | 34 | let solver = GoldenSolver::new(fun, bracket_x, bracket_f); 35 | 36 | iterative_optimize(solver, evaluator) 37 | } 38 | 39 | pub struct GoldenSolver { 40 | fun: F, 41 | a: R, 42 | x: Option, 43 | y: Option, 44 | b: R, 45 | fa: R, 46 | fx: Option, 47 | fy: Option, 48 | fb: R, 49 | } 50 | 51 | impl GoldenSolver { 52 | fn new(fun: F, bracket_x: (R, R), bracket_f: (R, R)) -> Self { 53 | Self { 54 | fun, 55 | a: bracket_x.0, 56 | x: None, 57 | y: None, 58 | b: bracket_x.1, 59 | fa: bracket_f.0, 60 | fx: None, 61 | fy: None, 62 | fb: bracket_f.1, 63 | } 64 | } 65 | } 66 | 67 | impl IterativeSolver for GoldenSolver 68 | where 69 | R: IntoMetric + Float + Espilon, 70 | M: Metric, 71 | F: Fn(R) -> R, 72 | { 73 | fn new_solution(&mut self) -> (R, R, Option, Option) { 74 | let d = (self.b - self.a) * R::from(GOLDEN_RATIO).unwrap(); 75 | let x = self.x.unwrap_or_else(|| self.b - d); 76 | let fx = self.fx.unwrap_or_else(|| (self.fun)(x)); 77 | 78 | let y = self.y.unwrap_or_else(|| self.a + d); 79 | let fy = self.fy.unwrap_or_else(|| (self.fun)(y)); 80 | 81 | if fy < fx { 82 | self.a = x; 83 | self.fa = fx; 84 | self.x = Some(y); 85 | self.fx = Some(fy); 86 | self.y = None; 87 | self.fy = None; 88 | 89 | (y, fy, None, None) 90 | } else { 91 | self.b = y; 92 | self.fb = fy; 93 | self.y = Some(x); 94 | self.fy = Some(fx); 95 | self.x = None; 96 | self.fx = None; 97 | 98 | (x, fx, None, None) 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/optimize/min_scalar/mod.rs: -------------------------------------------------------------------------------- 1 | use num::Float; 2 | 3 | use crate::optimize::*; 4 | 5 | pub mod golden; 6 | 7 | struct MinScalarEvaluator 8 | where 9 | R: IntoMetric + Float, 10 | M: Metric, 11 | { 12 | res: OptimizeResult, 13 | } 14 | 15 | impl Evaluator for MinScalarEvaluator 16 | where 17 | R: IntoMetric + Float, 18 | M: Metric, 19 | { 20 | fn new(criteria: Option>) -> Self 21 | where 22 | Self: Sized, 23 | { 24 | let res = OptimizeResult::default().set_criteria(criteria.unwrap_or_default()); 25 | Self { res } 26 | } 27 | 28 | fn update(&mut self, new_solution: (R, R, Option, Option)) { 29 | self.res.update(new_solution); 30 | } 31 | 32 | fn eval(&mut self) { 33 | if let (Some(o), Some(s)) = (self.result().old_f, self.result().sol_f) { 34 | if o < s { 35 | self.result_mut().set_failure( 36 | "Minimization fail, function is not strictly unimodal function".to_string(), 37 | ); 38 | return; 39 | } 40 | } 41 | 42 | if self.res.old_x == self.res.sol_x { 43 | return; 44 | } 45 | 46 | if self.result().satisfy_either() { 47 | let satisfied = self.result_mut().satisfied().join(", "); 48 | self.result_mut() 49 | .set_success(format!("Minimization successful, {} satisfied", satisfied)); 50 | } else if self.result().overran() { 51 | let maxiter = self.result_mut().criteria.maxiter; 52 | self.result_mut() 53 | .set_failure(format!("Minimization fail, max iter {:?} reached", maxiter)); 54 | } 55 | } 56 | 57 | fn result(&self) -> &OptimizeResult { 58 | &self.res 59 | } 60 | fn result_mut(&mut self) -> &mut OptimizeResult { 61 | &mut self.res 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/optimize/mod.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::fmt::Display; 3 | use std::rc::Rc; 4 | 5 | pub mod criteria; 6 | pub mod least_square; 7 | pub mod metric; 8 | pub mod min_scalar; 9 | pub mod root_scalar; 10 | pub mod util; 11 | 12 | pub use criteria::*; 13 | pub use least_square::*; 14 | pub use metric::*; 15 | 16 | /// Iterative Optimize 17 | pub fn iterative_optimize( 18 | mut solver: S, 19 | evaluator: Rc>, 20 | ) -> OptimizeResult 21 | where 22 | X: IntoMetric, 23 | F: IntoMetric, 24 | J: IntoMetric, 25 | H: IntoMetric, 26 | M: Metric, 27 | S: IterativeSolver, 28 | E: Evaluator, 29 | { 30 | while evaluator.borrow().flag().is_running() { 31 | let s = solver.new_solution(); 32 | evaluator.borrow_mut().update(s); 33 | evaluator.borrow_mut().eval(); 34 | } 35 | evaluator.borrow().result().clone() 36 | } 37 | 38 | /// # Iterative Solver 39 | /// the iterative solver that continue gaving out new solution for evaluator. 40 | pub trait IterativeSolver 41 | where 42 | X: IntoMetric, 43 | F: IntoMetric, 44 | J: IntoMetric, 45 | H: IntoMetric, 46 | M: Metric, 47 | { 48 | fn new_solution(&mut self) -> (X, F, Option, Option); 49 | } 50 | 51 | /// # Evaluator 52 | /// evaluate solution given by `IterativeSolver` 53 | pub trait Evaluator 54 | where 55 | X: IntoMetric, 56 | F: IntoMetric, 57 | J: IntoMetric, 58 | H: IntoMetric, 59 | M: Metric, 60 | { 61 | /// create a new evaluator with a criteria 62 | fn new(criteria: Option>) -> Self 63 | where 64 | Self: Sized; 65 | /// update self from the newsolution 66 | fn update(&mut self, new_solution: (X, F, Option, Option)); 67 | /// evaluate, change self.flag 68 | fn eval(&mut self) { 69 | if self.result().satisfy_either() { 70 | let satisfied = self.result_mut().satisfied().join(", "); 71 | self.result_mut() 72 | .set_success(format!("Root finding successful, {} satisfied", satisfied)); 73 | } else if self.result().overran() { 74 | let maxiter = self.result_mut().criteria.maxiter; 75 | self.result_mut() 76 | .set_failure(format!("Root finding fail, max iter {:?} reached", maxiter)); 77 | } 78 | } 79 | /// get the result as reference 80 | fn result(&self) -> &OptimizeResult; 81 | /// get the result as mutable reference 82 | fn result_mut(&mut self) -> &mut OptimizeResult; 83 | /// get the flag 84 | fn flag(&self) -> OptimizeResultFlag { 85 | self.result().flag.clone() 86 | } 87 | } 88 | 89 | /// # OptimizeResultFlag 90 | /// indicate if the result is success, failure or still running 91 | #[derive(Debug, Clone)] 92 | pub enum OptimizeResultFlag { 93 | Running, 94 | Success(String), 95 | Failure(String), 96 | } 97 | 98 | impl OptimizeResultFlag { 99 | fn is_running(&self) -> bool { 100 | matches!(self, Self::Running) 101 | } 102 | fn is_success(&self) -> bool { 103 | matches!(self, Self::Success(_)) 104 | } 105 | fn is_failure(&self) -> bool { 106 | matches!(self, Self::Failure(_)) 107 | } 108 | } 109 | 110 | #[derive(Debug, Clone)] 111 | pub struct OptimizeResult 112 | where 113 | X: IntoMetric, 114 | F: IntoMetric, 115 | J: IntoMetric, 116 | H: IntoMetric, 117 | M: Metric, 118 | { 119 | pub criteria: OptimizeCriteria, 120 | pub old_x: Option, 121 | pub old_f: Option, 122 | pub sol_x: Option, 123 | pub sol_f: Option, 124 | pub sol_j: Option, 125 | pub sol_h: Option, 126 | pub target_f: Option, 127 | pub iter: u64, 128 | pub nfev: u64, 129 | pub njev: u64, 130 | pub nhev: u64, 131 | pub err_xa: Option, 132 | pub err_xr: Option, 133 | pub err_fa: Option, 134 | pub err_fr: Option, 135 | pub err_fl: Option, 136 | pub flag: OptimizeResultFlag, 137 | } 138 | 139 | impl Default for OptimizeResult 140 | where 141 | X: IntoMetric, 142 | F: IntoMetric, 143 | J: IntoMetric, 144 | H: IntoMetric, 145 | M: Metric, 146 | { 147 | /// A default new result with all field initalizedm and default criteria 148 | fn default() -> Self { 149 | OptimizeResult { 150 | criteria: OptimizeCriteria::default(), 151 | old_x: None, 152 | old_f: None, 153 | sol_x: None, 154 | sol_f: None, 155 | sol_j: None, 156 | sol_h: None, 157 | target_f: None, 158 | iter: 0, 159 | nfev: 0, 160 | njev: 0, 161 | nhev: 0, 162 | err_xa: None, 163 | err_xr: None, 164 | err_fa: None, 165 | err_fr: None, 166 | err_fl: None, 167 | flag: OptimizeResultFlag::Running, 168 | } 169 | } 170 | } 171 | 172 | impl Display for OptimizeResult 173 | where 174 | X: IntoMetric, 175 | F: IntoMetric, 176 | J: IntoMetric, 177 | H: IntoMetric, 178 | M: Metric, 179 | { 180 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 181 | f.write_fmt(format_args!( 182 | "{}", 183 | [ 184 | "[Optimization Result]".to_string(), 185 | format!(" flag : {:?}", self.flag), 186 | format!(" sol_x : {:?}", self.sol_x), 187 | format!(" sol_f : {:?}", self.sol_f), 188 | format!(" sol_j : {:?}", self.sol_j), 189 | format!(" sol_h : {:?}", self.sol_h), 190 | format!(" iter : {:?}", self.iter), 191 | format!(" target_f : {:?}", self.target_f), 192 | format!(" nfev : {:?}", self.nfev), 193 | format!(" njev : {:?}", self.njev), 194 | format!(" nhev : {:?}", self.nhev), 195 | " [criteria]".to_string(), 196 | format!(" xatol : {:?}", self.criteria.xatol), 197 | format!(" xrtol : {:?}", self.criteria.xrtol), 198 | format!(" fatol : {:?}", self.criteria.fatol), 199 | format!(" frtol : {:?}", self.criteria.frtol), 200 | format!(" fltol : {:?}", self.criteria.fltol), 201 | format!(" maxiter : {:?}", self.criteria.maxiter), 202 | format!(" x_metric_type : {:?}", self.criteria.x_metric_type), 203 | format!(" f_metric_type : {:?}", self.criteria.f_metric_type), 204 | " [Error]".to_string(), 205 | format!(" err_xa : {:?}", self.err_xa), 206 | format!(" err_xr : {:?}", self.err_xr), 207 | format!(" err_fa : {:?}", self.err_fa), 208 | format!(" err_fr : {:?}", self.err_fr), 209 | format!(" err_fl : {:?}", self.err_fl), 210 | ] 211 | .join("\n") 212 | )) 213 | //Ok(()) 214 | } 215 | } 216 | 217 | // Evaluation Function 218 | impl OptimizeResult 219 | where 220 | X: IntoMetric, 221 | F: IntoMetric, 222 | J: IntoMetric, 223 | H: IntoMetric, 224 | M: Metric, 225 | { 226 | /// Builder Pattern for setting the criteria 227 | pub fn set_criteria(mut self, criteria: OptimizeCriteria) -> Self { 228 | self.criteria = criteria; 229 | self 230 | } 231 | 232 | /// Builder Pattern for setting the `target_f` 233 | pub fn set_target_f(mut self, target_f: Option) -> Self { 234 | self.target_f = target_f; 235 | self 236 | } 237 | 238 | /// Increment nfev 239 | pub fn fev(&mut self) { 240 | self.nfev += 1; 241 | } 242 | /// Increment njev 243 | pub fn jev(&mut self) { 244 | self.njev += 1; 245 | } 246 | /// Increment nhev 247 | pub fn hev(&mut self) { 248 | self.nhev += 1; 249 | } 250 | 251 | /// Update new solution from solver, increment iter, and calculate error 252 | pub fn update(&mut self, new_solution: (X, F, Option, Option)) { 253 | (self.old_x, self.old_f) = (self.sol_x.clone(), self.sol_f.clone()); 254 | self.sol_x = Some(new_solution.0); 255 | self.sol_f = Some(new_solution.1); 256 | self.sol_j = new_solution.2; 257 | self.sol_h = new_solution.3; 258 | self.iter += 1; 259 | self.err_xa = self.error_xa(); 260 | self.err_xr = self.error_xr(); 261 | self.err_fa = self.error_fa(); 262 | self.err_fr = self.error_fr(); 263 | self.err_fl = self.error_fl(); 264 | } 265 | 266 | /// calculate and return `err_xa`, only calculate, without update 267 | fn error_xa(&self) -> Option { 268 | Some((self.sol_x.clone()? - self.old_x.clone()?).eval(&self.criteria.x_metric_type)) 269 | } 270 | /// calculate and return `err_xr`, only calculate, without update 271 | fn error_xr(&self) -> Option { 272 | Some(self.error_xa()? / self.sol_x.clone()?.eval(&self.criteria.x_metric_type)) 273 | } 274 | /// calculate and return `err_fa`, only calculate, without update 275 | fn error_fa(&self) -> Option { 276 | Some((self.sol_f.clone()? - self.old_f.clone()?).eval(&self.criteria.f_metric_type)) 277 | } 278 | /// calculate and return `err_fr`, only calculate, without update 279 | fn error_fr(&self) -> Option { 280 | Some(self.error_fa()? / self.sol_f.clone()?.eval(&self.criteria.f_metric_type)) 281 | } 282 | /// calculate and return `err_fl`, only calculate, without update 283 | fn error_fl(&self) -> Option { 284 | Some((self.sol_f.clone()? - self.target_f.clone()?).eval(&self.criteria.f_metric_type)) 285 | } 286 | 287 | /// return `true` if `xatol` in criteria is set, and satisfy 288 | pub fn satisfy_xatol(&self) -> bool { 289 | match (self.criteria.xatol, self.err_xa) { 290 | (Some(tol), Some(error)) => error < tol, 291 | _ => false, 292 | } 293 | } 294 | /// return `true` if `xrtol` in criteria is set, and satisfy 295 | pub fn satisfy_xrtol(&self) -> bool { 296 | match (self.criteria.xrtol, self.err_xr) { 297 | (Some(tol), Some(error)) => error < tol, 298 | _ => false, 299 | } 300 | } 301 | /// return `true` if `fatol` in criteria is set, and satisfy 302 | pub fn satisfy_fatol(&self) -> bool { 303 | match (self.criteria.fatol, self.err_fa) { 304 | (Some(tol), Some(error)) => error < tol, 305 | _ => false, 306 | } 307 | } 308 | /// return `true` if `frtol` in criteria is set, and satisfy 309 | pub fn satisfy_frtol(&self) -> bool { 310 | match (self.criteria.frtol, self.err_fr) { 311 | (Some(tol), Some(error)) => error < tol, 312 | _ => false, 313 | } 314 | } 315 | /// return `true` if `fltol` in criteria is set, and satisfy 316 | pub fn satisfy_fltol(&self) -> bool { 317 | match (self.criteria.fltol, self.err_fl) { 318 | (Some(tol), Some(error)) => error < tol, 319 | _ => false, 320 | } 321 | } 322 | 323 | /// either of all tolerance is satisfied 324 | pub fn satisfy_either(&self) -> bool { 325 | self.satisfy_xatol() 326 | || self.satisfy_xrtol() 327 | || self.satisfy_fatol() 328 | || self.satisfy_frtol() 329 | || self.satisfy_fltol() 330 | } 331 | 332 | /// return a `Vec` of all the torlerance that passed 333 | pub fn satisfied(&self) -> Vec { 334 | [ 335 | (self.satisfy_xatol(), "xatol"), 336 | (self.satisfy_xrtol(), "xrtol"), 337 | (self.satisfy_fatol(), "fatol"), 338 | (self.satisfy_frtol(), "frtol"), 339 | (self.satisfy_fltol(), "fltol"), 340 | ] 341 | .iter() 342 | .filter_map(|(sat, tol)| if !sat { None } else { Some(tol.to_string()) }) 343 | .collect::>() 344 | } 345 | 346 | /// return `true` if `iter >= maxiter`, given maxiter is set in criteria 347 | pub fn overran(&self) -> bool { 348 | if let Some(maxiter) = self.criteria.maxiter { 349 | self.iter >= maxiter 350 | } else { 351 | false 352 | } 353 | } 354 | /// set flag to successful with a message 355 | pub fn set_success(&mut self, msg: String) { 356 | self.flag = OptimizeResultFlag::Success(msg); 357 | } 358 | /// set flag to failure with a message 359 | pub fn set_failure(&mut self, msg: String) { 360 | self.flag = OptimizeResultFlag::Failure(msg); 361 | } 362 | } 363 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/fixed_point.rs: -------------------------------------------------------------------------------- 1 | use crate::optimize::root_scalar::*; 2 | use crate::optimize::util::*; 3 | 4 | pub fn fixed_point_method C, C, M>( 5 | fun: F, 6 | x0: C, 7 | criteria: Option>, 8 | ) -> OptimizeResult 9 | where 10 | C: IntoMetric + ComplexFloat + Espilon, 11 | M: Metric, 12 | { 13 | let evaluator = RootScalarEvaluator::new(criteria); 14 | let evaluator = Rc::new(RefCell::new(evaluator)); 15 | let fun = { 16 | let evaluator = evaluator.clone(); 17 | move |x| { 18 | evaluator.borrow_mut().res.fev(); 19 | fun(x) 20 | } 21 | }; 22 | 23 | let solver = FixedPointSolver::new(fun, x0); 24 | 25 | iterative_optimize(solver, evaluator) 26 | } 27 | 28 | pub struct FixedPointSolver 29 | where 30 | C: ComplexFloat, 31 | F: Fn(C) -> C, 32 | { 33 | fun: F, 34 | x0: C, 35 | f0: C, 36 | } 37 | 38 | impl FixedPointSolver 39 | where 40 | C: ComplexFloat + Espilon, 41 | F: Fn(C) -> C, 42 | { 43 | fn new(mut fun: F, x0: C) -> Self { 44 | let f0 = fun(x0); 45 | Self { fun, x0, f0 } 46 | } 47 | } 48 | 49 | impl IterativeSolver for FixedPointSolver 50 | where 51 | C: IntoMetric + ComplexFloat + Espilon, 52 | M: Metric, 53 | F: Fn(C) -> C, 54 | { 55 | fn new_solution(&mut self) -> (C, C, Option, Option) { 56 | self.x0 = self.x0 - self.f0; 57 | self.f0 = (self.fun)(self.x0); 58 | (self.x0, self.f0, None, None) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/halley.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use crate::optimize::root_scalar::*; 4 | use crate::optimize::util::*; 5 | 6 | pub fn halley_method_approx( 7 | fun: F, 8 | x0: C, 9 | criteria: Option>, 10 | ) -> OptimizeResult 11 | where 12 | C: IntoMetric + ComplexFloat + Espilon, 13 | M: Metric, 14 | F: Fn(C) -> C, 15 | { 16 | let evaluator = RootScalarEvaluator::new(criteria); 17 | let evaluator = Rc::new(RefCell::new(evaluator)); 18 | 19 | let fun = { 20 | let evaluator = evaluator.clone(); 21 | move |x| { 22 | evaluator.borrow_mut().res.fev(); 23 | fun(x) 24 | } 25 | }; 26 | 27 | let fun = Rc::new(fun); 28 | 29 | let dfun = { 30 | let f = fun.clone(); 31 | 32 | approx_derivative(move |x| f(x)) 33 | }; 34 | 35 | let ddfun = { 36 | let f = fun.clone(); 37 | 38 | approx_second_derivative(move |x| f(x)) 39 | }; 40 | 41 | let fun = move |x| fun(x); 42 | 43 | let solver = NewtonSolver::new(fun, dfun, ddfun, x0); 44 | 45 | iterative_optimize(solver, evaluator) 46 | } 47 | 48 | pub fn halley_method( 49 | fun: F, 50 | dfun: FD1, 51 | ddfun: FD2, 52 | x0: C, 53 | criteria: Option>, 54 | ) -> OptimizeResult 55 | where 56 | C: IntoMetric + ComplexFloat + Espilon, 57 | M: Metric, 58 | F: Fn(C) -> C, 59 | FD1: Fn(C) -> C, 60 | FD2: Fn(C) -> C, 61 | { 62 | let evaluator = RootScalarEvaluator::new(criteria); 63 | let evaluator = Rc::new(RefCell::new(evaluator)); 64 | 65 | let fun = { 66 | let evaluator = evaluator.clone(); 67 | move |x| { 68 | evaluator.borrow_mut().res.fev(); 69 | fun(x) 70 | } 71 | }; 72 | 73 | let dfun = { 74 | let evaluator = evaluator.clone(); 75 | move |x| { 76 | evaluator.borrow_mut().res.jev(); 77 | dfun(x) 78 | } 79 | }; 80 | 81 | let ddfun = { 82 | let evaluator = evaluator.clone(); 83 | move |x| { 84 | evaluator.borrow_mut().res.hev(); 85 | ddfun(x) 86 | } 87 | }; 88 | 89 | let solver = NewtonSolver::new(fun, dfun, ddfun, x0); 90 | 91 | iterative_optimize(solver, evaluator) 92 | } 93 | 94 | pub struct NewtonSolver 95 | where 96 | C: ComplexFloat, 97 | { 98 | fun: F, 99 | dfun: FD1, 100 | ddfun: FD2, 101 | x0: C, 102 | f0: C, 103 | j0: C, 104 | h0: C, 105 | } 106 | 107 | impl NewtonSolver 108 | where 109 | C: ComplexFloat + Espilon, 110 | F: Fn(C) -> C, 111 | FD1: Fn(C) -> C, 112 | FD2: Fn(C) -> C, 113 | { 114 | fn new(mut fun: F, mut dfun: FD1, mut ddfun: FD2, x0: C) -> Self { 115 | let f0 = fun(x0); 116 | let j0 = dfun(x0); 117 | let h0 = ddfun(x0); 118 | 119 | Self { 120 | fun, 121 | dfun, 122 | ddfun, 123 | x0, 124 | f0, 125 | j0, 126 | h0, 127 | } 128 | } 129 | } 130 | 131 | impl IterativeSolver for NewtonSolver 132 | where 133 | C: IntoMetric + ComplexFloat + Espilon, 134 | M: Metric, 135 | F: Fn(C) -> C, 136 | FD1: Fn(C) -> C, 137 | FD2: Fn(C) -> C, 138 | { 139 | fn new_solution(&mut self) -> (C, C, Option, Option) { 140 | self.x0 = self.x0 141 | - (C::from(2.0).unwrap() * self.f0 * self.j0) 142 | / (C::from(2.0).unwrap() * self.j0.powi(2) - self.f0 * self.h0 + C::epsilon()); 143 | self.f0 = (self.fun)(self.x0); 144 | self.j0 = (self.dfun)(self.x0); 145 | self.h0 = (self.ddfun)(self.x0); 146 | 147 | (self.x0, self.f0, Some(self.j0), Some(self.h0)) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/mod.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::rc::Rc; 3 | 4 | use num::{complex::ComplexFloat, Float}; 5 | 6 | use crate::optimize::*; 7 | 8 | pub mod bracket; 9 | pub mod fixed_point; 10 | pub mod halley; 11 | pub mod newton; 12 | pub mod polynomial; 13 | pub mod secant; 14 | 15 | pub use bracket::solve_from_bracket; 16 | pub use fixed_point::fixed_point_method; 17 | pub use halley::{halley_method, halley_method_approx}; 18 | pub use newton::{newton_method, newton_method_approx}; 19 | pub use polynomial::polynomial_roots; 20 | pub use secant::secant_method; 21 | 22 | struct RootScalarEvaluator 23 | where 24 | C: IntoMetric + ComplexFloat, 25 | M: Metric, 26 | { 27 | res: OptimizeResult, 28 | } 29 | 30 | impl Evaluator for RootScalarEvaluator 31 | where 32 | C: IntoMetric + ComplexFloat, 33 | M: Metric, 34 | { 35 | fn new(criteria: Option>) -> Self 36 | where 37 | Self: Sized, 38 | { 39 | let res = OptimizeResult::default() 40 | .set_criteria(criteria.unwrap_or_default()) 41 | .set_target_f(Some(C::zero())); 42 | Self { res } 43 | } 44 | fn update(&mut self, new_solution: (C, C, Option, Option)) { 45 | self.res.update(new_solution); 46 | } 47 | fn result(&self) -> &OptimizeResult { 48 | &self.res 49 | } 50 | fn result_mut(&mut self) -> &mut OptimizeResult { 51 | &mut self.res 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/newton.rs: -------------------------------------------------------------------------------- 1 | use crate::optimize::root_scalar::*; 2 | use crate::optimize::util::*; 3 | 4 | pub fn newton_method_approx( 5 | fun: F, 6 | x0: C, 7 | criteria: Option>, 8 | ) -> OptimizeResult 9 | where 10 | C: IntoMetric + ComplexFloat + Espilon, 11 | M: Metric, 12 | F: Fn(C) -> C, 13 | { 14 | let evaluator = RootScalarEvaluator::new(criteria); 15 | let evaluator = Rc::new(RefCell::new(evaluator)); 16 | 17 | let fun = { 18 | let evaluator = evaluator.clone(); 19 | move |x| { 20 | evaluator.borrow_mut().res.fev(); 21 | fun(x) 22 | } 23 | }; 24 | 25 | let fun = Rc::new(fun); 26 | 27 | let dfun = { 28 | let evaluator = evaluator.clone(); 29 | let f = fun.clone(); 30 | 31 | approx_derivative(move |x| f(x)) 32 | }; 33 | 34 | let fun = move |x| fun(x); 35 | 36 | let solver = NewtonSolver::new(fun, dfun, x0); 37 | 38 | iterative_optimize(solver, evaluator) 39 | } 40 | 41 | pub fn newton_method( 42 | fun: F, 43 | dfun: FD, 44 | x0: C, 45 | criteria: Option>, 46 | ) -> OptimizeResult 47 | where 48 | C: IntoMetric + ComplexFloat + Espilon, 49 | M: Metric, 50 | F: Fn(C) -> C, 51 | FD: Fn(C) -> C, 52 | { 53 | let evaluator = RootScalarEvaluator::new(criteria); 54 | let evaluator = Rc::new(RefCell::new(evaluator)); 55 | 56 | let fun = { 57 | let evaluator = evaluator.clone(); 58 | move |x| { 59 | evaluator.borrow_mut().res.fev(); 60 | fun(x) 61 | } 62 | }; 63 | 64 | let dfun = { 65 | let evaluator = evaluator.clone(); 66 | move |x| { 67 | evaluator.borrow_mut().res.jev(); 68 | dfun(x) 69 | } 70 | }; 71 | 72 | let solver = NewtonSolver::new(fun, dfun, x0); 73 | 74 | iterative_optimize(solver, evaluator) 75 | } 76 | 77 | pub struct NewtonSolver 78 | where 79 | C: ComplexFloat + Espilon, 80 | { 81 | fun: F, 82 | dfun: FD, 83 | x0: C, 84 | f0: C, 85 | j0: C, 86 | } 87 | 88 | impl NewtonSolver 89 | where 90 | C: ComplexFloat + Espilon, 91 | F: Fn(C) -> C, 92 | FD: Fn(C) -> C, 93 | { 94 | fn new(mut fun: F, mut dfun: FD, x0: C) -> Self { 95 | let f0 = fun(x0); 96 | let j0 = dfun(x0); 97 | 98 | Self { 99 | fun, 100 | dfun, 101 | x0, 102 | f0, 103 | j0, 104 | } 105 | } 106 | } 107 | 108 | impl IterativeSolver for NewtonSolver 109 | where 110 | C: IntoMetric + ComplexFloat + Espilon, 111 | M: Metric, 112 | F: Fn(C) -> C, 113 | FD: Fn(C) -> C, 114 | { 115 | fn new_solution(&mut self) -> (C, C, Option, Option) { 116 | self.x0 = self.x0 - (self.f0) / (self.j0 + C::epsilon()); 117 | self.f0 = (self.fun)(self.x0); 118 | self.j0 = (self.dfun)(self.x0); 119 | 120 | (self.x0, self.f0, Some(self.j0), None) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/polynomial.rs: -------------------------------------------------------------------------------- 1 | use crate::odr::polynomial::{Polynomial, PolynomialCoef}; 2 | use crate::optimize::root_scalar::*; 3 | use crate::optimize::util::Espilon; 4 | use num::complex::{Complex32, Complex64, ComplexFloat}; 5 | 6 | pub trait IntoComplex: ComplexFloat { 7 | fn as_complex(&self) -> C; 8 | } 9 | impl IntoComplex for f32 { 10 | fn as_complex(&self) -> Complex32 { 11 | Complex32::new(*self, 0.0) 12 | } 13 | } 14 | impl IntoComplex for f64 { 15 | fn as_complex(&self) -> Complex64 { 16 | Complex64::new(*self, 0.0) 17 | } 18 | } 19 | impl IntoComplex for Complex32 { 20 | fn as_complex(&self) -> Complex32 { 21 | *self 22 | } 23 | } 24 | impl IntoComplex for Complex64 { 25 | fn as_complex(&self) -> Complex64 { 26 | *self 27 | } 28 | } 29 | /// Solve quadratic equation for given equation 30 | /// 31 | /// for `f32`, `f64`, it will return None if the determiniate is non-square-root-able 32 | /// 33 | /// for `Complex32`, `Complex64`, will always return root. 34 | pub fn quadratic_root(a: T, b: T, c: T) -> Option> 35 | where 36 | T: ComplexFloat, 37 | { 38 | let d = (b.powi(2) - a * c * T::from(4.0).unwrap()).sqrt(); 39 | if d.is_nan() { 40 | return None; 41 | } 42 | let a2 = a * T::from(2.0).unwrap(); 43 | Some(vec![(-b + d) / a2, (-b - d) / a2]) 44 | } 45 | 46 | /// solve for all root of a given polynomial. 47 | /// 48 | /// it always convert the given polynomial to complex domain first, 49 | /// thus there will always be root 50 | /// 51 | /// this function use halley's method to solve for roots 52 | /// and deflate polynomial recursively, 53 | /// it might panic if halley's method cannot solve for the root within 1000 iterations. 54 | pub fn polynomial_roots(polynomial: &Polynomial) -> Vec 55 | where 56 | T: PolynomialCoef + Espilon + IntoMetric + IntoComplex, 57 | C: PolynomialCoef + Espilon + IntoMetric + IntoComplex, 58 | M: Metric, 59 | { 60 | let polynomial: Polynomial = polynomial 61 | .iter() 62 | .map(|&c| c.as_complex()) 63 | .collect::>() 64 | .saturate(); 65 | 66 | match polynomial.degree() { 67 | 0 => vec![], 68 | 1 => vec![-polynomial[0] / polynomial[1]], 69 | 2 => quadratic_root(polynomial[2], polynomial[1], polynomial[0]).unwrap(), 70 | _ => { 71 | let criteria = Some( 72 | OptimizeCriteria::empty() 73 | .set_fltol(Some(M::from(1e-15).unwrap())) 74 | .set_maxiter(Some(1000)), 75 | ); 76 | 77 | let x0 = C::zero(); 78 | 79 | let fun = polynomial.clone().as_fn(); 80 | let dfun = polynomial.clone().differentiate().as_fn(); 81 | let ddfun = polynomial.clone().differentiate().differentiate().as_fn(); 82 | 83 | let res = halley_method(fun, dfun, ddfun, x0, criteria); 84 | 85 | let root = res.sol_x.unwrap(); 86 | 87 | let deflated = polynomial.deflate(root).unwrap().0.to_owned(); 88 | 89 | let roots: Vec = polynomial_roots(&deflated); 90 | 91 | std::iter::once(root).chain(roots).collect() 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/optimize/root_scalar/secant.rs: -------------------------------------------------------------------------------- 1 | use crate::optimize::root_scalar::*; 2 | use num::traits::Float; 3 | 4 | pub fn secant_method( 5 | fun: F, 6 | x0: C, 7 | x1: C, 8 | criteria: Option>, 9 | ) -> OptimizeResult 10 | where 11 | C: IntoMetric + ComplexFloat, 12 | M: Metric, 13 | F: Fn(C) -> C, 14 | { 15 | let evaluator = RootScalarEvaluator::new(criteria); 16 | let evaluator = Rc::new(RefCell::new(evaluator)); 17 | 18 | let fun = { 19 | let evaluator = evaluator.clone(); 20 | move |x| { 21 | evaluator.borrow_mut().res.fev(); 22 | fun(x) 23 | } 24 | }; 25 | 26 | let solver = SecantSolver::new(fun, x0, x1); 27 | 28 | iterative_optimize(solver, evaluator) 29 | } 30 | 31 | pub struct SecantSolver 32 | where 33 | C: ComplexFloat, 34 | { 35 | fun: F, 36 | x0: C, 37 | x1: C, 38 | f0: C, 39 | f1: C, 40 | } 41 | 42 | impl SecantSolver 43 | where 44 | C: ComplexFloat, 45 | F: Fn(C) -> C, 46 | { 47 | fn new(mut fun: F, x0: C, x1: C) -> Self { 48 | let (f0, f1) = (fun(x0), fun(x1)); 49 | Self { 50 | fun, 51 | x0, 52 | x1, 53 | f0, 54 | f1, 55 | } 56 | } 57 | } 58 | 59 | impl IterativeSolver for SecantSolver 60 | where 61 | C: IntoMetric + ComplexFloat, 62 | M: Metric, 63 | F: Fn(C) -> C, 64 | { 65 | fn new_solution(&mut self) -> (C, C, Option, Option) { 66 | (self.x0, self.x1) = ( 67 | self.x1, 68 | self.x1 69 | - self.f1 * (self.x1 - self.x0) 70 | / (self.f1 - self.f0 + C::from(C::Real::epsilon()).unwrap()), 71 | ); 72 | (self.f0, self.f1) = ((self.fun)(self.x0), (self.fun)(self.x1)); 73 | 74 | (self.x1, self.f1, None, None) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/optimize/util.rs: -------------------------------------------------------------------------------- 1 | use num::complex::{Complex32, Complex64, ComplexFloat}; 2 | 3 | use crate::optimize::*; 4 | 5 | pub fn approx_derivative(fun: F) -> impl Fn(C) -> C 6 | where 7 | F: Fn(C) -> C, 8 | C: IntoMetric + ComplexFloat + Espilon, 9 | M: Metric, 10 | { 11 | let delta: C = C::epsilon().powf(C::from(1.0 / 3.0).unwrap().re()); 12 | move |x: C| (fun(x + delta) - fun(x)) / delta 13 | } 14 | pub fn approx_second_derivative(fun: F) -> impl Fn(C) -> C 15 | where 16 | F: Fn(C) -> C, 17 | C: IntoMetric + ComplexFloat + Espilon, 18 | M: Metric, 19 | { 20 | let delta: C = C::epsilon().powf(C::from(1.0 / 3.0).unwrap().re()); 21 | move |x: C| (fun(x + delta) - fun(x) * C::from(2.0).unwrap() + fun(x - delta)) / delta / delta 22 | } 23 | 24 | /// A trait implemented for `f32`,`f64`,`Complex32` and `Complex64` 25 | /// 26 | /// allow the type to be able to call `espilon()` for safe division 27 | pub trait Espilon { 28 | fn epsilon() -> Self; 29 | } 30 | impl Espilon for f64 { 31 | fn epsilon() -> f64 { 32 | f64::EPSILON 33 | } 34 | } 35 | impl Espilon for f32 { 36 | fn epsilon() -> f32 { 37 | f32::EPSILON 38 | } 39 | } 40 | impl Espilon for Complex32 { 41 | fn epsilon() -> Complex32 { 42 | Complex32::new(f32::epsilon(), f32::epsilon()) 43 | } 44 | } 45 | impl Espilon for Complex64 { 46 | fn epsilon() -> Complex64 { 47 | Complex64::new(f64::epsilon(), f64::epsilon()) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/signal/band_filter.rs: -------------------------------------------------------------------------------- 1 | use super::{output_type::GenericZpk, tools::relative_degree}; 2 | use ndarray::{array, concatenate, Array1, ArrayView, Axis}; 3 | use num::{Complex, Float, Num, NumCast, Zero}; 4 | use std::{ 5 | borrow::Cow, 6 | ops::{Add, Div, Mul, Sub}, 7 | }; 8 | use thiserror::Error; 9 | 10 | pub type BandFilter = GenericBandFilter; 11 | 12 | #[derive(Debug, Clone, Copy)] 13 | pub enum GenericBandFilter { 14 | Highpass(T), 15 | Lowpass(T), 16 | Bandpass { low: T, high: T }, 17 | Bandstop { low: T, high: T }, 18 | } 19 | 20 | impl GenericBandFilter { 21 | pub fn cast(self) -> GenericBandFilter 22 | where 23 | K: From, 24 | { 25 | match self { 26 | GenericBandFilter::Highpass(data) => GenericBandFilter::Highpass(data.into()), 27 | GenericBandFilter::Lowpass(data) => GenericBandFilter::Lowpass(data.into()), 28 | GenericBandFilter::Bandpass { low, high } => GenericBandFilter::Bandpass { 29 | low: low.into(), 30 | high: high.into(), 31 | }, 32 | GenericBandFilter::Bandstop { low, high } => GenericBandFilter::Bandstop { 33 | low: low.into(), 34 | high: high.into(), 35 | }, 36 | } 37 | } 38 | 39 | pub fn cast_with_fn(self, f: impl Fn(T) -> K) -> GenericBandFilter { 40 | match self { 41 | GenericBandFilter::Highpass(data) => GenericBandFilter::Highpass(f(data)), 42 | GenericBandFilter::Lowpass(data) => GenericBandFilter::Lowpass(f(data)), 43 | GenericBandFilter::Bandpass { low, high } => GenericBandFilter::Bandpass { 44 | low: f(low), 45 | high: f(high), 46 | }, 47 | GenericBandFilter::Bandstop { low, high } => GenericBandFilter::Bandstop { 48 | low: f(low), 49 | high: f(high), 50 | }, 51 | } 52 | } 53 | } 54 | 55 | impl GenericBandFilter 56 | where 57 | T: Float, 58 | { 59 | pub fn tan(self) -> Self { 60 | match self { 61 | Self::Highpass(data) => Self::Highpass(data.tan()), 62 | Self::Lowpass(data) => Self::Lowpass(data.tan()), 63 | Self::Bandpass { low, high } => Self::Bandpass { 64 | low: low.tan(), 65 | high: high.tan(), 66 | }, 67 | Self::Bandstop { low, high } => Self::Bandstop { 68 | low: low.tan(), 69 | high: high.tan(), 70 | }, 71 | } 72 | } 73 | 74 | pub fn size(&self) -> u8 { 75 | match self { 76 | Self::Highpass(_) | Self::Lowpass(_) => 1, 77 | Self::Bandstop { low: _, high: _ } | Self::Bandpass { low: _, high: _ } => 2, 78 | } 79 | } 80 | pub fn to_vec(self) -> Vec { 81 | match self { 82 | Self::Highpass(f) | Self::Lowpass(f) => vec![f], 83 | Self::Bandstop { low, high } | Self::Bandpass { low, high } => { 84 | vec![low, high] 85 | } 86 | } 87 | } 88 | pub fn to_array(self) -> Array1 { 89 | match self { 90 | Self::Highpass(f) | Self::Lowpass(f) => array![f], 91 | Self::Bandstop { low, high } | Self::Bandpass { low, high } => { 92 | array![low, high] 93 | } 94 | } 95 | } 96 | 97 | pub fn pass_zero(&self) -> bool { 98 | match self { 99 | Self::Bandpass { low: _, high: _ } | Self::Highpass(_) => false, 100 | Self::Bandstop { low: _, high: _ } | Self::Lowpass(_) => true, 101 | } 102 | } 103 | 104 | pub fn pass_nyquist(&self, pass_zero: bool) -> bool { 105 | ((self.size() & 1) == 1) ^ pass_zero 106 | } 107 | } 108 | 109 | impl Mul for GenericBandFilter 110 | where 111 | T: Mul, 112 | { 113 | type Output = Self; 114 | fn mul(self, rhs: R) -> Self::Output { 115 | match self { 116 | Self::Highpass(data) => Self::Highpass(data * rhs), 117 | Self::Lowpass(data) => Self::Lowpass(data * rhs), 118 | Self::Bandpass { low, high } => Self::Bandpass { 119 | low: low * rhs, 120 | high: high * rhs, 121 | }, 122 | Self::Bandstop { low, high } => Self::Bandstop { 123 | low: low * rhs, 124 | high: high * rhs, 125 | }, 126 | } 127 | } 128 | } 129 | 130 | impl Div for GenericBandFilter { 131 | type Output = Self; 132 | fn div(self, rhs: T) -> Self::Output { 133 | match self { 134 | Self::Highpass(data) => Self::Highpass(data / rhs), 135 | Self::Lowpass(data) => Self::Lowpass(data / rhs), 136 | Self::Bandpass { low, high } => Self::Bandpass { 137 | low: low / rhs, 138 | high: high / rhs, 139 | }, 140 | Self::Bandstop { low, high } => Self::Bandstop { 141 | low: low / rhs, 142 | high: high / rhs, 143 | }, 144 | } 145 | } 146 | } 147 | 148 | #[derive(Debug, Clone, Copy)] 149 | pub struct GenericOrdBandFilter(GenericOrdBandFilterType); 150 | 151 | impl std::ops::Deref for GenericOrdBandFilter { 152 | type Target = GenericOrdBandFilterType; 153 | 154 | fn deref(&self) -> &Self::Target { 155 | &self.0 156 | } 157 | } 158 | 159 | impl AsRef> for GenericOrdBandFilter { 160 | fn as_ref(&self) -> &GenericOrdBandFilterType { 161 | &self.0 162 | } 163 | } 164 | 165 | #[derive(Debug, Error)] 166 | pub enum Error { 167 | #[error("{0}")] 168 | Validation(Cow<'static, str>), 169 | } 170 | 171 | impl GenericOrdBandFilter { 172 | pub fn lowpass(wp: T, ws: T) -> Result { 173 | if wp > ws { 174 | Err(Error::Validation(Cow::from( 175 | "for lowpass filter wp must be smaller than ws", 176 | )))?; 177 | } 178 | Ok(Self(GenericOrdBandFilterType::Lowpass { wp, ws })) 179 | } 180 | 181 | pub fn highpass(wp: T, ws: T) -> Result { 182 | if wp < ws { 183 | Err(Error::Validation(Cow::from( 184 | "for highpass filter ws must be smaller than wp", 185 | )))?; 186 | } 187 | Ok(Self(GenericOrdBandFilterType::Lowpass { wp, ws })) 188 | } 189 | 190 | pub fn bandpass(wp_low: T, wp_high: T, ws_low: T, ws_high: T) -> Result { 191 | if wp_low < ws_low { 192 | Err(Error::Validation(Cow::from( 193 | "for bandpass filter ws_low must be smaller than wp_low", 194 | )))?; 195 | } 196 | Ok(Self(GenericOrdBandFilterType::Bandpass { 197 | wp_low, 198 | wp_high, 199 | ws_low, 200 | ws_high, 201 | })) 202 | } 203 | 204 | pub fn bandstop(wp_low: T, wp_high: T, ws_low: T, ws_high: T) -> Result { 205 | if ws_low < wp_low { 206 | Err(Error::Validation(Cow::from( 207 | "for bandstop filter wp_low must be smaller than ws_low", 208 | )))?; 209 | } 210 | Ok(Self(GenericOrdBandFilterType::Bandpass { 211 | wp_low, 212 | wp_high, 213 | ws_low, 214 | ws_high, 215 | })) 216 | } 217 | 218 | pub fn tan(self) -> Self { 219 | use GenericOrdBandFilterType as S; 220 | let inner = match self.0 { 221 | S::Lowpass { wp, ws } => S::Lowpass { 222 | wp: wp.tan(), 223 | ws: ws.tan(), 224 | }, 225 | S::Highpass { wp, ws } => S::Highpass { 226 | wp: wp.tan(), 227 | ws: ws.tan(), 228 | }, 229 | S::Bandpass { 230 | wp_low, 231 | wp_high, 232 | ws_low, 233 | ws_high, 234 | } => S::Bandpass { 235 | wp_low: wp_low.tan(), 236 | wp_high: wp_high.tan(), 237 | ws_low: ws_low.tan(), 238 | ws_high: ws_high.tan(), 239 | }, 240 | S::Bandstop { 241 | wp_low, 242 | wp_high, 243 | ws_low, 244 | ws_high, 245 | } => S::Bandstop { 246 | wp_low: wp_low.tan(), 247 | wp_high: wp_high.tan(), 248 | ws_low: ws_low.tan(), 249 | ws_high: ws_high.tan(), 250 | }, 251 | }; 252 | Self(inner) 253 | } 254 | } 255 | 256 | #[derive(Debug, Clone, Copy)] 257 | pub enum GenericOrdBandFilterType { 258 | Lowpass { 259 | wp: T, 260 | ws: T, 261 | }, 262 | Highpass { 263 | wp: T, 264 | ws: T, 265 | }, 266 | Bandpass { 267 | wp_low: T, 268 | wp_high: T, 269 | ws_low: T, 270 | ws_high: T, 271 | }, 272 | Bandstop { 273 | wp_low: T, 274 | wp_high: T, 275 | ws_low: T, 276 | ws_high: T, 277 | }, 278 | } 279 | 280 | impl Div for GenericOrdBandFilter { 281 | type Output = Self; 282 | 283 | fn div(self, rhs: T) -> Self::Output { 284 | use GenericOrdBandFilterType as S; 285 | 286 | let inner = match self.0 { 287 | S::Lowpass { wp, ws } => S::Lowpass { 288 | wp: wp / rhs, 289 | ws: ws / rhs, 290 | }, 291 | S::Highpass { wp, ws } => S::Highpass { 292 | wp: wp / rhs, 293 | ws: ws / rhs, 294 | }, 295 | S::Bandpass { 296 | wp_low, 297 | wp_high, 298 | ws_low, 299 | ws_high, 300 | } => S::Bandpass { 301 | wp_low: wp_low / rhs, 302 | wp_high: wp_high / rhs, 303 | ws_low: ws_low / rhs, 304 | ws_high: ws_high / rhs, 305 | }, 306 | S::Bandstop { 307 | wp_low, 308 | wp_high, 309 | ws_low, 310 | ws_high, 311 | } => S::Bandstop { 312 | wp_low: wp_low / rhs, 313 | wp_high: wp_high / rhs, 314 | ws_low: ws_low / rhs, 315 | ws_high: ws_high / rhs, 316 | }, 317 | }; 318 | Self(inner) 319 | } 320 | } 321 | impl Sub for GenericOrdBandFilter { 322 | type Output = Self; 323 | 324 | fn sub(self, rhs: T) -> Self::Output { 325 | use GenericOrdBandFilterType as S; 326 | 327 | let inner = match self.0 { 328 | S::Lowpass { wp, ws } => S::Lowpass { 329 | wp: wp - rhs, 330 | ws: ws - rhs, 331 | }, 332 | S::Highpass { wp, ws } => S::Highpass { 333 | wp: wp - rhs, 334 | ws: ws - rhs, 335 | }, 336 | S::Bandpass { 337 | wp_low, 338 | wp_high, 339 | ws_low, 340 | ws_high, 341 | } => S::Bandpass { 342 | wp_low: wp_low - rhs, 343 | wp_high: wp_high - rhs, 344 | ws_low: ws_low - rhs, 345 | ws_high: ws_high - rhs, 346 | }, 347 | S::Bandstop { 348 | wp_low, 349 | wp_high, 350 | ws_low, 351 | ws_high, 352 | } => S::Bandstop { 353 | wp_low: wp_low - rhs, 354 | wp_high: wp_high - rhs, 355 | ws_low: ws_low - rhs, 356 | ws_high: ws_high - rhs, 357 | }, 358 | }; 359 | Self(inner) 360 | } 361 | } 362 | impl Add for GenericOrdBandFilter { 363 | type Output = Self; 364 | 365 | fn add(self, rhs: T) -> Self::Output { 366 | use GenericOrdBandFilterType as S; 367 | 368 | let inner = match self.0 { 369 | S::Lowpass { wp, ws } => S::Lowpass { 370 | wp: wp + rhs, 371 | ws: ws + rhs, 372 | }, 373 | S::Highpass { wp, ws } => S::Highpass { 374 | wp: wp + rhs, 375 | ws: ws + rhs, 376 | }, 377 | S::Bandpass { 378 | wp_low, 379 | wp_high, 380 | ws_low, 381 | ws_high, 382 | } => S::Bandpass { 383 | wp_low: wp_low + rhs, 384 | wp_high: wp_high + rhs, 385 | ws_low: ws_low + rhs, 386 | ws_high: ws_high + rhs, 387 | }, 388 | S::Bandstop { 389 | wp_low, 390 | wp_high, 391 | ws_low, 392 | ws_high, 393 | } => S::Bandstop { 394 | wp_low: wp_low + rhs, 395 | wp_high: wp_high + rhs, 396 | ws_low: ws_low + rhs, 397 | ws_high: ws_high + rhs, 398 | }, 399 | }; 400 | Self(inner) 401 | } 402 | } 403 | 404 | impl Mul for GenericOrdBandFilter { 405 | type Output = Self; 406 | 407 | fn mul(self, rhs: T) -> Self::Output { 408 | use GenericOrdBandFilterType as S; 409 | 410 | let inner = match self.0 { 411 | S::Lowpass { wp, ws } => S::Lowpass { 412 | wp: wp * rhs, 413 | ws: ws * rhs, 414 | }, 415 | S::Highpass { wp, ws } => S::Highpass { 416 | wp: wp * rhs, 417 | ws: ws * rhs, 418 | }, 419 | S::Bandpass { 420 | wp_low, 421 | wp_high, 422 | ws_low, 423 | ws_high, 424 | } => S::Bandpass { 425 | wp_low: wp_low * rhs, 426 | wp_high: wp_high * rhs, 427 | ws_low: ws_low * rhs, 428 | ws_high: ws_high * rhs, 429 | }, 430 | S::Bandstop { 431 | wp_low, 432 | wp_high, 433 | ws_low, 434 | ws_high, 435 | } => S::Bandstop { 436 | wp_low: wp_low * rhs, 437 | wp_high: wp_high * rhs, 438 | ws_low: ws_low * rhs, 439 | ws_high: ws_high * rhs, 440 | }, 441 | }; 442 | Self(inner) 443 | } 444 | } 445 | 446 | pub fn lp2bf_zpk(input: GenericZpk, wo: GenericBandFilter) -> GenericZpk 447 | where 448 | T: Float + Clone, 449 | { 450 | match wo { 451 | GenericBandFilter::Lowpass(wo) => lp2lp_zpk(input, wo), 452 | GenericBandFilter::Highpass(wo) => lp2hp_zpk(input, wo), 453 | GenericBandFilter::Bandpass { low, high } => { 454 | let bw = high - low; 455 | let wo = (low * high).sqrt(); 456 | lp2bp_zpk(input, wo, bw) 457 | } 458 | GenericBandFilter::Bandstop { low, high } => { 459 | let bw = high - low; 460 | let wo = (low * high).sqrt(); 461 | lp2bs_zpk(input, wo, bw) 462 | } 463 | } 464 | } 465 | 466 | pub fn lp2hp_zpk(mut input: GenericZpk, wo: T) -> GenericZpk 467 | where 468 | T: Float + Clone, 469 | { 470 | let degree = relative_degree(&input); 471 | 472 | let big_z = &input.z; 473 | let big_p = &input.p; 474 | 475 | let z_prod = big_z.map(|a| -a).product(); 476 | let p_prod = big_p.map(|a| -a).product(); 477 | 478 | input 479 | .z 480 | .mapv_inplace(|a| as From>::from(wo) / a); 481 | input 482 | .p 483 | .mapv_inplace(|a| as From>::from(wo) / a); 484 | 485 | let zeros = vec![Complex::zero(); degree]; 486 | input 487 | .z 488 | .append(Axis(0), ArrayView::from(zeros.as_slice())) 489 | .unwrap(); 490 | 491 | let factor = (z_prod / p_prod).re; 492 | 493 | input.k = input.k * factor; 494 | if input.k.is_nan() { 495 | println!("lp2hp nan") 496 | } 497 | input 498 | } 499 | 500 | pub fn lp2lp_zpk(mut input: GenericZpk, wo: T) -> GenericZpk 501 | where 502 | T: Float, 503 | { 504 | let degree = relative_degree(&input); 505 | input.z.mapv_inplace(|a| a * wo); 506 | input.p.mapv_inplace(|a| a * wo); 507 | let res = input.k * wo.powi(degree as _); 508 | res.is_nan(); 509 | input.k = res; 510 | input 511 | } 512 | 513 | pub fn lp2bp_zpk(input: GenericZpk, wo: T, bw: T) -> GenericZpk 514 | where 515 | T: Float, 516 | { 517 | let degree = relative_degree(&input); 518 | let GenericZpk { z, p, k } = input; 519 | let two = T::one() + T::one(); 520 | let z_lp = z.mapv(|a| a * bw / two); 521 | let p_lp = p.mapv(|a| a * bw / two); 522 | 523 | let z_hp_left = &z_lp + (z_lp.mapv(|a| a.powf(two) - wo.powf(two))).mapv(Complex::sqrt); 524 | let z_hp_right = &z_lp - (z_lp.mapv(|a| a.powf(two) - wo.powf(two))).mapv(Complex::sqrt); 525 | 526 | let p_hp_left = &p_lp + (p_lp.mapv(|a| a.powf(two) - wo.powf(two))).mapv(Complex::sqrt); 527 | let p_hp_right = &p_lp - (p_lp.mapv(|a| a.powf(two) - wo.powf(two))).mapv(Complex::sqrt); 528 | 529 | let z_bp = concatenate![Axis(0), z_hp_left, z_hp_right]; 530 | let p_bp = concatenate![Axis(0), p_hp_left, p_hp_right]; 531 | 532 | let z_bp = concatenate![Axis(0), z_bp, Array1::zeros(degree)]; 533 | 534 | let k_bp = k * bw.powi(degree as _); 535 | 536 | let p_bp = p_bp.mapv(|mut a| { 537 | if a.im == T::neg_zero() { 538 | a.im = T::zero() 539 | } 540 | a 541 | }); 542 | 543 | GenericZpk { 544 | z: z_bp, 545 | p: p_bp, 546 | k: k_bp, 547 | } 548 | } 549 | 550 | pub fn lp2bs_zpk(input: GenericZpk, wo: T, bw: T) -> GenericZpk 551 | where 552 | T: Float, 553 | { 554 | let degree = relative_degree(&input); 555 | let GenericZpk { z, p, k } = input; 556 | 557 | let bw_half = bw / T::from(2).unwrap(); 558 | let bw_half = Complex::new(bw_half, T::zero()); 559 | 560 | let z_hp = z.mapv(|a| bw_half / a); 561 | let p_hp = p.mapv(|a| bw_half / a); 562 | 563 | let z_bs_left = &z_hp + (z_hp.mapv(|a| a.powi(2) - wo.powi(2))).mapv(Complex::sqrt); 564 | let z_bs_right = &z_hp - (z_hp.mapv(|a| a.powi(2) - wo.powi(2))).mapv(Complex::sqrt); 565 | 566 | let p_bs_left = &p_hp + (p_hp.mapv(|a| a.powi(2) - wo.powi(2))).mapv(Complex::sqrt); 567 | let p_bs_right = &p_hp - (p_hp.mapv(|a| a.powi(2) - wo.powi(2))).mapv(Complex::sqrt); 568 | 569 | let dbg_zbs_left = z_bs_left.mapv(|a| { 570 | Complex::new( 571 | ::from(a.re).unwrap(), 572 | ::from(a.im).unwrap(), 573 | ) 574 | }); 575 | 576 | let dbg_zbs_right = z_bs_right.mapv(|a| { 577 | Complex::new( 578 | ::from(a.re).unwrap(), 579 | ::from(a.im).unwrap(), 580 | ) 581 | }); 582 | 583 | println!("dbg z {dbg_zbs_left:?}"); 584 | println!("dbg z {dbg_zbs_right:?}"); 585 | 586 | let z_bs = concatenate![Axis(0), z_bs_left, z_bs_right]; 587 | let p_bs = concatenate![Axis(0), p_bs_left, p_bs_right]; 588 | 589 | let dbg_zbs = z_bs.mapv(|a| { 590 | Complex::new( 591 | ::from(a.re).unwrap(), 592 | ::from(a.im).unwrap(), 593 | ) 594 | }); 595 | println!("dbg z {dbg_zbs:?}"); 596 | let z_bs = concatenate![Axis(0), z_bs, Array1::from_elem(degree, Complex::i() * wo)]; 597 | let z_bs = concatenate![ 598 | Axis(0), 599 | z_bs, 600 | Array1::from_elem(degree, Complex::new(T::zero(), -T::one()) * wo) 601 | ]; 602 | 603 | println!("degree {degree} p length: {}", p.len()); 604 | let factor = match z.len().cmp(&p.len()) { 605 | std::cmp::Ordering::Less => { 606 | let t_z = concatenate![Axis(0), -&z, Array1::ones(p.len() - z.len())]; 607 | (t_z / -&p).product() 608 | } 609 | std::cmp::Ordering::Equal => (-&z / -&p).product(), 610 | std::cmp::Ordering::Greater => { 611 | let t_p = concatenate![Axis(0), -&p, Array1::ones(z.len() - p.len())]; 612 | (&-z / t_p).product() 613 | } 614 | }; 615 | 616 | let k_bs = k * factor.re; 617 | 618 | GenericZpk { 619 | z: z_bs, 620 | p: p_bs, 621 | k: k_bs, 622 | } 623 | } 624 | -------------------------------------------------------------------------------- /src/signal/convolution/mod.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, Axis, Dimension}; 2 | use num::Num; 3 | 4 | #[allow(unused)] 5 | pub fn convolve( 6 | mut in1: Array, 7 | mut in2: Array, 8 | method: ConvolveMethod, 9 | mode: ConvolveMode, 10 | ) -> Array { 11 | if inputs_swap_needed(mode, in1.shape(), in2.shape()) { 12 | std::mem::swap(&mut in1, &mut in2); 13 | } 14 | 15 | match method { 16 | ConvolveMethod::Auto => todo!(), 17 | ConvolveMethod::Fft => fft_convolve(in1, in2, mode), 18 | ConvolveMethod::Direct => todo!(), 19 | } 20 | } 21 | 22 | #[allow(unused)] 23 | pub fn fft_convolve( 24 | in1: Array, 25 | in2: Array, 26 | mode: ConvolveMode, 27 | ) -> Array { 28 | if in1.ndim() == 0 { 29 | return in1 * in2; 30 | } 31 | 32 | if in1.is_empty() || in2.is_empty() { 33 | return Default::default(); 34 | } 35 | 36 | todo!() 37 | } 38 | 39 | #[allow(unused)] 40 | fn _init_freq_conv_axes( 41 | in1: Array, 42 | in2: Array, 43 | mode: ConvolveMode, 44 | axes: impl Into>>, 45 | sorted_axes: bool, 46 | ) { 47 | let axes: Option<_> = axes.into(); 48 | let s1 = in1.shape(); 49 | let s2 = in2.shape(); 50 | 51 | let noaxes = axes.is_none(); 52 | } 53 | 54 | fn inputs_swap_needed(mode: ConvolveMode, shape1: &[usize], shape2: &[usize]) -> bool { 55 | debug_assert_eq!(shape1.len(), shape2.len()); 56 | 57 | if !matches!(mode, ConvolveMode::Valid) { 58 | return false; 59 | } 60 | 61 | shape1.iter().zip(shape2).all(|(a, b)| a >= b) | shape2.iter().zip(shape1).all(|(a, b)| a >= b) 62 | } 63 | 64 | #[allow(unused)] 65 | #[derive(Clone, Copy, PartialEq, Eq)] 66 | pub enum ConvolveMethod { 67 | Auto, 68 | Fft, 69 | Direct, 70 | } 71 | 72 | #[allow(unused)] 73 | #[derive(Clone, Copy, PartialEq, Eq)] 74 | pub enum ConvolveMode { 75 | Full, 76 | Valid, 77 | Same, 78 | } 79 | -------------------------------------------------------------------------------- /src/signal/error.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use thiserror::Error; 3 | 4 | use super::filter_design; 5 | 6 | #[derive(Debug, Error)] 7 | pub enum Error { 8 | #[error(transparent)] 9 | IIRFilter(#[from] filter_design::error::Error), 10 | } 11 | -------------------------------------------------------------------------------- /src/signal/filter_design/bessel.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{array, Array1}; 2 | use num::{complex::ComplexFloat, traits::FloatConst, Complex, Float, NumCast, One, Zero}; 3 | use thiserror::Error; 4 | 5 | use crate::{ 6 | optimize::{util::Espilon, Metric}, 7 | signal::{output_type::GenericZpk, tools::polyval}, 8 | special::kve, 9 | tools::complex::normalize_zeros, 10 | }; 11 | 12 | use super::{GenericIIRFilterSettings, ProtoIIRFilter}; 13 | 14 | pub struct BesselFilter { 15 | pub norm: BesselNorm, 16 | pub settings: GenericIIRFilterSettings, 17 | } 18 | 19 | impl ProtoIIRFilter 20 | for BesselFilter 21 | where 22 | Complex: Espilon, 23 | { 24 | fn proto_filter( 25 | &self, 26 | ) -> Result, crate::signal::error::Error> { 27 | Ok(besselap(self.settings.order, self.norm).map_err(super::Error::from)?) 28 | } 29 | 30 | fn filter_settings(&self) -> &GenericIIRFilterSettings { 31 | &self.settings 32 | } 33 | } 34 | 35 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 36 | pub enum BesselNorm { 37 | Phase, 38 | Delay, 39 | Mag, 40 | } 41 | 42 | pub fn besselap( 43 | order: u32, 44 | norm: BesselNorm, 45 | ) -> Result, Error> 46 | where 47 | Complex: Espilon, 48 | { 49 | let z = array![]; 50 | let mut p: Array1>; 51 | let mut k = T::one(); 52 | if order == 0 { 53 | p = array![]; 54 | } else { 55 | let a_last: T = (_falling_factorial::(2 * order, order) 56 | / T::from(2.0).unwrap().powi(order as i32)) 57 | .floor(); 58 | p = _bessel_zeros::(order)? 59 | .into_iter() 60 | .map(|a| Complex::new(T::from(1.0).unwrap(), T::zero()) / a) 61 | .collect(); 62 | 63 | if norm == BesselNorm::Delay || norm == BesselNorm::Mag { 64 | k = a_last; 65 | if norm == BesselNorm::Mag { 66 | let norm_factor = _norm_factor(p.clone(), k); 67 | p = p.mapv(|a| a / norm_factor); 68 | k = norm_factor.powf(-T::from(order).unwrap()) * a_last; 69 | } 70 | } else { 71 | p.iter_mut().for_each(|a| { 72 | *a = *a 73 | * T::from(10.0) 74 | .unwrap() 75 | .powf(-a_last.log10() / T::from(order).unwrap()) 76 | }); 77 | } 78 | } 79 | let p = normalize_zeros(p); 80 | Ok(GenericZpk { z, p, k }) 81 | } 82 | 83 | fn _norm_factor(p: Array1>, k: T) -> T { 84 | let g = move |w: T| { 85 | let tmp = p.mapv(|a| Complex::i() * w - a); 86 | let tmp = Complex::new(k, T::zero()) / tmp.product(); 87 | tmp.abs() 88 | }; 89 | let cutoff = move |w: T| g(w) - T::one() / T::from(2).unwrap().sqrt(); 90 | 91 | let res = crate::optimize::root_scalar::secant_method( 92 | cutoff, 93 | T::from(1.5).unwrap(), 94 | T::from(1.5 * (1.0 + 1.0e-4)).unwrap(), 95 | None, 96 | ); 97 | 98 | res.sol_x.unwrap() 99 | } 100 | 101 | fn _falling_factorial(x: u32, n: u32) -> T { 102 | let mut y = 1.0; 103 | 104 | for i in (x - n + 1)..(x + 1) { 105 | y *= i as f64; 106 | } 107 | T::from(y).unwrap() 108 | } 109 | 110 | fn _bessel_zeros( 111 | order: u32, 112 | ) -> Result>, Error> 113 | where 114 | Complex: Espilon, 115 | { 116 | if order == 0 { 117 | return Ok(array![]); 118 | } 119 | 120 | let x0 = _campos_zeros(order); 121 | let f = |x: Complex| { 122 | let x = Complex::new( 123 | ::from(x.re).unwrap(), 124 | ::from(x.im).unwrap(), 125 | ); 126 | let r = kve(order as f64 + 0.5, 1.0 / x); 127 | Complex::new(T::from(r.re).unwrap(), T::from(r.im).unwrap()) 128 | }; 129 | 130 | let fp = |x: Complex| { 131 | let x = Complex::new( 132 | ::from(x.re).unwrap(), 133 | ::from(x.im).unwrap(), 134 | ); 135 | let order = order as f64; 136 | 137 | let first = kve(order - 0.5, 1.0 / x) / (2.0 * x.powi(2)); 138 | let second = kve(order + 0.5, 1.0 / x) / x.powi(2); 139 | let third = kve(order + 1.5, 1.0 / x) / (2.0 * x.powi(2)); 140 | let r = first - second + third; 141 | Complex::new(T::from(r.re).unwrap(), T::from(r.im).unwrap()) 142 | }; 143 | let mut x = _aberth(f, fp, &x0)?; 144 | 145 | for i in &mut x { 146 | let result = crate::optimize::root_scalar::newton::newton_method(f, fp, *i, None); 147 | 148 | *i = result.sol_x.unwrap(); 149 | } 150 | 151 | let clone = x.clone().into_iter().map(|a| a.conj()).rev(); 152 | 153 | let temp = x.iter().copied().zip(clone); 154 | let x: Array1> = temp.map(|(a, b)| (a + b) / T::from(2.0).unwrap()).collect(); 155 | 156 | Ok(x) 157 | } 158 | 159 | fn _aberth< 160 | T: Float + FloatConst, 161 | F: Fn(Complex) -> Complex, 162 | FP: Fn(Complex) -> Complex, 163 | >( 164 | f: F, 165 | fp: FP, 166 | x0: &[Complex], 167 | ) -> Result>, Error> { 168 | let mut zs = x0.to_vec(); 169 | let mut new_zs = zs.clone(); 170 | let tol = T::from(1e-16).unwrap(); 171 | 'iteration: loop { 172 | for i in 0..(x0.len()) { 173 | let p_of_z = f(zs[i]); 174 | let dydx_of_z = fp(zs[i]); 175 | 176 | let sum: Complex = (0..zs.len()) 177 | .filter(|&k| k != i) 178 | .fold(Complex::zero(), |acc: Complex, k| { 179 | acc + Complex::::one() / (zs[i] - zs[k]) 180 | }); 181 | 182 | let new_z = zs[i] + p_of_z / (p_of_z * sum - dydx_of_z); 183 | new_zs[i] = new_z; 184 | if new_z.re.is_nan() 185 | || new_z.im.is_nan() 186 | || new_z.re.is_infinite() 187 | || new_z.im.is_infinite() 188 | { 189 | break 'iteration; 190 | } 191 | let err = (new_z - zs[i]).abs(); 192 | if err < tol { 193 | return Ok(new_zs); 194 | } 195 | 196 | zs.clone_from(&new_zs); 197 | } 198 | } 199 | 200 | Err(Error::Converge) 201 | } 202 | 203 | // verified with python 204 | fn _campos_zeros(order: u32) -> Vec> { 205 | let n = order as _; 206 | if n == 1.0 { 207 | return vec![Complex::new(-T::one(), T::zero())]; 208 | } 209 | let s = polyval(n, [0.0, 0.0, 2.0, 0.0, -3.0, 1.0]); 210 | let b3 = polyval(n, [16.0, -8.0]) / s; 211 | let b2 = polyval(n, [-24.0, -12.0, 12.0]) / s; 212 | let b1 = polyval(n, [8.0, 24.0, -12.0, -2.0]) / s; 213 | let b0 = polyval(n, [0.0, -6.0, 0.0, 5.0, -1.0]) / s; 214 | 215 | let r = polyval(n, [0.0, 0.0, 2.0, 1.0]); 216 | 217 | let a1 = polyval(n, [-6.0, -6.0]) / r; 218 | let a2 = 6.0 / r; 219 | 220 | let k = 1..(order + 1); 221 | 222 | let x = k 223 | .clone() 224 | .map(|a| polyval(Complex::new(a as f64, 0.0), [0.0.into(), a1, a2])) 225 | .collect::>(); 226 | let y = k 227 | .map(|a| polyval(Complex::new(a as f64, 0.0), [b0, b1, b2, b3])) 228 | .collect::>(); 229 | 230 | assert_eq!(x.len(), y.len()); 231 | x.iter() 232 | .zip(y) 233 | .map(|(x, y)| *x + Complex::new(0.0, 1.0) * y) 234 | .map(|a| Complex::new(T::from(a.re).unwrap(), T::from(a.im).unwrap())) 235 | .collect::>() 236 | } 237 | 238 | #[derive(Debug, Error, Clone)] 239 | pub enum Error { 240 | #[error("failed to converge in aberth method")] 241 | Converge, 242 | } 243 | -------------------------------------------------------------------------------- /src/signal/filter_design/butter.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{array, Array1}; 2 | use num::{complex::ComplexFloat, traits::FloatConst, Complex, Float}; 3 | 4 | use crate::signal::output_type::GenericZpk; 5 | 6 | use super::{error::Infallible, GenericIIRFilterSettings, ProtoIIRFilter}; 7 | 8 | pub struct ButterFilter { 9 | pub settings: GenericIIRFilterSettings, 10 | } 11 | 12 | impl ProtoIIRFilter for ButterFilter { 13 | fn proto_filter( 14 | &self, 15 | ) -> Result, crate::signal::error::Error> { 16 | Ok(buttap(self.settings.order).map_err(super::Error::from)?) 17 | } 18 | 19 | fn filter_settings(&self) -> &GenericIIRFilterSettings { 20 | &self.settings 21 | } 22 | } 23 | 24 | pub fn buttap(order: u32) -> Result, Infallible> { 25 | use std::f64::consts::PI; 26 | let order = order as i32; 27 | let z = array![]; 28 | let range = Array1::range( 29 | T::from(1.0 - order as f64).unwrap(), 30 | T::from(order).unwrap(), 31 | T::from(2.0).unwrap(), 32 | ); 33 | 34 | let range = range.mapv(|a| Complex::from(a)); 35 | let k = T::one(); 36 | 37 | let p = -(range 38 | .mapv(|a| Complex::i() * T::from(PI).unwrap() * a / T::from(2 * order).unwrap())) 39 | .mapv(Complex::exp); 40 | 41 | let p = p.mapv(|mut a| { 42 | if a.im == T::neg_zero() { 43 | a.im = T::zero() 44 | } 45 | a 46 | }); 47 | 48 | Ok(GenericZpk { z, p, k }) 49 | } 50 | -------------------------------------------------------------------------------- /src/signal/filter_design/butterord.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{array, Array1}; 2 | use num::{Float, NumCast}; 3 | use thiserror::Error; 4 | 5 | use crate::signal::{band_filter::GenericOrdBandFilter, GenericSampling}; 6 | 7 | use super::{OrdCompute, OrdResult}; 8 | 9 | pub struct ButterOrd { 10 | pub band_filter: GenericOrdBandFilter, 11 | pub gpass: T, 12 | pub gstop: T, 13 | pub sampling: GenericSampling, 14 | } 15 | 16 | impl OrdCompute for ButterOrd { 17 | fn compute_order(&self) -> Result, crate::signal::error::Error> { 18 | Ok( 19 | buttord(self.band_filter, self.gpass, self.gstop, self.sampling) 20 | .map_err(super::Error::from)?, 21 | ) 22 | } 23 | } 24 | 25 | pub fn buttord( 26 | band_filter: GenericOrdBandFilter, 27 | gpass: T, 28 | gstop: T, 29 | sampling: GenericSampling, 30 | ) -> Result, Error> { 31 | _validate_gpass_gstop(gpass, gstop)?; 32 | let band_filter = _validate_wp_ws(band_filter, sampling); 33 | let band_filter = _pre_warp(band_filter, sampling); 34 | let (nat, band_filter) = _find_nat_freq(band_filter, gpass, gstop); 35 | 36 | let g_stop = T::from(10.0) 37 | .unwrap() 38 | .powf(T::from(0.1).unwrap() * gstop.abs()); 39 | let g_pass = T::from(10.0) 40 | .unwrap() 41 | .powf(T::from(0.1).unwrap() * gpass.abs()); 42 | 43 | // int(ceil(log10((GSTOP - 1.0) / (GPASS - 1.0)) / (2 * log10(nat)))) 44 | let ord = ::from( 45 | (((g_stop - T::one()) / (g_pass - T::one())).log10() / (T::from(2).unwrap() * nat.log10())) 46 | .ceil(), 47 | ) 48 | .unwrap(); 49 | 50 | let w0 = (g_pass - T::one()).powf(T::from(-1.0 / (2.0 * ord)).unwrap()); 51 | let ord = ord as u32; 52 | 53 | use crate::signal::band_filter::GenericOrdBandFilterType as S; 54 | let band_filter = match band_filter.as_ref() { 55 | S::Lowpass { .. } => band_filter * w0, 56 | S::Highpass { .. } => band_filter / w0, 57 | S::Bandpass { 58 | wp_low, 59 | wp_high, 60 | ws_low, 61 | ws_high, 62 | } => todo!(), 63 | S::Bandstop { 64 | wp_low, 65 | wp_high, 66 | ws_low, 67 | ws_high, 68 | } => todo!(), 69 | }; 70 | todo!() 71 | } 72 | 73 | fn _validate_gpass_gstop(gpass: T, gstop: T) -> Result<(), Error> { 74 | if gpass <= T::zero() { 75 | return Err(Error::BadGPass(gpass.to_f64().unwrap())); 76 | } 77 | if gstop <= T::zero() { 78 | return Err(Error::BadGStop(gstop.to_f64().unwrap())); 79 | } 80 | if gpass > gstop { 81 | return Err(Error::BadGpassAndGstop { 82 | gstop: gstop.to_f64().unwrap(), 83 | gpass: gpass.to_f64().unwrap(), 84 | }); 85 | } 86 | Ok(()) 87 | } 88 | 89 | fn _validate_wp_ws( 90 | mut band_filter: GenericOrdBandFilter, 91 | sampling: GenericSampling, 92 | ) -> GenericOrdBandFilter { 93 | if let GenericSampling::Digital { fs } = &sampling { 94 | band_filter = (band_filter * T::from(2.0).unwrap()) / *fs; 95 | } 96 | 97 | band_filter 98 | } 99 | 100 | fn _pre_warp( 101 | mut band_filter: GenericOrdBandFilter, 102 | sampling: GenericSampling, 103 | ) -> GenericOrdBandFilter { 104 | use std::f64::consts::PI; 105 | 106 | if !sampling.is_analog() { 107 | band_filter = ((band_filter * T::from(PI).unwrap()) / T::from(2.0).unwrap()).tan(); 108 | } 109 | 110 | band_filter 111 | } 112 | 113 | fn _find_nat_freq( 114 | band_filter: GenericOrdBandFilter, 115 | gpass: T, 116 | gstop: T, 117 | ) -> (T, GenericOrdBandFilter) { 118 | let nat = match *band_filter.as_ref() { 119 | crate::signal::band_filter::GenericOrdBandFilterType::Lowpass { wp, ws } => wp / ws, 120 | crate::signal::band_filter::GenericOrdBandFilterType::Highpass { wp, ws } => ws / wp, 121 | crate::signal::band_filter::GenericOrdBandFilterType::Bandpass { 122 | wp_low, 123 | wp_high, 124 | ws_low, 125 | ws_high, 126 | } => { 127 | let nat_1 = (ws_low.powi(2) - wp_low * wp_high) / (ws_low * (wp_low - wp_high)); 128 | let nat_2 = (ws_high.powi(2) - wp_low * wp_high) / (ws_high * (wp_low - wp_high)); 129 | nat_1.abs().min(nat_2.abs()) 130 | } 131 | crate::signal::band_filter::GenericOrdBandFilterType::Bandstop { 132 | wp_low, 133 | wp_high, 134 | ws_low, 135 | ws_high, 136 | } => unimplemented!(), 137 | }; 138 | (nat, band_filter) 139 | } 140 | 141 | #[derive(Debug, Error)] 142 | pub enum Error { 143 | #[error("gpass should be larger than 0.0, received {0}")] 144 | BadGPass(f64), 145 | #[error("gstop should be larger than 0.0, received {0}")] 146 | BadGStop(f64), 147 | #[error("gpass should be smaller than gstop, received: gpass {gpass}, gstop {gstop}")] 148 | BadGpassAndGstop { gpass: f64, gstop: f64 }, 149 | } 150 | -------------------------------------------------------------------------------- /src/signal/filter_design/cheby1.rs: -------------------------------------------------------------------------------- 1 | use crate::signal::output_type::GenericZpk; 2 | use crate::tools::complex::normalize_zeros; 3 | use ndarray::{array, Array1}; 4 | use num::complex::ComplexFloat; 5 | use num::*; 6 | use std::f64::consts::PI; 7 | 8 | use super::{error::Infallible, GenericIIRFilterSettings, ProtoIIRFilter}; 9 | 10 | pub struct Cheby1Filter { 11 | pub rp: T, 12 | pub settings: GenericIIRFilterSettings, 13 | } 14 | 15 | impl ProtoIIRFilter for Cheby1Filter { 16 | fn proto_filter( 17 | &self, 18 | ) -> Result, crate::signal::error::Error> { 19 | Ok(cheb1ap(self.settings.order, self.rp).map_err(super::Error::from)?) 20 | } 21 | 22 | fn filter_settings(&self) -> &GenericIIRFilterSettings { 23 | &self.settings 24 | } 25 | } 26 | 27 | pub fn cheb1ap(order: u32, rp: T) -> Result, Infallible> 28 | where 29 | T: Float, 30 | { 31 | let from = |v: f64| -> T { NumCast::from(v).unwrap() }; 32 | 33 | if order == 0 { 34 | return Ok(GenericZpk { 35 | z: array![], 36 | p: array![], 37 | k: from(10.0).powf(-rp / from(20.0)), 38 | }); 39 | } 40 | 41 | let z = array![]; 42 | 43 | let eps = (from(10.0).powf(from(0.1) * rp) - from(1.0)).sqrt(); 44 | let mu = from(1.0) / from(order as _) * (from(1.0) / eps).asinh(); 45 | 46 | let m = Array1::range(from(1.0 - (order as f64)), from(order as _), from(2.0)); 47 | let theta = m 48 | .mapv(|a| a * from(PI)) 49 | .mapv(|a| a / from(2.0 * (order as f64))); 50 | let p = -(theta.mapv(|a| Complex::new(mu, from(0.0)) + Complex::i() * a)).mapv(Complex::sinh); 51 | 52 | let mut k = (-&p).product().re; 53 | if order % 2 == 0 { 54 | k = k / (from(1.0) + eps * eps).sqrt(); 55 | } 56 | let z = normalize_zeros(z); 57 | let p = normalize_zeros(p); 58 | 59 | Ok(GenericZpk { z, p, k }) 60 | } 61 | -------------------------------------------------------------------------------- /src/signal/filter_design/cheby2.rs: -------------------------------------------------------------------------------- 1 | use crate::signal::output_type::GenericZpk; 2 | use ndarray::{array, concatenate, Array1, Axis}; 3 | use num::{complex::ComplexFloat, Complex, Float}; 4 | 5 | use super::{error::Infallible, GenericIIRFilterSettings, ProtoIIRFilter}; 6 | 7 | pub struct Cheby2Filter { 8 | pub rs: T, 9 | pub settings: GenericIIRFilterSettings, 10 | } 11 | 12 | impl ProtoIIRFilter 13 | for Cheby2Filter 14 | { 15 | fn proto_filter(&self) -> Result, crate::signal::error::Error> { 16 | Ok(cheb2ap(self.settings.order, self.rs).map_err(super::Error::from)?) 17 | } 18 | 19 | fn filter_settings(&self) -> &GenericIIRFilterSettings { 20 | &self.settings 21 | } 22 | } 23 | 24 | pub fn cheb2ap(order: u32, rs: T) -> Result, Infallible> { 25 | use std::f64::consts::PI; 26 | if order == 0 { 27 | return Ok(GenericZpk { 28 | z: array![], 29 | p: array![], 30 | k: T::one(), 31 | }); 32 | } 33 | 34 | let de = T::one() / (T::from(10).unwrap().powf(T::from(0.1).unwrap() * rs) - T::one()).sqrt(); 35 | let mu = (T::one() / de).asinh() / T::from(order).unwrap(); 36 | 37 | let m = if order % 2 == 1 { 38 | concatenate![ 39 | Axis(0), 40 | Array1::range( 41 | -T::from(order).unwrap() + T::one(), 42 | T::zero(), 43 | T::from(2).unwrap() 44 | ), 45 | Array1::range( 46 | T::from(2).unwrap(), 47 | T::from(order).unwrap(), 48 | T::from(2).unwrap() 49 | ) 50 | ] 51 | } else { 52 | Array1::range( 53 | -T::from(order).unwrap() + T::one(), 54 | T::from(order).unwrap(), 55 | T::from(2).unwrap(), 56 | ) 57 | } 58 | .mapv(Complex::from); 59 | let z = -m 60 | .mapv(|a| a * T::from(PI).unwrap()) 61 | .mapv(|a| a / (T::from(2 * order).unwrap())) 62 | .mapv(Complex::sin) 63 | .mapv(|a| Complex::::i() / a) 64 | .map(Complex::conj); 65 | 66 | let order = T::from(order).unwrap(); 67 | 68 | let p = -Array1::range(-order + T::one(), order, T::from(2).unwrap()) 69 | .mapv(|a| a / (T::from(2).unwrap() * order)) 70 | .mapv(|a| Complex::i() * T::from(PI).unwrap() * a) 71 | .mapv(Complex::exp); 72 | let p = p.mapv(|a| { 73 | let re = mu.sinh() * a.re; 74 | let im = mu.cosh() * a.im; 75 | Complex::new(re, im) 76 | }); 77 | let p = p.map(Complex::inv); 78 | 79 | let k = match z.len().cmp(&p.len()) { 80 | std::cmp::Ordering::Less => { 81 | let tmp_z = concatenate![Axis(0), -&z, Array1::ones(p.len() - z.len())]; 82 | (-&p / tmp_z).product() 83 | } 84 | std::cmp::Ordering::Equal => (-&p / -&z).product(), 85 | std::cmp::Ordering::Greater => { 86 | let tmp_p = concatenate![Axis(0), -&p, Array1::ones(z.len() - p.len())]; 87 | (tmp_p / -&z).product() 88 | } 89 | } 90 | .re; 91 | Ok(GenericZpk { z, p, k }) 92 | } 93 | -------------------------------------------------------------------------------- /src/signal/filter_design/ellip.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/signal/filter_design/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Debug, Error)] 4 | pub enum Error { 5 | #[error(transparent)] 6 | Bessel(#[from] super::bessel::Error), 7 | #[error(transparent)] 8 | ButterOrd(#[from] super::butterord::Error), 9 | #[error("{0}")] 10 | Infallible(#[from] Infallible), 11 | } 12 | 13 | #[derive(Debug, Error)] 14 | pub enum Infallible {} 15 | -------------------------------------------------------------------------------- /src/signal/filter_design/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | band_filter::GenericBandFilter, 3 | output_type::{DesiredFilterOutput, GenericFilterOutput, GenericZpk}, 4 | GenericSampling, 5 | }; 6 | use crate::signal::{band_filter::lp2bf_zpk, tools::bilinear_zpk}; 7 | use num::{complex::ComplexFloat, traits::FloatConst, Float}; 8 | 9 | pub mod bessel; 10 | pub mod butter; 11 | #[allow(unused)] 12 | pub mod butterord; 13 | pub mod cheby1; 14 | pub mod cheby2; 15 | pub mod error; 16 | // pub mod ellip; 17 | pub use error::Error; 18 | 19 | pub use bessel::{besselap, BesselFilter, BesselNorm}; 20 | pub use butter::{buttap, ButterFilter}; 21 | pub use cheby1::{cheb1ap, Cheby1Filter}; 22 | pub use cheby2::{cheb2ap, Cheby2Filter}; 23 | 24 | /// Generic iir_filter 25 | /// 26 | /// Takes a filter prototype and returns the final filter in the desired output 27 | pub fn iir_filter( 28 | proto: GenericZpk, 29 | _order: u32, 30 | mut band_filter: GenericBandFilter, 31 | mut analog: GenericSampling, 32 | desired_output: DesiredFilterOutput, 33 | ) -> Result, super::error::Error> 34 | where 35 | T: Float + FloatConst + ComplexFloat, 36 | { 37 | use std::f64::consts::PI; 38 | 39 | let mut warped: GenericBandFilter = band_filter; 40 | match &mut analog { 41 | GenericSampling::Analog => {} 42 | GenericSampling::Digital { fs } => { 43 | band_filter = (band_filter * T::from(2).unwrap()) / *fs; 44 | *fs = T::from(2).unwrap(); 45 | let tmp: GenericBandFilter<_> = ((band_filter * T::from(PI).unwrap()) / *fs).tan(); 46 | warped = tmp * T::from(4).unwrap(); 47 | } 48 | } 49 | 50 | let mut result = lp2bf_zpk(proto, warped); 51 | 52 | if let GenericSampling::Digital { fs } = &analog { 53 | result = bilinear_zpk(result, *fs); 54 | } 55 | 56 | Ok(GenericFilterOutput::get_output(result, desired_output)) 57 | } 58 | 59 | pub struct GenericIIRFilterSettings { 60 | pub order: u32, 61 | pub band_filter: GenericBandFilter, 62 | pub analog: GenericSampling, 63 | } 64 | 65 | pub trait ProtoIIRFilter { 66 | fn proto_filter(&self) -> Result, crate::signal::error::Error>; 67 | 68 | fn filter_settings(&self) -> &GenericIIRFilterSettings; 69 | } 70 | 71 | pub trait IIRFilterDesign: ProtoIIRFilter { 72 | fn compute_filter( 73 | &self, 74 | desired_output: DesiredFilterOutput, 75 | ) -> Result, super::error::Error> { 76 | let proto = self.proto_filter()?; 77 | let settings = self.filter_settings(); 78 | iir_filter( 79 | proto, 80 | settings.order, 81 | settings.band_filter, 82 | settings.analog, 83 | desired_output, 84 | ) 85 | } 86 | } 87 | 88 | impl IIRFilterDesign for K where K: ProtoIIRFilter {} 89 | 90 | pub trait OrdCompute { 91 | fn compute_order(&self) -> Result, crate::signal::error::Error>; 92 | } 93 | 94 | pub struct OrdResult { 95 | pub order: u32, 96 | pub filter: GenericBandFilter, 97 | } 98 | -------------------------------------------------------------------------------- /src/signal/fir_filter_design/firwin1.rs: -------------------------------------------------------------------------------- 1 | use crate::signal::band_filter::GenericBandFilter; 2 | use crate::signal::output_type::{Ba, FilterOutput, GenericBa, GenericFilterOutput}; 3 | use crate::signal::{BandFilter, GenericSampling, Sampling}; 4 | use crate::special::sinc; 5 | use ndarray::{array, Array1}; 6 | use num::Float; 7 | 8 | use super::windows::{get_window, WindowType}; 9 | use super::{kaiser_atten, kaiser_beta, GenericFIRFilterSettings}; 10 | 11 | pub struct Firwin1Filter { 12 | pub settings: GenericFIRFilterSettings, 13 | } 14 | 15 | impl Firwin1Filter { 16 | pub fn firwin(self) -> FilterOutput { 17 | let GenericFIRFilterSettings { 18 | numtaps, 19 | cutoff, 20 | width, 21 | window, 22 | scale, 23 | sampling, 24 | } = self.settings; 25 | 26 | firwin(numtaps, cutoff, width, window, scale, sampling) 27 | } 28 | } 29 | 30 | pub fn firwin( 31 | numtaps: i64, 32 | cutoff: GenericBandFilter, 33 | width: Option, 34 | mut window: WindowType, 35 | scale: bool, 36 | sampling: GenericSampling, 37 | ) -> GenericFilterOutput { 38 | let nyq = match sampling { 39 | GenericSampling::Digital { fs } => T::from(0.5).unwrap() * fs, 40 | GenericSampling::Analog => T::one(), 41 | }; 42 | let cutoff = cutoff / nyq; 43 | 44 | let pass_zero = cutoff.pass_zero(); 45 | 46 | let pass_nyquist = cutoff.pass_nyquist(pass_zero); 47 | 48 | if pass_nyquist && numtaps % 2 == 0 { 49 | panic!("A filter with an even number of coefficients must have zero response at the Nyquist frequency."); 50 | } 51 | 52 | if let Some(width) = width { 53 | let atten = kaiser_atten(numtaps, width); 54 | let beta = kaiser_beta(atten); 55 | window = WindowType::Kaiser { beta }; 56 | } 57 | 58 | let mut cutoff = cutoff.to_vec(); 59 | if pass_zero { 60 | cutoff.insert(0, T::zero()); 61 | } 62 | if pass_nyquist { 63 | cutoff.push(T::one()); 64 | } 65 | 66 | let size = cutoff.len(); 67 | if size % 2 != 0 { 68 | panic!(""); 69 | } 70 | 71 | let cutoff = Array1::from_vec(cutoff); 72 | let bands = cutoff.into_shape((size / 2, 2)).unwrap(); 73 | 74 | let alpha = 0.5 * ((numtaps as f64) - 1.0); 75 | 76 | let m = (0..numtaps) 77 | .map(|a| T::from((a as f64) - alpha).unwrap()) 78 | .collect::>(); 79 | 80 | let mut h = Array1::from_vec(vec![T::zero(); numtaps as usize]); 81 | 82 | for row in bands.rows() { 83 | let left = row[0]; 84 | let right = row[1]; 85 | 86 | h = h + &(sinc(m.mapv(|a| a * right)).mapv(|a| a * right)); 87 | h = h - &(sinc(m.mapv(|a| a * left)).mapv(|a| a * left)); 88 | } 89 | let win = get_window(window, numtaps as _, false); 90 | h = h * win; 91 | 92 | if scale { 93 | let first = bands.rows().into_iter().next().unwrap(); 94 | let (left, right) = (first[0], first[1]); 95 | 96 | let scale_frequency = if left == T::zero() { 97 | T::zero() 98 | } else if right == T::one() { 99 | T::one() 100 | } else { 101 | T::from(0.5).unwrap() * (left + right) 102 | }; 103 | 104 | let c = (m.mapv(|a| a * T::from(std::f64::consts::PI).unwrap() * scale_frequency)) 105 | .mapv(|a| a.cos()); 106 | let s = (c * &h).sum(); 107 | h = h.mapv(|a| a / s); 108 | } 109 | 110 | GenericFilterOutput::Ba(GenericBa { 111 | a: array![T::one()].mapv(Into::into), 112 | b: h.mapv(Into::into), 113 | }) 114 | } 115 | -------------------------------------------------------------------------------- /src/signal/fir_filter_design/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod firwin1; 2 | mod pass_zero; 3 | mod tools; 4 | pub mod windows; 5 | 6 | pub use tools::*; 7 | 8 | pub use self::windows::WindowType; 9 | use super::{band_filter::GenericBandFilter, GenericSampling}; 10 | 11 | pub struct GenericFIRFilterSettings { 12 | pub numtaps: i64, 13 | pub cutoff: GenericBandFilter, 14 | pub width: Option, 15 | pub window: WindowType, 16 | pub scale: bool, 17 | pub sampling: GenericSampling, 18 | } 19 | 20 | pub use firwin1::*; 21 | -------------------------------------------------------------------------------- /src/signal/fir_filter_design/pass_zero.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/signal/fir_filter_design/tools.rs: -------------------------------------------------------------------------------- 1 | use num::{Float, NumCast}; 2 | 3 | pub fn kaiser_beta(input: T) -> T { 4 | let input = ::from(input).unwrap(); 5 | T::from(if input > 50.0 { 6 | 0.1102 * (input - 8.7) 7 | } else if input > 21.0 { 8 | 0.5842 * (input - 21.0).powf(0.4) + 0.07886 * (input - 21.0) 9 | } else { 10 | 0.0 11 | }) 12 | .unwrap() 13 | } 14 | 15 | pub fn kaiser_atten(numtaps: i64, width: T) -> T { 16 | let width = ::from(width).unwrap(); 17 | T::from(2.285 * ((numtaps - 1) as f64) * std::f64::consts::PI * width + 7.95).unwrap() 18 | } 19 | 20 | pub fn kaiserord(ripple: T, width: T) -> (i64, T) { 21 | let ripple = ripple.abs(); 22 | 23 | if ripple < T::from(8.0).unwrap() { 24 | panic!("ripple attenuation too small!") 25 | } 26 | 27 | let beta = kaiser_beta(ripple); 28 | let ripple = ::from(ripple).unwrap(); 29 | let width = ::from(width).unwrap(); 30 | let numtaps = (ripple - 7.95) / 2.285 / (std::f64::consts::PI * width) + 1.0; 31 | 32 | let numtaps = f64::ceil(numtaps) as i64; 33 | 34 | (numtaps, beta) 35 | } 36 | -------------------------------------------------------------------------------- /src/signal/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Signal Processing 2 | //! The signal processing toolbox currently contains some filtering functions, a limited set of filter design tools,
3 | //! and a few B-spline interpolation algorithms for 1- and 2-D data. While the B-spline algorithms could
4 | //! technically be placed under the interpolation category, they are included here because they only work with
5 | //! equally-spaced data and make heavy use of filter-theory and transfer-function formalism to provide a fast B-spline transform.
6 | //! To understand this section, you will need to understand that a signal in sciport-rs is an array of real or complex numbers. 7 | //! 8 | //! ## Filter Design 9 | //! 10 | //! Time-discrete filters can be classified into finite response (FIR) filters and infinite response (IIR) filters.
11 | //! FIR filters can provide a linear phase response, whereas IIR filters cannot.
12 | //! sciport-rs provides functions for designing both types of filters. 13 | //! 14 | //! ### IIR Filter 15 | //! 16 | //! sciport-rs provides two functions to directly design IIR iirdesign and iirfilter, where the filter type (e.g., elliptic)
17 | //! is passed as an argument and several more filter design functions for specific filter types, e.g., ellip. 18 | //! ### Filter coefficients 19 | //! 20 | //! Filter coefficients can be stored in several different formats: 21 | //! - [`Ba`](`crate::signal::output_type::Ba`) 22 | //! - [`Zpk`](`crate::signal::output_type::Zpk`) 23 | //! - [`Sos`](`crate::signal::output_type::Sos`) (currently unsupported) 24 | //! 25 | //! # References: 26 | //! 27 | //! The documentation on this page is largely been copied from the [SciPy](https://docs.scipy.org/doc/scipy/tutorial/signal.html) documentation 28 | 29 | mod convolution; 30 | mod filter_design; 31 | #[allow(unused)] 32 | mod fir_filter_design; 33 | mod signal_tools; 34 | 35 | //pub use convolution::*; 36 | pub use filter_design::*; 37 | 38 | pub use fir_filter_design::{firwin, windows, Firwin1Filter, GenericFIRFilterSettings, WindowType}; 39 | use ndarray::{Array, Array1, Dimension, Ix1}; 40 | use num::Complex; 41 | 42 | pub mod band_filter; 43 | pub mod error; 44 | pub mod output_type; 45 | pub mod tools; 46 | 47 | pub use band_filter::{BandFilter, GenericBandFilter}; 48 | pub use filter_design::GenericIIRFilterSettings; 49 | pub use filter_design::IIRFilterDesign; 50 | pub use filter_design::{OrdCompute, OrdResult}; 51 | 52 | pub type Sampling = GenericSampling; 53 | 54 | #[derive(Debug, Clone, Copy)] 55 | pub enum GenericSampling { 56 | Analog, 57 | Digital { fs: T }, 58 | } 59 | 60 | impl GenericSampling { 61 | pub fn is_analog(&self) -> bool { 62 | match self { 63 | Self::Analog => true, 64 | Self::Digital { .. } => false, 65 | } 66 | } 67 | 68 | pub fn cast(self) -> GenericSampling 69 | where 70 | K: From, 71 | { 72 | match self { 73 | GenericSampling::Analog => GenericSampling::Analog, 74 | GenericSampling::Digital { fs } => GenericSampling::Digital { fs: fs.into() }, 75 | } 76 | } 77 | 78 | pub fn cast_with_fn(self, f: impl Fn(T) -> K) -> GenericSampling { 79 | match self { 80 | GenericSampling::Analog => GenericSampling::Analog, 81 | GenericSampling::Digital { fs } => GenericSampling::Digital { fs: f(fs) }, 82 | } 83 | } 84 | } 85 | 86 | pub trait Filter { 87 | fn lfilter( 88 | &self, 89 | x: Array1>, 90 | zi: Option>>, 91 | ) -> LFilterOutput; 92 | } 93 | 94 | #[derive(Debug, Clone)] 95 | pub struct LFilterOutput { 96 | pub filtered: Array, D>, 97 | pub zi: Option, D>>, 98 | } 99 | -------------------------------------------------------------------------------- /src/signal/output_type/ba.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | signal::{signal_tools::linear_filter, tools::zpk2ba}, 3 | tools::convolve1d, 4 | }; 5 | use ndarray::{Array1, Ix1}; 6 | use num::{complex::ComplexFloat, traits::FloatConst, Complex, Float, Zero}; 7 | 8 | use super::{Filter, GenericBa, GenericZpk, LFilterOutput}; 9 | 10 | impl From> for GenericBa { 11 | fn from(value: GenericZpk) -> Self { 12 | zpk2ba(value) 13 | } 14 | } 15 | #[allow(unused)] 16 | fn mul_by_x(input: &mut Vec>) { 17 | input.push(Complex::zero()); 18 | } 19 | 20 | #[allow(unused)] 21 | fn mul_by_scalar(input: &mut [Complex], scalar: Complex) { 22 | input.iter_mut().for_each(|a| *a *= scalar); 23 | } 24 | 25 | #[allow(unused)] 26 | fn sub(input: &mut [Complex], other: &[Complex]) { 27 | for (i, item) in other.iter().enumerate() { 28 | *input.get_mut(i).unwrap() = *input.get(i).unwrap() - item; 29 | } 30 | } 31 | 32 | impl Filter for GenericBa { 33 | fn lfilter( 34 | &self, 35 | x: ndarray::Array1>, 36 | zi: Option>>, 37 | ) -> LFilterOutput { 38 | let b = self.b.clone(); 39 | let a = self.a.clone(); 40 | 41 | let zi = if let Some(zi) = zi { 42 | zi 43 | } else { 44 | Array1::zeros(b.raw_dim() - 1) 45 | }; 46 | 47 | if a.len() == 1 { 48 | let b = b.mapv(|e| e / a[0]); 49 | 50 | let out_full = convolve1d(x.view(), b.view()); 51 | 52 | LFilterOutput { 53 | filtered: out_full, 54 | zi: None, 55 | } 56 | } else { 57 | LFilterOutput { 58 | filtered: linear_filter(b, a, x, zi), 59 | zi: None, 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/signal/output_type/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{tools::zpk2ba, Filter, LFilterOutput}; 2 | use ndarray::{Array1, Ix1}; 3 | use num::{complex::ComplexFloat, traits::FloatConst, Complex, Float}; 4 | mod ba; 5 | mod sos; 6 | mod zpk; 7 | 8 | #[derive(Debug, Clone, Copy)] 9 | pub enum DesiredFilterOutput { 10 | Zpk, 11 | Ba, 12 | Sos, 13 | } 14 | /// Enum containing the filter output 15 | pub type FilterOutput = GenericFilterOutput; 16 | 17 | #[derive(Debug, Clone)] 18 | pub enum GenericFilterOutput { 19 | /// See [Zpk] 20 | Zpk(GenericZpk), 21 | /// See [Ba] 22 | Ba(GenericBa), 23 | /// See [Sos] 24 | Sos(Sos), 25 | } 26 | 27 | /// # Zeros and poles representation 28 | /// 29 | /// The zpk format is a 3-tuple (z, p, k), where z is an M-length array of the complex zeros of the transfer 30 | /// function z=[z0,z1,...,zM1], p is an N-length array of the complex poles of the transfer function 31 | /// p=[p0,p1,...,pN1], 32 | /// and k is a scalar gain. These represent the digital transfer function: 33 | /// 34 | /// H(z)=k(zz0)(zz1)(zz(M1))(zp0)(zp1)(zp(N1))=ki=0M1(zzi)i=0N1(zpi) 35 | /// 36 | /// or the analog transfer function: 37 | /// 38 | /// H(s)=k(sz0)(sz1)(sz(M1))(sp0)(sp1)(sp(N1))=ki=0M1(szi)i=0N1(spi). 39 | /// 40 | /// 41 | /// Although the sets of roots are stored as vecs, their ordering does not matter: ([-1, -2], [-3, -4], 1) is the same filter as ([-2, -1], [-4, -3], 1). 42 | pub type Zpk = GenericZpk; 43 | 44 | #[derive(Debug, Clone, PartialEq, Eq)] 45 | pub struct GenericZpk { 46 | pub z: Array1>, 47 | pub p: Array1>, 48 | pub k: T, 49 | } 50 | 51 | impl GenericZpk { 52 | pub fn cast_with_fn(self, f: impl Fn(T) -> K) -> GenericZpk { 53 | let Self { z, p, k } = self; 54 | 55 | GenericZpk { 56 | z: z.mapv(|a| Complex::new(f(a.re), f(a.im))), 57 | p: p.mapv(|a| Complex::new(f(a.re), f(a.im))), 58 | k: f(k), 59 | } 60 | } 61 | } 62 | 63 | impl GenericFilterOutput { 64 | pub fn zpk(self) -> GenericZpk { 65 | match self { 66 | Self::Zpk(data) => data, 67 | _ => unreachable!(), 68 | } 69 | } 70 | 71 | pub fn ba(self) -> GenericBa { 72 | match self { 73 | Self::Ba(data) => data, 74 | _ => unreachable!(), 75 | } 76 | } 77 | } 78 | 79 | impl FilterOutput { 80 | pub fn get_output( 81 | input: GenericZpk, 82 | desired: DesiredFilterOutput, 83 | ) -> GenericFilterOutput 84 | where 85 | T: Float + FloatConst + ComplexFloat, 86 | { 87 | match desired { 88 | DesiredFilterOutput::Zpk => GenericFilterOutput::Zpk(input), 89 | DesiredFilterOutput::Ba => GenericFilterOutput::Ba(zpk2ba(input)), 90 | _ => todo!(), 91 | } 92 | } 93 | 94 | pub fn new(data: Zpk) -> Self { 95 | Self::Zpk(data) 96 | } 97 | 98 | pub fn sos(self) -> Sos { 99 | match self { 100 | Self::Sos(data) => data, 101 | _ => unreachable!(), 102 | } 103 | } 104 | } 105 | 106 | /// # Transfer function representation 107 | /// 108 | /// The ba or tf format is a 2-tuple (b, a) representing a transfer function, where b is a length M+1 array of 109 | /// coefficients of the M-order numerator polynomial, and a is a length N+1 array of coefficients of the N-order 110 | /// denominator, as positive, descending powers of the transfer function variable. So the tuple of 111 | /// b=[b0,b1,...,bM] and a=[a0,a1,...,aN] can represent an analog filter of the form: 112 | /// 113 | /// 114 | /// H(s)=b0sM+b1s(M1)++bMa0sN+a1s(N1)++aN=i=0Mbis(Mi)i=0Nais(Ni) 115 | /// 116 | /// or a discrete-time filter of the form: 117 | /// 118 | /// H(z)=b0zM+b1z(M1)++bMa0zN+a1z(N1)++aN=i=0Mbiz(Mi)i=0Naiz(Ni). 119 | /// 120 | /// This “positive powers” form is found more commonly in controls engineering. If M and N are equal 121 | /// (which is true for all filters generated by the bilinear transform), then this happens to be equivalent 122 | /// to the “negative powers” discrete-time form preferred in DSP: 123 | /// 124 | /// H(z)=b0+b1z1++bMzMa0+a1z1++aNzN=i=0Mbizii=0Naizi. 125 | /// 126 | /// Although this is true for common filters, remember that this is not true in the general case. 127 | /// If M and N are not equal, the discrete-time transfer function coefficients must first be converted 128 | /// to the “positive powers” form before finding the poles and zeros. 129 | /// 130 | /// This representation suffers from numerical error at higher orders, so other formats are preferred when possible. 131 | pub type Ba = GenericBa; 132 | 133 | #[derive(Debug, Clone)] 134 | pub struct GenericBa { 135 | pub a: Array1>, 136 | pub b: Array1>, 137 | } 138 | 139 | impl GenericBa { 140 | pub fn cast_with_fn(self, f: impl Fn(T) -> K) -> GenericBa { 141 | GenericBa { 142 | a: self.a.mapv(|a| Complex::new(f(a.re), f(a.im))), 143 | b: self.b.mapv(|a| Complex::new(f(a.re), f(a.im))), 144 | } 145 | } 146 | } 147 | 148 | #[derive(Debug, Clone)] 149 | pub struct Sos {} 150 | 151 | impl Filter for GenericFilterOutput { 152 | fn lfilter( 153 | &self, 154 | x: Array1>, 155 | zi: Option>>, 156 | ) -> LFilterOutput { 157 | match self { 158 | Self::Zpk(zpk) => zpk.lfilter(x, zi), 159 | Self::Ba(ba) => ba.lfilter(x, zi), 160 | Self::Sos(_sos) => todo!(), 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/signal/output_type/sos.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/signal/output_type/zpk.rs: -------------------------------------------------------------------------------- 1 | use super::{GenericBa, GenericZpk, LFilterOutput}; 2 | use std::ops::Mul; 3 | 4 | use ndarray::Ix1; 5 | use num::{traits::FloatConst, Complex, Float, Num}; 6 | 7 | use super::{Filter, Zpk}; 8 | 9 | impl Zpk {} 10 | 11 | impl Mul for Zpk 12 | where 13 | Complex: Mul, 14 | T: Mul, Output = Complex>, 15 | T: Mul, 16 | { 17 | type Output = Self; 18 | fn mul(mut self, rhs: T) -> Self::Output { 19 | self.z = self.z.mapv(|a| rhs * a); 20 | self.p = self.p.mapv(|a| rhs * a); 21 | self.k = rhs * self.k; 22 | 23 | self 24 | } 25 | } 26 | 27 | impl Filter for GenericZpk { 28 | fn lfilter( 29 | &self, 30 | x: ndarray::Array1>, 31 | zi: Option>>, 32 | ) -> LFilterOutput { 33 | let ba: GenericBa = self.clone().into(); 34 | ba.lfilter(x, zi) 35 | } 36 | } 37 | 38 | #[cfg(test)] 39 | mod tests { 40 | use ndarray::Array1; 41 | use ndarray_rand::{rand_distr::Normal, RandomExt}; 42 | use num::Complex; 43 | use rand::{thread_rng, Rng}; 44 | 45 | use crate::signal::output_type::GenericZpk; 46 | 47 | #[test] 48 | fn test_zpk_mul() { 49 | let z = Array1::random(6, Normal::new(10.0, 1.0).unwrap()).mapv(Complex::from); 50 | let p = Array1::random(6, Normal::new(10.0, 1.0).unwrap()).mapv(Complex::from); 51 | let k: f64 = thread_rng().gen_range(0.1..1.2); 52 | 53 | let mut zpk = GenericZpk { z, p, k }; 54 | let rhs = 6.0; 55 | let mul_zpk = zpk.clone() * rhs; 56 | 57 | zpk.z = zpk.z.mapv(|a| rhs * a); 58 | zpk.p = zpk.p.mapv(|a| rhs * a); 59 | zpk.k *= rhs; 60 | 61 | assert_eq!(mul_zpk, zpk); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/signal/signal_tools.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array1; 2 | use num::{Complex, Float}; 3 | 4 | /// Copy-pasted from scipy, this can probably be optimized 5 | pub(crate) fn c_filt( 6 | mut b: Array1>, 7 | mut a: Array1>, 8 | signal: Array1>, 9 | mut filter_state: Array1>, 10 | ) -> Array1> { 11 | let a0 = a[0]; 12 | 13 | b = b.mapv(|a| a / a0); 14 | a = a.mapv(|a| a / a0); 15 | 16 | let mut ret = Array1::>::zeros(signal.raw_dim()); 17 | 18 | let len_signal = signal.len(); 19 | 20 | for i in 0..len_signal { 21 | if b.len() > 1 { 22 | ret[i] = filter_state[0] + signal[i] * b[0]; 23 | for n in 1..(b.len() - 1) { 24 | filter_state[n - 1] = filter_state[n] + signal[i] * b[n] - ret[i] * a[n]; 25 | } 26 | filter_state[b.len() - 2] = signal[i] * b[b.len() - 1] - ret[i] * a[b.len() - 1]; 27 | } 28 | } 29 | ret 30 | } 31 | 32 | pub fn linear_filter( 33 | b: Array1>, 34 | a: Array1>, 35 | signal: Array1>, 36 | filter_state: Array1>, 37 | ) -> Array1> { 38 | c_filt(b, a, signal, filter_state) 39 | } 40 | -------------------------------------------------------------------------------- /src/signal/tools/mod.rs: -------------------------------------------------------------------------------- 1 | use super::output_type::{GenericBa, GenericZpk}; 2 | use crate::odr::polynomial::Polynomial; 3 | use ndarray::{array, concatenate, s, Array1, ArrayView1, ArrayViewMut1, Axis}; 4 | use num::{ 5 | complex::{Complex64, ComplexFloat}, 6 | traits::FloatConst, 7 | Complex, Float, Num, Zero, 8 | }; 9 | use std::fmt::Debug; 10 | 11 | pub fn bilinear_zpk(input: GenericZpk, fs: T) -> GenericZpk 12 | where 13 | T: Float, 14 | { 15 | let degree = relative_degree(&input); 16 | let GenericZpk { z, p, k } = input; 17 | 18 | let from = |a: f64| T::from(a).unwrap(); 19 | 20 | let fs2 = from(2.0) * fs; 21 | 22 | let z_z = z.mapv(|a| a + fs2) / z.mapv(|a| -a + fs2); 23 | let p_z = p.mapv(|a| a + fs2) / p.mapv(|a| -a + fs2); 24 | 25 | let z_z = concatenate![Axis(0), z_z, -Array1::ones(degree)]; 26 | 27 | let factor_map = match z.len().cmp(&p.len()) { 28 | std::cmp::Ordering::Equal => z.mapv(|a| -a + fs2) / p.mapv(|a| -a + fs2), 29 | std::cmp::Ordering::Greater => { 30 | let t_z = z.mapv(|a| -a + fs2); 31 | let t_p = p.mapv(|a| -a + fs2); 32 | let t_p = concatenate![Axis(0), t_p, Array1::ones(z.len() - p.len())]; 33 | t_z / t_p 34 | } 35 | std::cmp::Ordering::Less => { 36 | let t_z = z.mapv(|a| -a + fs2); 37 | let t_p = p.mapv(|a| -a + fs2); 38 | let t_z = concatenate![Axis(0), t_z, Array1::ones(p.len() - z.len())]; 39 | t_z / t_p 40 | } 41 | }; 42 | 43 | let k_z = k * factor_map.product().re; 44 | 45 | GenericZpk { 46 | z: z_z, 47 | p: p_z, 48 | k: k_z, 49 | } 50 | } 51 | 52 | pub fn relative_degree(input: &GenericZpk) -> usize { 53 | input.p.len() - input.z.len() 54 | } 55 | 56 | /// Compute polynomial coefficients from zeroes 57 | /// 58 | /// # Examples 59 | /// 60 | /// ```rust 61 | /// # use sciport_rs::signal::tools::poly; 62 | /// # use num::complex::Complex64; 63 | /// # use ndarray::array; 64 | /// let complex_z = array![ Complex64::new(2.1, 3.2), Complex64::new(1.0, 1.0) ]; 65 | /// 66 | /// let coeffs = poly((&complex_z).into()); 67 | /// 68 | /// ``` 69 | #[must_use] 70 | pub fn poly(zeroes: ArrayView1<'_, T>) -> Array1 { 71 | let mut coeff = array![T::one()]; 72 | for z in zeroes { 73 | let mut clone = coeff.clone(); 74 | mul_by_x(&mut coeff); 75 | mul_by_scalar(clone.view_mut(), *z); 76 | sub_coeff(coeff.slice_mut(s![1..]), &clone); 77 | } 78 | coeff 79 | } 80 | 81 | fn mul_by_x(coeff: &mut Array1) { 82 | coeff.append(Axis(0), (&array![T::zero()]).into()).unwrap(); 83 | } 84 | 85 | fn mul_by_scalar(mut coeff: ArrayViewMut1, scalar: T) { 86 | coeff.map_inplace(move |a| *a = *a * scalar); 87 | } 88 | 89 | fn sub_coeff(mut coeff: ArrayViewMut1, ar: &Array1) { 90 | for (i, c) in coeff.iter_mut().enumerate() { 91 | *c = *c - ar[i]; 92 | } 93 | } 94 | 95 | pub fn polyval> + Copy, const S: usize>(v: T, coeff: [T; S]) -> Complex { 96 | fn polyval(v: Complex, coeff: [Complex; S]) -> Complex { 97 | coeff 98 | .iter() 99 | .enumerate() 100 | .fold(Complex64::zero(), |acc, (i, item)| { 101 | acc + v.powi(i as _) * item 102 | }) 103 | } 104 | 105 | let tmp: [Complex; S] = coeff 106 | .into_iter() 107 | .map(|a| a.into()) 108 | .collect::>() 109 | .try_into() 110 | .unwrap(); 111 | polyval(v.into(), tmp) 112 | } 113 | 114 | pub fn zpk2ba(zpk: GenericZpk) -> GenericBa 115 | where 116 | T: Float + FloatConst + ComplexFloat, 117 | { 118 | let GenericZpk { z, p, k } = zpk; 119 | 120 | let pol = Polynomial::from_roots_k(z.clone(), T::one().into()); 121 | let pol = pol.saturate(); 122 | let mut b: Array1<_> = pol.iter().rev().copied().collect(); 123 | 124 | b = b.mapv(|a| a * k); 125 | 126 | let mut a: Array1<_> = Polynomial::from_roots_k(p.clone(), T::one().into()) 127 | .saturate() 128 | .into_iter() 129 | .rev() 130 | .collect(); 131 | let epsilon = T::from(10.0_f64.powi(-4)).unwrap(); 132 | 133 | let roots = z; 134 | let mut pos_roots: Vec> = 135 | roots.iter().copied().filter(|a| a.im > T::zero()).collect(); 136 | let mut neg_roots: Vec> = roots 137 | .iter() 138 | .filter(|a| a.im < T::zero()) 139 | .map(Complex::conj) 140 | .collect(); 141 | 142 | if pos_roots.len() == neg_roots.len() { 143 | sort_complex(&mut pos_roots); 144 | sort_complex(&mut neg_roots); 145 | 146 | if generic_approx_complex_relative_slice_eq( 147 | pos_roots.as_slice(), 148 | neg_roots.as_slice(), 149 | epsilon, 150 | epsilon, 151 | ) { 152 | b = b.into_iter().map(|a| a.re.into()).collect(); 153 | } 154 | } 155 | 156 | let roots = p; 157 | let mut pos_roots: Vec> = 158 | roots.iter().copied().filter(|a| a.im > T::zero()).collect(); 159 | let mut neg_roots: Vec> = roots 160 | .iter() 161 | .filter(|a| a.im < T::zero()) 162 | .map(Complex::conj) 163 | .collect(); 164 | 165 | if pos_roots.len() == neg_roots.len() { 166 | sort_complex(&mut pos_roots); 167 | sort_complex(&mut neg_roots); 168 | 169 | // println!("pos_roots: {:?}", pos_roots); 170 | // println!("neg_roots: {:?}", neg_roots); 171 | 172 | if generic_approx_complex_relative_slice_eq( 173 | pos_roots.as_slice(), 174 | neg_roots.as_slice(), 175 | epsilon, 176 | epsilon, 177 | ) { 178 | a = a.into_iter().map(|a| a.re.into()).collect(); 179 | } 180 | } 181 | // this shouldn't be here but without it nothing works! 182 | if b[0] == Complex::zero() { 183 | b.remove_index(Axis(0), 0); 184 | } 185 | 186 | if a[0] == Complex::zero() { 187 | a.remove_index(Axis(0), 0); 188 | } 189 | let a = a; 190 | let b = b; 191 | GenericBa { a, b } 192 | } 193 | 194 | /// lifted from 195 | pub fn generic_approx_relative_eq( 196 | lhs: &T, 197 | rhs: &K, 198 | epsilon: T, 199 | max_relative: T, 200 | ) -> bool { 201 | let rhs = *rhs; 202 | let lhs = *lhs; 203 | if lhs == T::from(rhs).unwrap() { 204 | return true; 205 | } 206 | if lhs.is_infinite() || rhs.is_infinite() { 207 | return false; 208 | } 209 | 210 | let abs_diff = (lhs - T::from(rhs).unwrap()).abs(); 211 | if abs_diff < epsilon { 212 | return true; 213 | } 214 | 215 | let abs_lhs = lhs.abs(); 216 | let abs_rhs = rhs.abs(); 217 | 218 | let largest = if abs_lhs > T::from(abs_rhs).unwrap() { 219 | abs_lhs 220 | } else { 221 | T::from(abs_rhs).unwrap() 222 | }; 223 | 224 | abs_diff <= largest * max_relative 225 | } 226 | pub fn generic_approx_complex_relative_eq( 227 | lhs: &Complex, 228 | rhs: &Complex, 229 | epsilon: T, 230 | max_relative: T, 231 | ) -> bool { 232 | generic_approx_relative_eq(&lhs.re, &rhs.re, epsilon, max_relative) 233 | && generic_approx_relative_eq(&lhs.im, &rhs.im, epsilon, max_relative) 234 | } 235 | 236 | pub fn generic_approx_complex_relative_eq_dbg< 237 | T: Float + Clone + Debug, 238 | K: Float + Clone + Debug, 239 | >( 240 | lhs: &Complex, 241 | rhs: &Complex, 242 | epsilon: T, 243 | max_relative: T, 244 | ) -> bool { 245 | let result = generic_approx_relative_eq(&lhs.re, &rhs.re, epsilon, max_relative) 246 | && generic_approx_relative_eq(&lhs.im, &rhs.im, epsilon, max_relative); 247 | if !result { 248 | println!("difference {lhs:?} {rhs:?}"); 249 | } 250 | result 251 | } 252 | 253 | pub fn generic_approx_relative_slice_eq( 254 | lhs: &[T], 255 | rhs: &[T], 256 | epsilon: T, 257 | max_relative: T, 258 | ) -> bool { 259 | let zip = lhs.iter().zip(rhs.iter()); 260 | zip.fold(true, |acc, (lhs, rhs)| { 261 | acc && generic_approx_relative_eq(lhs, rhs, epsilon, max_relative) 262 | }) 263 | } 264 | 265 | pub fn generic_approx_complex_relative_slice_eq( 266 | lhs: &[Complex], 267 | rhs: &[Complex], 268 | epsilon: T, 269 | max_relative: T, 270 | ) -> bool { 271 | let zip = lhs.iter().zip(rhs.iter()); 272 | zip.enumerate().fold(true, |acc, (i, (lhs, rhs))| { 273 | let new = generic_approx_complex_relative_eq(lhs, rhs, epsilon, max_relative); 274 | if !new { 275 | println!("difference at {i}"); 276 | } 277 | acc && new 278 | }) 279 | } 280 | 281 | pub fn generic_approx_complex_relative_slice_eq_dbg< 282 | T: Float + Clone + Debug, 283 | K: Float + Clone + Debug, 284 | >( 285 | lhs: &[Complex], 286 | rhs: &[Complex], 287 | epsilon: T, 288 | max_relative: T, 289 | ) -> bool { 290 | let zip = lhs.iter().zip(rhs.iter()); 291 | zip.fold(true, |acc, (lhs, rhs)| { 292 | let new = generic_approx_complex_relative_eq_dbg(lhs, rhs, epsilon, max_relative); 293 | 294 | acc && new 295 | }) 296 | } 297 | 298 | fn sort_complex(cxs: &mut [Complex]) 299 | where 300 | T: Float, 301 | { 302 | cxs.sort_by(|a, b| { 303 | a.re.partial_cmp(&b.re).map_or_else( 304 | || a.im.partial_cmp(&b.im).unwrap_or(std::cmp::Ordering::Equal), 305 | |comp| match comp { 306 | std::cmp::Ordering::Less => std::cmp::Ordering::Less, 307 | std::cmp::Ordering::Greater => std::cmp::Ordering::Greater, 308 | std::cmp::Ordering::Equal => { 309 | a.im.partial_cmp(&b.im).unwrap_or(std::cmp::Ordering::Equal) 310 | } 311 | }, 312 | ) 313 | }); 314 | } 315 | -------------------------------------------------------------------------------- /src/special/kv.rs: -------------------------------------------------------------------------------- 1 | use num::{Complex, Zero}; 2 | 3 | /// Modified Bessel function of the second kind of real order v 4 | /// 5 | /// Returns the modified Bessel function of the second kind for real order v at complex z. 6 | /// 7 | /// # Notes 8 | /// 9 | /// Wrapper on [complex_bessel_rs] 10 | /// 11 | pub fn kv(mut v: f64, mut z: Complex) -> Complex { 12 | if z.is_nan() { 13 | z = Complex::zero(); 14 | } 15 | 16 | if v.is_nan() { 17 | v = 0.0; 18 | } 19 | let res = complex_bessel_rs::bessel_k::bessel_k(v, z); 20 | if res.is_err() { 21 | println!("{v} {z}"); 22 | } 23 | 24 | res.unwrap() 25 | } 26 | 27 | /// Exponentially scaled modified Bessel function of the second kind. 28 | /// 29 | /// Returns the exponentially scaled, modified Bessel function of the second kind (sometimes called
30 | /// the third kind) for real order v at complex z: 31 | /// 32 | /// ``` 33 | /// # use sciport_rs::special::*; 34 | /// # let v = 1.0; 35 | /// # let z = num::Complex::new(1.0, 0.0); 36 | /// 37 | /// assert_eq!(kve(v, z), kv(v, z) * z.exp()) 38 | /// ``` 39 | pub fn kve(v: f64, z: Complex) -> Complex { 40 | kv(v, z) * z.exp() 41 | } 42 | 43 | #[cfg(test)] 44 | mod tests { 45 | use num::complex::Complex64; 46 | 47 | use crate::special::kve; 48 | 49 | use super::kv; 50 | 51 | #[test] 52 | fn test_kv() { 53 | let c1 = Complex64::new(1000.0, 0.0); 54 | let _res = kv(-3.5, c1); 55 | let _res2 = kve(0.0, 1.0.into()); 56 | let res3 = kve(3.5, 1.0 / Complex64::new(1.2, 0.3)); 57 | println!("{res3}"); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/special/mod.rs: -------------------------------------------------------------------------------- 1 | mod kv; 2 | mod trig; 3 | pub use kv::*; 4 | use ndarray::Array1; 5 | use num::Complex; 6 | pub use trig::*; 7 | 8 | pub fn i0(x: Array1) -> Result>, i32> { 9 | // TODO implement this using cephes 10 | x.mapv(|v| complex_bessel_rs::bessel_i::bessel_i(0.0, v.into())) 11 | .into_iter() 12 | .collect() 13 | } 14 | -------------------------------------------------------------------------------- /src/special/trig.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, Dimension}; 2 | use num::{Float, NumCast}; 3 | 4 | pub fn sinc(x: Array) -> Array { 5 | let x = x.mapv(|a| ::from(a).unwrap()); 6 | x.mapv(|a| { 7 | if a == 0.0 { 8 | T::one() 9 | } else { 10 | T::from((a * std::f64::consts::PI).sin() / (a * std::f64::consts::PI)).unwrap() 11 | } 12 | }) 13 | } 14 | -------------------------------------------------------------------------------- /src/tools/complex.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, Dimension}; 2 | use num::{Complex, Float}; 3 | 4 | pub fn normalize_zeros(a: Array, D>) -> Array, D> { 5 | a.mapv(|mut a| { 6 | if a.im == T::neg_zero() { 7 | a.im = T::zero() 8 | } 9 | a 10 | }) 11 | } 12 | -------------------------------------------------------------------------------- /src/tools/mod.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{concatenate, Array1, ArrayView1, Axis}; 2 | use num::{Complex, Float}; 3 | 4 | pub(crate) mod complex; 5 | 6 | pub fn convolve1d( 7 | data: ArrayView1>, 8 | window: ArrayView1>, 9 | ) -> Array1> { 10 | if window.len() > data.len() { 11 | return convolve1d(window, data); 12 | } 13 | let data = concatenate![Axis(0), Array1::zeros(window.len() - 1), data,]; 14 | let mut w = window.view(); 15 | w.invert_axis(Axis(0)); 16 | 17 | data.windows(w.len()) 18 | .into_iter() 19 | .map(|x| (&x * &w).sum()) 20 | .collect() 21 | } 22 | -------------------------------------------------------------------------------- /tests/optimize/least_square.rs: -------------------------------------------------------------------------------- 1 | use numpy::Complex64; 2 | 3 | use sciport_rs::optimize::least_square; 4 | 5 | #[test] 6 | fn least_square_poly_fit() -> Result<(), String> { 7 | println!("[ Real Example ]"); 8 | let x = (0..100).map(|v| v as f64).collect::>(); 9 | let y = x 10 | .iter() 11 | .map(|f| 10.0 - 0.25 * f + 0.11 * f * f) 12 | .collect::>(); 13 | 14 | let res = least_square::poly_fit(&x, &y, 1)?; 15 | println!("{:?}", res); 16 | let res = least_square::poly_fit(&x, &y, 2)?; 17 | println!("{:?}", res); 18 | let res = least_square::poly_fit(&x, &y, 3)?; 19 | println!("{:?}", res); 20 | 21 | println!("[ Complex Example ]"); 22 | let x = (0..100) 23 | .map(|v| Complex64::new(v as f64, v as f64 * 2.0)) 24 | .collect::>(); 25 | let a = Complex64::new(10.0, 5.5); 26 | let b = Complex64::new(-2.0, 3.3); 27 | let c = Complex64::new(0.22, 3.9); 28 | let d = Complex64::new(12.0, -9.0); 29 | let y = x 30 | .iter() 31 | .map(|f| d - c * f + b * f.powi(2) + a * f.powi(3)) 32 | .collect::>(); 33 | let res = least_square::poly_fit(&x, &y, 1)?; 34 | println!("{:#?}", res); 35 | let res = least_square::poly_fit(&x, &y, 2)?; 36 | println!("{:#?}", res); 37 | let res = least_square::poly_fit(&x, &y, 3)?; 38 | println!("{:#?}", res); 39 | let res = least_square::poly_fit(&x, &y, 4)?; 40 | println!("{:#?}", res); 41 | 42 | Ok(()) 43 | } 44 | -------------------------------------------------------------------------------- /tests/optimize/main.rs: -------------------------------------------------------------------------------- 1 | mod least_square; 2 | mod min_scaler; 3 | mod polynomial; 4 | mod root_scalar; 5 | -------------------------------------------------------------------------------- /tests/optimize/min_scaler.rs: -------------------------------------------------------------------------------- 1 | use sciport_rs::optimize::min_scalar; 2 | use sciport_rs::optimize::*; 3 | 4 | #[test] 5 | fn min_scalar() { 6 | let fun = |x: f64| x.powi(2); 7 | 8 | let bracket = (-20f64, 20f64); 9 | let criteria = Some( 10 | OptimizeCriteria::empty() 11 | .set_xatol(Some(1e-9f64)) 12 | .set_maxiter(Some(5000)), 13 | ); 14 | 15 | let res = min_scalar::golden::golden_method(fun, bracket, criteria); 16 | 17 | println!("{}", res); 18 | } 19 | -------------------------------------------------------------------------------- /tests/optimize/polynomial.rs: -------------------------------------------------------------------------------- 1 | use num::complex::Complex64; 2 | use sciport_rs::odr::polynomial::Polynomial; 3 | 4 | fn polynomial_equal(p1: &Polynomial, p2: &Polynomial) -> bool { 5 | p1.degree() == p2.degree() 6 | && p1 7 | .iter() 8 | .zip(p2.iter()) 9 | .fold(true, |acc, (a, b)| acc && (a - b).abs() < 1e-12) 10 | } 11 | 12 | #[test] 13 | pub fn polynomial_test() { 14 | let p1 = Polynomial::from(vec![-3.0, 1.0]); 15 | let p2 = Polynomial::from_iter([-3.0, 1.0].iter()); 16 | 17 | assert!(polynomial_equal(&p1, &p2)); 18 | 19 | let p3 = Polynomial::from(vec![2.0, 3.0]); 20 | let p4 = Polynomial::from(vec![1.0, 1.0]) + Polynomial::from(vec![1.0, 2.0]); 21 | let p5 = Polynomial::from(vec![3.0, 4.0]) - Polynomial::from(vec![1.0, 1.0]); 22 | 23 | assert!(polynomial_equal(&p3, &p4)); 24 | assert!(polynomial_equal(&p3, &p5)); 25 | 26 | let p6 = Polynomial::from(vec![1.0, 2.0, 1.0]); 27 | let p7 = Polynomial::from(vec![1.0, 1.0]) * Polynomial::from(vec![1.0, 1.0]); 28 | let p8 = Polynomial::from(vec![2.0, 4.0, 2.0]) * 0.5; 29 | assert!(polynomial_equal(&p6, &p7)); 30 | assert!(polynomial_equal(&p6, &p8)); 31 | 32 | let p9 = Polynomial::from(vec![1.0, 2.0, 3.0]); 33 | let p10 = Polynomial::from(vec![1.0, 1.0, 1.0, 1.0]).differentiate(); 34 | assert!(polynomial_equal(&p9, &p10)); 35 | 36 | let p11 = Polynomial::from_roots_k(vec![1.0, 2.0, 3.0], 1.0); 37 | let p12 = Polynomial::from(vec![-1.0, 1.0]) 38 | * Polynomial::from(vec![-2.0, 1.0]) 39 | * Polynomial::from(vec![-3.0, 1.0]); 40 | assert!(polynomial_equal(&p11, &p12)); 41 | 42 | let sol = p11.roots(); 43 | println!("{:?}", sol); 44 | 45 | let roots = vec![ 46 | Complex64::new(1.0, 2.0), 47 | Complex64::new(2.0, 3.0), 48 | Complex64::new(-1.0, -2.0), 49 | ]; 50 | let k = Complex64::new(1.0, 0.0); 51 | let p13 = Polynomial::from_roots_k(roots, k); 52 | let sol = p13.roots(); 53 | println!("{:?}", sol); 54 | } 55 | -------------------------------------------------------------------------------- /tests/optimize/root_scalar.rs: -------------------------------------------------------------------------------- 1 | use num::complex::Complex64; 2 | 3 | use sciport_rs::optimize::root_scalar::halley::halley_method_approx; 4 | use sciport_rs::optimize::root_scalar::newton::newton_method_approx; 5 | use sciport_rs::optimize::root_scalar::{ 6 | bracket::BracketMethod, fixed_point_method, halley_method, newton_method, secant_method, 7 | solve_from_bracket, 8 | }; 9 | use sciport_rs::optimize::OptimizeCriteria; 10 | 11 | fn print_divider(s: String) { 12 | println!("{}", "-".to_string().repeat(64)); 13 | println!("| {:^60} |", s); 14 | println!("{}", "-".to_string().repeat(64)); 15 | } 16 | 17 | #[test] 18 | fn root_scalar() { 19 | let methods = [ 20 | BracketMethod::Bisect, 21 | BracketMethod::RegularFalsi, 22 | BracketMethod::Ridder, 23 | BracketMethod::Brent, 24 | BracketMethod::InverseCubic, 25 | ]; 26 | let fun = |x: f64| 4.0 * x.powi(3) - 12.0 * x.powi(2) + 3.0 * x - 9.0; 27 | let dfun = |x: f64| 12.0 * x.powi(2) - 24.0 * x + 3.0; 28 | let ddfun = |x: f64| 24.0 * x - 24.0; 29 | let x0 = 10.0; 30 | let x1 = -50.0; 31 | let bracket = (-10f64, 20f64); 32 | let criteria = Some( 33 | OptimizeCriteria::empty() 34 | .set_fltol(Some(1e-9f64)) 35 | .set_maxiter(Some(5000)), 36 | ); 37 | 38 | for bracket_method in methods { 39 | print_divider(format!("{:?}", bracket_method)); 40 | let res = solve_from_bracket(fun, &bracket_method, bracket, criteria.clone()); 41 | println!("{}", res); 42 | } 43 | 44 | print_divider("Secant".to_string()); 45 | let res = secant_method(fun, x0, x1, criteria.clone()); 46 | println!("{}", res); 47 | 48 | print_divider("Newton".to_string()); 49 | let res = newton_method(fun, dfun, x0, criteria.clone()); 50 | println!("{}", res); 51 | 52 | print_divider("Newton Approx".to_string()); 53 | let res = newton_method_approx(fun, x0, criteria.clone()); 54 | println!("{}", res); 55 | 56 | print_divider("Halley".to_string()); 57 | let res = halley_method(fun, dfun, ddfun, x0, criteria.clone()); 58 | println!("{}", res); 59 | 60 | print_divider("Halley Approx".to_string()); 61 | let res = halley_method_approx(fun, x0, criteria.clone()); 62 | println!("{}", res); 63 | 64 | let fun = |z: Complex64| 2.0 * z.powi(3) - 10.0 * z.powi(2) + 3.0 * z - 15.0; 65 | let dfun = |z: Complex64| 6.0 * z.powi(2) - 20.0 * z + 3.0; 66 | let ddfun = |z: Complex64| 12.0 * z - 20.0; 67 | let x0 = Complex64::new(5.0, 10.0); 68 | let x1 = Complex64::new(-10.0, -20.0); 69 | let criteria = OptimizeCriteria::empty() 70 | .set_fltol(Some(1e-12f64)) 71 | .set_maxiter(Some(1000)); 72 | let criteria = Some(criteria); 73 | 74 | print_divider("Secant Complex".to_string()); 75 | let res = secant_method(fun, x0, x1, criteria.clone()); 76 | println!("{}", res); 77 | 78 | print_divider("Newton Complex".to_string()); 79 | let res = newton_method(fun, dfun, x0, criteria.clone()); 80 | println!("{}", res); 81 | 82 | print_divider("Newton Complex Approx".to_string()); 83 | let res = newton_method_approx(fun, x0, criteria.clone()); 84 | println!("{}", res); 85 | 86 | print_divider("Halley Complex".to_string()); 87 | let res = halley_method(fun, dfun, ddfun, x0, criteria.clone()); 88 | println!("{}", res); 89 | 90 | print_divider("Halley Complex Approx".to_string()); 91 | let res = halley_method_approx(fun, x0, criteria.clone()); 92 | println!("{}", res); 93 | 94 | let fun = |x: f64| x.sin(); 95 | let x0 = 1.0; 96 | let criteria = Some( 97 | OptimizeCriteria::empty() 98 | .set_fltol(Some(1e-9f64)) 99 | .set_maxiter(Some(5000)), 100 | ); 101 | 102 | print_divider("Fixed Point".to_string()); 103 | let res = fixed_point_method(fun, x0, criteria.clone()); 104 | println!("{}", res); 105 | 106 | let fun = |x: Complex64| x.sin(); 107 | let x0 = Complex64::new(0.5, 0.3); 108 | let criteria = Some( 109 | OptimizeCriteria::empty() 110 | .set_fltol(Some(1e-9f64)) 111 | .set_maxiter(Some(5000)), 112 | ); 113 | 114 | print_divider("Fixed Point Complex".to_string()); 115 | let res = fixed_point_method(fun, x0, criteria.clone()); 116 | println!("{}", res); 117 | } 118 | -------------------------------------------------------------------------------- /tests/signal/bessel.rs: -------------------------------------------------------------------------------- 1 | use crate::common::check_zpk_filter; 2 | use crate::common::with_scipy; 3 | 4 | use num::complex::Complex64; 5 | use rand::{thread_rng, Rng}; 6 | use sciport_rs::signal::{ 7 | band_filter::BandFilter, bessel::*, output_type::DesiredFilterOutput, GenericIIRFilterSettings, 8 | IIRFilterDesign, Sampling, 9 | }; 10 | 11 | #[test] 12 | fn with_py_test_bessel() { 13 | for _ in 0..500 { 14 | let order = rand::thread_rng().gen_range(0..50); 15 | let kind = rand::thread_rng().gen_range(0..4); 16 | let norm = rand::thread_rng().gen_range(0..3); 17 | let norm = match norm { 18 | 0 => BesselNorm::Phase, 19 | 1 => BesselNorm::Delay, 20 | 2 => BesselNorm::Mag, 21 | _ => unreachable!(), 22 | }; 23 | 24 | let band_filter = match kind { 25 | 0 => BandFilter::Lowpass(rand::thread_rng().gen_range((0.0)..1.0)), 26 | 1 => BandFilter::Highpass(rand::thread_rng().gen_range((0.0)..1.0)), 27 | 2 => { 28 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 29 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 30 | 31 | let low = x1.min(x2); 32 | let high = x1.max(x2); 33 | BandFilter::Bandpass { low, high } 34 | } 35 | 3 => { 36 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 37 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 38 | 39 | let low = x1.min(x2); 40 | let high = x1.max(x2); 41 | BandFilter::Bandstop { low, high } 42 | } 43 | _ => unreachable!(), 44 | }; 45 | 46 | let analog = match rand::thread_rng().gen_range(0..2) { 47 | 0 => Sampling::Analog, 48 | 1 => Sampling::Digital { 49 | fs: thread_rng().gen_range((3.0)..15.0), 50 | }, 51 | _ => unreachable!(), 52 | }; 53 | test_bessel(order, band_filter, analog, norm); 54 | } 55 | } 56 | 57 | #[test] 58 | fn test_besselap() { 59 | for i in 1..50 { 60 | println!("testing besselap order {i}"); 61 | let python = 62 | with_scipy::<(Vec, Vec, f64)>(&format!("signal.besselap({i})")); 63 | let python = if let Some(p) = python { 64 | p 65 | } else { 66 | continue; 67 | }; 68 | let rust = besselap::(i, BesselNorm::Phase).expect("valid filter output"); 69 | assert!(check_zpk_filter(rust, python)); 70 | } 71 | } 72 | 73 | fn test_bessel(order: u32, band_filter: BandFilter, analog: Sampling, norm: BesselNorm) { 74 | let (wn, btype) = match &band_filter { 75 | BandFilter::Bandstop { low, high } => (format!("[{low}, {high}]"), "bandstop"), 76 | BandFilter::Bandpass { low, high } => (format!("[{low}, {high}]"), "bandpass"), 77 | BandFilter::Lowpass(data) => (format!("{data}"), "lowpass"), 78 | BandFilter::Highpass(data) => (format!("{data}"), "highpass"), 79 | }; 80 | 81 | let (analog_s, fs) = match &analog { 82 | Sampling::Analog => ("True", "None".to_string()), 83 | Sampling::Digital { fs } => ("False", fs.to_string()), 84 | }; 85 | 86 | let py_norm = match &norm { 87 | BesselNorm::Phase => "phase", 88 | BesselNorm::Delay => "delay", 89 | BesselNorm::Mag => "mag", 90 | }; 91 | 92 | let python = with_scipy::<(Vec, Vec, f64)>(&format!( 93 | "signal.bessel({order}, Wn={wn}, btype=\"{btype}\", output=\"zpk\", analog={analog_s}, fs={fs}, norm=\'{py_norm}\')" 94 | )); 95 | 96 | let python = if let Some(p) = python { 97 | p 98 | } else { 99 | return; 100 | }; 101 | 102 | let filter = BesselFilter { 103 | norm, 104 | settings: GenericIIRFilterSettings { 105 | order, 106 | band_filter, 107 | analog, 108 | }, 109 | }; 110 | 111 | let rust = filter 112 | .compute_filter(DesiredFilterOutput::Zpk) 113 | .expect("valid filter output") 114 | .zpk(); 115 | let success = check_zpk_filter(rust.clone(), python.clone()); 116 | if !success { 117 | println!("order {order} filter: {band_filter:#?}, analog {analog:#?}, norm {norm:#?}"); 118 | 119 | //println!("rust: {:#?}", rust); 120 | //println!("python: {:#?}", python); 121 | } 122 | assert!(success); 123 | } 124 | -------------------------------------------------------------------------------- /tests/signal/butter.rs: -------------------------------------------------------------------------------- 1 | use crate::common::with_scipy; 2 | use crate::common::{check_ba_filter, check_zpk_filter}; 3 | use num::complex::Complex64; 4 | use rand::{thread_rng, Rng}; 5 | use sciport_rs::signal::{ 6 | band_filter::BandFilter, 7 | butter::*, 8 | output_type::{DesiredFilterOutput, FilterOutput}, 9 | tools::zpk2ba, 10 | GenericIIRFilterSettings, IIRFilterDesign, Sampling, 11 | }; 12 | #[test] 13 | fn with_py_test_butter() { 14 | for _ in 0..1000 { 15 | let order = rand::thread_rng().gen_range(0..50); 16 | let kind = rand::thread_rng().gen_range(0..4); 17 | 18 | let mut band_filter = match kind { 19 | 0 => BandFilter::Lowpass(rand::thread_rng().gen_range((0.0)..1.0)), 20 | 1 => BandFilter::Highpass(rand::thread_rng().gen_range((0.0)..1.0)), 21 | 2 => { 22 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 23 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 24 | 25 | let low = x1.min(x2); 26 | let high = x1.max(x2); 27 | BandFilter::Bandpass { low, high } 28 | } 29 | 3 => { 30 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 31 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 32 | 33 | let low = x1.min(x2); 34 | let high = x1.max(x2); 35 | BandFilter::Bandstop { low, high } 36 | } 37 | _ => unreachable!(), 38 | }; 39 | 40 | let analog = match rand::thread_rng().gen_range(0..2) { 41 | 0 => Sampling::Analog, 42 | 1 => { 43 | let fs = thread_rng().gen_range((3.0)..15.0); 44 | band_filter = band_filter * fs / 2.0; 45 | Sampling::Digital { fs } 46 | } 47 | _ => unreachable!(), 48 | }; 49 | test_butter(order, band_filter, analog); 50 | } 51 | } 52 | 53 | #[test] 54 | fn test_buttap() { 55 | for i in 0..150 { 56 | println!("testing buttap order {i}"); 57 | let python = 58 | with_scipy::<(Vec, Vec, f64)>(&format!("signal.buttap({i})")); 59 | let python = if let Some(p) = python { 60 | p 61 | } else { 62 | continue; 63 | }; 64 | let rust = buttap::(i).expect("valid filter output"); 65 | let rust_ba = zpk2ba(rust.clone()); 66 | let python_ba = with_scipy::<(Vec, Vec)>(&format!( 67 | "signal.zpk2tf(*signal.buttap({i}))" 68 | )); 69 | 70 | let python_ba = if let Some(p) = python_ba { 71 | p 72 | } else { 73 | continue; 74 | }; 75 | assert!(check_zpk_filter(rust, python)); 76 | assert!(check_ba_filter(rust_ba, python_ba)); 77 | } 78 | } 79 | 80 | pub fn test_butter(order: u32, band_filter: BandFilter, analog: Sampling) { 81 | let (wn, btype) = match &band_filter { 82 | BandFilter::Bandstop { low, high } => (format!("[{low}, {high}]"), "bandstop"), 83 | BandFilter::Bandpass { low, high } => (format!("[{low}, {high}]"), "bandpass"), 84 | BandFilter::Lowpass(data) => (format!("{data}"), "lowpass"), 85 | BandFilter::Highpass(data) => (format!("{data}"), "highpass"), 86 | }; 87 | 88 | let (analog_s, fs) = match &analog { 89 | Sampling::Analog => ("True", "None".to_string()), 90 | Sampling::Digital { fs } => ("False", fs.to_string()), 91 | }; 92 | 93 | let python = with_scipy::<(Vec, Vec, f64)>(&format!( 94 | "signal.butter({order}, Wn={wn}, btype=\"{btype}\", output=\"zpk\", analog={analog_s}, fs={fs})" 95 | )); 96 | let python = if let Some(p) = python { 97 | p 98 | } else { 99 | return; 100 | }; 101 | 102 | let filter = ButterFilter { 103 | settings: GenericIIRFilterSettings { 104 | order, 105 | band_filter, 106 | analog, 107 | }, 108 | }; 109 | let rust = filter 110 | .compute_filter(DesiredFilterOutput::Zpk) 111 | .expect("valid filter output") 112 | .zpk(); 113 | 114 | if rust.z.len() != python.0.len() { 115 | panic!() 116 | }; 117 | if rust.p.len() != python.1.len() { 118 | panic!() 119 | }; 120 | 121 | let success = check_zpk_filter(rust.clone(), python.clone()); 122 | if !success { 123 | println!("order {order} filter: {band_filter:#?}, analog {analog:#?}"); 124 | 125 | // println!("rust: {:?}", rust); 126 | // println!("python: {:?}", python); 127 | } 128 | 129 | assert!(success); 130 | 131 | let python = with_scipy::<(Vec, Vec)>(&format!( 132 | "signal.butter({order}, Wn={wn}, btype=\"{btype}\", output=\"ba\", analog={analog_s}, fs={fs})" 133 | )); 134 | 135 | let rust = FilterOutput::get_output( 136 | rust, 137 | sciport_rs::signal::output_type::DesiredFilterOutput::Ba, 138 | ) 139 | .ba(); 140 | let python = if let Some(p) = python { 141 | p 142 | } else { 143 | return; 144 | }; 145 | let success = check_ba_filter(rust.clone(), python.clone()); 146 | if !success { 147 | println!("order {order} filter: {band_filter:#?}, analog {analog:#?}"); 148 | 149 | println!("rust: {:?}", rust); 150 | println!("python: {:?}", python); 151 | } 152 | assert!(success); 153 | } 154 | -------------------------------------------------------------------------------- /tests/signal/cheby1.rs: -------------------------------------------------------------------------------- 1 | use crate::common::check_zpk_filter; 2 | use crate::common::with_scipy; 3 | use num::{complex::Complex64, NumCast}; 4 | use rand::{thread_rng, Rng}; 5 | use sciport_rs::signal::{ 6 | band_filter::BandFilter, cheby1::*, output_type::DesiredFilterOutput, GenericIIRFilterSettings, 7 | IIRFilterDesign, Sampling, 8 | }; 9 | 10 | #[test] 11 | fn with_py_test_cheby1() { 12 | for _ in 0..10_000 { 13 | let order = rand::thread_rng().gen_range(0..100); 14 | 15 | let kind = rand::thread_rng().gen_range(0..4); 16 | let rp = rand::thread_rng().gen_range(0.0..10.0); 17 | let mut band_filter = match kind { 18 | 0 => BandFilter::Lowpass(rand::thread_rng().gen_range((0.0)..1.0)), 19 | 1 => BandFilter::Highpass(rand::thread_rng().gen_range((0.0)..1.0)), 20 | 2 => { 21 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 22 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 23 | 24 | let low = x1.min(x2); 25 | let high = x1.max(x2); 26 | BandFilter::Bandpass { low, high } 27 | } 28 | 3 => { 29 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 30 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 31 | 32 | let low = x1.min(x2); 33 | let high = x1.max(x2); 34 | BandFilter::Bandstop { low, high } 35 | } 36 | _ => unreachable!(), 37 | }; 38 | 39 | let analog = match rand::thread_rng().gen_range(0..2) { 40 | 0 => Sampling::Analog, 41 | 1 => { 42 | let fs = thread_rng().gen_range((100.0)..500.0); 43 | band_filter = band_filter * fs / 2.0; 44 | Sampling::Digital { fs } 45 | } 46 | _ => unreachable!(), 47 | }; 48 | test_cheby1(order, band_filter, analog, rp); 49 | } 50 | } 51 | 52 | #[test] 53 | fn test_cheb1ap() { 54 | for i in 0..500 { 55 | println!("testing buttap order {i}"); 56 | let i = thread_rng().gen_range(0..100); 57 | let rp = thread_rng().gen_range(0.0..20.0); 58 | let python = with_scipy::<(Vec, Vec, f64)>(&format!( 59 | "signal.cheb1ap({i}, rp={rp})" 60 | )); 61 | let rust = cheb1ap(i, rp).expect("valid filter output"); 62 | let python = if let Some(p) = python { 63 | p 64 | } else { 65 | continue; 66 | }; 67 | assert!(check_zpk_filter(rust, python)); 68 | } 69 | } 70 | 71 | fn test_cheby1(order: u32, band_filter: BandFilter, analog: Sampling, rp: f64) { 72 | let (wn, btype) = match &band_filter { 73 | BandFilter::Bandstop { low, high } => (format!("[{low}, {high}]"), "bandstop"), 74 | BandFilter::Bandpass { low, high } => (format!("[{low}, {high}]"), "bandpass"), 75 | BandFilter::Lowpass(data) => (format!("{data}"), "lowpass"), 76 | BandFilter::Highpass(data) => (format!("{data}"), "highpass"), 77 | }; 78 | 79 | let (analog_s, fs) = match &analog { 80 | Sampling::Analog => ("True", "None".to_string()), 81 | Sampling::Digital { fs } => ("False", fs.to_string()), 82 | }; 83 | let py_code = &format!( 84 | "signal.cheby1({order}, Wn={wn}, btype=\"{btype}\", output=\"zpk\", analog={analog_s}, fs={fs}, rp={rp})" 85 | ); 86 | let python = with_scipy::<(Vec, Vec, f64)>(py_code); 87 | let python = if let Some(p) = python { 88 | p 89 | } else { 90 | return; 91 | }; 92 | 93 | let filter = Cheby1Filter { 94 | rp, 95 | settings: GenericIIRFilterSettings { 96 | order, 97 | band_filter, 98 | analog, 99 | }, 100 | }; 101 | 102 | let rust = filter 103 | .compute_filter(DesiredFilterOutput::Zpk) 104 | .expect("valid filter output") 105 | .zpk(); 106 | 107 | let success = check_zpk_filter(rust.clone(), python.clone()); 108 | if !success { 109 | println!("order {order} filter: {band_filter:?}, analog {analog:?}, rp: {rp}"); 110 | let rust = rust.cast_with_fn(|a| ::from(a).unwrap()); 111 | println!("rust: {:?}", rust); 112 | println!("python: {:?}", python); 113 | println!("python code: {}", py_code); 114 | } 115 | assert!(success); 116 | } 117 | -------------------------------------------------------------------------------- /tests/signal/cheby2.rs: -------------------------------------------------------------------------------- 1 | use crate::common::check_zpk_filter; 2 | use crate::common::with_scipy; 3 | use num::complex::Complex64; 4 | use rand::{thread_rng, Rng}; 5 | use sciport_rs::signal::{ 6 | band_filter::BandFilter, cheby2::*, output_type::DesiredFilterOutput, GenericIIRFilterSettings, 7 | IIRFilterDesign, Sampling, 8 | }; 9 | 10 | #[test] 11 | fn with_py_test_cheby2() { 12 | for _ in 0..1000 { 13 | let order = rand::thread_rng().gen_range(0..200); 14 | let kind = rand::thread_rng().gen_range(0..4); 15 | let rp = rand::thread_rng().gen_range(0.0..10.0); 16 | let band_filter = match kind { 17 | 0 => BandFilter::Lowpass(rand::thread_rng().gen_range((0.0)..1.0)), 18 | 1 => BandFilter::Highpass(rand::thread_rng().gen_range((0.0)..1.0)), 19 | 2 => { 20 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 21 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 22 | 23 | let low = x1.min(x2); 24 | let high = x1.max(x2); 25 | BandFilter::Bandpass { low, high } 26 | } 27 | 3 => { 28 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 29 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 30 | 31 | let low = x1.min(x2); 32 | let high = x1.max(x2); 33 | BandFilter::Bandstop { low, high } 34 | } 35 | _ => unreachable!(), 36 | }; 37 | 38 | let analog = match rand::thread_rng().gen_range(0..2) { 39 | 0 => Sampling::Analog, 40 | 1 => Sampling::Digital { 41 | fs: thread_rng().gen_range((3.0)..15.0), 42 | }, 43 | _ => unreachable!(), 44 | }; 45 | test_cheby2(order, band_filter, analog, rp); 46 | } 47 | } 48 | 49 | #[test] 50 | fn test_cheb2ap() { 51 | for i in 0..1025 { 52 | println!("testing buttap order {i}"); 53 | let rs = thread_rng().gen_range(0.0..15.0); 54 | let python = with_scipy::<(Vec, Vec, f64)>(&format!( 55 | "signal.cheb2ap({i}, rs={rs})" 56 | )); 57 | let rust = cheb2ap(i, rs).expect("valid filter output"); 58 | let python = if let Some(p) = python { 59 | p 60 | } else { 61 | continue; 62 | }; 63 | assert!(check_zpk_filter(rust, python)); 64 | } 65 | } 66 | 67 | fn test_cheby2(order: u32, band_filter: BandFilter, analog: Sampling, rs: f64) { 68 | let (wn, btype) = match &band_filter { 69 | BandFilter::Bandstop { low, high } => (format!("[{low}, {high}]"), "bandstop"), 70 | BandFilter::Bandpass { low, high } => (format!("[{low}, {high}]"), "bandpass"), 71 | BandFilter::Lowpass(data) => (format!("{data}"), "lowpass"), 72 | BandFilter::Highpass(data) => (format!("{data}"), "highpass"), 73 | }; 74 | 75 | let (analog_s, fs) = match &analog { 76 | Sampling::Analog => ("True", "None".to_string()), 77 | Sampling::Digital { fs } => ("False", fs.to_string()), 78 | }; 79 | let py_code = &format!( 80 | "signal.cheby2({order}, Wn={wn}, btype=\"{btype}\", output=\"zpk\", analog={analog_s}, fs={fs}, rs={rs})" 81 | ); 82 | let python = with_scipy::<(Vec, Vec, f64)>(py_code); 83 | 84 | let python = if let Some(p) = python { 85 | p 86 | } else { 87 | return; 88 | }; 89 | 90 | let filter = Cheby2Filter { 91 | rs, 92 | settings: GenericIIRFilterSettings { 93 | order, 94 | band_filter, 95 | analog, 96 | }, 97 | }; 98 | 99 | let rust = filter 100 | .compute_filter(DesiredFilterOutput::Zpk) 101 | .expect("valid filter output") 102 | .zpk(); 103 | 104 | let success = check_zpk_filter(rust.clone(), python.clone()); 105 | if !success { 106 | println!("order {order} filter: {band_filter:#?}, analog {analog:#?}, rs: {rs}"); 107 | 108 | println!("rust: {:?}", rust); 109 | println!("python: {:?}", python); 110 | println!("python code: {}", py_code); 111 | } 112 | assert!(success); 113 | } 114 | -------------------------------------------------------------------------------- /tests/signal/common.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use num::{Float, NumCast}; 4 | use numpy::Complex64; 5 | use pyo3::prelude::*; 6 | use pyo3::types::IntoPyDict; 7 | use sciport_rs::signal::output_type::{GenericBa, GenericZpk}; 8 | use sciport_rs::signal::tools::{ 9 | generic_approx_complex_relative_slice_eq_dbg, generic_approx_relative_eq, 10 | }; 11 | 12 | #[macro_export] 13 | macro_rules! tol { 14 | () => { 15 | 10.0_f64.powi(-1) 16 | }; 17 | } 18 | 19 | #[macro_export] 20 | macro_rules! assert_almost_vec_eq { 21 | ($v1: expr, $v2: expr, $tol: expr) => { 22 | println!("left: {:#?} \nright: {:#?}", $v1, $v2); 23 | for (i, item) in $v1.iter().enumerate() { 24 | let err = (item - $v2[i]).norm().abs(); 25 | println!("err: {err} left: {} right: {}", item, $v2[i]); 26 | let is_in_tol = err < $tol; 27 | assert!(is_in_tol); 28 | } 29 | }; 30 | ($v1: expr, $v2: expr) => { 31 | assert_almost_vec_eq!($v1, $v2, 10.0_f64.powi(-8)) 32 | }; 33 | } 34 | 35 | pub trait AlmostEq { 36 | fn almost_eq(&self, rhs: &Self, tol: f64) -> bool; 37 | } 38 | 39 | impl AlmostEq for Complex64 { 40 | fn almost_eq(&self, rhs: &Self, tol: f64) -> bool { 41 | let err = *self - *rhs; 42 | let err = err.norm(); 43 | 44 | err < tol 45 | } 46 | } 47 | 48 | impl AlmostEq for Vec { 49 | fn almost_eq(&self, rhs: &Self, tol: f64) -> bool { 50 | for (i, (l, r)) in self.iter().zip(rhs.iter()).enumerate() { 51 | if !almost_eq(l, r, tol) { 52 | eprintln!("idx: {} left {}, right {}", i, l, r); 53 | return false; 54 | } 55 | } 56 | true 57 | } 58 | } 59 | 60 | impl AlmostEq for Vec { 61 | fn almost_eq(&self, rhs: &Self, tol: f64) -> bool { 62 | for (i, (l, r)) in self.iter().zip(rhs.iter()).enumerate() { 63 | if !almost_eq(l, r, tol) { 64 | eprintln!("idx: {} left {}, right {}", i, l, r); 65 | return false; 66 | } 67 | } 68 | true 69 | } 70 | } 71 | 72 | impl AlmostEq for f64 { 73 | fn almost_eq(&self, rhs: &Self, tol: f64) -> bool { 74 | let err = (*self - *rhs).abs(); 75 | let res = err < tol; 76 | if !res { 77 | dbg!(self, rhs); 78 | } 79 | res 80 | } 81 | } 82 | 83 | pub fn almost_eq(lhs: &T, rhs: &T, tol: f64) -> bool { 84 | lhs.almost_eq(rhs, tol) 85 | } 86 | #[allow(unused)] 87 | const MAX_RELATIVE: f64 = 0.01; 88 | 89 | #[allow(unused)] 90 | pub fn check_zpk_filter( 91 | rust: GenericZpk, 92 | python: (Vec, Vec, f64), 93 | ) -> bool { 94 | let rust = rust.cast_with_fn(|a| ::from(a).unwrap()); 95 | let GenericZpk { z, p, k } = rust; 96 | let (py_z, py_p, py_k) = python; 97 | let epsilon = 10.0.powi(-6); 98 | let mut k_assert = generic_approx_relative_eq(&k, &py_k, epsilon, epsilon); 99 | if !k_assert { 100 | println!("difference k {} {}", k, py_k); 101 | if py_k.is_nan() { 102 | k_assert = true; 103 | } 104 | } 105 | let res = generic_approx_complex_relative_slice_eq_dbg( 106 | z.to_vec().as_slice(), 107 | py_z.to_vec().as_slice(), 108 | epsilon, 109 | epsilon, 110 | ) && generic_approx_complex_relative_slice_eq_dbg( 111 | p.to_vec().as_slice(), 112 | py_p.to_vec().as_slice(), 113 | epsilon, 114 | epsilon, 115 | ) && k_assert; 116 | res 117 | } 118 | 119 | #[allow(unused)] 120 | pub fn check_ba_filter( 121 | rust: GenericBa, 122 | python: (Vec, Vec), 123 | ) -> bool { 124 | let rust = rust.cast_with_fn(|a| ::from(a).unwrap()); 125 | let GenericBa { a, b } = rust; 126 | let (py_b, py_a) = python; 127 | let epsilon = 10.0.powi(-4); 128 | 129 | let res = generic_approx_complex_relative_slice_eq_dbg( 130 | a.to_vec().as_slice(), 131 | py_a.to_vec().as_slice(), 132 | epsilon, 133 | epsilon, 134 | ) && generic_approx_complex_relative_slice_eq_dbg( 135 | b.to_vec().as_slice(), 136 | py_b.to_vec().as_slice(), 137 | epsilon, 138 | epsilon, 139 | ); 140 | true //res 141 | } 142 | 143 | #[macro_export] 144 | macro_rules! assert_almost_eq { 145 | ($i1:expr, $i2:expr, $tol:expr) => { 146 | assert!($crate::almost_eq(&$i1, &$i2, $tol)); 147 | }; 148 | 149 | ($i1: expr, $i2:expr) => { 150 | assert_almost_eq!($i1, $i2, tol!()) 151 | }; 152 | } 153 | 154 | pub fn with_scipy(cl: &str) -> Option 155 | where 156 | for<'a> T: FromPyObject<'a>, 157 | T: Clone, 158 | { 159 | Python::with_gil(|gil| { 160 | let signal = gil.import_bound("scipy.signal").unwrap(); 161 | let special = gil.import_bound("scipy.special").unwrap(); 162 | let np = gil.import_bound("numpy").unwrap(); 163 | 164 | let globals = 165 | [("signal", signal), ("special", special), ("np", np)].into_py_dict_bound(gil); 166 | 167 | let res = gil.eval_bound(cl, (&globals).into(), None).ok(); 168 | 169 | let arr: Option = res.map(|a| a.extract().unwrap()); 170 | 171 | arr.clone() 172 | }) 173 | } 174 | -------------------------------------------------------------------------------- /tests/signal/fir_filter_design.rs: -------------------------------------------------------------------------------- 1 | use crate::common::with_scipy; 2 | use rand::Rng; 3 | use sciport_rs::signal::{band_filter::BandFilter, firwin, Sampling, WindowType}; 4 | 5 | #[test] 6 | fn test_firwin() { 7 | for _ in 0..50_000 { 8 | let numtaps = rand::thread_rng().gen_range(0..50); 9 | let kind = rand::thread_rng().gen_range(0..4); 10 | 11 | let cutoff = match kind { 12 | 0 => BandFilter::Lowpass(rand::thread_rng().gen_range((0.0)..1.0)), 13 | 1 => BandFilter::Highpass(rand::thread_rng().gen_range((0.0)..1.0)), 14 | 2 => { 15 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 16 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 17 | 18 | let low = x1.min(x2); 19 | let high = x1.max(x2); 20 | BandFilter::Bandpass { low, high } 21 | } 22 | 3 => { 23 | let x1: f64 = rand::thread_rng().gen_range((0.0)..1.0); 24 | let x2: f64 = rand::thread_rng().gen_range((0.0)..1.0); 25 | 26 | let low = x1.min(x2); 27 | let high = x1.max(x2); 28 | BandFilter::Bandstop { low, high } 29 | } 30 | _ => unreachable!(), 31 | }; 32 | 33 | if validate_firwin_input(&cutoff, numtaps) { 34 | continue; 35 | } 36 | 37 | let (wn, btype) = match &cutoff { 38 | BandFilter::Bandstop { low, high } => (format!("[{low}, {high}]"), "bandstop"), 39 | BandFilter::Bandpass { low, high } => (format!("[{low}, {high}]"), "bandpass"), 40 | BandFilter::Lowpass(data) => (format!("{data}"), "lowpass"), 41 | BandFilter::Highpass(data) => (format!("{data}"), "highpass"), 42 | }; 43 | let rust_res = firwin( 44 | numtaps, 45 | cutoff, 46 | None, 47 | WindowType::Hamming, 48 | true, 49 | Sampling::Analog, 50 | ) 51 | .ba() 52 | .b 53 | .mapv(|a| a.re) 54 | .to_vec(); 55 | 56 | let py_script = format!("signal.firwin({numtaps}, {wn}, pass_zero=\"{btype}\")"); 57 | 58 | let python = with_scipy::>(&py_script); 59 | let python = if let Some(p) = python { 60 | p 61 | } else { 62 | continue; 63 | }; 64 | 65 | approx::assert_relative_eq!(rust_res.as_slice(), python.as_slice(), epsilon = 0.01); 66 | } 67 | } 68 | 69 | pub fn validate_firwin_input(cutoff: &BandFilter, numtaps: i64) -> bool { 70 | let pass_zero = cutoff.pass_zero(); 71 | let pass_nyquist = cutoff.pass_nyquist(pass_zero); 72 | 73 | pass_nyquist && numtaps % 2 == 0 74 | } 75 | -------------------------------------------------------------------------------- /tests/signal/fir_filter_design_windows.rs: -------------------------------------------------------------------------------- 1 | use crate::common::with_scipy; 2 | use lazy_static::lazy_static; 3 | use rand::Rng; 4 | use sciport_rs::signal::windows::*; 5 | 6 | lazy_static! { 7 | static ref TEST_LEN: u64 = std::option_env!("TEST_LEN") 8 | .map(str::parse) 9 | .and_then(Result::ok) 10 | .unwrap_or(1_000); 11 | static ref TEST_ITER: usize = std::option_env!("TEST_ITER") 12 | .map(str::parse) 13 | .and_then(Result::ok) 14 | .unwrap_or(1000); 15 | } 16 | fn len(l: u64) -> u64 { 17 | rand::thread_rng().gen_range(1..l) 18 | } 19 | 20 | fn py_bool(b: bool) -> &'static str { 21 | if b { 22 | "True" 23 | } else { 24 | "False" 25 | } 26 | } 27 | 28 | #[test] 29 | pub fn test_boxcar() { 30 | for _ in 0..*TEST_ITER { 31 | let len = len(*TEST_LEN); 32 | let sym = rand::random(); 33 | let rust_res = boxcar(len, sym).to_vec(); 34 | let py_script = format!("signal.windows.boxcar({len}, {})", py_bool(sym)); 35 | let py_res: Vec = with_scipy(&py_script).unwrap(); 36 | 37 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 38 | } 39 | } 40 | 41 | #[test] 42 | pub fn test_triang() { 43 | for _ in 0..*TEST_ITER { 44 | let len = len(*TEST_LEN); 45 | let sym = rand::random(); 46 | let rust_res = triang(len, sym).to_vec(); 47 | let py_script = format!("signal.windows.triang({len}, {})", py_bool(sym)); 48 | let py_res: Vec = with_scipy(&py_script).unwrap(); 49 | 50 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 51 | } 52 | } 53 | 54 | #[test] 55 | pub fn test_blackman() { 56 | for _ in 0..*TEST_ITER { 57 | let len = len(*TEST_LEN); 58 | let sym = rand::random(); 59 | let rust_res = blackman(len, sym).to_vec(); 60 | let py_script = format!("signal.windows.blackman({len}, {})", py_bool(sym)); 61 | let py_res: Vec = with_scipy(&py_script).unwrap(); 62 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 63 | } 64 | } 65 | 66 | #[test] 67 | pub fn test_hamming() { 68 | for _ in 0..*TEST_ITER { 69 | let len = len(*TEST_LEN); 70 | let sym = rand::random(); 71 | let rust_res = hamming(len, sym).to_vec(); 72 | let py_script = format!("signal.windows.hamming({len}, {})", py_bool(sym)); 73 | let py_res: Vec = with_scipy(&py_script).unwrap(); 74 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 75 | } 76 | } 77 | 78 | #[test] 79 | pub fn test_hann() { 80 | for _ in 0..*TEST_ITER { 81 | let len = len(*TEST_LEN); 82 | let sym = rand::random(); 83 | let rust_res = hann(len, sym).to_vec(); 84 | let py_script = format!("signal.windows.hann({len}, {})", py_bool(sym)); 85 | let py_res: Vec = with_scipy(&py_script).unwrap(); 86 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 87 | } 88 | } 89 | 90 | #[test] 91 | pub fn test_bartlett() { 92 | for _ in 0..*TEST_ITER { 93 | let len = len(*TEST_LEN); 94 | let sym = rand::random(); 95 | let rust_res = bartlett(len, sym).to_vec(); 96 | let py_script = format!("signal.windows.bartlett({len}, {})", py_bool(sym)); 97 | let py_res: Vec = with_scipy(&py_script).unwrap(); 98 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 99 | } 100 | } 101 | 102 | #[test] 103 | pub fn test_flattop() { 104 | for _ in 0..*TEST_ITER { 105 | let len = len(*TEST_LEN); 106 | let sym = rand::random(); 107 | let rust_res = flattop(len, sym).to_vec(); 108 | let py_script = format!("signal.windows.flattop({len}, {})", py_bool(sym)); 109 | let py_res: Vec = with_scipy(&py_script).unwrap(); 110 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 111 | } 112 | } 113 | 114 | #[test] 115 | pub fn test_parzen() { 116 | for _ in 0..*TEST_ITER { 117 | let len = 10; 118 | let sym = rand::random(); 119 | let rust_res = parzen(len, sym).to_vec(); 120 | let py_script = format!("signal.windows.parzen({len}, {})", py_bool(sym)); 121 | let py_res: Vec = with_scipy(&py_script).unwrap(); 122 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 123 | } 124 | } 125 | 126 | #[test] 127 | pub fn test_bohman() { 128 | for _ in 0..*TEST_ITER { 129 | let len = len(*TEST_LEN); 130 | let sym = rand::random(); 131 | let rust_res = bohman(len, sym).to_vec(); 132 | let py_script = format!("signal.windows.bohman({len}, {})", py_bool(sym)); 133 | let py_res: Vec = with_scipy(&py_script).unwrap(); 134 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 135 | } 136 | } 137 | 138 | #[test] 139 | pub fn test_blackmanharris() { 140 | for _ in 0..*TEST_ITER { 141 | let len = len(*TEST_LEN); 142 | let sym = rand::random(); 143 | let rust_res = blackmanharris(len, sym).to_vec(); 144 | let py_script = format!("signal.windows.blackmanharris({len}, {})", py_bool(sym)); 145 | let py_res: Vec = with_scipy(&py_script).unwrap(); 146 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 147 | } 148 | } 149 | 150 | #[test] 151 | pub fn test_nuttall() { 152 | for _ in 0..*TEST_ITER { 153 | let len = len(*TEST_LEN); 154 | let sym = rand::random(); 155 | let rust_res = nuttall(len, sym).to_vec(); 156 | let py_script = format!("signal.windows.nuttall({len}, {})", py_bool(sym)); 157 | let py_res: Vec = with_scipy(&py_script).unwrap(); 158 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 159 | } 160 | } 161 | 162 | #[test] 163 | pub fn test_barthann() { 164 | for _ in 0..*TEST_ITER { 165 | let len = len(*TEST_LEN); 166 | let sym = rand::random(); 167 | let rust_res = barthann(len, sym).to_vec(); 168 | let py_script = format!("signal.windows.barthann({len}, {})", py_bool(sym)); 169 | let py_res: Vec = with_scipy(&py_script).unwrap(); 170 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 171 | } 172 | } 173 | 174 | #[test] 175 | pub fn test_cosine() { 176 | for _ in 0..*TEST_ITER { 177 | let len = len(*TEST_LEN); 178 | let sym = rand::random(); 179 | let rust_res = cosine(len, sym).to_vec(); 180 | let py_script = format!("signal.windows.cosine({len}, {})", py_bool(sym)); 181 | let py_res: Vec = with_scipy(&py_script).unwrap(); 182 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 183 | } 184 | } 185 | 186 | #[test] 187 | pub fn test_exponential() { 188 | for _ in 0..*TEST_ITER { 189 | let len = len(*TEST_LEN); 190 | let sym = rand::random(); 191 | let rust_res = exponential(len, None, None, sym).to_vec(); 192 | let py_script = format!("signal.windows.exponential({len}, sym={})", py_bool(sym)); 193 | let py_res: Vec = with_scipy(&py_script).unwrap(); 194 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 195 | } 196 | } 197 | 198 | #[test] 199 | pub fn test_tukey() { 200 | for _ in 0..*TEST_ITER { 201 | let len = len(*TEST_LEN); 202 | let sym = rand::random(); 203 | let alpha = rand::thread_rng().gen_range(0.0..1.0); 204 | let rust_res = tukey(len, alpha, sym).to_vec(); 205 | let py_script = format!( 206 | "signal.windows.tukey({len}, alpha={alpha}, sym={})", 207 | py_bool(sym) 208 | ); 209 | let py_res: Vec = with_scipy(&py_script).unwrap(); 210 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 211 | } 212 | } 213 | 214 | #[test] 215 | pub fn test_taylor() { 216 | for _ in 0..*TEST_ITER { 217 | let len = len(*TEST_LEN); 218 | let sym = rand::random(); 219 | let rust_res = taylor(len, None, None, None, sym).to_vec(); 220 | let py_script = format!( 221 | "signal.windows.taylor({len},norm=False,sym={})", 222 | py_bool(sym) 223 | ); 224 | let py_res: Vec = with_scipy(&py_script).unwrap(); 225 | 226 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 227 | } 228 | } 229 | 230 | #[test] 231 | pub fn test_lanczos() { 232 | for _ in 0..*TEST_ITER { 233 | let len = len(*TEST_LEN); 234 | let sym = rand::random(); 235 | let rust_res = lanczos(len, sym).to_vec(); 236 | let py_script = format!("signal.windows.lanczos({len}, sym={})", py_bool(sym)); 237 | let py_res: Vec = with_scipy(&py_script).unwrap(); 238 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice()); 239 | } 240 | } 241 | 242 | #[test] 243 | pub fn test_kaiser() { 244 | for _ in 0..*TEST_ITER { 245 | let len = len(*TEST_LEN); 246 | let sym = rand::random(); 247 | let beta = rand::thread_rng().gen_range(2.0..30.0); 248 | let rust_res = kaiser(len, beta, sym).to_vec(); 249 | let py_script = format!( 250 | "signal.windows.kaiser({len}, beta={beta}, sym={})", 251 | py_bool(sym) 252 | ); 253 | let py_res: Vec = with_scipy(&py_script).unwrap(); 254 | approx::assert_relative_eq!( 255 | rust_res.as_slice(), 256 | py_res.as_slice(), 257 | epsilon = 0.00000000000001 258 | ); 259 | } 260 | } 261 | 262 | #[test] 263 | pub fn test_kaiser_bessel_derived() { 264 | for _ in 0..*TEST_ITER { 265 | let len = len(*TEST_LEN); 266 | let len = if len % 2 == 1 { len + 1 } else { len }; 267 | let sym = true; 268 | let beta = rand::thread_rng().gen_range(2.0..30.0); 269 | let rust_res = kaiser_bessel_derived(len, beta, sym).to_vec(); 270 | let py_script = format!( 271 | "signal.windows.kaiser_bessel_derived({len}, beta={beta}, sym={})", 272 | py_bool(sym) 273 | ); 274 | let py_res: Vec = with_scipy(&py_script).unwrap(); 275 | approx::assert_relative_eq!( 276 | rust_res.as_slice(), 277 | py_res.as_slice(), 278 | epsilon = 0.00000000000001 279 | ); 280 | } 281 | } 282 | 283 | #[test] 284 | pub fn test_gaussian() { 285 | for _ in 0..*TEST_ITER { 286 | let len = len(*TEST_LEN); 287 | let len = if len % 2 == 1 { len + 1 } else { len }; 288 | let sym = true; 289 | let std_dev = rand::thread_rng().gen_range(2.0..30.0); 290 | let rust_res = gaussian(len, std_dev, sym).to_vec(); 291 | let py_script = format!( 292 | "signal.windows.gaussian({len}, {std_dev}, sym={})", 293 | py_bool(sym) 294 | ); 295 | let py_res: Vec = with_scipy(&py_script).unwrap(); 296 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice(),); 297 | } 298 | } 299 | 300 | #[test] 301 | pub fn test_general_gaussian() { 302 | for _ in 0..*TEST_ITER { 303 | let len = len(*TEST_LEN); 304 | let len = if len % 2 == 1 { len + 1 } else { len }; 305 | let sym = true; 306 | let p = rand::thread_rng().gen_range(0.0..1.0); 307 | let sig = rand::thread_rng().gen_range(2.0..30.0); 308 | let rust_res = general_gaussian(len, p, sig, sym).to_vec(); 309 | let py_script = format!( 310 | "signal.windows.general_gaussian({len}, {p}, {sig}, sym={})", 311 | py_bool(sym) 312 | ); 313 | let py_res: Vec = with_scipy(&py_script).unwrap(); 314 | approx::assert_relative_eq!(rust_res.as_slice(), py_res.as_slice(),); 315 | } 316 | } 317 | -------------------------------------------------------------------------------- /tests/signal/lp2bf_zpk.rs: -------------------------------------------------------------------------------- 1 | use crate::common::check_zpk_filter; 2 | use crate::common::with_scipy; 3 | use ndarray::Array1; 4 | use ndarray_rand::RandomExt; 5 | use num::{complex::Complex64, NumCast}; 6 | use rand::{distributions::Uniform, Rng}; 7 | use sciport_rs::signal::{ 8 | band_filter::{lp2bp_zpk, lp2bs_zpk, lp2hp_zpk, lp2lp_zpk}, 9 | butter::*, 10 | cheby1::cheb1ap, 11 | output_type::GenericZpk, 12 | tools::bilinear_zpk, 13 | }; 14 | 15 | #[test] 16 | fn with_py_fuzzy_lp2lp_zpk() { 17 | for _ in 0..1_000 { 18 | let mut rng = rand::thread_rng(); 19 | let order = rng.gen_range(1..200); 20 | let input = cheb1ap(order as _, rng.gen_range(1.0..7.00)).expect("valid filter output"); 21 | let wo = rng.gen_range(0.0..1.0); 22 | let input2 = input.clone(); 23 | let rust = lp2lp_zpk(input.clone(), wo); 24 | 25 | let z_formatted = format!("{}", input2.z).replace('i', "j"); 26 | let p_formatted = format!("{}", input2.p).replace('i', "j"); 27 | 28 | let py_code = format!( 29 | "signal.lp2lp_zpk({z_formatted},{p_formatted},{}, wo={wo})", 30 | input.k 31 | ); 32 | 33 | let python = with_scipy::<(Vec, Vec, f64)>(&py_code); 34 | let python = if let Some(p) = python { 35 | p 36 | } else { 37 | continue; 38 | }; 39 | assert!(check_zpk_filter(rust, python)) 40 | } 41 | } 42 | 43 | #[test] 44 | fn with_py_fuzzy_lp2hp_zpk() { 45 | for _ in 0..10_000 { 46 | let mut rng = rand::thread_rng(); 47 | let order = rng.gen_range(1..200); 48 | let input = buttap(order as _).expect("valid filter output"); 49 | let wo = rng.gen_range(0.0..1.0); 50 | let input2 = input.clone(); 51 | let rust = lp2hp_zpk(input.clone(), wo); 52 | 53 | let z_formatted = format!("{}", input2.z).replace('i', "j"); 54 | let p_formatted = format!("{}", input2.p).replace('i', "j"); 55 | 56 | let py_code = format!( 57 | "signal.lp2hp_zpk({z_formatted},{p_formatted},{}, wo={wo})", 58 | input.k 59 | ); 60 | 61 | let python = with_scipy::<(Vec, Vec, f64)>(&py_code); 62 | 63 | let python = if let Some(p) = python { 64 | p 65 | } else { 66 | continue; 67 | }; 68 | assert!(check_zpk_filter(rust, python)) 69 | } 70 | } 71 | 72 | #[test] 73 | fn with_py_fuzzy_lp2bp_zpk() { 74 | for _ in 0..10_000 { 75 | let mut rng = rand::thread_rng(); 76 | let order = rng.gen_range(2..100); 77 | let input = buttap(order as _).expect("valid filter output"); 78 | let wo = rng.gen_range(0.0..1.0); 79 | let bw = rng.gen_range(0.0..1.0); 80 | let input2 = input.clone(); 81 | let rust = lp2bp_zpk(input.clone(), wo, bw); 82 | 83 | let z_formatted = format!("{}", input2.z).replace('i', "j"); 84 | let p_formatted = format!("{}", input2.p).replace('i', "j"); 85 | 86 | let py_code = format!( 87 | "signal.lp2bp_zpk({z_formatted},{p_formatted},{}, wo={wo}, bw={bw})", 88 | input.k 89 | ); 90 | 91 | let python = with_scipy::<(Vec, Vec, f64)>(&py_code); 92 | let python = if let Some(p) = python { 93 | p 94 | } else { 95 | continue; 96 | }; 97 | let res = check_zpk_filter(rust.clone(), python.clone()); 98 | if !res { 99 | println!("{}", py_code); 100 | println!( 101 | "{:?} \n{:?}", 102 | rust.cast_with_fn(|a| ::from(a).unwrap()), 103 | python 104 | ); 105 | } 106 | assert!(res) 107 | } 108 | } 109 | 110 | #[test] 111 | fn with_py_fuzzy_lp2bs_zpk() { 112 | for _ in 0..10_000 { 113 | let order = rand::thread_rng().gen_range(1..5); 114 | 115 | let input = buttap(order as _).expect("valid filter output"); 116 | 117 | let wo = 0.5; 118 | let bw = 10.0; 119 | 120 | let z_formatted = format!("{}", input.z).replace('i', "j"); 121 | let p_formatted = format!("{}", input.p).replace('i', "j"); 122 | let k = input.k; 123 | 124 | let py_code = format!( 125 | "signal.lp2bs_zpk({z_formatted},{p_formatted},{}, wo={wo}, bw={bw})", 126 | k 127 | ) 128 | .replace("+-", "-"); 129 | 130 | println!("{py_code}"); 131 | 132 | let rust = lp2bs_zpk(input.clone(), wo, bw); 133 | 134 | let python = with_scipy::<(Vec, Vec, f64)>(&py_code); 135 | let python = if let Some(p) = python { 136 | p 137 | } else { 138 | continue; 139 | }; 140 | 141 | assert!(check_zpk_filter(rust, python)) 142 | } 143 | } 144 | 145 | #[test] 146 | fn with_py_fuzzy_bilinear_zpk() { 147 | for _ in 0..1_000 { 148 | let z = Array1::::random(14, Uniform::new(-1.0, 1.0)).mapv(Into::into); 149 | let p = Array1::::random(14, Uniform::new(-1.0, 1.0)).mapv(Into::into); 150 | let k = rand::thread_rng().gen_range(0.0..5.0); 151 | 152 | let zpk = GenericZpk { z, p, k }; 153 | let input = zpk.clone(); 154 | let fs = rand::thread_rng().gen_range(10.0..500.0); 155 | let result = bilinear_zpk(zpk, fs); 156 | 157 | let z_formatted = format!("{}", input.z).replace('i', "j"); 158 | let p_formatted = format!("{}", input.p).replace('i', "j"); 159 | let k = input.k; 160 | 161 | let py_code = format!("signal.bilinear_zpk({z_formatted},{p_formatted},{k}, fs={fs})",) 162 | .replace("+-", "-"); 163 | 164 | let python = with_scipy::<(Vec, Vec, f64)>(&py_code); 165 | let python = if let Some(p) = python { 166 | p 167 | } else { 168 | continue; 169 | }; 170 | 171 | assert!(check_zpk_filter(result, python)); 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /tests/signal/main.rs: -------------------------------------------------------------------------------- 1 | mod bessel; 2 | mod butter; 3 | mod cheby1; 4 | mod cheby2; 5 | mod common; 6 | mod fir_filter_design; 7 | mod fir_filter_design_windows; 8 | mod lp2bf_zpk; 9 | mod signal_tools; 10 | -------------------------------------------------------------------------------- /tests/signal/signal_tools.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array1; 2 | use numpy::Complex64; 3 | use sciport_rs::signal::{ 4 | band_filter::BandFilter, tools::generic_approx_complex_relative_slice_eq_dbg, Firwin1Filter, 5 | GenericFIRFilterSettings, Sampling, WindowType, 6 | }; 7 | 8 | use crate::common::with_scipy; 9 | 10 | #[test] 11 | fn bad_test() { 12 | let filter_gen = Firwin1Filter { 13 | settings: GenericFIRFilterSettings { 14 | numtaps: 8, 15 | cutoff: BandFilter::Lowpass(4.0), 16 | width: None, 17 | window: WindowType::Hamming, 18 | scale: false, 19 | sampling: Sampling::Digital { fs: 200.0 }, 20 | }, 21 | }; 22 | 23 | let filter = filter_gen.firwin().ba(); 24 | 25 | println!("filter: {:?}", filter); 26 | 27 | let signal = Array1::linspace(-1.0, 1.0, 200); 28 | let result = sciport_rs::signal::Filter::lfilter(&filter, signal.mapv(Into::into), None); 29 | 30 | let b_formatted = filter.b.to_string().replace('i', "j"); 31 | let a_formatted = filter.a.to_string().replace('i', "j"); 32 | let _zi: Array1<_> = vec![0.0; filter.a.len() - 1].into(); 33 | 34 | let py_cmd = format!("signal.lfilter({b_formatted}, {a_formatted}, {signal})"); 35 | let python = with_scipy::>(&py_cmd); 36 | let python = if let Some(p) = python { 37 | p 38 | } else { 39 | return; 40 | }; 41 | 42 | // let py_cmd = format!("signal.lfilter({b_formatted}, {a_formatted}, {signal}, zi=[0,0,0])"); 43 | // let (python, _): (Vec, Vec) = with_scipy(&py_cmd); 44 | 45 | let filtered = result.filtered; 46 | 47 | let epsilon = 10.0_f64.powi(-7); 48 | 49 | assert_eq!(filtered.len(), python.len()); 50 | 51 | println!("signal start {}", signal[0]); 52 | println!("signal end {:?}", signal.last()); 53 | 54 | assert!(generic_approx_complex_relative_slice_eq_dbg( 55 | filtered.to_vec().as_slice(), 56 | python.as_slice(), 57 | epsilon, 58 | epsilon, 59 | )); 60 | // let filter = filter_gen.compute_filter(DesiredFilterOutput::Zpk).zpk(); 61 | // let result = 62 | // sciport_rs::signal::output_type::Filter::lfilter(&filter, signal.mapv(Into::into), None); 63 | 64 | // let filtered = result.filtered; 65 | // assert_relative_eq!(filtered.to_vec().as_slice(), python.as_slice()); 66 | } 67 | -------------------------------------------------------------------------------- /tests/special.rs: -------------------------------------------------------------------------------- 1 | use num::Complex; 2 | use pyo3::{types::IntoPyDict, FromPyObject, Python}; 3 | use rand::Rng; 4 | 5 | pub fn with_scipy(cl: &str) -> Option 6 | where 7 | for<'a> T: FromPyObject<'a>, 8 | { 9 | Python::with_gil(|gil| { 10 | let signal = gil.import("scipy.signal").unwrap(); 11 | let special = gil.import("scipy.special").unwrap(); 12 | let np = gil.import("numpy").unwrap(); 13 | 14 | let globals = [("signal", signal), ("special", special), ("np", np)].into_py_dict(gil); 15 | 16 | let res = gil.eval(cl, globals.into(), None).ok(); 17 | 18 | let arr: Option = res.map(|a| a.extract().unwrap()); 19 | 20 | arr.clone() 21 | }) 22 | } 23 | #[test] 24 | pub fn test_i0() { 25 | for _ in 0..500 { 26 | let len = rand::thread_rng().gen_range(2..100); 27 | let mut v = Vec::with_capacity(len); 28 | for _ in 0..len { 29 | v.push(rand::thread_rng().gen_range(0.0..1.0)); 30 | } 31 | 32 | let rust_result = sciport_rs::special::i0(v.clone().into()) 33 | .unwrap() 34 | .mapv(|a| a.norm()) 35 | .to_vec(); 36 | let py_script = format!("special.i0({v:?})"); 37 | let py_res: Vec> = with_scipy(&py_script).unwrap(); 38 | let py_res: Vec = py_res.into_iter().map(|a| a.norm()).collect(); 39 | 40 | approx::assert_relative_eq!( 41 | rust_result.as_slice(), 42 | py_res.as_slice(), 43 | epsilon = 0.00000000000001 44 | ); 45 | } 46 | } 47 | --------------------------------------------------------------------------------