├── .gitignore ├── .DS_Store ├── Cargo.toml ├── LICENSE-MIT ├── src ├── conv_fft │ ├── good_size.rs │ ├── processor │ │ ├── mod.rs │ │ ├── complex.rs │ │ └── real.rs │ ├── padding.rs │ ├── mod.rs │ └── tests.rs ├── padding │ ├── dim.rs │ ├── half_dim.rs │ └── mod.rs ├── lib.rs ├── conv │ ├── mod.rs │ └── tests.rs └── dilation │ └── mod.rs ├── .vscode └── launch.json ├── benches └── with_torch.rs ├── LICENSE-APACHE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | notebook 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TYPEmber/ndarray-conv/HEAD/.DS_Store -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ndarray-conv" 3 | version = "0.6.0" 4 | edition = "2021" 5 | license = "MIT OR Apache-2.0" 6 | keywords = ["convolution", "ndarray", "FFT"] 7 | description = "N-Dimension convolution (with FFT) lib for ndarray." 8 | repository = "https://github.com/TYPEmber/ndarray-conv.git" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | ndarray = {version = ">=0.17", features = ["rayon"]} 14 | num = "0.4" 15 | rustfft = "6.4" 16 | realfft = "3.5" 17 | thiserror = "2.0" 18 | castaway = "0.2" 19 | 20 | [dev-dependencies] 21 | tch = {version = "0.20.0", features = ["download-libtorch"]} 22 | criterion = { version = "0.6", features = ["html_reports"] } 23 | fft-convolver = "0.2" 24 | # fftconvolve = "0.1" 25 | convolutions-rs = "0.3" 26 | ndarray-vision = "0.5" 27 | ndarray-rand = ">=0.16" 28 | 29 | [[bench]] 30 | name = "with_torch" 31 | harness = false -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 - 2024 Yupei Tian (TYPEmber), 4 | Wenhui Han (hwharper), 5 | and ndarray-conv developers 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /src/conv_fft/good_size.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions for determining good FFT sizes. 2 | //! 3 | //! This module implements strategies for finding FFT sizes that 4 | //! are efficient for the `rustfft` library. 5 | 6 | fn good_size_cc(n: usize) -> usize { 7 | let mut best_fac = n.next_power_of_two(); 8 | 9 | loop { 10 | let new_fac = best_fac / 4 * 3; 11 | match new_fac.cmp(&n) { 12 | std::cmp::Ordering::Less => break, 13 | std::cmp::Ordering::Equal => return n, 14 | std::cmp::Ordering::Greater => { 15 | best_fac = new_fac; 16 | } 17 | } 18 | } 19 | loop { 20 | let new_fac = best_fac / 6 * 5; 21 | match new_fac.cmp(&n) { 22 | std::cmp::Ordering::Less => break, 23 | std::cmp::Ordering::Equal => return n, 24 | std::cmp::Ordering::Greater => { 25 | best_fac = new_fac; 26 | } 27 | } 28 | } 29 | 30 | best_fac 31 | } 32 | 33 | /// Computes efficient FFT sizes for each dimension of an array. 34 | /// 35 | /// This function takes an array of dimensions and returns an array of the same 36 | /// size where each element is a "good" size for FFT calculations, as determined 37 | /// by the `good_size_cc` function. 38 | /// 39 | /// This function seems to be very slow for numbers that have large prime components. 40 | pub fn compute(size: &[usize; N]) -> [usize; N] { 41 | std::array::from_fn(|i| good_size_cc(size[i])) 42 | } 43 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug unit tests in library 'ndarray_conv'", 11 | "cargo": { 12 | "args": [ 13 | "test", 14 | "--no-run", 15 | "--lib", 16 | "--package=ndarray-conv" 17 | ], 18 | "filter": { 19 | "name": "ndarray_conv", 20 | "kind": "lib" 21 | } 22 | }, 23 | "args": [], 24 | "cwd": "${workspaceFolder}" 25 | }, 26 | { 27 | "type": "lldb", 28 | "request": "launch", 29 | "name": "Debug executable 'ndarray-conv'", 30 | "cargo": { 31 | "args": [ 32 | "build", 33 | "--bin=ndarray-conv", 34 | "--package=ndarray-conv" 35 | ], 36 | "filter": { 37 | "name": "ndarray-conv", 38 | "kind": "bin" 39 | } 40 | }, 41 | "args": [], 42 | "cwd": "${workspaceFolder}" 43 | }, 44 | { 45 | "type": "lldb", 46 | "request": "launch", 47 | "name": "Debug unit tests in executable 'ndarray-conv'", 48 | "cargo": { 49 | "args": [ 50 | "test", 51 | "--no-run", 52 | "--bin=ndarray-conv", 53 | "--package=ndarray-conv" 54 | ], 55 | "filter": { 56 | "name": "ndarray-conv", 57 | "kind": "bin" 58 | } 59 | }, 60 | "args": [], 61 | "cwd": "${workspaceFolder}" 62 | } 63 | ] 64 | } -------------------------------------------------------------------------------- /src/conv_fft/processor/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides FFT processor implementations for convolution operations. 2 | //! 3 | //! This module contains traits and implementations for performing forward and backward FFT transforms 4 | //! on real and complex-valued arrays. These processors are used internally by the FFT-accelerated 5 | //! convolution methods. 6 | 7 | use std::marker::PhantomData; 8 | 9 | use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis}; 10 | use num::Complex; 11 | use rustfft::FftNum; 12 | 13 | pub mod complex; 14 | pub mod real; 15 | 16 | /// Marker trait for numeric types that can be used with ConvFftNum. 17 | /// 18 | /// This trait is implemented for both integer and floating-point types that implement `FftNum`. 19 | /// 20 | /// # Important Note 21 | /// 22 | /// While this trait is implemented for integer types (i8, i16, i32, i64, i128, isize), 23 | /// **integer FFT operations have known accuracy issues** and should NOT be used in production. 24 | /// Only lengths of 2 or 4 work correctly for 1D arrays; other lengths produce incorrect results. 25 | /// 26 | /// **Always use f32 or f64 for FFT operations.** 27 | pub trait ConvFftNum: FftNum {} 28 | 29 | macro_rules! impl_conv_fft_num { 30 | ($($t:ty),*) => { 31 | $(impl ConvFftNum for $t {})* 32 | }; 33 | } 34 | 35 | impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64); 36 | 37 | /// Returns a processor instance for the given input element type. 38 | /// 39 | /// This function is a convenience wrapper around `GetProcessor::get_processor()`. 40 | /// 41 | /// # Type Parameters 42 | /// 43 | /// * `T`: The FFT numeric type (typically f32 or f64) 44 | /// * `InElem`: The input element type (`T` for real, `Complex` for complex) 45 | pub fn get>() -> impl Processor { 46 | InElem::get_processor() 47 | } 48 | 49 | /// Trait for FFT processors that can perform forward and backward transforms. 50 | /// 51 | /// This trait defines the interface for performing FFT operations on N-dimensional arrays. 52 | /// Implementations exist for both real-valued and complex-valued inputs. 53 | pub trait Processor> { 54 | /// Performs a forward FFT transform. 55 | /// 56 | /// Converts the input array from the spatial/time domain to the frequency domain. 57 | /// 58 | /// # Arguments 59 | /// 60 | /// * `input`: A mutable reference to the input array 61 | /// 62 | /// # Returns 63 | /// 64 | /// An array of complex values representing the frequency domain. 65 | fn forward, const N: usize>( 66 | &mut self, 67 | input: &mut ArrayBase>, 68 | ) -> Array, Dim<[Ix; N]>> 69 | where 70 | Dim<[Ix; N]>: RemoveAxis, 71 | [Ix; N]: IntoDimension>; 72 | 73 | /// Performs a backward (inverse) FFT transform. 74 | /// 75 | /// Converts the input array from the frequency domain back to the spatial/time domain. 76 | /// 77 | /// # Arguments 78 | /// 79 | /// * `input`: A mutable reference to the frequency domain array 80 | /// 81 | /// # Returns 82 | /// 83 | /// An array in the spatial/time domain with the same element type as the original input. 84 | fn backward( 85 | &mut self, 86 | input: &mut Array, Dim<[Ix; N]>>, 87 | ) -> Array> 88 | where 89 | Dim<[Ix; N]>: RemoveAxis, 90 | [Ix; N]: IntoDimension>; 91 | } 92 | 93 | /// Trait for types that can provide a processor instance. 94 | /// 95 | /// This trait is implemented for real and complex numeric types, allowing them to 96 | /// automatically select the appropriate FFT processor implementation. 97 | pub trait GetProcessor 98 | where 99 | InElem: GetProcessor, 100 | { 101 | /// Returns a processor instance appropriate for this type. 102 | fn get_processor() -> impl Processor; 103 | } 104 | 105 | impl GetProcessor for T { 106 | fn get_processor() -> impl Processor { 107 | real::Processor::::default() 108 | } 109 | } 110 | 111 | impl GetProcessor> for Complex { 112 | fn get_processor() -> impl Processor> { 113 | complex::Processor::::default() 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/padding/dim.rs: -------------------------------------------------------------------------------- 1 | //! Provides padding functions for individual dimensions. 2 | //! 3 | //! This module contains functions for applying padding to a specific 4 | //! dimension of an array, including constant, replicate, reflect, and 5 | //! circular padding modes. These functions are used internally by the 6 | //! `PaddingExt` trait to implement padding in N-dimensional arrays. 7 | 8 | use ndarray::{ArrayBase, DataMut, Dim, Ix, RemoveAxis}; 9 | use num::traits::NumAssign; 10 | 11 | use super::half_dim; 12 | 13 | /// Applies constant padding to a specific dimension of the input array. 14 | /// 15 | /// This function pads the front and back of a given dimension with a 16 | /// specified constant value. 17 | /// 18 | /// # Type Parameters 19 | /// 20 | /// * `N`: The number of dimensions. 21 | /// * `T`: The numeric type of the array elements. 22 | /// * `S`: The data storage type of the array. 23 | /// * `D`: The dimension type of the input data. 24 | /// * `DO`: The dimension type of the output data. 25 | /// 26 | /// # Arguments 27 | /// 28 | /// * `input_dim`: The dimensions of the original array. 29 | /// * `buffer`: A mutable reference to the array to be padded. 30 | /// * `dim`: The dimension to pad. 31 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 32 | /// * `constant`: The constant value to pad with. 33 | #[inline] 34 | pub fn constant( 35 | input_dim: D, 36 | buffer: &mut ArrayBase, 37 | dim: usize, 38 | padding: [usize; 2], 39 | constant: T, 40 | ) where 41 | T: NumAssign + Copy, 42 | S: DataMut, 43 | D: RemoveAxis, 44 | DO: RemoveAxis, 45 | Dim<[Ix; N]>: RemoveAxis, 46 | { 47 | half_dim::constant_front(buffer, dim, padding, constant); 48 | half_dim::constant_back(input_dim, buffer, dim, padding, constant); 49 | } 50 | 51 | /// Applies replicate padding to a specific dimension of the input array. 52 | /// 53 | /// This function pads the front and back of a given dimension by replicating 54 | /// the edge values. 55 | /// 56 | /// # Type Parameters 57 | /// 58 | /// * `N`: The number of dimensions. 59 | /// * `T`: The numeric type of the array elements. 60 | /// * `S`: The data storage type of the array. 61 | /// * `D`: The dimension type of the input data. 62 | /// * `DO`: The dimension type of the output data. 63 | /// 64 | /// # Arguments 65 | /// 66 | /// * `input_dim`: The dimensions of the original array. 67 | /// * `buffer`: A mutable reference to the array to be padded. 68 | /// * `dim`: The dimension to pad. 69 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 70 | #[inline] 71 | pub fn replicate( 72 | input_dim: D, 73 | buffer: &mut ArrayBase, 74 | dim: usize, 75 | padding: [usize; 2], 76 | ) where 77 | T: NumAssign + Copy, 78 | S: DataMut, 79 | D: RemoveAxis, 80 | DO: RemoveAxis, 81 | Dim<[Ix; N]>: RemoveAxis, 82 | { 83 | half_dim::replicate_front(buffer, dim, padding); 84 | half_dim::replicate_back(input_dim, buffer, dim, padding); 85 | } 86 | 87 | /// Applies reflect padding to a specific dimension of the input array. 88 | /// 89 | /// This function pads the front and back of a given dimension by 90 | /// reflecting the array at the boundaries. 91 | /// 92 | /// # Type Parameters 93 | /// 94 | /// * `N`: The number of dimensions. 95 | /// * `T`: The numeric type of the array elements. 96 | /// * `S`: The data storage type of the array. 97 | /// * `D`: The dimension type of the input data. 98 | /// * `DO`: The dimension type of the output data. 99 | /// 100 | /// # Arguments 101 | /// 102 | /// * `input_dim`: The dimensions of the original array. 103 | /// * `buffer`: A mutable reference to the array to be padded. 104 | /// * `dim`: The dimension to pad. 105 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 106 | #[inline] 107 | pub fn reflect( 108 | input_dim: D, 109 | buffer: &mut ArrayBase, 110 | dim: usize, 111 | padding: [usize; 2], 112 | ) where 113 | T: NumAssign + Copy, 114 | S: DataMut, 115 | D: RemoveAxis, 116 | DO: RemoveAxis, 117 | Dim<[Ix; N]>: RemoveAxis, 118 | { 119 | half_dim::reflect_front(buffer, dim, padding); 120 | half_dim::reflect_back(input_dim, buffer, dim, padding); 121 | } 122 | 123 | /// Applies circular padding to a specific dimension of the input array. 124 | /// 125 | /// This function pads the front and back of a given dimension by 126 | /// wrapping the data around the boundaries. 127 | /// 128 | /// # Type Parameters 129 | /// 130 | /// * `N`: The number of dimensions. 131 | /// * `T`: The numeric type of the array elements. 132 | /// * `S`: The data storage type of the array. 133 | /// * `D`: The dimension type of the input data. 134 | /// * `DO`: The dimension type of the output data. 135 | /// 136 | /// # Arguments 137 | /// 138 | /// * `input_dim`: The dimensions of the original array. 139 | /// * `buffer`: A mutable reference to the array to be padded. 140 | /// * `dim`: The dimension to pad. 141 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 142 | #[inline] 143 | pub fn circular( 144 | input_dim: D, 145 | buffer: &mut ArrayBase, 146 | dim: usize, 147 | padding: [usize; 2], 148 | ) where 149 | T: NumAssign + Copy, 150 | S: DataMut, 151 | D: RemoveAxis, 152 | DO: RemoveAxis, 153 | Dim<[Ix; N]>: RemoveAxis, 154 | { 155 | half_dim::circular_front(buffer, dim, padding); 156 | half_dim::circular_back(input_dim, buffer, dim, padding); 157 | } 158 | -------------------------------------------------------------------------------- /src/conv_fft/padding.rs: -------------------------------------------------------------------------------- 1 | //! Provides padding functions for FFT-based convolutions. 2 | //! 3 | //! These functions handle padding of input data and kernels to 4 | //! appropriate sizes for efficient FFT calculations. Padding is 5 | //! crucial for correctly implementing convolution via FFT. 6 | 7 | use ndarray::{ 8 | Array, ArrayBase, Data, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo, SliceInfoElem, 9 | }; 10 | use num::traits::NumAssign; 11 | 12 | use crate::{dilation::KernelWithDilation, padding::PaddingExt, ExplicitPadding, PaddingMode}; 13 | 14 | /// Pads the input data for FFT-based convolution. 15 | /// 16 | /// This function takes the input data, padding mode, explicit padding, and desired FFT size 17 | /// and returns a new array with the appropriate padding applied. The padding is applied 18 | /// to each dimension according to the specified `padding_mode` and `explicit_padding`. 19 | /// 20 | /// # Arguments 21 | /// 22 | /// * `data`: The input data array. 23 | /// * `padding_mode`: The padding mode to use (e.g., `PaddingMode::Zeros`, `PaddingMode::Reflect`). 24 | /// * `explicit_padding`: An array specifying the padding for each dimension. 25 | /// * `fft_size`: The desired size for FFT calculations. The output array will have these dimensions. 26 | /// 27 | /// # Returns 28 | /// 29 | /// A new array with the padded data, ready for FFT transformation. 30 | pub fn data( 31 | data: &ArrayBase>, 32 | padding_mode: PaddingMode, 33 | explicit_padding: ExplicitPadding, 34 | fft_size: [usize; N], 35 | ) -> Array> 36 | where 37 | T: NumAssign + Copy, 38 | S: Data, 39 | Dim<[Ix; N]>: RemoveAxis, 40 | [Ix; N]: IntoDimension>, 41 | // the key question is how to prove 42 | // , Dim<[Ix; N]>> as SliceArg>>::OutDim 43 | // is Dim<[Ix; N]> 44 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 45 | SliceArg, OutDim = Dim<[Ix; N]>>, 46 | { 47 | let mut buffer: Array> = Array::from_elem(fft_size, T::zero()); 48 | 49 | let raw_dim = data.raw_dim(); 50 | let mut buffer_slice = buffer.slice_mut(unsafe { 51 | SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 52 | start: 0, 53 | end: Some((explicit_padding[i][0] + raw_dim[i] + explicit_padding[i][1]) as isize), 54 | step: 1, 55 | })) 56 | .unwrap() 57 | }); 58 | 59 | data.padding_in(&mut buffer_slice, padding_mode, explicit_padding); 60 | 61 | buffer 62 | } 63 | 64 | /// Pads the kernel for FFT-based convolution. 65 | /// 66 | /// This function takes the kernel, expands it with dilations, and pads it with zeros to the 67 | /// desired FFT size, preparing it for FFT transformation. The kernel is also reversed 68 | /// in each dimension as required for convolution via FFT. 69 | /// 70 | /// # Arguments 71 | /// 72 | /// * `kwd`: The kernel with dilation information. 73 | /// * `fft_size`: The desired size for FFT calculations. The output array will have these dimensions. 74 | /// 75 | /// # Returns 76 | /// 77 | /// A new array with the padded and reversed kernel, ready for FFT transformation. 78 | pub fn kernel<'a, T, S, const N: usize>( 79 | kwd: KernelWithDilation<'a, S, N>, 80 | fft_size: [usize; N], 81 | ) -> Array> 82 | where 83 | T: NumAssign + Copy + 'a, 84 | S: Data, 85 | [Ix; N]: IntoDimension>, 86 | Dim<[Ix; N]>: RemoveAxis, 87 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 88 | SliceArg, OutDim = Dim<[Ix; N]>>, 89 | { 90 | let mut buffer: Array> = Array::from_elem(fft_size, T::zero()); 91 | 92 | let kernel = kwd.kernel; 93 | 94 | let kernel_raw_dim = kernel.raw_dim(); 95 | let kernel_raw_dim_with_dilation: [usize; N] = 96 | std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1); 97 | 98 | let mut buffer_slice = buffer.slice_mut(unsafe { 99 | SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 100 | start: 0, 101 | end: Some(kernel_raw_dim_with_dilation[i] as isize), 102 | // use negative stride to make kernel reverse 103 | step: (kwd.dilation[i] as isize) * if kwd.reverse { 1 } else { -1 }, 104 | })) 105 | .unwrap() 106 | }); 107 | 108 | buffer_slice.zip_mut_with(kernel, |b, &k| *b = k); 109 | 110 | buffer 111 | } 112 | 113 | #[cfg(test)] 114 | mod tests { 115 | use crate::{ 116 | dilation::{IntoKernelWithDilation, WithDilation}, 117 | BorderType, ConvMode, 118 | }; 119 | use ndarray::prelude::*; 120 | 121 | use super::*; 122 | 123 | #[test] 124 | fn data_padding() { 125 | let arr = array![[1, 2], [3, 4]]; 126 | let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]]; 127 | let kernel = kernel.into_kernel_with_dilation(); 128 | 129 | let explicit_conv = ConvMode::Full.unfold(&kernel); 130 | let explicit_padding = explicit_conv.padding; 131 | 132 | let arr_padded = data( 133 | &arr, 134 | PaddingMode::Custom([BorderType::Const(7), BorderType::Const(8)]), 135 | // PaddingMode::Const(7), 136 | explicit_padding, 137 | [8, 8], 138 | ); 139 | 140 | assert_eq!( 141 | arr_padded, 142 | array![ 143 | [8, 8, 7, 7, 8, 8, 0, 0], 144 | [8, 8, 7, 7, 8, 8, 0, 0], 145 | [8, 8, 1, 2, 8, 8, 0, 0], 146 | [8, 8, 3, 4, 8, 8, 0, 0], 147 | [8, 8, 7, 7, 8, 8, 0, 0], 148 | [8, 8, 7, 7, 8, 8, 0, 0], 149 | [0, 0, 0, 0, 0, 0, 0, 0], 150 | [0, 0, 0, 0, 0, 0, 0, 0] 151 | ] 152 | ); 153 | } 154 | 155 | #[test] 156 | fn kernel_padding() { 157 | let _arr = array![[1, 2], [3, 4]]; 158 | let kernel = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 159 | let kernel = kernel.with_dilation([2, 3]).into_kernel_with_dilation(); 160 | 161 | let explicit_conv = ConvMode::Full.unfold(&kernel); 162 | let _explicit_padding = explicit_conv.padding; 163 | 164 | let kernel_padded = super::kernel(kernel, [8, 8]); 165 | 166 | dbg!(&kernel_padded); 167 | 168 | assert_eq!( 169 | kernel_padded, 170 | array![ 171 | [1, 0, 0, 2, 0, 0, 3, 0], 172 | [0, 0, 0, 0, 0, 0, 0, 0], 173 | [4, 0, 0, 5, 0, 0, 6, 0], 174 | [0, 0, 0, 0, 0, 0, 0, 0], 175 | [7, 0, 0, 8, 0, 0, 9, 0], 176 | [0, 0, 0, 0, 0, 0, 0, 0], 177 | [0, 0, 0, 0, 0, 0, 0, 0], 178 | [0, 0, 0, 0, 0, 0, 0, 0] 179 | ] 180 | ); 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! `ndarray-conv` provides N-dimensional convolution operations for `ndarray` arrays. 2 | //! 3 | //! This crate extends the `ndarray` library with both standard and 4 | //! FFT-accelerated convolution methods. 5 | //! 6 | //! # Getting Started 7 | //! 8 | //! To start performing convolutions, you'll interact with the following: 9 | //! 10 | //! 1. **Input Arrays:** Use `ndarray`'s [`Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) 11 | //! or [`ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html) 12 | //! as your input data and convolution kernel. 13 | //! 2. **Convolution Methods:** Call `array.conv(...)` or `array.conv_fft(...)`. 14 | //! These methods are added to `ArrayBase` types via the traits 15 | //! [`ConvExt::conv`] and [`ConvFFTExt::conv_fft`]. 16 | //! 3. **Convolution Mode:** [`ConvMode`] specifies the size of the output. 17 | //! 4. **Padding Mode:** [`PaddingMode`] specifies how to handle array boundaries. 18 | //! 19 | //! # Basic Example: 20 | //! 21 | //! Here's a simple example of how to perform a 2D convolution using `ndarray-conv`: 22 | //! 23 | //! ```rust 24 | //! use ndarray::prelude::*; 25 | //! use ndarray_conv::{ConvExt, ConvFFTExt, ConvMode, PaddingMode}; 26 | //! 27 | //! // Input data 28 | //! let input = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 29 | //! 30 | //! // Convolution kernel 31 | //! let kernel = array![[1, 1], [1, 1]]; 32 | //! 33 | //! // Perform standard convolution with "same" output size and zero padding 34 | //! let output = input.conv( 35 | //! &kernel, 36 | //! ConvMode::Same, 37 | //! PaddingMode::Zeros, 38 | //! ).unwrap(); 39 | //! 40 | //! println!("Standard Convolution Output:\n{:?}", output); 41 | //! 42 | //! // Perform FFT-accelerated convolution with "same" output size and zero padding 43 | //! let output_fft = input.map(|&x| x as f32).conv_fft( 44 | //! &kernel.map(|&x| x as f32), 45 | //! ConvMode::Same, 46 | //! PaddingMode::Zeros, 47 | //! ).unwrap(); 48 | //! 49 | //! println!("FFT Convolution Output:\n{:?}", output_fft); 50 | //! ``` 51 | //! 52 | //! # Choosing a convolution method 53 | //! 54 | //! * Use [`ConvExt::conv`] for standard convolution 55 | //! * Use [`ConvFFTExt::conv_fft`] for FFT accelerated convolution. 56 | //! FFT accelerated convolution is generally faster for larger kernels, but 57 | //! standard convolution may be faster for smaller kernels. 58 | //! 59 | //! # Key Structs, Enums and Traits 60 | //! 61 | //! * [`ConvMode`]: Specifies how to determine the size of the convolution output (e.g., `Full`, `Same`, `Valid`). 62 | //! * [`PaddingMode`]: Specifies how to handle array boundaries (e.g., `Zeros`, `Reflect`, `Replicate`). You can also use `PaddingMode::Custom` or `PaddingMode::Explicit` to combine different [`BorderType`] strategies for each dimension or for each side of each dimension. 63 | //! * [`BorderType`]: Used with [`PaddingMode`] for `Custom` and `Explicit`, specifies the padding strategy (e.g., `Zeros`, `Reflect`, `Replicate`, `Circular`). 64 | //! * [`ConvExt`]: The trait that adds the `conv` method, extending `ndarray` arrays with standard convolution functionality. 65 | //! * [`ConvFFTExt`]: The trait that adds the `conv_fft` method, extending `ndarray` arrays with FFT-accelerated convolution functionality. 66 | 67 | mod conv; 68 | mod conv_fft; 69 | mod dilation; 70 | mod padding; 71 | 72 | pub(crate) use padding::ExplicitPadding; 73 | 74 | pub use conv::ConvExt; 75 | pub use conv_fft::{ 76 | get_processor as get_fft_processor, ConvFFTExt, GetProcessor, Processor as FftProcessor, 77 | }; 78 | pub use dilation::{ReverseKernel, WithDilation}; 79 | 80 | /// Specifies the convolution mode, which determines the output size. 81 | #[derive(Debug, Clone, Copy)] 82 | pub enum ConvMode { 83 | /// The output has the largest size, including all positions where 84 | /// the kernel and input overlap at least partially. 85 | Full, 86 | /// The output has the same size as the input. 87 | Same, 88 | /// The output has the smallest size, including only positions 89 | /// where the kernel and input fully overlap. 90 | Valid, 91 | /// Specifies custom padding and strides. 92 | Custom { 93 | /// The padding to use for each dimension. 94 | padding: [usize; N], 95 | /// The strides to use for each dimension. 96 | strides: [usize; N], 97 | }, 98 | /// Specifies explicit padding and strides. 99 | Explicit { 100 | /// The padding to use for each side of each dimension. 101 | padding: [[usize; 2]; N], 102 | /// The strides to use for each dimension. 103 | strides: [usize; N], 104 | }, 105 | } 106 | /// Specifies the padding mode, which determines how to handle borders. 107 | /// 108 | /// The padding mode can be either a single `BorderType` applied on all sides 109 | /// or a custom tuple of two `BorderTypes` for each dimension or a `BorderType` 110 | /// for each side of each dimension. 111 | #[derive(Debug, Clone, Copy)] 112 | pub enum PaddingMode { 113 | /// Pads with zeros. 114 | Zeros, 115 | /// Pads with a constant value. 116 | Const(T), 117 | /// Reflects the input at the borders. 118 | Reflect, 119 | /// Replicates the edge values. 120 | Replicate, 121 | /// Treats the input as a circular buffer. 122 | Circular, 123 | /// Specifies a different `BorderType` for each dimension. 124 | Custom([BorderType; N]), 125 | /// Specifies a different `BorderType` for each side of each dimension. 126 | Explicit([[BorderType; 2]; N]), 127 | } 128 | 129 | /// Used with [`PaddingMode`]. Specifies the padding mode for a single dimension 130 | /// or a single side of a dimension. 131 | #[derive(Debug, Clone, Copy)] 132 | pub enum BorderType { 133 | /// Pads with zeros. 134 | Zeros, 135 | /// Pads with a constant value. 136 | Const(T), 137 | /// Reflects the input at the borders. 138 | Reflect, 139 | /// Replicates the edge values. 140 | Replicate, 141 | /// Treats the input as a circular buffer. 142 | Circular, 143 | } 144 | 145 | use thiserror::Error; 146 | 147 | /// Error type for convolution operations. 148 | #[derive(Error, Debug)] 149 | pub enum Error { 150 | /// Indicates that the input data array has a dimension with zero size. 151 | #[error("Data shape shouldn't have ZERO. {0:?}")] 152 | DataShape(ndarray::Dim<[ndarray::Ix; N]>), 153 | /// Indicates that the kernel array has a dimension with zero size. 154 | #[error("Kernel shape shouldn't have ZERO. {0:?}")] 155 | KernelShape(ndarray::Dim<[ndarray::Ix; N]>), 156 | /// Indicates that the shape of the kernel with dilation is not compatible with the chosen `ConvMode`. 157 | #[error("ConvMode {0:?} does not match KernelWithDilation Size {1:?}")] 158 | MismatchShape(ConvMode, [ndarray::Ix; N]), 159 | } 160 | -------------------------------------------------------------------------------- /benches/with_torch.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, Criterion}; 2 | 3 | use ndarray::prelude::*; 4 | use ndarray_conv::*; 5 | use ndarray_rand::{rand_distr::Uniform, RandomExt}; 6 | use ndarray_vision::processing::ConvolutionExt; 7 | use num::Complex; 8 | 9 | /// Benchmark for 1D convolution using `conv_fft` with various libraries. 10 | fn criterion_benchmark(c: &mut Criterion) { 11 | let x = Array::random(5000, Uniform::new(0f32, 1.).unwrap()); 12 | let k = Array::random(31, Uniform::new(0f32, 1.).unwrap()); 13 | 14 | let x_crs = x.to_shape((1, 1, 5000)).unwrap().to_owned(); 15 | let k_crs = k.to_shape((1, 1, 1, 31)).unwrap().to_owned(); 16 | 17 | let tensor = tch::Tensor::from_slice(x.as_slice().unwrap()) 18 | .to_dtype(tch::Kind::Float, false, true) 19 | .reshape([1, 1, 5000]); 20 | let kernel = tch::Tensor::from_slice(k.as_slice().unwrap()) 21 | .to_dtype(tch::Kind::Float, false, true) 22 | .reshape([1, 1, 31]); 23 | 24 | // for (a, b) in x 25 | // .conv_fft(&k, ConvMode::Same, PaddingMode::Zeros) 26 | // .unwrap() 27 | // .iter() 28 | // .zip( 29 | // tensor 30 | // .conv1d_padding::(&kernel, None, 1, "same", 1, 1) 31 | // .reshape(5000) 32 | // .iter::() 33 | // .unwrap(), 34 | // ) 35 | // { 36 | // // need to div kernel size 37 | // assert!((*a as f64 - b).abs() < 1e-5); 38 | // } 39 | 40 | let mut fft_processor = get_fft_processor(); 41 | 42 | /// Benchmark for 1D convolution using `conv_fft`. 43 | c.bench_function("fft_1d", |b| { 44 | b.iter(|| x.conv_fft(&k, ConvMode::Same, PaddingMode::Zeros)) 45 | }); 46 | 47 | /// Benchmark for 1D convolution using `conv_fft_with_processor`. 48 | c.bench_function("fft_with_processor_1d", |b| { 49 | b.iter(|| { 50 | x.conv_fft_with_processor(&k, ConvMode::Same, PaddingMode::Zeros, &mut fft_processor) 51 | }) 52 | }); 53 | 54 | c.bench_function("torch_1d", |b| { 55 | b.iter(|| tensor.conv1d_padding::(&kernel, None, 1, "same", 1, 1)) 56 | }); 57 | 58 | // c.bench_function("convolution_rs_1d", |b| { 59 | // b.iter(|| { 60 | // convolutions_rs::convolutions::ConvolutionLayer::new_tf( 61 | // k_crs.clone(), 62 | // None, 63 | // 1, 64 | // convolutions_rs::Padding::Same, 65 | // ) 66 | // .convolve(&x_crs) 67 | // }); 68 | // }); 69 | 70 | // c.bench_function("fftconvolve_1d", |b| { 71 | // b.iter(|| fftconvolve::fftconvolve(&x, &k, fftconvolve::Mode::Same)) 72 | // }); 73 | 74 | let x = Array::random((200, 5000), Uniform::new(0f32, 1.).unwrap()); 75 | let k = Array::random((11, 31), Uniform::new(0f32, 1.).unwrap()); 76 | 77 | let x_crs = x.to_shape((1, 200, 5000)).unwrap().to_owned(); 78 | let k_crs = k.to_shape((1, 1, 11, 31)).unwrap().to_owned(); 79 | 80 | let x_nvs = x.to_shape((200, 5000, 1)).unwrap().to_owned(); 81 | let k_nvs = k.to_shape((11, 31, 1)).unwrap().to_owned(); 82 | 83 | let tensor = tch::Tensor::from_slice(x.as_slice().unwrap()) 84 | .to_dtype(tch::Kind::Float, false, true) 85 | .reshape([1, 1, 200, 5000]); 86 | let kernel = tch::Tensor::from_slice(k.as_slice().unwrap()) 87 | .to_dtype(tch::Kind::Float, false, true) 88 | .reshape([1, 1, 11, 31]); 89 | 90 | let mut fft_processor = get_fft_processor(); 91 | 92 | /// Benchmark for 2D convolution using `conv_fft`. 93 | c.bench_function("fft_2d", |b| { 94 | b.iter(|| x.conv_fft(&k, ConvMode::Same, PaddingMode::Zeros)) 95 | }); 96 | 97 | /// Benchmark for 2D convolution using `conv_fft_with_processor`. 98 | c.bench_function("fft_with_processor_2d", |b| { 99 | b.iter(|| { 100 | x.conv_fft_with_processor(&k, ConvMode::Same, PaddingMode::Zeros, &mut fft_processor) 101 | }) 102 | }); 103 | 104 | c.bench_function("torch_2d", |b| { 105 | b.iter(|| tensor.conv2d_padding::(&kernel, None, 1, "same", 1, 1)) 106 | }); 107 | 108 | // c.bench_function("ndarray_vision_2d", |b| { 109 | // b.iter(|| x_nvs.conv2d_with_padding(k_nvs.clone(), &ndarray_vision::core::ZeroPadding)) 110 | // }); 111 | 112 | // c.bench_function("convolution_rs_2d", |b| { 113 | // b.iter(|| { 114 | // convolutions_rs::convolutions::ConvolutionLayer::new_tf( 115 | // k_crs.clone(), 116 | // None, 117 | // 1, 118 | // convolutions_rs::Padding::Same, 119 | // ) 120 | // .convolve(&x_crs) 121 | // }); 122 | // }); 123 | 124 | // c.bench_function("fftconvolve_2d", |b| { 125 | // b.iter(|| fftconvolve::fftconvolve(&x, &k, fftconvolve::Mode::Same)) 126 | // }); 127 | 128 | let x = Array::random((10, 100, 200), Uniform::new(0f32, 1.).unwrap()); 129 | let k = Array::random((5, 11, 31), Uniform::new(0f32, 1.).unwrap()); 130 | 131 | let x_crs = x.to_shape((10, 100, 200)).unwrap().to_owned(); 132 | let k_crs = k.to_shape((1, 5, 11, 31)).unwrap().to_owned(); 133 | 134 | let tensor = tch::Tensor::from_slice(x.as_slice().unwrap()) 135 | .to_dtype(tch::Kind::Float, false, true) 136 | .reshape([1, 1, 10, 100, 200]); 137 | let kernel = tch::Tensor::from_slice(k.as_slice().unwrap()) 138 | .to_dtype(tch::Kind::Float, false, true) 139 | .reshape([1, 1, 5, 11, 31]); 140 | 141 | let mut fft_processor = get_fft_processor(); 142 | 143 | /// Benchmark for 3D convolution using `conv_fft`. 144 | c.bench_function("fft_3d", |b| { 145 | b.iter(|| x.conv_fft(&k, ConvMode::Same, PaddingMode::Zeros)) 146 | }); 147 | 148 | /// Benchmark for 3D convolution using `conv_fft_with_processor`. 149 | c.bench_function("fft_with_processor_3d", |b| { 150 | b.iter(|| { 151 | x.conv_fft_with_processor(&k, ConvMode::Same, PaddingMode::Zeros, &mut fft_processor) 152 | }) 153 | }); 154 | 155 | c.bench_function("torch_3d", |b| { 156 | b.iter(|| tensor.conv3d_padding::(&kernel, None, 1, "same", 1, 1)) 157 | }); 158 | 159 | // c.bench_function("convolution_rs_3d", |b| { 160 | // b.iter(|| { 161 | // convolutions_rs::convolutions::ConvolutionLayer::new_tf( 162 | // k_crs.clone(), 163 | // None, 164 | // 1, 165 | // convolutions_rs::Padding::Same, 166 | // ) 167 | // .convolve(&x_crs) 168 | // }); 169 | // }); 170 | 171 | // c.bench_function("fftconvolve_3d", |b| { 172 | // b.iter(|| fftconvolve::fftconvolve(&x, &k, fftconvolve::Mode::Same)) 173 | // }); 174 | } 175 | 176 | criterion_group!(benches, criterion_benchmark); 177 | criterion_main!(benches); 178 | -------------------------------------------------------------------------------- /src/conv/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides convolution operations for `ndarray` arrays. 2 | //! Includes standard convolution and related utilities. 3 | 4 | use ndarray::{ 5 | Array, ArrayBase, ArrayView, Data, Dim, Dimension, IntoDimension, Ix, RawData, RemoveAxis, 6 | SliceArg, SliceInfo, SliceInfoElem, 7 | }; 8 | use num::traits::NumAssign; 9 | 10 | use crate::{ 11 | dilation::{IntoKernelWithDilation, KernelWithDilation}, 12 | padding::PaddingExt, 13 | ConvMode, PaddingMode, 14 | }; 15 | 16 | #[cfg(test)] 17 | mod tests; 18 | 19 | /// Represents explicit convolution parameters after unfolding from `ConvMode`. 20 | /// 21 | /// This struct holds padding and strides information used directly 22 | /// by the convolution algorithm. 23 | pub struct ExplicitConv { 24 | pub padding: [[usize; 2]; N], 25 | pub strides: [usize; N], 26 | } 27 | 28 | impl ConvMode { 29 | pub(crate) fn unfold(self, kernel: &KernelWithDilation) -> ExplicitConv 30 | where 31 | S: ndarray::RawData, 32 | Dim<[Ix; N]>: Dimension, 33 | { 34 | let kernel_dim = kernel.kernel.raw_dim(); 35 | let kernel_dim: [usize; N] = std::array::from_fn(|i| 36 | // k + (k - 1) * (d - 1) 37 | kernel_dim[i] * kernel.dilation[i] - kernel.dilation[i] + 1); 38 | 39 | match self { 40 | ConvMode::Full => ExplicitConv { 41 | padding: std::array::from_fn(|i| [kernel_dim[i] - 1; 2]), 42 | strides: [1; N], 43 | }, 44 | ConvMode::Same => ExplicitConv { 45 | padding: std::array::from_fn(|i| { 46 | let k_size = kernel_dim[i]; 47 | if k_size.is_multiple_of(2) { 48 | [(k_size - 1) / 2 + 1, (k_size - 1) / 2] 49 | } else { 50 | [(k_size - 1) / 2; 2] 51 | } 52 | }), 53 | strides: [1; N], 54 | }, 55 | ConvMode::Valid => ExplicitConv { 56 | padding: [[0; 2]; N], 57 | strides: [1; N], 58 | }, 59 | ConvMode::Custom { padding, strides } => ExplicitConv { 60 | padding: padding.map(|pad| [pad; 2]), 61 | strides, 62 | }, 63 | ConvMode::Explicit { padding, strides } => ExplicitConv { padding, strides }, 64 | } 65 | } 66 | } 67 | 68 | /// Extends `ndarray`'s `ArrayBase` with convolution operations. 69 | /// 70 | /// This trait adds the `conv` method to `ArrayBase`, enabling 71 | /// standard convolution operations on N-dimensional arrays. 72 | /// 73 | /// # Type Parameters 74 | /// 75 | /// * `T`: The numeric type of the array elements. 76 | /// * `S`: The data storage type of the input array. 77 | /// * `SK`: The data storage type of the kernel array. 78 | pub trait ConvExt<'a, T, S, SK, const N: usize> 79 | where 80 | T: NumAssign + Copy, 81 | S: RawData, 82 | SK: RawData, 83 | { 84 | /// Performs a standard convolution operation. 85 | /// 86 | /// This method convolves the input array with a given kernel, 87 | /// using the specified convolution mode and padding. 88 | /// 89 | /// # Arguments 90 | /// 91 | /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings created using `with_dilation()`. 92 | /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`). 93 | /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`). 94 | /// 95 | /// # Returns 96 | /// 97 | /// Returns `Ok(Array>)` containing the convolution result, or an `Err(Error)` if the operation fails 98 | /// (e.g., due to incompatible shapes or zero-sized dimensions). 99 | /// 100 | /// # Example 101 | /// 102 | /// ```rust 103 | /// use ndarray::array; 104 | /// use ndarray_conv::{ConvExt, ConvMode, PaddingMode}; 105 | /// 106 | /// let input = array![[1, 2, 3], [4, 5, 6]]; 107 | /// let kernel = array![[1, 1], [1, 1]]; 108 | /// let result = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap(); 109 | /// ``` 110 | fn conv( 111 | &self, 112 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 113 | conv_mode: ConvMode, 114 | padding_mode: PaddingMode, 115 | ) -> Result>, crate::Error>; 116 | } 117 | 118 | impl<'a, T, S, SK, const N: usize> ConvExt<'a, T, S, SK, N> for ArrayBase> 119 | where 120 | T: NumAssign + Copy + 'a, 121 | S: Data + 'a, 122 | SK: Data + 'a, 123 | Dim<[Ix; N]>: RemoveAxis, 124 | [Ix; N]: IntoDimension>, 125 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 126 | SliceArg, OutDim = Dim<[Ix; N]>>, 127 | { 128 | fn conv( 129 | &self, 130 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 131 | conv_mode: ConvMode, 132 | padding_mode: PaddingMode, 133 | ) -> Result>, crate::Error> { 134 | let kwd = kernel.into_kernel_with_dilation(); 135 | 136 | let self_raw_dim = self.raw_dim(); 137 | if self.shape().iter().product::() == 0 { 138 | return Err(crate::Error::DataShape(self_raw_dim)); 139 | } 140 | 141 | let kernel_raw_dim = kwd.kernel.raw_dim(); 142 | if kwd.kernel.shape().iter().product::() == 0 { 143 | return Err(crate::Error::KernelShape(kernel_raw_dim)); 144 | } 145 | 146 | let kernel_raw_dim_with_dilation: [usize; N] = 147 | std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1); 148 | 149 | let cm = conv_mode.unfold(&kwd); 150 | let pds = self.padding(padding_mode, cm.padding); 151 | 152 | let pds_raw_dim = pds.raw_dim(); 153 | if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) { 154 | return Err(crate::Error::MismatchShape( 155 | conv_mode, 156 | kernel_raw_dim_with_dilation, 157 | )); 158 | } 159 | 160 | let offset_list = kwd.gen_offset_list(pds.strides()); 161 | 162 | let output_shape: [usize; N] = std::array::from_fn(|i| { 163 | (cm.padding[i][0] + cm.padding[i][1] + self_raw_dim[i] 164 | - kernel_raw_dim_with_dilation[i]) 165 | / cm.strides[i] 166 | + 1 167 | }); 168 | let mut ret = Array::zeros(output_shape); 169 | 170 | let shape: [usize; N] = std::array::from_fn(|i| ret.raw_dim()[i]); 171 | let strides: [usize; N] = 172 | std::array::from_fn(|i| cm.strides[i] * pds.strides()[i] as usize); 173 | 174 | // dbg!(&offset_list); 175 | // dbg!(strides); 176 | 177 | unsafe { 178 | // use raw pointer to improve performance. 179 | let p: *mut T = ret.as_mut_ptr(); 180 | 181 | // use ArrayView's iter without handle strides 182 | let view = ArrayView::from_shape( 183 | ndarray::ShapeBuilder::strides(shape, strides), 184 | pds.as_slice().unwrap(), 185 | ) 186 | .unwrap(); 187 | 188 | view.iter().enumerate().for_each(|(i, cur)| { 189 | let mut tmp_res = T::zero(); 190 | 191 | offset_list.iter().for_each(|(tmp_offset, tmp_kernel)| { 192 | tmp_res += *(cur as *const T).offset(*tmp_offset) * *tmp_kernel 193 | }); 194 | 195 | *p.add(i) = tmp_res; 196 | }); 197 | } 198 | 199 | Ok(ret) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/padding/half_dim.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions for padding on a single side of a dimension. 2 | //! 3 | //! This module contains functions for applying padding to the front 4 | //! or back of a specific dimension of an array. These functions are 5 | //! used internally by the `dim` module and implement the core logic 6 | //! for different padding modes like constant, replicate, reflect and 7 | //! circular. 8 | 9 | use ndarray::{ArrayBase, Axis, DataMut, Dim, Ix, RemoveAxis}; 10 | use num::traits::NumAssign; 11 | 12 | /// Applies constant padding to the front of a given dimension of an array. 13 | /// 14 | /// This function modifies the input array by padding the front of the 15 | /// specified dimension with a constant value. 16 | /// 17 | /// # Type Parameters 18 | /// 19 | /// * `T`: The numeric type of the array elements. 20 | /// * `S`: The data storage type of the array. 21 | /// * `D`: The dimension type of the array. 22 | /// 23 | /// # Arguments 24 | /// 25 | /// * `buffer`: A mutable reference to the array to be padded. 26 | /// * `dim`: The dimension to pad. 27 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 28 | /// * `constant`: The constant value to pad with. 29 | #[inline] 30 | pub fn constant_front( 31 | buffer: &mut ArrayBase, 32 | dim: usize, 33 | padding: [usize; 2], 34 | constant: T, 35 | ) where 36 | T: NumAssign + Copy, 37 | S: DataMut, 38 | D: RemoveAxis, 39 | { 40 | for j in 0..padding[0] { 41 | unsafe { 42 | let buffer_mut = (buffer as *const _ as *mut ArrayBase) 43 | .as_mut() 44 | .unwrap(); 45 | 46 | buffer_mut.index_axis_mut(Axis(dim), j).fill(constant); 47 | } 48 | } 49 | } 50 | 51 | /// Applies constant padding to the back of a given dimension of an array. 52 | /// 53 | /// This function modifies the input array by padding the back of the 54 | /// specified dimension with a constant value. 55 | /// 56 | /// # Type Parameters 57 | /// 58 | /// * `N`: The number of dimensions. 59 | /// * `T`: The numeric type of the array elements. 60 | /// * `S`: The data storage type of the array. 61 | /// * `D`: The dimension type of the original data. 62 | /// * `DO`: The dimension type of the output data. 63 | /// 64 | /// # Arguments 65 | /// 66 | /// * `input_dim`: The dimensions of the original array. 67 | /// * `buffer`: A mutable reference to the array to be padded. 68 | /// * `dim`: The dimension to pad. 69 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 70 | /// * `constant`: The constant value to pad with. 71 | #[inline] 72 | pub fn constant_back( 73 | input_dim: D, 74 | buffer: &mut ArrayBase, 75 | dim: usize, 76 | padding: [usize; 2], 77 | constant: T, 78 | ) where 79 | T: NumAssign + Copy, 80 | S: DataMut, 81 | D: RemoveAxis, 82 | DO: RemoveAxis, 83 | Dim<[Ix; N]>: RemoveAxis, 84 | { 85 | for j in input_dim[dim] + padding[0]..buffer.raw_dim()[dim] { 86 | unsafe { 87 | let buffer_mut = (buffer as *const _ as *mut ArrayBase) 88 | .as_mut() 89 | .unwrap(); 90 | 91 | buffer_mut.index_axis_mut(Axis(dim), j).fill(constant); 92 | } 93 | } 94 | } 95 | 96 | /// Applies replicate padding to the front of a given dimension of an array. 97 | /// 98 | /// This function modifies the input array by padding the front of the 99 | /// specified dimension by replicating the edge values. 100 | /// 101 | /// # Type Parameters 102 | /// 103 | /// * `T`: The numeric type of the array elements. 104 | /// * `S`: The data storage type of the array. 105 | /// * `D`: The dimension type of the array. 106 | /// 107 | /// # Arguments 108 | /// 109 | /// * `buffer`: A mutable reference to the array to be padded. 110 | /// * `dim`: The dimension to pad. 111 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 112 | #[inline] 113 | pub fn replicate_front(buffer: &mut ArrayBase, dim: usize, padding: [usize; 2]) 114 | where 115 | T: NumAssign + Copy, 116 | S: DataMut, 117 | D: RemoveAxis, 118 | { 119 | let border = buffer.index_axis(Axis(dim), padding[0]); 120 | for j in 0..padding[0] { 121 | unsafe { 122 | let buffer_mut = (buffer as *const _ as *mut ArrayBase) 123 | .as_mut() 124 | .unwrap(); 125 | 126 | buffer_mut.index_axis_mut(Axis(dim), j).assign(&border); 127 | } 128 | } 129 | } 130 | 131 | /// Applies replicate padding to the back of a given dimension of an array. 132 | /// 133 | /// This function modifies the input array by padding the back of the 134 | /// specified dimension by replicating the edge values. 135 | /// 136 | /// # Type Parameters 137 | /// 138 | /// * `N`: The number of dimensions. 139 | /// * `T`: The numeric type of the array elements. 140 | /// * `S`: The data storage type of the array. 141 | /// * `D`: The dimension type of the original data. 142 | /// * `DO`: The dimension type of the output data. 143 | /// 144 | /// # Arguments 145 | /// 146 | /// * `input_dim`: The dimensions of the original array. 147 | /// * `buffer`: A mutable reference to the array to be padded. 148 | /// * `dim`: The dimension to pad. 149 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 150 | #[inline] 151 | pub fn replicate_back( 152 | input_dim: D, 153 | buffer: &mut ArrayBase, 154 | dim: usize, 155 | padding: [usize; 2], 156 | ) where 157 | T: NumAssign + Copy, 158 | S: DataMut, 159 | D: RemoveAxis, 160 | DO: RemoveAxis, 161 | Dim<[Ix; N]>: RemoveAxis, 162 | { 163 | let border = buffer.index_axis(Axis(dim), buffer.raw_dim()[dim] - padding[1] - 1); 164 | for j in input_dim[dim] + padding[0]..buffer.raw_dim()[dim] { 165 | unsafe { 166 | let buffer_mut = (buffer as *const _ as *mut ArrayBase) 167 | .as_mut() 168 | .unwrap(); 169 | 170 | buffer_mut.index_axis_mut(Axis(dim), j).assign(&border); 171 | } 172 | } 173 | } 174 | 175 | /// Applies reflect padding to the front of a given dimension of an array. 176 | /// 177 | /// This function modifies the input array by padding the front of the 178 | /// specified dimension by reflecting the array at the boundaries. 179 | /// 180 | /// # Type Parameters 181 | /// 182 | /// * `T`: The numeric type of the array elements. 183 | /// * `S`: The data storage type of the array. 184 | /// * `D`: The dimension type of the array. 185 | /// 186 | /// # Arguments 187 | /// 188 | /// * `buffer`: A mutable reference to the array to be padded. 189 | /// * `dim`: The dimension to pad. 190 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 191 | #[inline] 192 | pub fn reflect_front(buffer: &mut ArrayBase, dim: usize, padding: [usize; 2]) 193 | where 194 | T: NumAssign + Copy, 195 | S: DataMut, 196 | D: RemoveAxis, 197 | { 198 | let border_index = padding[0]; 199 | for j in 0..padding[0] { 200 | let reflect_j = (border_index - j) + border_index; 201 | unsafe { 202 | let output_mut = (buffer as *const _ as *mut ArrayBase) 203 | .as_mut() 204 | .unwrap(); 205 | 206 | output_mut 207 | .index_axis_mut(Axis(dim), j) 208 | .assign(&buffer.index_axis(Axis(dim), reflect_j)); 209 | } 210 | } 211 | } 212 | 213 | /// Applies reflect padding to the back of a given dimension of an array. 214 | /// 215 | /// This function modifies the input array by padding the back of the 216 | /// specified dimension by reflecting the array at the boundaries. 217 | /// 218 | /// # Type Parameters 219 | /// 220 | /// * `N`: The number of dimensions. 221 | /// * `T`: The numeric type of the array elements. 222 | /// * `S`: The data storage type of the array. 223 | /// * `D`: The dimension type of the original data. 224 | /// * `DO`: The dimension type of the output data. 225 | /// 226 | /// # Arguments 227 | /// 228 | /// * `input_dim`: The dimensions of the original array. 229 | /// * `buffer`: A mutable reference to the array to be padded. 230 | /// * `dim`: The dimension to pad. 231 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 232 | #[inline] 233 | pub fn reflect_back( 234 | input_dim: D, 235 | buffer: &mut ArrayBase, 236 | dim: usize, 237 | padding: [usize; 2], 238 | ) where 239 | T: NumAssign + Copy, 240 | S: DataMut, 241 | D: RemoveAxis, 242 | DO: RemoveAxis, 243 | Dim<[Ix; N]>: RemoveAxis, 244 | { 245 | let border_index = buffer.raw_dim()[dim] - padding[1] - 1; 246 | for j in input_dim[dim] + padding[0]..buffer.raw_dim()[dim] { 247 | let reflect_j = border_index - (j - border_index); 248 | unsafe { 249 | let output_mut = (buffer as *const _ as *mut ArrayBase) 250 | .as_mut() 251 | .unwrap(); 252 | 253 | output_mut 254 | .index_axis_mut(Axis(dim), j) 255 | .assign(&buffer.index_axis(Axis(dim), reflect_j)); 256 | } 257 | } 258 | } 259 | 260 | /// Applies circular padding to the front of a given dimension of an array. 261 | /// 262 | /// This function modifies the input array by padding the front of the 263 | /// specified dimension by wrapping the data around the boundaries. 264 | /// 265 | /// # Type Parameters 266 | /// 267 | /// * `T`: The numeric type of the array elements. 268 | /// * `S`: The data storage type of the array. 269 | /// * `D`: The dimension type of the array. 270 | /// 271 | /// # Arguments 272 | /// 273 | /// * `buffer`: A mutable reference to the array to be padded. 274 | /// * `dim`: The dimension to pad. 275 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 276 | #[inline] 277 | pub fn circular_front(buffer: &mut ArrayBase, dim: usize, padding: [usize; 2]) 278 | where 279 | T: NumAssign + Copy, 280 | S: DataMut, 281 | D: RemoveAxis, 282 | { 283 | let border_index = padding[0]; 284 | for j in 0..padding[0] { 285 | let circular_j = buffer.raw_dim()[dim] - padding[1] - (border_index - j); 286 | unsafe { 287 | let output_mut = (buffer as *const _ as *mut ArrayBase) 288 | .as_mut() 289 | .unwrap(); 290 | 291 | output_mut 292 | .index_axis_mut(Axis(dim), j) 293 | .assign(&buffer.index_axis(Axis(dim), circular_j)); 294 | } 295 | } 296 | } 297 | 298 | /// Applies circular padding to the back of a given dimension of an array. 299 | /// 300 | /// This function modifies the input array by padding the back of the 301 | /// specified dimension by wrapping the data around the boundaries. 302 | /// 303 | /// # Type Parameters 304 | /// 305 | /// * `N`: The number of dimensions. 306 | /// * `T`: The numeric type of the array elements. 307 | /// * `S`: The data storage type of the array. 308 | /// * `D`: The dimension type of the original data. 309 | /// * `DO`: The dimension type of the output data. 310 | /// 311 | /// # Arguments 312 | /// 313 | /// * `input_dim`: The dimensions of the original array. 314 | /// * `buffer`: A mutable reference to the array to be padded. 315 | /// * `dim`: The dimension to pad. 316 | /// * `padding`: An array containing two `usize` values representing padding for front and back, respectively. 317 | #[inline] 318 | pub fn circular_back( 319 | input_dim: D, 320 | buffer: &mut ArrayBase, 321 | dim: usize, 322 | padding: [usize; 2], 323 | ) where 324 | T: NumAssign + Copy, 325 | S: DataMut, 326 | D: RemoveAxis, 327 | DO: RemoveAxis, 328 | Dim<[Ix; N]>: RemoveAxis, 329 | { 330 | let border_index = buffer.raw_dim()[dim] - padding[1] - 1; 331 | for j in input_dim[dim] + padding[0]..buffer.raw_dim()[dim] { 332 | let circular_j = padding[0] + (j - border_index - 1); 333 | unsafe { 334 | let output_mut = (buffer as *const _ as *mut ArrayBase) 335 | .as_mut() 336 | .unwrap(); 337 | 338 | output_mut 339 | .index_axis_mut(Axis(dim), j) 340 | .assign(&buffer.index_axis(Axis(dim), circular_j)); 341 | } 342 | } 343 | } 344 | -------------------------------------------------------------------------------- /src/conv_fft/processor/complex.rs: -------------------------------------------------------------------------------- 1 | use ndarray::ArrayViewMut; 2 | 3 | use super::Processor as ProcessorTrait; 4 | use super::*; 5 | 6 | /// Complex-valued FFT processor backed by `rustfft`. 7 | /// 8 | /// Plans complex-to-complex FFTs and provides helpers that operate on 9 | /// n-dimensional `ndarray::Array` values. The implementation always 10 | /// performs FFTs along the last axis (which is kept contiguous), then 11 | /// permutes axes so the next axis becomes the last. This keeps the heavy 12 | /// FFT work on contiguous memory and avoids many small allocations. 13 | pub struct Processor { 14 | cp: rustfft::FftPlanner, 15 | _phantom: PhantomData>, 16 | } 17 | 18 | impl Default for Processor { 19 | fn default() -> Self { 20 | Self { 21 | cp: rustfft::FftPlanner::new(), 22 | _phantom: Default::default(), 23 | } 24 | } 25 | } 26 | 27 | impl Processor { 28 | /// Perform an N-D complex FFT along all axes. 29 | /// 30 | /// The function expects an array whose last axis is contiguous. It 31 | /// performs an out-of-place FFT on that axis into an uninitialized 32 | /// `output` buffer, then rotates axes and repeats so each axis becomes 33 | /// the last once. `direction` controls forward vs inverse transform. 34 | pub fn internal>, const N: usize>( 35 | &mut self, 36 | input: &mut ArrayBase>, 37 | direction: rustfft::FftDirection, 38 | ) -> Array, Dim<[Ix; N]>> 39 | where 40 | Dim<[Ix; N]>: RemoveAxis, 41 | [Ix; N]: IntoDimension>, 42 | { 43 | // Ensure we always run FFTs along the last axis (which is contiguous), 44 | // then permute the array so the next axis becomes the last. This 45 | // avoids creating many small slices and keeps the heavy FFT work on 46 | // contiguous memory. 47 | let output = Array::uninit(input.raw_dim()); 48 | let mut output = unsafe { output.assume_init() }; 49 | 50 | let mut buffer = input.view_mut(); 51 | 52 | // axes permutation helper: rotate so the next dimension becomes the 53 | // last (contiguous) axis before each subsequent stage. 54 | let mut axes: [usize; N] = std::array::from_fn(|i| i); 55 | 56 | match direction { 57 | rustfft::FftDirection::Forward => axes.rotate_right(1), 58 | rustfft::FftDirection::Inverse => axes.rotate_left(1), 59 | }; 60 | 61 | // perform FFT on last axis, then permute and repeat for remaining axes 62 | for i in 0..N { 63 | let fft = self.cp.plan_fft(output.shape()[N - 1], direction); 64 | let mut scratch = 65 | vec![Complex::new(T::zero(), T::zero()); fft.get_outofplace_scratch_len()]; 66 | 67 | fft.process_outofplace_with_scratch( 68 | buffer.as_slice_mut().unwrap(), 69 | output.as_slice_mut().unwrap(), 70 | &mut scratch, 71 | ); 72 | 73 | // permute axes so the next axis becomes the last (contiguous) 74 | if i != N - 1 { 75 | output = output.permuted_axes(axes); 76 | 77 | // reshape 78 | buffer = 79 | unsafe { ArrayViewMut::from_shape_ptr(output.raw_dim(), buffer.as_mut_ptr()) }; 80 | buffer.zip_mut_with(&output, |transpose, &origin| { 81 | *transpose = origin; 82 | }); 83 | 84 | // continuous 85 | output = 86 | Array::from_shape_vec(output.raw_dim(), output.into_raw_vec_and_offset().0) 87 | .unwrap(); 88 | } 89 | } 90 | 91 | output 92 | } 93 | } 94 | 95 | impl ProcessorTrait> for Processor { 96 | fn forward>, const N: usize>( 97 | &mut self, 98 | input: &mut ArrayBase>, 99 | ) -> Array, Dim<[Ix; N]>> 100 | where 101 | Dim<[Ix; N]>: RemoveAxis, 102 | [Ix; N]: IntoDimension>, 103 | { 104 | self.internal(input, rustfft::FftDirection::Forward) 105 | } 106 | 107 | fn backward( 108 | &mut self, 109 | input: &mut Array, Dim<[Ix; N]>>, 110 | ) -> Array, Dim<[Ix; N]>> 111 | where 112 | Dim<[Ix; N]>: RemoveAxis, 113 | [Ix; N]: IntoDimension>, 114 | { 115 | let mut output = self.internal(input, rustfft::FftDirection::Inverse); 116 | let len = Complex::new(T::from_usize(output.len()).unwrap(), T::zero()); 117 | output.map_mut(|x| *x = *x / len); 118 | output 119 | // self.backward_internal(input, None) 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod tests { 125 | use super::*; 126 | use ndarray::array; 127 | 128 | // ===== Basic Roundtrip Tests ===== 129 | 130 | mod roundtrip { 131 | use super::*; 132 | 133 | #[test] 134 | fn test_1d() { 135 | let mut proc = Processor::::default(); 136 | let original = array![ 137 | Complex::new(1.0f32, 0.5), 138 | Complex::new(2.0, -0.25), 139 | Complex::new(3.0, 1.25), 140 | Complex::new(4.0, -0.75) 141 | ]; 142 | let mut input = original.clone(); 143 | let mut freq = proc.forward(&mut input); 144 | let recon = proc.backward(&mut freq); 145 | 146 | for (orig, recon) in original.iter().zip(recon.iter()) { 147 | assert!( 148 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 149 | "1D roundtrip failed. Original: {:?}, Reconstructed: {:?}", 150 | orig, 151 | recon 152 | ); 153 | } 154 | } 155 | 156 | #[test] 157 | fn test_2d() { 158 | let mut proc = Processor::::default(); 159 | let original = array![ 160 | [Complex::new(1.0f32, 0.5), Complex::new(2.0, -1.0)], 161 | [Complex::new(3.0, 1.5), Complex::new(4.0, -0.5)] 162 | ]; 163 | let mut input = original.clone(); 164 | let mut freq = proc.forward(&mut input); 165 | let recon = proc.backward(&mut freq); 166 | 167 | for (orig, recon) in original.iter().zip(recon.iter()) { 168 | assert!( 169 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 170 | "2D roundtrip failed. Original: {:?}, Reconstructed: {:?}", 171 | orig, 172 | recon 173 | ); 174 | } 175 | } 176 | 177 | #[test] 178 | fn test_3d() { 179 | let mut proc = Processor::::default(); 180 | let original = array![ 181 | [ 182 | [Complex::new(1.0f32, 0.125), Complex::new(2.0, -0.25)], 183 | [Complex::new(3.0, 0.375), Complex::new(4.0, -0.5)] 184 | ], 185 | [ 186 | [Complex::new(5.0, 0.625), Complex::new(6.0, -0.75)], 187 | [Complex::new(7.0, 0.875), Complex::new(8.0, -1.0)] 188 | ] 189 | ]; 190 | let mut input = original.clone(); 191 | let mut freq = proc.forward(&mut input); 192 | let recon = proc.backward(&mut freq); 193 | 194 | for (orig, recon) in original.iter().zip(recon.iter()) { 195 | assert!( 196 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 197 | "3D roundtrip failed. Original: {:?}, Reconstructed: {:?}", 198 | orig, 199 | recon 200 | ); 201 | } 202 | } 203 | 204 | #[test] 205 | fn different_sizes() { 206 | let test_cases = vec![ 207 | array![Complex::new(1.0f32, 0.5), Complex::new(2.0, -0.25)], 208 | array![ 209 | Complex::new(1.0f32, 0.75), 210 | Complex::new(2.0, 1.0), 211 | Complex::new(3.0, -1.0) 212 | ], 213 | array![ 214 | Complex::new(1.0f32, 0.25), 215 | Complex::new(2.0, -0.5), 216 | Complex::new(3.0, 0.75), 217 | Complex::new(4.0, -1.0), 218 | Complex::new(5.0, 1.25) 219 | ], 220 | ]; 221 | 222 | for (i, original) in test_cases.into_iter().enumerate() { 223 | let mut proc = Processor::::default(); 224 | let mut input = original.clone(); 225 | let mut freq = proc.forward(&mut input); 226 | let recon = proc.backward(&mut freq); 227 | 228 | for (orig, recon) in original.iter().zip(recon.iter()) { 229 | assert!( 230 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 231 | "Size test case {} failed. Original: {:?}, Reconstructed: {:?}", 232 | i, 233 | orig, 234 | recon 235 | ); 236 | } 237 | } 238 | } 239 | } 240 | 241 | // ===== Complex Value Tests ===== 242 | 243 | mod complex_values { 244 | use super::*; 245 | 246 | #[test] 247 | fn large_imaginary_parts() { 248 | let mut proc = Processor::::default(); 249 | let original = array![ 250 | Complex::new(1.0f32, 3.0), 251 | Complex::new(2.0, -2.5), 252 | Complex::new(0.5, 4.0), 253 | Complex::new(-1.0, 2.0) 254 | ]; 255 | let mut input = original.clone(); 256 | let mut freq = proc.forward(&mut input); 257 | let recon = proc.backward(&mut freq); 258 | 259 | for (orig, recon) in original.iter().zip(recon.iter()) { 260 | assert!( 261 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 262 | "Large imaginary parts roundtrip failed. Original: {:?}, Reconstructed: {:?}", 263 | orig, 264 | recon 265 | ); 266 | } 267 | } 268 | 269 | #[test] 270 | fn pure_imaginary() { 271 | let mut proc = Processor::::default(); 272 | // Test with pure imaginary numbers (re = 0, im != 0) 273 | let original = array![ 274 | Complex::new(0.0f32, 1.0), 275 | Complex::new(0.0, 2.0), 276 | Complex::new(0.0, -1.5), 277 | Complex::new(0.0, 3.0) 278 | ]; 279 | let mut input = original.clone(); 280 | let mut freq = proc.forward(&mut input); 281 | let recon = proc.backward(&mut freq); 282 | 283 | for (orig, recon) in original.iter().zip(recon.iter()) { 284 | assert!( 285 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 286 | "Pure imaginary roundtrip failed. Original: {:?}, Reconstructed: {:?}", 287 | orig, 288 | recon 289 | ); 290 | } 291 | } 292 | 293 | #[test] 294 | fn mixed_signs() { 295 | let mut proc = Processor::::default(); 296 | // Test with various combinations of positive/negative real and imaginary parts 297 | let original = array![ 298 | [Complex::new(1.0f32, 2.0), Complex::new(-1.0, 2.0)], 299 | [Complex::new(1.0, -2.0), Complex::new(-1.0, -2.0)] 300 | ]; 301 | let mut input = original.clone(); 302 | let mut freq = proc.forward(&mut input); 303 | let recon = proc.backward(&mut freq); 304 | 305 | for (orig, recon) in original.iter().zip(recon.iter()) { 306 | assert!( 307 | (orig.re - recon.re).abs() < 1e-6 && (orig.im - recon.im).abs() < 1e-6, 308 | "Mixed signs roundtrip failed. Original: {:?}, Reconstructed: {:?}", 309 | orig, 310 | recon 311 | ); 312 | } 313 | } 314 | } 315 | } 316 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ndarray-conv 2 | 3 | ndarray-conv is a crate that provides a N-Dimension convolutions (with FFT acceleration) library in pure Rust. 4 | 5 | Inspired by 6 | 7 | ndarray-vision (https://github.com/rust-cv/ndarray-vision) 8 | 9 | convolutions-rs (https://github.com/Conzel/convolutions-rs#readme) 10 | 11 | pocketfft (https://github.com/mreineck/pocketfft) 12 | 13 | ## Roadmap 14 | 15 | - [x] basic conv for N dimension `Array`/`ArrayView` 16 | - [x] conv with FFT acceleration for N dimension `Array`/`ArrayView` 17 | - [x] impl `ConvMode` and `PaddingMode` 18 | - [x] `ConvMode`: Full Same Valid Custom Explicit 19 | - [x] `PaddingMode`: Zeros Const Reflect Replicate Circular Custom Explicit 20 | - [x] conv with strides 21 | - [x] kernel with dilation 22 | - [x] handle input size error 23 | - [x] explict error type 24 | - [x] bench with similar libs 25 | - [x] support `Complex` 26 | - [ ] conv with GPU acceleration for N dimension `Array`/`ArrayView` via `wgpu` 27 | 28 | ## Examples 29 | 30 | ```rust 31 | use ndarray_conv::*; 32 | 33 | x_nd.conv( 34 | &k_n, 35 | ConvMode::Full, 36 | PaddingMode::Circular, 37 | ); 38 | 39 | // for cross-correlation 40 | x_nd.conv( 41 | k_n.no_reverse(), 42 | ConvMode::Full, 43 | PaddingMode::Circular, 44 | ); 45 | 46 | x_1d.view().conv_fft( 47 | &k_1d, 48 | ConvMode::Same, 49 | PaddingMode::Explicit([[BorderType::Replicate, BorderType::Reflect]]), 50 | ); 51 | 52 | x_2d.conv_fft( 53 | k_2d.with_dilation(2), 54 | ConvMode::Same, 55 | PaddingMode::Custom([BorderType::Reflect, BorderType::Circular]), 56 | ); 57 | 58 | // avoid loss of accuracy for fft ver 59 | // convert Integer to Float before caculate. 60 | x_3d.map(|&x| x as f32) 61 | .conv_fft( 62 | &kernel.map(|&x| x as f32), 63 | ConvMode::Same, 64 | PaddingMode::Zeros, 65 | ) 66 | .unwrap() 67 | .map(|x| x.round() as i32); 68 | ``` 69 | 70 | ```rust 71 | // Example for thin wrapper 72 | use ndarray::{ 73 | array, Array, ArrayView, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo, SliceInfoElem, 74 | }; 75 | use ndarray_conv::*; 76 | 77 | pub fn fftconvolve<'a, T, const N: usize>( 78 | in1: impl Into>>, 79 | in2: impl Into>>, 80 | ) -> Array> 81 | where 82 | T: num::traits::NumAssign + rustfft::FftNum, 83 | Dim<[Ix; N]>: RemoveAxis, 84 | [Ix; N]: IntoDimension>, 85 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 86 | SliceArg, OutDim = Dim<[Ix; N]>>, 87 | { 88 | in1.into() 89 | .conv_fft(&in2.into(), ConvMode::Full, PaddingMode::Zeros) 90 | .unwrap() 91 | } 92 | 93 | fn test() { 94 | let o0 = fftconvolve(&[1., 2.], &array![1., 3., 7.]); 95 | let o1 = fftconvolve(&vec![1., 2.], &[1., 3., 7.]); 96 | } 97 | ``` 98 | 99 | ## Benchmark 100 | 101 | ```rust 102 | let x = Array::random(5000, Uniform::new(0f32, 1.)); 103 | let k = Array::random(31, Uniform::new(0f32, 1.)); 104 | 105 | fft_1d time: [76.621 µs 76.649 µs 76.681 µs] 106 | fft_with_processor_1d time: [34.563 µs 34.790 µs 35.125 µs] 107 | torch_1d time: [45.542 µs 45.658 µs 45.775 µs] 108 | fftconvolve_1d time: [161.52 µs 162.28 µs 163.05 µs] 109 | 110 | --------------------------------------------------------------- 111 | 112 | let x = Array::random((200, 5000), Uniform::new(0f32, 1.)); 113 | let k = Array::random((11, 31), Uniform::new(0f32, 1.)); 114 | 115 | fft_2d time: [16.022 ms 16.046 ms 16.071 ms] 116 | fft_with_processor_2d time: [15.949 ms 15.977 ms 16.010 ms] 117 | torch_2d time: [109.76 ms 111.62 ms 113.79 ms] 118 | ndarray_vision_2d time: [429.47 ms 429.64 ms 429.82 ms] 119 | fftconvolve_2d time: [56.273 ms 56.342 ms 56.420 ms] 120 | 121 | --------------------------------------------------------------- 122 | 123 | let x = Array::random((10, 100, 200), Uniform::new(0f32, 1.)); 124 | let k = Array::random((5, 11, 31), Uniform::new(0f32, 1.)); 125 | 126 | fft_3d time: [4.6476 ms 4.6651 ms 4.6826 ms] 127 | fft_with_processor_3d time: [4.6393 ms 4.6575 ms 4.6754 ms] 128 | torch_3d time: [160.73 ms 161.12 ms 161.56 ms] 129 | fftconvolve_3d time: [11.991 ms 12.009 ms 12.031 ms] 130 | ``` 131 | 132 | ## Versions 133 | - 0.6.0 - Dependecy update: update ndarray from 0.16 to >=0.17. 134 | - 0.5.3 - Buf fix: error handling for kernel shape in convolution operation. Export GetProcessor as public trait for external consumer 135 | - 0.5.2 - Doc update. 136 | - 0.5.1 - Add support for Complex. Complete unit tests. Improve performance. 137 | - 0.5.0 - **[breaking change]** Add `ReverseKernel` trait for cross-correlation, make `conv` & `conv_fft` calculating mathematical convolution. 138 | - 0.4.2 - Remove `Debug` trait on `T`. 139 | - 0.4.1 - Doc update. 140 | - 0.4.0 - Dependency update: update ndarray from 0.15 to 0.16. 141 | - 0.3.4 - Bug fix: fix unsafe type cast in circular padding. 142 | - 0.3.3 - Bug fix: correct conv_fft's output shape. 143 | - 0.3.2 - Improve performance, by modifying `good_fft_size` and `transpose`. 144 | - 0.3.1 - Impl basic error type. Fix some bugs. 145 | - 0.3.0 - update to N-Dimension convolution. 146 | - 0.2.0 - finished `conv_2d` & `conv_2d_fft`. 147 | 148 | ## Frequently Asked Questions (FAQ) 149 | 150 | This FAQ addresses common questions about the `ndarray-conv` crate, a Rust library for N-dimensional convolutions using the `ndarray` ecosystem. 151 | 152 | ### 1. What is `ndarray-conv`? 153 | 154 | `ndarray-conv` is a Rust crate that provides N-dimensional convolution operations for the `ndarray` crate. It offers both standard and FFT-accelerated convolutions, giving you efficient tools for image processing, signal processing, and other applications that rely on convolutions. 155 | 156 | ### 2. What are the main features of `ndarray-conv`? 157 | 158 | * **N-Dimensional Convolutions:** Supports convolutions on arrays with any number of dimensions. 159 | * **Standard and FFT-Accelerated:** Offers both `conv` (standard) and `conv_fft` (FFT-based) methods. 160 | * **Flexible Convolution Modes:** `ConvMode` (Full, Same, Valid, Custom, Explicit) to control output size. 161 | * **Various Padding Modes:** `PaddingMode` (Zeros, Const, Reflect, Replicate, Circular, Custom, Explicit) to handle boundary conditions. 162 | * **Strides and Dilation:** Supports strided convolutions and dilated kernels using the `with_dilation()` method. 163 | * **Performance Optimization:** Uses FFTs for larger kernels and optimized low-level operations for efficiency. The `conv_fft_with_processor` allows to reuse an `FftProcessor` for improved performance on repeated calls. 164 | * **Integration with `ndarray`:** Seamlessly works with `ndarray` `Array` and `ArrayView` types. 165 | 166 | ### 3. When should I use `conv_fft` vs. `conv`? 167 | 168 | * **`conv_fft` (FFT-accelerated):** Generally faster for larger kernels (e.g., larger than 11x11) because the computational complexity of FFTs grows more slowly than direct convolution as the kernel size increases. 169 | * **`conv` (Standard):** Might be faster for very small kernels (e.g., 3x3, 5x5) due to the overhead associated with FFT calculations. 170 | 171 | It's a good idea to benchmark both methods with your specific kernel sizes and data dimensions to determine the best choice. 172 | 173 | ### 4. How do I choose the right `ConvMode`? 174 | 175 | * **`Full`:** The output contains all positions where the kernel and input overlap at least partially. This results in the largest output size. 176 | * **`Same`:** The output has the same size as the input. This is achieved by padding the input appropriately. 177 | * **`Valid`:** The output contains only positions where the kernel and input fully overlap. This results in the smallest output size. 178 | * **`Custom`:** You specify the padding for all the dimensions and strides. 179 | * **`Explicit`:** You specify the explicit padding for each side of each dimension, and the strides. 180 | 181 | The best choice depends on the desired output size and how you want to handle boundary conditions. 182 | 183 | ### 5. How do I handle border effects with `PaddingMode`? 184 | 185 | `PaddingMode` determines how the input is padded before the convolution. 186 | 187 | * **`Zeros`:** Pads with zeros. 188 | * **`Const(value)`:** Pads with a constant value. 189 | * **`Reflect`:** Reflects the input at the borders. 190 | * **`Replicate`:** Replicates the edge values. 191 | * **`Circular`:** Treats the input as a circular buffer, wrapping around at the borders. 192 | * **`Custom`:** You provide an array of `BorderType` enums, one for each dimension, to specify different padding behavior for each dimension. 193 | * **`Explicit`:** You provide an array with arrays of `BorderType` enums, one for each side of each dimension, to specify different padding behavior for each dimension. 194 | 195 | Choose the `PaddingMode` that best suits your application's requirements for handling edges. 196 | 197 | ### 6. What is dilation, and how do I use it? 198 | 199 | Dilation expands the *receptive field* of a kernel without increasing the number of its parameters. It does this by inserting spaces (usually zeros) between the original kernel elements. A dilation factor of `d` means that `d-1` zeros are inserted between each kernel element. 200 | 201 | * Use the `with_dilation()` method on an `ndarray` `Array` or `ArrayView` representing your kernel to create a dilated kernel. 202 | * Pass the dilated kernel to the `conv` or `conv_fft` methods. 203 | 204 | **Example:** 205 | 206 | ```rust 207 | let kernel = ndarray::array![[1, 2, 3], [4, 5, 6]]; 208 | let dilated_kernel = kernel.with_dilation(2); // Dilate by a factor of 2 in both dimensions 209 | // dilated_kernel will effectively be: [[1, 0, 2, 0, 3], [4, 0, 5, 0, 6]] 210 | ``` 211 | 212 | **Why Use Dilation?** 213 | 214 | * **Increased Receptive Field:** Captures information from a wider area of the input without increasing the parameter count. 215 | * **Computational Efficiency:** More efficient than using very large standard kernels to achieve the same receptive field. 216 | * **Multi-Scale Feature Extraction:** Enables extracting features at different scales by using varying dilation rates. 217 | 218 | **Applications:** 219 | 220 | * Semantic segmentation 221 | * Object detection 222 | * Image generation 223 | * Audio processing 224 | * Time-series analysis 225 | 226 | ### 7. How can I improve the performance of repeated convolutions? 227 | 228 | * **Use `conv_fft_with_processor`:** If you're performing multiple FFT-based convolutions, create an `FftProcessor` and reuse it with the `conv_fft_with_processor` method. This avoids recomputing FFT plans and reallocating scratch buffers. 229 | * **Convert to `f32` or `f64`:** For FFT convolutions, ensure your input and kernel data are `f32` (for `Rfft32`) or `f64` (for `Rfft64`). This avoids unnecessary type conversions. 230 | 231 | ### 8. How do I install `ndarray-conv`? 232 | 233 | Add the following to your `Cargo.toml` file: 234 | 235 | ```toml 236 | ndarray-conv = "0.3.3" # Use the latest version 237 | ``` 238 | 239 | ### 9. Are there any limitations to be aware of? 240 | 241 | * **FFT Overhead:** For very small kernels, FFT-based convolutions might be slower than standard convolutions due to the overhead of FFT calculations. 242 | * **Memory Usage:** FFT operations might require additional memory for intermediate buffers. 243 | * **`conv_fft` requires floating point:** The input and kernel must be floating point types (`f32` or `f64`) for FFT-based convolutions. 244 | 245 | ### 10. How do I convert integer arrays to floating-point for use with `conv_fft`? 246 | 247 | Use the `.map(|&x| x as f32)` or `.map(|&x| x as f64)` methods to convert an integer `ndarray` to `f32` or `f64`, respectively. 248 | 249 | **Example:** 250 | 251 | ```rust 252 | let int_array = ndarray::Array::from_shape_vec((2, 3), vec![1, 2, 3, 4, 5, 6]).unwrap(); 253 | let float_array = int_array.map(|&x| x as f32); 254 | ``` 255 | 256 | ### 11. Where can I find examples and documentation? 257 | 258 | * **README:** The project's README file on GitHub contains basic examples and usage instructions. 259 | * **Rust Docs:** Once published to crates.io, you can find detailed API documentation on docs.rs. 260 | * **Test Cases:** The `tests` modules within the source code provide further examples of how to use the library. 261 | 262 | ### 12. How does `ndarray-conv` compare to other convolution libraries? 263 | 264 | The `ndarray-conv` project includes benchmarks comparing its performance to libraries like `tch` (LibTorch/PyTorch), `ndarray-vision`, and `fftconvolve`. `ndarray-conv` is generally competitive and often outperforms these other libraries, especially when using `conv_fft_with_processor` for repeated convolutions. 265 | -------------------------------------------------------------------------------- /src/conv_fft/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides FFT-accelerated convolution operations. 2 | //! 3 | //! This module offers the `ConvFFTExt` trait, which extends `ndarray` 4 | //! with FFT-based convolution methods. 5 | 6 | use ndarray::{ 7 | Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo, 8 | SliceInfoElem, 9 | }; 10 | use num::traits::NumAssign; 11 | use rustfft::FftNum; 12 | 13 | use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode}; 14 | 15 | mod good_size; 16 | mod padding; 17 | mod processor; 18 | 19 | // pub use fft::Processor; 20 | pub use processor::{get as get_processor, GetProcessor, Processor}; 21 | 22 | // /// Represents a "baked" convolution operation. 23 | // /// 24 | // /// This struct holds pre-computed data for performing FFT-accelerated 25 | // /// convolutions, including the FFT size, FFT processor, scratch space, 26 | // /// and padding information. It's designed to optimize repeated 27 | // /// convolutions with the same kernel and settings. 28 | // pub struct Baked 29 | // where 30 | // T: NumAssign + Debug + Copy, 31 | // SK: RawData, 32 | // { 33 | // fft_size: [usize; N], 34 | // fft_processor: impl Processor, 35 | // scratch: Vec>, 36 | // cm: ExplicitConv, 37 | // padding_mode: PaddingMode, 38 | // kernel_raw_dim_with_dilation: [usize; N], 39 | // pds_raw_dim: [usize; N], 40 | // kernel_pd: Array>, 41 | // _sk_hint: PhantomData, 42 | // } 43 | 44 | /// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations. 45 | /// 46 | /// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`, 47 | /// enabling efficient FFT-based convolutions on N-dimensional arrays. 48 | /// 49 | /// # Type Parameters 50 | /// 51 | /// * `T`: The numeric type used internally for FFT operations. Must be a floating-point type that implements `FftNum`. 52 | /// * `InElem`: The element type of the input arrays. Can be real (`T`) or complex (`Complex`). 53 | /// * `S`: The data storage type of the input array. 54 | /// * `SK`: The data storage type of the kernel array. 55 | /// 56 | /// # Methods 57 | /// 58 | /// * `conv_fft`: Performs an FFT-accelerated convolution with default settings. 59 | /// * `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans across multiple convolutions for better performance. 60 | /// 61 | /// # Example 62 | /// 63 | /// ```rust 64 | /// use ndarray::prelude::*; 65 | /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode}; 66 | /// 67 | /// let arr = array![[1., 2.], [3., 4.]]; 68 | /// let kernel = array![[1., 0.], [0., 1.]]; 69 | /// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap(); 70 | /// ``` 71 | /// 72 | /// # Notes 73 | /// 74 | /// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels. 75 | /// Use standard convolution (`ConvExt::conv`) for small kernels or when working with integer types. 76 | /// 77 | /// # Performance Tips 78 | /// 79 | /// For repeated convolutions with different data but the same kernel and settings, consider using 80 | /// `conv_fft_with_processor` to reuse the FFT planner and avoid redundant setup overhead. 81 | pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize> 82 | where 83 | T: NumAssign + Copy + FftNum, 84 | InElem: processor::GetProcessor + Copy + NumAssign, 85 | S: RawData, 86 | SK: RawData, 87 | { 88 | /// Performs an FFT-accelerated convolution operation. 89 | /// 90 | /// This method convolves the input array with a given kernel using FFT, 91 | /// which is typically faster for larger kernels. 92 | /// 93 | /// # Arguments 94 | /// 95 | /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings. 96 | /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`). 97 | /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`). 98 | /// 99 | /// # Returns 100 | /// 101 | /// Returns `Ok(Array>)` containing the convolution result, or an `Err(Error)` if the operation fails. 102 | /// 103 | /// # Example 104 | /// 105 | /// ```rust 106 | /// use ndarray::array; 107 | /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode}; 108 | /// 109 | /// let input = array![[1.0, 2.0], [3.0, 4.0]]; 110 | /// let kernel = array![[1.0, 0.0], [0.0, 1.0]]; 111 | /// let result = input.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap(); 112 | /// ``` 113 | fn conv_fft( 114 | &self, 115 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 116 | conv_mode: ConvMode, 117 | padding_mode: PaddingMode, 118 | ) -> Result>, crate::Error>; 119 | 120 | /// Performs an FFT-accelerated convolution using a provided processor. 121 | /// 122 | /// This method is useful when performing multiple convolutions, as it allows 123 | /// reusing the FFT planner and avoiding redundant initialization overhead. 124 | /// 125 | /// # Arguments 126 | /// 127 | /// * `kernel`: The convolution kernel. 128 | /// * `conv_mode`: The convolution mode. 129 | /// * `padding_mode`: The padding mode. 130 | /// * `fft_processor`: A mutable reference to an FFT processor instance. 131 | /// 132 | /// # Returns 133 | /// 134 | /// Returns `Ok(Array>)` containing the convolution result, or an `Err(Error)` if the operation fails. 135 | /// 136 | /// # Example 137 | /// 138 | /// ```rust 139 | /// use ndarray::array; 140 | /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode, get_fft_processor}; 141 | /// 142 | /// let input1 = array![[1.0, 2.0], [3.0, 4.0]]; 143 | /// let input2 = array![[5.0, 6.0], [7.0, 8.0]]; 144 | /// let kernel = array![[1.0, 0.0], [0.0, 1.0]]; 145 | /// 146 | /// // Reuse the same processor for multiple convolutions 147 | /// let mut proc = get_fft_processor::(); 148 | /// let result1 = input1.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap(); 149 | /// let result2 = input2.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap(); 150 | /// ``` 151 | fn conv_fft_with_processor( 152 | &self, 153 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 154 | conv_mode: ConvMode, 155 | padding_mode: PaddingMode, 156 | fft_processor: &mut impl Processor, 157 | ) -> Result>, crate::Error>; 158 | 159 | // fn conv_fft_bake( 160 | // &self, 161 | // kernel: impl IntoKernelWithDilation<'a, SK, N>, 162 | // conv_mode: ConvMode, 163 | // padding_mode: PaddingMode, 164 | // ) -> Result, crate::Error>; 165 | 166 | // fn conv_fft_with_baked(&self, baked: &mut Baked) -> Array>; 167 | } 168 | 169 | impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N> 170 | for ArrayBase> 171 | where 172 | T: NumAssign + FftNum, 173 | InElem: processor::GetProcessor + NumAssign + Copy + 'a, 174 | S: Data + 'a, 175 | SK: Data + 'a, 176 | [Ix; N]: IntoDimension>, 177 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 178 | SliceArg, OutDim = Dim<[Ix; N]>>, 179 | Dim<[Ix; N]>: RemoveAxis, 180 | { 181 | // fn conv_fft_bake( 182 | // &self, 183 | // kernel: impl IntoKernelWithDilation<'a, SK, N>, 184 | // conv_mode: ConvMode, 185 | // padding_mode: PaddingMode, 186 | // ) -> Result, crate::Error> { 187 | // let mut fft_processor = Processor::default(); 188 | 189 | // let kwd = kernel.into_kernel_with_dilation(); 190 | 191 | // let data_raw_dim = self.raw_dim(); 192 | // if self.shape().iter().product::() == 0 { 193 | // return Err(crate::Error::DataShape(data_raw_dim)); 194 | // } 195 | 196 | // let kernel_raw_dim = kwd.kernel.raw_dim(); 197 | // if kwd.kernel.shape().iter().product::() == 0 { 198 | // return Err(crate::Error::DataShape(kernel_raw_dim)); 199 | // } 200 | 201 | // let kernel_raw_dim_with_dilation: [usize; N] = 202 | // std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1); 203 | 204 | // let cm = conv_mode.unfold(&kwd); 205 | 206 | // let pds_raw_dim: [usize; N] = 207 | // std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1])); 208 | // if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) { 209 | // return Err(crate::Error::MismatchShape( 210 | // conv_mode, 211 | // kernel_raw_dim_with_dilation, 212 | // )); 213 | // } 214 | 215 | // let fft_size = good_size::compute::(&std::array::from_fn(|i| { 216 | // pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i]) 217 | // })); 218 | 219 | // let scratch = fft_processor.get_scratch(fft_size); 220 | 221 | // let kernel_pd = padding::kernel(kwd, fft_size); 222 | 223 | // Ok(Baked { 224 | // fft_size, 225 | // fft_processor, 226 | // scratch, 227 | // cm, 228 | // padding_mode, 229 | // kernel_raw_dim_with_dilation, 230 | // pds_raw_dim, 231 | // kernel_pd, 232 | // _sk_hint: PhantomData, 233 | // }) 234 | // } 235 | 236 | // fn conv_fft_with_baked(&self, baked: &mut Baked) -> Array> { 237 | // let Baked { 238 | // scratch, 239 | // fft_processor, 240 | // fft_size, 241 | // cm, 242 | // padding_mode, 243 | // kernel_pd, 244 | // kernel_raw_dim_with_dilation, 245 | // pds_raw_dim, 246 | // _sk_hint, 247 | // } = baked; 248 | 249 | // let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size); 250 | 251 | // let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch); 252 | // let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch); 253 | 254 | // data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k); 255 | // // let mul_spec = data_pd_fft * kernel_pd_fft; 256 | 257 | // let output = fft_processor.backward(data_pd_fft); 258 | 259 | // output.slice_move(unsafe { 260 | // SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 261 | // start: kernel_raw_dim_with_dilation[i] as isize - 1, 262 | // end: Some((pds_raw_dim[i]) as isize), 263 | // step: cm.strides[i] as isize, 264 | // })) 265 | // .unwrap() 266 | // }) 267 | // } 268 | 269 | fn conv_fft( 270 | &self, 271 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 272 | conv_mode: ConvMode, 273 | padding_mode: PaddingMode, 274 | ) -> Result>, crate::Error> { 275 | let mut p = InElem::get_processor(); 276 | self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p) 277 | } 278 | 279 | fn conv_fft_with_processor( 280 | &self, 281 | kernel: impl IntoKernelWithDilation<'a, SK, N>, 282 | conv_mode: ConvMode, 283 | padding_mode: PaddingMode, 284 | fft_processor: &mut impl Processor, 285 | ) -> Result>, crate::Error> { 286 | let kwd = kernel.into_kernel_with_dilation(); 287 | 288 | let data_raw_dim = self.raw_dim(); 289 | if self.shape().iter().product::() == 0 { 290 | return Err(crate::Error::DataShape(data_raw_dim)); 291 | } 292 | 293 | let kernel_raw_dim = kwd.kernel.raw_dim(); 294 | if kwd.kernel.shape().iter().product::() == 0 { 295 | return Err(crate::Error::DataShape(kernel_raw_dim)); 296 | } 297 | 298 | let kernel_raw_dim_with_dilation: [usize; N] = 299 | std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1); 300 | 301 | let cm = conv_mode.unfold(&kwd); 302 | 303 | let pds_raw_dim: [usize; N] = 304 | std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]); 305 | if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) { 306 | return Err(crate::Error::MismatchShape( 307 | conv_mode, 308 | kernel_raw_dim_with_dilation, 309 | )); 310 | } 311 | 312 | let fft_size = good_size::compute::(&std::array::from_fn(|i| { 313 | pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i]) 314 | })); 315 | 316 | let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size); 317 | let mut kernel_pd = padding::kernel(kwd, fft_size); 318 | 319 | let mut data_pd_fft = fft_processor.forward(&mut data_pd); 320 | let kernel_pd_fft = fft_processor.forward(&mut kernel_pd); 321 | 322 | data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k); 323 | // let mul_spec = data_pd_fft * kernel_pd_fft; 324 | 325 | let output = fft_processor.backward(&mut data_pd_fft); 326 | 327 | let output = output.slice_move(unsafe { 328 | SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 329 | start: kernel_raw_dim_with_dilation[i] as isize - 1, 330 | end: Some((pds_raw_dim[i]) as isize), 331 | step: cm.strides[i] as isize, 332 | })) 333 | .unwrap() 334 | }); 335 | 336 | Ok(output) 337 | } 338 | } 339 | 340 | #[cfg(test)] 341 | mod tests; 342 | -------------------------------------------------------------------------------- /src/dilation/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides functionality for kernel dilation. 2 | 3 | use ndarray::{ 4 | ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem, 5 | }; 6 | 7 | /// Represents a kernel along with its dilation factors for each dimension. 8 | pub struct KernelWithDilation<'a, S: RawData, const N: usize> { 9 | pub(crate) kernel: &'a ArrayBase>, 10 | pub(crate) dilation: [usize; N], 11 | pub(crate) reverse: bool, 12 | } 13 | 14 | impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N> 15 | where 16 | T: num::traits::NumAssign + Copy, 17 | S: Data, 18 | Dim<[Ix; N]>: Dimension, 19 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: 20 | SliceArg, OutDim = Dim<[Ix; N]>>, 21 | { 22 | /// Generates a list of offsets and corresponding kernel values for efficient convolution. 23 | /// 24 | /// This method calculates the offsets into the input array that need to be accessed 25 | /// during the convolution operation, taking into account the kernel's dilation. 26 | /// It filters out elements where the kernel value is zero to optimize the computation. 27 | /// 28 | /// # Arguments 29 | /// 30 | /// * `pds_strides`: The strides of the padded input array. 31 | /// 32 | /// # Returns 33 | /// A `Vec` of tuples, where each tuple contains an offset and the corresponding kernel value. 34 | pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> { 35 | let buffer_slice = self.kernel.slice(unsafe { 36 | SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 37 | start: 0, 38 | end: Some(self.kernel.raw_dim()[i] as isize), 39 | step: if self.reverse { -1 } else { 1 }, 40 | })) 41 | .unwrap() 42 | }); 43 | 44 | let strides: [isize; N] = 45 | std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]); 46 | 47 | buffer_slice 48 | .indexed_iter() 49 | .filter(|(_, v)| **v != T::zero()) 50 | .map(|(index, v)| { 51 | let index = index.into_dimension(); 52 | ( 53 | (0..N) 54 | .map(|n| index[n] as isize * strides[n]) 55 | .sum::(), 56 | *v, 57 | ) 58 | }) 59 | .collect() 60 | } 61 | } 62 | 63 | /// Trait for converting a value into a dilation array. 64 | pub trait IntoDilation { 65 | fn into_dilation(self) -> [usize; N]; 66 | } 67 | 68 | impl IntoDilation for usize { 69 | #[inline] 70 | fn into_dilation(self) -> [usize; N] { 71 | [self; N] 72 | } 73 | } 74 | 75 | impl IntoDilation for [usize; N] { 76 | #[inline] 77 | fn into_dilation(self) -> [usize; N] { 78 | self 79 | } 80 | } 81 | 82 | /// Trait for adding dilation information to a kernel. 83 | /// 84 | /// Dilation is a parameter that controls the spacing between kernel elements 85 | /// during convolution. A dilation of 1 means no spacing (standard convolution), 86 | /// while larger values insert gaps between kernel elements. 87 | /// 88 | /// # Example 89 | /// 90 | /// ```rust 91 | /// use ndarray::array; 92 | /// use ndarray_conv::{WithDilation, ConvExt, ConvMode, PaddingMode}; 93 | /// 94 | /// let input = array![1, 2, 3, 4, 5]; 95 | /// let kernel = array![1, 1, 1]; 96 | /// 97 | /// // Standard convolution (dilation = 1) 98 | /// let result1 = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap(); 99 | /// 100 | /// // Dilated convolution (dilation = 2) 101 | /// let result2 = input.conv(kernel.with_dilation(2), ConvMode::Same, PaddingMode::Zeros).unwrap(); 102 | /// ``` 103 | pub trait WithDilation { 104 | /// Adds dilation information to the kernel. 105 | /// 106 | /// # Arguments 107 | /// 108 | /// * `dilation`: The dilation factor(s). Can be a single value (applied to all dimensions) 109 | /// or an array of values (one per dimension). 110 | /// 111 | /// # Returns 112 | /// 113 | /// A `KernelWithDilation` instance containing the kernel and dilation information. 114 | fn with_dilation(&self, dilation: impl IntoDilation) -> KernelWithDilation<'_, S, N>; 115 | } 116 | 117 | impl WithDilation for ArrayBase> { 118 | #[inline] 119 | fn with_dilation(&self, dilation: impl IntoDilation) -> KernelWithDilation<'_, S, N> { 120 | KernelWithDilation { 121 | kernel: self, 122 | dilation: dilation.into_dilation(), 123 | reverse: true, 124 | } 125 | } 126 | } 127 | 128 | /// Trait for controlling kernel reversal behavior in convolution operations. 129 | /// 130 | /// In standard convolution, the kernel is reversed (flipped) along all axes. 131 | /// This trait allows you to control whether the kernel should be reversed or not. 132 | /// 133 | /// # Convolution vs Cross-Correlation 134 | /// 135 | /// * **Convolution** (default, `reverse()`): The kernel is reversed, which is the mathematical definition of convolution. 136 | /// * **Cross-correlation** (`no_reverse()`): The kernel is NOT reversed. This is commonly used in machine learning frameworks. 137 | /// 138 | /// # Example 139 | /// 140 | /// ```rust 141 | /// use ndarray::array; 142 | /// use ndarray_conv::{WithDilation, ReverseKernel, ConvExt, ConvMode, PaddingMode}; 143 | /// 144 | /// let input = array![1, 2, 3, 4, 5]; 145 | /// let kernel = array![1, 2, 3]; 146 | /// 147 | /// // Standard convolution (kernel is reversed) 148 | /// let result1 = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap(); 149 | /// // Equivalent to: 150 | /// let result1_explicit = input.conv(kernel.reverse(), ConvMode::Same, PaddingMode::Zeros).unwrap(); 151 | /// 152 | /// // Cross-correlation (kernel is NOT reversed) 153 | /// let result2 = input.conv(kernel.no_reverse(), ConvMode::Same, PaddingMode::Zeros).unwrap(); 154 | /// ``` 155 | pub trait ReverseKernel<'a, S: RawData, const N: usize> { 156 | /// Explicitly enables kernel reversal (standard convolution). 157 | /// 158 | /// This is the default behavior, so calling this method is usually not necessary. 159 | fn reverse(self) -> KernelWithDilation<'a, S, N>; 160 | 161 | /// Disables kernel reversal (cross-correlation). 162 | /// 163 | /// Use this when you want the kernel to be applied without flipping, 164 | /// which is common in machine learning applications. 165 | fn no_reverse(self) -> KernelWithDilation<'a, S, N>; 166 | } 167 | 168 | impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K 169 | where 170 | K: IntoKernelWithDilation<'a, S, N>, 171 | { 172 | #[inline] 173 | fn reverse(self) -> KernelWithDilation<'a, S, N> { 174 | let mut kwd = self.into_kernel_with_dilation(); 175 | 176 | kwd.reverse = true; 177 | 178 | kwd 179 | } 180 | 181 | #[inline] 182 | fn no_reverse(self) -> KernelWithDilation<'a, S, N> { 183 | let mut kwd = self.into_kernel_with_dilation(); 184 | 185 | kwd.reverse = false; 186 | 187 | kwd 188 | } 189 | } 190 | 191 | /// Trait for converting a reference to a `KernelWithDilation`. 192 | pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> { 193 | fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>; 194 | } 195 | 196 | impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N> 197 | for &'a ArrayBase> 198 | { 199 | #[inline] 200 | fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> { 201 | self.with_dilation(1) 202 | } 203 | } 204 | 205 | impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N> 206 | for KernelWithDilation<'a, S, N> 207 | { 208 | #[inline] 209 | fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> { 210 | self 211 | } 212 | } 213 | 214 | #[cfg(test)] 215 | mod tests { 216 | use ndarray::array; 217 | 218 | use super::*; 219 | 220 | // ===== Trait Implementation Tests ===== 221 | 222 | mod trait_implementation { 223 | use super::*; 224 | 225 | #[test] 226 | fn check_trait_impl() { 227 | fn conv_example<'a, S: RawData + 'a, const N: usize>( 228 | kernel: impl IntoKernelWithDilation<'a, S, N>, 229 | ) { 230 | let _ = kernel.into_kernel_with_dilation(); 231 | } 232 | 233 | let kernel = array![1, 0, 1]; 234 | conv_example(&kernel); 235 | 236 | let kernel = array![1, 0, 1]; 237 | conv_example(kernel.with_dilation(2)); 238 | 239 | let kernel = array![[1, 0, 1], [0, 1, 0]]; 240 | conv_example(kernel.with_dilation([1, 2])); 241 | 242 | // for convolution (default) 243 | conv_example(&kernel); 244 | // for convolution (explicit) 245 | conv_example(kernel.reverse()); 246 | // for cross-correlation 247 | conv_example(kernel.with_dilation(2).no_reverse()); 248 | } 249 | } 250 | 251 | // ===== Basic API Tests ===== 252 | 253 | mod basic_api { 254 | use super::*; 255 | 256 | #[test] 257 | fn dilation_and_reverse_settings() { 258 | let kernel = array![1, 2, 3]; 259 | 260 | // Test dilation is set correctly for different dimensions 261 | assert_eq!(kernel.with_dilation(2).dilation, [2]); 262 | assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]); 263 | assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]); 264 | 265 | // Test reverse behavior (default is true, can be toggled) 266 | assert!(kernel.with_dilation(1).reverse); 267 | assert!(!kernel.with_dilation(1).no_reverse().reverse); 268 | assert!(kernel.with_dilation(1).no_reverse().reverse().reverse); 269 | } 270 | } 271 | 272 | // ===== Offset Generation Tests ===== 273 | 274 | mod offset_generation { 275 | use super::*; 276 | 277 | #[test] 278 | fn gen_offset_1d_no_dilation() { 279 | let kernel = array![1.0, 2.0, 3.0]; 280 | let kwd = kernel.with_dilation(1); 281 | 282 | // Stride = 1 for 1D 283 | let offsets = kwd.gen_offset_list(&[1]); 284 | 285 | // Should have 3 offsets (all kernel elements) 286 | assert_eq!(offsets.len(), 3); 287 | 288 | // With reverse=true, kernel is reversed: [3, 2, 1] 289 | // Offsets: [0, 1, 2] * stride[1] = [0, 1, 2] 290 | assert_eq!(offsets[0], (0, 3.0)); 291 | assert_eq!(offsets[1], (1, 2.0)); 292 | assert_eq!(offsets[2], (2, 1.0)); 293 | } 294 | 295 | #[test] 296 | fn gen_offset_1d_with_dilation() { 297 | let kernel = array![1.0, 2.0, 3.0]; 298 | let kwd = kernel.with_dilation(2); 299 | 300 | // Stride = 1, but dilation = 2 301 | let offsets = kwd.gen_offset_list(&[1]); 302 | 303 | assert_eq!(offsets.len(), 3); 304 | 305 | // Effective kernel: [1, 0, 2, 0, 3] 306 | // With reverse, indices with dilation: [0*2, 1*2, 2*2] = [0, 2, 4] 307 | // But reversed: [3, 2, 1] at positions [0, 2, 4] 308 | assert_eq!(offsets[0], (0, 3.0)); 309 | assert_eq!(offsets[1], (2, 2.0)); 310 | assert_eq!(offsets[2], (4, 1.0)); 311 | } 312 | 313 | #[test] 314 | fn gen_offset_1d_no_reverse() { 315 | let kernel = array![1.0, 2.0, 3.0]; 316 | let kwd = kernel.with_dilation(2).no_reverse(); 317 | 318 | let offsets = kwd.gen_offset_list(&[1]); 319 | 320 | assert_eq!(offsets.len(), 3); 321 | 322 | // No reverse: [1, 2, 3] at positions [0, 2, 4] 323 | assert_eq!(offsets[0], (0, 1.0)); 324 | assert_eq!(offsets[1], (2, 2.0)); 325 | assert_eq!(offsets[2], (4, 3.0)); 326 | } 327 | 328 | #[test] 329 | fn gen_offset_2d_no_dilation() { 330 | let kernel = array![[1.0, 2.0], [3.0, 4.0]]; 331 | let kwd = kernel.with_dilation(1); 332 | 333 | // Strides for 2D: [row_stride, col_stride] 334 | let offsets = kwd.gen_offset_list(&[10, 1]); 335 | 336 | assert_eq!(offsets.len(), 4); 337 | 338 | // With reverse, kernel becomes [[4, 3], [2, 1]] 339 | // Flattened in row-major order with reversed indices: 340 | // (0,0)=4 at offset 0, (0,1)=3 at offset 1, (1,0)=2 at offset 10, (1,1)=1 at offset 11 341 | assert_eq!(offsets[0], (0, 4.0)); 342 | assert_eq!(offsets[1], (1, 3.0)); 343 | assert_eq!(offsets[2], (10, 2.0)); 344 | assert_eq!(offsets[3], (11, 1.0)); 345 | } 346 | 347 | #[test] 348 | fn gen_offset_2d_with_dilation() { 349 | let kernel = array![[1.0, 2.0], [3.0, 4.0]]; 350 | let kwd = kernel.with_dilation([2, 3]); 351 | 352 | let offsets = kwd.gen_offset_list(&[10, 1]); 353 | 354 | assert_eq!(offsets.len(), 4); 355 | 356 | // Dilation [2, 3] means: 357 | // - row spacing = 2 (kernel rows are 0 and 2*10=20 apart) 358 | // - col spacing = 3 (kernel cols are 0 and 3*1=3 apart) 359 | // With reverse, kernel [[4,3],[2,1]] at effective positions: 360 | // (0,0)=4 at 0, (0,3)=3 at 3, (2,0)=2 at 20, (2,3)=1 at 23 361 | assert_eq!(offsets[0], (0, 4.0)); 362 | assert_eq!(offsets[1], (3, 3.0)); 363 | assert_eq!(offsets[2], (20, 2.0)); 364 | assert_eq!(offsets[3], (23, 1.0)); 365 | } 366 | 367 | #[test] 368 | fn gen_offset_filters_zeros() { 369 | let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0]; 370 | let kwd = kernel.with_dilation(1); 371 | 372 | let offsets = kwd.gen_offset_list(&[1]); 373 | 374 | // Should only have 3 offsets (non-zero elements) 375 | assert_eq!(offsets.len(), 3); 376 | } 377 | } 378 | 379 | // ===== Edge Cases ===== 380 | 381 | mod edge_cases { 382 | use super::*; 383 | 384 | #[test] 385 | fn single_element_kernel() { 386 | let kernel = array![42.0]; 387 | let kwd = kernel.with_dilation(5); 388 | 389 | assert_eq!(kwd.dilation, [5]); 390 | 391 | let offsets = kwd.gen_offset_list(&[1]); 392 | assert_eq!(offsets.len(), 1); 393 | assert_eq!(offsets[0], (0, 42.0)); 394 | } 395 | 396 | #[test] 397 | fn all_zeros_kernel() { 398 | let kernel = array![0.0, 0.0, 0.0]; 399 | let kwd = kernel.with_dilation(2); 400 | 401 | let offsets = kwd.gen_offset_list(&[1]); 402 | // Should filter out all zeros 403 | assert_eq!(offsets.len(), 0); 404 | } 405 | 406 | #[test] 407 | fn large_dilation_value() { 408 | let kernel = array![1, 2]; 409 | let kwd = kernel.with_dilation(100); 410 | 411 | assert_eq!(kwd.dilation, [100]); 412 | // Effective size: 2 + (2-1)*99 = 101 413 | } 414 | 415 | #[test] 416 | fn asymmetric_2d_dilation() { 417 | let kernel = array![[1, 2, 3], [4, 5, 6]]; 418 | let kwd = kernel.with_dilation([1, 5]); 419 | 420 | assert_eq!(kwd.dilation, [1, 5]); 421 | // dim 0: no dilation (keeps 2 rows) 422 | // dim 1: dilation=5 (3 + (3-1)*4 = 11 effective cols) 423 | } 424 | } 425 | 426 | // ===== Integration Tests ===== 427 | 428 | mod integration_with_padding { 429 | use super::*; 430 | 431 | #[test] 432 | fn effective_kernel_size_calculation() { 433 | // This tests the concept used in padding calculations 434 | let kernel = array![1, 2, 3]; 435 | 436 | // No dilation 437 | let kwd1 = kernel.with_dilation(1); 438 | let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1); 439 | assert_eq!(effective_size_1, 3); 440 | 441 | // Dilation = 2 442 | let kwd2 = kernel.with_dilation(2); 443 | let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1); 444 | assert_eq!(effective_size_2, 5); 445 | 446 | // Dilation = 3 447 | let kwd3 = kernel.with_dilation(3); 448 | let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1); 449 | assert_eq!(effective_size_3, 7); 450 | } 451 | } 452 | } 453 | -------------------------------------------------------------------------------- /src/conv_fft/processor/real.rs: -------------------------------------------------------------------------------- 1 | use ndarray::ArrayViewMut; 2 | 3 | use super::*; 4 | 5 | use super::Processor as ProcessorTrait; 6 | 7 | pub struct Processor { 8 | rp: realfft::RealFftPlanner, 9 | rp_origin_len: usize, 10 | cp: rustfft::FftPlanner, 11 | } 12 | 13 | impl Default for Processor { 14 | fn default() -> Self { 15 | Self { 16 | rp: Default::default(), 17 | rp_origin_len: Default::default(), 18 | cp: rustfft::FftPlanner::new(), 19 | } 20 | } 21 | } 22 | 23 | impl ProcessorTrait for Processor { 24 | /// Performs a forward FFT on the given input array. 25 | /// 26 | /// This computes a real-to-complex FFT on the last axis (contiguous), 27 | /// producing a complex-valued array where the last axis length is 28 | /// `rp.complex_len()`. Remaining axes are transformed with complex 29 | /// FFTs. All scratch buffers are allocated locally and reused where 30 | /// possible to avoid extra allocations. 31 | fn forward, const N: usize>( 32 | &mut self, 33 | input: &mut ArrayBase>, 34 | ) -> Array, Dim<[Ix; N]>> 35 | where 36 | Dim<[Ix; N]>: RemoveAxis, 37 | [Ix; N]: IntoDimension>, 38 | { 39 | // Do real->complex on the last (contiguous) axis into an 40 | // uninitialized `output` buffer, then permute and run complex-to- 41 | // complex FFTs on the remaining axes while swapping two buffers 42 | // to avoid repeated allocations and copies. 43 | let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]); 44 | let rp = self.rp.plan_fft_forward(raw_dim[N - 1]); 45 | self.rp_origin_len = rp.len(); 46 | 47 | let mut output_shape = raw_dim; 48 | output_shape[N - 1] = rp.complex_len(); 49 | 50 | // let output = Array::uninit(output_shape); 51 | // let buffer = Array::uninit(output_shape); 52 | // let mut output = unsafe { output.assume_init() }; 53 | // let mut buffer = unsafe { buffer.assume_init() }; 54 | 55 | let mut output = Array::zeros(output_shape); 56 | let mut buffer = Array::zeros(output_shape); 57 | 58 | let mut scratch = vec![Complex::new(T::zero(), T::zero()); rp.get_scratch_len()]; 59 | for (mut input_row, mut output_row) in input.rows_mut().into_iter().zip(output.rows_mut()) { 60 | rp.process_with_scratch( 61 | input_row.as_slice_mut().unwrap(), 62 | output_row.as_slice_mut().unwrap(), 63 | &mut scratch, 64 | ) 65 | .unwrap(); 66 | } 67 | 68 | // axes permutation helper: rotate right so we make the next axis the 69 | // last (contiguous) axis on each iteration. 70 | let mut axes: [usize; N] = std::array::from_fn(|i| i); 71 | axes.rotate_right(1); 72 | 73 | // perform FFT on last axis, then permute and repeat for remaining axes 74 | for _ in 0..N - 1 { 75 | output = output.permuted_axes(axes); 76 | 77 | // reshape 78 | buffer = Array::from_shape_vec(output.raw_dim(), buffer.into_raw_vec_and_offset().0) 79 | .unwrap(); 80 | buffer.zip_mut_with(&output, |transpose, &origin| { 81 | *transpose = origin; 82 | }); 83 | 84 | // contiguous 85 | output = Array::from_shape_vec(output.raw_dim(), output.into_raw_vec_and_offset().0) 86 | .unwrap(); 87 | 88 | let fft = self 89 | .cp 90 | .plan_fft(output.shape()[N - 1], rustfft::FftDirection::Forward); 91 | let mut scratch = 92 | vec![Complex::new(T::zero(), T::zero()); fft.get_outofplace_scratch_len()]; 93 | 94 | fft.process_outofplace_with_scratch( 95 | buffer.as_slice_mut().unwrap(), 96 | output.as_slice_mut().unwrap(), 97 | &mut scratch, 98 | ); 99 | } 100 | 101 | output 102 | } 103 | 104 | /// Performs an inverse FFT on the given input array. 105 | /// 106 | /// This performs inverse complex-to-complex FFTs on the axes other 107 | /// than the last, then finishes with a complex-to-real inverse on the 108 | /// last axis (turning complex frequency data back into real samples). 109 | /// Like `forward`, scratch buffers are local and reused when possible. 110 | fn backward( 111 | &mut self, 112 | input: &mut Array, Dim<[Ix; N]>>, 113 | ) -> Array> 114 | where 115 | Dim<[Ix; N]>: RemoveAxis, 116 | [Ix; N]: IntoDimension>, 117 | { 118 | // Reverse the forward flow: perform inverse complex FFTs on the last 119 | // axis (for each remaining axis), permuting and copying into a 120 | // temporary buffer to maintain contiguous layout, then finally run 121 | // the complex->real inverse on the last axis. 122 | let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]); 123 | 124 | // one temporary buffer used per iteration; allocated to raw_dim and 125 | // re-shaped as necessary to reuse its allocation. 126 | let buffer = Array::uninit(raw_dim); 127 | let mut buffer = unsafe { buffer.assume_init() }; 128 | 129 | // axes permutation helper: rotate left to undo the right rotations 130 | // performed by forward. 131 | let mut axes: [usize; N] = std::array::from_fn(|i| i); 132 | axes.rotate_left(1); 133 | 134 | // work on a mutable view of the input so we can copy into it 135 | let mut input = input.view_mut(); 136 | 137 | for _ in 0..N - 1 { 138 | let fft = self.cp.plan_fft_inverse(buffer.shape()[N - 1]); 139 | let mut scratch = 140 | vec![Complex::new(T::zero(), T::zero()); fft.get_outofplace_scratch_len()]; 141 | 142 | // contiguous 143 | buffer = Array::from_shape_vec(buffer.raw_dim(), buffer.into_raw_vec_and_offset().0) 144 | .unwrap(); 145 | 146 | // out-of-place inverse FFT from `input` into `buffer` 147 | fft.process_outofplace_with_scratch( 148 | input.as_slice_mut().unwrap(), 149 | buffer.as_slice_mut().unwrap(), 150 | &mut scratch, 151 | ); 152 | 153 | // permute `buffer` so the next axis becomes the last, then copy 154 | // its contents back into `input` (which is arranged to be 155 | // contiguous for the next iteration). 156 | buffer = buffer.permuted_axes(axes); 157 | input = unsafe { ArrayViewMut::from_shape_ptr(buffer.raw_dim(), input.as_mut_ptr()) }; 158 | input.zip_mut_with(&buffer, |dst, &src| *dst = src); 159 | } 160 | 161 | // now inverse real FFT on the last axis 162 | let rp = self.rp.plan_fft_inverse(self.rp_origin_len); 163 | 164 | let mut output_shape = input.raw_dim(); 165 | output_shape[N - 1] = self.rp_origin_len; 166 | let mut output = Array::zeros(output_shape); 167 | 168 | let mut scratch = vec![Complex::new(T::zero(), T::zero()); rp.get_scratch_len()]; 169 | for (mut input_row, mut output_row) in input.rows_mut().into_iter().zip(output.rows_mut()) { 170 | // no need to check result 171 | // large input sizes may cause slight numerical issues 172 | let _ = rp.process_with_scratch( 173 | input_row.as_slice_mut().unwrap(), 174 | output_row.as_slice_mut().unwrap(), 175 | &mut scratch, 176 | ); 177 | } 178 | 179 | let len = T::from_usize(output.len()).unwrap(); 180 | output.map_mut(|x| *x = x.div(len)); 181 | output 182 | } 183 | } 184 | 185 | #[cfg(test)] 186 | mod tests { 187 | use super::*; 188 | use ndarray::array; 189 | 190 | // ===== 1D Roundtrip Tests ===== 191 | 192 | mod roundtrip_1d { 193 | use super::*; 194 | 195 | #[test] 196 | fn basic() { 197 | let original = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; 198 | let mut input = original.clone(); 199 | 200 | let mut p = Processor { 201 | rp: realfft::RealFftPlanner::new(), 202 | rp_origin_len: 0, 203 | cp: rustfft::FftPlanner::new(), 204 | }; 205 | 206 | let mut freq = p.forward(&mut input); 207 | let reconstructed = p.backward(&mut freq); 208 | 209 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 210 | assert!( 211 | (orig - recon).abs() < 1e-10, 212 | "1D Forward->Backward failed. Original: {}, Reconstructed: {}", 213 | orig, 214 | recon 215 | ); 216 | } 217 | } 218 | 219 | #[test] 220 | fn different_sizes() { 221 | // Test various 1D sizes to catch edge cases 222 | let test_cases = vec![ 223 | array![1.0f64, 2.0], 224 | array![1.0, 2.0, 3.0], 225 | array![1.0, 2.0, 3.0, 4.0, 5.0], 226 | ]; 227 | 228 | for (i, original) in test_cases.into_iter().enumerate() { 229 | let mut input = original.clone(); 230 | let mut p = Processor { 231 | rp: realfft::RealFftPlanner::new(), 232 | rp_origin_len: 0, 233 | cp: rustfft::FftPlanner::new(), 234 | }; 235 | 236 | let mut freq = p.forward(&mut input); 237 | let reconstructed = p.backward(&mut freq); 238 | 239 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 240 | assert!( 241 | (orig - recon).abs() < 1e-10, 242 | "1D Test case {} failed. Original: {}, Reconstructed: {}", 243 | i, 244 | orig, 245 | recon 246 | ); 247 | } 248 | } 249 | } 250 | } 251 | 252 | // ===== 2D Roundtrip Tests ===== 253 | 254 | mod roundtrip_2d { 255 | use super::*; 256 | 257 | #[test] 258 | fn basic() { 259 | let original = array![[1.0f64, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]; 260 | let mut input = original.clone(); 261 | 262 | let mut p = Processor { 263 | rp: realfft::RealFftPlanner::new(), 264 | rp_origin_len: 0, 265 | cp: rustfft::FftPlanner::new(), 266 | }; 267 | 268 | let mut freq = p.forward(&mut input); 269 | let reconstructed = p.backward(&mut freq); 270 | 271 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 272 | assert!( 273 | (orig - recon).abs() < 1e-10, 274 | "2D Forward->Backward failed. Original: {}, Reconstructed: {}", 275 | orig, 276 | recon 277 | ); 278 | } 279 | } 280 | 281 | #[test] 282 | fn different_sizes() { 283 | // Test 2x2 284 | let original = array![[1.0f64, 2.0], [3.0, 4.0]]; 285 | let mut input = original.clone(); 286 | let mut p = Processor { 287 | rp: realfft::RealFftPlanner::new(), 288 | rp_origin_len: 0, 289 | cp: rustfft::FftPlanner::new(), 290 | }; 291 | 292 | let mut freq = p.forward(&mut input); 293 | let reconstructed = p.backward(&mut freq); 294 | 295 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 296 | assert!( 297 | (orig - recon).abs() < 1e-10, 298 | "2D (2x2) test failed. Original: {}, Reconstructed: {}", 299 | orig, 300 | recon 301 | ); 302 | } 303 | 304 | // Test 3x3 305 | let original = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]; 306 | let mut input = original.clone(); 307 | let mut p = Processor { 308 | rp: realfft::RealFftPlanner::new(), 309 | rp_origin_len: 0, 310 | cp: rustfft::FftPlanner::new(), 311 | }; 312 | 313 | let mut freq = p.forward(&mut input); 314 | let reconstructed = p.backward(&mut freq); 315 | 316 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 317 | assert!( 318 | (orig - recon).abs() < 1e-10, 319 | "2D (3x3) test failed. Original: {}, Reconstructed: {}", 320 | orig, 321 | recon 322 | ); 323 | } 324 | } 325 | 326 | #[test] 327 | fn large_array() { 328 | use ndarray_rand::rand_distr::Uniform; 329 | use ndarray_rand::RandomExt; 330 | 331 | // Test large array that might trigger edge cases 332 | let original = Array::random((200, 5000), Uniform::new(0f32, 1f32).unwrap()); 333 | let mut input = original.clone(); 334 | 335 | let mut p = Processor { 336 | rp: realfft::RealFftPlanner::new(), 337 | rp_origin_len: 0, 338 | cp: rustfft::FftPlanner::new(), 339 | }; 340 | 341 | let mut freq = p.forward(&mut input); 342 | let reconstructed = p.backward(&mut freq); 343 | 344 | // Check a sample of values 345 | let sample_indices = vec![(0, 0), (0, 100), (100, 0), (100, 2500), (199, 4999)]; 346 | for &(i, j) in &sample_indices { 347 | let orig = original[[i, j]]; 348 | let recon = reconstructed[[i, j]]; 349 | assert!( 350 | (orig - recon).abs() < 1e-6, 351 | "Large 2D test failed at ({}, {}). Original: {}, Reconstructed: {}, Diff: {}", 352 | i, 353 | j, 354 | orig, 355 | recon, 356 | (orig - recon).abs() 357 | ); 358 | } 359 | 360 | // Check overall statistics 361 | let max_diff = original 362 | .iter() 363 | .zip(reconstructed.iter()) 364 | .map(|(o, r)| (o - r).abs()) 365 | .fold(0.0f32, |acc, x| acc.max(x)); 366 | 367 | assert!( 368 | max_diff < 1e-6, 369 | "Maximum reconstruction error {} exceeds tolerance", 370 | max_diff 371 | ); 372 | } 373 | } 374 | 375 | // ===== 3D Roundtrip Tests ===== 376 | 377 | mod roundtrip_3d { 378 | use super::*; 379 | 380 | #[test] 381 | fn basic() { 382 | let original = array![ 383 | [[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]], 384 | [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], 385 | ]; 386 | let mut input = original.clone(); 387 | 388 | let mut p = Processor { 389 | rp: realfft::RealFftPlanner::new(), 390 | rp_origin_len: 0, 391 | cp: rustfft::FftPlanner::new(), 392 | }; 393 | 394 | let mut a_fft = p.forward(&mut input); 395 | let reconstructed = p.backward(&mut a_fft); 396 | 397 | for (orig, recon) in original.iter().zip(reconstructed.iter()) { 398 | assert!( 399 | (orig - recon).abs() < 1e-10, 400 | "3D Forward->Backward failed. Original: {}, Reconstructed: {}, Diff: {}", 401 | orig, 402 | recon, 403 | (orig - recon).abs() 404 | ); 405 | } 406 | } 407 | } 408 | 409 | // ===== Low-Level FFT API Tests ===== 410 | 411 | mod fft_api { 412 | use super::*; 413 | use rustfft::num_complex::Complex; 414 | 415 | #[test] 416 | fn manual_complex_fft_roundtrip() { 417 | // Test using rustfft directly to verify FFT planner usage 418 | let mut arr = array![[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4],] 419 | .map(|&v| Complex::new(v as f32, 0.0)); 420 | let mut fft = rustfft::FftPlanner::new(); 421 | 422 | // Forward FFT 423 | let row_forward = fft.plan_fft_forward(arr.shape()[1]); 424 | for mut row in arr.rows_mut() { 425 | row_forward.process(row.as_slice_mut().unwrap()); 426 | } 427 | 428 | // Transpose 429 | let mut arr = Array::from_shape_vec( 430 | [arr.shape()[1], arr.shape()[0]], 431 | arr.permuted_axes([1, 0]).iter().copied().collect(), 432 | ) 433 | .unwrap(); 434 | 435 | let row_forward = fft.plan_fft_forward(arr.shape()[1]); 436 | for mut row in arr.rows_mut() { 437 | row_forward.process(row.as_slice_mut().unwrap()); 438 | } 439 | 440 | arr /= Complex::new(16.0, 0.0); 441 | 442 | // Backward FFT 443 | let row_backward = fft.plan_fft_inverse(arr.shape()[1]); 444 | for mut row in arr.rows_mut() { 445 | row_backward.process(row.as_slice_mut().unwrap()); 446 | } 447 | 448 | // Transpose back 449 | let mut arr = Array::from_shape_vec( 450 | [arr.shape()[1], arr.shape()[0]], 451 | arr.permuted_axes([1, 0]).iter().copied().collect(), 452 | ) 453 | .unwrap(); 454 | 455 | let row_backward = fft.plan_fft_inverse(arr.shape()[1]); 456 | for mut row in arr.rows_mut() { 457 | row_backward.process(row.as_slice_mut().unwrap()); 458 | } 459 | 460 | // Verify reconstruction (should be close to original [[1,2,3,4], ...]) 461 | for val in arr.iter() { 462 | let expected_re = val.re.round(); 463 | assert!( 464 | (val.re - expected_re).abs() < 1e-5, 465 | "FFT roundtrip failed. Got {}, expected approximately {}", 466 | val.re, 467 | expected_re 468 | ); 469 | assert!( 470 | val.im.abs() < 1e-5, 471 | "Imaginary part should be near zero, got {}", 472 | val.im 473 | ); 474 | } 475 | } 476 | } 477 | } 478 | -------------------------------------------------------------------------------- /src/conv/tests.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use crate::{dilation::WithDilation, ReverseKernel}; 3 | use ndarray::prelude::*; 4 | use num::traits::FromPrimitive; 5 | 6 | // ===== Helper Functions ===== 7 | 8 | fn assert_eq_tch(res: Array>, res_tch: tch::Tensor) 9 | where 10 | T: PartialEq + FromPrimitive + std::fmt::Debug, 11 | Dim<[usize; N]>: Dimension, 12 | { 13 | let tch_res = Array::from_iter(res_tch.reshape(res.len() as i64).iter::().unwrap()) 14 | .to_shape(res.shape()) 15 | .unwrap() 16 | .map(|v| T::from_f64(v.round()).unwrap()) 17 | .into_dimensionality::>() 18 | .unwrap(); 19 | 20 | assert_eq!( 21 | res, tch_res, 22 | "Conv result doesn't match LibTorch.\nGot: {:?}\nExpected: {:?}", 23 | res, tch_res 24 | ); 25 | } 26 | 27 | fn get_tch_shape(shape: &[usize]) -> Vec { 28 | std::iter::repeat_n(1, 2) 29 | .chain(shape.iter().map(|v| *v as i64)) 30 | .collect::>() 31 | } 32 | 33 | // ===== Verification Against LibTorch ===== 34 | // These tests establish Conv as the trusted reference implementation 35 | 36 | mod vs_torch { 37 | use super::*; 38 | 39 | // ----- Full Mode ----- 40 | 41 | mod full_mode { 42 | use super::*; 43 | 44 | #[test] 45 | fn test_1d() { 46 | let arr = array![1, 2, 3, 4, 5]; 47 | let kernel = array![1, 2, 1]; 48 | 49 | let res = arr 50 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 51 | .unwrap(); 52 | 53 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 54 | .to_dtype(tch::Kind::Float, false, true) 55 | .reshape(get_tch_shape(arr.shape())); 56 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 57 | .to_dtype(tch::Kind::Float, false, true) 58 | .reshape(get_tch_shape(kernel.shape())); 59 | 60 | // Full mode: padding = kernel_size - 1 = 3 - 1 = 2 61 | let res_tch = tensor 62 | .f_conv1d::(&kernel_tensor, None, 1, 2, 1, 1) 63 | .unwrap(); 64 | 65 | assert_eq_tch(res, res_tch); 66 | } 67 | 68 | #[test] 69 | fn test_2d() { 70 | let arr = array![[1, 2], [3, 4]]; 71 | let kernel = array![[1, 1], [1, 1]]; 72 | 73 | let res = arr 74 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 75 | .unwrap(); 76 | 77 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 78 | .to_dtype(tch::Kind::Float, false, true) 79 | .reshape(get_tch_shape(arr.shape())); 80 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 81 | .to_dtype(tch::Kind::Float, false, true) 82 | .reshape(get_tch_shape(kernel.shape())); 83 | 84 | // Full mode: padding = kernel_size - 1 = [2 - 1, 2 - 1] = [1, 1] 85 | let res_tch = tensor 86 | .f_conv2d::(&kernel_tensor, None, 1, 1, 1, 1) 87 | .unwrap(); 88 | 89 | assert_eq_tch(res, res_tch); 90 | } 91 | 92 | #[test] 93 | fn test_3d() { 94 | let arr = array![[[1, 2]], [[3, 4]]]; 95 | let kernel = array![[[1, 1]], [[1, 1]]]; 96 | 97 | let res = arr 98 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 99 | .unwrap(); 100 | 101 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 102 | .to_dtype(tch::Kind::Float, false, true) 103 | .reshape(get_tch_shape(arr.shape())); 104 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 105 | .to_dtype(tch::Kind::Float, false, true) 106 | .reshape(get_tch_shape(kernel.shape())); 107 | 108 | // Full mode: padding = kernel_size - 1 = [2-1, 1-1, 2-1] = [1, 0, 1] 109 | let res_tch = tensor 110 | .f_conv3d::(&kernel_tensor, None, 1, [1, 0, 1], 1, 1) 111 | .unwrap(); 112 | 113 | assert_eq_tch(res, res_tch); 114 | } 115 | } 116 | 117 | // ----- Same Mode ----- 118 | 119 | mod same_mode { 120 | use super::*; 121 | 122 | #[test] 123 | fn test_1d() { 124 | let arr = array![1, 2, 3, 4, 5]; 125 | let kernel = array![1, 2, 1]; 126 | 127 | let res = arr 128 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 129 | .unwrap(); 130 | 131 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 132 | .to_dtype(tch::Kind::Float, false, true) 133 | .reshape(get_tch_shape(arr.shape())); 134 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 135 | .to_dtype(tch::Kind::Float, false, true) 136 | .reshape(get_tch_shape(kernel.shape())); 137 | 138 | let res_tch = tensor 139 | .f_conv1d_padding::(&kernel_tensor, None, 1, "same", 1, 1) 140 | .unwrap(); 141 | 142 | assert_eq_tch(res, res_tch); 143 | } 144 | 145 | #[test] 146 | fn test_2d() { 147 | let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 148 | let kernel = array![[1, 0, -1], [2, 0, -2], [1, 0, -1]]; 149 | 150 | // Default behavior: kernel is reversed 151 | let res = arr 152 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 153 | .unwrap(); 154 | 155 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 156 | .to_dtype(tch::Kind::Float, false, true) 157 | .reshape(get_tch_shape(arr.shape())); 158 | 159 | // PyTorch doesn't reverse kernel, so we need to reverse it first 160 | let kernel_reversed = kernel 161 | .as_slice() 162 | .unwrap() 163 | .iter() 164 | .copied() 165 | .rev() 166 | .collect::>(); 167 | let kernel_tensor = tch::Tensor::from_slice(&kernel_reversed) 168 | .to_dtype(tch::Kind::Float, false, true) 169 | .reshape(get_tch_shape(kernel.shape())); 170 | 171 | let res_tch = tensor 172 | .f_conv2d_padding::(&kernel_tensor, None, 1, "same", 1, 1) 173 | .unwrap(); 174 | 175 | assert_eq_tch(res, res_tch); 176 | } 177 | 178 | #[test] 179 | fn test_3d() { 180 | // Use a simpler 3D array with symmetric kernel 181 | let arr = array![[[1, 2, 3]], [[4, 5, 6]]]; 182 | let kernel = array![[[1, 1, 1]]]; 183 | 184 | let res = arr 185 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 186 | .unwrap(); 187 | 188 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 189 | .to_dtype(tch::Kind::Float, false, true) 190 | .reshape(get_tch_shape(arr.shape())); 191 | 192 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 193 | .to_dtype(tch::Kind::Float, false, true) 194 | .reshape(get_tch_shape(kernel.shape())); 195 | 196 | // Same mode with kernel [1, 1, 3]: padding = [0, 0, 1] 197 | let res_tch = tensor 198 | .f_conv3d_padding::(&kernel_tensor, None, 1, "same", 1, 1) 199 | .unwrap(); 200 | 201 | assert_eq_tch(res, res_tch); 202 | } 203 | } 204 | 205 | // ----- Valid Mode ----- 206 | 207 | mod valid_mode { 208 | use super::*; 209 | 210 | #[test] 211 | fn test_1d() { 212 | let arr = array![1, 2, 3, 4, 5]; 213 | let kernel = array![1, 2, 1]; 214 | 215 | let res = arr 216 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 217 | .unwrap(); 218 | 219 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 220 | .to_dtype(tch::Kind::Float, false, true) 221 | .reshape(get_tch_shape(arr.shape())); 222 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 223 | .to_dtype(tch::Kind::Float, false, true) 224 | .reshape(get_tch_shape(kernel.shape())); 225 | 226 | let res_tch = tensor 227 | .f_conv1d_padding::(&kernel_tensor, None, 1, "valid", 1, 1) 228 | .unwrap(); 229 | 230 | assert_eq_tch(res, res_tch); 231 | } 232 | 233 | #[test] 234 | fn test_2d() { 235 | let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 236 | let kernel = array![[1, 1], [1, 1]]; 237 | 238 | let res = arr 239 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 240 | .unwrap(); 241 | 242 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 243 | .to_dtype(tch::Kind::Float, false, true) 244 | .reshape(get_tch_shape(arr.shape())); 245 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 246 | .to_dtype(tch::Kind::Float, false, true) 247 | .reshape(get_tch_shape(kernel.shape())); 248 | 249 | let res_tch = tensor 250 | .f_conv2d_padding::(&kernel_tensor, None, 1, "valid", 1, 1) 251 | .unwrap(); 252 | 253 | assert_eq_tch(res, res_tch); 254 | } 255 | 256 | #[test] 257 | fn test_3d() { 258 | let arr = array![[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]; 259 | let kernel = array![[[1, 1]], [[1, 1]]]; 260 | 261 | let res = arr 262 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 263 | .unwrap(); 264 | 265 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 266 | .to_dtype(tch::Kind::Float, false, true) 267 | .reshape(get_tch_shape(arr.shape())); 268 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 269 | .to_dtype(tch::Kind::Float, false, true) 270 | .reshape(get_tch_shape(kernel.shape())); 271 | 272 | let res_tch = tensor 273 | .f_conv3d_padding::(&kernel_tensor, None, 1, "valid", 1, 1) 274 | .unwrap(); 275 | 276 | assert_eq_tch(res, res_tch); 277 | } 278 | } 279 | 280 | // ----- With Stride ----- 281 | 282 | mod with_stride { 283 | use super::*; 284 | 285 | #[test] 286 | fn stride_2_1d() { 287 | let arr = array![1, 2, 3, 4, 5, 6]; 288 | let kernel = array![1, 1, 1]; 289 | 290 | let res = arr 291 | .conv( 292 | &kernel, 293 | ConvMode::Custom { 294 | padding: [1], 295 | strides: [2], 296 | }, 297 | PaddingMode::Zeros, 298 | ) 299 | .unwrap(); 300 | 301 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 302 | .to_dtype(tch::Kind::Float, false, true) 303 | .reshape(get_tch_shape(arr.shape())); 304 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 305 | .to_dtype(tch::Kind::Float, false, true) 306 | .reshape(get_tch_shape(kernel.shape())); 307 | 308 | let res_tch = tensor 309 | .f_conv1d::(&kernel_tensor, None, 2, 1, 1, 1) 310 | .unwrap(); 311 | 312 | assert_eq_tch(res, res_tch); 313 | } 314 | 315 | #[test] 316 | fn stride_2_2d() { 317 | let arr = array![[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]; 318 | let kernel = array![[1, 1], [1, 1]]; 319 | 320 | let res = arr 321 | .conv( 322 | &kernel, 323 | ConvMode::Custom { 324 | padding: [1, 1], 325 | strides: [2, 2], 326 | }, 327 | PaddingMode::Zeros, 328 | ) 329 | .unwrap(); 330 | 331 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 332 | .to_dtype(tch::Kind::Float, false, true) 333 | .reshape(get_tch_shape(arr.shape())); 334 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 335 | .to_dtype(tch::Kind::Float, false, true) 336 | .reshape(get_tch_shape(kernel.shape())); 337 | 338 | let res_tch = tensor 339 | .f_conv2d::(&kernel_tensor, None, [2, 2], 1, 1, 1) 340 | .unwrap(); 341 | 342 | assert_eq_tch(res, res_tch); 343 | } 344 | 345 | #[test] 346 | fn stride_3_1d() { 347 | let arr = array![1, 2, 3, 4, 5, 6, 7, 8, 9]; 348 | let kernel = array![1, 2, 1]; 349 | 350 | let res = arr 351 | .conv( 352 | &kernel, 353 | ConvMode::Custom { 354 | padding: [2], 355 | strides: [3], 356 | }, 357 | PaddingMode::Zeros, 358 | ) 359 | .unwrap(); 360 | 361 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 362 | .to_dtype(tch::Kind::Float, false, true) 363 | .reshape(get_tch_shape(arr.shape())); 364 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 365 | .to_dtype(tch::Kind::Float, false, true) 366 | .reshape(get_tch_shape(kernel.shape())); 367 | 368 | let res_tch = tensor 369 | .f_conv1d::(&kernel_tensor, None, 3, 2, 1, 1) 370 | .unwrap(); 371 | 372 | assert_eq_tch(res, res_tch); 373 | } 374 | } 375 | 376 | // ----- With Dilation ----- 377 | 378 | mod with_dilation { 379 | use super::*; 380 | 381 | #[test] 382 | fn dilation_2_1d() { 383 | let arr = array![1, 2, 3, 4, 5, 6]; 384 | let kernel = array![1, 1, 2]; 385 | 386 | let res = arr 387 | .conv( 388 | kernel.with_dilation(2).no_reverse(), 389 | ConvMode::Custom { 390 | padding: [4], 391 | strides: [2], 392 | }, 393 | PaddingMode::Zeros, 394 | ) 395 | .unwrap(); 396 | 397 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 398 | .to_dtype(tch::Kind::Float, false, true) 399 | .reshape(get_tch_shape(arr.shape())); 400 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 401 | .to_dtype(tch::Kind::Float, false, true) 402 | .reshape(get_tch_shape(kernel.shape())); 403 | 404 | let res_tch = tensor 405 | .f_conv1d::(&kernel_tensor, None, 2, 4, 2, 1) 406 | .unwrap(); 407 | 408 | assert_eq_tch(res, res_tch); 409 | } 410 | 411 | #[test] 412 | fn dilation_2_2d() { 413 | let arr = array![[1, 1, 1], [1, 1, 1], [1, 1, 2]]; 414 | let kernel = array![[2, 1, 1], [1, 1, 1]]; 415 | 416 | let res = arr 417 | .conv( 418 | kernel.with_dilation(2).no_reverse(), 419 | ConvMode::Same, 420 | PaddingMode::Zeros, 421 | ) 422 | .unwrap(); 423 | 424 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 425 | .to_dtype(tch::Kind::Float, false, true) 426 | .reshape(get_tch_shape(arr.shape())); 427 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 428 | .to_dtype(tch::Kind::Float, false, true) 429 | .reshape(get_tch_shape(kernel.shape())); 430 | 431 | let res_tch = tensor 432 | .f_conv2d_padding::(&kernel_tensor, None, 1, "same", 2, 1) 433 | .unwrap(); 434 | 435 | assert_eq_tch(res, res_tch); 436 | } 437 | 438 | #[test] 439 | fn dilation_2_3d() { 440 | let arr = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; 441 | let kernel = array![ 442 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 443 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 444 | ]; 445 | 446 | let res = arr 447 | .conv( 448 | kernel.with_dilation(2).no_reverse(), 449 | ConvMode::Custom { 450 | padding: [2, 2, 2], 451 | strides: [1, 2, 1], 452 | }, 453 | PaddingMode::Zeros, 454 | ) 455 | .unwrap(); 456 | 457 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 458 | .to_dtype(tch::Kind::Float, false, true) 459 | .reshape(get_tch_shape(arr.shape())); 460 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 461 | .to_dtype(tch::Kind::Float, false, true) 462 | .reshape(get_tch_shape(kernel.shape())); 463 | 464 | let res_tch = tensor 465 | .f_conv3d::(&kernel_tensor, None, [1, 2, 1], 2, 2, 1) 466 | .unwrap(); 467 | 468 | assert_eq_tch(res, res_tch); 469 | } 470 | } 471 | 472 | // ----- Kernel Reversal ----- 473 | 474 | mod kernel_reverse { 475 | use super::*; 476 | 477 | #[test] 478 | fn with_reverse() { 479 | let arr = array![1, 2, 3, 4, 5, 6]; 480 | let kernel = array![1, 1, 2]; 481 | 482 | let res = arr 483 | .conv( 484 | kernel.with_dilation(2), 485 | ConvMode::Custom { 486 | padding: [4], 487 | strides: [2], 488 | }, 489 | PaddingMode::Zeros, 490 | ) 491 | .unwrap(); 492 | 493 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 494 | .to_dtype(tch::Kind::Float, false, true) 495 | .reshape(get_tch_shape(arr.shape())); 496 | 497 | // Reverse kernel for torch (it expects non-reversed) 498 | let kernel_tensor = tch::Tensor::from_slice( 499 | &kernel 500 | .as_slice() 501 | .unwrap() 502 | .iter() 503 | .copied() 504 | .rev() 505 | .collect::>(), 506 | ) 507 | .to_dtype(tch::Kind::Float, false, true) 508 | .reshape(get_tch_shape(kernel.shape())); 509 | 510 | let res_tch = tensor 511 | .f_conv1d::(&kernel_tensor, None, 2, 4, 2, 1) 512 | .unwrap(); 513 | 514 | assert_eq_tch(res, res_tch); 515 | } 516 | 517 | #[test] 518 | fn no_reverse() { 519 | let arr = array![1, 2, 3, 4, 5, 6]; 520 | let kernel = array![1, 1, 2]; 521 | 522 | let res = arr 523 | .conv( 524 | kernel.with_dilation(2).no_reverse(), 525 | ConvMode::Custom { 526 | padding: [4], 527 | strides: [2], 528 | }, 529 | PaddingMode::Zeros, 530 | ) 531 | .unwrap(); 532 | 533 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 534 | .to_dtype(tch::Kind::Float, false, true) 535 | .reshape(get_tch_shape(arr.shape())); 536 | let kernel_tensor = tch::Tensor::from_slice(kernel.as_slice().unwrap()) 537 | .to_dtype(tch::Kind::Float, false, true) 538 | .reshape(get_tch_shape(kernel.shape())); 539 | 540 | let res_tch = tensor 541 | .f_conv1d::(&kernel_tensor, None, 2, 4, 2, 1) 542 | .unwrap(); 543 | 544 | assert_eq_tch(res, res_tch); 545 | } 546 | } 547 | } 548 | 549 | // ===== Edge Cases ===== 550 | // Quick regression tests without external dependencies 551 | 552 | mod edge_cases { 553 | use super::*; 554 | 555 | #[test] 556 | fn single_element_array() { 557 | let arr = array![42]; 558 | let kernel = array![2]; 559 | let res = arr 560 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 561 | .unwrap(); 562 | assert_eq!(res, array![84]); 563 | } 564 | 565 | #[test] 566 | fn single_element_kernel() { 567 | let arr = array![1, 2, 3, 4]; 568 | let kernel = array![3]; 569 | let res = arr 570 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 571 | .unwrap(); 572 | assert_eq!(res, array![3, 6, 9, 12]); 573 | } 574 | 575 | #[test] 576 | fn identity_kernel() { 577 | let arr = array![[1, 2], [3, 4]]; 578 | let kernel = array![[1]]; 579 | let res = arr 580 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 581 | .unwrap(); 582 | assert_eq!(res, arr); 583 | } 584 | } 585 | -------------------------------------------------------------------------------- /src/conv_fft/tests.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use crate::{dilation::WithDilation, ConvExt, ReverseKernel}; 3 | use ndarray::prelude::*; 4 | use rustfft::num_complex::Complex; 5 | 6 | // ===== Verification Against Conv ===== 7 | // Conv FFT results should match Conv (the trusted baseline) 8 | 9 | mod vs_conv { 10 | use num::complex::ComplexFloat; 11 | 12 | use super::*; 13 | 14 | // Tolerance constants 15 | const TOLERANCE_F32: f32 = 1e-5; 16 | const TOLERANCE_F64: f64 = 1e-9; 17 | 18 | /// Compare f32 FFT results with Conv baseline 19 | fn assert_fft_matches_conv_f32( 20 | fft: Array>, 21 | conv: Array>, 22 | ) where 23 | Dim<[usize; N]>: Dimension, 24 | { 25 | assert_eq!( 26 | fft.shape(), 27 | conv.shape(), 28 | "Shape mismatch: FFT {:?} vs Conv {:?}", 29 | fft.shape(), 30 | conv.shape() 31 | ); 32 | 33 | fft.iter() 34 | .zip(conv.iter()) 35 | .enumerate() 36 | .for_each(|(idx, (fft_val, conv_val))| { 37 | let diff = (fft_val.round() - *conv_val as f32).abs(); 38 | assert!( 39 | diff < TOLERANCE_F32, 40 | "Mismatch at index {}: FFT={:.6}, Conv={}, diff={:.6}", 41 | idx, 42 | fft_val, 43 | conv_val, 44 | diff 45 | ); 46 | }); 47 | } 48 | 49 | /// Compare f64 FFT results with Conv baseline 50 | fn assert_fft_matches_conv_f64( 51 | fft: Array>, 52 | conv: Array>, 53 | ) where 54 | Dim<[usize; N]>: Dimension, 55 | { 56 | assert_eq!( 57 | fft.shape(), 58 | conv.shape(), 59 | "Shape mismatch: FFT {:?} vs Conv {:?}", 60 | fft.shape(), 61 | conv.shape() 62 | ); 63 | 64 | fft.iter() 65 | .zip(conv.iter()) 66 | .enumerate() 67 | .for_each(|(idx, (fft_val, conv_val))| { 68 | let diff = (fft_val.round() - *conv_val as f64).abs(); 69 | assert!( 70 | diff < TOLERANCE_F64, 71 | "Mismatch at index {}: FFT={:.10}, Conv={}, diff={:.10}", 72 | idx, 73 | fft_val, 74 | conv_val, 75 | diff 76 | ); 77 | }); 78 | } 79 | 80 | /// Compare Complex FFT results with another Complex result 81 | fn assert_fft_matches_conv_complex( 82 | fft: Array, Dim<[usize; N]>>, 83 | conv: Array, Dim<[usize; N]>>, 84 | ) where 85 | Dim<[usize; N]>: Dimension, 86 | { 87 | assert_eq!( 88 | fft.shape(), 89 | conv.shape(), 90 | "Shape mismatch: FFT {:?} vs Conv {:?}", 91 | fft.shape(), 92 | conv.shape() 93 | ); 94 | 95 | fft.iter() 96 | .zip(conv.iter()) 97 | .enumerate() 98 | .for_each(|(idx, (fft_val, conv_val))| { 99 | let diff = (fft_val - conv_val).abs(); 100 | assert!( 101 | diff < TOLERANCE_F32, 102 | "Mismatch at index {}: FFT={:.6}+{:.6}i, Conv={:.6}+{:.6}i, diff={:.6}", 103 | idx, 104 | fft_val.re, 105 | fft_val.im, 106 | conv_val.re, 107 | conv_val.im, 108 | diff 109 | ); 110 | }); 111 | } 112 | 113 | /// Compare Complex FFT results with another Complex result 114 | fn assert_fft_matches_conv_complex_f64( 115 | fft: Array, Dim<[usize; N]>>, 116 | conv: Array, Dim<[usize; N]>>, 117 | ) where 118 | Dim<[usize; N]>: Dimension, 119 | { 120 | assert_eq!( 121 | fft.shape(), 122 | conv.shape(), 123 | "Shape mismatch: FFT {:?} vs Conv {:?}", 124 | fft.shape(), 125 | conv.shape() 126 | ); 127 | 128 | fft.iter() 129 | .zip(conv.iter()) 130 | .enumerate() 131 | .for_each(|(idx, (fft_val, conv_val))| { 132 | let diff = (fft_val - conv_val).abs(); 133 | assert!( 134 | diff < TOLERANCE_F64, 135 | "Mismatch at index {}: FFT={:.10}+{:.10}i, Conv={:.10}+{:.10}i, diff={:.10}", 136 | idx, 137 | fft_val.re, 138 | fft_val.im, 139 | conv_val.re, 140 | conv_val.im, 141 | diff 142 | ); 143 | }); 144 | } 145 | 146 | // ----- 1D Tests ----- 147 | 148 | mod one_d { 149 | use super::*; 150 | 151 | #[test] 152 | fn same_mode_f32() { 153 | let arr = array![1, 2, 3, 4, 5, 6]; 154 | let kernel = array![1, 1, 1, 1]; 155 | 156 | let conv_result = arr 157 | .conv(kernel.with_dilation(2), ConvMode::Same, PaddingMode::Zeros) 158 | .unwrap(); 159 | 160 | let fft_result = arr 161 | .map(|&x| x as f32) 162 | .conv_fft( 163 | kernel.map(|&x| x as f32).with_dilation(2), 164 | ConvMode::Same, 165 | PaddingMode::Zeros, 166 | ) 167 | .unwrap(); 168 | 169 | assert_fft_matches_conv_f32(fft_result, conv_result); 170 | } 171 | 172 | #[test] 173 | fn same_mode_complex() { 174 | // Test with actual complex numbers (non-zero imaginary parts) 175 | let arr_complex = array![ 176 | Complex::new(1.0, 0.5), 177 | Complex::new(2.0, -0.3), 178 | Complex::new(3.0, 0.8), 179 | Complex::new(4.0, -0.2), 180 | Complex::new(5.0, 0.6), 181 | Complex::new(6.0, -0.4), 182 | ]; 183 | let kernel_complex = array![ 184 | Complex::new(1.0, 0.1), 185 | Complex::new(1.0, -0.1), 186 | Complex::new(1.0, 0.2), 187 | Complex::new(1.0, -0.2), 188 | ]; 189 | 190 | let conv_result = arr_complex 191 | .conv_fft( 192 | kernel_complex.with_dilation(2), 193 | ConvMode::Same, 194 | PaddingMode::Zeros, 195 | ) 196 | .unwrap(); 197 | 198 | let fft_result = arr_complex 199 | .conv_fft( 200 | kernel_complex.with_dilation(2), 201 | ConvMode::Same, 202 | PaddingMode::Zeros, 203 | ) 204 | .unwrap(); 205 | 206 | assert_fft_matches_conv_complex(fft_result, conv_result); 207 | } 208 | 209 | #[test] 210 | fn circular_padding() { 211 | let arr: Array1 = array![ 212 | 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4 213 | ]; 214 | let kernel: Array1 = array![0.1, 0.3, 0.6, 0.3, 0.1]; 215 | 216 | let conv_result = arr 217 | .conv(&kernel, ConvMode::Same, PaddingMode::Circular) 218 | .unwrap(); 219 | 220 | let fft_result = arr 221 | .conv_fft(&kernel, ConvMode::Same, PaddingMode::Circular) 222 | .unwrap(); 223 | 224 | conv_result 225 | .iter() 226 | .zip(fft_result.iter()) 227 | .enumerate() 228 | .for_each(|(idx, (conv_val, fft_val))| { 229 | assert!( 230 | (conv_val - fft_val).abs() < 1e-6, 231 | "Mismatch at index {}: Conv={:.6}, FFT={:.6}", 232 | idx, 233 | conv_val, 234 | fft_val 235 | ); 236 | }); 237 | } 238 | 239 | #[test] 240 | fn full_mode() { 241 | let arr = array![1, 2, 3, 4, 5]; 242 | let kernel = array![1, 2, 1]; 243 | 244 | let conv_result = arr 245 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 246 | .unwrap(); 247 | 248 | let fft_result = arr 249 | .map(|&x| x as f64) 250 | .conv_fft( 251 | &kernel.map(|&x| x as f64), 252 | ConvMode::Full, 253 | PaddingMode::Zeros, 254 | ) 255 | .unwrap(); 256 | 257 | assert_fft_matches_conv_f64(fft_result, conv_result); 258 | } 259 | 260 | #[test] 261 | fn valid_mode() { 262 | let arr = array![1, 2, 3, 4, 5, 6]; 263 | let kernel = array![1, 1, 1]; 264 | 265 | let conv_result = arr 266 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 267 | .unwrap(); 268 | 269 | let fft_result = arr 270 | .map(|&x| x as f32) 271 | .conv_fft( 272 | &kernel.map(|&x| x as f32), 273 | ConvMode::Valid, 274 | PaddingMode::Zeros, 275 | ) 276 | .unwrap(); 277 | 278 | assert_fft_matches_conv_f32(fft_result, conv_result); 279 | } 280 | } 281 | 282 | // ----- 2D Tests ----- 283 | 284 | mod two_d { 285 | use super::*; 286 | 287 | #[test] 288 | fn same_mode_f32() { 289 | let arr = array![[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]; 290 | let kernel = array![[1, 0], [3, 1]]; 291 | 292 | let conv_result = arr 293 | .conv(&kernel, ConvMode::Same, PaddingMode::Replicate) 294 | .unwrap(); 295 | 296 | let fft_result = arr 297 | .map(|&x| x as f64) 298 | .conv_fft( 299 | &kernel.map(|&x| x as f64), 300 | ConvMode::Same, 301 | PaddingMode::Replicate, 302 | ) 303 | .unwrap(); 304 | 305 | assert_fft_matches_conv_f64(fft_result, conv_result); 306 | } 307 | 308 | #[test] 309 | fn custom_mode_with_dilation() { 310 | let arr = array![[1, 2], [3, 4]]; 311 | let kernel = array![[1, 0], [3, 1]]; 312 | 313 | let conv_result = arr 314 | .conv( 315 | kernel.with_dilation(2).no_reverse(), 316 | ConvMode::Custom { 317 | padding: [3, 3], 318 | strides: [2, 2], 319 | }, 320 | PaddingMode::Replicate, 321 | ) 322 | .unwrap(); 323 | 324 | let fft_result_f64 = arr 325 | .map(|&x| x as f64) 326 | .conv_fft( 327 | kernel.map(|&x| x as f64).with_dilation(2).no_reverse(), 328 | ConvMode::Custom { 329 | padding: [3, 3], 330 | strides: [2, 2], 331 | }, 332 | PaddingMode::Replicate, 333 | ) 334 | .unwrap(); 335 | 336 | assert_fft_matches_conv_f64(fft_result_f64, conv_result); 337 | } 338 | 339 | #[test] 340 | fn custom_mode_complex() { 341 | // Test with actual complex numbers (non-zero imaginary parts) 342 | let arr_complex = array![ 343 | [Complex::new(1.0, 0.2), Complex::new(2.0, -0.3)], 344 | [Complex::new(3.0, 0.5), Complex::new(4.0, -0.1)] 345 | ]; 346 | let kernel_complex = array![ 347 | [Complex::new(1.0, 0.1), Complex::new(0.0, 0.2)], 348 | [Complex::new(3.0, -0.2), Complex::new(1.0, 0.15)] 349 | ]; 350 | 351 | let conv_result = arr_complex 352 | .conv_fft( 353 | kernel_complex.with_dilation(2).no_reverse(), 354 | ConvMode::Custom { 355 | padding: [3, 3], 356 | strides: [2, 2], 357 | }, 358 | PaddingMode::Replicate, 359 | ) 360 | .unwrap(); 361 | 362 | let fft_result = arr_complex 363 | .conv_fft( 364 | kernel_complex.with_dilation(2).no_reverse(), 365 | ConvMode::Custom { 366 | padding: [3, 3], 367 | strides: [2, 2], 368 | }, 369 | PaddingMode::Replicate, 370 | ) 371 | .unwrap(); 372 | 373 | assert_fft_matches_conv_complex_f64(fft_result, conv_result); 374 | } 375 | 376 | #[test] 377 | fn full_mode() { 378 | let arr = array![[1, 2], [3, 4]]; 379 | let kernel = array![[1, 1], [1, 1]]; 380 | 381 | let conv_result = arr 382 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 383 | .unwrap(); 384 | 385 | let fft_result = arr 386 | .map(|&x| x as f32) 387 | .conv_fft( 388 | &kernel.map(|&x| x as f32), 389 | ConvMode::Full, 390 | PaddingMode::Zeros, 391 | ) 392 | .unwrap(); 393 | 394 | assert_fft_matches_conv_f32(fft_result, conv_result); 395 | } 396 | 397 | #[test] 398 | fn valid_mode() { 399 | let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 400 | let kernel = array![[1, 1], [1, 1]]; 401 | 402 | let conv_result = arr 403 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 404 | .unwrap(); 405 | 406 | let fft_result = arr 407 | .map(|&x| x as f64) 408 | .conv_fft( 409 | &kernel.map(|&x| x as f64), 410 | ConvMode::Valid, 411 | PaddingMode::Zeros, 412 | ) 413 | .unwrap(); 414 | 415 | assert_fft_matches_conv_f64(fft_result, conv_result); 416 | } 417 | } 418 | 419 | // ----- 3D Tests ----- 420 | 421 | mod three_d { 422 | use super::*; 423 | 424 | #[test] 425 | fn same_mode_f32() { 426 | let arr = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; 427 | let kernel = array![ 428 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 429 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 430 | ]; 431 | 432 | let conv_result = arr 433 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 434 | .unwrap(); 435 | 436 | let fft_result = arr 437 | .map(|&x| x as f32) 438 | .conv_fft( 439 | &kernel.map(|&x| x as f32), 440 | ConvMode::Same, 441 | PaddingMode::Zeros, 442 | ) 443 | .unwrap(); 444 | 445 | assert_fft_matches_conv_f32(fft_result, conv_result); 446 | } 447 | 448 | #[test] 449 | fn same_mode_complex() { 450 | // Test with actual complex numbers (non-zero imaginary parts) 451 | let arr_complex = array![ 452 | [ 453 | [Complex::new(1.0, 0.3), Complex::new(2.0, -0.2)], 454 | [Complex::new(3.0, 0.5), Complex::new(4.0, -0.4)] 455 | ], 456 | [ 457 | [Complex::new(5.0, 0.1), Complex::new(6.0, -0.3)], 458 | [Complex::new(7.0, 0.6), Complex::new(8.0, -0.1)] 459 | ] 460 | ]; 461 | let kernel_complex = array![ 462 | [ 463 | [ 464 | Complex::new(1.0, 0.05), 465 | Complex::new(1.0, -0.05), 466 | Complex::new(1.0, 0.1) 467 | ], 468 | [ 469 | Complex::new(1.0, -0.1), 470 | Complex::new(1.0, 0.15), 471 | Complex::new(1.0, -0.15) 472 | ], 473 | [ 474 | Complex::new(1.0, 0.2), 475 | Complex::new(1.0, -0.2), 476 | Complex::new(1.0, 0.05) 477 | ] 478 | ], 479 | [ 480 | [ 481 | Complex::new(1.0, -0.05), 482 | Complex::new(1.0, 0.1), 483 | Complex::new(1.0, -0.1) 484 | ], 485 | [ 486 | Complex::new(1.0, 0.15), 487 | Complex::new(1.0, -0.15), 488 | Complex::new(1.0, 0.2) 489 | ], 490 | [ 491 | Complex::new(1.0, -0.2), 492 | Complex::new(1.0, 0.05), 493 | Complex::new(1.0, -0.05) 494 | ] 495 | ], 496 | ]; 497 | 498 | let conv_result = arr_complex 499 | .conv_fft(&kernel_complex, ConvMode::Same, PaddingMode::Zeros) 500 | .unwrap(); 501 | 502 | let fft_result = arr_complex 503 | .conv_fft(&kernel_complex, ConvMode::Same, PaddingMode::Zeros) 504 | .unwrap(); 505 | 506 | assert_fft_matches_conv_complex(fft_result, conv_result); 507 | } 508 | 509 | #[test] 510 | fn full_mode() { 511 | let arr = array![[[1, 2]], [[3, 4]]]; 512 | let kernel = array![[[1, 1]], [[1, 1]]]; 513 | 514 | let conv_result = arr 515 | .conv(&kernel, ConvMode::Full, PaddingMode::Zeros) 516 | .unwrap(); 517 | 518 | let fft_result = arr 519 | .map(|&x| x as f64) 520 | .conv_fft( 521 | &kernel.map(|&x| x as f64), 522 | ConvMode::Full, 523 | PaddingMode::Zeros, 524 | ) 525 | .unwrap(); 526 | 527 | assert_fft_matches_conv_f64(fft_result, conv_result); 528 | } 529 | 530 | #[test] 531 | fn valid_mode() { 532 | let arr = array![[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]; 533 | let kernel = array![[[1, 1]], [[1, 1]]]; 534 | 535 | let conv_result = arr 536 | .conv(&kernel, ConvMode::Valid, PaddingMode::Zeros) 537 | .unwrap(); 538 | 539 | let fft_result = arr 540 | .map(|&x| x as f32) 541 | .conv_fft( 542 | &kernel.map(|&x| x as f32), 543 | ConvMode::Valid, 544 | PaddingMode::Zeros, 545 | ) 546 | .unwrap(); 547 | 548 | assert_fft_matches_conv_f32(fft_result, conv_result); 549 | } 550 | } 551 | 552 | // ----- Different Padding Modes ----- 553 | 554 | mod padding_modes { 555 | use super::*; 556 | 557 | #[test] 558 | fn replicate_2d() { 559 | let arr = array![[1, 2, 3], [4, 5, 6]]; 560 | let kernel = array![[1, 1], [1, 1]]; 561 | 562 | let conv_result = arr 563 | .conv(&kernel, ConvMode::Same, PaddingMode::Replicate) 564 | .unwrap(); 565 | 566 | let fft_result = arr 567 | .map(|&x| x as f32) 568 | .conv_fft( 569 | &kernel.map(|&x| x as f32), 570 | ConvMode::Same, 571 | PaddingMode::Replicate, 572 | ) 573 | .unwrap(); 574 | 575 | assert_fft_matches_conv_f32(fft_result, conv_result); 576 | } 577 | 578 | #[test] 579 | fn zeros_2d() { 580 | let arr = array![[1, 2, 3], [4, 5, 6]]; 581 | let kernel = array![[1, 1], [1, 1]]; 582 | 583 | let conv_result = arr 584 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 585 | .unwrap(); 586 | 587 | let fft_result = arr 588 | .map(|&x| x as f64) 589 | .conv_fft( 590 | &kernel.map(|&x| x as f64), 591 | ConvMode::Same, 592 | PaddingMode::Zeros, 593 | ) 594 | .unwrap(); 595 | 596 | assert_fft_matches_conv_f64(fft_result, conv_result); 597 | } 598 | 599 | #[test] 600 | fn const_padding_2d() { 601 | let arr = array![[1, 2], [3, 4]]; 602 | let kernel = array![[1, 1], [1, 1]]; 603 | 604 | let conv_result = arr 605 | .conv(&kernel, ConvMode::Full, PaddingMode::Const(7)) 606 | .unwrap(); 607 | 608 | let fft_result = arr 609 | .map(|&x| x as f32) 610 | .conv_fft( 611 | &kernel.map(|&x| x as f32), 612 | ConvMode::Full, 613 | PaddingMode::Const(7.0), 614 | ) 615 | .unwrap(); 616 | 617 | assert_fft_matches_conv_f32(fft_result, conv_result); 618 | } 619 | } 620 | } 621 | 622 | // ===== Edge Cases ===== 623 | 624 | mod edge_cases { 625 | use super::*; 626 | 627 | #[test] 628 | fn single_element() { 629 | let arr = array![42]; 630 | let kernel = array![2]; 631 | 632 | let conv_result = arr 633 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 634 | .unwrap(); 635 | 636 | let fft_result = arr 637 | .map(|&x| x as f32) 638 | .conv_fft( 639 | &kernel.map(|&x| x as f32), 640 | ConvMode::Same, 641 | PaddingMode::Zeros, 642 | ) 643 | .unwrap(); 644 | 645 | assert_eq!(fft_result.map(|x| x.round() as i32), conv_result); 646 | } 647 | 648 | #[test] 649 | fn large_array_2d() { 650 | // Test with a larger array to ensure FFT is actually used 651 | let arr = Array::from_shape_fn((50, 50), |(i, j)| ((i + j) % 10) as i32); 652 | let kernel = array![[1, 2, 1], [2, 4, 2], [1, 2, 1]]; 653 | 654 | let conv_result = arr 655 | .conv(&kernel, ConvMode::Same, PaddingMode::Zeros) 656 | .unwrap(); 657 | 658 | let fft_result = arr 659 | .map(|&x| x as f64) 660 | .conv_fft( 661 | &kernel.map(|&x| x as f64), 662 | ConvMode::Same, 663 | PaddingMode::Zeros, 664 | ) 665 | .unwrap(); 666 | 667 | // Check a sample of points 668 | for i in 0..5 { 669 | for j in 0..5 { 670 | let diff = (fft_result[[i, j]].round() - conv_result[[i, j]] as f64).abs(); 671 | assert!( 672 | diff < 1e-8, 673 | "Mismatch at [{}, {}]: FFT={:.6}, Conv={}", 674 | i, 675 | j, 676 | fft_result[[i, j]], 677 | conv_result[[i, j]] 678 | ); 679 | } 680 | } 681 | } 682 | } 683 | -------------------------------------------------------------------------------- /src/padding/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides padding functionality for ndarray arrays. 2 | //! 3 | //! This module defines the `PaddingExt` trait, which extends the `ArrayBase` 4 | //! struct from the `ndarray` crate with methods for padding arrays using 5 | //! different padding modes. It also provides helper functions for 6 | //! applying specific types of padding. 7 | 8 | use super::{BorderType, PaddingMode}; 9 | 10 | use ndarray::{ 11 | Array, ArrayBase, Data, DataMut, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo, 12 | SliceInfoElem, 13 | }; 14 | use num::traits::NumAssign; 15 | 16 | pub(crate) mod dim; 17 | mod half_dim; 18 | 19 | /// Represents explicit padding sizes for each dimension. 20 | pub type ExplicitPadding = [[usize; 2]; N]; 21 | 22 | /// Extends `ndarray`'s `ArrayBase` with padding operations. 23 | /// 24 | /// This trait provides the `padding` and `padding_in` methods for adding 25 | /// padding to an array using various modes, like zero padding, constant 26 | /// padding, replication, reflection, and circular padding. 27 | /// 28 | /// # Type Parameters 29 | /// 30 | /// * `N`: The number of dimensions of the array. 31 | /// * `T`: The numeric type of the array elements. 32 | /// * `Output`: The type of the padded array returned by `padding`, typically an `Array>`. 33 | pub trait PaddingExt { 34 | /// Returns a new array with the specified padding applied. 35 | /// 36 | /// This method creates a new array with the dimensions and padding specified by 37 | /// `mode` and `padding_size`. It calls the `padding_in` method internally to handle the padding itself. 38 | /// 39 | /// # Arguments 40 | /// 41 | /// * `mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`). 42 | /// * `padding_size`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 43 | /// 44 | /// # Returns 45 | /// A new `Array` with the padded data. 46 | fn padding(&self, mode: PaddingMode, padding_size: ExplicitPadding) -> Output; 47 | 48 | /// Modifies the buffer in-place by applying padding using the specified mode. 49 | /// 50 | /// This method directly modifies the provided buffer by adding padding to its content. 51 | /// 52 | /// # Type Parameters 53 | /// 54 | /// * `SO`: The data storage type of the output buffer. 55 | /// * `DO`: The dimension type of the output buffer. 56 | /// 57 | /// # Arguments 58 | /// 59 | /// * `buffer`: A mutable reference to the array to be padded. 60 | /// * `mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`). 61 | /// * `padding_size`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 62 | fn padding_in, DO: RemoveAxis>( 63 | &self, 64 | buffer: &mut ArrayBase, 65 | mode: PaddingMode, 66 | padding_size: ExplicitPadding, 67 | ) where 68 | T: NumAssign + Copy, 69 | [Ix; N]: IntoDimension>, 70 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 71 | Dim<[Ix; N]>: RemoveAxis, 72 | SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg; 73 | } 74 | 75 | impl PaddingExt>> for ArrayBase 76 | where 77 | T: NumAssign + Copy, 78 | S: Data, 79 | [Ix; N]: IntoDimension>, 80 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 81 | Dim<[Ix; N]>: RemoveAxis, 82 | D: RemoveAxis + IntoDimension, 83 | { 84 | fn padding( 85 | &self, 86 | mode: PaddingMode, 87 | explicit_padding: ExplicitPadding, 88 | ) -> Array> { 89 | let c = match mode { 90 | PaddingMode::Const(c) => c, 91 | _ => T::zero(), 92 | }; 93 | 94 | let raw_dim = self.raw_dim(); 95 | 96 | let output_dim = 97 | std::array::from_fn(|i| raw_dim[i] + explicit_padding[i][0] + explicit_padding[i][1]); 98 | 99 | let mut output: Array> = Array::from_elem(output_dim, c); 100 | 101 | padding_const(self, &mut output, explicit_padding); 102 | 103 | match mode { 104 | PaddingMode::Replicate => padding_replicate(self, &mut output, explicit_padding), 105 | PaddingMode::Reflect => padding_reflect(self, &mut output, explicit_padding), 106 | PaddingMode::Circular => padding_circular(self, &mut output, explicit_padding), 107 | PaddingMode::Custom(borders) => { 108 | padding_custom(self, &mut output, explicit_padding, borders) 109 | } 110 | PaddingMode::Explicit(borders) => { 111 | padding_explicit(self, &mut output, explicit_padding, borders) 112 | } 113 | _ => {} 114 | }; 115 | 116 | output 117 | } 118 | 119 | fn padding_in( 120 | &self, 121 | buffer: &mut ArrayBase, 122 | mode: PaddingMode, 123 | explicit_padding: ExplicitPadding, 124 | ) where 125 | T: NumAssign + Copy, 126 | S: Data, 127 | SO: DataMut, 128 | [Ix; N]: IntoDimension>, 129 | SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg, 130 | Dim<[Ix; N]>: RemoveAxis, 131 | DO: RemoveAxis, 132 | { 133 | padding_const(self, buffer, explicit_padding); 134 | 135 | match mode { 136 | PaddingMode::Const(c) => { 137 | explicit_padding 138 | .iter() 139 | .enumerate() 140 | .for_each(|(dim, &explicit_padding)| { 141 | dim::constant(self.raw_dim(), buffer, dim, explicit_padding, c); 142 | }) 143 | } 144 | PaddingMode::Replicate => padding_replicate(self, buffer, explicit_padding), 145 | PaddingMode::Reflect => padding_reflect(self, buffer, explicit_padding), 146 | PaddingMode::Circular => padding_circular(self, buffer, explicit_padding), 147 | PaddingMode::Custom(borders) => padding_custom(self, buffer, explicit_padding, borders), 148 | PaddingMode::Explicit(borders) => { 149 | padding_explicit(self, buffer, explicit_padding, borders) 150 | } 151 | _ => {} 152 | }; 153 | } 154 | } 155 | 156 | /// Applies padding using a constant value to the specified slice of the output array. 157 | /// 158 | /// This function copies the input array to a specific slice of the output array, leaving the rest of the 159 | /// output array with the default padding value, which is typically a zero or a constant, depending on the padding mode. 160 | /// 161 | /// # Type Parameters 162 | /// 163 | /// * `N`: The number of dimensions of the array. 164 | /// * `T`: The numeric type of the array elements. 165 | /// * `S`: The data storage type of the input array. 166 | /// * `D`: The dimension type of the input array. 167 | /// * `SO`: The data storage type of the output array. 168 | /// * `DO`: The dimension type of the output array. 169 | /// 170 | /// # Arguments 171 | /// 172 | /// * `input`: The input array to pad. 173 | /// * `output`: A mutable reference to the array where the padded result will be stored. 174 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 175 | pub(crate) fn padding_const( 176 | input: &ArrayBase, 177 | output: &mut ArrayBase, 178 | explicit_padding: ExplicitPadding, 179 | ) where 180 | T: NumAssign + Copy, 181 | S: Data, 182 | SO: DataMut, 183 | [Ix; N]: IntoDimension>, 184 | // SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 185 | SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg, 186 | Dim<[Ix; N]>: RemoveAxis, 187 | D: RemoveAxis, 188 | DO: RemoveAxis, 189 | { 190 | let mut output_slice = output.slice_mut(unsafe { 191 | SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice { 192 | start: explicit_padding[i][0] as isize, 193 | end: Some((explicit_padding[i][0] + input.raw_dim()[i]) as isize), 194 | step: 1, 195 | })) 196 | .unwrap() 197 | }); 198 | 199 | output_slice.assign(input); 200 | } 201 | 202 | /// Applies replicate padding to the specified slice of the output array. 203 | /// 204 | /// This function uses the `dim::replicate` function to add replicate padding 205 | /// to each dimension of the output array. 206 | /// 207 | /// # Type Parameters 208 | /// 209 | /// * `N`: The number of dimensions of the array. 210 | /// * `T`: The numeric type of the array elements. 211 | /// * `S`: The data storage type of the input array. 212 | /// * `D`: The dimension type of the input array. 213 | /// * `SO`: The data storage type of the output array. 214 | /// * `DO`: The dimension type of the output array. 215 | /// 216 | /// # Arguments 217 | /// 218 | /// * `input`: The input array to pad. 219 | /// * `output`: A mutable reference to the array where the padded result will be stored. 220 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 221 | fn padding_replicate( 222 | input: &ArrayBase, 223 | output: &mut ArrayBase, 224 | explicit_padding: ExplicitPadding, 225 | ) where 226 | T: NumAssign + Copy, 227 | S: Data, 228 | SO: DataMut, 229 | [Ix; N]: IntoDimension>, 230 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 231 | Dim<[Ix; N]>: RemoveAxis, 232 | D: RemoveAxis + IntoDimension, 233 | DO: RemoveAxis, 234 | { 235 | explicit_padding 236 | .iter() 237 | .enumerate() 238 | .for_each(|(dim, &explicit_padding)| { 239 | dim::replicate(input.raw_dim(), output, dim, explicit_padding); 240 | }); 241 | } 242 | 243 | /// Applies reflect padding to the specified slice of the output array. 244 | /// 245 | /// This function uses the `dim::reflect` function to add reflect padding 246 | /// to each dimension of the output array. 247 | /// 248 | /// # Type Parameters 249 | /// 250 | /// * `N`: The number of dimensions of the array. 251 | /// * `T`: The numeric type of the array elements. 252 | /// * `S`: The data storage type of the input array. 253 | /// * `D`: The dimension type of the input array. 254 | /// * `SO`: The data storage type of the output array. 255 | /// * `DO`: The dimension type of the output array. 256 | /// 257 | /// # Arguments 258 | /// 259 | /// * `input`: The input array to pad. 260 | /// * `output`: A mutable reference to the array where the padded result will be stored. 261 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 262 | fn padding_reflect( 263 | input: &ArrayBase, 264 | output: &mut ArrayBase, 265 | explicit_padding: ExplicitPadding, 266 | ) where 267 | T: NumAssign + Copy, 268 | S: Data, 269 | SO: DataMut, 270 | [Ix; N]: IntoDimension>, 271 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 272 | Dim<[Ix; N]>: RemoveAxis, 273 | D: RemoveAxis, 274 | DO: RemoveAxis, 275 | { 276 | explicit_padding 277 | .iter() 278 | .enumerate() 279 | .for_each(|(dim, &explicit_padding)| { 280 | dim::reflect(input.raw_dim(), output, dim, explicit_padding); 281 | }); 282 | } 283 | 284 | /// Applies circular padding to the specified slice of the output array. 285 | /// 286 | /// This function uses the `dim::circular` function to add circular padding 287 | /// to each dimension of the output array. 288 | /// 289 | /// # Type Parameters 290 | /// 291 | /// * `N`: The number of dimensions of the array. 292 | /// * `T`: The numeric type of the array elements. 293 | /// * `S`: The data storage type of the input array. 294 | /// * `D`: The dimension type of the input array. 295 | /// * `SO`: The data storage type of the output array. 296 | /// * `DO`: The dimension type of the output array. 297 | /// 298 | /// # Arguments 299 | /// 300 | /// * `input`: The input array to pad. 301 | /// * `output`: A mutable reference to the array where the padded result will be stored. 302 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 303 | fn padding_circular( 304 | input: &ArrayBase, 305 | output: &mut ArrayBase, 306 | explicit_padding: ExplicitPadding, 307 | ) where 308 | T: NumAssign + Copy, 309 | S: Data, 310 | SO: DataMut, 311 | [Ix; N]: IntoDimension>, 312 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 313 | Dim<[Ix; N]>: RemoveAxis, 314 | D: RemoveAxis, 315 | DO: RemoveAxis, 316 | { 317 | explicit_padding 318 | .iter() 319 | .enumerate() 320 | .for_each(|(dim, &explicit_padding)| { 321 | dim::circular(input.raw_dim(), output, dim, explicit_padding); 322 | }); 323 | } 324 | 325 | /// Applies custom padding to the specified slice of the output array using `BorderType` for each dimension. 326 | /// 327 | /// This function uses the `dim::constant`, `dim::reflect`, `dim::replicate`, 328 | /// or `dim::circular` function based on the corresponding `BorderType` specified in the `borders` argument, 329 | /// to add padding to each dimension of the output array. 330 | /// 331 | /// # Type Parameters 332 | /// 333 | /// * `N`: The number of dimensions of the array. 334 | /// * `T`: The numeric type of the array elements. 335 | /// * `S`: The data storage type of the input array. 336 | /// * `D`: The dimension type of the input array. 337 | /// * `SO`: The data storage type of the output array. 338 | /// * `DO`: The dimension type of the output array. 339 | /// 340 | /// # Arguments 341 | /// 342 | /// * `input`: The input array to pad. 343 | /// * `output`: A mutable reference to the array where the padded result will be stored. 344 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 345 | /// * `borders`: An array containing a `BorderType` enum for each dimension. 346 | fn padding_custom( 347 | input: &ArrayBase, 348 | output: &mut ArrayBase, 349 | explicit_padding: ExplicitPadding, 350 | borders: [BorderType; N], 351 | ) where 352 | T: NumAssign + Copy, 353 | S: Data, 354 | SO: DataMut, 355 | [Ix; N]: IntoDimension>, 356 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 357 | Dim<[Ix; N]>: RemoveAxis, 358 | D: RemoveAxis, 359 | DO: RemoveAxis, 360 | { 361 | explicit_padding 362 | .iter() 363 | .zip(borders.iter()) 364 | .enumerate() 365 | .for_each(|(dim, (&explicit_padding, border))| match border { 366 | BorderType::Zeros => { 367 | dim::constant(input.raw_dim(), output, dim, explicit_padding, T::zero()) 368 | } 369 | BorderType::Const(c) => { 370 | dim::constant(input.raw_dim(), output, dim, explicit_padding, *c) 371 | } 372 | BorderType::Reflect => dim::reflect(input.raw_dim(), output, dim, explicit_padding), 373 | BorderType::Replicate => dim::replicate(input.raw_dim(), output, dim, explicit_padding), 374 | BorderType::Circular => dim::circular(input.raw_dim(), output, dim, explicit_padding), 375 | }); 376 | } 377 | 378 | /// Applies explicit padding to the specified slice of the output array using `BorderType` for each side of each dimension. 379 | /// 380 | /// This function uses the `half_dim::constant_front`, `half_dim::constant_back`, 381 | /// `half_dim::reflect_front`, `half_dim::reflect_back`, `half_dim::replicate_front`, 382 | /// `half_dim::replicate_back`, `half_dim::circular_front`, and `half_dim::circular_back` 383 | /// functions based on the corresponding `BorderType` specified in the `borders` argument, 384 | /// to add padding to each dimension of the output array. 385 | /// 386 | /// # Type Parameters 387 | /// 388 | /// * `N`: The number of dimensions of the array. 389 | /// * `T`: The numeric type of the array elements. 390 | /// * `S`: The data storage type of the input array. 391 | /// * `D`: The dimension type of the input array. 392 | /// * `SO`: The data storage type of the output array. 393 | /// * `DO`: The dimension type of the output array. 394 | /// 395 | /// # Arguments 396 | /// 397 | /// * `input`: The input array to pad. 398 | /// * `output`: A mutable reference to the array where the padded result will be stored. 399 | /// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`. 400 | /// * `borders`: An array containing an array of two `BorderType` enums for each dimension. 401 | fn padding_explicit( 402 | input: &ArrayBase, 403 | output: &mut ArrayBase, 404 | explicit_padding: ExplicitPadding, 405 | borders: [[BorderType; 2]; N], 406 | ) where 407 | T: NumAssign + Copy, 408 | S: Data, 409 | SO: DataMut, 410 | [Ix; N]: IntoDimension>, 411 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 412 | Dim<[Ix; N]>: RemoveAxis, 413 | D: RemoveAxis, 414 | DO: RemoveAxis, 415 | { 416 | explicit_padding 417 | .iter() 418 | .zip(borders.iter()) 419 | .enumerate() 420 | .for_each(|(dim, (&explicit_padding, border))| { 421 | match border[0] { 422 | BorderType::Zeros => { 423 | half_dim::constant_front(output, dim, explicit_padding, T::zero()) 424 | } 425 | BorderType::Const(c) => half_dim::constant_front(output, dim, explicit_padding, c), 426 | BorderType::Reflect => half_dim::reflect_front(output, dim, explicit_padding), 427 | BorderType::Replicate => half_dim::replicate_front(output, dim, explicit_padding), 428 | BorderType::Circular => half_dim::circular_front(output, dim, explicit_padding), 429 | } 430 | match border[1] { 431 | BorderType::Zeros => half_dim::constant_back( 432 | input.raw_dim(), 433 | output, 434 | dim, 435 | explicit_padding, 436 | T::zero(), 437 | ), 438 | BorderType::Const(c) => { 439 | half_dim::constant_back(input.raw_dim(), output, dim, explicit_padding, c) 440 | } 441 | BorderType::Reflect => { 442 | half_dim::reflect_back(input.raw_dim(), output, dim, explicit_padding) 443 | } 444 | BorderType::Replicate => { 445 | half_dim::replicate_back(input.raw_dim(), output, dim, explicit_padding) 446 | } 447 | BorderType::Circular => { 448 | half_dim::circular_back(input.raw_dim(), output, dim, explicit_padding) 449 | } 450 | } 451 | }); 452 | } 453 | 454 | #[cfg(test)] 455 | mod tests { 456 | use ndarray::prelude::*; 457 | 458 | use super::*; 459 | use crate::dilation::IntoKernelWithDilation; 460 | use crate::ConvMode; 461 | 462 | // ===== Basic Padding Tests ===== 463 | 464 | mod zeros_padding { 465 | use super::*; 466 | 467 | #[test] 468 | fn test_1d() { 469 | let arr = array![1, 2, 3]; 470 | let explicit_padding = [[1, 1]]; 471 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 472 | assert_eq!(padded, array![0, 1, 2, 3, 0]); 473 | } 474 | 475 | #[test] 476 | fn test_2d() { 477 | let arr = array![[1, 2], [3, 4]]; 478 | let explicit_padding = [[1, 1], [1, 1]]; 479 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 480 | assert_eq!( 481 | padded, 482 | array![[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]] 483 | ); 484 | } 485 | 486 | #[test] 487 | fn test_3d() { 488 | let arr = array![[[1, 2]], [[3, 4]]]; 489 | let explicit_padding = [[1, 0], [0, 1], [1, 0]]; 490 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 491 | // Shape: [2, 1, 2] -> [3, 2, 3] 492 | // dim 0: padding [1, 0] => add 1 layer before 493 | // dim 1: padding [0, 1] => add 1 layer after 494 | // dim 2: padding [1, 0] => add 1 column before each row 495 | assert_eq!( 496 | padded, 497 | array![ 498 | [[0, 0, 0], [0, 0, 0]], // padded layer at front (dim 0) 499 | [[0, 1, 2], [0, 0, 0]], // original [[[1, 2]]] with padding 500 | [[0, 3, 4], [0, 0, 0]] // original [[[3, 4]]] with padding 501 | ] 502 | ); 503 | } 504 | 505 | #[test] 506 | fn test_asymmetric_padding() { 507 | let arr = array![1, 2, 3]; 508 | let explicit_padding = [[2, 1]]; 509 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 510 | assert_eq!(padded, array![0, 0, 1, 2, 3, 0]); 511 | } 512 | } 513 | 514 | mod const_padding { 515 | use super::*; 516 | 517 | #[test] 518 | fn test_1d() { 519 | let arr = array![1, 2, 3]; 520 | let explicit_padding = [[1, 1]]; 521 | let padded = arr.padding(PaddingMode::Const(7), explicit_padding); 522 | assert_eq!(padded, array![7, 1, 2, 3, 7]); 523 | } 524 | 525 | #[test] 526 | fn test_2d() { 527 | let arr = array![[1, 2], [3, 4]]; 528 | let explicit_padding = [[1, 1], [1, 1]]; 529 | let padded = arr.padding(PaddingMode::Const(9), explicit_padding); 530 | assert_eq!( 531 | padded, 532 | array![[9, 9, 9, 9], [9, 1, 2, 9], [9, 3, 4, 9], [9, 9, 9, 9]] 533 | ); 534 | } 535 | } 536 | 537 | mod replicate_padding { 538 | use super::*; 539 | 540 | #[test] 541 | fn test_1d() { 542 | let arr = array![1, 2, 3]; 543 | let explicit_padding = [[1, 2]]; 544 | let padded = arr.padding(PaddingMode::Replicate, explicit_padding); 545 | assert_eq!(padded, array![1, 1, 2, 3, 3, 3]); 546 | } 547 | 548 | #[test] 549 | fn test_2d() { 550 | let arr = array![[1, 2], [3, 4]]; 551 | let explicit_padding = [[1, 1], [1, 1]]; 552 | let padded = arr.padding(PaddingMode::Replicate, explicit_padding); 553 | assert_eq!( 554 | padded, 555 | array![[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]] 556 | ); 557 | } 558 | 559 | #[test] 560 | fn test_large_padding() { 561 | let arr = array![1, 2]; 562 | let explicit_padding = [[3, 3]]; 563 | let padded = arr.padding(PaddingMode::Replicate, explicit_padding); 564 | assert_eq!(padded, array![1, 1, 1, 1, 2, 2, 2, 2]); 565 | } 566 | } 567 | 568 | mod reflect_padding { 569 | use super::*; 570 | 571 | #[test] 572 | fn test_1d() { 573 | let arr = array![1, 2, 3, 4]; 574 | let explicit_padding = [[2, 2]]; 575 | let padded = arr.padding(PaddingMode::Reflect, explicit_padding); 576 | assert_eq!(padded, array![3, 2, 1, 2, 3, 4, 3, 2]); 577 | } 578 | 579 | #[test] 580 | fn test_2d() { 581 | let arr = array![[1, 2, 3], [4, 5, 6]]; 582 | let explicit_padding = [[1, 1], [1, 1]]; 583 | let padded = arr.padding(PaddingMode::Reflect, explicit_padding); 584 | assert_eq!( 585 | padded, 586 | array![ 587 | [5, 4, 5, 6, 5], 588 | [2, 1, 2, 3, 2], 589 | [5, 4, 5, 6, 5], 590 | [2, 1, 2, 3, 2] 591 | ] 592 | ); 593 | } 594 | } 595 | 596 | mod circular_padding { 597 | use super::*; 598 | 599 | #[test] 600 | fn test_1d() { 601 | let arr = array![1, 2, 3, 4]; 602 | let explicit_padding = [[2, 2]]; 603 | let padded = arr.padding(PaddingMode::Circular, explicit_padding); 604 | assert_eq!(padded, array![3, 4, 1, 2, 3, 4, 1, 2]); 605 | } 606 | 607 | #[test] 608 | fn test_2d() { 609 | let arr = array![[1, 2], [3, 4]]; 610 | let explicit_padding = [[1, 1], [1, 1]]; 611 | let padded = arr.padding(PaddingMode::Circular, explicit_padding); 612 | assert_eq!( 613 | padded, 614 | array![[4, 3, 4, 3], [2, 1, 2, 1], [4, 3, 4, 3], [2, 1, 2, 1]] 615 | ); 616 | } 617 | 618 | #[test] 619 | fn test_type_cast_safety() { 620 | // Regression test for issue with type casting in circular padding 621 | let arr = array![1u8, 2, 3]; 622 | let explicit_padding = [[1, 1]]; 623 | let padded = arr.padding(PaddingMode::Circular, explicit_padding); 624 | assert_eq!(padded, array![3u8, 1, 2, 3, 1]); 625 | } 626 | } 627 | 628 | mod custom_padding { 629 | use super::*; 630 | 631 | #[test] 632 | fn test_per_dimension() { 633 | let arr = array![[1, 2], [3, 4]]; 634 | let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]]; 635 | let kernel = kernel.into_kernel_with_dilation(); 636 | 637 | let explicit_conv = ConvMode::Full.unfold(&kernel); 638 | let explicit_padding = explicit_conv.padding; 639 | 640 | let arr_padded = arr.padding( 641 | PaddingMode::Custom([BorderType::Replicate, BorderType::Circular]), 642 | explicit_padding, 643 | ); 644 | assert_eq!( 645 | arr_padded, 646 | array![ 647 | [1, 2, 1, 2, 1, 2], 648 | [1, 2, 1, 2, 1, 2], 649 | [1, 2, 1, 2, 1, 2], 650 | [3, 4, 3, 4, 3, 4], 651 | [3, 4, 3, 4, 3, 4], 652 | [3, 4, 3, 4, 3, 4] 653 | ] 654 | ); 655 | } 656 | 657 | #[test] 658 | fn test_mixed_types() { 659 | let arr = array![[1, 2], [3, 4]]; 660 | let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]]; 661 | let kernel = kernel.into_kernel_with_dilation(); 662 | 663 | let explicit_conv = ConvMode::Full.unfold(&kernel); 664 | let explicit_padding = explicit_conv.padding; 665 | 666 | let arr_padded = arr.padding( 667 | PaddingMode::Custom([BorderType::Reflect, BorderType::Const(7)]), 668 | explicit_padding, 669 | ); 670 | assert_eq!( 671 | arr_padded, 672 | array![ 673 | [7, 7, 0, 0, 7, 7], 674 | [7, 7, 3, 4, 7, 7], 675 | [7, 7, 1, 2, 7, 7], 676 | [7, 7, 3, 4, 7, 7], 677 | [7, 7, 1, 2, 7, 7], 678 | [7, 7, 3, 4, 7, 7] 679 | ] 680 | ); 681 | } 682 | } 683 | 684 | mod explicit_padding { 685 | use super::*; 686 | 687 | #[test] 688 | fn test_per_side() { 689 | let arr = array![1, 2, 3]; 690 | let explicit_padding = [[1, 2]]; 691 | 692 | // Use different BorderType for each side 693 | let padded = arr.padding( 694 | PaddingMode::Explicit([[BorderType::Const(7), BorderType::Const(9)]]), 695 | explicit_padding, 696 | ); 697 | assert_eq!(padded, array![7, 1, 2, 3, 9, 9]); 698 | } 699 | } 700 | 701 | // ===== Edge Cases ===== 702 | 703 | mod edge_cases { 704 | use super::*; 705 | 706 | #[test] 707 | fn test_zero_padding() { 708 | let arr = array![1, 2, 3]; 709 | let explicit_padding = [[0, 0]]; 710 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 711 | assert_eq!(padded, arr); 712 | } 713 | 714 | #[test] 715 | fn test_single_element() { 716 | let arr = array![42]; 717 | let explicit_padding = [[2, 2]]; 718 | let padded = arr.padding(PaddingMode::Replicate, explicit_padding); 719 | assert_eq!(padded, array![42, 42, 42, 42, 42]); 720 | } 721 | 722 | #[test] 723 | fn test_large_array() { 724 | let arr = Array::from_shape_fn((100, 100), |(i, j)| (i + j) as i32); 725 | let explicit_padding = [[5, 5], [5, 5]]; 726 | let padded = arr.padding(PaddingMode::Zeros, explicit_padding); 727 | 728 | // Verify shape 729 | assert_eq!(padded.shape(), &[110, 110]); 730 | 731 | // Verify padding is zeros 732 | // Top padding 733 | for i in 0..5 { 734 | for j in 0..110 { 735 | assert_eq!(padded[[i, j]], 0); 736 | } 737 | } 738 | // Bottom padding 739 | for i in 105..110 { 740 | for j in 0..110 { 741 | assert_eq!(padded[[i, j]], 0); 742 | } 743 | } 744 | // Left and right padding (middle rows) 745 | for i in 5..105 { 746 | for j in 0..5 { 747 | assert_eq!(padded[[i, j]], 0); 748 | } 749 | for j in 105..110 { 750 | assert_eq!(padded[[i, j]], 0); 751 | } 752 | } 753 | 754 | // Verify original data is preserved 755 | assert_eq!(padded[[5, 5]], arr[[0, 0]]); // top-left 756 | assert_eq!(padded[[54, 54]], arr[[49, 49]]); // middle 757 | assert_eq!(padded[[104, 104]], arr[[99, 99]]); // bottom-right 758 | } 759 | } 760 | 761 | // ===== Torch Verification Tests ===== 762 | 763 | #[test] 764 | fn aligned_with_libtorch() { 765 | // Test all padding modes against torch for 3D 766 | let arr = array![[[1, 2, 3], [3, 4, 5]], [[5, 6, 7], [7, 8, 9]]]; 767 | let kernel = array![ 768 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 769 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 770 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]] 771 | ]; 772 | let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation()); 773 | let explicit_padding = explicit_conv.padding; 774 | check(&arr, PaddingMode::Zeros, explicit_padding); 775 | check(&arr, PaddingMode::Const(7), explicit_padding); 776 | check(&arr, PaddingMode::Replicate, explicit_padding); 777 | check(&arr, PaddingMode::Reflect, explicit_padding); 778 | check(&arr, PaddingMode::Circular, explicit_padding); 779 | 780 | // Test all padding modes against torch for 2D 781 | let arr = array![[1, 2], [3, 4]]; 782 | let kernel = array![[1, 1], [1, 1]]; 783 | let explicit_conv = ConvMode::Full.unfold(&kernel.into_kernel_with_dilation()); 784 | let explicit_padding = explicit_conv.padding; 785 | check(&arr, PaddingMode::Zeros, explicit_padding); 786 | check(&arr, PaddingMode::Const(7), explicit_padding); 787 | check(&arr, PaddingMode::Replicate, explicit_padding); 788 | check(&arr, PaddingMode::Reflect, explicit_padding); 789 | check(&arr, PaddingMode::Circular, explicit_padding); 790 | 791 | // Test all padding modes against torch for 1D 792 | let arr = array![1, 2, 3]; 793 | let kernel = array![1, 1, 1, 1]; 794 | let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation()); 795 | let explicit_padding = explicit_conv.padding; 796 | check(&arr, PaddingMode::Zeros, explicit_padding); 797 | check(&arr, PaddingMode::Const(7), explicit_padding); 798 | check(&arr, PaddingMode::Replicate, explicit_padding); 799 | check(&arr, PaddingMode::Reflect, explicit_padding); 800 | check(&arr, PaddingMode::Circular, explicit_padding); 801 | } 802 | 803 | fn check( 804 | arr: &Array>, 805 | padding_mode: PaddingMode, 806 | explicit_padding: ExplicitPadding, 807 | ) where 808 | T: num::traits::NumAssign + Copy + tch::kind::Element + std::fmt::Debug, 809 | Dim<[Ix; N]>: Dimension, 810 | [Ix; N]: IntoDimension>, 811 | SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg>, 812 | Dim<[Ix; N]>: RemoveAxis, 813 | f64: std::convert::From, 814 | T: num::traits::FromPrimitive, 815 | { 816 | let ndarray_result = arr.padding(padding_mode, explicit_padding); 817 | dbg!(&ndarray_result); 818 | 819 | let shape = [1, 1] 820 | .iter() 821 | .chain(arr.shape()) 822 | .map(|s| *s as i64) 823 | .collect::>(); 824 | let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap()) 825 | .reshape(shape) 826 | .totype(tch::Kind::Float); 827 | 828 | let (mode, value) = match padding_mode { 829 | PaddingMode::Zeros => ("constant", Some(0.0)), 830 | PaddingMode::Const(c) => ("constant", Some(f64::from(c))), 831 | PaddingMode::Replicate => ("replicate", None), 832 | PaddingMode::Reflect => ("reflect", None), 833 | PaddingMode::Circular => ("circular", None), 834 | _ => unreachable!(), 835 | }; 836 | 837 | let tensor_result = tensor 838 | .f_pad( 839 | explicit_padding 840 | .into_iter() 841 | .flatten() 842 | .map(|p| p as i64) 843 | .collect::>(), 844 | mode, 845 | value, 846 | ) 847 | .unwrap(); 848 | 849 | dbg!(&tensor_result); 850 | tensor_result.print(); 851 | 852 | assert_eq!( 853 | ndarray_result.into_raw_vec_and_offset().0, 854 | tensor_result 855 | .reshape(tensor_result.size().iter().product::()) 856 | .iter::() 857 | .unwrap() 858 | .map(|v| T::from_f64(v).unwrap()) 859 | .collect::>() 860 | ); 861 | } 862 | } 863 | --------------------------------------------------------------------------------