├── .github └── workflows │ ├── CI.yml │ ├── audit-on-push.yml │ └── scheduled-audit.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── benchmarks.rs └── src ├── core.rs ├── display.rs ├── iterators.rs ├── lib.rs ├── operators.rs └── storage.rs /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - 'v*.*.*' 9 | pull_request: 10 | types: [ opened, synchronize, reopened ] 11 | branches: 12 | - main 13 | env: 14 | CARGO_TERM_COLOR: always 15 | 16 | jobs: 17 | test: 18 | name: test ${{ matrix.rust }} ${{ matrix.flags }} 19 | runs-on: ubuntu-latest 20 | timeout-minutes: 30 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | rust: [ "stable", "beta", "nightly", "1.65" ] # MSRV 25 | flags: [ "--no-default-features", "", "--all-features" ] 26 | exclude: 27 | # Skip because some features have highest MSRV. 28 | - rust: "1.65" # MSRV 29 | flags: "--all-features" 30 | steps: 31 | - uses: actions/checkout@v3 32 | - uses: dtolnay/rust-toolchain@master 33 | with: 34 | toolchain: ${{ matrix.rust }} 35 | - uses: Swatinem/rust-cache@v2 36 | with: 37 | cache-on-failure: true 38 | # Only run tests on the latest stable and above 39 | - name: check 40 | if: ${{ matrix.rust == '1.65' }} # MSRV 41 | run: cargo check --workspace ${{ matrix.flags }} 42 | - name: test 43 | if: ${{ matrix.rust != '1.65' }} # MSRV 44 | run: cargo test --workspace ${{ matrix.flags }} 45 | 46 | coverage: 47 | name: Code Coverage 48 | runs-on: ubuntu-latest 49 | env: 50 | LLVMCOV_VERSION: 0.5.14 51 | steps: 52 | - name: Checkout repository 53 | uses: actions/checkout@v3 54 | 55 | - name: Install Rust 56 | uses: dtolnay/rust-toolchain@stable 57 | with: 58 | toolchain: stable 59 | override: true 60 | profile: minimal 61 | components: clippy, rustfmt 62 | 63 | - name: Cache rust dependencies 64 | uses: Swatinem/rust-cache@v2 65 | with: 66 | shared-key: rust-cache-hyper-server-coverage-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }}-${{ env.LLVMCOV_VERSION }} 67 | 68 | - name: Install cargo-llvm-cov 69 | run: cargo install cargo-llvm-cov --version=${{ env.LLVMCOV_VERSION }} --locked 70 | 71 | - name: Generate code coverage 72 | run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info 73 | 74 | - name: Upload coverage to Codecov 75 | uses: codecov/codecov-action@v3 76 | with: 77 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 78 | files: lcov.info 79 | fail_ci_if_error: false 80 | 81 | # TODO(Miri presently has errors/detects undefined behavior in the codebase) 82 | # We should fix this and then enable. 83 | # miri: 84 | # name: miri ${{ matrix.flags }} 85 | # runs-on: ubuntu-latest 86 | # timeout-minutes: 30 87 | # strategy: 88 | # fail-fast: false 89 | # matrix: 90 | # flags: [ "--no-default-features", "", "--all-features" ] 91 | # env: 92 | # MIRIFLAGS: -Zmiri-strict-provenance 93 | # steps: 94 | # - uses: actions/checkout@v3 95 | # - uses: dtolnay/rust-toolchain@miri 96 | # - uses: Swatinem/rust-cache@v2 97 | # with: 98 | # cache-on-failure: true 99 | # - run: cargo miri setup ${{ matrix.flags }} 100 | # - run: cargo miri test ${{ matrix.flags }} 101 | 102 | feature-checks: 103 | runs-on: ubuntu-latest 104 | timeout-minutes: 30 105 | steps: 106 | - uses: actions/checkout@v3 107 | - uses: dtolnay/rust-toolchain@stable 108 | - uses: taiki-e/install-action@cargo-hack 109 | - uses: Swatinem/rust-cache@v2 110 | with: 111 | cache-on-failure: true 112 | - name: cargo hack 113 | run: cargo hack check --feature-powerset --depth 2 114 | 115 | clippy: 116 | runs-on: ubuntu-latest 117 | timeout-minutes: 30 118 | steps: 119 | - uses: actions/checkout@v3 120 | - uses: dtolnay/rust-toolchain@clippy 121 | - uses: Swatinem/rust-cache@v2 122 | with: 123 | cache-on-failure: true 124 | - run: cargo clippy --workspace --all-targets --all-features 125 | env: 126 | RUSTFLAGS: -Dwarnings 127 | 128 | docs: 129 | runs-on: ubuntu-latest 130 | timeout-minutes: 30 131 | steps: 132 | - uses: actions/checkout@v3 133 | - uses: dtolnay/rust-toolchain@nightly 134 | - uses: Swatinem/rust-cache@v2 135 | with: 136 | cache-on-failure: true 137 | - run: cargo doc --workspace --all-features --no-deps --document-private-items 138 | env: 139 | RUSTDOCFLAGS: "--cfg docsrs -D warnings" 140 | 141 | fmt: 142 | runs-on: ubuntu-latest 143 | timeout-minutes: 30 144 | steps: 145 | - uses: actions/checkout@v3 146 | - uses: dtolnay/rust-toolchain@nightly 147 | with: 148 | components: rustfmt 149 | - run: cargo fmt --all --check -------------------------------------------------------------------------------- /.github/workflows/audit-on-push.yml: -------------------------------------------------------------------------------- 1 | name: Security audit 2 | on: 3 | push: 4 | paths: 5 | - '**/Cargo.toml' 6 | - '**/Cargo.lock' 7 | jobs: 8 | security_audit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - uses: actions-rs/audit-check@v1 13 | with: 14 | token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/scheduled-audit.yml: -------------------------------------------------------------------------------- 1 | name: Security audit 2 | on: 3 | schedule: 4 | - cron: '0 0 * * *' 5 | jobs: 6 | security_audit: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | - uses: actions-rs/audit-check@v1 11 | with: 12 | token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # IntelliJ RustRover 17 | /.idea 18 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["0xAlcibiades "] 3 | categories = ["mathematics"] 4 | description = "Rust native generic, flexible n-dimensional array." 5 | homepage = "https://github.com/warlock-labs/dimensionals" 6 | keywords = ["math", "matrix", "vector", "tensor", "array"] 7 | license = "MIT" 8 | readme = "README.md" 9 | repository = "https://github.com/warlock-labs/dimensionals" 10 | name = "dimensionals" 11 | version = "0.2.1" 12 | edition = "2021" 13 | 14 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 15 | 16 | [dependencies] 17 | num-traits = "0.2.19" 18 | 19 | [lib] 20 | 21 | [dev-dependencies] 22 | criterion = "0.5.1" 23 | 24 | [[bench]] 25 | name = "benchmarks" 26 | harness = false 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Warlock Labs Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dimensionals 2 | 3 | [![License](https://img.shields.io/crates/l/dimensionals)](https://choosealicense.com/licenses/mit/) 4 | [![Crates.io](https://img.shields.io/crates/v/dimensionals)](https://crates.io/crates/dimensionals) 5 | [![Docs](https://img.shields.io/crates/v/dimensionals?color=blue&label=docs)](https://docs.rs/dimensionals/) 6 | ![CI](https://github.com/warlock-labs/dimensionals/actions/workflows/CI.yml/badge.svg) 7 | 8 | Dimensionals is a Rust library for working with n-dimensional data. It provides a flexible and efficient multidimensional array implementation with a generic storage backend over generic number types. 9 | 10 | ## Features 11 | 12 | - Generic over element type `T` (implementing `Num` and `Copy`), number of dimensions `N`, and storage backend `S` 13 | - Support for Scalar (0D), Vector (1D), Matrix (2D), and Tensor (N>2 D) types 14 | - Efficient `LinearArrayStorage` backend with support for row-major and column-major layouts 15 | - Iterators (immutable and mutable) for efficient traversal 16 | - Indexing and slicing operations 17 | - Arithmetic operations (element-wise and scalar) with operator overloading 18 | - Convenient macros for vector and matrix creation (`vector!` and `matrix!`) 19 | 20 | ## Usage 21 | 22 | Add this to your `Cargo.toml`: 23 | 24 | ```toml 25 | [dependencies] 26 | dimensionals = "0.1.0" 27 | ``` 28 | 29 | Here's a basic example of creating and using a matrix: 30 | 31 | ```rust 32 | use dimensionals::{matrix, Dimensional, LinearArrayStorage}; 33 | 34 | fn main() { 35 | let m: Dimensional, 2> = matrix![ 36 | [1, 2, 3], 37 | [4, 5, 6] 38 | ]; 39 | assert_eq!(m[[0, 0]], 1); 40 | assert_eq!(m[[1, 1]], 5); 41 | 42 | // Element-wise addition 43 | let m2 = &m + &m; 44 | assert_eq!(m2[[0, 0]], 2); 45 | assert_eq!(m2[[1, 1]], 10); 46 | 47 | // Scalar multiplication 48 | let m3 = &m * 2; 49 | assert_eq!(m3[[0, 0]], 2); 50 | assert_eq!(m3[[1, 1]], 10); 51 | 52 | // Iteration 53 | for &value in m.iter() { 54 | println!("{}", value); 55 | } 56 | 57 | // Matrix multiplication 58 | let m4: Dimensional, 2> = matrix![ 59 | [7, 8], 60 | [9, 10], 61 | [11, 12] 62 | ]; 63 | let product = m.dot(&m4); 64 | assert_eq!(product[[0, 0]], 58); 65 | } 66 | ``` 67 | 68 | For more examples and usage details, see the [API documentation](https://docs.rs/dimensionals). 69 | 70 | ## Core Concepts 71 | 72 | - **Element type `T`**: The type of data stored in the array (must implement `Num` and `Copy`). 73 | - **Storage backend `S`**: The underlying storage mechanism for the array (must implement `DimensionalStorage`). 74 | - **Number of dimensions `N`**: The dimensionality of the array (const generic parameter). 75 | 76 | ## Performance 77 | 78 | The `LinearArrayStorage` backend stores elements in a contiguous `Vec` and supports both row-major and column-major layouts. This provides good cache locality for traversals. The storage computes strides for efficient indexing. 79 | 80 | ## Roadmap 81 | 82 | The following features and improvements are planned for future releases: 83 | 84 | - [ * ] Basic N-dimensional array 85 | - [ * ] Basic indexing 86 | - [ * ] Basic iterators 87 | - [ * ] Basic arithmetic operations 88 | - [ * ] Basic slicing 89 | - [ * ] Use safe rust in indexing 90 | - [ * ] Support common arithmetic operations 91 | - [ * ] Use safe rust in arithmetic operations 92 | - [ ] Move shape data to type-system for compile-time known dimensions 93 | - [ ] Matrix multiplication 94 | - [ ] Use safe Rust in iterators (currently uses unsafe code) 95 | - [ ] Add tensor macro for creating higher-dimensional arrays 96 | - [ ] Remove the need for phantom data markers 97 | - [ ] Support reshaping, appending, and removing operations 98 | - [ ] Implement comprehensive linear algebra functions 99 | - [ ] Add support for common statistical functions 100 | - [ ] Implement geometric functions like Brownian motion 101 | - [ ] Add support for GPU offloading 102 | - [ ] Implement SIMD optimizations 103 | - [ ] Support Apache Arrow or safetensors storage backend 104 | - [ ] Integrate with Polars, plotly-rs, and argmin-rs 105 | - [ ] Add parallel processing support with Rayon 106 | - [ ] Implement feature flags for optional functionality 107 | - [ ] Support no_std environments 108 | - [ ] Add WebAssembly and WebGPU support 109 | - [ ] Implement support for SVM targets 110 | 111 | ## Contributing 112 | 113 | Contributions are welcome! Please feel free to submit issues, feature requests, or pull requests on the [GitHub repository](https://github.com/warlock-labs/dimensionals). 114 | 115 | ## License 116 | 117 | This project is licensed under the [MIT License](https://choosealicense.com/licenses/mit/). 118 | 119 | ## Contact 120 | 121 | Warlock Labs - [https://github.com/warlock-labs](https://github.com/warlock-labs) 122 | 123 | Project Link: [https://github.com/warlock-labs/dimensionals](https://github.com/warlock-labs/dimensionals) -------------------------------------------------------------------------------- /benches/benchmarks.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use dimensionals::{Dimensional, LinearArrayStorage}; 3 | 4 | // TODO: This needs meaningful benchmarks for common operations useful in 5 | // quantitive situations 6 | 7 | fn bench_dimensional_array_creation_zeros(c: &mut Criterion) { 8 | let shape = [1000, 1000]; 9 | c.bench_function("dimensional_array_creation_zeros", |b| { 10 | b.iter(|| Dimensional::, 2>::zeros(shape)) 11 | }); 12 | } 13 | 14 | fn bench_dimensional_array_creation_ones(c: &mut Criterion) { 15 | let shape = [1000, 1000]; 16 | c.bench_function("dimensional_array_creation_ones", |b| { 17 | b.iter(|| Dimensional::, 2>::ones(shape)) 18 | }); 19 | } 20 | 21 | fn bench_dimensional_array_indexing(c: &mut Criterion) { 22 | let shape = [1000, 1000]; 23 | let array = Dimensional::, 2>::zeros(shape); 24 | 25 | c.bench_function("dimensional_array_indexing", |b| { 26 | b.iter(|| { 27 | for i in 0..shape[0] { 28 | for j in 0..shape[1] { 29 | black_box(array[[i, j]]); 30 | } 31 | } 32 | }) 33 | }); 34 | } 35 | 36 | fn bench_dimensional_array_mutable_indexing(c: &mut Criterion) { 37 | let shape = [1000, 1000]; 38 | let mut array = Dimensional::, 2>::zeros(shape); 39 | 40 | c.bench_function("dimensional_array_mutable_indexing", |b| { 41 | b.iter(|| { 42 | for i in 0..shape[0] { 43 | for j in 0..shape[1] { 44 | array[[i, j]] = 1.0; 45 | } 46 | } 47 | }) 48 | }); 49 | } 50 | 51 | fn bench_matrix_multiplication(c: &mut Criterion) { 52 | let shape1 = [100, 200]; 53 | let shape2 = [200, 100]; 54 | let m1 = Dimensional::, 2>::ones(shape1); 55 | let m2 = Dimensional::, 2>::ones(shape2); 56 | 57 | c.bench_function("matrix_multiplication", |b| b.iter(|| m1.dot(&m2))); 58 | } 59 | 60 | fn bench_matrix_transpose(c: &mut Criterion) { 61 | let shape = [1000, 1000]; 62 | let m = Dimensional::, 2>::ones(shape); 63 | 64 | c.bench_function("matrix_transpose", |b| b.iter(|| m.transpose())); 65 | } 66 | 67 | fn bench_matrix_trace(c: &mut Criterion) { 68 | let shape = [1000, 1000]; 69 | let m = Dimensional::, 2>::ones(shape); 70 | 71 | c.bench_function("matrix_trace", |b| b.iter(|| m.trace())); 72 | } 73 | 74 | criterion_group!( 75 | benches, 76 | bench_dimensional_array_creation_zeros, 77 | bench_dimensional_array_creation_ones, 78 | bench_dimensional_array_indexing, 79 | bench_dimensional_array_mutable_indexing, 80 | bench_matrix_multiplication, 81 | bench_matrix_transpose, 82 | bench_matrix_trace 83 | ); 84 | criterion_main!(benches); 85 | -------------------------------------------------------------------------------- /src/core.rs: -------------------------------------------------------------------------------- 1 | use crate::storage::DimensionalStorage; 2 | use num_traits::Num; 3 | use std::marker::PhantomData; 4 | 5 | /// A multidimensional array type. 6 | /// 7 | /// This struct represents a multidimensional array with a generic storage backend. 8 | /// 9 | /// # Type Parameters 10 | /// 11 | /// * `T`: The element type of the array. Must implement `Num` and `Copy`. 12 | /// * `S`: The storage backend for the array. Must implement `DimensionalStorage`. 13 | /// * `N`: The dimensionality of the array a `usize`. 14 | #[derive(Debug, Clone, Eq, Copy)] 15 | pub struct Dimensional 16 | where 17 | S: DimensionalStorage, 18 | { 19 | pub(crate) shape: [usize; N], 20 | pub(crate) storage: S, 21 | _marker: PhantomData, 22 | } 23 | 24 | impl Dimensional 25 | where 26 | S: DimensionalStorage, 27 | { 28 | /// Creates a new array filled with zeros. 29 | /// 30 | /// # Arguments 31 | /// 32 | /// * `shape`: The shape of the array. 33 | /// 34 | /// # Examples 35 | /// 36 | /// ``` 37 | /// use dimensionals::{Dimensional, LinearArrayStorage}; 38 | /// 39 | /// let zeros: Dimensional, 2> = Dimensional::zeros([2, 3]); 40 | /// assert_eq!(zeros.shape(), [2, 3]); 41 | /// assert!(zeros.as_slice().iter().all(|&x| x == 0)); 42 | /// ``` 43 | pub fn zeros(shape: [usize; N]) -> Self { 44 | let storage = S::zeros(shape); 45 | 46 | Self { 47 | shape, 48 | storage, 49 | _marker: PhantomData, 50 | } 51 | } 52 | 53 | /// Creates a new array filled with ones. 54 | /// 55 | /// # Arguments 56 | /// 57 | /// * `shape`: The shape of the array. 58 | /// 59 | /// # Examples 60 | /// 61 | /// ``` 62 | /// use dimensionals::{Dimensional, LinearArrayStorage}; 63 | /// 64 | /// let ones: Dimensional, 2> = Dimensional::ones([2, 3]); 65 | /// assert_eq!(ones.shape(), [2, 3]); 66 | /// assert!(ones.as_slice().iter().all(|&x| x == 1)); 67 | /// ``` 68 | pub fn ones(shape: [usize; N]) -> Self { 69 | let storage = S::ones(shape); 70 | 71 | Self { 72 | shape, 73 | storage, 74 | _marker: PhantomData, 75 | } 76 | } 77 | 78 | /// Creates a new multidimensional array. 79 | /// 80 | /// # Arguments 81 | /// 82 | /// * `shape`: The shape of the array. 83 | /// * `storage`: The storage backend for the array. 84 | /// 85 | /// # Examples 86 | /// 87 | /// ``` 88 | /// use dimensionals::{Dimensional, LinearArrayStorage, DimensionalStorage}; 89 | /// 90 | /// let storage = LinearArrayStorage::from_vec([2, 3], vec![1, 2, 3, 4, 5, 6]); 91 | /// let array = Dimensional::new([2, 3], storage); 92 | /// assert_eq!(array.shape(), [2, 3]); 93 | /// assert_eq!(array.as_slice(), &[1, 2, 3, 4, 5, 6]); 94 | /// ``` 95 | pub fn new(shape: [usize; N], storage: S) -> Self { 96 | assert_eq!( 97 | shape.iter().product::(), 98 | storage.as_slice().len(), 99 | "Storage size must match the product of shape dimensions" 100 | ); 101 | Self { 102 | shape, 103 | storage, 104 | _marker: PhantomData, 105 | } 106 | } 107 | 108 | /// Creates a new array using a function to initialize each element. 109 | /// 110 | /// # Arguments 111 | /// 112 | /// * `shape`: The shape of the array. 113 | /// * `f`: A function that takes an index and returns the value for that index. 114 | /// 115 | /// # Examples 116 | /// 117 | /// ``` 118 | /// use dimensionals::{Dimensional, LinearArrayStorage}; 119 | /// 120 | /// let array: Dimensional, 2> = 121 | /// Dimensional::from_fn([2, 3], |[i, j]| (i * 3 + j) as i32); 122 | /// assert_eq!(array.shape(), [2, 3]); 123 | /// assert_eq!(array.as_slice(), &[0, 1, 2, 3, 4, 5]); 124 | /// ``` 125 | pub fn from_fn(shape: [usize; N], f: F) -> Self 126 | where 127 | F: Fn([usize; N]) -> T, 128 | { 129 | // Initialize with zeros 130 | let storage = S::zeros(shape); 131 | let mut array = Self { 132 | shape, 133 | storage, 134 | _marker: PhantomData, 135 | }; 136 | 137 | // Unravel index and apply f 138 | for i in 0..array.len() { 139 | let index = array.unravel_index(i); 140 | array.storage.as_mut_slice()[i] = f(index); 141 | } 142 | 143 | array 144 | } 145 | 146 | /// Converts a linear index to a multidimensional index. 147 | /// 148 | /// # Arguments 149 | /// 150 | /// * `index`: The linear index. 151 | /// 152 | /// # Returns 153 | /// 154 | /// A multidimensional index as an array of `usize`. 155 | pub fn unravel_index(&self, index: usize) -> [usize; N] { 156 | let mut index = index; 157 | let mut unraveled = [0; N]; 158 | 159 | for i in (0..N).rev() { 160 | unraveled[i] = index % self.shape[i]; 161 | index /= self.shape[i]; 162 | } 163 | 164 | unraveled 165 | } 166 | 167 | /// Converts a multidimensional index to a linear index. 168 | /// 169 | /// # Arguments 170 | /// 171 | /// * `indices`: The multidimensional index. 172 | /// 173 | /// # Returns 174 | /// 175 | /// A linear index as `usize`. 176 | pub fn ravel_index(&self, indices: &[usize; N]) -> usize { 177 | indices 178 | .iter() 179 | .zip(self.shape.iter()) 180 | .fold(0, |acc, (&i, &s)| acc * s + i) 181 | } 182 | 183 | // TODO what if any is the use case for jagged arrays? 184 | 185 | /// Returns the shape of the array. 186 | /// 187 | /// # Returns 188 | /// 189 | /// An array of `usize` representing the shape of the array. 190 | pub fn shape(&self) -> [usize; N] { 191 | self.shape 192 | } 193 | 194 | /// Returns the number of dimensions of the array. 195 | /// 196 | /// # Returns 197 | /// 198 | /// The number of dimensions as `usize`. 199 | pub fn ndim(&self) -> usize { 200 | N 201 | } 202 | 203 | /// Returns the total number of elements in the array. 204 | /// 205 | /// # Returns 206 | /// 207 | /// The total number of elements as `usize`. 208 | pub fn len(&self) -> usize { 209 | self.storage.len() 210 | } 211 | 212 | /// Returns `true` if the array is empty. 213 | /// 214 | /// # Returns 215 | /// 216 | /// A boolean indicating whether the array is empty. 217 | pub fn is_empty(&self) -> bool { 218 | self.storage.len() == 0 219 | } 220 | 221 | /// Returns the length of the array along a given axis. 222 | /// 223 | /// # Arguments 224 | /// 225 | /// * `axis`: The axis to get the length of. 226 | /// 227 | /// # Returns 228 | /// 229 | /// The length of the specified axis as `usize`. 230 | /// 231 | /// # Panics 232 | /// 233 | /// Panics if the axis is out of bounds. 234 | pub fn len_axis(&self, axis: usize) -> usize { 235 | assert!(axis < N, "Axis out of bounds"); 236 | self.shape[axis] 237 | } 238 | 239 | // TODO Seems like there may need to be an abstraction layer here 240 | 241 | /// Returns a mutable slice of the underlying data. 242 | /// 243 | /// # Returns 244 | /// 245 | /// A mutable slice of the underlying data. 246 | pub fn as_mut_slice(&mut self) -> &mut [T] { 247 | self.storage.as_mut_slice() 248 | } 249 | 250 | // TODO same story here, this seems pretty tightly coupled to the storage 251 | 252 | /// Returns an immutable slice of the underlying data. 253 | /// 254 | /// # Returns 255 | /// 256 | /// An immutable slice of the underlying data. 257 | pub fn as_slice(&self) -> &[T] { 258 | self.storage.as_slice() 259 | } 260 | } 261 | 262 | // Specific implementations for 2D arrays 263 | impl Dimensional 264 | where 265 | S: DimensionalStorage, 266 | { 267 | /// Creates a new identity array (square matrix with ones on the diagonal and zeros elsewhere). 268 | /// 269 | /// # Arguments 270 | /// 271 | /// * `n`: The size of the square matrix. 272 | /// 273 | /// # Examples 274 | /// 275 | /// ``` 276 | /// use dimensionals::{Dimensional, LinearArrayStorage}; 277 | /// 278 | /// let eye: Dimensional, 2> = Dimensional::eye(3); 279 | /// assert_eq!(eye.shape(), [3, 3]); 280 | /// assert_eq!(eye[[0, 0]], 1); 281 | /// assert_eq!(eye[[1, 1]], 1); 282 | /// assert_eq!(eye[[2, 2]], 1); 283 | /// assert_eq!(eye[[0, 1]], 0); 284 | /// ``` 285 | pub fn eye(n: usize) -> Self { 286 | Self::from_fn([n, n], |[i, j]| if i == j { T::one() } else { T::zero() }) 287 | } 288 | 289 | /// Creates a new identity-like array with a specified value on the diagonal. 290 | /// 291 | /// # Arguments 292 | /// 293 | /// * `n`: The size of the square matrix. 294 | /// * `value`: The value to place on the diagonal. 295 | /// 296 | /// # Examples 297 | /// 298 | /// ``` 299 | /// use dimensionals::{Dimensional, LinearArrayStorage}; 300 | /// 301 | /// let eye: Dimensional, 2> = Dimensional::eye_value(3, 2.5); 302 | /// assert_eq!(eye.shape(), [3, 3]); 303 | /// assert_eq!(eye[[0, 0]], 2.5); 304 | /// assert_eq!(eye[[1, 1]], 2.5); 305 | /// assert_eq!(eye[[2, 2]], 2.5); 306 | /// assert_eq!(eye[[0, 1]], 0.0); 307 | /// ``` 308 | pub fn eye_value(n: usize, value: T) -> Self { 309 | Self::from_fn([n, n], |[i, j]| if i == j { value } else { T::zero() }) 310 | } 311 | } 312 | 313 | #[cfg(test)] 314 | mod tests { 315 | use super::*; 316 | use crate::LinearArrayStorage; 317 | use num_traits::FloatConst; 318 | 319 | #[test] 320 | fn test_zeros_and_ones() { 321 | let zeros: Dimensional, 2> = Dimensional::zeros([2, 3]); 322 | assert_eq!(zeros.shape(), [2, 3]); 323 | assert!(zeros.as_slice().iter().all(|&x| x == 0)); 324 | 325 | let ones: Dimensional, 2> = Dimensional::ones([2, 3]); 326 | assert_eq!(ones.shape(), [2, 3]); 327 | assert!(ones.as_slice().iter().all(|&x| x == 1)); 328 | } 329 | 330 | #[test] 331 | fn test_new() { 332 | let storage = LinearArrayStorage::from_vec([2, 3], vec![1, 2, 3, 4, 5, 6]); 333 | let array = Dimensional::new([2, 3], storage); 334 | assert_eq!(array.shape(), [2, 3]); 335 | assert_eq!(array.as_slice(), &[1, 2, 3, 4, 5, 6]); 336 | } 337 | 338 | #[test] 339 | #[should_panic(expected = "Storage size must match the product of shape dimensions")] 340 | fn test_new_mismatched_shape() { 341 | let storage = LinearArrayStorage::from_vec([2, 2], vec![1, 2, 3, 4]); 342 | Dimensional::new([2, 3], storage); 343 | } 344 | 345 | #[test] 346 | fn test_from_fn() { 347 | let array: Dimensional, 2> = 348 | Dimensional::from_fn([2, 3], |[i, j]| (i * 3 + j) as i32); 349 | assert_eq!(array.shape(), [2, 3]); 350 | assert_eq!(array.as_slice(), &[0, 1, 2, 3, 4, 5]); 351 | } 352 | 353 | #[test] 354 | fn test_unravel_and_ravel_index() { 355 | let array: Dimensional, 3> = Dimensional::zeros([2, 3, 4]); 356 | for i in 0..24 { 357 | let unraveled = array.unravel_index(i); 358 | let raveled = array.ravel_index(&unraveled); 359 | assert_eq!(i, raveled); 360 | } 361 | } 362 | 363 | #[test] 364 | fn test_ravel_unravel_consistency() { 365 | let array: Dimensional, 3> = Dimensional::zeros([2, 3, 4]); 366 | 367 | for i in 0..2 { 368 | for j in 0..3 { 369 | for k in 0..4 { 370 | let index = [i, j, k]; 371 | let raveled = array.ravel_index(&index); 372 | let unraveled = array.unravel_index(raveled); 373 | assert_eq!( 374 | index, unraveled, 375 | "Ravel/unravel mismatch for index {:?}", 376 | index 377 | ); 378 | } 379 | } 380 | } 381 | } 382 | 383 | #[test] 384 | fn test_shape_and_dimensions() { 385 | let array: Dimensional, 3> = Dimensional::zeros([2, 3, 4]); 386 | assert_eq!(array.shape(), [2, 3, 4]); 387 | assert_eq!(array.ndim(), 3); 388 | assert_eq!(array.len(), 24); 389 | assert!(!array.is_empty()); 390 | assert_eq!(array.len_axis(0), 2); 391 | assert_eq!(array.len_axis(1), 3); 392 | assert_eq!(array.len_axis(2), 4); 393 | } 394 | 395 | #[test] 396 | #[should_panic(expected = "Axis out of bounds")] 397 | fn test_len_axis_out_of_bounds() { 398 | let array: Dimensional, 2> = Dimensional::zeros([2, 3]); 399 | array.len_axis(2); 400 | } 401 | 402 | #[test] 403 | fn test_as_slice_and_as_mut_slice() { 404 | let mut array: Dimensional, 2> = 405 | Dimensional::from_fn([2, 3], |[i, j]| (i * 3 + j) as i32); 406 | 407 | assert_eq!(array.as_slice(), &[0, 1, 2, 3, 4, 5]); 408 | 409 | { 410 | let slice = array.as_mut_slice(); 411 | slice[0] = 10; 412 | slice[5] = 50; 413 | } 414 | 415 | assert_eq!(array.as_slice(), &[10, 1, 2, 3, 4, 50]); 416 | } 417 | 418 | #[test] 419 | fn test_eye() { 420 | let eye: Dimensional, 2> = Dimensional::eye(3); 421 | 422 | // Check shape 423 | assert_eq!(eye.shape(), [3, 3]); 424 | 425 | // Check diagonal elements 426 | assert_eq!(eye[[0, 0]], 1); 427 | assert_eq!(eye[[1, 1]], 1); 428 | assert_eq!(eye[[2, 2]], 1); 429 | 430 | // Check off-diagonal elements 431 | assert_eq!(eye[[0, 1]], 0); 432 | assert_eq!(eye[[0, 2]], 0); 433 | assert_eq!(eye[[1, 0]], 0); 434 | assert_eq!(eye[[1, 2]], 0); 435 | assert_eq!(eye[[2, 0]], 0); 436 | assert_eq!(eye[[2, 1]], 0); 437 | 438 | // Check sum of all elements (should equal to the size of the matrix) 439 | let sum: i32 = eye.as_slice().iter().sum(); 440 | assert_eq!(sum, 3); 441 | 442 | // Test with a different size 443 | let eye_4x4: Dimensional, 2> = Dimensional::eye(4); 444 | assert_eq!(eye_4x4.shape(), [4, 4]); 445 | assert_eq!(eye_4x4[[3, 3]], 1); 446 | assert_eq!(eye_4x4[[0, 3]], 0); 447 | 448 | // Test with a different type 449 | let eye_float: Dimensional, 2> = Dimensional::eye(2); 450 | assert_eq!(eye_float[[0, 0]], 1.0); 451 | assert_eq!(eye_float[[0, 1]], 0.0); 452 | assert_eq!(eye_float[[1, 0]], 0.0); 453 | assert_eq!(eye_float[[1, 1]], 1.0); 454 | } 455 | 456 | #[test] 457 | fn test_eye_value() { 458 | // Test with integer type 459 | let eye_int: Dimensional, 2> = Dimensional::eye_value(3, 5); 460 | assert_eq!(eye_int.shape(), [3, 3]); 461 | assert_eq!(eye_int[[0, 0]], 5); 462 | assert_eq!(eye_int[[1, 1]], 5); 463 | assert_eq!(eye_int[[2, 2]], 5); 464 | assert_eq!(eye_int[[0, 1]], 0); 465 | assert_eq!(eye_int[[1, 2]], 0); 466 | 467 | // Test with floating-point type 468 | let eye_float: Dimensional, 2> = 469 | Dimensional::eye_value(2, f64::PI()); 470 | assert_eq!(eye_float.shape(), [2, 2]); 471 | assert_eq!(eye_float[[0, 0]], f64::PI()); 472 | assert_eq!(eye_float[[1, 1]], f64::PI()); 473 | assert_eq!(eye_float[[0, 1]], 0.0); 474 | assert_eq!(eye_float[[1, 0]], 0.0); 475 | 476 | // Test with a negative value 477 | let eye_neg: Dimensional, 2> = 478 | Dimensional::eye_value(2, -1); 479 | assert_eq!(eye_neg.shape(), [2, 2]); 480 | assert_eq!(eye_neg[[0, 0]], -1); 481 | assert_eq!(eye_neg[[1, 1]], -1); 482 | assert_eq!(eye_neg[[0, 1]], 0); 483 | assert_eq!(eye_neg[[1, 0]], 0); 484 | } 485 | 486 | #[test] 487 | fn test_len() { 488 | let array_2d: Dimensional, 2> = Dimensional::zeros([2, 3]); 489 | assert_eq!(array_2d.len(), 6); 490 | 491 | let array_3d: Dimensional, 3> = 492 | Dimensional::zeros([2, 3, 4]); 493 | assert_eq!(array_3d.len(), 24); 494 | 495 | let array_1d: Dimensional, 1> = Dimensional::zeros([5]); 496 | assert_eq!(array_1d.len(), 5); 497 | 498 | let array_empty: Dimensional, 1> = Dimensional::zeros([0]); 499 | assert_eq!(array_empty.len(), 0); 500 | } 501 | } 502 | -------------------------------------------------------------------------------- /src/display.rs: -------------------------------------------------------------------------------- 1 | //! This module implements the Display trait for the Dimensional struct, 2 | //! allowing for pretty-printing of multidimensional arrays. 3 | 4 | use crate::{Dimensional, DimensionalStorage}; 5 | use num_traits::Num; 6 | use std::fmt; 7 | 8 | impl fmt::Display for Dimensional 9 | where 10 | T: Num + Copy + fmt::Display, 11 | S: DimensionalStorage, 12 | { 13 | /// Formats the Dimensional array for display. 14 | /// 15 | /// The format differs based on the number of dimensions: 16 | /// - 1D arrays are displayed as a single row 17 | /// - 2D arrays are displayed as a matrix 18 | /// - Higher dimensional arrays are displayed in a compact format 19 | /// 20 | /// # Arguments 21 | /// 22 | /// * `f` - A mutable reference to the Formatter 23 | /// 24 | /// # Returns 25 | /// 26 | /// A fmt::Result indicating whether the operation was successful 27 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 28 | match N { 29 | 1 => self.fmt_1d(f), 30 | 2 => self.fmt_2d(f), 31 | _ => self.fmt_nd(f), 32 | } 33 | } 34 | } 35 | 36 | impl Dimensional 37 | where 38 | T: Num + Copy + fmt::Display, 39 | S: DimensionalStorage, 40 | { 41 | /// Formats a 1D array for display. 42 | /// 43 | /// # Arguments 44 | /// 45 | /// * `f` - A mutable reference to the Formatter 46 | /// 47 | /// # Returns 48 | /// 49 | /// A fmt::Result indicating whether the operation was successful 50 | fn fmt_1d(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 51 | write!(f, "[")?; 52 | let mut iter = self.as_slice().iter().peekable(); 53 | while let Some(val) = iter.next() { 54 | // Check if a precision is specified in the formatter 55 | if let Some(precision) = f.precision() { 56 | write!(f, "{:.1$}", val, precision)?; 57 | } else { 58 | write!(f, "{}", val)?; 59 | } 60 | if iter.peek().is_some() { 61 | write!(f, ", ")?; 62 | } 63 | } 64 | write!(f, "]") 65 | } 66 | 67 | /// Formats a 2D array for display as a matrix. 68 | /// 69 | /// # Arguments 70 | /// 71 | /// * `f` - A mutable reference to the Formatter 72 | /// 73 | /// # Returns 74 | /// 75 | /// A fmt::Result indicating whether the operation was successful 76 | fn fmt_2d(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 77 | assert_eq!(N, 2, "fmt_2d should only be called for 2D arrays"); 78 | let shape = self.shape(); 79 | writeln!(f, "[")?; 80 | for i in 0..shape[0] { 81 | write!(f, " [")?; 82 | for j in 0..shape[1] { 83 | let mut index_array = [0; N]; 84 | index_array[0] = i; 85 | index_array[1] = j; 86 | let index = self.ravel_index(&index_array); 87 | // Check if a precision is specified in the formatter 88 | if let Some(precision) = f.precision() { 89 | write!(f, "{:.1$}", self.as_slice()[index], precision)?; 90 | } else { 91 | write!(f, "{}", self.as_slice()[index])?; 92 | } 93 | if j < shape[1] - 1 { 94 | write!(f, ", ")?; 95 | } 96 | } 97 | if i < shape[0] - 1 { 98 | writeln!(f, "],")?; 99 | } else { 100 | writeln!(f, "]")?; 101 | } 102 | } 103 | write!(f, "]") 104 | } 105 | 106 | /// Formats a higher dimensional array for display in a compact format. 107 | /// 108 | /// # Arguments 109 | /// 110 | /// * `f` - A mutable reference to the Formatter 111 | /// 112 | /// # Returns 113 | /// 114 | /// A fmt::Result indicating whether the operation was successful 115 | fn fmt_nd(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 116 | write!(f, "{}D array: ", N)?; 117 | write!(f, "shape {:?}, ", self.shape())?; 118 | write!(f, "data [")?; 119 | 120 | let slice = self.as_slice(); 121 | let len = slice.len(); 122 | let display_count = if len > 6 { 3 } else { len }; 123 | 124 | for (i, val) in slice.iter().take(display_count).enumerate() { 125 | if i > 0 { 126 | write!(f, ", ")?; 127 | } 128 | if let Some(precision) = f.precision() { 129 | write!(f, "{:.1$}", val, precision)?; 130 | } else { 131 | write!(f, "{}", val)?; 132 | } 133 | } 134 | 135 | if len > 6 { 136 | write!(f, ", ..., ")?; 137 | if let Some(precision) = f.precision() { 138 | write!(f, "{:.1$}", slice.last().ok_or(fmt::Error)?, precision)?; 139 | } else { 140 | write!(f, "{}", slice.last().ok_or(fmt::Error)?)?; 141 | } 142 | } 143 | 144 | write!(f, "]") 145 | } 146 | } 147 | 148 | #[cfg(test)] 149 | mod tests { 150 | use super::*; 151 | use crate::LinearArrayStorage; 152 | 153 | #[test] 154 | fn test_display_1d() { 155 | let array: Dimensional, 1> = 156 | Dimensional::from_fn([5], |[i]| i as i32); 157 | assert_eq!(format!("{}", array), "[0, 1, 2, 3, 4]"); 158 | 159 | // Test empty 1D array 160 | let empty: Dimensional, 1> = Dimensional::zeros([0]); 161 | assert_eq!(format!("{}", empty), "[]"); 162 | 163 | // Test 1D array with single element 164 | let single: Dimensional, 1> = Dimensional::ones([1]); 165 | assert_eq!(format!("{}", single), "[1]"); 166 | 167 | // Test large 1D array 168 | let large: Dimensional, 1> = 169 | Dimensional::from_fn([10], |[i]| i as i32); 170 | assert_eq!(format!("{}", large), "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"); 171 | } 172 | 173 | #[test] 174 | fn test_display_2d() { 175 | let array: Dimensional, 2> = 176 | Dimensional::from_fn([2, 3], |[i, j]| (i * 3 + j) as i32); 177 | assert_eq!(format!("{}", array), "[\n [0, 1, 2],\n [3, 4, 5]\n]"); 178 | 179 | // Test 2D array with single row 180 | let single_row: Dimensional, 2> = 181 | Dimensional::from_fn([1, 3], |[_, j]| j as i32); 182 | assert_eq!(format!("{}", single_row), "[\n [0, 1, 2]\n]"); 183 | 184 | // Test 2D array with single column 185 | let single_column: Dimensional, 2> = 186 | Dimensional::from_fn([3, 1], |[i, _]| i as i32); 187 | assert_eq!(format!("{}", single_column), "[\n [0],\n [1],\n [2]\n]"); 188 | 189 | // Test empty 2D array 190 | let empty: Dimensional, 2> = Dimensional::zeros([0, 0]); 191 | assert_eq!(format!("{}", empty), "[\n]"); 192 | } 193 | 194 | #[test] 195 | fn test_display_3d() { 196 | let array: Dimensional, 3> = 197 | Dimensional::from_fn([2, 2, 2], |[i, j, k]| (i * 4 + j * 2 + k) as i32); 198 | assert_eq!( 199 | format!("{}", array), 200 | "3D array: shape [2, 2, 2], data [0, 1, 2, ..., 7]" 201 | ); 202 | 203 | // Test 3D array with small size 204 | let small: Dimensional, 3> = 205 | Dimensional::from_fn([1, 2, 3], |[i, j, k]| (i * 6 + j * 3 + k) as i32); 206 | assert_eq!( 207 | format!("{}", small), 208 | "3D array: shape [1, 2, 3], data [0, 1, 2, 3, 4, 5]" 209 | ); 210 | 211 | // Test empty 3D array 212 | let empty: Dimensional, 3> = Dimensional::zeros([0, 0, 0]); 213 | assert_eq!(format!("{}", empty), "3D array: shape [0, 0, 0], data []"); 214 | } 215 | 216 | #[test] 217 | fn test_display_float() { 218 | let array: Dimensional, 2> = 219 | Dimensional::from_fn([2, 2], |[i, j]| (i * 2 + j) as f64 + 0.5); 220 | assert_eq!(format!("{}", array), "[\n [0.5, 1.5],\n [2.5, 3.5]\n]"); 221 | 222 | // Test float precision for 1D array 223 | let precise_1d: Dimensional, 1> = 224 | Dimensional::from_fn([3], |[i]| i as f64 / 3.0); 225 | assert_eq!(format!("{:.2}", precise_1d), "[0.00, 0.33, 0.67]"); 226 | 227 | // Test float precision for 2D array 228 | let precise_2d: Dimensional, 2> = 229 | Dimensional::from_fn([2, 2], |[i, j]| (i + j) as f64 / 3.0); 230 | assert_eq!( 231 | format!("{:.2}", precise_2d), 232 | "[\n [0.00, 0.33],\n [0.33, 0.67]\n]" 233 | ); 234 | 235 | // TODO(Fix this test case, maybe issue with ravel_index) 236 | // Test float precision for 3D array 237 | let precise_3d: Dimensional, 3> = 238 | Dimensional::from_fn([3, 3, 3], |[i, j, k]| (i + j + k) as f64 / 3.0); 239 | assert_eq!( 240 | format!("{:.2}", precise_3d), 241 | "3D array: shape [3, 3, 3], data [0.00, 0.33, 0.67, ..., 2.00]" 242 | ); 243 | } 244 | 245 | #[test] 246 | fn test_display_large_dimensions() { 247 | let array: Dimensional, 4> = 248 | Dimensional::from_fn([2, 2, 2, 2], |[i, j, k, l]| { 249 | (i * 8 + j * 4 + k * 2 + l) as i32 250 | }); 251 | assert_eq!( 252 | format!("{}", array), 253 | "4D array: shape [2, 2, 2, 2], data [0, 1, 2, ..., 15]" 254 | ); 255 | 256 | // Test 5D array 257 | let array_5d: Dimensional, 5> = 258 | Dimensional::from_fn([2, 2, 2, 2, 2], |[i, j, k, l, m]| { 259 | (i * 16 + j * 8 + k * 4 + l * 2 + m) as i32 260 | }); 261 | assert_eq!( 262 | format!("{}", array_5d), 263 | "5D array: shape [2, 2, 2, 2, 2], data [0, 1, 2, ..., 31]" 264 | ); 265 | } 266 | 267 | #[test] 268 | fn test_display_consistency() { 269 | let array: Dimensional, 2> = 270 | Dimensional::from_fn([3, 4], |[i, j]| (i * 4 + j) as i32); 271 | 272 | let display_output = format!("{}", array); 273 | let expected_output = "[\n [0, 1, 2, 3],\n [4, 5, 6, 7],\n [8, 9, 10, 11]\n]"; 274 | assert_eq!(display_output, expected_output, "Display output mismatch"); 275 | } 276 | } 277 | -------------------------------------------------------------------------------- /src/iterators.rs: -------------------------------------------------------------------------------- 1 | //! This module provides iterator implementations for the Dimensional struct. 2 | //! It includes both immutable and mutable iterators, allowing for efficient 3 | //! traversal and modification of Dimensional arrays. 4 | 5 | use crate::{storage::DimensionalStorage, Dimensional}; 6 | use num_traits::Num; 7 | 8 | // TODO: Parallel iterators 9 | 10 | /// An iterator over the elements of a Dimensional array. 11 | /// 12 | /// This struct is created by the `iter` method on Dimensional. It provides 13 | /// a way to iterate over the elements of the array in row-major order. 14 | pub struct DimensionalIter<'a, T, S, const N: usize> 15 | where 16 | T: Num + Copy, 17 | S: DimensionalStorage, 18 | { 19 | dimensional: &'a Dimensional, 20 | current_index: [usize; N], 21 | remaining: usize, 22 | } 23 | 24 | impl<'a, T, S, const N: usize> Iterator for DimensionalIter<'a, T, S, N> 25 | where 26 | T: Num + Copy, 27 | S: DimensionalStorage, 28 | { 29 | type Item = &'a T; 30 | 31 | fn next(&mut self) -> Option { 32 | if self.remaining == 0 { 33 | return None; 34 | } 35 | 36 | let result = &self.dimensional[self.current_index]; 37 | 38 | // TODO: Actually iterate correctly here over an `N`-dimensional array 39 | // with `N` axes each with a possibly different length. 40 | // and determine iteration pattern 41 | 42 | // Update the index for the next iteration 43 | for i in (0..N).rev() { 44 | self.current_index[i] += 1; 45 | if self.current_index[i] < self.dimensional.shape()[i] { 46 | break; 47 | } 48 | self.current_index[i] = 0; 49 | } 50 | 51 | self.remaining -= 1; 52 | Some(result) 53 | } 54 | 55 | fn size_hint(&self) -> (usize, Option) { 56 | (self.remaining, Some(self.remaining)) 57 | } 58 | } 59 | 60 | impl<'a, T, S, const N: usize> ExactSizeIterator for DimensionalIter<'a, T, S, N> 61 | where 62 | T: Num + Copy, 63 | S: DimensionalStorage, 64 | { 65 | } 66 | 67 | /// A mutable iterator over the elements of a Dimensional array. 68 | /// 69 | /// This struct is created by the `iter_mut` method on Dimensional. It provides 70 | /// a way to iterate over and modify the elements of the array in row-major order. 71 | pub struct DimensionalIterMut<'a, T, S, const N: usize> 72 | where 73 | T: Num + Copy, 74 | S: DimensionalStorage, 75 | { 76 | dimensional: &'a mut Dimensional, 77 | current_index: [usize; N], 78 | remaining: usize, 79 | } 80 | 81 | impl<'a, T, S, const N: usize> Iterator for DimensionalIterMut<'a, T, S, N> 82 | where 83 | T: Num + Copy, 84 | S: DimensionalStorage, 85 | { 86 | type Item = &'a mut T; 87 | 88 | fn next(&mut self) -> Option { 89 | if self.remaining == 0 { 90 | return None; 91 | } 92 | 93 | let index = self.current_index; 94 | 95 | self.remaining -= 1; 96 | 97 | // Update the index for the next iteration 98 | for i in (0..N).rev() { 99 | if self.current_index[i] < self.dimensional.shape()[i] - 1 { 100 | self.current_index[i] += 1; 101 | break; 102 | } else { 103 | self.current_index[i] = 0; 104 | } 105 | } 106 | 107 | let linear_index = self.dimensional.ravel_index(&index); 108 | // TODO: We really don't want to use unsafe rust here 109 | // SAFETY: This is safe because we're returning a unique reference to each element, 110 | // and we're iterating over each element only once. 111 | // But what if we modify the array while iterating? 112 | // What if the array is deleted while iterating? 113 | // What if we want to use parallel iterators? 114 | unsafe { Some(&mut *(&mut self.dimensional.as_mut_slice()[linear_index] as *mut T)) } 115 | } 116 | 117 | fn size_hint(&self) -> (usize, Option) { 118 | (self.remaining, Some(self.remaining)) 119 | } 120 | } 121 | 122 | impl Dimensional 123 | where 124 | T: Num + Copy, 125 | S: DimensionalStorage, 126 | { 127 | /// Returns an iterator over the elements of the array. 128 | /// 129 | /// The iterator yields all items from the array in row-major order. 130 | /// 131 | /// # Examples 132 | /// 133 | /// ``` 134 | /// use dimensionals::{Dimensional, LinearArrayStorage, vector, matrix}; 135 | /// 136 | /// let v = vector![1, 2, 3, 4, 5]; 137 | /// let mut iter = v.iter(); 138 | /// assert_eq!(iter.next(), Some(&1)); 139 | /// assert_eq!(iter.next(), Some(&2)); 140 | /// // ... 141 | /// 142 | /// let m = matrix![[1, 2], [3, 4]]; 143 | /// let mut iter = m.iter(); 144 | /// assert_eq!(iter.next(), Some(&1)); 145 | /// assert_eq!(iter.next(), Some(&2)); 146 | /// assert_eq!(iter.next(), Some(&3)); 147 | /// assert_eq!(iter.next(), Some(&4)); 148 | /// assert_eq!(iter.next(), None); 149 | /// ``` 150 | pub fn iter(&self) -> DimensionalIter { 151 | DimensionalIter { 152 | dimensional: self, 153 | current_index: [0; N], 154 | remaining: self.len(), 155 | } 156 | } 157 | 158 | /// Returns a mutable iterator over the elements of the array. 159 | /// 160 | /// The iterator yields all items from the array in row-major order, 161 | /// and allows modifying each value. 162 | /// 163 | /// # Examples 164 | /// 165 | /// ``` 166 | /// use dimensionals::{Dimensional, LinearArrayStorage, vector, matrix}; 167 | /// 168 | /// let mut v = vector![1, 2, 3, 4, 5]; 169 | /// for elem in v.iter_mut() { 170 | /// *elem *= 2; 171 | /// } 172 | /// assert_eq!(v, vector![2, 4, 6, 8, 10]); 173 | /// 174 | /// let mut m = matrix![[1, 2], [3, 4]]; 175 | /// for elem in m.iter_mut() { 176 | /// *elem += 1; 177 | /// } 178 | /// assert_eq!(m, matrix![[2, 3], [4, 5]]); 179 | /// ``` 180 | pub fn iter_mut(&mut self) -> DimensionalIterMut { 181 | let len = self.len(); 182 | DimensionalIterMut { 183 | dimensional: self, 184 | current_index: [0; N], 185 | remaining: len, 186 | } 187 | } 188 | } 189 | 190 | // TODO: Since these are consuming, do they really need a lifetime? 191 | 192 | impl<'a, T, S, const N: usize> IntoIterator for &'a Dimensional 193 | where 194 | T: Num + Copy, 195 | S: DimensionalStorage, 196 | { 197 | type Item = &'a T; 198 | type IntoIter = DimensionalIter<'a, T, S, N>; 199 | 200 | fn into_iter(self) -> Self::IntoIter { 201 | self.iter() 202 | } 203 | } 204 | 205 | impl<'a, T, S, const N: usize> IntoIterator for &'a mut Dimensional 206 | where 207 | T: Num + Copy, 208 | S: DimensionalStorage, 209 | { 210 | type Item = &'a mut T; 211 | type IntoIter = DimensionalIterMut<'a, T, S, N>; 212 | 213 | fn into_iter(self) -> Self::IntoIter { 214 | self.iter_mut() 215 | } 216 | } 217 | 218 | #[cfg(test)] 219 | mod tests { 220 | use crate::{matrix, storage::LinearArrayStorage, Dimensional}; 221 | 222 | #[test] 223 | fn test_iter_mut_borrow() { 224 | let mut m = matrix![[1, 2], [3, 4]]; 225 | let mut iter = m.iter_mut(); 226 | assert_eq!(iter.next(), Some(&mut 1)); 227 | assert_eq!(iter.next(), Some(&mut 2)); 228 | assert_eq!(iter.next(), Some(&mut 3)); 229 | assert_eq!(iter.next(), Some(&mut 4)); 230 | assert_eq!(iter.next(), None); 231 | } 232 | 233 | #[test] 234 | fn test_iter_next() { 235 | let array_1d: Dimensional, 1> = Dimensional::zeros([5]); 236 | let mut iter = array_1d.iter(); 237 | assert_eq!(iter.next(), Some(&0)); 238 | assert_eq!(iter.next(), Some(&0)); 239 | assert_eq!(iter.next(), Some(&0)); 240 | assert_eq!(iter.next(), Some(&0)); 241 | assert_eq!(iter.next(), Some(&0)); 242 | assert_eq!(iter.next(), None); 243 | } 244 | 245 | #[test] 246 | fn test_iter_next_matrix() { 247 | let array_2d: Dimensional, 2> = Dimensional::zeros([2, 3]); 248 | let mut iter = array_2d.iter(); 249 | assert_eq!(iter.next(), Some(&0)); 250 | assert_eq!(iter.next(), Some(&0)); 251 | assert_eq!(iter.next(), Some(&0)); 252 | assert_eq!(iter.next(), Some(&0)); 253 | assert_eq!(iter.next(), Some(&0)); 254 | assert_eq!(iter.next(), Some(&0)); 255 | assert_eq!(iter.next(), None); 256 | } 257 | 258 | #[test] 259 | fn test_iter_mut_next() { 260 | let mut array_1d: Dimensional, 1> = Dimensional::zeros([5]); 261 | let mut iter = array_1d.iter_mut(); 262 | if let Some(elem) = iter.next() { 263 | *elem = 1; 264 | } 265 | if let Some(elem) = iter.next() { 266 | *elem = 2; 267 | } 268 | if let Some(elem) = iter.next() { 269 | *elem = 3; 270 | } 271 | if let Some(elem) = iter.next() { 272 | *elem = 4; 273 | } 274 | if let Some(elem) = iter.next() { 275 | *elem = 5; 276 | } 277 | 278 | let mut iter = array_1d.iter_mut(); 279 | assert_eq!(iter.next(), Some(&mut 1)); 280 | assert_eq!(iter.next(), Some(&mut 2)); 281 | assert_eq!(iter.next(), Some(&mut 3)); 282 | assert_eq!(iter.next(), Some(&mut 4)); 283 | assert_eq!(iter.next(), Some(&mut 5)); 284 | assert_eq!(iter.next(), None); 285 | } 286 | 287 | #[test] 288 | fn test_iter_mut_next_matrix() { 289 | let mut array_2d: Dimensional, 2> = 290 | Dimensional::zeros([2, 3]); 291 | let mut iter = array_2d.iter_mut(); 292 | if let Some(elem) = iter.next() { 293 | *elem = 1; 294 | } 295 | if let Some(elem) = iter.next() { 296 | *elem = 2; 297 | } 298 | if let Some(elem) = iter.next() { 299 | *elem = 3; 300 | } 301 | if let Some(elem) = iter.next() { 302 | *elem = 4; 303 | } 304 | if let Some(elem) = iter.next() { 305 | *elem = 5; 306 | } 307 | if let Some(elem) = iter.next() { 308 | *elem = 6; 309 | } 310 | 311 | let mut iter = array_2d.iter_mut(); 312 | assert_eq!(iter.next(), Some(&mut 1)); 313 | assert_eq!(iter.next(), Some(&mut 2)); 314 | assert_eq!(iter.next(), Some(&mut 3)); 315 | assert_eq!(iter.next(), Some(&mut 4)); 316 | assert_eq!(iter.next(), Some(&mut 5)); 317 | assert_eq!(iter.next(), Some(&mut 6)); 318 | assert_eq!(iter.next(), None); 319 | } 320 | 321 | #[test] 322 | fn test_iter_empty() { 323 | let array_empty: Dimensional, 1> = Dimensional::zeros([0]); 324 | let mut iter = array_empty.iter(); 325 | assert_eq!(iter.next(), None); 326 | } 327 | 328 | #[test] 329 | fn test_iter_mut_empty() { 330 | let mut array_empty: Dimensional, 1> = 331 | Dimensional::zeros([0]); 332 | let mut iter = array_empty.iter_mut(); 333 | assert_eq!(iter.next(), None); 334 | } 335 | 336 | #[test] 337 | fn test_iter_high_dimensional() { 338 | let array_3d: Dimensional, 3> = 339 | Dimensional::zeros([2, 3, 2]); 340 | let mut iter = array_3d.iter(); 341 | assert_eq!(iter.next(), Some(&0)); 342 | assert_eq!(iter.next(), Some(&0)); 343 | assert_eq!(iter.next(), Some(&0)); 344 | assert_eq!(iter.next(), Some(&0)); 345 | assert_eq!(iter.next(), Some(&0)); 346 | assert_eq!(iter.next(), Some(&0)); 347 | assert_eq!(iter.next(), Some(&0)); 348 | assert_eq!(iter.next(), Some(&0)); 349 | assert_eq!(iter.next(), Some(&0)); 350 | assert_eq!(iter.next(), Some(&0)); 351 | assert_eq!(iter.next(), Some(&0)); 352 | assert_eq!(iter.next(), Some(&0)); 353 | assert_eq!(iter.next(), None); 354 | } 355 | 356 | #[test] 357 | fn test_iter_mut_high_dimensional() { 358 | let mut array_3d: Dimensional, 3> = 359 | Dimensional::zeros([2, 3, 2]); 360 | let mut iter = array_3d.iter_mut(); 361 | for i in 1..=12 { 362 | if let Some(elem) = iter.next() { 363 | *elem = i; 364 | } 365 | } 366 | 367 | let mut iter = array_3d.iter_mut(); 368 | for mut i in 1..=12 { 369 | assert_eq!(iter.next(), Some(&mut i)); 370 | } 371 | assert_eq!(iter.next(), None); 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The Dimensionals library provides a multidimensional array implementation 2 | //! with a generic storage backend over a generic number type. 3 | //! 4 | //! # Core Concepts 5 | //! 6 | //! - Element type `T`: The type of data stored in the array. 7 | //! - Storage backend `S`: The underlying storage mechanism for the array. 8 | //! - Number of dimensions `N`: The dimensionality of the array. 9 | //! 10 | //! # Dimensional Types 11 | //! 12 | //! - Scalar: A 0-dimensional object, or just the element of type `T` itself. 13 | //! - Vector: A 1-dimensional array of elements with the type `T`. 14 | //! - Matrix: A 2-dimensional array of elements with the type `T`. 15 | //! - Tensor: An `N`-dimensional array of elements with the type `T`, where `N` > 2. 16 | //! 17 | //! # Goals 18 | //! 19 | //! The primary goal of this library is to provide a flexible and efficient way to work with 20 | //! multidimensional arrays of numeric types in Rust. 21 | //! 22 | //! Using a generic storage backend, `S`, allows for different memory layouts and optimizations. 23 | //! 24 | //! # Convenience Macros 25 | //! 26 | //! The library provides convenience macros for creating arrays: 27 | //! 28 | //! - [`vector!`]: Creates a 1-dimensional array. 29 | //! - [`matrix!`]: Creates a 2-dimensional array. 30 | //! 31 | //! # Example 32 | //! 33 | //! ``` 34 | //! use dimensionals::{matrix, vector, Dimensional, LinearArrayStorage}; 35 | //! 36 | //! // Create a vector 37 | //! let v: Dimensional, 1> = vector![1, 2, 3, 4, 5]; 38 | //! assert_eq!(v[[0]], 1); 39 | //! 40 | //! // Create a matrix 41 | //! let m: Dimensional, 2> = matrix![ 42 | //! [1.0, 2.0, 3.0], 43 | //! [4.0, 5.0, 6.0] 44 | //! ]; 45 | //! assert_eq!(m[[0, 0]], 1.0); 46 | //! assert_eq!(m[[1, 1]], 5.0); 47 | //! ``` 48 | mod core; 49 | mod display; 50 | mod iterators; 51 | mod operators; 52 | mod storage; 53 | 54 | // Public API 55 | pub use crate::core::Dimensional; 56 | pub use iterators::*; 57 | pub use storage::DimensionalStorage; 58 | pub use storage::LinearArrayStorage; 59 | 60 | /// Creates a 1-dimensional array (vector). 61 | /// 62 | /// # Examples 63 | /// 64 | /// ``` 65 | /// use dimensionals::{vector, Dimensional, LinearArrayStorage}; 66 | /// 67 | /// let v: Dimensional, 1> = vector![1, 2, 3, 4, 5]; 68 | /// assert_eq!(v[[0]], 1); 69 | /// assert_eq!(v[[4]], 5); 70 | /// ``` 71 | #[macro_export] 72 | macro_rules! vector { 73 | ($($value:expr),+ $(,)?) => { 74 | { 75 | let data = vec![$($value),+]; 76 | let shape = [data.len()]; 77 | Dimensional::<_, LinearArrayStorage<_, 1>, 1>::from_fn(shape, |[i]| data[i]) 78 | } 79 | }; 80 | } 81 | 82 | /// Creates a 2-dimensional array (matrix). 83 | /// 84 | /// # Examples 85 | /// 86 | /// ``` 87 | /// use dimensionals::{matrix, Dimensional, LinearArrayStorage}; 88 | /// 89 | /// let m: Dimensional, 2> = matrix![ 90 | /// [1, 2, 3], 91 | /// [4, 5, 6] 92 | /// ]; 93 | /// assert_eq!(m[[0, 0]], 1); 94 | /// assert_eq!(m[[1, 2]], 6); 95 | /// ``` 96 | #[macro_export] 97 | macro_rules! matrix { 98 | ($([$($value:expr),* $(,)?]),+ $(,)?) => { 99 | { 100 | let data: Vec> = vec![$(vec![$($value),*]),+]; 101 | let rows = data.len(); 102 | let cols = data[0].len(); 103 | let shape = [rows, cols]; 104 | Dimensional::<_, LinearArrayStorage<_, 2>, 2>::from_fn(shape, |[i, j]| data[i][j]) 105 | } 106 | }; 107 | } 108 | 109 | // TODO: Implement a generic tensor macro 110 | // The tensor macro should create an N-dimensional array (N > 2) with the following features: 111 | // - Infer the number of dimensions and shape from the input 112 | // - Work with any number of dimensions (3 or more) 113 | // - Be as user-friendly as the vector! and matrix! macros 114 | // - Handle type inference correctly 115 | // - Integrate seamlessly with the Dimensional struct and LinearArrayStorage 116 | 117 | #[cfg(test)] 118 | mod tests { 119 | use super::*; 120 | use crate::{matrix, vector}; 121 | 122 | #[test] 123 | fn test_vector_creation() { 124 | let v: Dimensional, 1> = vector![1, 2, 3, 4, 5]; 125 | assert_eq!(v.shape(), [5]); 126 | assert_eq!(v[[0]], 1); 127 | assert_eq!(v[[4]], 5); 128 | } 129 | 130 | #[test] 131 | fn test_vector_indexing() { 132 | let v = vector![10, 20, 30, 40, 50]; 133 | assert_eq!(v[[0]], 10); 134 | assert_eq!(v[[2]], 30); 135 | assert_eq!(v[[4]], 50); 136 | } 137 | 138 | #[test] 139 | fn test_vector_iteration() { 140 | let v = vector![1, 2, 3, 4, 5]; 141 | let sum: i32 = v.iter().sum(); 142 | assert_eq!(sum, 15); 143 | } 144 | 145 | #[test] 146 | fn test_matrix_creation() { 147 | let m: Dimensional, 2> = matrix![[1, 2, 3], [4, 5, 6]]; 148 | assert_eq!(m.shape(), [2, 3]); 149 | assert_eq!(m[[0, 0]], 1); 150 | assert_eq!(m[[1, 2]], 6); 151 | } 152 | 153 | #[test] 154 | fn test_matrix_indexing() { 155 | let m = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 156 | assert_eq!(m[[0, 0]], 1); 157 | assert_eq!(m[[1, 1]], 5); 158 | assert_eq!(m[[2, 2]], 9); 159 | } 160 | 161 | #[test] 162 | fn test_matrix_iteration() { 163 | let m = matrix![[1, 2], [3, 4]]; 164 | let sum: i32 = m.iter().sum(); 165 | assert_eq!(sum, 10); 166 | } 167 | 168 | #[test] 169 | fn test_dimensional_properties() { 170 | let v = vector![1, 2, 3, 4, 5]; 171 | assert_eq!(v.ndim(), 1); 172 | assert_eq!(v.len(), 5); 173 | assert_eq!(v.len_axis(0), 5); 174 | 175 | let m = matrix![[1, 2, 3], [4, 5, 6]]; 176 | assert_eq!(m.ndim(), 2); 177 | assert_eq!(m.len(), 6); 178 | assert_eq!(m.len_axis(0), 2); 179 | assert_eq!(m.len_axis(1), 3); 180 | } 181 | 182 | #[test] 183 | fn test_dimensional_from_fn() { 184 | let v = Dimensional::<_, LinearArrayStorage<_, 1>, 1>::from_fn([5], |[i]| i * 2); 185 | assert_eq!(v[[0]], 0); 186 | assert_eq!(v[[2]], 4); 187 | assert_eq!(v[[4]], 8); 188 | 189 | let m = Dimensional::<_, LinearArrayStorage<_, 2>, 2>::from_fn([3, 3], |[i, j]| i + j); 190 | assert_eq!(m[[0, 0]], 0); 191 | assert_eq!(m[[1, 1]], 2); 192 | assert_eq!(m[[2, 2]], 4); 193 | } 194 | 195 | #[test] 196 | fn test_dimensional_zeros_and_ones() { 197 | let v_zeros = Dimensional::, 1>::zeros([5]); 198 | assert_eq!(v_zeros.iter().sum::(), 0); 199 | 200 | let v_ones = Dimensional::, 1>::ones([5]); 201 | assert_eq!(v_ones.iter().sum::(), 5); 202 | 203 | let m_zeros = Dimensional::, 2>::zeros([3, 3]); 204 | assert_eq!(m_zeros.iter().sum::(), 0); 205 | 206 | let m_ones = Dimensional::, 2>::ones([3, 3]); 207 | assert_eq!(m_ones.iter().sum::(), 9); 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /src/operators.rs: -------------------------------------------------------------------------------- 1 | use crate::{storage::DimensionalStorage, Dimensional}; 2 | use num_traits::Num; 3 | use std::ops::{ 4 | Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign, 5 | }; 6 | 7 | /// Implements indexing operations for Dimensional arrays. 8 | 9 | impl Index<[usize; N]> for Dimensional 10 | where 11 | S: DimensionalStorage, 12 | { 13 | type Output = T; 14 | 15 | /// Returns an index into the array using a multidimensional index à la [i, j, k]. 16 | fn index(&self, index: [usize; N]) -> &Self::Output { 17 | // TODO(This is too tightly coupled to the storage layout) 18 | &self.storage[index] 19 | } 20 | } 21 | 22 | /// Implements mutable indexing operations for Dimensional arrays. 23 | impl IndexMut<[usize; N]> for Dimensional 24 | where 25 | S: DimensionalStorage, 26 | { 27 | /// Returns a mutable index into the array using a multidimensional index à la [i, j, k]. 28 | fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output { 29 | // TODO(This is too tightly coupled to the storage layout) 30 | &mut self.storage[index] 31 | } 32 | } 33 | 34 | /// Implements partial equality comparison for Dimensional arrays. 35 | impl PartialEq for Dimensional 36 | where 37 | S: DimensionalStorage, 38 | { 39 | /// Compares two `Dimensional` arrays for partial equality. 40 | fn eq(&self, other: &Self) -> bool { 41 | if self.shape != other.shape { 42 | return false; 43 | } 44 | 45 | // TODO(Benchmark copying these to slice vs not) 46 | self.as_slice() == other.as_slice() 47 | } 48 | } 49 | 50 | // TODO(These operators in general need an audit and correction) 51 | // for varying operations between scalars, vectors, matrices and tensors. 52 | // 53 | // Some will fit neatly into the rust operators 54 | // Some will need to be implemented as methods 55 | // 56 | // Scalars: 57 | // Addition: a + b 58 | // Subtraction: a - b 59 | // Multiplication: a * b 60 | // Division: a / b 61 | // Vectors: 62 | // Addition: (a + b)_i = a_i + b_i 63 | // Subtraction: (a - b)_i = a_i - b_i 64 | // Scalar Multiplication: (ca)_i = c * a_i 65 | // Dot Product: a · b = Σ a_i * b_i 66 | // Cross Product (3D): (a × b)_1 = a_2b_3 - a_3b_2, (a × b)_2 = a_3b_1 - a_1b_3, (a × b)_3 = a_1b_2 - a_2b_1 67 | // Tensor Product: (a ⊗ b)_ij = a_i * b_j 68 | // Matrices: 69 | // Addition: (A + B)_ij = A_ij + B_ij 70 | // Subtraction: (A - B)_ij = A_ij - B_ij 71 | // Scalar Multiplication: (cA)_ij = c * A_ij 72 | // Matrix Multiplication: (AB)_ij = Σ A_ik * B_kj 73 | // Transpose: (A^T)_ij = A_ji 74 | // Inverse: AA^(-1) = A^(-1)A = I 75 | // Trace: tr(A) = Σ A_ii 76 | // Determinant: det(A) 77 | // Tensors: 78 | // Addition: (A + B)_i1i2...in = A_i1i2...in + B_i1i2...in 79 | // Subtraction: (A - B)_i1i2...in = A_i1i2...in - B_i1i2...in 80 | // Scalar Multiplication: (cA)_i1i2...in = c * A_i1i2...in 81 | // Hadamard Product: (A ⊙ B)_i1i2...in = A_i1i2...in * B_i1i2...in 82 | // Tensor Product: (A ⊗ B)_i1...in,j1...jm = A_i1...in * B_j1...jm 83 | // Transpose: (A^T)_i1...ik...in = A_i1...in...ik 84 | // Contraction: Σ A_i1...ik...in 85 | // There is also broadcasting between different shapes to consider. 86 | 87 | // Scalar arithmetic operations 88 | 89 | /// Implements scalar addition for Dimensional arrays. 90 | impl Add for &Dimensional 91 | where 92 | S: DimensionalStorage, 93 | { 94 | type Output = Dimensional; 95 | 96 | /// Adds a scalar to a `Dimensional` array element-wise. 97 | fn add(self, rhs: T) -> Self::Output { 98 | self.map(|x| x + rhs) 99 | } 100 | } 101 | 102 | /// Implements scalar subtraction for Dimensional arrays. 103 | impl Sub for &Dimensional 104 | where 105 | S: DimensionalStorage, 106 | { 107 | type Output = Dimensional; 108 | 109 | /// Subtracts a scalar from a `Dimensional` array element-wise. 110 | fn sub(self, rhs: T) -> Self::Output { 111 | self.map(|x| x - rhs) 112 | } 113 | } 114 | 115 | /// Implements scalar multiplication for Dimensional arrays. 116 | impl Mul for &Dimensional 117 | where 118 | S: DimensionalStorage, 119 | { 120 | type Output = Dimensional; 121 | 122 | /// Multiplies a `Dimensional` array by a scalar element-wise. 123 | fn mul(self, rhs: T) -> Self::Output { 124 | self.map(|x| x * rhs) 125 | } 126 | } 127 | 128 | /// Implements scalar division for Dimensional arrays. 129 | impl Div for &Dimensional 130 | where 131 | S: DimensionalStorage, 132 | { 133 | type Output = Dimensional; 134 | 135 | // Divides a `Dimensional` array by a scalar element-wise. 136 | fn div(self, rhs: T) -> Self::Output { 137 | self.map(|x| x / rhs) 138 | } 139 | } 140 | 141 | // Element-wise operations 142 | 143 | /// Implements element-wise addition for Dimensional arrays. 144 | impl Add for &Dimensional 145 | where 146 | S: DimensionalStorage, 147 | { 148 | type Output = Dimensional; 149 | 150 | /// Adds two `Dimensional` arrays element-wise. 151 | fn add(self, rhs: Self) -> Self::Output { 152 | assert_eq!( 153 | self.shape(), 154 | rhs.shape(), 155 | "Shapes must match for element-wise addition" 156 | ); 157 | self.zip_map(rhs, |a, b| a + b) 158 | } 159 | } 160 | 161 | /// Implements element-wise subtraction for Dimensional arrays. 162 | impl Sub for &Dimensional 163 | where 164 | S: DimensionalStorage, 165 | { 166 | type Output = Dimensional; 167 | 168 | /// Subtracts one `Dimensional` array from another element-wise. 169 | fn sub(self, rhs: Self) -> Self::Output { 170 | assert_eq!( 171 | self.shape(), 172 | rhs.shape(), 173 | "Shapes must match for element-wise subtraction" 174 | ); 175 | self.zip_map(rhs, |a, b| a - b) 176 | } 177 | } 178 | 179 | /// Implements element-wise multiplication for Dimensional arrays. 180 | impl Mul for &Dimensional 181 | where 182 | S: DimensionalStorage, 183 | { 184 | type Output = Dimensional; 185 | 186 | /// Multiplies two `Dimensional` arrays element-wise. 187 | fn mul(self, rhs: Self) -> Self::Output { 188 | assert_eq!( 189 | self.shape(), 190 | rhs.shape(), 191 | "Shapes must match for element-wise multiplication" 192 | ); 193 | self.zip_map(rhs, |a, b| a * b) 194 | } 195 | } 196 | 197 | /// Implements element-wise division for Dimensional arrays. 198 | impl Div for &Dimensional 199 | where 200 | S: DimensionalStorage, 201 | { 202 | type Output = Dimensional; 203 | 204 | /// Divides one `Dimensional` array by another element-wise. 205 | fn div(self, rhs: Self) -> Self::Output { 206 | assert_eq!( 207 | self.shape(), 208 | rhs.shape(), 209 | "Shapes must match for element-wise division" 210 | ); 211 | self.zip_map(rhs, |a, b| a / b) 212 | } 213 | } 214 | 215 | // Matrix operations 216 | 217 | /// Implements matrix multiplication for Dimensional arrays. 218 | impl Dimensional 219 | where 220 | S: DimensionalStorage, 221 | { 222 | /// Multiplies two matrices. 223 | pub fn dot(&self, rhs: &Self) -> Self { 224 | assert_eq!( 225 | self.shape()[1], 226 | rhs.shape()[0], 227 | "Matrix dimensions must match for multiplication" 228 | ); 229 | let (rows, cols) = (self.shape()[0], rhs.shape()[1]); 230 | 231 | Self::from_fn([rows, cols], |[i, j]| { 232 | (0..self.shape()[1]).fold(T::zero(), |sum, k| sum + self[[i, k]] * rhs[[k, j]]) 233 | }) 234 | } 235 | } 236 | 237 | // TODO Find a zero copy way to do transpose. 238 | /// Implements matrix transpose for Dimensional arrays. 239 | impl Dimensional 240 | where 241 | S: DimensionalStorage, 242 | { 243 | /// Transposes a matrix. 244 | pub fn transpose(&self) -> Self { 245 | let (rows, cols) = (self.shape()[1], self.shape()[0]); 246 | Self::from_fn([rows, cols], |[i, j]| self[[j, i]]) 247 | } 248 | } 249 | 250 | /// Implements matrix trace for Dimensional arrays. 251 | impl Dimensional 252 | where 253 | S: DimensionalStorage, 254 | { 255 | /// Computes the trace of a matrix. 256 | pub fn trace(&self) -> T { 257 | assert_eq!( 258 | self.shape()[0], 259 | self.shape()[1], 260 | "Matrix must be square to compute trace" 261 | ); 262 | (0..self.shape()[0]).fold(T::zero(), |sum, i| sum + self[[i, i]]) 263 | } 264 | } 265 | 266 | // Assignment operations 267 | 268 | /// Implements scalar addition assignment for Dimensional arrays. 269 | impl AddAssign for Dimensional 270 | where 271 | S: DimensionalStorage, 272 | { 273 | /// Adds a scalar to a `Dimensional` array element-wise in-place. 274 | fn add_assign(&mut self, rhs: T) { 275 | self.map_inplace(|x| *x += rhs); 276 | } 277 | } 278 | 279 | /// Implements scalar subtraction assignment for Dimensional arrays. 280 | impl SubAssign for Dimensional 281 | where 282 | S: DimensionalStorage, 283 | { 284 | /// Subtracts a scalar from a `Dimensional` array element-wise in-place. 285 | fn sub_assign(&mut self, rhs: T) { 286 | self.map_inplace(|x| *x -= rhs); 287 | } 288 | } 289 | 290 | /// Implements scalar multiplication assignment for Dimensional arrays. 291 | impl MulAssign for Dimensional 292 | where 293 | S: DimensionalStorage, 294 | { 295 | /// Multiplies a `Dimensional` array by a scalar element-wise in-place. 296 | fn mul_assign(&mut self, rhs: T) { 297 | self.map_inplace(|x| *x *= rhs); 298 | } 299 | } 300 | 301 | /// Implements scalar division assignment for Dimensional arrays. 302 | impl DivAssign for Dimensional 303 | where 304 | S: DimensionalStorage, 305 | { 306 | /// Divides a `Dimensional` array by a scalar element-wise in-place. 307 | fn div_assign(&mut self, rhs: T) { 308 | self.map_inplace(|x| *x /= rhs); 309 | } 310 | } 311 | 312 | /// Implements element-wise addition assignment for Dimensional arrays. 313 | impl AddAssign<&Dimensional> 314 | for Dimensional 315 | where 316 | S: DimensionalStorage, 317 | { 318 | /// Adds two `Dimensional` arrays element-wise in-place. 319 | fn add_assign(&mut self, rhs: &Dimensional) { 320 | assert_eq!( 321 | self.shape, rhs.shape, 322 | "Shapes must match for element-wise addition assignment" 323 | ); 324 | self.zip_map_inplace(rhs, |a, b| *a += b); 325 | } 326 | } 327 | 328 | /// Implements element-wise subtraction assignment for Dimensional arrays. 329 | impl SubAssign<&Dimensional> 330 | for Dimensional 331 | where 332 | S: DimensionalStorage, 333 | { 334 | /// Subtracts one `Dimensional` array from another element-wise in-place. 335 | fn sub_assign(&mut self, rhs: &Dimensional) { 336 | assert_eq!( 337 | self.shape, rhs.shape, 338 | "Shapes must match for element-wise subtraction assignment" 339 | ); 340 | self.zip_map_inplace(rhs, |a, b| *a -= b); 341 | } 342 | } 343 | 344 | /// Implements element-wise multiplication assignment for Dimensional arrays. 345 | impl MulAssign<&Dimensional> 346 | for Dimensional 347 | where 348 | S: DimensionalStorage, 349 | { 350 | /// Multiplies two `Dimensional` arrays element-wise in-place. 351 | fn mul_assign(&mut self, rhs: &Dimensional) { 352 | assert_eq!( 353 | self.shape, rhs.shape, 354 | "Shapes must match for element-wise multiplication assignment" 355 | ); 356 | self.zip_map_inplace(rhs, |a, b| *a *= b); 357 | } 358 | } 359 | 360 | /// Implements element-wise division assignment for Dimensional arrays. 361 | impl DivAssign<&Dimensional> 362 | for Dimensional 363 | where 364 | S: DimensionalStorage, 365 | { 366 | /// Divides one `Dimensional` array by another element-wise in-place. 367 | fn div_assign(&mut self, rhs: &Dimensional) { 368 | assert_eq!( 369 | self.shape, rhs.shape, 370 | "Shapes must match for element-wise division assignment" 371 | ); 372 | self.zip_map_inplace(rhs, |a, b| *a /= b); 373 | } 374 | } 375 | 376 | // Implement unary negation for references 377 | impl, S, const N: usize> Neg for &Dimensional 378 | where 379 | S: DimensionalStorage, 380 | { 381 | type Output = Dimensional; 382 | 383 | /// Negates a `Dimensional` array element-wise. 384 | fn neg(self) -> Self::Output { 385 | self.map(|x| -x) 386 | } 387 | } 388 | 389 | // TODO How much are these helper abstractions really helping? 390 | // Seems like .zip .map etc should do it without these. 391 | // We don't want bloat, we want a razor sharp and performant tool. 392 | // We can likely create a map/zip/collect implementation or override 393 | // to make this better. 394 | 395 | impl Dimensional 396 | where 397 | T: Num + Copy, 398 | S: DimensionalStorage, 399 | { 400 | /// Applies a function to each element of the array, creating a new array. 401 | fn map(&self, f: F) -> Self 402 | where 403 | F: Fn(T) -> T, 404 | { 405 | Self::from_fn(self.shape, |idx| f(self[idx])) 406 | } 407 | 408 | /// Applies a function to each element of the array in-place. 409 | fn map_inplace(&mut self, f: F) 410 | where 411 | F: Fn(&mut T), 412 | { 413 | for x in self.as_mut_slice() { 414 | f(x); 415 | } 416 | } 417 | 418 | /// Applies a function to pairs of elements from two arrays, creating a new array. 419 | fn zip_map(&self, other: &Self, f: F) -> Self 420 | where 421 | F: Fn(T, T) -> T, 422 | { 423 | assert_eq!( 424 | self.shape, other.shape, 425 | "Shapes must match for zip_map operation" 426 | ); 427 | Self::from_fn(self.shape, |idx| f(self[idx], other[idx])) 428 | } 429 | 430 | /// Applies a function to pairs of elements from two arrays in-place. 431 | fn zip_map_inplace(&mut self, other: &Self, f: F) 432 | where 433 | F: Fn(&mut T, T), 434 | { 435 | assert_eq!( 436 | self.shape, other.shape, 437 | "Shapes must match for zip_map_inplace operation" 438 | ); 439 | for (a, &b) in self.as_mut_slice().iter_mut().zip(other.as_slice().iter()) { 440 | f(a, b); 441 | } 442 | } 443 | } 444 | 445 | #[cfg(test)] 446 | mod tests { 447 | use super::*; 448 | use crate::{matrix, vector, LinearArrayStorage}; 449 | 450 | #[test] 451 | fn test_scalar_operations() { 452 | let v = vector![1, 2, 3, 4, 5]; 453 | 454 | assert_eq!(&v + 1, vector![2, 3, 4, 5, 6]); 455 | assert_eq!(&v - 1, vector![0, 1, 2, 3, 4]); 456 | assert_eq!(&v * 2, vector![2, 4, 6, 8, 10]); 457 | assert_eq!(&v / 2, vector![0, 1, 1, 2, 2]); // Integer division 458 | } 459 | 460 | #[test] 461 | fn test_element_wise_operations() { 462 | let v1 = vector![1, 2, 3, 4, 5]; 463 | let v2 = vector![5, 4, 3, 2, 1]; 464 | 465 | assert_eq!(&v1 + &v2, vector![6, 6, 6, 6, 6]); 466 | assert_eq!(&v1 - &v2, vector![-4, -2, 0, 2, 4]); 467 | assert_eq!(&v1 * &v2, vector![5, 8, 9, 8, 5]); 468 | assert_eq!(&v1 / &v2, vector![0, 0, 1, 2, 5]); // Integer division 469 | } 470 | 471 | #[test] 472 | fn test_assignment_operations() { 473 | let mut v = vector![1, 2, 3, 4, 5]; 474 | 475 | v += 1; 476 | assert_eq!(v, vector![2, 3, 4, 5, 6]); 477 | 478 | v -= 1; 479 | assert_eq!(v, vector![1, 2, 3, 4, 5]); 480 | 481 | v *= 2; 482 | assert_eq!(v, vector![2, 4, 6, 8, 10]); 483 | 484 | v /= 2; 485 | assert_eq!(v, vector![1, 2, 3, 4, 5]); 486 | } 487 | 488 | #[test] 489 | fn test_element_wise_assignment_operations() { 490 | let mut v1 = vector![1, 2, 3, 4, 5]; 491 | let v2 = vector![5, 4, 3, 2, 1]; 492 | 493 | v1 += &v2; 494 | assert_eq!(v1, vector![6, 6, 6, 6, 6]); 495 | 496 | v1 -= &v2; 497 | assert_eq!(v1, vector![1, 2, 3, 4, 5]); 498 | 499 | v1 *= &v2; 500 | assert_eq!(v1, vector![5, 8, 9, 8, 5]); 501 | 502 | v1 /= &v2; 503 | assert_eq!(v1, vector![1, 2, 3, 4, 5]); 504 | } 505 | 506 | #[test] 507 | fn test_negation() { 508 | let v = vector![1, -2, 3, -4, 5]; 509 | assert_eq!(-&v, vector![-1, 2, -3, 4, -5]); 510 | } 511 | 512 | #[test] 513 | fn test_matrix_operations() { 514 | let m1 = matrix![[1, 2], [3, 4]]; 515 | let m2 = matrix![[5, 6], [7, 8]]; 516 | 517 | assert_eq!(&m1 + &m2, matrix![[6, 8], [10, 12]]); 518 | assert_eq!(&m1 - &m2, matrix![[-4, -4], [-4, -4]]); 519 | assert_eq!(&m1 * &m2, matrix![[5, 12], [21, 32]]); 520 | assert_eq!(&m1 / &m2, matrix![[0, 0], [0, 0]]); // Integer division 521 | 522 | let mut m3 = m1.clone(); 523 | m3 += 1; 524 | assert_eq!(m3, matrix![[2, 3], [4, 5]]); 525 | 526 | m3 -= 1; 527 | assert_eq!(m3, m1); 528 | 529 | m3 *= 2; 530 | assert_eq!(m3, matrix![[2, 4], [6, 8]]); 531 | 532 | m3 /= 2; 533 | assert_eq!(m3, m1); 534 | 535 | m3 += &m2; 536 | assert_eq!(m3, matrix![[6, 8], [10, 12]]); 537 | 538 | m3 -= &m2; 539 | assert_eq!(m3, m1); 540 | 541 | m3 *= &m2; 542 | assert_eq!(m3, matrix![[5, 12], [21, 32]]); 543 | 544 | // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division 545 | } 546 | 547 | #[test] 548 | fn test_mixed_dimensional_operations() { 549 | let v = vector![1, 2, 3]; 550 | let m = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; 551 | 552 | assert_eq!(&v + 1, vector![2, 3, 4]); 553 | assert_eq!(&m + 1, matrix![[2, 3, 4], [5, 6, 7], [8, 9, 10]]); 554 | 555 | assert_eq!(&v * 2, vector![2, 4, 6]); 556 | assert_eq!(&m * 2, matrix![[2, 4, 6], [8, 10, 12], [14, 16, 18]]); 557 | } 558 | 559 | #[test] 560 | #[should_panic(expected = "Shapes must match for element-wise addition")] 561 | fn test_mismatched_shapes_addition() { 562 | let v1 = vector![1, 2, 3]; 563 | let v2 = vector![1, 2, 3, 4]; 564 | let _ = &v1 + &v2; 565 | } 566 | 567 | #[test] 568 | #[should_panic(expected = "Shapes must match for element-wise multiplication")] 569 | fn test_mismatched_shapes_multiplication() { 570 | let m1 = matrix![[1, 2], [3, 4]]; 571 | let m2 = matrix![[1, 2, 3], [4, 5, 6]]; 572 | let _ = &m1 * &m2; 573 | } 574 | 575 | #[test] 576 | fn test_scalar_operations_with_floats() { 577 | let v: Dimensional, 1> = vector![1.0, 2.0, 3.0, 4.0, 5.0]; 578 | 579 | assert_eq!(&v + 1.5, vector![2.5, 3.5, 4.5, 5.5, 6.5]); 580 | assert_eq!(&v - 0.5, vector![0.5, 1.5, 2.5, 3.5, 4.5]); 581 | assert_eq!(&v * 2.0, vector![2.0, 4.0, 6.0, 8.0, 10.0]); 582 | assert_eq!(&v / 2.0, vector![0.5, 1.0, 1.5, 2.0, 2.5]); 583 | } 584 | 585 | #[test] 586 | fn test_element_wise_operations_with_floats() { 587 | let v1: Dimensional, 1> = vector![1.0, 2.0, 3.0, 4.0, 5.0]; 588 | let v2: Dimensional, 1> = vector![0.5, 1.0, 1.5, 2.0, 2.5]; 589 | 590 | assert_eq!(&v1 + &v2, vector![1.5, 3.0, 4.5, 6.0, 7.5]); 591 | assert_eq!(&v1 - &v2, vector![0.5, 1.0, 1.5, 2.0, 2.5]); 592 | assert_eq!(&v1 * &v2, vector![0.5, 2.0, 4.5, 8.0, 12.5]); 593 | assert_eq!(&v1 / &v2, vector![2.0, 2.0, 2.0, 2.0, 2.0]); 594 | } 595 | 596 | #[test] 597 | fn test_negation_with_floats() { 598 | let v: Dimensional, 1> = vector![1.5, -2.5, 3.5, -4.5, 5.5]; 599 | assert_eq!(-&v, vector![-1.5, 2.5, -3.5, 4.5, -5.5]); 600 | } 601 | 602 | #[test] 603 | fn test_equality() { 604 | let v1 = vector![1, 2, 3, 4, 5]; 605 | let v2 = vector![1, 2, 3, 4, 5]; 606 | let v3 = vector![1, 2, 3, 4, 6]; 607 | 608 | assert_eq!(v1, v2); 609 | assert_ne!(v1, v3); 610 | 611 | let m1 = matrix![[1, 2], [3, 4]]; 612 | let m2 = matrix![[1, 2], [3, 4]]; 613 | let m3 = matrix![[1, 2], [3, 5]]; 614 | 615 | assert_eq!(m1, m2); 616 | assert_ne!(m1, m3); 617 | } 618 | 619 | #[test] 620 | fn test_higher_dimensional_arrays() { 621 | let a1: Dimensional, 3> = 622 | Dimensional::from_fn([2, 2, 2], |[i, j, k]| (i * 4 + j * 2 + k + 1) as i32); 623 | let a2: Dimensional, 3> = 624 | Dimensional::from_fn([2, 2, 2], |[i, j, k]| (8 - i * 4 - j * 2 - k) as i32); 625 | 626 | let sum = &a1 + &a2; 627 | assert_eq!(sum.as_slice(), &[9; 8]); 628 | 629 | let product = &a1 * &a2; 630 | assert_eq!(product.as_slice(), &[8, 14, 18, 20, 20, 18, 14, 8]); 631 | } 632 | 633 | #[test] 634 | fn test_matrix_multiplication() { 635 | // Define a 2x3 matrix 636 | let m1 = matrix![[1, 2, 3], [4, 5, 6]]; 637 | 638 | // Define a 3x2 matrix 639 | let m2 = matrix![[7, 8], [9, 10], [11, 12]]; 640 | 641 | // Expected 2x2 product matrix 642 | let product = matrix![[58, 64], [139, 154]]; 643 | 644 | assert_eq!(m1.dot(&m2), product); 645 | } 646 | 647 | #[test] 648 | fn test_matrix_transpose() { 649 | let m = matrix![[1, 2, 3], [4, 5, 6]]; 650 | let m_t = matrix![[1, 4], [2, 5], [3, 6]]; 651 | assert_eq!(m.transpose(), m_t); 652 | } 653 | 654 | #[test] 655 | fn test_matrix_trace() { 656 | let m = matrix![[1, 2], [3, 4]]; 657 | assert_eq!(m.trace(), 5); 658 | } 659 | } 660 | -------------------------------------------------------------------------------- /src/storage.rs: -------------------------------------------------------------------------------- 1 | use num_traits::Num; 2 | use std::ops::{Index, IndexMut}; 3 | 4 | /// A trait for storage backends for multidimensional arrays. 5 | /// 6 | /// This trait defines methods for creating arrays filled with zeros or ones, 7 | /// and for creating an array from a vector of data. 8 | /// 9 | /// # Type Parameters 10 | /// 11 | /// * `T`: The element type of the array. Must implement `Num` and `Copy`. 12 | /// * `N`: The number of dimensions of the array. 13 | pub trait DimensionalStorage: 14 | Index<[usize; N], Output = T> + IndexMut<[usize; N], Output = T> 15 | { 16 | /// Creates an array filled with zeros. 17 | /// 18 | /// # Arguments 19 | /// 20 | /// * `shape`: The shape of the array. 21 | fn zeros(shape: [usize; N]) -> Self; 22 | 23 | /// Creates an array filled with ones. 24 | /// 25 | /// # Arguments 26 | /// 27 | /// * `shape`: The shape of the array. 28 | fn ones(shape: [usize; N]) -> Self; 29 | 30 | /// Creates an array from a vector of data. 31 | /// 32 | /// # Arguments 33 | /// 34 | /// * `shape`: The shape of the array. 35 | /// * `data`: The data to initialize the array with. 36 | fn from_vec(shape: [usize; N], data: Vec) -> Self; 37 | 38 | /// Returns the total number of elements in the storage. 39 | fn len(&self) -> usize; 40 | 41 | /// Checks if the storage is empty. 42 | fn is_empty(&self) -> bool { 43 | self.len() == 0 44 | } 45 | 46 | /// Returns a mutable slice of the underlying data from storage. 47 | fn as_mut_slice(&mut self) -> &mut [T]; 48 | 49 | /// Returns an immutable slice of the underlying data from storage. 50 | fn as_slice(&self) -> &[T]; 51 | } 52 | 53 | /// An enum representing the memory layout of a linear array. 54 | #[derive(Debug, Copy, Clone, PartialEq)] 55 | pub enum LinearArrayLayout { 56 | /// Row-major layout (default). 57 | RowMajor, 58 | /// Column-major layout. 59 | ColumnMajor, 60 | } 61 | 62 | /// A linear array storage backend for multidimensional arrays. 63 | /// 64 | /// This struct stores the array data in a contiguous block of memory, 65 | /// using either row-major or column-major layout. 66 | /// 67 | /// # Type Parameters 68 | /// 69 | /// * `T`: The element type of the array. Must implement `Num` and `Copy`. 70 | /// * `N`: The number of dimensions of the array. 71 | #[derive(Debug, Clone, PartialEq)] 72 | pub struct LinearArrayStorage { 73 | data: Vec, 74 | shape: [usize; N], 75 | layout: LinearArrayLayout, 76 | strides: [usize; N], 77 | len: usize, 78 | } 79 | 80 | impl LinearArrayStorage { 81 | /// Computes the strides for a given shape and layout. 82 | /// 83 | /// In this implementation, strides represent the number of elements (not bytes) to skip 84 | /// in each dimension when traversing the array. This approach simplifies indexing calculations 85 | /// while still providing efficient access to elements. 86 | /// 87 | /// # Arguments 88 | /// 89 | /// * `shape`: The shape of the array. 90 | /// * `layout`: The memory layout of the array. 91 | /// 92 | /// # Returns 93 | /// 94 | /// An array of strides, where each stride represents the number of elements to skip 95 | /// in the corresponding dimension. 96 | fn compute_strides(shape: &[usize; N], layout: &LinearArrayLayout) -> [usize; N] { 97 | let mut strides = [0; N]; 98 | match layout { 99 | LinearArrayLayout::RowMajor => { 100 | strides[N - 1] = 1; 101 | for i in (0..N - 1).rev() { 102 | strides[i] = strides[i + 1] * shape[i + 1]; 103 | } 104 | } 105 | LinearArrayLayout::ColumnMajor => { 106 | strides[0] = 1; 107 | for i in 1..N { 108 | strides[i] = strides[i - 1] * shape[i - 1]; 109 | } 110 | } 111 | } 112 | strides 113 | } 114 | 115 | /// Computes the linear index for a given multidimensional index. 116 | /// 117 | /// This method calculates the position of an element in the underlying 1D vector 118 | /// based on its multidimensional index and the array's strides. 119 | /// 120 | /// # Arguments 121 | /// 122 | /// * `index`: The multidimensional index. 123 | /// 124 | /// # Returns 125 | /// 126 | /// The linear index in the underlying data vector. 127 | fn layout_index(&self, index: [usize; N]) -> usize { 128 | index 129 | .iter() 130 | .zip(self.strides.iter()) 131 | .map(|(&i, &stride)| i * stride) 132 | .sum() 133 | } 134 | 135 | /// Creates a new `LinearArrayStorage` with the given parameters. 136 | /// 137 | /// # Arguments 138 | /// 139 | /// * `shape`: The shape of the array. 140 | /// * `data`: The data to initialize the array with. 141 | /// * `layout`: The memory layout of the array. 142 | /// 143 | /// # Panics 144 | /// 145 | /// Panic if the length of `data` doesn't match the product of dimensions in `shape`. 146 | pub fn new(shape: [usize; N], data: Vec, layout: LinearArrayLayout) -> Self { 147 | assert_eq!( 148 | shape.iter().product::(), 149 | data.len(), 150 | "Data length must match the product of shape dimensions" 151 | ); 152 | let strides = Self::compute_strides(&shape, &layout); 153 | let len = data.len(); 154 | Self { 155 | data, 156 | shape, 157 | layout, 158 | strides, 159 | len, 160 | } 161 | } 162 | 163 | /// Returns the shape of the array. 164 | pub fn shape(&self) -> &[usize; N] { 165 | &self.shape 166 | } 167 | 168 | /// Returns the layout of the array. 169 | pub fn layout(&self) -> LinearArrayLayout { 170 | self.layout 171 | } 172 | 173 | /// Returns the strides of the array. 174 | pub fn strides(&self) -> &[usize; N] { 175 | &self.strides 176 | } 177 | } 178 | 179 | impl Index<[usize; N]> for LinearArrayStorage { 180 | type Output = T; 181 | 182 | fn index(&self, index: [usize; N]) -> &Self::Output { 183 | let linear_index = self.layout_index(index); 184 | &self.data[linear_index] 185 | } 186 | } 187 | 188 | impl IndexMut<[usize; N]> for LinearArrayStorage { 189 | fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output { 190 | let linear_index = self.layout_index(index); 191 | &mut self.data[linear_index] 192 | } 193 | } 194 | 195 | impl DimensionalStorage for LinearArrayStorage { 196 | fn zeros(shape: [usize; N]) -> Self { 197 | let data = vec![T::zero(); shape.iter().product::()]; 198 | LinearArrayStorage::new(shape, data, LinearArrayLayout::RowMajor) 199 | } 200 | 201 | fn ones(shape: [usize; N]) -> Self { 202 | let data = vec![T::one(); shape.iter().product::()]; 203 | LinearArrayStorage::new(shape, data, LinearArrayLayout::RowMajor) 204 | } 205 | 206 | fn from_vec(shape: [usize; N], data: Vec) -> Self { 207 | LinearArrayStorage::new(shape, data, LinearArrayLayout::RowMajor) 208 | } 209 | 210 | fn len(&self) -> usize { 211 | self.len 212 | } 213 | 214 | fn as_mut_slice(&mut self) -> &mut [T] { 215 | &mut self.data 216 | } 217 | 218 | fn as_slice(&self) -> &[T] { 219 | &self.data 220 | } 221 | } 222 | 223 | #[cfg(test)] 224 | mod tests { 225 | use super::*; 226 | 227 | #[test] 228 | fn test_zeros_and_ones() { 229 | let zeros = LinearArrayStorage::::zeros([2, 3]); 230 | assert_eq!(zeros.as_slice(), &[0, 0, 0, 0, 0, 0]); 231 | 232 | let ones = LinearArrayStorage::::ones([2, 3]); 233 | assert_eq!(ones.as_slice(), &[1, 1, 1, 1, 1, 1]); 234 | } 235 | 236 | #[test] 237 | fn test_from_vec() { 238 | let data = vec![1, 2, 3, 4, 5, 6]; 239 | let array = LinearArrayStorage::::from_vec([2, 3], data.clone()); 240 | assert_eq!(array.as_slice(), &data); 241 | } 242 | 243 | #[test] 244 | #[should_panic(expected = "Data length must match the product of shape dimensions")] 245 | fn test_from_vec_wrong_size() { 246 | let data = vec![1, 2, 3, 4, 5]; 247 | LinearArrayStorage::::from_vec([2, 3], data); 248 | } 249 | 250 | #[test] 251 | fn test_indexing() { 252 | let data = vec![1, 2, 3, 4, 5, 6]; 253 | let array = LinearArrayStorage::::from_vec([2, 3], data); 254 | assert_eq!(array[[0, 0]], 1); 255 | assert_eq!(array[[0, 2]], 3); 256 | assert_eq!(array[[1, 1]], 5); 257 | } 258 | 259 | #[test] 260 | fn test_mutable_indexing() { 261 | let data = vec![1, 2, 3, 4, 5, 6]; 262 | let mut array = LinearArrayStorage::::from_vec([2, 3], data); 263 | array[[0, 0]] = 10; 264 | array[[1, 2]] = 20; 265 | assert_eq!(array[[0, 0]], 10); 266 | assert_eq!(array[[1, 2]], 20); 267 | } 268 | 269 | #[test] 270 | fn test_strides_calculation() { 271 | let row_major = 272 | LinearArrayStorage::::new([2, 3, 4], vec![0; 24], LinearArrayLayout::RowMajor); 273 | assert_eq!(row_major.strides(), &[12, 4, 1]); 274 | 275 | let col_major = LinearArrayStorage::::new( 276 | [2, 3, 4], 277 | vec![0; 24], 278 | LinearArrayLayout::ColumnMajor, 279 | ); 280 | assert_eq!(col_major.strides(), &[1, 2, 6]); 281 | } 282 | 283 | #[test] 284 | fn test_layout_index() { 285 | let row_major = LinearArrayStorage::::new( 286 | [2, 3, 4], 287 | (0..24).collect(), 288 | LinearArrayLayout::RowMajor, 289 | ); 290 | assert_eq!(row_major[[0, 0, 0]], 0); 291 | assert_eq!(row_major[[1, 2, 3]], 23); 292 | assert_eq!(row_major[[0, 1, 2]], 6); 293 | 294 | let col_major = LinearArrayStorage::::new( 295 | [2, 3, 4], 296 | (0..24).collect(), 297 | LinearArrayLayout::ColumnMajor, 298 | ); 299 | assert_eq!(col_major[[0, 0, 0]], 0); 300 | assert_eq!(col_major[[1, 2, 3]], 23); 301 | assert_eq!(col_major[[0, 1, 2]], 14); 302 | } 303 | 304 | #[test] 305 | fn test_different_layouts() { 306 | let data: Vec = (0..6).collect(); 307 | 308 | let row_major = LinearArrayStorage::new([2, 3], data.clone(), LinearArrayLayout::RowMajor); 309 | assert_eq!(row_major[[0, 0]], 0); 310 | assert_eq!(row_major[[0, 2]], 2); 311 | assert_eq!(row_major[[1, 0]], 3); 312 | 313 | let col_major = LinearArrayStorage::new([2, 3], data, LinearArrayLayout::ColumnMajor); 314 | assert_eq!(col_major[[0, 0]], 0); 315 | assert_eq!(col_major[[0, 2]], 4); 316 | assert_eq!(col_major[[1, 0]], 1); 317 | } 318 | 319 | #[test] 320 | fn test_as_slice_and_as_mut_slice() { 321 | let mut array = LinearArrayStorage::::from_vec([2, 3], vec![1, 2, 3, 4, 5, 6]); 322 | 323 | assert_eq!(array.as_slice(), &[1, 2, 3, 4, 5, 6]); 324 | 325 | { 326 | let slice = array.as_mut_slice(); 327 | slice[0] = 10; 328 | slice[5] = 60; 329 | } 330 | 331 | assert_eq!(array.as_slice(), &[10, 2, 3, 4, 5, 60]); 332 | } 333 | } 334 | --------------------------------------------------------------------------------