├── ragged_buffer ├── py.typed ├── __init__.py └── __init__.pyi ├── .gitignore ├── monomorphize.sh ├── pyproject.toml ├── Cargo.toml ├── .github └── workflows │ ├── test.yaml │ └── publish-wheel.yaml ├── src ├── lib.rs ├── monomorphs.rs ├── monomorphs │ ├── i64.rs │ ├── f32.rs │ └── bool.rs ├── ragged_buffer.rs └── ragged_buffer_view.rs ├── LICENSE-MIT ├── README.md ├── LICENSE-APACHE └── tests └── test.py /ragged_buffer/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .vscode 4 | **/__pycache__ 5 | **/*.so 6 | -------------------------------------------------------------------------------- /monomorphize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp src/monomorphs/{f32,i64}.rs 4 | sed -i s/F32/I64/g src/monomorphs/i64.rs 5 | sed -i s/f32/i64/g src/monomorphs/i64.rs 6 | cp src/monomorphs/{f32,bool}.rs 7 | sed -i s/F32/Bool/g src/monomorphs/bool.rs 8 | sed -i s/f32/bool/g src/monomorphs/bool.rs 9 | sed -i 's/cfg(all())/cfg(any())/g' src/monomorphs/bool.rs 10 | sed -i "/use crate::monomorphs::RaggedBufferI64;/d" src/monomorphs/i64.rs 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=0.13,<0.14"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "ragged-buffer" 7 | dependencies = ["numpy~=1.21"] 8 | requires-python = ">=3.7" 9 | classifiers = [ 10 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 11 | "Programming Language :: Rust", 12 | "Programming Language :: Python :: Implementation :: CPython", 13 | ] 14 | 15 | [tool.maturin] 16 | features = ["python"] -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Clemens Winter "] 3 | description = "Efficient RaggedBuffer datatype that implements 3D arrays with variable-length 2nd dimension." 4 | edition = "2021" 5 | license = "MIT OR Apache-2.0" 6 | name = "ragged-buffer" 7 | readme = "README.md" 8 | repository = "https://github.com/entity-neural-network/ragged-buffer" 9 | version = "0.4.8" 10 | 11 | [lib] 12 | crate-type = ["cdylib", "rlib"] 13 | name = "ragged_buffer" 14 | 15 | [dependencies] 16 | ndarray = "0.15.4" 17 | numpy = {version = "0.16.2", optional = true} 18 | pyo3 = {version = "0.16.5", features = ["extension-module"], optional = true} 19 | 20 | [profile.release] 21 | debug = true 22 | 23 | [features] 24 | python = ["pyo3", "numpy"] 25 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | name: Test 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: actions-rs/toolchain@v1 18 | with: 19 | profile: minimal 20 | toolchain: stable 21 | override: true 22 | - uses: messense/maturin-action@v1 23 | with: 24 | maturin-version: latest 25 | command: build 26 | args: --features=python 27 | - name: Test 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install target/wheels/ragged_buffer-*.whl 31 | python tests/test.py 32 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "python")] 2 | use pyo3::prelude::*; 3 | #[cfg(feature = "python")] 4 | use pyo3::wrap_pyfunction; 5 | #[cfg(feature = "python")] 6 | pub mod monomorphs; 7 | 8 | pub mod ragged_buffer; 9 | 10 | #[cfg(feature = "python")] 11 | pub mod ragged_buffer_view; 12 | 13 | #[cfg(feature = "python")] 14 | #[pymodule] 15 | fn ragged_buffer(_py: Python, m: &PyModule) -> PyResult<()> { 16 | // New exports also have to be added to __init__.py 17 | m.add_class::()?; 18 | m.add_class::()?; 19 | m.add_class::()?; 20 | m.add_function(wrap_pyfunction!(translate_rotate, m)?)?; 21 | Ok(()) 22 | } 23 | 24 | #[cfg(feature = "python")] 25 | #[pyfunction] 26 | fn translate_rotate( 27 | source: &monomorphs::RaggedBufferF32, 28 | translation: monomorphs::RaggedBufferF32, 29 | rotation: monomorphs::RaggedBufferF32, 30 | ) -> PyResult<()> { 31 | ragged_buffer_view::translate_rotate(&source.0, &translation.0, &rotation.0) 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Clemens Winter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/monomorphs.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::format_push_string)] // Caused by FromPyObject macro 2 | use numpy::PyReadonlyArray1; 3 | use pyo3::types::PySlice; 4 | use pyo3::{FromPyObject, Py, PyResult}; 5 | 6 | mod bool; 7 | mod f32; 8 | mod i64; 9 | 10 | pub use self::bool::RaggedBufferBool; 11 | pub use self::f32::RaggedBufferF32; 12 | pub use self::i64::RaggedBufferI64; 13 | 14 | #[derive(FromPyObject)] 15 | pub enum Index<'a> { 16 | PermutationNP(PyReadonlyArray1<'a, i64>), 17 | Permutation(Vec), 18 | Int(usize), 19 | Slice(Py), 20 | } 21 | 22 | impl<'a> std::fmt::Debug for Index<'a> { 23 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 24 | match self { 25 | Self::PermutationNP(arg0) => f 26 | .debug_tuple("PermutationNP") 27 | .field(&arg0.to_vec().unwrap()) 28 | .finish(), 29 | Self::Permutation(arg0) => f.debug_tuple("Permutation").field(arg0).finish(), 30 | Self::Int(arg0) => f.debug_tuple("Int").field(arg0).finish(), 31 | Self::Slice(arg0) => f.debug_tuple("Slice").field(arg0).finish(), 32 | } 33 | } 34 | } 35 | 36 | #[derive(FromPyObject, Debug)] 37 | pub enum MultiIndex<'a> { 38 | Index1(Index<'a>), 39 | Index2((Index<'a>, Index<'a>)), 40 | Index3((Index<'a>, Index<'a>, Index<'a>)), 41 | } 42 | 43 | type PyArray<'a, T, D> = &'a numpy::PyArray>; 44 | type PadpackResult<'a> = PyResult< 45 | Option<( 46 | PyArray<'a, i64, [usize; 2]>, 47 | PyArray<'a, f32, [usize; 2]>, 48 | PyArray<'a, i64, [usize; 1]>, 49 | )>, 50 | >; 51 | -------------------------------------------------------------------------------- /.github/workflows/publish-wheel.yaml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | 9 | jobs: 10 | publish: 11 | name: Publish for ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | target: 16 | - x86_64-unknown-linux-musl 17 | - x86_64-apple-darwin 18 | - x86_64-pc-windows-msvc 19 | - i686-pc-windows-msvc 20 | - aarch64-pc-windows-msvc 21 | include: 22 | - target: x86_64-unknown-linux-musl 23 | os: ubuntu-latest 24 | - target: x86_64-apple-darwin 25 | os: macos-latest 26 | - target: x86_64-pc-windows-msvc 27 | os: windows-latest 28 | - target: i686-pc-windows-msvc 29 | os: windows-latest 30 | - target: aarch64-pc-windows-msvc 31 | os: windows-latest 32 | 33 | runs-on: ${{ matrix.os }} 34 | environment: PyPI 35 | steps: 36 | - uses: actions/checkout@v2 37 | - uses: actions/setup-python@v2 38 | with: 39 | python-version: 3.7 40 | - uses: actions/setup-python@v2 41 | with: 42 | python-version: 3.8 43 | - uses: actions/setup-python@v2 44 | with: 45 | python-version: 3.9 46 | - uses: actions/setup-python@v2 47 | with: 48 | python-version: "3.10" 49 | - uses: actions-rs/toolchain@v1 50 | with: 51 | profile: minimal 52 | toolchain: stable 53 | override: true 54 | 55 | - name: Publish 56 | uses: messense/maturin-action@v1 57 | env: 58 | MATURIN_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 59 | with: 60 | maturin-version: latest 61 | command: publish 62 | args: --username=__token__ --skip-existing --features=python --find-interpreter 63 | -------------------------------------------------------------------------------- /ragged_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generic, List, Protocol, Type, TypeVar, Union, cast, overload 2 | from numpy.typing import NDArray 3 | import numpy as np 4 | 5 | from .ragged_buffer import ( 6 | RaggedBufferF32, 7 | RaggedBufferI64, 8 | RaggedBufferBool, 9 | translate_rotate, 10 | ) 11 | 12 | ScalarType = TypeVar("ScalarType", bound=np.generic) 13 | 14 | 15 | class RaggedBuffer(Generic[ScalarType]): 16 | def __init__(self, features: int) -> None: 17 | raise ValueError( 18 | "RaggedBuffer is an abstract class, use RaggedBufferF32 or RaggedBufferI64" 19 | ) 20 | 21 | @classmethod 22 | def from_array(cls, x: NDArray[ScalarType]) -> "RaggedBuffer[ScalarType]": 23 | if x.dtype == np.float32: 24 | return RaggedBufferF32.from_array(x) 25 | elif x.dtype == np.int64: 26 | return RaggedBufferI64.from_array(x) 27 | elif x.dtype == np.bool_: 28 | return RaggedBufferBool.from_array(x) 29 | else: 30 | raise ValueError( 31 | f"Unsupported dtype {x.dtype}. Only float32 and int64 are currently supported." 32 | ) 33 | 34 | @classmethod 35 | def from_flattened( 36 | cls, flattened: NDArray[ScalarType], lengths: NDArray[np.int64] 37 | ) -> "RaggedBuffer[ScalarType]": 38 | if flattened.dtype == np.float32: 39 | return RaggedBufferF32.from_flattened(flattened, lengths) 40 | elif flattened.dtype == np.int64: 41 | return RaggedBufferI64.from_flattened(flattened, lengths) 42 | elif flattened.dtype == np.bool_: 43 | return RaggedBufferBool.from_flattened(flattened, lengths) 44 | else: 45 | raise ValueError( 46 | f"Unsupported dtype {flattened.dtype}. Only float32 and int64 are currently supported." 47 | ) 48 | 49 | 50 | def cat( 51 | buffers: List[RaggedBuffer[ScalarType]], dim: int = 0 52 | ) -> RaggedBuffer[ScalarType]: 53 | if len(buffers) == 0: 54 | raise ValueError("Can't concatenate an empty list of buffers") 55 | else: 56 | if isinstance(buffers[0], RaggedBufferF32): 57 | return RaggedBufferF32.cat(buffers, dim) 58 | elif isinstance(buffers[0], RaggedBufferI64): 59 | return RaggedBufferI64.cat(buffers, dim) 60 | elif isinstance(buffers[0], RaggedBufferBool): 61 | return RaggedBufferBool.cat(buffers, dim) 62 | else: 63 | raise TypeError(f"Type {type(buffers[0])} is not a RaggedBuffer") 64 | -------------------------------------------------------------------------------- /ragged_buffer/__init__.pyi: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Any, Generic, Tuple, TypeVar, Union, overload, List, Optional 3 | from numpy.typing import NDArray 4 | import numpy as np 5 | 6 | ScalarType = TypeVar("ScalarType", bound=np.generic) 7 | 8 | class RaggedBuffer(Generic[ScalarType]): 9 | def __init__(self, features: int) -> None: ... 10 | @classmethod 11 | def from_array(cls, x: NDArray[ScalarType]) -> RaggedBuffer[ScalarType]: ... 12 | @classmethod 13 | def from_flattened( 14 | cls, flattened: NDArray[ScalarType], lengths: NDArray[np.int64] 15 | ) -> RaggedBuffer[ScalarType]: ... 16 | def push(self, x: NDArray[ScalarType]) -> None: ... 17 | def push_empty(self) -> None: ... 18 | def extend(self, x: RaggedBuffer[ScalarType]) -> None: ... 19 | def as_array(self) -> NDArray[ScalarType]: ... 20 | def size0(self) -> int: ... 21 | @overload 22 | def size1(self) -> NDArray[np.int64]: ... 23 | @overload 24 | def size1(self, i: int) -> int: ... 25 | def size2(self) -> int: ... 26 | @overload 27 | def __add__(self, other: RaggedBuffer[ScalarType]) -> RaggedBuffer[ScalarType]: ... 28 | @overload 29 | def __add__(self, other: int) -> RaggedBuffer[ScalarType]: ... 30 | @overload 31 | def __add__(self, other: float) -> RaggedBuffer[np.float32]: ... 32 | @overload 33 | def __mul__(self, other: RaggedBuffer[ScalarType]) -> RaggedBuffer[ScalarType]: ... 34 | @overload 35 | def __mul__(self, other: int) -> RaggedBuffer[ScalarType]: ... 36 | @overload 37 | def __mul__(self, other: float) -> RaggedBuffer[np.float32]: ... 38 | def __getitem__( 39 | self, 40 | i: Union[ 41 | int, 42 | NDArray[np.int64], 43 | Tuple[Union[int, List[int], slice, NDArray[np.int64]], ...], 44 | ], 45 | ) -> RaggedBuffer[ScalarType]: ... 46 | def __eq__(self, other: Any) -> bool: ... 47 | def __ne__(self, other: Any) -> bool: ... 48 | def clear(self) -> None: ... 49 | def indices(self, dim: int) -> RaggedBufferI64: ... 50 | def flat_indices(self) -> RaggedBufferI64: ... 51 | def padpack( 52 | self, 53 | ) -> Optional[Tuple[NDArray[np.int64], NDArray[np.float32], NDArray[np.int64]]]: ... 54 | def __isub__(self, other: RaggedBuffer[ScalarType]) -> RaggedBuffer[ScalarType]: ... 55 | def __len__(self) -> int: ... 56 | def items(self) -> int: ... 57 | def clone(self) -> RaggedBuffer[ScalarType]: ... 58 | def materialize(self) -> RaggedBuffer[ScalarType]: ... 59 | 60 | RaggedBufferF32 = RaggedBuffer[np.float32] 61 | RaggedBufferI64 = RaggedBuffer[np.int64] 62 | RaggedBufferBool = RaggedBuffer[np.bool_] 63 | 64 | def cat( 65 | buffers: List[RaggedBuffer[ScalarType]], dim: int = 0 66 | ) -> RaggedBuffer[ScalarType]: ... 67 | def translate_rotate( 68 | source: RaggedBuffer[np.float32], 69 | translation: RaggedBuffer[np.float32], 70 | rotation: RaggedBuffer[np.float32], 71 | ) -> None: ... 72 | -------------------------------------------------------------------------------- /src/monomorphs/i64.rs: -------------------------------------------------------------------------------- 1 | use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArrayDyn, ToPyArray}; 2 | use pyo3::basic::CompareOp; 3 | use pyo3::prelude::*; 4 | use pyo3::types::PyType; 5 | 6 | use crate::ragged_buffer_view::RaggedBufferView; 7 | 8 | use super::{Index, MultiIndex, PadpackResult}; 9 | 10 | #[pyclass] 11 | #[derive(Clone)] 12 | pub struct RaggedBufferI64(pub RaggedBufferView); 13 | 14 | #[pymethods] 15 | impl RaggedBufferI64 { 16 | #[new] 17 | pub fn new(features: usize) -> Self { 18 | RaggedBufferI64(RaggedBufferView::new(features)) 19 | } 20 | #[classmethod] 21 | fn from_array(_cls: &PyType, array: PyReadonlyArray3) -> Self { 22 | RaggedBufferI64(RaggedBufferView::from_array(array)) 23 | } 24 | #[classmethod] 25 | fn from_flattened( 26 | _cls: &PyType, 27 | flattened: PyReadonlyArray2, 28 | lengths: PyReadonlyArray1, 29 | ) -> PyResult { 30 | Ok(RaggedBufferI64(RaggedBufferView::from_flattened( 31 | flattened, lengths, 32 | )?)) 33 | } 34 | fn push(&mut self, items: PyReadonlyArrayDyn) -> PyResult<()> { 35 | if items.ndim() == 1 && items.len() == 0 { 36 | self.0.push_empty() 37 | } else if items.ndim() == 2 { 38 | self.0.push( 39 | &items 40 | .reshape((items.shape()[0], items.shape()[1]))? 41 | .readonly(), 42 | ) 43 | } else { 44 | Err(pyo3::exceptions::PyValueError::new_err( 45 | "Expected 2 dimensional array", 46 | )) 47 | } 48 | } 49 | fn push_empty(&mut self) -> PyResult<()> { 50 | self.0.push_empty() 51 | } 52 | 53 | fn clear(&mut self) -> PyResult<()> { 54 | self.0.clear() 55 | } 56 | 57 | fn as_array<'a>( 58 | &self, 59 | py: Python<'a>, 60 | ) -> PyResult<&'a numpy::PyArray>> { 61 | self.0.as_array(py) 62 | } 63 | 64 | fn extend(&mut self, other: &RaggedBufferI64) -> PyResult<()> { 65 | self.0.extend(&other.0) 66 | } 67 | fn size0(&self) -> usize { 68 | self.0.size0() 69 | } 70 | fn size1(&mut self, py: Python, i: Option) -> PyResult { 71 | match i { 72 | Some(i) => self.0.size1(i).map(|s| s.into_py(py)), 73 | None => self.0.lengths(py).map(|ok| ok.into_py(py)), 74 | } 75 | } 76 | fn size2(&self) -> usize { 77 | self.0.size2() 78 | } 79 | fn indices(&mut self, dim: usize) -> PyResult { 80 | Ok(RaggedBufferI64(self.0.indices(dim)?)) 81 | } 82 | fn flat_indices(&mut self) -> PyResult { 83 | Ok(RaggedBufferI64(self.0.flat_indices()?)) 84 | } 85 | #[classmethod] 86 | fn cat(_cls: &PyType, buffers: Vec>, dim: usize) -> PyResult { 87 | Ok(RaggedBufferI64(RaggedBufferView::cat( 88 | &buffers.iter().map(|b| &b.0).collect::>(), 89 | dim, 90 | )?)) 91 | } 92 | #[allow(clippy::type_complexity)] 93 | fn padpack<'a>(&mut self, py: Python<'a>) -> PadpackResult<'a> { 94 | match self.0.padpack()? { 95 | Some((padbpack_index, padpack_batch, padpack_inverse_index, dims)) => Ok(Some(( 96 | padbpack_index.to_pyarray(py).reshape(dims)?, 97 | padpack_batch.to_pyarray(py).reshape(dims)?, 98 | padpack_inverse_index 99 | .to_pyarray(py) 100 | .reshape(self.0.len()?)?, 101 | ))), 102 | _ => Ok(None), 103 | } 104 | } 105 | fn items(&mut self) -> PyResult { 106 | self.0.items() 107 | } 108 | fn clone(&self) -> Self { 109 | RaggedBufferI64(self.0.deepclone()) 110 | } 111 | fn materialize(&self) -> Self { 112 | RaggedBufferI64(self.0.materialize()) 113 | } 114 | fn __str__(&self) -> PyResult { 115 | self.0.__str__() 116 | } 117 | 118 | fn __repr__(&self) -> PyResult { 119 | self.0.__str__() 120 | } 121 | 122 | fn __richcmp__(&self, other: RaggedBufferI64, op: CompareOp) -> PyResult { 123 | match op { 124 | CompareOp::Eq => Ok(self.0 == other.0), 125 | CompareOp::Ne => Ok(self.0 != other.0), 126 | _ => Err(pyo3::exceptions::PyTypeError::new_err( 127 | "Only == and != are supported", 128 | )), 129 | } 130 | } 131 | 132 | // Is substituted for #[cfg(any())] for bool.rs to omit method 133 | #[cfg(all())] 134 | fn __add__( 135 | lhs: PyRef, 136 | rhs: RaggedBufferI64OrI64, 137 | ) -> PyResult { 138 | match rhs { 139 | RaggedBufferI64OrI64::RB(rhs) => Ok(RaggedBufferI64( 140 | lhs.0.binop::(&rhs.0)?, 141 | )), 142 | RaggedBufferI64OrI64::Scalar(rhs) => Ok(RaggedBufferI64( 143 | lhs.0.op_scalar::(rhs)?, 144 | )), 145 | } 146 | } 147 | 148 | #[cfg(all())] 149 | fn __mul__( 150 | lhs: PyRef, 151 | rhs: RaggedBufferI64OrI64, 152 | ) -> PyResult { 153 | match rhs { 154 | RaggedBufferI64OrI64::RB(rhs) => Ok(RaggedBufferI64( 155 | lhs.0.binop::(&rhs.0)?, 156 | )), 157 | RaggedBufferI64OrI64::Scalar(rhs) => Ok(RaggedBufferI64( 158 | lhs.0.op_scalar::(rhs)?, 159 | )), 160 | } 161 | } 162 | 163 | #[cfg(all())] 164 | fn __isub__(&mut self, rhs: RaggedBufferI64) -> PyResult<()> { 165 | self.0.binop_mut::(&rhs.0) 166 | } 167 | 168 | fn __getitem__(&self, index: MultiIndex) -> PyResult { 169 | match index { 170 | MultiIndex::Index1(index) => match index { 171 | Index::PermutationNP(indices) => Ok(RaggedBufferI64(self.0.swizzle(indices)?)), 172 | Index::Permutation(indices) => Ok(RaggedBufferI64(self.0.swizzle_usize(&indices)?)), 173 | Index::Int(i) => Ok(RaggedBufferI64(self.0.get_sequence(i)?)), 174 | Index::Slice(slice) => panic!("{:?}", slice), 175 | }, 176 | MultiIndex::Index3((i0, i1, i2)) => Ok(RaggedBufferI64(Python::with_gil(|py| { 177 | self.0.get_slice(py, i0, i1, i2) 178 | })?)), 179 | x => panic!("{:?}", x), 180 | } 181 | } 182 | fn __len__(&self) -> PyResult { 183 | self.0.len() 184 | } 185 | } 186 | 187 | #[derive(FromPyObject)] 188 | pub enum RaggedBufferI64OrI64<'p> { 189 | RB(PyRef<'p, RaggedBufferI64>), 190 | Scalar(i64), 191 | } 192 | -------------------------------------------------------------------------------- /src/monomorphs/f32.rs: -------------------------------------------------------------------------------- 1 | use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArrayDyn, ToPyArray}; 2 | use pyo3::basic::CompareOp; 3 | use pyo3::prelude::*; 4 | use pyo3::types::PyType; 5 | 6 | use crate::monomorphs::RaggedBufferI64; 7 | use crate::ragged_buffer_view::RaggedBufferView; 8 | 9 | use super::{Index, MultiIndex, PadpackResult}; 10 | 11 | #[pyclass] 12 | #[derive(Clone)] 13 | pub struct RaggedBufferF32(pub RaggedBufferView); 14 | 15 | #[pymethods] 16 | impl RaggedBufferF32 { 17 | #[new] 18 | pub fn new(features: usize) -> Self { 19 | RaggedBufferF32(RaggedBufferView::new(features)) 20 | } 21 | #[classmethod] 22 | fn from_array(_cls: &PyType, array: PyReadonlyArray3) -> Self { 23 | RaggedBufferF32(RaggedBufferView::from_array(array)) 24 | } 25 | #[classmethod] 26 | fn from_flattened( 27 | _cls: &PyType, 28 | flattened: PyReadonlyArray2, 29 | lengths: PyReadonlyArray1, 30 | ) -> PyResult { 31 | Ok(RaggedBufferF32(RaggedBufferView::from_flattened( 32 | flattened, lengths, 33 | )?)) 34 | } 35 | fn push(&mut self, items: PyReadonlyArrayDyn) -> PyResult<()> { 36 | if items.ndim() == 1 && items.len() == 0 { 37 | self.0.push_empty() 38 | } else if items.ndim() == 2 { 39 | self.0.push( 40 | &items 41 | .reshape((items.shape()[0], items.shape()[1]))? 42 | .readonly(), 43 | ) 44 | } else { 45 | Err(pyo3::exceptions::PyValueError::new_err( 46 | "Expected 2 dimensional array", 47 | )) 48 | } 49 | } 50 | fn push_empty(&mut self) -> PyResult<()> { 51 | self.0.push_empty() 52 | } 53 | 54 | fn clear(&mut self) -> PyResult<()> { 55 | self.0.clear() 56 | } 57 | 58 | fn as_array<'a>( 59 | &self, 60 | py: Python<'a>, 61 | ) -> PyResult<&'a numpy::PyArray>> { 62 | self.0.as_array(py) 63 | } 64 | 65 | fn extend(&mut self, other: &RaggedBufferF32) -> PyResult<()> { 66 | self.0.extend(&other.0) 67 | } 68 | fn size0(&self) -> usize { 69 | self.0.size0() 70 | } 71 | fn size1(&mut self, py: Python, i: Option) -> PyResult { 72 | match i { 73 | Some(i) => self.0.size1(i).map(|s| s.into_py(py)), 74 | None => self.0.lengths(py).map(|ok| ok.into_py(py)), 75 | } 76 | } 77 | fn size2(&self) -> usize { 78 | self.0.size2() 79 | } 80 | fn indices(&mut self, dim: usize) -> PyResult { 81 | Ok(RaggedBufferI64(self.0.indices(dim)?)) 82 | } 83 | fn flat_indices(&mut self) -> PyResult { 84 | Ok(RaggedBufferI64(self.0.flat_indices()?)) 85 | } 86 | #[classmethod] 87 | fn cat(_cls: &PyType, buffers: Vec>, dim: usize) -> PyResult { 88 | Ok(RaggedBufferF32(RaggedBufferView::cat( 89 | &buffers.iter().map(|b| &b.0).collect::>(), 90 | dim, 91 | )?)) 92 | } 93 | #[allow(clippy::type_complexity)] 94 | fn padpack<'a>(&mut self, py: Python<'a>) -> PadpackResult<'a> { 95 | match self.0.padpack()? { 96 | Some((padbpack_index, padpack_batch, padpack_inverse_index, dims)) => Ok(Some(( 97 | padbpack_index.to_pyarray(py).reshape(dims)?, 98 | padpack_batch.to_pyarray(py).reshape(dims)?, 99 | padpack_inverse_index 100 | .to_pyarray(py) 101 | .reshape(self.0.len()?)?, 102 | ))), 103 | _ => Ok(None), 104 | } 105 | } 106 | fn items(&mut self) -> PyResult { 107 | self.0.items() 108 | } 109 | fn clone(&self) -> Self { 110 | RaggedBufferF32(self.0.deepclone()) 111 | } 112 | fn materialize(&self) -> Self { 113 | RaggedBufferF32(self.0.materialize()) 114 | } 115 | fn __str__(&self) -> PyResult { 116 | self.0.__str__() 117 | } 118 | 119 | fn __repr__(&self) -> PyResult { 120 | self.0.__str__() 121 | } 122 | 123 | fn __richcmp__(&self, other: RaggedBufferF32, op: CompareOp) -> PyResult { 124 | match op { 125 | CompareOp::Eq => Ok(self.0 == other.0), 126 | CompareOp::Ne => Ok(self.0 != other.0), 127 | _ => Err(pyo3::exceptions::PyTypeError::new_err( 128 | "Only == and != are supported", 129 | )), 130 | } 131 | } 132 | 133 | // Is substituted for #[cfg(any())] for bool.rs to omit method 134 | #[cfg(all())] 135 | fn __add__( 136 | lhs: PyRef, 137 | rhs: RaggedBufferF32OrF32, 138 | ) -> PyResult { 139 | match rhs { 140 | RaggedBufferF32OrF32::RB(rhs) => Ok(RaggedBufferF32( 141 | lhs.0.binop::(&rhs.0)?, 142 | )), 143 | RaggedBufferF32OrF32::Scalar(rhs) => Ok(RaggedBufferF32( 144 | lhs.0.op_scalar::(rhs)?, 145 | )), 146 | } 147 | } 148 | 149 | #[cfg(all())] 150 | fn __mul__( 151 | lhs: PyRef, 152 | rhs: RaggedBufferF32OrF32, 153 | ) -> PyResult { 154 | match rhs { 155 | RaggedBufferF32OrF32::RB(rhs) => Ok(RaggedBufferF32( 156 | lhs.0.binop::(&rhs.0)?, 157 | )), 158 | RaggedBufferF32OrF32::Scalar(rhs) => Ok(RaggedBufferF32( 159 | lhs.0.op_scalar::(rhs)?, 160 | )), 161 | } 162 | } 163 | 164 | #[cfg(all())] 165 | fn __isub__(&mut self, rhs: RaggedBufferF32) -> PyResult<()> { 166 | self.0.binop_mut::(&rhs.0) 167 | } 168 | 169 | fn __getitem__(&self, index: MultiIndex) -> PyResult { 170 | match index { 171 | MultiIndex::Index1(index) => match index { 172 | Index::PermutationNP(indices) => Ok(RaggedBufferF32(self.0.swizzle(indices)?)), 173 | Index::Permutation(indices) => Ok(RaggedBufferF32(self.0.swizzle_usize(&indices)?)), 174 | Index::Int(i) => Ok(RaggedBufferF32(self.0.get_sequence(i)?)), 175 | Index::Slice(slice) => panic!("{:?}", slice), 176 | }, 177 | MultiIndex::Index3((i0, i1, i2)) => Ok(RaggedBufferF32(Python::with_gil(|py| { 178 | self.0.get_slice(py, i0, i1, i2) 179 | })?)), 180 | x => panic!("{:?}", x), 181 | } 182 | } 183 | fn __len__(&self) -> PyResult { 184 | self.0.len() 185 | } 186 | } 187 | 188 | #[derive(FromPyObject)] 189 | pub enum RaggedBufferF32OrF32<'p> { 190 | RB(PyRef<'p, RaggedBufferF32>), 191 | Scalar(f32), 192 | } 193 | -------------------------------------------------------------------------------- /src/monomorphs/bool.rs: -------------------------------------------------------------------------------- 1 | use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArrayDyn, ToPyArray}; 2 | use pyo3::basic::CompareOp; 3 | use pyo3::prelude::*; 4 | use pyo3::types::PyType; 5 | 6 | use crate::monomorphs::RaggedBufferI64; 7 | use crate::ragged_buffer_view::RaggedBufferView; 8 | 9 | use super::{Index, MultiIndex, PadpackResult}; 10 | 11 | #[pyclass] 12 | #[derive(Clone)] 13 | pub struct RaggedBufferBool(pub RaggedBufferView); 14 | 15 | #[pymethods] 16 | impl RaggedBufferBool { 17 | #[new] 18 | pub fn new(features: usize) -> Self { 19 | RaggedBufferBool(RaggedBufferView::new(features)) 20 | } 21 | #[classmethod] 22 | fn from_array(_cls: &PyType, array: PyReadonlyArray3) -> Self { 23 | RaggedBufferBool(RaggedBufferView::from_array(array)) 24 | } 25 | #[classmethod] 26 | fn from_flattened( 27 | _cls: &PyType, 28 | flattened: PyReadonlyArray2, 29 | lengths: PyReadonlyArray1, 30 | ) -> PyResult { 31 | Ok(RaggedBufferBool(RaggedBufferView::from_flattened( 32 | flattened, lengths, 33 | )?)) 34 | } 35 | fn push(&mut self, items: PyReadonlyArrayDyn) -> PyResult<()> { 36 | if items.ndim() == 1 && items.len() == 0 { 37 | self.0.push_empty() 38 | } else if items.ndim() == 2 { 39 | self.0.push( 40 | &items 41 | .reshape((items.shape()[0], items.shape()[1]))? 42 | .readonly(), 43 | ) 44 | } else { 45 | Err(pyo3::exceptions::PyValueError::new_err( 46 | "Expected 2 dimensional array", 47 | )) 48 | } 49 | } 50 | fn push_empty(&mut self) -> PyResult<()> { 51 | self.0.push_empty() 52 | } 53 | 54 | fn clear(&mut self) -> PyResult<()> { 55 | self.0.clear() 56 | } 57 | 58 | fn as_array<'a>( 59 | &self, 60 | py: Python<'a>, 61 | ) -> PyResult<&'a numpy::PyArray>> { 62 | self.0.as_array(py) 63 | } 64 | 65 | fn extend(&mut self, other: &RaggedBufferBool) -> PyResult<()> { 66 | self.0.extend(&other.0) 67 | } 68 | fn size0(&self) -> usize { 69 | self.0.size0() 70 | } 71 | fn size1(&mut self, py: Python, i: Option) -> PyResult { 72 | match i { 73 | Some(i) => self.0.size1(i).map(|s| s.into_py(py)), 74 | None => self.0.lengths(py).map(|ok| ok.into_py(py)), 75 | } 76 | } 77 | fn size2(&self) -> usize { 78 | self.0.size2() 79 | } 80 | fn indices(&mut self, dim: usize) -> PyResult { 81 | Ok(RaggedBufferI64(self.0.indices(dim)?)) 82 | } 83 | fn flat_indices(&mut self) -> PyResult { 84 | Ok(RaggedBufferI64(self.0.flat_indices()?)) 85 | } 86 | #[classmethod] 87 | fn cat(_cls: &PyType, buffers: Vec>, dim: usize) -> PyResult { 88 | Ok(RaggedBufferBool(RaggedBufferView::cat( 89 | &buffers.iter().map(|b| &b.0).collect::>(), 90 | dim, 91 | )?)) 92 | } 93 | #[allow(clippy::type_complexity)] 94 | fn padpack<'a>(&mut self, py: Python<'a>) -> PadpackResult<'a> { 95 | match self.0.padpack()? { 96 | Some((padbpack_index, padpack_batch, padpack_inverse_index, dims)) => Ok(Some(( 97 | padbpack_index.to_pyarray(py).reshape(dims)?, 98 | padpack_batch.to_pyarray(py).reshape(dims)?, 99 | padpack_inverse_index 100 | .to_pyarray(py) 101 | .reshape(self.0.len()?)?, 102 | ))), 103 | _ => Ok(None), 104 | } 105 | } 106 | fn items(&mut self) -> PyResult { 107 | self.0.items() 108 | } 109 | fn clone(&self) -> Self { 110 | RaggedBufferBool(self.0.deepclone()) 111 | } 112 | fn materialize(&self) -> Self { 113 | RaggedBufferBool(self.0.materialize()) 114 | } 115 | fn __str__(&self) -> PyResult { 116 | self.0.__str__() 117 | } 118 | 119 | fn __repr__(&self) -> PyResult { 120 | self.0.__str__() 121 | } 122 | 123 | fn __richcmp__(&self, other: RaggedBufferBool, op: CompareOp) -> PyResult { 124 | match op { 125 | CompareOp::Eq => Ok(self.0 == other.0), 126 | CompareOp::Ne => Ok(self.0 != other.0), 127 | _ => Err(pyo3::exceptions::PyTypeError::new_err( 128 | "Only == and != are supported", 129 | )), 130 | } 131 | } 132 | 133 | // Is substituted for #[cfg(any())] for bool.rs to omit method 134 | #[cfg(any())] 135 | fn __add__( 136 | lhs: PyRef, 137 | rhs: RaggedBufferBoolOrBool, 138 | ) -> PyResult { 139 | match rhs { 140 | RaggedBufferBoolOrBool::RB(rhs) => Ok(RaggedBufferBool( 141 | lhs.0.binop::(&rhs.0)?, 142 | )), 143 | RaggedBufferBoolOrBool::Scalar(rhs) => Ok(RaggedBufferBool( 144 | lhs.0.op_scalar::(rhs)?, 145 | )), 146 | } 147 | } 148 | 149 | #[cfg(any())] 150 | fn __mul__( 151 | lhs: PyRef, 152 | rhs: RaggedBufferBoolOrBool, 153 | ) -> PyResult { 154 | match rhs { 155 | RaggedBufferBoolOrBool::RB(rhs) => Ok(RaggedBufferBool( 156 | lhs.0.binop::(&rhs.0)?, 157 | )), 158 | RaggedBufferBoolOrBool::Scalar(rhs) => Ok(RaggedBufferBool( 159 | lhs.0.op_scalar::(rhs)?, 160 | )), 161 | } 162 | } 163 | 164 | #[cfg(any())] 165 | fn __isub__(&mut self, rhs: RaggedBufferBool) -> PyResult<()> { 166 | self.0.binop_mut::(&rhs.0) 167 | } 168 | 169 | fn __getitem__(&self, index: MultiIndex) -> PyResult { 170 | match index { 171 | MultiIndex::Index1(index) => match index { 172 | Index::PermutationNP(indices) => Ok(RaggedBufferBool(self.0.swizzle(indices)?)), 173 | Index::Permutation(indices) => Ok(RaggedBufferBool(self.0.swizzle_usize(&indices)?)), 174 | Index::Int(i) => Ok(RaggedBufferBool(self.0.get_sequence(i)?)), 175 | Index::Slice(slice) => panic!("{:?}", slice), 176 | }, 177 | MultiIndex::Index3((i0, i1, i2)) => Ok(RaggedBufferBool(Python::with_gil(|py| { 178 | self.0.get_slice(py, i0, i1, i2) 179 | })?)), 180 | x => panic!("{:?}", x), 181 | } 182 | } 183 | fn __len__(&self) -> PyResult { 184 | self.0.len() 185 | } 186 | } 187 | 188 | #[derive(FromPyObject)] 189 | pub enum RaggedBufferBoolOrBool<'p> { 190 | RB(PyRef<'p, RaggedBufferBool>), 191 | Scalar(bool), 192 | } 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ENN Ragged Buffer 2 | 3 | [![Actions Status](https://github.com/entity-neural-network/ragged-buffer/workflows/Test/badge.svg)](https://github.com/entity-neural-network/ragged-buffer/actions) 4 | [![PyPI](https://img.shields.io/pypi/v/ragged-buffer.svg?style=flat-square)](https://pypi.org/project/ragged-buffer/) 5 | [![Discord](https://img.shields.io/discord/913497968701747270?style=flat-square)](https://discord.gg/SjVqhSW4Qf) 6 | 7 | This Python package implements an efficient `RaggedBuffer` datatype that is similar to 8 | a 3D numpy array, but which allows for variable sequence length in the second 9 | dimension. It was created primarily for use in [enn-trainer](https://github.com/entity-neural-network/enn-trainer) 10 | and currently only supports a small selection of the numpy array methods. 11 | 12 | ![Ragged Buffer](https://user-images.githubusercontent.com/12845088/143787823-c6a585de-aeda-429c-9824-f4b4a98e6cea.png) 13 | 14 | ## User Guide 15 | 16 | Install the package with `pip install ragged-buffer`. 17 | The package currently supports three `RaggedBuffer` variants, `RaggedBufferF32`, `RaggedBufferI64`, and `RaggedBufferBool`. 18 | 19 | 20 | - [Creating a RaggedBuffer](#creating-a-raggedbuffer) 21 | - [Get size](#get-size) 22 | - [Convert to numpy array](#convert-to-numpy-array) 23 | - [Indexing](#indexing) 24 | - [Addition](#addition) 25 | - [Concatenation](#concatentation) 26 | - [Clear](#clear) 27 | 28 | ### Creating a RaggedBuffer 29 | 30 | There are three ways to create a `RaggedBuffer`: 31 | - `RaggedBufferF32(features: int)` creates an empty `RaggedBuffer` with the specified number of features. 32 | - `RaggedBufferF32.from_flattened(flattened: np.ndarray, lenghts: np.ndarray)` creates a `RaggedBuffer` from a flattened 2D numpy array and a 1D numpy array of lengths. 33 | - `RaggedBufferF32.from_array` creates a `RaggedBuffer` (with equal sequence lenghts) from a 3D numpy array. 34 | 35 | Creating an empty buffer and pushing each row: 36 | 37 | ```python 38 | import numpy as np 39 | from ragged_buffer import RaggedBufferF32 40 | 41 | # Create an empty RaggedBuffer with a feature size of 3 42 | buffer = RaggedBufferF32(3) 43 | # Push sequences with 3, 5, 0, and 1 elements 44 | buffer.push(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)) 45 | buffer.push(np.array([[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]], dtype=np.float32)) 46 | buffer.push(np.array([], dtype=np.float32)) # Alternative: `buffer.push_empty()` 47 | buffer.push(np.array([[25, 25, 27]], dtype=np.float32)) 48 | ``` 49 | 50 | Creating a RaggedBuffer from a flat 2D numpy array which combines the first and second dimension, 51 | and an array of sequence lengths: 52 | 53 | ```python 54 | import numpy as np 55 | from ragged_buffer import RaggedBufferF32 56 | 57 | buffer = RaggedBufferF32.from_flattened( 58 | np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24], [25, 25, 27]], dtype=np.float32), 59 | np.array([3, 5, 0, 1], dtype=np.int64)) 60 | ) 61 | ``` 62 | 63 | Creating a RaggedBuffer from a 3D numpy array (all sequences have the same length): 64 | 65 | ```python 66 | import numpy as np 67 | from ragged_buffer import RaggedBufferF32 68 | 69 | buffer = RaggedBufferF32.from_array(np.zeros((4, 5, 3), dtype=np.float32)) 70 | ``` 71 | 72 | ### Get size 73 | 74 | The `size0`, `size1`, and `size2` methods return the number of sequences, the number of elements in a sequence, and the number of features respectively. 75 | 76 | ```python 77 | import numpy as np 78 | from ragged_buffer import RaggedBufferF32 79 | 80 | buffer = RaggedBufferF32.from_flattened( 81 | np.zeros((9, 64), dtype=np.float32), 82 | np.array([3, 5, 0, 1], dtype=np.int64)) 83 | ) 84 | 85 | # Get size of the first/batch dimension. 86 | assert buffer.size0() == 10 87 | # Get size of individual sequences. 88 | assert buffer.size1(1) == 5 89 | assert buffer.size1(2) == 0 90 | # Get size of the last/feature dimension. 91 | assert buffer.size2() == 64 92 | ``` 93 | 94 | ### Convert to numpy array 95 | 96 | `as_aray` converts a `RaggedBuffer` to a flat 2D numpy array that combines the first and second dimension. 97 | 98 | ```python 99 | import numpy as np 100 | from ragged_buffer import RaggedBufferI64 101 | 102 | buffer = RaggedBufferI64(1) 103 | buffer.push(np.array([[1], [1], [1]], dtype=np.int64)) 104 | buffer.push(np.array([[2], [2]], dtype=np.int64)) 105 | assert np.all(buffer.as_array(), np.array([[1], [1], [1], [2], [2]], dtype=np.int64)) 106 | ``` 107 | 108 | ### Indexing 109 | 110 | You can index a `RaggedBuffer` with a single integer (returning a `RaggedBuffer` with a single sequence), or with a numpy array of integers selecting/permuting multiple sequences. 111 | 112 | ```python 113 | import numpy as np 114 | from ragged_buffer import RaggedBufferF32 115 | 116 | # Create a new `RaggedBufferF32` 117 | buffer = RaggedBufferF32.from_flattened( 118 | np.arange(0, 40, dtype=np.float32).reshape(10, 4), 119 | np.array([3, 5, 0, 1], dtype=np.int64) 120 | ) 121 | 122 | # Retrieve the first sequence. 123 | assert np.all( 124 | buffer[0].as_array() == 125 | np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=np.float32) 126 | ) 127 | 128 | # Get a RaggedBatch with 2 randomly selected sequences. 129 | buffer[np.random.permutation(4)[:2]] 130 | ``` 131 | 132 | ### Addition 133 | 134 | You can add two `RaggedBuffer`s with the `+` operator if they have the same number of sequences, sequence lengths, and features. You can also add a `RaggedBuffer` where all sequences have a length of 1 to a `RaggedBuffer` with variable length sequences, broadcasting along each sequence. 135 | 136 | ```python 137 | import numpy as np 138 | from ragged_buffer import RaggedBufferF32 139 | 140 | # Create ragged buffer with dimensions (3, [1, 3, 2], 1) 141 | rb3 = RaggedBufferI64(1) 142 | rb3.push(np.array([[0]], dtype=np.int64)) 143 | rb3.push(np.array([[0], [1], [2]], dtype=np.int64)) 144 | rb3.push(np.array([[0], [5]], dtype=np.int64)) 145 | 146 | # Create ragged buffer with dimensions (3, [1, 1, 1], 1) 147 | rb4 = RaggedBufferI64.from_array(np.array([0, 3, 10], dtype=np.int64).reshape(3, 1, 1)) 148 | 149 | # Add rb3 and rb4, broadcasting along the sequence dimension. 150 | rb5 = rb3 + rb4 151 | assert np.all( 152 | rb5.as_array() == np.array([[0], [3], [4], [5], [10], [15]], dtype=np.int64) 153 | ) 154 | ``` 155 | 156 | ### Concatenation 157 | 158 | The `extend` method can be used to mutate a `RaggedBuffer` by appending another `RaggedBuffer` to it. 159 | 160 | ```python 161 | import numpy as np 162 | from ragged_buffer import RaggedBufferF32 163 | 164 | 165 | rb1 = RaggedBufferF32.from_array(np.zeros((4, 5, 3), dtype=np.float32)) 166 | rb2 = RaggedBufferF32.from_array(np.zeros((2, 5, 3), dtype=np.float32)) 167 | rb1.extend(r2) 168 | assert rb1.size0() == 6 169 | ``` 170 | 171 | ### Clear 172 | 173 | The `clear` method removes all elements from a `RaggedBuffer` without deallocating the underlying memory. 174 | 175 | 176 | ```python 177 | import numpy as np 178 | from ragged_buffer import RaggedBufferF32 179 | 180 | rb = RaggedBufferF32.from_array(np.zeros((4, 5, 3), dtype=np.float32)) 181 | rb.clear() 182 | assert rb.size0() == 0 183 | ``` 184 | 185 | ## License 186 | 187 | ENN Ragged Buffer dual-licensed under Apache-2.0 and MIT. 188 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | import math 3 | import numpy as np 4 | from ragged_buffer import ( 5 | RaggedBufferF32, 6 | RaggedBufferI64, 7 | RaggedBufferBool, 8 | RaggedBuffer, 9 | ) 10 | import ragged_buffer 11 | 12 | rba = RaggedBufferF32(3) 13 | 14 | ScalarType = TypeVar("ScalarType", bound=np.generic) 15 | 16 | 17 | def generic_len(r: RaggedBuffer[ScalarType]) -> int: 18 | return sum([r.size1(s) for s in range(r.size0())]) * r.size2() 19 | 20 | 21 | expected = """RaggedBuffer([ 22 | ], '0 * var * 3 * f32)""" 23 | assert str(rba) == expected 24 | 25 | rba.push(np.array([[2.0, 3.0, 1.0], [1.0, 2.0, 3.0]], dtype=np.float32)) 26 | rba.push( 27 | np.array([[2.0, 3.0, 1.0], [1.0, 2.0, 3.0], [1.4, 2.4, 3.4]], dtype=np.float32) 28 | ) 29 | rba.push( 30 | np.array( 31 | [[2.0, 3.0, 1.0], [1.0, 2.0, 3.0], [1.4, 2.4, 3.4], [1.4, 2.4, 3.4]], 32 | dtype=np.float32, 33 | ) 34 | ) 35 | rba.push(np.array([], dtype=np.float32)) 36 | rba.push_empty() 37 | 38 | assert generic_len(rba) == 27, f"Expected 27 elements, got {generic_len(rba)}" 39 | 40 | expected = """RaggedBuffer([ 41 | [ 42 | [2, 3, 1], 43 | [1, 2, 3], 44 | ], 45 | [ 46 | [2, 3, 1], 47 | [1, 2, 3], 48 | [1.4, 2.4, 3.4], 49 | ], 50 | [ 51 | [2, 3, 1], 52 | [1, 2, 3], 53 | [1.4, 2.4, 3.4], 54 | [1.4, 2.4, 3.4], 55 | ], 56 | [], 57 | [], 58 | ], '5 * var * 3 * f32)""" 59 | 60 | 61 | assert str(rba) == expected, str(rba) 62 | flattened = np.array( 63 | [ 64 | [2.0, 3.0, 1.0], 65 | [1.0, 2.0, 3.0], 66 | [2.0, 3.0, 1.0], 67 | [1.0, 2.0, 3.0], 68 | [1.4, 2.4, 3.4], 69 | [2.0, 3.0, 1.0], 70 | [1.0, 2.0, 3.0], 71 | [1.4, 2.4, 3.4], 72 | [1.4, 2.4, 3.4], 73 | ], 74 | dtype=np.float32, 75 | ) 76 | assert np.all(rba.as_array() == flattened) 77 | assert rba == RaggedBufferF32.from_flattened( 78 | flattened=flattened, 79 | lengths=np.array([2, 3, 4, 0, 0], dtype=np.int64), 80 | ) 81 | assert RaggedBufferF32(3) == RaggedBufferF32(3) 82 | 83 | 84 | rba2 = RaggedBufferF32(3) 85 | rba2.push(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=np.float32)) 86 | rba2.push( 87 | np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=np.float32) 88 | ) 89 | rba.extend(rba2) 90 | assert rba == RaggedBufferF32.from_flattened( 91 | flattened=np.concatenate([flattened, np.zeros((5, 3), dtype=np.float32)]), 92 | lengths=np.array([2, 3, 4, 0, 0, 2, 3], dtype=np.int64), 93 | ) 94 | rba[np.random.permutation(rba.size0())] 95 | 96 | assert rba.size0() == 7 97 | assert rba.size1(0) == 2 98 | assert rba.size1(1) == 3 99 | assert rba.size1(2) == 4 100 | assert rba.size1(3) == 0 101 | assert rba.size1(4) == 0 102 | assert rba.size1(5) == 2 103 | assert rba.size1(6) == 3 104 | assert np.all(rba.size1() == np.array([2, 3, 4, 0, 0, 2, 3], dtype=np.int64)) 105 | assert rba.size2() == 3 106 | 107 | rba.clear() 108 | assert rba == RaggedBufferF32(3) 109 | 110 | rb3 = RaggedBufferI64(1) 111 | rb3.push(np.array([[0]], dtype=np.int64)) 112 | rb3.push(np.array([[0], [1], [2]], dtype=np.int64)) 113 | rb3.push(np.array([[0], [5]], dtype=np.int64)) 114 | assert rb3 == RaggedBufferI64.from_flattened( 115 | flattened=np.array([[0], [0], [1], [2], [0], [5]], dtype=np.int64), 116 | lengths=np.array([1, 3, 2], dtype=np.int64), 117 | ) 118 | # Shuffle 119 | rb3[np.random.permutation(rb3.size0())] 120 | 121 | rb4 = RaggedBufferI64.from_array(np.array([0, 3, 10], dtype=np.int64).reshape(3, 1, 1)) 122 | assert rb4 == RaggedBufferI64.from_flattened( 123 | flattened=np.array([[0], [3], [10]], dtype=np.int64), 124 | lengths=np.array([1, 1, 1], dtype=np.int64), 125 | ) 126 | rb5 = rb3 + rb4 127 | assert np.all( 128 | rb5.as_array() == np.array([[0], [3], [4], [5], [10], [15]], dtype=np.int64) 129 | ), f"{rb5.as_array()}" 130 | assert rb3 + rb4 == rb4 + rb3 131 | assert rb5 * 10 == RaggedBufferI64.from_flattened( 132 | flattened=np.array([[0], [30], [40], [50], [100], [150]], dtype=np.int64), 133 | lengths=np.array([1, 3, 2], dtype=np.int64), 134 | ) 135 | assert rb5.indices(1) == RaggedBufferI64.from_flattened( 136 | flattened=np.array([[0], [0], [1], [2], [0], [1]], dtype=np.int64), 137 | lengths=np.array([1, 3, 2], dtype=np.int64), 138 | ) 139 | 140 | 141 | rb6 = RaggedBufferF32.from_flattened( 142 | np.array( 143 | [ 144 | [0.0, 0.0, 0.0], 145 | [1.0, 2.0, 3.0], 146 | [4.0, 5.0, 6.0], 147 | [7.0, 8.0, 9.0], 148 | [10.0, 11.0, 12.0], 149 | [13.0, 14.0, 15.0], 150 | ], 151 | dtype=np.float32, 152 | ), 153 | np.array([3, 0, 2, 1], dtype=np.int64), 154 | ) 155 | assert rb6[np.array([1, 3, 0], dtype=np.int64)] == RaggedBufferF32.from_flattened( 156 | np.array( 157 | [ 158 | [13.0, 14.0, 15.0], 159 | [0, 0, 0], 160 | [1, 2, 3], 161 | [4, 5, 6], 162 | ], 163 | dtype=np.float32, 164 | ), 165 | np.array([0, 1, 3], dtype=np.int64), 166 | ) 167 | 168 | assert np.all( 169 | rb6[2].as_array() 170 | == np.array([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=np.float32) 171 | ), f"{rb6[2]}" 172 | 173 | 174 | entities1 = RaggedBufferF32.from_flattened( 175 | np.zeros((6, 64), dtype=np.float32), np.array([3, 1, 2], dtype=np.int64) 176 | ) 177 | entities2 = RaggedBufferF32.from_flattened( 178 | np.zeros((3, 64), dtype=np.float32), np.array([1, 2, 0], dtype=np.int64) 179 | ) 180 | assert np.all(entities1.size1() == np.array([3, 1, 2], dtype=np.int64)) 181 | assert np.all(entities2.size1() == np.array([1, 2, 0], dtype=np.int64)) 182 | print("TEST 1 PASSED") 183 | 184 | bi1 = entities1.indices(0).as_array().flatten() 185 | bi2 = entities2.indices(0).as_array().flatten() 186 | assert np.all(bi1 == np.array([0, 0, 0, 1, 2, 2], dtype=np.int64)), f"{bi1}" 187 | assert np.all(bi2 == np.array([0, 1, 1], dtype=np.int64)), f"{bi2}" 188 | print("TEST 2 PASSED") 189 | 190 | flati1 = entities1.flat_indices() 191 | print("TEST 3 PASSED") 192 | flati2 = entities2.flat_indices() + 6 193 | print("TEST 4 PASSED") 194 | flat = ragged_buffer.cat([flati1, flati2, flati1, flati2], dim=1).as_array().flatten() 195 | assert np.all( 196 | flat 197 | == np.array([0, 1, 2, 6, 0, 1, 2, 6, 3, 7, 8, 3, 7, 8, 4, 5, 4, 5], dtype=np.int64), 198 | ), f"{flat} {ragged_buffer.cat([flati1, flati2, flati1, flati2], dim=1)}" 199 | print("TEST 5 PASSED") 200 | 201 | 202 | mask = RaggedBufferI64.from_array(np.zeros((4, 1, 1), dtype=np.int64)) 203 | offset = RaggedBufferI64.from_flattened( 204 | np.array([0, 1, 2, 3, 13, 22, 32, 41, 42, 43, 44, 45,], dtype=np.int64).reshape( 205 | -1, 206 | 1, 207 | ), 208 | np.ones(12, dtype=np.int64), 209 | ) 210 | try: 211 | mask = mask + offset 212 | except ValueError as e: 213 | pass 214 | else: 215 | assert False, "Did not raise ValueError" 216 | 217 | 218 | zerofeats = RaggedBufferF32(features=0) 219 | zerofeats.push(np.zeros((1, 0), dtype=np.float32)) 220 | zerofeats.push(np.zeros((0, 0), dtype=np.float32)) 221 | assert zerofeats.as_array().shape == (1, 0), f"{zerofeats.as_array().shape}" 222 | 223 | 224 | boolrb = RaggedBufferBool.from_flattened( 225 | np.array([[True, False, True], [False, True, False]], dtype=np.bool_), 226 | np.array([2, 0], dtype=np.int64), 227 | ) 228 | assert boolrb.as_array().shape == (2, 3), f"{boolrb.as_array().shape}" 229 | assert np.all( 230 | boolrb.as_array() 231 | == np.array([[True, False, True], [False, True, False]], dtype=np.bool_) 232 | ), f"{boolrb.as_array()}" 233 | 234 | batch_index = RaggedBufferI64.from_flattened( 235 | np.array( 236 | [[0], [0], [0], [0], [0], [0], [1], [1], [1], [2], [2], [2], [2], [4]], 237 | dtype=np.int64, 238 | ), 239 | np.array([6, 3, 4, 0, 1], dtype=np.int64), 240 | ) 241 | 242 | padpack = batch_index.padpack() 243 | assert padpack is not None 244 | padpack_index, padpack_batch, padpack_inverse_index = padpack 245 | 246 | assert np.all( 247 | padpack_index 248 | == np.array( 249 | [[0, 1, 2, 3, 4, 5], [6, 7, 8, 13, 0, 0], [9, 10, 11, 12, 0, 0]], dtype=np.int64 250 | ), 251 | ), f"{padpack_index}" 252 | 253 | 254 | np.testing.assert_equal( 255 | padpack_batch, 256 | np.array( 257 | [ 258 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 259 | [1.0, 1.0, 1.0, 4.0, np.NaN, np.NaN], 260 | [2.0, 2.0, 2.0, 2.0, np.NaN, np.NaN], 261 | ], 262 | dtype=np.float32, 263 | ), 264 | ) 265 | 266 | np.testing.assert_equal( 267 | padpack_inverse_index, 268 | np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 9], dtype=np.int64), 269 | ) 270 | 271 | 272 | origin = RaggedBufferF32.from_array( 273 | np.array( 274 | [ 275 | [[0.0, 0.0, 100.0, -23.0]], 276 | [[1.0, -1.0, 200.0, -23.0]], 277 | [[2.0, -2.0, 300.0, -23.0]], 278 | [[-10.0, -10.0, 400.0, -23.0]], 279 | ], 280 | dtype=np.float32, 281 | ) 282 | ) 283 | entities = RaggedBufferF32.from_flattened( 284 | np.array( 285 | [ 286 | [10, 3.0, 10, 1.0], 287 | [11, 4.0, 11, 2.0], 288 | [12, 5.0, 12, 3.0], 289 | [13, 6.0, 13, 4.0], 290 | [14, 7.0, 14, 5.0], 291 | [15, 8.0, 15, 6.0], 292 | ], 293 | dtype=np.float32, 294 | ), 295 | np.array([3, 0, 2, 1], dtype=np.int64), 296 | ) 297 | 298 | 299 | entities_slice = entities[:, :, [1, 3]] 300 | entities_clone = entities.clone() 301 | entities_slice -= origin[:, :, [0, 1]] 302 | assert entities == RaggedBufferF32.from_flattened( 303 | np.array( 304 | [ 305 | [10, 3.0, 10, 1.0], 306 | [11, 4.0, 11, 2.0], 307 | [12, 5.0, 12, 3.0], 308 | [13, 4.0, 13, 6.0], 309 | [14, 5.0, 14, 7.0], 310 | [15, 18.0, 15, 16.0], 311 | ], 312 | dtype=np.float32, 313 | ), 314 | np.array([3, 0, 2, 1], dtype=np.int64), 315 | ), f"{entities}" 316 | 317 | assert len(entities) == 24, f"{len(entities)}" 318 | assert entities.items() == 6 319 | assert entities_clone != entities 320 | 321 | 322 | origin = RaggedBufferF32.from_array( 323 | np.array( 324 | [ 325 | [[0.0, 0.0, 100.0, -23.0, 1.0, 0.0]], 326 | [[1.0, -1.0, 200.0, -23.0, 1.0, 0.0]], 327 | [[2.0, -2.0, 300.0, -23.0, math.sqrt(2) / 2, -math.sqrt(2) / 2]], 328 | [[-10.0, -10.0, 400.0, -23.0, -1.0, 0.0]], 329 | ], 330 | dtype=np.float32, 331 | ) 332 | ) 333 | entities = RaggedBufferF32.from_flattened( 334 | np.array( 335 | [ 336 | [10, 3.0, 10, 1.0], 337 | [11, 4.0, 11, 2.0], 338 | [12, 5.0, 12, 3.0], 339 | [13, 6.0, 13, 4.0], 340 | [14, 7.0, 14, 5.0], 341 | [15, 8.0, 15, 6.0], 342 | ], 343 | dtype=np.float32, 344 | ), 345 | np.array([3, 0, 2, 1], dtype=np.int64), 346 | ) 347 | 348 | 349 | ragged_buffer.translate_rotate( 350 | entities[:, :, [1, 3]], origin[:, :, [0, 1]], origin[:, :, [4, 5]] 351 | ) 352 | assert np.allclose( 353 | entities.as_array(), 354 | np.array( 355 | [ 356 | [10, 3.0, 10, 1.0], 357 | [11, 4.0, 11, 2.0], 358 | [12, 5.0, 12, 3.0], 359 | [13, -1.4142134, 13, 7.071068], 360 | [14, -1.4142137, 14, 8.485281], 361 | [15, -18.0, 15, -16.0], 362 | ], 363 | dtype=np.float32, 364 | ), 365 | 1e-6, 366 | ), f"{entities}" 367 | 368 | 369 | feats = RaggedBufferF32.from_flattened( 370 | np.array( 371 | [ 372 | [0.0, 3.0, 1.0, 1.0], 373 | [1.0, 4.0, 0, 2.0], 374 | [2.0, 5.0, 1.0, 3.0], 375 | [3.0, 6.0, 0.0, 4.0], 376 | [4.0, 7.0, 0.0, 5.0], 377 | [5.0, 8.0, 0.0, 6.0], 378 | ], 379 | dtype=np.float32, 380 | ), 381 | np.array([3, 2, 1], dtype=np.int64), 382 | ) 383 | assert np.array_equal( 384 | feats[1:2, :, :].as_array(), 385 | np.array( 386 | [ 387 | [3.0, 6.0, 0.0, 4.0], 388 | [4.0, 7.0, 0.0, 5.0], 389 | ], 390 | dtype=np.float32, 391 | ), 392 | ), f"{feats[1:2, :, :].as_array()}" 393 | assert np.array_equal( 394 | feats[:, :, 1:3].as_array(), 395 | np.array( 396 | [ 397 | [3.0, 1.0], 398 | [4.0, 0.0], 399 | [5.0, 1.0], 400 | [6.0, 0.0], 401 | [7.0, 0.0], 402 | [8.0, 0.0], 403 | ], 404 | dtype=np.float32, 405 | ), 406 | ), f"{feats[:, :, 1:3].as_array()}" 407 | assert np.array_equal( 408 | feats[:, 0, :].as_array(), 409 | np.array( 410 | [ 411 | [0.0, 3.0, 1.0, 1.0], 412 | [3.0, 6.0, 0.0, 4.0], 413 | [5.0, 8.0, 0.0, 6.0], 414 | ], 415 | dtype=np.float32, 416 | ), 417 | ), f"{feats[:, :, 0].as_array()}" 418 | assert np.array_equal( 419 | feats[:, 0:2:3, :].size1(), 420 | np.array([1, 1, 1], dtype=np.int64), 421 | ), f"{feats[:, 0:2:3, :].size1()}" 422 | assert np.array_equal( 423 | feats[:, 1:3, 0].size1(), 424 | np.array([2, 1, 0], dtype=np.int64), 425 | ), f"{feats[:, 1:3, 0].size1()}" 426 | assert np.array_equal( 427 | feats[1:, 1:10, 0].size1(), 428 | np.array([1, 0], dtype=np.int64), 429 | ), f"{feats[1:, 1:10, 0].size1()}" 430 | 431 | 432 | # Test broadcasting concatenation along dim 2 433 | entities = RaggedBufferF32.from_flattened( 434 | np.array( 435 | [ 436 | [10, 3.0, 10, 1.0], 437 | [11, 4.0, 11, 2.0], 438 | [12, 5.0, 12, 3.0], 439 | [13, 4.0, 13, 6.0], 440 | [14, 5.0, 14, 7.0], 441 | [15, 18.0, 15, 16.0], 442 | ], 443 | dtype=np.float32, 444 | ), 445 | np.array([3, 0, 2, 1], dtype=np.int64), 446 | ) 447 | global_feats = RaggedBufferF32.from_array( 448 | np.array( 449 | [ 450 | [[0.0]], 451 | [[1.0]], 452 | [[2.0]], 453 | [[3.0]], 454 | ], 455 | dtype=np.float32, 456 | ) 457 | ) 458 | result = ragged_buffer.cat([entities, global_feats], dim=2) 459 | assert np.array_equal( 460 | result.as_array(), 461 | np.array( 462 | [ 463 | [10, 3.0, 10, 1.0, 0.0], 464 | [11, 4.0, 11, 2.0, 0.0], 465 | [12, 5.0, 12, 3.0, 0.0], 466 | [13, 4.0, 13, 6.0, 2.0], 467 | [14, 5.0, 14, 7.0, 2.0], 468 | [15, 18.0, 15, 16.0, 3.0], 469 | ] 470 | ), 471 | ), f"{result}" 472 | 473 | # Test complex indexing 474 | sliced = entities[[1, 0], 0:3:2, [0, 3, 2]] 475 | assert np.array_equal( 476 | sliced.as_array(), 477 | np.array( 478 | [ 479 | [10, 1.0, 10], 480 | [12, 3.0, 12], 481 | ] 482 | ), 483 | ), f"{sliced}" 484 | assert np.array_equal( 485 | sliced.materialize().size1(), 486 | np.array([0, 2], dtype=np.int64), 487 | ), f"{sliced.size1()}" 488 | 489 | print("ALL TESTS PASSED") 490 | -------------------------------------------------------------------------------- /src/ragged_buffer.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::collections::{binary_heap, BinaryHeap}; 3 | use std::fmt::{Display, Write}; 4 | use std::ops::{Add, Mul, Range, Sub}; 5 | 6 | use ndarray::{ArrayView1, ArrayView2, ArrayView3}; 7 | 8 | #[derive(Debug)] 9 | pub enum Error { 10 | Generic(String), 11 | } 12 | 13 | impl Error { 14 | fn generic>(s: S) -> Self { 15 | Self::Generic(s.into()) 16 | } 17 | } 18 | 19 | pub type Result = std::result::Result; 20 | 21 | #[derive(Clone, PartialEq, Eq, Hash, Debug)] 22 | pub struct RaggedBuffer { 23 | pub data: Vec, 24 | // Each element of `subarrays` gives the start/end index of the items within that subarray (step size 1). 25 | // The start index of the data of an item is obtained by multiplying its index by `features`. 26 | pub subarrays: Vec>, 27 | pub features: usize, 28 | } 29 | 30 | pub trait BinOp { 31 | fn op(lhs: T, rhs: T) -> T; 32 | } 33 | 34 | pub struct BinOpAdd; 35 | 36 | impl> BinOp for BinOpAdd { 37 | #[inline] 38 | fn op(lhs: T, rhs: T) -> T { 39 | lhs + rhs 40 | } 41 | } 42 | 43 | pub struct BinOpSub; 44 | 45 | impl> BinOp for BinOpSub { 46 | #[inline] 47 | fn op(lhs: T, rhs: T) -> T { 48 | lhs - rhs 49 | } 50 | } 51 | 52 | pub struct BinOpMul; 53 | 54 | impl> BinOp for BinOpMul { 55 | #[inline] 56 | fn op(lhs: T, rhs: T) -> T { 57 | lhs * rhs 58 | } 59 | } 60 | 61 | impl RaggedBuffer { 62 | pub fn new(features: usize) -> Self { 63 | RaggedBuffer { 64 | data: Vec::new(), 65 | subarrays: Vec::new(), 66 | features, 67 | } 68 | } 69 | 70 | pub fn from_array(data: ArrayView3) -> Self { 71 | let features = data.shape()[2]; 72 | RaggedBuffer { 73 | data: data.iter().cloned().collect(), 74 | subarrays: (0..data.shape()[0]) 75 | .map(|i| i * data.shape()[1]..(i + 1) * data.shape()[1]) 76 | .collect(), 77 | features, 78 | } 79 | } 80 | 81 | pub fn from_flattened(data: ArrayView2, lengths: ArrayView1) -> Result { 82 | let features = data.shape()[1]; 83 | let mut subarrays = Vec::new(); 84 | let mut item = 0; 85 | for len in lengths.iter().cloned() { 86 | subarrays.push(item..(item + len as usize)); 87 | item += len as usize; 88 | } 89 | if item != data.shape()[0] { 90 | Err(Error::generic(format!( 91 | "Lengths array specifies {} items, but data array has {} items", 92 | item, 93 | data.shape()[0] 94 | ))) 95 | } else { 96 | Ok(RaggedBuffer { 97 | data: data.iter().cloned().collect(), 98 | subarrays, 99 | features, 100 | }) 101 | } 102 | } 103 | 104 | pub fn extend(&mut self, other: &RaggedBuffer) -> Result<()> { 105 | if self.features != other.features { 106 | return Err(Error::generic(format!( 107 | "Features mismatch: {} != {}", 108 | self.features, other.features 109 | ))); 110 | } 111 | let item = self.items(); 112 | self.data.extend(other.data.iter()); 113 | self.subarrays 114 | .extend(other.subarrays.iter().map(|r| r.start + item..r.end + item)); 115 | Ok(()) 116 | } 117 | 118 | pub fn clear(&mut self) { 119 | self.data.clear(); 120 | self.subarrays.clear(); 121 | } 122 | 123 | // pub fn as_array<'a>( 124 | // &self, 125 | // py: Python<'a>, 126 | // ) -> PyResult<&'a numpy::PyArray>> { 127 | // self.data 128 | // .to_pyarray(py) 129 | // .reshape((self.items, self.features)) 130 | // } 131 | 132 | pub fn push(&mut self, data: &ArrayView2) -> Result<()> { 133 | if data.dim().1 != self.features { 134 | return Err(Error::generic(format!( 135 | "Features mismatch: {} != {}", 136 | self.features, 137 | data.dim().1 138 | ))); 139 | } 140 | self.subarrays 141 | .push(self.items()..(self.items() + data.dim().0)); 142 | match data.as_slice() { 143 | Some(slice) => self.data.extend_from_slice(slice), 144 | None => { 145 | for x in data.iter() { 146 | self.data.push(*x); 147 | } 148 | } 149 | } 150 | Ok(()) 151 | } 152 | 153 | pub fn push_empty(&mut self) { 154 | self.subarrays.push(self.items()..self.items()); 155 | } 156 | 157 | pub fn swizzle(&self, indices: ArrayView1) -> Result> { 158 | let indices = indices 159 | .as_slice() 160 | .ok_or_else(|| Error::generic("Indices must be a **contiguous** 1D array"))?; 161 | let mut subarrays = Vec::with_capacity(indices.len()); 162 | let mut item = 0usize; 163 | for i in indices { 164 | let sublen = self.subarrays[*i as usize].end - self.subarrays[*i as usize].start; 165 | subarrays.push(item..(item + sublen)); 166 | item += sublen; 167 | } 168 | let mut data = Vec::with_capacity(item * self.features); 169 | for i in indices { 170 | let Range { start, end } = self.subarrays[*i as usize]; 171 | data.extend_from_slice(&self.data[start * self.features..end * self.features]); 172 | } 173 | Ok(RaggedBuffer { 174 | data, 175 | subarrays, 176 | features: self.features, 177 | }) 178 | } 179 | 180 | // TODO: dedupe with swizzle 181 | pub fn swizzle_usize(&self, indices: &[usize]) -> Result> { 182 | let mut subarrays = Vec::with_capacity(indices.len()); 183 | let mut item = 0usize; 184 | for &i in indices { 185 | let sublen = self.subarrays[i].end - self.subarrays[i].start; 186 | subarrays.push(item..(item + sublen)); 187 | item += sublen; 188 | } 189 | let mut data = Vec::with_capacity(item * self.features); 190 | for i in indices { 191 | let Range { start, end } = self.subarrays[*i as usize]; 192 | data.extend_from_slice(&self.data[start * self.features..end * self.features]); 193 | } 194 | Ok(RaggedBuffer { 195 | data, 196 | subarrays, 197 | features: self.features, 198 | }) 199 | } 200 | 201 | pub fn get(&self, i: usize) -> RaggedBuffer { 202 | let subarray = self.subarrays[i].clone(); 203 | let Range { start, end } = subarray; 204 | RaggedBuffer { 205 | subarrays: vec![0..subarray.len()], 206 | data: self.data[start * self.features..end * self.features].to_vec(), 207 | features: self.features, 208 | } 209 | } 210 | 211 | pub fn size0(&self) -> usize { 212 | self.subarrays.len() 213 | } 214 | 215 | pub fn lengths(&self) -> Vec { 216 | self.subarrays 217 | .iter() 218 | .map(|r| (r.end - r.start) as i64) 219 | .collect::>() 220 | } 221 | 222 | pub fn size1(&self, i: usize) -> Result { 223 | if i >= self.subarrays.len() { 224 | Err(Error::generic(format!("Index {} out of range", i))) 225 | } else { 226 | Ok(self.subarrays[i].end - self.subarrays[i].start) 227 | } 228 | } 229 | 230 | pub fn size2(&self) -> usize { 231 | self.features 232 | } 233 | 234 | pub fn __str__(&self) -> Result { 235 | let mut array = String::new(); 236 | array.push_str("RaggedBuffer(["); 237 | array.push('\n'); 238 | for range in &self.subarrays { 239 | let slice = range.start * self.features..range.end * self.features; 240 | if range.start == range.end { 241 | writeln!(array, " [],").unwrap(); 242 | } else if range.start + 1 == range.end { 243 | writeln!(array, " [{:?}],", &self.data[slice]).unwrap(); 244 | } else { 245 | writeln!(array, " [").unwrap(); 246 | for i in slice.clone() { 247 | if i % self.features == 0 { 248 | if i != slice.start { 249 | writeln!(array, "],").unwrap(); 250 | } 251 | write!(array, " [").unwrap(); 252 | } 253 | write!(array, "{}", self.data[i]).unwrap(); 254 | if i % self.features != self.features - 1 { 255 | write!(array, ", ").unwrap(); 256 | } 257 | } 258 | writeln!(array, "],").unwrap(); 259 | writeln!(array, " ],").unwrap(); 260 | } 261 | } 262 | write!( 263 | array, 264 | "], '{} * var * {} * {})", 265 | self.subarrays.len(), 266 | self.features, 267 | std::any::type_name::(), 268 | ) 269 | .unwrap(); 270 | 271 | Ok(array) 272 | } 273 | 274 | pub fn binop>(&self, rhs: &RaggedBuffer) -> Result> { 275 | if self.features == rhs.features && self.subarrays == rhs.subarrays { 276 | let mut data = Vec::with_capacity(self.data.len()); 277 | for i in 0..self.data.len() { 278 | data.push(Op::op(self.data[i], rhs.data[i])); 279 | } 280 | Ok(RaggedBuffer { 281 | data, 282 | subarrays: self.subarrays.clone(), 283 | features: self.features, 284 | }) 285 | } else if self.features == rhs.features 286 | && self.subarrays.len() == rhs.subarrays.len() 287 | && rhs.subarrays.iter().all(|r| r.end - r.start == 1) 288 | { 289 | let mut data = Vec::with_capacity(self.data.len()); 290 | for (subarray, rhs_subarray) in self.subarrays.iter().zip(rhs.subarrays.iter()) { 291 | for item in subarray.clone() { 292 | let lhs_offset = item * self.features; 293 | let rhs_offset = rhs_subarray.start * self.features; 294 | for i in 0..self.features { 295 | data.push(Op::op(self.data[lhs_offset + i], rhs.data[rhs_offset + i])); 296 | } 297 | } 298 | } 299 | Ok(RaggedBuffer { 300 | data, 301 | subarrays: self.subarrays.clone(), 302 | features: self.features, 303 | }) 304 | } else if self.features == rhs.features 305 | && self.subarrays.len() == rhs.subarrays.len() 306 | && self.subarrays.iter().all(|r| r.end - r.start == 1) 307 | { 308 | rhs.binop::(self) 309 | } else { 310 | Err(Error::generic(format!( 311 | "Dimensions mismatch: ({}, {:?}, {}) != ({}, {:?}, {})", 312 | self.size0(), 313 | self.subarrays 314 | .iter() 315 | .map(|r| r.end - r.start) 316 | .collect::>(), 317 | self.size2(), 318 | rhs.size0(), 319 | rhs.subarrays 320 | .iter() 321 | .map(|r| r.end - r.start) 322 | .collect::>(), 323 | rhs.size2(), 324 | ))) 325 | } 326 | } 327 | 328 | pub fn op_scalar>(&self, scalar: T) -> RaggedBuffer { 329 | RaggedBuffer { 330 | data: self.data.iter().map(|x| Op::op(*x, scalar)).collect(), 331 | subarrays: self.subarrays.clone(), 332 | features: self.features, 333 | } 334 | } 335 | 336 | pub fn indices(&self, dim: usize) -> Result> { 337 | match dim { 338 | 0 => { 339 | let mut indices = Vec::with_capacity(self.items()); 340 | for (index, subarray) in self.subarrays.iter().enumerate() { 341 | for _ in subarray.clone() { 342 | indices.push(index as i64); 343 | } 344 | } 345 | Ok(RaggedBuffer { 346 | subarrays: self.subarrays.clone(), 347 | data: indices, 348 | features: 1, 349 | }) 350 | } 351 | 1 => { 352 | let mut indices = Vec::with_capacity(self.items()); 353 | for subarray in &self.subarrays { 354 | for (i, _) in subarray.clone().enumerate() { 355 | indices.push(i as i64); 356 | } 357 | } 358 | Ok(RaggedBuffer { 359 | subarrays: self.subarrays.clone(), 360 | data: indices, 361 | features: 1, 362 | }) 363 | } 364 | _ => Err(Error::generic(format!("Invalid dimension {}", dim))), 365 | } 366 | } 367 | 368 | pub fn flat_indices(&self) -> Result> { 369 | Ok(RaggedBuffer { 370 | subarrays: self.subarrays.clone(), 371 | data: (0..self.items()).map(|i| i as i64).collect(), 372 | features: 1, 373 | }) 374 | } 375 | 376 | pub fn cat(buffers: &[&RaggedBuffer], dim: usize) -> Result> { 377 | match dim { 378 | 0 => { 379 | if buffers.iter().any(|b| b.features != buffers[0].features) { 380 | return Err(Error::generic(format!( 381 | "All buffers must have the same number of features, but found {}", 382 | buffers 383 | .iter() 384 | .map(|b| b.features.to_string()) 385 | .collect::>() 386 | .join(", ") 387 | ))); 388 | } 389 | let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum()); 390 | for buffer in buffers { 391 | data.extend_from_slice(&buffer.data); 392 | } 393 | let mut subarrays = 394 | Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum()); 395 | let mut item = 0; 396 | for buffer in buffers { 397 | subarrays.extend_from_slice( 398 | &buffer 399 | .subarrays 400 | .iter() 401 | .map(|r| { 402 | let start = r.start + item; 403 | let end = r.end + item; 404 | start..end 405 | }) 406 | .collect::>(), 407 | ); 408 | item += buffer.items(); 409 | } 410 | Ok(RaggedBuffer { 411 | data, 412 | subarrays, 413 | features: buffers[0].features, 414 | }) 415 | } 416 | 1 => { 417 | if buffers 418 | .iter() 419 | .any(|b| b.subarrays.len() != buffers[0].subarrays.len()) 420 | { 421 | return Err(Error::generic(format!( 422 | "All buffers must have the same number of subarrays, but found {}", 423 | buffers 424 | .iter() 425 | .map(|b| b.subarrays.len().to_string()) 426 | .collect::>() 427 | .join(", ") 428 | ))); 429 | } 430 | if buffers.iter().any(|b| b.features != buffers[0].features) { 431 | return Err(Error::generic(format!( 432 | "All buffers must have the same number of features, but found {}", 433 | buffers 434 | .iter() 435 | .map(|b| b.features.to_string()) 436 | .collect::>() 437 | .join(", ") 438 | ))); 439 | } 440 | let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum()); 441 | let mut subarrays = 442 | Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum()); 443 | let mut item = 0; 444 | let mut last_item = 0; 445 | for i in 0..buffers[0].subarrays.len() { 446 | for buffer in buffers { 447 | let Range { start, end } = &buffer.subarrays[i]; 448 | data.extend_from_slice( 449 | &buffer.data[start * buffer.features..end * buffer.features], 450 | ); 451 | item += end - start; 452 | } 453 | subarrays.push(Range { 454 | start: last_item, 455 | end: item, 456 | }); 457 | last_item = item; 458 | } 459 | Ok(RaggedBuffer { 460 | data, 461 | subarrays, 462 | features: buffers[0].features, 463 | }) 464 | } 465 | 2 => { 466 | // TODO: disallow broadcasting on some sequences but not other? 467 | // TODO: think more about empty sequences 468 | let sequences = buffers[0].size0(); 469 | if buffers.iter().any(|b| b.size0() != sequences) { 470 | return Err(Error::generic(format!( 471 | "All buffers must have the same number of sequences, but found {}", 472 | buffers 473 | .iter() 474 | .map(|b| b.size0().to_string()) 475 | .collect::>() 476 | .join(", ") 477 | ))); 478 | } 479 | 480 | let features = buffers.iter().map(|b| b.features).sum(); 481 | let mut subarrays = Vec::with_capacity(sequences); 482 | let mut data = Vec::with_capacity(sequences * features); 483 | let mut items = 0; 484 | for iseq in 0..sequences { 485 | let seqlen = if buffers.iter().any(|b| { 486 | b.size1(iseq) 487 | .expect("All sequences should be the same length.") 488 | == 0 489 | }) { 490 | 0 491 | } else { 492 | buffers 493 | .iter() 494 | .map(|b| { 495 | b.size1(iseq) 496 | .expect("All sequences should be the same length.") 497 | }) 498 | .max() 499 | .expect("There should be at least one buffer.") 500 | }; 501 | subarrays.push(items..items + seqlen); 502 | items += seqlen; 503 | for iitem in 0..seqlen { 504 | for (ibuf, buffer) in buffers.iter().enumerate() { 505 | let _items = buffer.subarrays[iseq].len(); 506 | if _items == 1 { 507 | data.extend_from_slice( 508 | &buffer.data[buffer.subarrays[iseq].start * buffer.features 509 | ..buffer.subarrays[iseq].end * buffer.features], 510 | ); 511 | } else { 512 | if _items != seqlen { 513 | return Err(Error::generic(format!( 514 | "Buffer {} has {} items for sequence {}, but expected {}", 515 | ibuf, _items, iseq, seqlen 516 | ))); 517 | } 518 | let start_item = buffer.subarrays[iseq].start + iitem; 519 | data.extend_from_slice( 520 | &buffer.data[start_item * buffer.features 521 | ..(start_item + 1) * buffer.features], 522 | ); 523 | } 524 | } 525 | } 526 | } 527 | 528 | Ok(RaggedBuffer { 529 | data, 530 | subarrays, 531 | features, 532 | }) 533 | } 534 | _ => Err(Error::generic(format!( 535 | "Invalid dimension {}, RaggedBuffer only has 3 dimensions", 536 | dim 537 | ))), 538 | } 539 | } 540 | 541 | #[allow(clippy::type_complexity)] 542 | pub fn padpack(&self) -> Option<(Vec, Vec, Vec, (usize, usize))> { 543 | if self.subarrays.is_empty() 544 | || self 545 | .subarrays 546 | .iter() 547 | .all(|r| r.end - r.start == self.subarrays[0].end - self.subarrays[0].start) 548 | { 549 | return None; 550 | } 551 | 552 | let mut padbpack_index = vec![]; 553 | let mut padpack_batch = vec![]; 554 | let mut padpack_inverse_index = vec![]; 555 | let max_seq_len = self 556 | .subarrays 557 | .iter() 558 | .map(|r| r.end - r.start) 559 | .max() 560 | .unwrap(); 561 | let mut sequences: BinaryHeap = binary_heap::BinaryHeap::new(); 562 | 563 | for (batch_index, subarray) in self.subarrays.iter().enumerate() { 564 | let (free, packed_batch_index) = match sequences.peek().cloned() { 565 | Some(seq) if seq.free >= subarray.end - subarray.start => { 566 | sequences.pop(); 567 | (seq.free, seq.batch_index) 568 | } 569 | _ => { 570 | for _ in 0..max_seq_len { 571 | padbpack_index.push(0); 572 | padpack_batch.push(f32::NAN); 573 | } 574 | (max_seq_len, sequences.len()) 575 | } 576 | }; 577 | 578 | for (i, item) in subarray.clone().enumerate() { 579 | let packed_index = packed_batch_index * max_seq_len + max_seq_len - free + i; 580 | padbpack_index[packed_index] = item as i64; 581 | padpack_batch[packed_index] = batch_index as f32; 582 | padpack_inverse_index.push(packed_index as i64); 583 | } 584 | sequences.push(Sequence { 585 | batch_index: packed_batch_index, 586 | free: free - (subarray.end - subarray.start), 587 | }); 588 | } 589 | 590 | Some(( 591 | padbpack_index, 592 | padpack_batch, 593 | padpack_inverse_index, 594 | (sequences.len(), max_seq_len), 595 | )) 596 | } 597 | 598 | pub fn items(&self) -> usize { 599 | self.subarrays.last().map(|r| r.end).unwrap_or(0) 600 | } 601 | 602 | pub fn len(&self) -> usize { 603 | self.data.len() 604 | } 605 | 606 | pub fn is_empty(&self) -> bool { 607 | self.data.is_empty() 608 | } 609 | } 610 | 611 | #[derive(Copy, Clone, Eq, PartialEq, Debug)] 612 | struct Sequence { 613 | free: usize, 614 | batch_index: usize, 615 | } 616 | 617 | impl Ord for Sequence { 618 | fn cmp(&self, other: &Self) -> Ordering { 619 | self.free 620 | .cmp(&other.free) 621 | .then_with(|| other.batch_index.cmp(&self.batch_index)) 622 | } 623 | } 624 | 625 | impl PartialOrd for Sequence { 626 | fn partial_cmp(&self, other: &Self) -> Option { 627 | Some(self.cmp(other)) 628 | } 629 | } 630 | -------------------------------------------------------------------------------- /src/ragged_buffer_view.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Display; 2 | use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; 3 | 4 | use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, ToPyArray}; 5 | use pyo3::{exceptions, PyErr, PyResult, Python}; 6 | 7 | use crate::monomorphs::Index; 8 | use crate::ragged_buffer::{BinOp, Error, RaggedBuffer}; 9 | 10 | #[derive(Clone, PartialEq, Eq, Hash, Debug)] 11 | enum Slice { 12 | Range { 13 | start: usize, 14 | end: usize, 15 | step: usize, 16 | }, 17 | Permutation(Vec), 18 | } 19 | 20 | impl Slice { 21 | fn into_iter(self) -> Box> { 22 | match self { 23 | Slice::Range { start, end, step } => Box::new((start..end).step_by(step)), 24 | Slice::Permutation(permutation) => Box::new(permutation.into_iter()), 25 | } 26 | } 27 | 28 | fn len(&self) -> usize { 29 | match self { 30 | Slice::Range { start, end, step } => (end - start + step - 1) / step, 31 | Slice::Permutation(permutation) => permutation.len(), 32 | } 33 | } 34 | } 35 | 36 | // TODO: Eq/PartialEq/Hash 37 | #[derive(Clone, Debug)] 38 | pub struct RaggedBufferView { 39 | pub inner: Arc>>, 40 | view: Option<(Slice, Slice, Slice)>, 41 | } 42 | 43 | impl RaggedBufferView { 44 | pub fn new(features: usize) -> Self { 45 | RaggedBufferView { 46 | inner: Arc::new(RwLock::new(RaggedBuffer::new(features))), 47 | view: None, 48 | } 49 | } 50 | 51 | pub fn get_slice<'a>( 52 | &self, 53 | py: Python<'a>, 54 | i0: Index, 55 | i1: Index, 56 | i2: Index, 57 | ) -> PyResult> { 58 | // TODO: Check that i0, i1, i2 are valid indices 59 | let materialized = self.materialize(); 60 | let v0 = match i0 { 61 | Index::PermutationNP(np) => { 62 | Slice::Permutation(np.to_vec()?.into_iter().map(|x| x as usize).collect()) 63 | } 64 | Index::Permutation(p) => Slice::Permutation(p), 65 | Index::Int(i) => Slice::Range { 66 | start: i, 67 | end: i + 1, 68 | step: 1, 69 | }, 70 | Index::Slice(slice) => { 71 | let indices = slice 72 | .as_ref(py) 73 | .indices(materialized.size0().try_into().unwrap())?; 74 | Slice::Range { 75 | start: indices.start as usize, 76 | end: indices.stop as usize, 77 | step: indices.step as usize, 78 | } 79 | } 80 | }; 81 | let v1 = match i1 { 82 | Index::PermutationNP(np) => { 83 | Slice::Permutation(np.to_vec()?.into_iter().map(|x| x as usize).collect()) 84 | } 85 | Index::Permutation(p) => Slice::Permutation(p), 86 | Index::Int(i) => Slice::Range { 87 | start: i, 88 | end: i + 1, 89 | step: 1, 90 | }, 91 | Index::Slice(slice) => { 92 | let indices = slice 93 | .as_ref(py) 94 | .indices(materialized.len()?.try_into().unwrap())?; 95 | Slice::Range { 96 | start: indices.start as usize, 97 | end: indices.stop as usize, 98 | step: indices.step as usize, 99 | } 100 | } 101 | }; 102 | let v2 = match i2 { 103 | Index::PermutationNP(np) => { 104 | Slice::Permutation(np.to_vec()?.into_iter().map(|x| x as usize).collect()) 105 | } 106 | Index::Permutation(p) => Slice::Permutation(p), 107 | Index::Int(i) => Slice::Range { 108 | start: i, 109 | end: i + 1, 110 | step: 1, 111 | }, 112 | Index::Slice(slice) => { 113 | let indices = slice 114 | .as_ref(py) 115 | .indices(materialized.size2().try_into().unwrap())?; 116 | Slice::Range { 117 | start: indices.start as usize, 118 | end: indices.stop as usize, 119 | step: indices.step as usize, 120 | } 121 | } 122 | }; 123 | 124 | Ok(RaggedBufferView { 125 | inner: materialized.inner, 126 | view: Some((v0, v1, v2)), 127 | }) 128 | } 129 | 130 | fn get(&self) -> RwLockReadGuard> { 131 | self.inner.read().unwrap() 132 | } 133 | 134 | fn get_mut(&self) -> RwLockWriteGuard> { 135 | self.inner.write().unwrap() 136 | } 137 | 138 | fn make_contiguous(&mut self) { 139 | let materialized = self.materialize(); 140 | self.inner = materialized.inner; 141 | self.view = None; 142 | } 143 | fn require_contiguous(&self, method_name: &str) -> PyResult<()> { 144 | match self.view { 145 | Some(_) => Err(pyo3::exceptions::PyValueError::new_err(format!( 146 | "Cannot call method {} on a view. Call .materialize() first to get a materialized copy of the view.", 147 | method_name 148 | ))), 149 | None => Ok(()), 150 | } 151 | } 152 | 153 | pub fn from_array(data: PyReadonlyArray3) -> Self { 154 | RaggedBufferView { 155 | inner: Arc::new(RwLock::new(RaggedBuffer::from_array(data.as_array()))), 156 | view: None, 157 | } 158 | } 159 | 160 | pub fn from_flattened( 161 | data: PyReadonlyArray2, 162 | lengths: PyReadonlyArray1, 163 | ) -> PyResult { 164 | Ok(RaggedBufferView { 165 | inner: Arc::new(RwLock::new(RaggedBuffer::from_flattened( 166 | data.as_array(), 167 | lengths.as_array(), 168 | )?)), 169 | view: None, 170 | }) 171 | } 172 | 173 | pub fn extend(&mut self, other: &RaggedBufferView) -> PyResult<()> { 174 | self.make_contiguous(); 175 | let other = other.materialize(); 176 | let other = other.get(); 177 | self.get_mut().extend(&*other).map_err(Into::into) 178 | } 179 | 180 | pub fn clear(&mut self) -> PyResult<()> { 181 | self.make_contiguous(); 182 | self.get_mut().clear(); 183 | Ok(()) 184 | } 185 | 186 | pub fn as_array<'a>( 187 | &self, 188 | py: Python<'a>, 189 | ) -> PyResult<&'a numpy::PyArray>> { 190 | match self.view { 191 | None => { 192 | let inner = self.get(); 193 | inner 194 | .data 195 | .to_pyarray(py) 196 | .reshape((inner.items(), inner.features)) 197 | .map_err(Into::into) 198 | } 199 | _ => self.materialize().as_array(py), 200 | } 201 | } 202 | 203 | pub fn materialize(&self) -> RaggedBufferView { 204 | match self.view.clone() { 205 | Some(( 206 | Slice::Range { 207 | start: start0, 208 | end: end0, 209 | step: step0, 210 | }, 211 | Slice::Range { 212 | start: start1, 213 | end: end1, 214 | step: step1, 215 | }, 216 | Slice::Range { 217 | start: start2, 218 | end: end2, 219 | step: step2, 220 | }, 221 | )) => { 222 | let mut data = Vec::new(); 223 | let mut subarrays = Vec::new(); 224 | let mut item = 0; 225 | let inner = self.get(); 226 | for i0 in (start0..end0).step_by(step0) { 227 | let mut items = 0; 228 | for i1 in inner.subarrays[i0] 229 | .clone() 230 | .skip(start1) 231 | .take(end1 - start1) 232 | .step_by(step1) 233 | { 234 | for i2 in (start2..end2).step_by(step2) { 235 | data.push(inner.data[i1 * inner.features + i2]); 236 | } 237 | items += 1; 238 | } 239 | subarrays.push(item..item + items); 240 | item += items; 241 | } 242 | let features = (end2 - start2 + step2 - 1) / step2; 243 | let materialized = RaggedBuffer { 244 | data, 245 | subarrays, 246 | features, 247 | }; 248 | RaggedBufferView { 249 | inner: Arc::new(RwLock::new(materialized)), 250 | view: None, 251 | } 252 | } 253 | Some((v0, v1, v2)) => { 254 | let mut data = Vec::new(); 255 | let mut items = 0; 256 | let mut subarrays = Vec::new(); 257 | let inner = self.get(); 258 | for i0 in v0.into_iter() { 259 | let item_start = items; 260 | let subarray = inner.subarrays[i0].clone(); 261 | for i1 in v1.clone().into_iter() { 262 | if i1 >= subarray.len() { 263 | break; 264 | } 265 | let offset = (subarray.start + i1) * inner.features; 266 | for i2 in v2.clone().into_iter() { 267 | data.push(inner.data[offset + i2]); 268 | } 269 | items += 1; 270 | } 271 | subarrays.push(item_start..items); 272 | } 273 | let features = v2.len(); 274 | let materialized = RaggedBuffer { 275 | data, 276 | subarrays, 277 | features, 278 | }; 279 | RaggedBufferView { 280 | inner: Arc::new(RwLock::new(materialized)), 281 | view: None, 282 | } 283 | } 284 | None => self.clone(), 285 | } 286 | } 287 | 288 | pub fn push(&mut self, x: &PyReadonlyArray2) -> PyResult<()> { 289 | self.make_contiguous(); 290 | self.get_mut().push(&x.as_array()).map_err(Into::into) 291 | } 292 | 293 | pub fn push_empty(&mut self) -> PyResult<()> { 294 | self.make_contiguous(); 295 | self.get_mut().push_empty(); 296 | Ok(()) 297 | } 298 | 299 | pub fn swizzle(&self, indices: PyReadonlyArray1) -> PyResult> { 300 | match self.view { 301 | Some((_, _, _)) => todo!(), 302 | None => Ok(self.get().swizzle(indices.as_array())?.view()), 303 | } 304 | } 305 | 306 | pub fn swizzle_usize(&self, indices: &[usize]) -> PyResult> { 307 | match self.view { 308 | Some((_, _, _)) => todo!(), 309 | None => Ok(self.get().swizzle_usize(indices)?.view()), 310 | } 311 | } 312 | 313 | pub fn get_sequence(&self, i: usize) -> PyResult> { 314 | self.require_contiguous("get_sequence")?; 315 | Ok(self.get().get(i).view()) 316 | } 317 | 318 | pub fn size0(&self) -> usize { 319 | match &self.view { 320 | Some((s0, _, _)) => s0.len(), 321 | None => self.get().size0(), 322 | } 323 | } 324 | 325 | pub fn size2(&self) -> usize { 326 | match &self.view { 327 | Some((_, _, s2)) => s2.len(), 328 | None => self.get().size2(), 329 | } 330 | } 331 | 332 | pub fn lengths<'a>( 333 | &self, 334 | py: Python<'a>, 335 | ) -> PyResult<&'a numpy::PyArray>> { 336 | match self.view { 337 | None => Ok(self.get().lengths().to_pyarray(py)), 338 | Some(( 339 | Slice::Range { 340 | start: start0, 341 | end: end0, 342 | step: step0, 343 | }, 344 | Slice::Range { 345 | start: start1, 346 | end: end1, 347 | step: step1, 348 | }, 349 | _, 350 | )) => { 351 | let mut lengths = Vec::with_capacity((end0 - start0) / step0); 352 | let inner = self.get(); 353 | for i0 in (start0..end0).step_by(step0) { 354 | let end1 = std::cmp::min(end1, inner.subarrays[i0].len()); 355 | if end1 > start1 { 356 | let stepsf = (end1 - start1) / step1; 357 | lengths.push( 358 | stepsf as i64 + if stepsf * step1 < end1 - start1 { 1 } else { 0 }, 359 | ); 360 | } else { 361 | lengths.push(0); 362 | } 363 | } 364 | Ok(lengths.to_pyarray(py)) 365 | } 366 | _ => { 367 | self.require_contiguous("lengths")?; 368 | Ok(self.get().lengths().to_pyarray(py)) 369 | } 370 | } 371 | } 372 | 373 | pub fn size1(&mut self, i: usize) -> PyResult { 374 | self.make_contiguous(); 375 | self.get().size1(i).map_err(Into::into) 376 | } 377 | 378 | pub fn __str__(&self) -> PyResult { 379 | self.materialize().get().__str__().map_err(Into::into) 380 | } 381 | 382 | pub fn binop>(&self, rhs: &RaggedBufferView) -> PyResult> { 383 | self.require_contiguous("binop")?; 384 | Ok(self.get().binop::(&*rhs.get())?.view()) 385 | } 386 | 387 | pub fn op_scalar>(&self, scalar: T) -> PyResult> { 388 | self.require_contiguous("op_scalar")?; 389 | Ok(self.get().op_scalar::(scalar).view()) 390 | } 391 | 392 | pub fn indices(&mut self, dim: usize) -> PyResult> { 393 | self.make_contiguous(); 394 | Ok(self.get().indices(dim)?.view()) 395 | } 396 | 397 | pub fn flat_indices(&mut self) -> PyResult> { 398 | self.make_contiguous(); 399 | Ok(self.get().flat_indices()?.view()) 400 | } 401 | 402 | pub fn cat(buffers: &[&RaggedBufferView], dim: usize) -> PyResult> { 403 | if buffers.is_empty() { 404 | return Err(pyo3::exceptions::PyValueError::new_err( 405 | "cat requires at least one ragged buffer", 406 | )); 407 | } 408 | let mut rbs = Vec::new(); 409 | for b in buffers { 410 | b.require_contiguous("cat")?; 411 | rbs.push(b.get()); 412 | } 413 | let rb = RaggedBuffer::cat(&rbs.iter().map(|r| &**r).collect::>(), dim)?; 414 | Ok(RaggedBufferView { 415 | inner: Arc::new(RwLock::new(rb)), 416 | view: None, 417 | }) 418 | } 419 | 420 | #[allow(clippy::type_complexity)] 421 | pub fn padpack(&mut self) -> PyResult, Vec, Vec, (usize, usize))>> { 422 | self.make_contiguous(); 423 | Ok(self.get().padpack()) 424 | } 425 | 426 | #[allow(clippy::len_without_is_empty)] 427 | pub fn len(&self) -> PyResult { 428 | self.require_contiguous("len")?; 429 | Ok(self.get().len()) 430 | } 431 | 432 | pub fn is_empty(&mut self) -> PyResult { 433 | self.make_contiguous(); 434 | Ok(self.get().is_empty()) 435 | } 436 | 437 | pub fn items(&mut self) -> PyResult { 438 | self.make_contiguous(); 439 | Ok(self.get().items()) 440 | } 441 | 442 | pub fn binop_mut>(&self, rhs: &RaggedBufferView) -> PyResult<()> { 443 | let (lhs_i0, lhs_i1, lhs_i2) = self.view.clone().unwrap(); 444 | let (rhs_i0, rhs_i1, rhs_i2) = rhs.view.clone().unwrap(); 445 | 446 | let (lhs_iter_0, rhs_iter_0) = if self.size0() == rhs.size0() { 447 | (lhs_i0.into_iter(), rhs_i0.into_iter()) 448 | } else { 449 | return Err(exceptions::PyValueError::new_err(format!( 450 | "size mismatch in first dimension: {} != {}", 451 | self.size0(), 452 | rhs.size0(), 453 | ))); 454 | }; 455 | assert!(matches!(lhs_i1, Slice::Range { .. })); 456 | assert!(matches!(rhs_i1, Slice::Range { .. })); 457 | if self.size2() != rhs.size2() { 458 | return Err(exceptions::PyValueError::new_err(format!( 459 | "size mismatch in third dimension: {} != {}", 460 | self.size2(), 461 | rhs.size2(), 462 | ))); 463 | }; 464 | 465 | let stride2l = self.get().size2(); 466 | let stride2r = rhs.get().size2(); 467 | 468 | let mut lhs = self.get_mut(); 469 | let rhs = rhs.get(); 470 | for (l0, r0) in lhs_iter_0.zip(rhs_iter_0) { 471 | let (lhs_iter_1, rhs_iter_1): ( 472 | Box>, 473 | Box>, 474 | ) = if lhs.subarrays[l0].len() != rhs.subarrays[r0].len() { 475 | if lhs.subarrays[l0].len() == 1 { 476 | ( 477 | Box::new( 478 | vec![lhs.subarrays[l0].start; rhs.subarrays[r0].len()].into_iter(), 479 | ), 480 | Box::new(rhs.subarrays[r0].clone()), 481 | ) 482 | } else if rhs.subarrays[r0].len() == 1 { 483 | ( 484 | Box::new(lhs.subarrays[l0].clone()), 485 | Box::new( 486 | vec![rhs.subarrays[r0].start; lhs.subarrays[l0].len()].into_iter(), 487 | ), 488 | ) 489 | } else { 490 | return Err(exceptions::PyValueError::new_err(format!( 491 | "size mismatch between {}th and {}th sequence: {} != {}", 492 | l0, 493 | r0, 494 | lhs.subarrays[l0].len(), 495 | rhs.subarrays[r0].len(), 496 | ))); 497 | } 498 | } else { 499 | ( 500 | Box::new(lhs.subarrays[l0].clone()), 501 | Box::new(rhs.subarrays[r0].clone()), 502 | ) 503 | }; 504 | for (l1, r1) in lhs_iter_1.zip(rhs_iter_1) { 505 | for (l2, r2) in lhs_i2.clone().into_iter().zip(rhs_i2.clone().into_iter()) { 506 | lhs.data[l1 * stride2l + l2] = 507 | Op::op(lhs.data[l1 * stride2l + l2], rhs.data[r1 * stride2r + r2]); 508 | } 509 | } 510 | } 511 | 512 | Ok(()) 513 | } 514 | 515 | pub fn deepclone(&self) -> RaggedBufferView { 516 | let inner = self.get().clone(); 517 | RaggedBufferView { 518 | inner: Arc::new(RwLock::new(inner)), 519 | view: self.view.clone(), 520 | } 521 | } 522 | } 523 | 524 | pub fn translate_rotate( 525 | source: &RaggedBufferView, 526 | translation: &RaggedBufferView, 527 | rotation: &RaggedBufferView, 528 | ) -> PyResult<()> { 529 | if source.size0() != translation.size0() { 530 | return Err(exceptions::PyValueError::new_err(format!( 531 | "size mismatch in first dimension: {} != {}", 532 | source.size0(), 533 | translation.size0(), 534 | ))); 535 | } 536 | if source.size2() != 2 { 537 | return Err(exceptions::PyValueError::new_err(format!( 538 | "expected 2D source, got {}D", 539 | source.size2(), 540 | ))); 541 | } 542 | if translation.size2() != 2 { 543 | return Err(exceptions::PyValueError::new_err(format!( 544 | "expected 2D translation, got {}D", 545 | translation.size2(), 546 | ))); 547 | } 548 | if rotation.size2() != 2 { 549 | return Err(exceptions::PyValueError::new_err(format!( 550 | "expected rotation to be a 2D direction, got {}D", 551 | rotation.size2(), 552 | ))); 553 | } 554 | let (s0, _, s2) = source.view.clone().unwrap(); 555 | let (t0, _, t2) = translation.view.clone().unwrap(); 556 | let (r0, _, r2) = rotation.view.clone().unwrap(); 557 | let mut source = source.get_mut(); 558 | let translation = translation.get(); 559 | let rotation = rotation.get(); 560 | 561 | let ss0 = source.size0(); 562 | let ts0 = translation.size0(); 563 | let rs0 = rotation.size0(); 564 | match s0 { 565 | Slice::Range { start, end, step } if start == 0 && end == ss0 && step == 1 => {} 566 | _ => { 567 | return Err(exceptions::PyValueError::new_err( 568 | "view on first dimension of source not supported".to_string(), 569 | )) 570 | } 571 | } 572 | match t0 { 573 | Slice::Range { start, end, step } if start == 0 && end == ts0 && step == 1 => {} 574 | _ => { 575 | return Err(exceptions::PyValueError::new_err( 576 | "view on first dimension of translation not supported".to_string(), 577 | )) 578 | } 579 | } 580 | match r0 { 581 | Slice::Range { start, end, step } if start == 0 && end == rs0 && step == 1 => {} 582 | _ => { 583 | return Err(exceptions::PyValueError::new_err( 584 | "view on first dimension of rotation not supported".to_string(), 585 | )) 586 | } 587 | } 588 | let (sxi, syi) = match s2 { 589 | Slice::Range { start, step, .. } => (start, start + step), 590 | Slice::Permutation(indices) => (indices[0], indices[1]), 591 | }; 592 | let (txi, tyi) = match t2 { 593 | Slice::Range { start, step, .. } => (start, start + step), 594 | Slice::Permutation(indices) => (indices[0], indices[1]), 595 | }; 596 | let (rxi, ryi) = match r2 { 597 | Slice::Range { start, step, .. } => (start, start + step), 598 | Slice::Permutation(indices) => (indices[0], indices[1]), 599 | }; 600 | let sstride = source.features; 601 | for i0 in 0..source.size0() { 602 | if translation.size1(i0)? != 1 || rotation.size1(i0)? != 1 { 603 | return Err(exceptions::PyValueError::new_err(format!( 604 | "must have single item in translation and rotation for each sequence, but got {} and {} items for sequence {}", 605 | translation.size1(i0)?, rotation.size1(i0)?, i0, 606 | ))); 607 | } 608 | // TODO: check no view on dim 1 609 | for i1 in source.subarrays[i0].clone() { 610 | let sstart = i1 * sstride; 611 | source.data[sstart + sxi] -= translation.data[i0 * translation.features + txi]; 612 | source.data[sstart + syi] -= translation.data[i0 * translation.features + tyi]; 613 | let rx = rotation.data[i0 * rotation.features + rxi]; 614 | let ry = rotation.data[i0 * rotation.features + ryi]; 615 | let sx = source.data[sstart + sxi]; 616 | let sy = source.data[sstart + syi]; 617 | source.data[sstart + sxi] = rx * sx + ry * sy; 618 | source.data[sstart + syi] = -ry * sx + rx * sy; 619 | } 620 | } 621 | Ok(()) 622 | } 623 | 624 | impl PartialEq 625 | for RaggedBufferView 626 | { 627 | fn eq(&self, other: &RaggedBufferView) -> bool { 628 | // TODO: implement for views 629 | self.require_contiguous("eq").unwrap(); 630 | other.require_contiguous("eq").unwrap(); 631 | *self.get() == *other.get() 632 | } 633 | } 634 | 635 | impl Eq for RaggedBufferView {} 636 | 637 | impl RaggedBuffer { 638 | pub fn view(self) -> RaggedBufferView { 639 | RaggedBufferView { 640 | inner: Arc::new(RwLock::new(self)), 641 | view: None, 642 | } 643 | } 644 | } 645 | 646 | impl From for PyErr { 647 | fn from(Error::Generic(msg): Error) -> PyErr { 648 | exceptions::PyValueError::new_err(msg) 649 | } 650 | } 651 | --------------------------------------------------------------------------------