├── .coveragerc ├── .github └── workflows │ └── run_tests.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── azure-pipelines.yml ├── setup.py └── slicer ├── __init__.py ├── slicer.py ├── slicer_internal.py ├── test_slicer.py └── utils_testing.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = *test* 3 | [tool:pytest] 4 | addopts = --cov=slicer --cov-report term-missing --doctest-modules -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | run_tests: 13 | strategy: 14 | matrix: 15 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 16 | fail-fast: false 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install pytest numpy pandas scipy torch 28 | - name: Install project 29 | run: pip install . 30 | - name: Test with pytest 31 | run: pytest 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | *.pyc 4 | .coverage 5 | coverage.xml 6 | build/ 7 | junit/ 8 | dist/ 9 | *.egg-info/ 10 | htmlcov/ 11 | .idea/ 12 | staging/ 13 | .vscode/ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and the versioning is mostly derived from [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [v0.0.7] - 2020-12-15 8 | ### Fixed 9 | - Fix around alias re-assignment. 10 | 11 | ## [v0.0.6] - 2020-12-11 12 | ### Added 13 | - Added support for ragged numpy arrays. 14 | - Added support for assignment on existing tracked objects. 15 | 16 | ## [v0.0.5] - 2020-11-04 17 | ### Added 18 | - Added support for SciPy csr, csc, lil, dok matrices. 19 | 20 | ## [v0.0.4] - 2020-09-17 21 | ### Added 22 | - Initial public release. 23 | 24 | [v0.0.7]: https://github.com/interpretml/slicer/releases/tag/v0.0.7 25 | [v0.0.6]: https://github.com/interpretml/slicer/releases/tag/v0.0.6 26 | [v0.0.5]: https://github.com/interpretml/slicer/releases/tag/v0.0.5 27 | [v0.0.4]: https://github.com/interpretml/slicer/releases/tag/v0.0.4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The InterpretML Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # slicer [alpha] 2 | ![License](https://img.shields.io/github/license/interpretml/slicer.svg?style=flat-square) 3 | ![Python Version](https://img.shields.io/pypi/pyversions/slicer.svg?style=flat-square) 4 | ![Package Version](https://img.shields.io/pypi/v/slicer.svg?style=flat-square) 5 | ![Maintenance](https://img.shields.io/maintenance/yes/2025.svg?style=flat-square) 6 | 7 | *(Equal Contribution) Samuel Jenkins & Harsha Nori & Scott Lundberg* 8 | 9 | **slicer** wraps tensor-like objects and provides a uniform slicing interface via `__getitem__`. 10 | 11 |
12 | It supports many data types including: 13 | 14 |    15 | [numpy](https://github.com/numpy/numpy) | 16 | [pandas](https://github.com/pandas-dev/pandas) | 17 | [scipy](https://docs.scipy.org/doc/scipy/reference/sparse.html) | 18 | [pytorch](https://github.com/pytorch/pytorch) | 19 | [list](https://github.com/python/cpython) | 20 | [tuple](https://github.com/python/cpython) | 21 | [dict](https://github.com/python/cpython) 22 | 23 | And enables upgraded slicing functionality on its objects: 24 | ```python 25 | # Handles non-integer indexes for slicing. 26 | S(df)[:, ["Age", "Income"]] 27 | 28 | # Handles nested slicing in one call. 29 | S(nested_list)[..., :5] 30 | ``` 31 | 32 | It can also simultaneously slice many objects at once: 33 | ```python 34 | # Gets first elements of both objects. 35 | S(first=df, second=ar)[0, :] 36 | ``` 37 | 38 | This package has **0** dependencies. Not even one. 39 | 40 | ## Installation 41 | 42 | Python 3.6+ | Linux, Mac, Windows 43 | ```sh 44 | pip install slicer 45 | ``` 46 | 47 | ## Getting Started 48 | 49 | Basic anonymous slicing: 50 | ```python 51 | from slicer import Slicer as S 52 | li = [[1, 2, 3], [4, 5, 6]] 53 | S(li)[:, 0:2].o 54 | # [[1, 2], [4, 5]] 55 | di = {'x': [1, 2, 3], 'y': [4, 5, 6]} 56 | S(di)[:, 0:2].o 57 | # {'x': [1, 2], 'y': [4, 5]} 58 | ``` 59 | 60 | Basic named slicing: 61 | ```python 62 | import pandas as pd 63 | import numpy as np 64 | df = pd.DataFrame({'A': [1, 3], 'B': [2, 4]}) 65 | ar = np.array([[5, 6], [7, 8]]) 66 | sliced = S(first=df, second=ar)[0, :] 67 | sliced.first 68 | # A 1 69 | # B 2 70 | # Name: 0, dtype: int64 71 | sliced.second 72 | # array([5, 6]) 73 | ``` 74 | 75 | Real example: 76 | ```python 77 | from slicer import Slicer as S 78 | from slicer import Alias as A 79 | 80 | data = [[1, 2], [3, 4]] 81 | values = [[5, 6], [7, 8]] 82 | identifiers = ["id1", "id1"] 83 | instance_names = ["r1", "r2"] 84 | feature_names = ["f1", "f2"] 85 | full_name = "A" 86 | 87 | slicer = S( 88 | data=data, 89 | values=values, 90 | # Aliases are objects that also function as slicing keys. 91 | # A(obj, dim) where dim informs what dimension it can be sliced on. 92 | identifiers=A(identifiers, 0), 93 | instance_names=A(instance_names, 0), 94 | feature_names=A(feature_names, 1), 95 | full_name=full_name, 96 | ) 97 | 98 | sliced = slicer[:, 1] # Tensor-like parallel slicing on all objects 99 | assert sliced.data == [2, 4] 100 | assert sliced.instance_names == ["r1", "r2"] 101 | assert sliced.feature_names == "f2" 102 | assert sliced.values == [6, 8] 103 | 104 | sliced = slicer["r1", "f2"] # Example use of aliasing 105 | assert sliced.data == 2 106 | assert sliced.feature_names == "f2" 107 | assert sliced.instance_names == "r1" 108 | assert sliced.values == 6 109 | ``` 110 | 111 | ## Contact us 112 | Raise an issue on GitHub, or contact us at interpret@microsoft.com 113 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | - job: 'Test' 3 | strategy: 4 | matrix: 5 | LinuxPython36: 6 | python.version: '3.6' 7 | image.name: 'ubuntu-18.04' 8 | LinuxPython37: 9 | python.version: '3.7' 10 | image.name: 'ubuntu-18.04' 11 | LinuxPython38: 12 | python.version: '3.8' 13 | image.name: 'ubuntu-18.04' 14 | WindowsPython36: 15 | python.version: '3.6' 16 | image.name: 'windows-2019' 17 | WindowsPython37: 18 | python.version: '3.7' 19 | image.name: 'windows-2019' 20 | WindowsPython38: 21 | python.version: '3.8' 22 | image.name: 'windows-2019' 23 | MacPython36: 24 | python.version: '3.6' 25 | image.name: 'macOS-10.14' 26 | MacPython37: 27 | python.version: '3.7' 28 | image.name: 'macOS-10.14' 29 | MacPython38: 30 | python.version: '3.8' 31 | image.name: 'macOS-10.14' 32 | maxParallel: 9 33 | pool: 34 | vmImage: '$(image.name)' 35 | steps: 36 | - task: UsePythonVersion@0 37 | condition: succeeded() 38 | inputs: 39 | versionSpec: '$(python.version)' 40 | architecture: 'x64' 41 | - script: python -m pip install --upgrade pip setuptools wheel 42 | condition: succeeded() 43 | displayName: 'Install tools' 44 | - script: python -m pip install pytest pytest-cov numpy pandas scipy 45 | condition: succeeded() 46 | displayName: 'Install test requirements' 47 | - script: python -m pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 48 | condition: or(startsWith(variables['image.name'], 'windows'), startsWith(variables['image.name'], 'ubuntu')) 49 | displayName: 'Install pytorch (windows/linux)' 50 | - script: python -m pip install torch 51 | condition: startsWith(variables['image.name'], 'macOS') 52 | displayName: 'Install pytorch (mac)' 53 | - script: | 54 | python -m pytest --doctest-modules --junitxml=junit/test-results.xml --cov=slicer --cov-report=xml --cov-report=html 55 | displayName: 'Run pytest' 56 | - task: PublishTestResults@2 57 | condition: succeededOrFailed() 58 | inputs: 59 | testResultsFiles: '**/test-*.xml' 60 | testRunTitle: 'Publish test results for Python $(python.version) at $(image.name)' 61 | displayName: 'Publish test results' 62 | - task: PublishCodeCoverageResults@1 63 | inputs: 64 | codeCoverageTool: Cobertura 65 | summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' 66 | reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov' 67 | displayName: 'Publish test coverage results' 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="slicer", 8 | version="0.0.8", 9 | author="InterpretML", 10 | author_email="interpret@microsoft.com", 11 | description="A small package for big slicing.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/interpretml/slicer", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3.8", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Development Status :: 3 - Alpha", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | ], 26 | python_requires='>=3.6', 27 | ) 28 | -------------------------------------------------------------------------------- /slicer/__init__.py: -------------------------------------------------------------------------------- 1 | """ Unified slicing for data. 2 | Less API, less think. 3 | """ 4 | from .slicer import Slicer, Alias, Obj 5 | -------------------------------------------------------------------------------- /slicer/slicer.py: -------------------------------------------------------------------------------- 1 | """ Public facing layer for slicer. 2 | The little slicer that could. 3 | """ 4 | # TODO: Move Obj and Alias class here. 5 | 6 | from .slicer_internal import AtomicSlicer, Alias, Obj, AliasLookup, Tracked, UnifiedDataHandler 7 | from .slicer_internal import reduced_o, resolve_dim, unify_slice 8 | 9 | 10 | class Slicer: 11 | """ Provides unified slicing to tensor-like objects. """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | """ Wraps objects in args and provides unified numpy-like slicing. 15 | 16 | Currently supports (with arbitrary nesting): 17 | 18 | - lists and tuples 19 | - dictionaries 20 | - numpy arrays 21 | - pandas dataframes and series 22 | - pytorch tensors 23 | 24 | Args: 25 | *args: Unnamed tensor-like objects. 26 | **kwargs: Named tensor-like objects. 27 | 28 | Examples: 29 | 30 | Basic anonymous slicing: 31 | 32 | >>> from slicer import Slicer as S 33 | >>> li = [[1, 2, 3], [4, 5, 6]] 34 | >>> S(li)[:, 0:2].o 35 | [[1, 2], [4, 5]] 36 | >>> di = {'x': [1, 2, 3], 'y': [4, 5, 6]} 37 | >>> S(di)[:, 0:2].o 38 | {'x': [1, 2], 'y': [4, 5]} 39 | 40 | Basic named slicing: 41 | 42 | >>> import pandas as pd 43 | >>> import numpy as np 44 | >>> df = pd.DataFrame({'A': [1, 3], 'B': [2, 4]}) 45 | >>> ar = np.array([[5, 6], [7, 8]]) 46 | >>> sliced = S(first=df, second=ar)[0, :] 47 | >>> sliced.first 48 | A 1 49 | B 2 50 | Name: 0, dtype: int64 51 | >>> sliced.second 52 | array([5, 6]) 53 | 54 | """ 55 | self.__class__._init_slicer(self, *args, **kwargs) 56 | 57 | @classmethod 58 | def from_slicer(cls, *args, **kwargs): 59 | """ Alternative to SUPER SLICE 60 | Args: 61 | *args: 62 | **kwargs: 63 | 64 | Returns: 65 | 66 | """ 67 | slicer_instance = cls.__new__(cls) 68 | cls._init_slicer(slicer_instance, *args, **kwargs) 69 | return slicer_instance 70 | 71 | @classmethod 72 | def _init_slicer(cls, slicer_instance, *args, **kwargs): 73 | # NOTE: Protected attributes. 74 | slicer_instance._max_dim = 0 75 | 76 | # NOTE: Private attributes. 77 | slicer_instance._anon = [] 78 | slicer_instance._objects = {} 79 | slicer_instance._aliases = {} 80 | slicer_instance._alias_lookup = None 81 | 82 | # Go through unnamed objects / aliases 83 | slicer_instance.__setattr__("o", args) 84 | 85 | # Go through named objects / aliases 86 | for key, value in kwargs.items(): 87 | slicer_instance.__setattr__(key, value) 88 | 89 | # Generate default aliases only if one object and no aliases exist 90 | objects_len = len(slicer_instance._objects) 91 | anon_len = len(slicer_instance._anon) 92 | aliases_len = len(slicer_instance._aliases) 93 | if ((objects_len == 1) ^ (anon_len == 1)) and aliases_len == 0: 94 | obj = None 95 | for _, t in slicer_instance._iter_tracked(): 96 | obj = t 97 | 98 | generated_aliases = UnifiedDataHandler.default_alias(obj.o) 99 | for generated_alias in generated_aliases: 100 | slicer_instance.__setattr__(generated_alias._name, generated_alias) 101 | 102 | def __getitem__(self, item): 103 | index_tup = unify_slice(item, self._max_dim, self._alias_lookup) 104 | new_args = [] 105 | new_kwargs = {} 106 | for name, tracked in self._iter_tracked(include_aliases=True): 107 | if len(tracked.dim) == 0: # No slice on empty dim 108 | new_tracked = tracked 109 | else: 110 | index_slicer = AtomicSlicer(index_tup, max_dim=1) 111 | slicer_index = index_slicer[tracked.dim] 112 | sliced_o = tracked[slicer_index] 113 | sliced_dim = resolve_dim(index_tup, tracked.dim) 114 | 115 | new_tracked = tracked.__class__(sliced_o, sliced_dim) 116 | new_tracked._name = tracked._name 117 | 118 | if name == "o": 119 | new_args.append(new_tracked) 120 | else: 121 | new_kwargs[name] = new_tracked 122 | 123 | return self.__class__.from_slicer(*new_args, **new_kwargs) 124 | 125 | def __getattr__(self, item): 126 | """ Override default getattr to return tracked attribute. 127 | 128 | Args: 129 | item: Name of tracked attribute. 130 | Returns: 131 | Corresponding object. 132 | """ 133 | if item.startswith("_"): 134 | return super(Slicer, self).__getattr__(item) 135 | 136 | if item == "o": 137 | return reduced_o(self._anon) 138 | else: 139 | tracked = self._objects.get(item, None) 140 | if tracked is None: 141 | tracked = self._aliases.get(item, None) 142 | 143 | if tracked is None: 144 | raise AttributeError(f"Attribute '{item}' does not exist.") 145 | 146 | return tracked.o 147 | 148 | def __setattr__(self, key, value): 149 | """ Override default setattr to sync tracking of slicer. 150 | 151 | Args: 152 | key: Name of tracked attribute. 153 | value: Either an Obj, Alias or Python Object. 154 | """ 155 | if key.startswith("_"): 156 | return super(Slicer, self).__setattr__(key, value) 157 | 158 | # Grab previous objects if they exist: 159 | old_obj = self._objects.get(key, None) 160 | old_alias = self._aliases.get(key, None) 161 | 162 | # For existing attributes, honor Alias status and dimension unless specified otherwise 163 | if getattr(self, key, None) is not None and key != "o": 164 | if not isinstance(value, Tracked): 165 | if old_obj: 166 | value = Obj(value, dim=old_obj.dim) 167 | elif old_alias: 168 | value = Alias(value, dim=old_alias.dim) 169 | 170 | if isinstance(value, Alias): 171 | value._name = key 172 | self._aliases[key] = value 173 | 174 | if old_obj: # If object previously existed as an object, clean up all references. 175 | del self._objects[key] 176 | 177 | # Build lookup (for perf) 178 | if self._alias_lookup is None: 179 | self._alias_lookup = AliasLookup(self._aliases) 180 | else: 181 | if old_alias: 182 | self._alias_lookup.delete(old_alias) 183 | self._alias_lookup.update(value) 184 | else: 185 | if key == "o": 186 | tracked = [Obj(x) if not isinstance(x, Obj) else x for x in value] 187 | self._anon = tracked 188 | for t in tracked: 189 | self._update_max_dim(t) 190 | 191 | os = reduced_o(self._anon) 192 | super(Slicer, self).__setattr__(key, os) 193 | else: 194 | if old_alias: # If object previously existed as an alias, clean up all references. 195 | self._alias_lookup.delete(old_alias) 196 | del self._aliases[key] 197 | 198 | value = Obj(value) if not isinstance(value, Obj) else value 199 | value._name = key 200 | self._objects[key] = value 201 | self._update_max_dim(value) 202 | super(Slicer, self).__setattr__(key, value.o) 203 | 204 | def __delattr__(self, item): 205 | """ Override default delattr to remove tracked attribute. 206 | 207 | Args: 208 | item: Name of tracked attribute to delete. 209 | """ 210 | if item.startswith("_"): 211 | return super(Slicer, self).__delattr__(item) 212 | 213 | # Sync private attributes that help track 214 | self._objects.pop(item, None) 215 | self._aliases.pop(item, None) 216 | if item == "o": 217 | self._anon.clear() 218 | 219 | # Recompute max_dim 220 | self._recompute_max_dim() 221 | 222 | # Recompute alias lookup 223 | # NOTE: This doesn't use diff-style deletes, but we don't care (not a perf target). 224 | self._alias_lookup = AliasLookup(self._aliases) 225 | 226 | # TODO: Mutate and check interactively what it does 227 | super(Slicer, self).__delattr__(item) 228 | 229 | def __repr__(self): 230 | """ Override default repr for human readability. 231 | 232 | Returns: 233 | String to display. 234 | """ 235 | orig = self.__dict__ 236 | di = {} 237 | for key, value in orig.items(): 238 | if not key.startswith("_"): 239 | di[key] = value 240 | return f"{self.__class__.__name__}({str(di)})" 241 | 242 | def _update_max_dim(self, tracked): 243 | self._max_dim = max(self._max_dim, max(tracked.dim, default=-1) + 1) 244 | 245 | def _iter_tracked(self, include_aliases=False): 246 | for tracked in self._anon: 247 | yield "o", tracked 248 | for name, tracked in self._objects.items(): 249 | yield name, tracked 250 | if include_aliases: 251 | for name, tracked in self._aliases.items(): 252 | yield name, tracked 253 | 254 | def _recompute_max_dim(self): 255 | self._max_dim = max( 256 | [max(o.dim, default=-1) + 1 for _, o in self._iter_tracked()], default=0 257 | ) 258 | -------------------------------------------------------------------------------- /slicer/slicer_internal.py: -------------------------------------------------------------------------------- 1 | """ Lower level layer for slicer. 2 | Mom's spaghetti. 3 | """ 4 | # TODO: Consider boolean array indexing. 5 | 6 | from typing import Any, AnyStr, Union, List, Tuple 7 | from abc import abstractmethod 8 | import numbers 9 | 10 | 11 | class AtomicSlicer: 12 | """ Wrapping object that will unify slicing across data structures. 13 | 14 | What we support: 15 | Basic indexing (return references): 16 | - (start:stop:step) slicing 17 | - support ellipses 18 | Advanced indexing (return references): 19 | - integer array indexing 20 | 21 | Numpy Reference: 22 | Basic indexing (return views): 23 | - (start:stop:step) slicing 24 | - support ellipses and newaxis (alias for None) 25 | Advanced indexing (return copy): 26 | - integer array indexing, i.e. X[[1,2], [3,4]] 27 | - boolean array indexing 28 | - mixed array indexing (has integer array, ellipses, newaxis in same slice) 29 | """ 30 | 31 | def __init__(self, o: Any, max_dim: Union[None, int, AnyStr] = "auto"): 32 | """ Provides a consistent slicing API to the object provided. 33 | 34 | Args: 35 | o: Object to enable consistent slicing. 36 | Currently supports numpy dense arrays, recursive lists ending with list or numpy. 37 | max_dim: Max number of dimensions the wrapped object has. 38 | If set to "auto", max dimensions will be inferred. This comes at compute cost. 39 | """ 40 | self.o = o 41 | self.max_dim = max_dim 42 | if self.max_dim == "auto": 43 | self.max_dim = UnifiedDataHandler.max_dim(o) 44 | 45 | def __repr__(self) -> AnyStr: 46 | """ Override default repr for human readability. 47 | 48 | Returns: 49 | String to display. 50 | """ 51 | return f"{self.__class__.__name__}({self.o.__repr__()})" 52 | 53 | def __getitem__(self, item: Any) -> Any: 54 | """ Consistent slicing into wrapped object. 55 | 56 | Args: 57 | item: Slicing key of type integer or slice. 58 | 59 | Returns: 60 | Sliced object. 61 | 62 | Raises: 63 | ValueError: If slicing is not compatible with wrapped object. 64 | """ 65 | # Turn item into tuple if not already. 66 | index_tup = unify_slice(item, self.max_dim) 67 | 68 | # Slice according to object type. 69 | return UnifiedDataHandler.slice(self.o, index_tup, self.max_dim) 70 | 71 | 72 | def unify_slice(item: Any, max_dim: int, alias_lookup=None) -> Tuple: 73 | """ Resolves aliases and ellipses in a slice item. 74 | 75 | Args: 76 | item: Slicing key that is passed to __getitem__. 77 | max_dim: Max dimension of object to be sliced. 78 | alias_lookup: AliasLookup structure. 79 | 80 | Returns: 81 | A tuple representation of the item. 82 | """ 83 | item = _normalize_slice_key(item) 84 | index_tup = _normalize_subkey_types(item) 85 | index_tup = _handle_newaxis_ellipses(index_tup, max_dim) 86 | if alias_lookup: 87 | index_tup = _handle_aliases(index_tup, alias_lookup) 88 | return index_tup 89 | 90 | 91 | def _normalize_subkey_types(index_tup: Tuple) -> Tuple: 92 | """ Casts subkeys into basic types such as int. 93 | 94 | Args: 95 | key: Slicing key that is passed within __getitem__. 96 | 97 | Returns: 98 | Tuple with subkeys casted to basic types. 99 | """ 100 | new_index_tup = [] # Gets casted to tuple at the end 101 | 102 | np_int_types = { 103 | "int8", 104 | "int16", 105 | "int32", 106 | "int64", 107 | "uint8", 108 | "uint16", 109 | "uint32", 110 | "uint64", 111 | } 112 | for subkey in index_tup: 113 | if _safe_isinstance(subkey, "numpy", np_int_types): 114 | new_subkey = int(subkey) 115 | elif _safe_isinstance(subkey, "numpy", "ndarray"): 116 | if len(subkey.shape) == 1: 117 | new_subkey = subkey.tolist() 118 | else: 119 | raise ValueError(f"Cannot use array of shape {subkey.shape} as subkey.") 120 | else: 121 | new_subkey = subkey 122 | 123 | new_index_tup.append(new_subkey) 124 | return tuple(new_index_tup) 125 | 126 | 127 | def _normalize_slice_key(key: Any) -> Tuple: 128 | """ Normalizes slice key into always being a top-level tuple. 129 | 130 | Args: 131 | key: Slicing key that is passed within __getitem__. 132 | 133 | Returns: 134 | Expanded slice as a tuple. 135 | """ 136 | if not isinstance(key, tuple): 137 | return (key,) 138 | else: 139 | return key 140 | 141 | 142 | def _handle_newaxis_ellipses(index_tup: Tuple, max_dim: int) -> Tuple: 143 | """ Expands newaxis and ellipses within a slice for simplification. 144 | This code is mostly adapted from: https://github.com/clbarnes/h5py_like/blob/master/h5py_like/shape_utils.py#L111 145 | 146 | Args: 147 | index_tup: Slicing key as a tuple. 148 | max_dim: Maximum number of dimensions in the respective sliceable object. 149 | 150 | Returns: 151 | Expanded slice as a tuple. 152 | """ 153 | non_indexes = (None, Ellipsis) 154 | concrete_indices = sum(idx not in non_indexes for idx in index_tup) 155 | index_list = [] 156 | # newaxis_at = [] 157 | has_ellipsis = False 158 | int_count = 0 159 | for item in index_tup: 160 | if isinstance(item, numbers.Number): 161 | int_count += 1 162 | 163 | # NOTE: If we need locations of new axis, re-enable this. 164 | if item is None: # pragma: no cover 165 | pass 166 | # newaxis_at.append(len(index_list) + len(newaxis_at) - int_count) 167 | elif item == Ellipsis: 168 | if has_ellipsis: # pragma: no cover 169 | raise IndexError("an index can only have a single ellipsis ('...')") 170 | has_ellipsis = True 171 | initial_len = len(index_list) 172 | while len(index_list) + (concrete_indices - initial_len) < max_dim: 173 | index_list.append(slice(None)) 174 | else: 175 | index_list.append(item) 176 | 177 | if len(index_list) > max_dim: # pragma: no cover 178 | raise IndexError("too many indices for array") 179 | while len(index_list) < max_dim: 180 | index_list.append(slice(None)) 181 | 182 | # return index_list, newaxis_at 183 | return tuple(index_list) 184 | 185 | 186 | def _handle_aliases(index_tup: Tuple, alias_lookup) -> Tuple: 187 | new_index_tup = [] 188 | 189 | def resolve(item, dim): 190 | if isinstance(item, slice): 191 | return item 192 | # Replace element if in alias lookup, otherwise use original. 193 | item = alias_lookup.get(dim, item, item) 194 | return item 195 | 196 | # Go through each element within the index and resolve if needed. 197 | for dim, item in enumerate(index_tup): 198 | if isinstance(item, list): 199 | new_item = [] 200 | for sub_item in item: 201 | new_item.append(resolve(sub_item, dim)) 202 | else: 203 | new_item = resolve(item, dim) 204 | new_index_tup.append(new_item) 205 | 206 | return tuple(new_index_tup) 207 | 208 | 209 | class Tracked(AtomicSlicer): 210 | """ Tracked defines an object that slicer wraps.""" 211 | 212 | def __init__(self, o: Any, dim: Union[int, List, tuple, None, str] = "auto"): 213 | """ Defines an object that will be wrapped by slicer. 214 | 215 | Args: 216 | o: Object that will be tracked for slicer. 217 | dim: Target dimension(s) slicer will index on for this object. 218 | """ 219 | super().__init__(o) 220 | 221 | # Protected attribute that can be overriden. 222 | self._name = None 223 | 224 | # Place dim into coordinate form. 225 | if dim == "auto": 226 | self.dim = list(range(self.max_dim)) 227 | elif dim is None: 228 | self.dim = [] 229 | elif isinstance(dim, int): 230 | self.dim = [dim] 231 | elif isinstance(dim, list): 232 | self.dim = dim 233 | elif isinstance(dim, tuple): 234 | self.dim = list(dim) 235 | else: # pragma: no cover 236 | raise ValueError(f"Cannot handle dim of type: {type(dim)}") 237 | 238 | 239 | class Obj(Tracked): 240 | """ An object that slicer wraps. """ 241 | def __init__(self, o, dim="auto"): 242 | super().__init__(o, dim) 243 | 244 | 245 | class Alias(Tracked): 246 | """ Defines a tracked object as well as additional __getitem__ keys. """ 247 | def __init__(self, o, dim): 248 | if not ( 249 | isinstance(dim, int) or (isinstance(dim, (list, tuple)) and len(dim) <= 1) 250 | ): # pragma: no cover 251 | raise ValueError("Aliases must track a single dimension") 252 | super().__init__(o, dim) 253 | 254 | 255 | class AliasLookup: 256 | def __init__(self, aliases): 257 | self._lookup = {} 258 | 259 | # Populate lookup and merge indexes. 260 | for _, alias in aliases.items(): 261 | self.update(alias) 262 | 263 | def update(self, alias): 264 | if alias.dim is None or len(alias.dim) == 0: 265 | return 266 | 267 | dim = alias.dim[0] 268 | if dim not in self._lookup: 269 | self._lookup[dim] = {} 270 | 271 | dim_lookup = self._lookup[dim] 272 | # NOTE: Alias must be backed by either a list or dictionary. 273 | itr = enumerate(alias.o) if isinstance(alias.o, list) else alias.o.items() 274 | for i, x in itr: 275 | if x not in dim_lookup: 276 | dim_lookup[x] = set() 277 | dim_lookup[x].add(i) 278 | 279 | def delete(self, alias): 280 | '''Delete an alias that exists from lookup''' 281 | dim = alias.dim[0] 282 | dim_lookup = self._lookup[dim] 283 | # NOTE: Alias must be backed by either a list or dictionary. 284 | itr = enumerate(alias.o) if isinstance(alias.o, list) else alias.o.items() 285 | for i, x in itr: 286 | del dim_lookup[x] 287 | 288 | def get(self, dim, target, default=None): 289 | if dim not in self._lookup: 290 | return default 291 | 292 | indexes = self._lookup[dim].get(target, None) 293 | if indexes is None: 294 | return default 295 | 296 | if len(indexes) == 1: 297 | return next(iter(indexes)) 298 | else: 299 | return list(indexes) 300 | 301 | 302 | def resolve_dim(slicer_index: Tuple, slicer_dim: List) -> List: 303 | """ Extracts new dim after applying slicing index and maps it back to the original index list. """ 304 | 305 | new_slicer_dim = [] 306 | reduced_mask = [] 307 | 308 | for _, curr_idx in enumerate(slicer_index): 309 | if isinstance(curr_idx, (tuple, list, slice)): 310 | reduced_mask.append(0) 311 | else: 312 | reduced_mask.append(1) 313 | 314 | for curr_dim in slicer_dim: 315 | if reduced_mask[curr_dim] == 0: 316 | new_slicer_dim.append(curr_dim - sum(reduced_mask[:curr_dim])) 317 | 318 | return new_slicer_dim 319 | 320 | 321 | def reduced_o(tracked: Tracked) -> Union[List, Any]: 322 | os = [t.o for t in tracked] 323 | os = os[0] if len(os) == 1 else os 324 | return os 325 | 326 | 327 | class BaseHandler: 328 | @classmethod 329 | @abstractmethod 330 | def head_slice(cls, o, index_tup, max_dim): 331 | raise NotImplementedError() # pragma: no cover 332 | 333 | @classmethod 334 | @abstractmethod 335 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 336 | raise NotImplementedError() # pragma: no cover 337 | 338 | @classmethod 339 | @abstractmethod 340 | def max_dim(cls, o): 341 | raise NotImplementedError() # pragma: no cover 342 | 343 | @classmethod 344 | def default_alias(cls, o): 345 | return [] 346 | 347 | 348 | class SeriesHandler(BaseHandler): 349 | @classmethod 350 | def head_slice(cls, o, index_tup, max_dim): 351 | head_index = index_tup[0] 352 | is_element = True if isinstance(head_index, int) else False 353 | sliced_o = o.iloc[head_index] 354 | 355 | return is_element, sliced_o, 1 356 | 357 | @classmethod 358 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 359 | # NOTE: Series only has one dimension, 360 | # call slicer again to end the recursion. 361 | return AtomicSlicer(o, max_dim=max_dim)[tail_index] 362 | 363 | @classmethod 364 | def max_dim(cls, o): 365 | return len(o.shape) 366 | 367 | @classmethod 368 | def default_alias(cls, o): 369 | index_alias = Alias(o.index.to_list(), 0) 370 | index_alias._name = "index" 371 | return [index_alias] 372 | 373 | 374 | class DataFrameHandler(BaseHandler): 375 | @classmethod 376 | def head_slice(cls, o, index_tup, max_dim): 377 | # NOTE: At head slice, we know there are two fixed dimensions. 378 | cut_index = index_tup 379 | is_element = True if isinstance(cut_index[-1], int) else False 380 | sliced_o = o.iloc[cut_index] 381 | 382 | return is_element, sliced_o, 2 383 | 384 | @classmethod 385 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 386 | # NOTE: Dataframe has fixed dimensions, 387 | # call slicer again to end the recursion. 388 | return AtomicSlicer(o, max_dim=max_dim)[tail_index] 389 | 390 | @classmethod 391 | def max_dim(cls, o): 392 | return len(o.shape) 393 | 394 | @classmethod 395 | def default_alias(cls, o): 396 | index_alias = Alias(o.index.to_list(), 0) 397 | index_alias._name = "index" 398 | column_alias = Alias(o.columns.to_list(), 1) 399 | column_alias._name = "columns" 400 | return [index_alias, column_alias] 401 | 402 | 403 | class ArrayHandler(BaseHandler): 404 | @classmethod 405 | def head_slice(cls, o, index_tup, max_dim): 406 | # Check if head is string 407 | head_index, tail_index = index_tup[0], index_tup[1:] 408 | cut = 1 409 | 410 | for sub_index in tail_index: 411 | if isinstance(sub_index, str) or cut == len(o.shape): 412 | break 413 | cut += 1 414 | 415 | # Process native array dimensions 416 | cut_index = index_tup[:cut] 417 | is_element = any([True if isinstance(x, int) else False for x in cut_index]) 418 | sliced_o = o[cut_index] 419 | 420 | return is_element, sliced_o, cut 421 | 422 | @classmethod 423 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 424 | if flatten: 425 | # NOTE: If we're dealing with a scipy matrix, 426 | # we have to manually flatten it ourselves 427 | # to keep consistent to the rest of slicer's API. 428 | if _safe_isinstance(o, "scipy.sparse.csc", "csc_matrix"): 429 | return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index] 430 | elif _safe_isinstance(o, "scipy.sparse.csr", "csr_matrix"): 431 | return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index] 432 | elif _safe_isinstance(o, "scipy.sparse.dok", "dok_matrix"): 433 | return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index] 434 | elif _safe_isinstance(o, "scipy.sparse.lil", "lil_matrix"): 435 | return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index] 436 | else: 437 | return AtomicSlicer(o, max_dim=max_dim)[tail_index] 438 | else: 439 | inner = [AtomicSlicer(e, max_dim=max_dim)[tail_index] for e in o] 440 | if _safe_isinstance(o, "numpy", "ndarray"): 441 | import numpy 442 | if len(inner) > 0 and hasattr(inner[0], "__len__"): 443 | ragged = not all(len(x) == len(inner[0]) for x in inner) 444 | else: 445 | ragged = False 446 | if ragged: 447 | return numpy.array(inner, dtype=object) 448 | else: 449 | return numpy.array(inner) 450 | elif _safe_isinstance(o, "torch", "Tensor"): 451 | import torch 452 | 453 | if len(inner) > 0 and isinstance(inner[0], torch.Tensor): 454 | return torch.stack(inner) 455 | else: 456 | return torch.tensor(inner) 457 | elif _safe_isinstance(o, "scipy.sparse.csc", "csc_matrix"): 458 | from scipy.sparse import vstack 459 | out = vstack(inner, format='csc') 460 | return out 461 | elif _safe_isinstance(o, "scipy.sparse.csr", "csr_matrix"): 462 | from scipy.sparse import vstack 463 | out = vstack(inner, format='csr') 464 | return out 465 | elif _safe_isinstance(o, "scipy.sparse.dok", "dok_matrix"): 466 | from scipy.sparse import vstack 467 | out = vstack(inner, format='dok') 468 | return out 469 | elif _safe_isinstance(o, "scipy.sparse.lil", "lil_matrix"): 470 | from scipy.sparse import vstack 471 | out = vstack(inner, format='lil') 472 | return out 473 | else: 474 | raise ValueError(f"Cannot handle type {type(o)}.") # pragma: no cover 475 | 476 | @classmethod 477 | def max_dim(cls, o): 478 | if _safe_isinstance(o, "numpy", "ndarray") and o.dtype == "object": 479 | return max([UnifiedDataHandler.max_dim(x) for x in o], default=-1) + 1 480 | else: 481 | return len(o.shape) 482 | 483 | 484 | class DictHandler(BaseHandler): 485 | @classmethod 486 | def head_slice(cls, o, index_tup, max_dim): 487 | head_index = index_tup[0] 488 | if isinstance(head_index, (tuple, list)) and len(index_tup) == 0: 489 | return False, o, 1 490 | 491 | if isinstance(head_index, (list, tuple)): 492 | return ( 493 | False, 494 | { 495 | sub_index: AtomicSlicer(o, max_dim=max_dim)[sub_index] 496 | for sub_index in head_index 497 | }, 498 | 1, 499 | ) 500 | elif isinstance(head_index, slice): 501 | if head_index == slice(None, None, None): 502 | return False, o, 1 503 | return False, o[head_index], 1 504 | else: 505 | return True, o[head_index], 1 506 | 507 | @classmethod 508 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 509 | if flatten: 510 | return AtomicSlicer(o, max_dim=max_dim)[tail_index] 511 | else: 512 | return { 513 | k: AtomicSlicer(e, max_dim=max_dim)[tail_index] for k, e in o.items() 514 | } 515 | 516 | @classmethod 517 | def max_dim(cls, o): 518 | return max([UnifiedDataHandler.max_dim(x) for x in o.values()], default=-1) + 1 519 | 520 | 521 | class ListTupleHandler(BaseHandler): 522 | @classmethod 523 | def head_slice(cls, o, index_tup, max_dim): 524 | head_index = index_tup[0] 525 | if isinstance(head_index, (tuple, list)) and len(index_tup) == 0: 526 | return False, o, 1 527 | 528 | if isinstance(head_index, (list, tuple)): 529 | if len(head_index) == 0: 530 | return False, o, 1 531 | else: 532 | results = [ 533 | AtomicSlicer(o, max_dim=max_dim)[sub_index] 534 | for sub_index in head_index 535 | ] 536 | results = tuple(results) if isinstance(o, tuple) else results 537 | return False, results, 1 538 | elif isinstance(head_index, slice): 539 | return False, o[head_index], 1 540 | elif isinstance(head_index, int): 541 | return True, o[head_index], 1 542 | else: # pragma: no cover 543 | raise ValueError(f"Invalid key {head_index} for {o}") 544 | 545 | @classmethod 546 | def tail_slice(cls, o, tail_index, max_dim, flatten=True): 547 | if flatten: 548 | return AtomicSlicer(o, max_dim=max_dim)[tail_index] 549 | else: 550 | results = [AtomicSlicer(e, max_dim=max_dim)[tail_index] for e in o] 551 | return tuple(results) if isinstance(o, tuple) else results 552 | 553 | @classmethod 554 | def max_dim(cls, o): 555 | return max([UnifiedDataHandler.max_dim(x) for x in o], default=-1) + 1 556 | 557 | 558 | class UnifiedDataHandler: 559 | """ Registry that maps types to their unified slice calls.""" 560 | 561 | """ Class attribute that maps type to their unified slice calls.""" 562 | type_map = { 563 | ("builtins", "list"): ListTupleHandler, 564 | ("builtins", "tuple"): ListTupleHandler, 565 | ("builtins", "dict"): DictHandler, 566 | ("torch", "Tensor"): ArrayHandler, 567 | ("numpy", "ndarray"): ArrayHandler, 568 | ("scipy.sparse.csc", "csc_matrix"): ArrayHandler, 569 | ("scipy.sparse.csr", "csr_matrix"): ArrayHandler, 570 | ("scipy.sparse.dok", "dok_matrix"): ArrayHandler, 571 | ("scipy.sparse.lil", "lil_matrix"): ArrayHandler, 572 | ("pandas.core.frame", "DataFrame"): DataFrameHandler, 573 | ("pandas.core.series", "Series"): SeriesHandler, 574 | } 575 | 576 | @classmethod 577 | def slice(cls, o, index_tup, max_dim): 578 | # NOTE: Unified handles base cases such as empty tuples, which 579 | # specialized handlers do not. 580 | if isinstance(index_tup, (tuple, list)) and len(index_tup) == 0: 581 | return o 582 | 583 | # Slice as delegated by data handler. 584 | o_type = _type_name(o) 585 | head_slice = cls.type_map[o_type].head_slice 586 | tail_slice = cls.type_map[o_type].tail_slice 587 | 588 | is_element, sliced_o, cut = head_slice(o, index_tup, max_dim) 589 | out = tail_slice(sliced_o, index_tup[cut:], max_dim - cut, is_element) 590 | return out 591 | 592 | @classmethod 593 | def max_dim(cls, o): 594 | o_type = _type_name(o) 595 | if o_type not in cls.type_map: 596 | return 0 597 | return cls.type_map[o_type].max_dim(o) 598 | 599 | @classmethod 600 | def default_alias(cls, o): 601 | o_type = _type_name(o) 602 | if o_type not in cls.type_map: 603 | return {} 604 | return cls.type_map[o_type].default_alias(o) 605 | 606 | 607 | def _type_name(o: object) -> Tuple[str, str]: 608 | return _handle_module_aliases(o.__class__.__module__), o.__class__.__name__ 609 | 610 | 611 | def _safe_isinstance( 612 | o: object, module_name: str, type_name: Union[str, set, tuple] 613 | ) -> bool: 614 | o_module, o_type = _type_name(o) 615 | if isinstance(type_name, str): 616 | return o_module == module_name and o_type == type_name 617 | else: 618 | return o_module == module_name and o_type in type_name 619 | 620 | 621 | def _handle_module_aliases(module_name): 622 | # scipy modules such as "scipy.sparse.csc" were renamed to "scipy.sparse._csc" in v1.8 623 | # Standardise by removing underscores for compatibility with either name 624 | # Else just pass module name unchanged 625 | module_map = { 626 | "scipy.sparse._csc": "scipy.sparse.csc", 627 | "scipy.sparse._csr": "scipy.sparse.csr", 628 | "scipy.sparse._dok": "scipy.sparse.dok", 629 | "scipy.sparse._lil": "scipy.sparse.lil", 630 | } 631 | return module_map.get(module_name, module_name) 632 | -------------------------------------------------------------------------------- /slicer/test_slicer.py: -------------------------------------------------------------------------------- 1 | """ Basic tests for slicer. 2 | An unholy balance of use cases and test coverage. 3 | """ 4 | 5 | import pytest 6 | 7 | from .slicer import AtomicSlicer 8 | 9 | from . import Slicer as S 10 | from . import Alias as A 11 | from . import Obj as O 12 | 13 | import pandas as pd 14 | import numpy as np 15 | import torch 16 | from scipy.sparse import csc_matrix 17 | from scipy.sparse import csr_matrix 18 | from scipy.sparse import dok_matrix 19 | from scipy.sparse import lil_matrix 20 | 21 | 22 | from .utils_testing import ctr_eq 23 | 24 | 25 | def test_slicer_ragged_numpy(): 26 | values = np.array([ 27 | np.array([0, 1]), 28 | np.array([2, 3, 4]) 29 | ], dtype=object) 30 | data = np.array([ 31 | np.array([5, 6, 7]), 32 | ]) 33 | 34 | slicer = S(values=values, data=data) 35 | sliced = slicer[0, 1] 36 | 37 | assert ctr_eq(sliced.data, data[0][1]) 38 | assert ctr_eq(sliced.values, values[0][1]) 39 | 40 | 41 | def test_slicer_basic(): 42 | data = [[1, 2], [3, 4]] 43 | values = [[5, 6], [7, 8]] 44 | identifiers = ["id1", "id1"] 45 | instance_names = ["r1", "r2"] 46 | feature_names = ["f1", "f2"] 47 | full_name = "A" 48 | 49 | slicer = S( 50 | data=data, 51 | values=values, 52 | identifiers=A(identifiers, 0), 53 | instance_names=A(instance_names, 0), 54 | feature_names=A(feature_names, 1), 55 | full_name=full_name, 56 | ) 57 | 58 | colon_actual = slicer[:, 1] 59 | assert colon_actual.data == [2, 4] 60 | assert colon_actual.instance_names == ["r1", "r2"] 61 | assert colon_actual.feature_names == "f2" 62 | assert colon_actual.values == [6, 8] 63 | 64 | ellipses_actual = slicer[..., 1] 65 | assert ellipses_actual.data == [2, 4] 66 | assert ellipses_actual.instance_names == ["r1", "r2"] 67 | assert ellipses_actual.feature_names == "f2" 68 | assert ellipses_actual.values == [6, 8] 69 | 70 | array_index_actual = slicer[[0, 1], 1] 71 | assert array_index_actual.data == [2, 4] 72 | assert array_index_actual.feature_names == "f2" 73 | assert array_index_actual.instance_names == ["r1", "r2"] 74 | assert array_index_actual.values == [6, 8] 75 | 76 | alias_actual = slicer["r1", "f2"] 77 | assert alias_actual.data == 2 78 | assert alias_actual.feature_names == "f2" 79 | assert alias_actual.instance_names == "r1" 80 | assert alias_actual.values == 6 81 | 82 | alias_actual = slicer["id1", "f2"] 83 | assert alias_actual.data == [2, 4] 84 | assert alias_actual.feature_names == "f2" 85 | assert alias_actual.instance_names == ["r1", "r2"] 86 | assert alias_actual.values == [6, 8] 87 | 88 | chained_actual = slicer[:][:, 1] 89 | assert chained_actual.data == [2, 4] 90 | assert chained_actual.feature_names == "f2" 91 | assert chained_actual.instance_names == ["r1", "r2"] 92 | assert chained_actual.values == [6, 8] 93 | 94 | alias_actual = slicer["id1"][:, "f2"] 95 | assert alias_actual.data == [2, 4] 96 | assert alias_actual.feature_names == "f2" 97 | assert alias_actual.instance_names == ["r1", "r2"] 98 | assert alias_actual.values == [6, 8] 99 | 100 | alias_actual = slicer["r1"] 101 | alias_actual = alias_actual["f2"] 102 | assert alias_actual.data == 2 103 | assert alias_actual.feature_names == "f2" 104 | assert alias_actual.instance_names == "r1" 105 | assert alias_actual.values == 6 106 | 107 | 108 | def test_slicer_unnamed(): 109 | a = [1, 2, 3] 110 | b = [4, 5, 6] 111 | 112 | slicer = S(a, b) 113 | actual_a, actual_b = slicer[1].o 114 | assert actual_a == 2 115 | assert actual_b == 5 116 | 117 | df1 = pd.DataFrame([[1, 2], [3, 4]]) 118 | df2 = pd.DataFrame([[5, 6], [7, 8]]) 119 | slicer = S(df1, df2) 120 | actual_1, actual_2 = slicer[:, 0].o 121 | 122 | assert ctr_eq(actual_1.values, [1, 3]) 123 | assert ctr_eq(actual_2.values, [5, 7]) 124 | 125 | 126 | def test_slicer_crud(): 127 | data = [[1, 2], [3, 4]] 128 | values = [[5, 6], [7, 8]] 129 | extra = [[9, 10], [11, 12]] 130 | overridden = [[13, 14], [15, 16]] 131 | 132 | slicer = S(data=data, values=values) 133 | slicer.extra = extra # Create 134 | slicer.data = overridden # Update 135 | del slicer.values # Delete 136 | 137 | sliced = slicer[0, 1] # Read 138 | assert sliced.data == 14 139 | with pytest.raises(Exception): 140 | _ = sliced.values 141 | 142 | assert sliced.extra == 10 143 | 144 | del slicer.o 145 | assert slicer.o == [] 146 | 147 | 148 | def test_slicer_default_alias(): 149 | 150 | df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) 151 | slicer = S(df) 152 | assert getattr(slicer, "index", None) 153 | assert getattr(slicer, "columns", None) 154 | actual = slicer[:, "A"].o 155 | assert ctr_eq(actual, [1, 3]) 156 | 157 | 158 | def test_slicer_anon_dict(): 159 | di = {"a": [1, 2, 3], "b": [4, 5, 6]} 160 | slicer = S(di) 161 | 162 | result = slicer["a", 1].o 163 | assert result == 2 164 | 165 | 166 | def test_slicer_3d(): 167 | data_2d = [[1, 2], [3, 4], [5, 6]] 168 | values_3d = [ 169 | [[1, 2, 3], [4, 5, 6]], 170 | [[7, 8, 9], [10, 11, 12]], 171 | [[13, 14, 15], [16, 17, 18]], 172 | ] 173 | names = ["a", "b", "c"] 174 | 175 | slicer = S(data=data_2d, values=values_3d, names=A(names, 2)) 176 | actual = slicer[..., 1] 177 | assert ctr_eq(actual.data, data_2d) 178 | assert actual.names == "b" 179 | 180 | actual = slicer[0, :, 1] 181 | assert ctr_eq(actual.data, data_2d[0]) 182 | assert actual.names == "b" 183 | 184 | actual = slicer[0, :][:, 1] 185 | assert ctr_eq(actual.data, data_2d[0]) 186 | assert actual.names == "b" 187 | 188 | 189 | def test_untracked(): 190 | data = [1, 2, 3, 4] 191 | primitive = 1 192 | collection = [[8, 9]] 193 | slicer = S(data=data, primitive=O(primitive, None), collection=O(collection, None)) 194 | actual = slicer[:2] 195 | assert actual.data == data[:2] 196 | assert actual.primitive == primitive 197 | assert ctr_eq(actual.collection, collection) 198 | 199 | 200 | def test_partial_untracked(): 201 | s = S(a=np.zeros((4, 5, 6)), b=O(np.ones((4, 2, 2)), [0])) 202 | assert s[:, :, 1].b.shape == (4, 2, 2) 203 | 204 | 205 | def test_numpy_subkeys(): 206 | data = [1, 2, 3, 4] 207 | slicer = S(data=data) 208 | 209 | subkey = np.int64(1) 210 | assert slicer[subkey].data == 2 211 | 212 | subkey = np.array([0, 1]) 213 | assert ctr_eq(slicer[subkey].data, [1, 2]) 214 | 215 | subkey = np.array([[0, 1], [3, 4]]) 216 | with pytest.raises(ValueError): 217 | _ = slicer[subkey] 218 | 219 | 220 | def test_repr_smoke(): 221 | slicer = S([1, 2], ["a", "b"], named=[3, 4]) 222 | print(slicer) 223 | 224 | atomic = AtomicSlicer([1, 2, 3, 4]) 225 | print(atomic) 226 | 227 | 228 | def test_slicer_simple_di(): 229 | di = {"A": [1, 2], "B": [3, 4], "C": [5, 6]} 230 | slicer = S(di) 231 | actual = slicer["B", 0] 232 | actual = actual.o 233 | assert ctr_eq(actual, 3) 234 | 235 | nested_di = {"X": di, "Y": di} 236 | actual = S(nested_di)["X", "B", 0].o 237 | assert ctr_eq(actual, 3) 238 | 239 | 240 | def test_slicer_sparse(): 241 | array = np.array([[1, 0, 4], [0, 0, 5], [2, 3, 6]]) 242 | csc_array = csc_matrix(array) 243 | csr_array = csr_matrix(array) 244 | dok_array = dok_matrix(array) 245 | lil_array = lil_matrix(array) 246 | 247 | candidates = [csc_array, csr_array, dok_array, lil_array] 248 | for candidate in candidates: 249 | print("testing:", type(candidate)) 250 | slicer = S(candidate) 251 | actual = slicer[0, 0] 252 | assert ctr_eq(actual.o, 1) 253 | actual = slicer[1, 1] 254 | assert ctr_eq(actual.o, 0) 255 | 256 | actual = slicer[0] 257 | expected = np.array([1, 0, 4]) 258 | assert ctr_eq(actual.o, expected) 259 | 260 | actual = slicer[:, 1] 261 | expected = np.array([0, 0, 3]) 262 | assert ctr_eq(actual.o, expected) 263 | 264 | actual = slicer[:, :] 265 | expected = np.array([[1, 0, 4], [0, 0, 5], [2, 3, 6]]) 266 | assert ctr_eq(actual.o, expected) 267 | 268 | actual = slicer[0, :] 269 | expected = np.array([1, 0, 4]) 270 | assert ctr_eq(actual.o, expected) 271 | 272 | 273 | def test_slicer_torch(): 274 | import torch 275 | 276 | data = torch.tensor([[1, 2], [3, 4]]) 277 | values = torch.tensor([[5, 6], [7, 8]]) 278 | alias = ["f1", "f2"] 279 | 280 | slicer = S(data=data, values=values, alias=A(alias, 1)) 281 | sliced = slicer[0, "f2"] 282 | assert sliced.data == 2 283 | assert sliced.values == 6 284 | 285 | 286 | def test_slicer_pandas(): 287 | di = {"A": [1, 2], "B": [3, 4], "C": [5, 6]} 288 | df = pd.DataFrame(di) 289 | 290 | slicer = S(df) 291 | assert slicer[0, "A"].o == 1 292 | assert ctr_eq(slicer[:, "A"].o, [1, 2]) 293 | assert ctr_eq(slicer[0, :].o, [1, 3, 5]) 294 | 295 | df = pd.DataFrame(di, index=["X", "Y"]) 296 | slicer = S(df) 297 | assert slicer["X", "A"].o == 1 298 | assert slicer[0, "A"].o == 1 299 | assert slicer[0, 0].o == 1 300 | slicer = S(df["A"]) 301 | assert slicer["X"].o == 1 302 | assert slicer[0].o == 1 303 | assert ctr_eq(slicer[:].o, [1, 2]) 304 | 305 | 306 | def test_handle_newaxis_ellipses(): 307 | from .slicer_internal import _handle_newaxis_ellipses 308 | 309 | index_tup = (1,) 310 | max_dim = 3 311 | 312 | expanded_index_tup = _handle_newaxis_ellipses(index_tup, max_dim) 313 | assert expanded_index_tup == (1, slice(None), slice(None)) 314 | 315 | 316 | def test_tracked_dim_arg_smoke(): 317 | li = ['A', 'B'] 318 | _ = A(li, dim=0) 319 | _ = A(li, dim=[0]) 320 | _ = A(li, dim=(0,)) 321 | 322 | # Aliases must have a single dim 323 | with pytest.raises(Exception): 324 | _ = A(li, dim=None) 325 | 326 | with pytest.raises(Exception): 327 | _ = A(li, dim=[0,1]) 328 | 329 | _ = O(li, dim=0) 330 | _ = O(li, dim=[0]) 331 | _ = O(li, dim=(0,)) 332 | 333 | assert True 334 | 335 | 336 | def test_operations_1d(): 337 | elements = [1, 2, 3, 4] 338 | li = elements 339 | tup = tuple(elements) 340 | di = {i: x for i, x in enumerate(elements)} 341 | series = pd.Series(elements) 342 | array = np.array(elements) 343 | torch_array = torch.tensor(elements) 344 | containers = [li, tup, array, torch_array, di, series] 345 | for ctr in containers: 346 | print("testing:", type(ctr)) 347 | slicer = AtomicSlicer(ctr) 348 | 349 | assert ctr_eq(slicer[0], elements[0]) 350 | 351 | # Array 352 | assert ctr_eq(slicer[[0, 1, 2, 3]], elements) 353 | assert ctr_eq(slicer[[0, 1, 2]], elements[:-1]) 354 | 355 | # All 356 | assert ctr_eq(slicer[:], elements[:]) 357 | assert ctr_eq(slicer[tuple()], elements) 358 | 359 | # Ranged slicing 360 | if not isinstance(ctr, dict): # Do not test on dictionaries. 361 | assert ctr_eq(slicer[-1], elements[-1]) 362 | assert ctr_eq(slicer[0:3:2], elements[0:3:2]) 363 | 364 | 365 | def test_operations_2d(): 366 | elements = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 367 | li = elements 368 | df = pd.DataFrame(elements, columns=["A", "B", "C"]) 369 | 370 | sparse_csc = csc_matrix(elements) 371 | sparse_csr = csr_matrix(elements) 372 | sparse_dok = dok_matrix(elements) 373 | sparse_lil = lil_matrix(elements) 374 | 375 | containers = [li, df, sparse_csc, sparse_csr, sparse_dok, sparse_lil] 376 | for ctr in containers: 377 | print("testing:", type(ctr)) 378 | slicer = AtomicSlicer(ctr) 379 | 380 | assert ctr_eq(slicer[0], elements[0]) 381 | 382 | # Ranged slicing 383 | if not isinstance(ctr, dict): 384 | assert ctr_eq(slicer[-1], elements[-1]) 385 | assert ctr_eq(slicer[0, 0:3:2], elements[0][0:3:2]) 386 | 387 | # Array 388 | assert ctr_eq(slicer[[0, 1, 2], :], elements) 389 | 390 | # All 391 | assert ctr_eq(slicer[:], elements) 392 | assert ctr_eq(slicer[tuple()], elements) 393 | 394 | assert ctr_eq(slicer[:, 0], [elements[i][0] for i, _ in enumerate(elements)]) 395 | assert ctr_eq(slicer[[0, 1], 0], [elements[i][0] for i in [0, 1]]) 396 | assert ctr_eq(slicer[[0, 1], 1], [elements[i][1] for i in [0, 1]]) 397 | assert ctr_eq(slicer[0, :], elements[0]) 398 | assert ctr_eq(slicer[0, 1], elements[0][1]) 399 | 400 | assert ctr_eq(slicer[..., 0], [elements[i][0] for i, _ in enumerate(elements)]) 401 | 402 | 403 | def test_operations_3d(): 404 | # 3-dimensional fixed dimension case 405 | elements = [ 406 | [[1, 2, 3], [4, 5, 6]], 407 | [[7, 8, 9], [10, 11, 12]], 408 | [[13, 14, 15], [16, 17, 18]], 409 | ] 410 | tuple_elements = ( 411 | ((1, 2, 3), (4, 5, 6)), 412 | ((7, 8, 9), (10, 11, 12)), 413 | ((13, 14, 15), (16, 17, 18)), 414 | ) 415 | torch_array = torch.tensor(elements) 416 | multi_array = np.array(elements) 417 | list_of_lists = elements 418 | tuples_of_tuples = tuple_elements 419 | list_of_multi_arrays = [ 420 | np.array(elements[0]), 421 | np.array(elements[1]), 422 | np.array(elements[2]), 423 | ] 424 | di_of_multi_arrays = { 425 | 0: np.array(elements[0]), 426 | 1: np.array(elements[1]), 427 | 2: np.array(elements[2]), 428 | } 429 | 430 | containers = [ 431 | torch_array, 432 | multi_array, 433 | tuples_of_tuples, 434 | list_of_lists, 435 | list_of_multi_arrays, 436 | di_of_multi_arrays, 437 | ] 438 | for ctr in containers: 439 | print("testing:", type(ctr)) 440 | slicer = AtomicSlicer(ctr) 441 | 442 | assert ctr_eq(slicer[0], elements[0]) 443 | 444 | # Ranged slicing 445 | if not isinstance(ctr, dict): 446 | assert ctr_eq(slicer[-1], elements[-1]) 447 | assert ctr_eq(slicer[0, 0:3:2], elements[0][0:3:2]) 448 | 449 | # Array 450 | assert ctr_eq(slicer[[0, 1, 2], :], elements) 451 | 452 | # All 453 | assert ctr_eq(slicer[:], elements) 454 | assert ctr_eq(slicer[tuple()], elements) 455 | 456 | assert ctr_eq(slicer[:, 0], [elements[i][0] for i, _ in enumerate(elements)]) 457 | assert ctr_eq(slicer[[0, 1], 0], [elements[i][0] for i in [0, 1]]) 458 | assert ctr_eq(slicer[[0, 1], 1], [elements[i][1] for i in [0, 1]]) 459 | assert ctr_eq(slicer[0, :], elements[0]) 460 | assert ctr_eq(slicer[0, 1], elements[0][1]) 461 | 462 | rows = [] 463 | for i, _ in enumerate(elements): 464 | cols = [] 465 | for j, _ in enumerate(elements[i]): 466 | cols.append(elements[i][j][1]) 467 | rows.append(cols) 468 | assert ctr_eq(slicer[..., 1], rows) 469 | assert ctr_eq( 470 | slicer[0, ..., 1], [elements[0][i][1] for i in range(len(elements[0]))] 471 | ) 472 | 473 | def test_attribute_assignment(): 474 | data = [[1, 2], [3, 4]] 475 | values = [[5, 6], [7, 8]] 476 | identifiers = ["id1", "id1"] 477 | instance_names = ["r1", "r2"] 478 | feature_names = ["f1", "f2"] 479 | full_name = "A" 480 | 481 | exp = S( 482 | data=data, 483 | values=values, 484 | identifiers=A(identifiers, 0), 485 | instance_names=A(instance_names, 0), 486 | feature_names=A(feature_names, 1), 487 | full_name=full_name, 488 | ) 489 | 490 | exp.feature_names = ['f3', 'f4'] 491 | 492 | assert exp.feature_names == ['f3', 'f4'] 493 | assert exp[:, 0].feature_names == 'f3' 494 | 495 | with pytest.raises(Exception): 496 | _ = exp[:, 'f1'] # f1 should no longer exist as valid alias 497 | 498 | exp.feature_names = A(['f5', 'f6'], dim=0) 499 | 500 | assert exp.feature_names == ['f5', 'f6'] 501 | assert exp[1, :].feature_names == 'f6' # feature_names now tracks dim 0 502 | -------------------------------------------------------------------------------- /slicer/utils_testing.py: -------------------------------------------------------------------------------- 1 | """ Testing utilities that allow for easier assertions on collections. 2 | Do you love tests? Neither do we. 3 | """ 4 | # TODO: This module due is for a refactor. 5 | from typing import Any 6 | import numbers 7 | import numpy as np 8 | import torch 9 | import pandas as pd 10 | from scipy.sparse import csc_matrix 11 | from scipy.sparse import csr_matrix 12 | from scipy.sparse import dok_matrix 13 | from scipy.sparse import lil_matrix 14 | 15 | 16 | def coerced(o: Any): 17 | if isinstance(o, (csc_matrix, csr_matrix, dok_matrix, lil_matrix)): 18 | o = o.toarray() 19 | 20 | to_list_collections = tuple([np.ndarray, torch.Tensor, pd.core.series.Series]) 21 | if isinstance(o, (list, tuple)): 22 | return o 23 | elif isinstance(o, to_list_collections): 24 | return o.tolist() 25 | elif isinstance(o, pd.core.frame.DataFrame): 26 | return o.values.tolist() 27 | elif isinstance(o, dict): 28 | li = [np.nan] * len(o) 29 | for k, v in o.items(): 30 | li[k] = v 31 | return li 32 | else: 33 | raise ValueError(f"Object {o} of {type(o)} is not a list, tuple nor array.") 34 | 35 | 36 | def is_close( 37 | a: numbers.Number, b: numbers.Number, rel_tol: float = 1e-09, abs_tol: float = 0.0 38 | ): 39 | return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) 40 | 41 | 42 | def ctr_eq(c1: Any, c2: Any): 43 | if isinstance(c1, torch.Tensor) and c1.shape == torch.Size([]): 44 | c1 = c1.item() 45 | if isinstance(c2, torch.Tensor) and c2.shape == torch.Size([]): 46 | c2 = c2.item() 47 | 48 | if isinstance(c1, numbers.Number) and isinstance(c2, numbers.Number): 49 | return is_close(c1, c2) 50 | 51 | c1 = coerced(c1) 52 | c2 = coerced(c2) 53 | 54 | return all([ctr_eq(c1[i], c2[i]) for i in range(max(len(c1), len(c2)))]) 55 | --------------------------------------------------------------------------------