├── .git_archival.txt ├── .gitattributes ├── .github ├── CONTRIBUTING.md ├── dependabot.yml ├── matchers │ └── pylint.json └── workflows │ ├── cd.yml │ ├── ci.yml │ └── lower-bound-requirements.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── README.md ├── docs ├── conf.py ├── index.md └── logo.svg ├── noxfile.py ├── pyproject.toml ├── src └── ragged │ ├── __init__.py │ ├── _helper_functions.py │ ├── _import.py │ ├── _spec_array_object.py │ ├── _spec_constants.py │ ├── _spec_creation_functions.py │ ├── _spec_data_type_functions.py │ ├── _spec_elementwise_functions.py │ ├── _spec_indexing_functions.py │ ├── _spec_linear_algebra_functions.py │ ├── _spec_manipulation_functions.py │ ├── _spec_searching_functions.py │ ├── _spec_set_functions.py │ ├── _spec_sorting_functions.py │ ├── _spec_statistical_functions.py │ ├── _spec_utility_functions.py │ ├── _typing.py │ ├── _version.pyi │ ├── io │ ├── __init__.py │ └── cf.py │ └── py.typed ├── tests-cuda └── test_cuda_spec_set_functions.py └── tests ├── conftest.py ├── test_spec_array_object.py ├── test_spec_broadcasting.py ├── test_spec_constants.py ├── test_spec_creation_functions.py ├── test_spec_data_type_functions.py ├── test_spec_elementwise_functions.py ├── test_spec_indexing.py ├── test_spec_indexing_functions.py ├── test_spec_linear_algebra_functions.py ├── test_spec_manipulation_functions.py ├── test_spec_searching_functions.py ├── test_spec_set_functions.py ├── test_spec_sorting_functions.py ├── test_spec_statistical_functions.py ├── test_spec_utility_functions.py ├── test_spec_version.py └── test_type_promotion.py /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: 98709142fe1d0b51973e8cc82c5cfb717b96ee96 2 | node-date: 2025-04-22T08:45:46+02:00 3 | describe-name: v0.2.0-1-g9870914 4 | ref-names: HEAD -> main 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst 2 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed 2 | description of best practices for developing scientific packages. 3 | 4 | [spc-dev-intro]: https://learn.scientific-python.org/development/ 5 | 6 | # Quick development 7 | 8 | The fastest way to start with development is to use nox. If you don't have nox, 9 | you can use `pipx run nox` to run it without installing, or `pipx install nox`. 10 | If you don't have pipx (pip for applications), then you can install with 11 | `pip install pipx` (the only case were installing an application with regular 12 | pip is reasonable). If you use macOS, then pipx and nox are both in brew, use 13 | `brew install pipx nox`. 14 | 15 | To use, run `nox`. This will lint and test using every installed version of 16 | Python on your system, skipping ones that are not installed. You can also run 17 | specific jobs: 18 | 19 | ```console 20 | $ nox -s lint # Lint only 21 | $ nox -s tests # Python tests 22 | $ nox -s docs -- --serve # Build and serve the docs 23 | $ nox -s build # Make an SDist and wheel 24 | ``` 25 | 26 | Nox handles everything for you, including setting up an temporary virtual 27 | environment for each run. 28 | 29 | # Setting up a development environment manually 30 | 31 | You can set up a development environment by running: 32 | 33 | ```bash 34 | python3 -m venv .venv 35 | source ./.venv/bin/activate 36 | pip install -v -e .[dev] 37 | ``` 38 | 39 | If you have the 40 | [Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you 41 | can instead do: 42 | 43 | ```bash 44 | py -m venv .venv 45 | py -m install -v -e .[dev] 46 | ``` 47 | 48 | # Post setup 49 | 50 | You should prepare pre-commit, which will help you by checking that commits pass 51 | required checks: 52 | 53 | ```bash 54 | pip install pre-commit # or brew install pre-commit on macOS 55 | pre-commit install # Will install a pre-commit hook into the git repo 56 | ``` 57 | 58 | You can also/alternatively run `pre-commit run` (changes only) or 59 | `pre-commit run --all-files` to check even without installing the hook. 60 | 61 | # Testing 62 | 63 | Use pytest to run the unit checks: 64 | 65 | ```bash 66 | pytest 67 | ``` 68 | 69 | # Coverage 70 | 71 | Use pytest-cov to generate coverage reports: 72 | 73 | ```bash 74 | pytest --cov=ragged 75 | ``` 76 | 77 | # Building docs 78 | 79 | You can build the docs using: 80 | 81 | ```bash 82 | nox -s docs 83 | ``` 84 | 85 | You can see a preview with: 86 | 87 | ```bash 88 | nox -s docs -- --serve 89 | ``` 90 | 91 | # Pre-commit 92 | 93 | This project uses pre-commit for all style checking. While you can run it with 94 | nox, this is such an important tool that it deserves to be installed on its own. 95 | Install pre-commit and run: 96 | 97 | ```bash 98 | pre-commit run -a 99 | ``` 100 | 101 | to check all files. 102 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | groups: 9 | actions: 10 | patterns: 11 | - "*" 12 | -------------------------------------------------------------------------------- /.github/matchers/pylint.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "severity": "warning", 5 | "pattern": [ 6 | { 7 | "regexp": "^([^:]+):(\\d+):(\\d+): ([A-DF-Z]\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", 8 | "file": 1, 9 | "line": 2, 10 | "column": 3, 11 | "code": 4, 12 | "message": 5 13 | } 14 | ], 15 | "owner": "pylint-warning" 16 | }, 17 | { 18 | "severity": "error", 19 | "pattern": [ 20 | { 21 | "regexp": "^([^:]+):(\\d+):(\\d+): (E\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", 22 | "file": 1, 23 | "line": 2, 24 | "column": 3, 25 | "code": 4, 26 | "message": 5 27 | } 28 | ], 29 | "owner": "pylint-error" 30 | } 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | release: 10 | types: 11 | - published 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | env: 18 | FORCE_COLOR: 3 19 | 20 | jobs: 21 | dist: 22 | name: Distribution build 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | with: 28 | fetch-depth: 0 29 | 30 | - uses: hynek/build-and-inspect-python-package@v2 31 | 32 | publish: 33 | needs: [dist] 34 | name: Publish to PyPI 35 | environment: pypi 36 | permissions: 37 | id-token: write 38 | runs-on: ubuntu-latest 39 | if: github.event_name == 'release' && github.event.action == 'published' 40 | 41 | steps: 42 | - uses: actions/download-artifact@v4 43 | with: 44 | name: Packages 45 | path: dist 46 | 47 | - uses: pypa/gh-action-pypi-publish@release/v1 48 | if: github.event_name == 'release' && github.event.action == 'published' 49 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | env: 15 | FORCE_COLOR: 3 16 | 17 | jobs: 18 | pre-commit: 19 | name: Format 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.x" 28 | - uses: pre-commit/action@v3.0.1 29 | with: 30 | extra_args: --hook-stage manual --all-files 31 | - name: Run PyLint 32 | run: | 33 | echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json" 34 | pipx run nox -s pylint 35 | 36 | checks: 37 | name: 38 | "py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ 39 | matrix.runs-on }}" 40 | runs-on: ${{ matrix.runs-on }} 41 | needs: [pre-commit] 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | python-version: ["3.9", "3.12"] 46 | numpy-version: ["latest"] 47 | runs-on: [ubuntu-latest, macos-latest, windows-latest] 48 | 49 | include: 50 | - python-version: "pypy-3.10" 51 | numpy-version: "latest" 52 | runs-on: ubuntu-latest 53 | - python-version: "3.9" 54 | numpy-version: "1.24.0" 55 | runs-on: ubuntu-latest 56 | 57 | steps: 58 | - uses: actions/checkout@v4 59 | with: 60 | fetch-depth: 0 61 | 62 | - uses: actions/setup-python@v5 63 | with: 64 | python-version: ${{ matrix.python-version }} 65 | allow-prereleases: true 66 | 67 | - name: Install uv 68 | run: python -m pip install --upgrade uv 69 | 70 | - name: Install old NumPy 71 | if: matrix.numpy-version != 'latest' 72 | run: uv pip install --system numpy==${{ matrix.numpy-version }} 73 | 74 | - name: Install package 75 | run: uv pip install --system '.[test]' 76 | 77 | - name: Print NumPy version 78 | run: python -c 'import numpy as np; print(np.__version__)' 79 | 80 | - name: Test package 81 | run: >- 82 | pytest -ra --cov --cov-report=xml --cov-report=term --durations=20 83 | 84 | - name: Upload coverage report 85 | uses: codecov/codecov-action@v5.4.0 86 | -------------------------------------------------------------------------------- /.github/workflows/lower-bound-requirements.yml: -------------------------------------------------------------------------------- 1 | name: Minimum supported dependencies 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | workflow_dispatch: 9 | 10 | jobs: 11 | test: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | # minimum supported Python 17 | python-version: ["3.9"] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies and force lowest bound 28 | run: | 29 | python -m pip install uv 30 | uv pip install --system --upgrade ".[test]" 31 | uv pip install --system --upgrade --resolution lowest-direct . 32 | 33 | - name: List installed Python packages 34 | run: uv pip list --system 35 | 36 | - name: Test with pytest 37 | run: pytest tests/ 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # setuptools_scm 163 | src/*/_version.py 164 | 165 | # ruff 166 | .ruff_cache/ 167 | 168 | # OS specific stuff 169 | .DS_Store 170 | .DS_Store? 171 | ._* 172 | .Spotlight-V100 173 | .Trashes 174 | ehthumbs.db 175 | Thumbs.db 176 | 177 | # Common editor files 178 | *~ 179 | *.swp 180 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: "v4.5.0" 8 | hooks: 9 | - id: check-added-large-files 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-symlinks 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | - id: name-tests-test 18 | args: ["--pytest-test-first"] 19 | - id: requirements-txt-fixer 20 | - id: trailing-whitespace 21 | 22 | - repo: https://github.com/pre-commit/pygrep-hooks 23 | rev: "v1.10.0" 24 | hooks: 25 | - id: rst-backticks 26 | - id: rst-directive-colons 27 | - id: rst-inline-touching-normal 28 | 29 | - repo: https://github.com/pre-commit/mirrors-prettier 30 | rev: "v3.1.0" 31 | hooks: 32 | - id: prettier 33 | types_or: [yaml, markdown, html, css, scss, javascript, json] 34 | args: [--prose-wrap=always] 35 | 36 | - repo: https://github.com/astral-sh/ruff-pre-commit 37 | rev: "v0.1.9" 38 | hooks: 39 | - id: ruff 40 | args: ["--fix", "--show-fixes"] 41 | - id: ruff-format 42 | 43 | - repo: https://github.com/pre-commit/mirrors-mypy 44 | rev: "v1.7.1" 45 | hooks: 46 | - id: mypy 47 | files: src|tests 48 | args: [] 49 | additional_dependencies: 50 | - pytest 51 | 52 | - repo: https://github.com/codespell-project/codespell 53 | rev: "v2.2.6" 54 | hooks: 55 | - id: codespell 56 | 57 | - repo: https://github.com/shellcheck-py/shellcheck-py 58 | rev: "v0.9.0.6" 59 | hooks: 60 | - id: shellcheck 61 | 62 | - repo: https://github.com/abravalheri/validate-pyproject 63 | rev: v0.15 64 | hooks: 65 | - id: validate-pyproject 66 | 67 | - repo: https://github.com/python-jsonschema/check-jsonschema 68 | rev: 0.27.0 69 | hooks: 70 | - id: check-dependabot 71 | - id: check-github-workflows 72 | - id: check-readthedocs 73 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | sphinx: 11 | configuration: docs/conf.py 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - docs 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Jim Pivarski. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the vector package developers nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ragged 2 | 3 | [![Actions Status][actions-badge]][actions-link] 4 | [![PyPI version][pypi-version]][pypi-link] 5 | [![PyPI platforms][pypi-platforms]][pypi-link] 6 | [![GitHub Discussion][github-discussions-badge]][github-discussions-link] 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | [actions-badge]: https://github.com/jpivarski/ragged/workflows/CI/badge.svg 15 | [actions-link]: https://github.com/jpivarski/ragged/actions 16 | [conda-badge]: https://img.shields.io/conda/vn/conda-forge/ragged 17 | [conda-link]: https://github.com/conda-forge/ragged-feedstock 18 | [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github 19 | [github-discussions-link]: https://github.com/jpivarski/ragged/discussions 20 | [pypi-link]: https://pypi.org/project/ragged/ 21 | [pypi-platforms]: https://img.shields.io/pypi/pyversions/ragged 22 | [pypi-version]: https://img.shields.io/pypi/v/ragged 23 | [rtd-badge]: https://readthedocs.org/projects/ragged/badge/?version=latest 24 | [rtd-link]: https://ragged.readthedocs.io/en/latest/?badge=latest 25 | 26 | 27 | ## Introduction 28 | 29 | **Ragged** is a library for manipulating ragged arrays as though they were 30 | **NumPy** or **CuPy** arrays, following the 31 | [Array API specification](https://data-apis.org/array-api/latest/API_specification). 32 | 33 | For example, this is a 34 | [ragged/jagged array](https://en.wikipedia.org/wiki/Jagged_array): 35 | 36 | ```python 37 | >>> import ragged 38 | >>> a = ragged.array([[[1.1, 2.2, 3.3], []], [[4.4]], [], [[5.5, 6.6, 7.7, 8.8], [9.9]]]) 39 | >>> a 40 | ragged.array([ 41 | [[1.1, 2.2, 3.3], []], 42 | [[4.4]], 43 | [], 44 | [[5.5, 6.6, 7.7, 8.8], [9.9]] 45 | ]) 46 | ``` 47 | 48 | The values are all floating-point numbers, so `a.dtype` is `float64`, 49 | 50 | ```python 51 | >>> a.dtype 52 | dtype('float64') 53 | ``` 54 | 55 | but `a.shape` has non-integer dimensions to account for the fact that some of 56 | its list lengths are non-uniform: 57 | 58 | ```python 59 | >>> a.shape 60 | (4, None, None) 61 | ``` 62 | 63 | In general, a `ragged.array` can have any mixture of regular and irregular 64 | dimensions, though `shape[0]` (the length) is always an integer. This convention 65 | follows the **Array API**'s specification for 66 | [array.shape](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.shape.html#array_api.array.shape), 67 | which must be a tuple of `int` or `None`: 68 | 69 | ```python 70 | array.shape: Tuple[Optional[int], ...] 71 | ``` 72 | 73 | (Our use of `None` to indicate a dimension without a single-valued size differs 74 | from the **Array API**'s intention of specifying dimensions of _unknown_ size, 75 | but it follows the technical specification. **Array API**-consuming libraries 76 | can try using **Ragged** to find out if they are ragged-ready.) 77 | 78 | All of the normal elementwise and reducing functions apply, as well as slices: 79 | 80 | ```python 81 | >>> ragged.sqrt(a) 82 | ragged.array([ 83 | [[1.05, 1.48, 1.82], []], 84 | [[2.1]], 85 | [], 86 | [[2.35, 2.57, 2.77, 2.97], [3.15]] 87 | ]) 88 | 89 | >>> ragged.sum(a, axis=0) 90 | ragged.array([ 91 | [11, 8.8, 11, 8.8], 92 | [9.9] 93 | ]) 94 | 95 | >>> ragged.sum(a, axis=-1) 96 | ragged.array([ 97 | [6.6, 0], 98 | [4.4], 99 | [], 100 | [28.6, 9.9] 101 | ]) 102 | 103 | >>> a[-1, 0, 2] 104 | ragged.array(7.7) 105 | 106 | >>> a[a * 10 % 2 == 0] 107 | ragged.array([ 108 | [[2.2], []], 109 | [[4.4]], 110 | [], 111 | [[6.6, 8.8], []] 112 | ]) 113 | ``` 114 | 115 | All of the methods, attributes, and functions in the **Array API** will be 116 | implemented for **Ragged**, as well as conveniences that are not required by the 117 | **Array API**. See 118 | [open issues marked "todo"](https://github.com/jpivarski/ragged/issues?q=is%3Aissue+is%3Aopen+label%3Atodo) 119 | for **Array API** functions that still need to be written (out of 120 in total). 120 | 121 | **Ragged** has two `device` values, `"cpu"` (backed by **NumPy**) and `"cuda"` 122 | (backed by **CuPy**). Eventually, all operations will be identical for CPU and 123 | GPU. 124 | 125 | ## Implementation 126 | 127 | **Ragged** is implemented using **Awkward Array** 128 | ([code](https://github.com/scikit-hep/awkward), 129 | [docs](https://awkward-array.org/)), which is an array library for arbitrary 130 | tree-like (JSON-like) data. Because of its generality, **Awkward Array** cannot 131 | follow the **Array API**—in fact, its array objects can't have separate `dtype` 132 | and `shape` attributes (the array `type` can't be factorized). **Ragged** is 133 | therefore 134 | 135 | - a _specialization_ of **Awkward Array** for numeric data in fixed-length and 136 | variable-length lists, and 137 | - a _formalization_ to adhere to the **Array API** and its fully typed 138 | protocols. 139 | 140 | See 141 | [Why does this library exist?](https://github.com/jpivarski/ragged/discussions/6) 142 | under the [Discussions](https://github.com/jpivarski/ragged/discussions) tab for 143 | more details. 144 | 145 | **Ragged** is a thin wrapper around **Awkward Array**, restricting it to ragged 146 | arrays and transforming its function arguments and return values to fit the 147 | specification. 148 | 149 | **Awkward Array**, in turn, is time- and memory-efficient, ready for big 150 | datasets. Consider the following: 151 | 152 | ```python 153 | import gc # control for garbage collection 154 | import psutil # measure process memory 155 | import time # measure time 156 | 157 | import math 158 | import ragged 159 | 160 | this_process = psutil.Process() 161 | 162 | def measure_memory(task): 163 | gc.collect() 164 | start_memory = this_process.memory_full_info().uss 165 | out = task() 166 | gc.collect() 167 | stop_memory = this_process.memory_full_info().uss 168 | print(f"memory: {(stop_memory - start_memory) * 1e-9:.3f} GB") 169 | return out 170 | 171 | def measure_time(task): 172 | gc.disable() 173 | start_time = time.perf_counter() 174 | out = task() 175 | stop_time = time.perf_counter() 176 | gc.enable() 177 | print(f"time: {stop_time - start_time:.3f} sec") 178 | return out 179 | 180 | def make_big_python_object(): 181 | out = [] 182 | for i in range(10000000): 183 | out.append([j * 1.1 for j in range(i % 10)]) 184 | return out 185 | 186 | def make_ragged_array(): 187 | return ragged.array(pyobj) 188 | 189 | def compute_on_python_object(): 190 | out = [] 191 | for row in pyobj: 192 | out.append([math.sqrt(x) for x in row]) 193 | return out 194 | 195 | def compute_on_ragged_array(): 196 | return ragged.sqrt(arr) 197 | ``` 198 | 199 | The `ragged.array` is 3 times smaller: 200 | 201 | ```python 202 | >>> pyobj = measure_memory(make_big_python_object) 203 | memory: 2.687 GB 204 | 205 | >>> arr = measure_memory(make_ragged_array) 206 | memory: 0.877 GB 207 | ``` 208 | 209 | and a sample calculation on it (square root of each value) is 50 times faster: 210 | 211 | ```python 212 | >>> result = measure_time(compute_on_python_object) 213 | time: 4.180 sec 214 | 215 | >>> result = measure_time(compute_on_ragged_array) 216 | time: 0.082 sec 217 | ``` 218 | 219 | **Awkward Array** and **Ragged** are generally smaller and faster than their 220 | Python equivalents for the same reasons that **NumPy** is smaller and faster 221 | than Python lists. See **Awkward Array** 222 | [papers and presentations](https://awkward-array.org/doc/main/getting-started/papers-and-talks.html) 223 | for more. 224 | 225 | ## Installation 226 | 227 | **Ragged** is on PyPI: 228 | 229 | ```bash 230 | pip install ragged 231 | ``` 232 | 233 | and will someday be on conda-forge. 234 | 235 | `ragged` is a pure-Python library that only depends on `awkward` (which, in 236 | turn, only depends on `numpy` and a compiled extension). In principle (i.e. 237 | eventually), `ragged` can be loaded into Pyodide and JupyterLite. 238 | 239 | # Acknowledgements 240 | 241 | Support for this work was provided by NSF grant 242 | [OAC-2103945](https://www.nsf.gov/awardsearch/showAward?AWD_ID=2103945) and the 243 | gracious help of 244 | [Awkward Array contributors](https://github.com/scikit-hep/awkward?tab=readme-ov-file#acknowledgements). 245 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib.metadata 4 | 5 | project = "ragged" 6 | copyright = "2023, Jim Pivarski" 7 | author = "Jim Pivarski" 8 | version = release = importlib.metadata.version("ragged") 9 | 10 | extensions = [ 11 | "myst_parser", 12 | "sphinx.ext.autodoc", 13 | "sphinx.ext.intersphinx", 14 | "sphinx.ext.mathjax", 15 | "sphinx.ext.napoleon", 16 | "sphinx_autodoc_typehints", 17 | "sphinx_copybutton", 18 | ] 19 | 20 | source_suffix = [".rst", ".md"] 21 | exclude_patterns = [ 22 | "_build", 23 | "**.ipynb_checkpoints", 24 | "Thumbs.db", 25 | ".DS_Store", 26 | ".env", 27 | ".venv", 28 | ] 29 | 30 | html_theme = "furo" 31 | 32 | myst_enable_extensions = [ 33 | "colon_fence", 34 | ] 35 | 36 | intersphinx_mapping = { 37 | "python": ("https://docs.python.org/3", None), 38 | } 39 | 40 | nitpick_ignore = [ 41 | ("py:class", "_io.StringIO"), 42 | ("py:class", "_io.BytesIO"), 43 | ] 44 | 45 | always_document_param_types = True 46 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # ragged 2 | 3 | ```{toctree} 4 | :maxdepth: 2 5 | :hidden: 6 | 7 | ``` 8 | 9 | ```{include} ../README.md 10 | :start-after: 11 | ``` 12 | 13 | ## Indices and tables 14 | 15 | - {ref}`genindex` 16 | - {ref}`modindex` 17 | - {ref}`search` 18 | -------------------------------------------------------------------------------- /docs/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 39 | 41 | 46 | 50 | 54 | 58 | 62 | 66 | 70 | 74 | 78 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import shutil 5 | from pathlib import Path 6 | 7 | import nox 8 | 9 | DIR = Path(__file__).parent.resolve() 10 | 11 | nox.options.sessions = ["lint", "pylint", "tests"] 12 | 13 | 14 | @nox.session 15 | def lint(session: nox.Session) -> None: 16 | """ 17 | Run the linter. 18 | """ 19 | session.install("pre-commit") 20 | session.run( 21 | "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs 22 | ) 23 | 24 | 25 | @nox.session 26 | def pylint(session: nox.Session) -> None: 27 | """ 28 | Run PyLint. 29 | """ 30 | # This needs to be installed into the package environment, and is slower 31 | # than a pre-commit check 32 | session.install(".", "pylint") 33 | session.run("pylint", "ragged", *session.posargs) 34 | 35 | 36 | @nox.session 37 | def tests(session: nox.Session) -> None: 38 | """ 39 | Run the unit and regular tests. 40 | """ 41 | session.install(".[test]") 42 | session.run("pytest", *session.posargs) 43 | 44 | 45 | @nox.session(reuse_venv=True) 46 | def docs(session: nox.Session) -> None: 47 | """ 48 | Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. 49 | """ 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--serve", action="store_true", help="Serve after building") 53 | parser.add_argument( 54 | "-b", dest="builder", default="html", help="Build target (default: html)" 55 | ) 56 | args, posargs = parser.parse_known_args(session.posargs) 57 | 58 | if args.builder != "html" and args.serve: 59 | session.error("Must not specify non-HTML builder with --serve") 60 | 61 | extra_installs = ["sphinx-autobuild"] if args.serve else [] 62 | 63 | session.install("-e.[docs]", *extra_installs) 64 | session.chdir("docs") 65 | 66 | if args.builder == "linkcheck": 67 | session.run( 68 | "sphinx-build", "-b", "linkcheck", ".", "_build/linkcheck", *posargs 69 | ) 70 | return 71 | 72 | shared_args = ( 73 | "-n", # nitpicky mode 74 | "-T", # full tracebacks 75 | f"-b={args.builder}", 76 | ".", 77 | f"_build/{args.builder}", 78 | *posargs, 79 | ) 80 | 81 | if args.serve: 82 | session.run("sphinx-autobuild", *shared_args) 83 | else: 84 | session.run("sphinx-build", "--keep-going", *shared_args) 85 | 86 | 87 | @nox.session 88 | def build_api_docs(session: nox.Session) -> None: 89 | """ 90 | Build (regenerate) API docs. 91 | """ 92 | 93 | session.install("sphinx") 94 | session.chdir("docs") 95 | session.run( 96 | "sphinx-apidoc", 97 | "-o", 98 | "api/", 99 | "--module-first", 100 | "--no-toc", 101 | "--force", 102 | "../src/ragged", 103 | ) 104 | 105 | 106 | @nox.session 107 | def build(session: nox.Session) -> None: 108 | """ 109 | Build an SDist and wheel. 110 | """ 111 | 112 | build_path = DIR.joinpath("build") 113 | if build_path.exists(): 114 | shutil.rmtree(build_path) 115 | 116 | session.install("build") 117 | session.run("python", "-m", "build") 118 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | 6 | [project] 7 | name = "ragged" 8 | authors = [ 9 | { name = "Jim Pivarski", email = "jpivarski@gmail.com" }, 10 | ] 11 | description = "Ragged array library, complying with Python API specification." 12 | readme = "README.md" 13 | license.file = "LICENSE" 14 | requires-python = ">=3.9" 15 | classifiers = [ 16 | "Development Status :: 1 - Planning", 17 | "Intended Audience :: Science/Research", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: BSD License", 20 | "Operating System :: OS Independent", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Topic :: Scientific/Engineering", 29 | "Typing :: Typed", 30 | ] 31 | dynamic = ["version"] 32 | dependencies = [ 33 | "awkward>=2.6.7", 34 | "numpy>=1.24.0", 35 | ] 36 | 37 | [project.optional-dependencies] 38 | test = [ 39 | "pytest >=6", 40 | "pytest-cov >=3", 41 | ] 42 | dev = [ 43 | "pytest >=6", 44 | "pytest-cov >=3", 45 | ] 46 | docs = [ 47 | "sphinx>=7.0", 48 | "myst_parser>=0.13", 49 | "sphinx_copybutton", 50 | "sphinx_autodoc_typehints", 51 | "furo>=2023.08.17", 52 | ] 53 | 54 | [project.urls] 55 | Homepage = "https://github.com/jpivarski/ragged" 56 | "Bug Tracker" = "https://github.com/jpivarski/ragged/issues" 57 | Discussions = "https://github.com/jpivarski/ragged/discussions" 58 | Changelog = "https://github.com/jpivarski/ragged/releases" 59 | 60 | 61 | [tool.hatch] 62 | version.source = "vcs" 63 | build.hooks.vcs.version-file = "src/ragged/_version.py" 64 | 65 | [tool.hatch.env.default] 66 | features = ["test"] 67 | scripts.test = "pytest {args}" 68 | 69 | 70 | [tool.pytest.ini_options] 71 | minversion = "6.0" 72 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] 73 | xfail_strict = true 74 | filterwarnings = [ 75 | "error", 76 | ] 77 | log_cli_level = "INFO" 78 | testpaths = [ 79 | "tests", 80 | ] 81 | 82 | 83 | [tool.coverage] 84 | run.source = ["ragged"] 85 | report.exclude_also = [ 86 | '\.\.\.', 87 | 'if typing.TYPE_CHECKING:', 88 | ] 89 | 90 | [tool.mypy] 91 | files = ["src", "tests"] 92 | python_version = "3.9" 93 | warn_unused_configs = true 94 | strict = true 95 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] 96 | warn_unreachable = true 97 | disallow_untyped_defs = false 98 | disallow_incomplete_defs = false 99 | 100 | [[tool.mypy.overrides]] 101 | module = "ragged.*" 102 | disallow_untyped_defs = true 103 | disallow_incomplete_defs = true 104 | 105 | [[tool.mypy.overrides]] 106 | module = "numpy.*" 107 | ignore_missing_imports = true 108 | 109 | [[tool.mypy.overrides]] 110 | module = "cupy.*" 111 | ignore_missing_imports = true 112 | 113 | [[tool.mypy.overrides]] 114 | module = "awkward.*" 115 | ignore_missing_imports = true 116 | 117 | [tool.ruff] 118 | src = ["src"] 119 | 120 | [tool.ruff.lint] 121 | extend-select = [ 122 | "B", # flake8-bugbear 123 | "I", # isort 124 | "ARG", # flake8-unused-arguments 125 | "C4", # flake8-comprehensions 126 | "EM", # flake8-errmsg 127 | "ICN", # flake8-import-conventions 128 | "G", # flake8-logging-format 129 | "PGH", # pygrep-hooks 130 | "PIE", # flake8-pie 131 | "PL", # pylint 132 | "PT", # flake8-pytest-style 133 | "PTH", # flake8-use-pathlib 134 | "RET", # flake8-return 135 | "RUF", # Ruff-specific 136 | "SIM", # flake8-simplify 137 | "T20", # flake8-print 138 | "UP", # pyupgrade 139 | "YTT", # flake8-2020 140 | "EXE", # flake8-executable 141 | "NPY", # NumPy specific rules 142 | "PD", # pandas-vet 143 | ] 144 | ignore = [ 145 | "PLR09", # Too many <...> 146 | "PLR2004", # Magic value used in comparison 147 | "ISC001", # Conflicts with formatter 148 | "RET505", # I like my if (return) elif (return) else (return) pattern 149 | "PLR5501", # I like my if (return) elif (return) else (return) pattern 150 | "RET506", # I like my if (raise) elif ... else ... pattern 151 | ] 152 | isort.required-imports = ["from __future__ import annotations"] 153 | # Uncomment if using a _compat.typing backport 154 | # typing-modules = ["ragged._compat.typing"] 155 | 156 | [tool.ruff.lint.per-file-ignores] 157 | "tests/**" = ["T20"] 158 | "noxfile.py" = ["T20"] 159 | 160 | 161 | [tool.pylint] 162 | py-version = "3.9" 163 | ignore-paths = [".*/_version.py"] 164 | reports.output-format = "colorized" 165 | similarities.ignore-imports = "yes" 166 | messages_control.disable = [ 167 | "design", 168 | "fixme", 169 | "line-too-long", 170 | "missing-module-docstring", 171 | "wrong-import-position", 172 | "missing-class-docstring", 173 | "missing-function-docstring", 174 | "R1705", # I like my if (return) elif (return) else (return) pattern 175 | "R1720", # I like my if (raise) elif ... else ... pattern 176 | "R0801", # Different files can have similar lines; that's okay 177 | "C0302", # I can have as many lines as I want; what's it with you? 178 | ] 179 | -------------------------------------------------------------------------------- /src/ragged/__init__.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | Ragged array module. 5 | 6 | FIXME: needs more documentation! 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from ._spec_array_object import array 12 | from ._spec_constants import ( 13 | e, 14 | inf, 15 | nan, 16 | newaxis, 17 | pi, 18 | ) 19 | from ._spec_creation_functions import ( 20 | arange, 21 | asarray, 22 | empty, 23 | empty_like, 24 | eye, 25 | from_dlpack, 26 | full, 27 | full_like, 28 | linspace, 29 | meshgrid, 30 | ones, 31 | ones_like, 32 | tril, 33 | triu, 34 | zeros, 35 | zeros_like, 36 | ) 37 | from ._spec_data_type_functions import ( 38 | astype, 39 | can_cast, 40 | finfo, 41 | iinfo, 42 | isdtype, 43 | result_type, 44 | ) 45 | from ._spec_elementwise_functions import ( # pylint: disable=W0622 46 | abs, 47 | acos, 48 | acosh, 49 | add, 50 | asin, 51 | asinh, 52 | atan, 53 | atan2, 54 | atanh, 55 | bitwise_and, 56 | bitwise_invert, 57 | bitwise_left_shift, 58 | bitwise_or, 59 | bitwise_right_shift, 60 | bitwise_xor, 61 | ceil, 62 | conj, 63 | cos, 64 | cosh, 65 | divide, 66 | equal, 67 | exp, 68 | expm1, 69 | floor, 70 | floor_divide, 71 | greater, 72 | greater_equal, 73 | imag, 74 | isfinite, 75 | isinf, 76 | isnan, 77 | less, 78 | less_equal, 79 | log, 80 | log1p, 81 | log2, 82 | log10, 83 | logaddexp, 84 | logical_and, 85 | logical_not, 86 | logical_or, 87 | logical_xor, 88 | multiply, 89 | negative, 90 | not_equal, 91 | positive, 92 | pow, 93 | real, 94 | remainder, 95 | round, 96 | sign, 97 | sin, 98 | sinh, 99 | sqrt, 100 | square, 101 | subtract, 102 | tan, 103 | tanh, 104 | trunc, 105 | ) 106 | from ._spec_indexing_functions import ( 107 | take, 108 | ) 109 | from ._spec_linear_algebra_functions import ( 110 | matmul, 111 | matrix_transpose, 112 | tensordot, 113 | vecdot, 114 | ) 115 | from ._spec_manipulation_functions import ( 116 | broadcast_arrays, 117 | broadcast_to, 118 | concat, 119 | expand_dims, 120 | flip, 121 | permute_dims, 122 | reshape, 123 | roll, 124 | squeeze, 125 | stack, 126 | ) 127 | from ._spec_searching_functions import ( 128 | argmax, 129 | argmin, 130 | nonzero, 131 | where, 132 | ) 133 | from ._spec_set_functions import ( # pylint: disable=R0401 134 | unique_all, 135 | unique_counts, 136 | unique_inverse, 137 | unique_values, 138 | ) 139 | from ._spec_sorting_functions import ( 140 | argsort, 141 | sort, 142 | ) 143 | from ._spec_statistical_functions import ( # pylint: disable=W0622 144 | max, 145 | mean, 146 | min, 147 | prod, 148 | std, 149 | sum, 150 | var, 151 | ) 152 | from ._spec_utility_functions import ( # pylint: disable=W0622 153 | all, 154 | any, 155 | ) 156 | 157 | __array_api_version__ = "2022.12" 158 | 159 | __all__ = [ 160 | "__array_api_version__", 161 | # _spec_array_object 162 | "array", 163 | # _spec_constants 164 | "e", 165 | "inf", 166 | "nan", 167 | "newaxis", 168 | "pi", 169 | # _spec_creation_functions 170 | "arange", 171 | "asarray", 172 | "empty", 173 | "empty_like", 174 | "eye", 175 | "from_dlpack", 176 | "full", 177 | "full_like", 178 | "linspace", 179 | "meshgrid", 180 | "ones", 181 | "ones_like", 182 | "tril", 183 | "triu", 184 | "zeros", 185 | "zeros_like", 186 | # _spec_data_type_functions 187 | "astype", 188 | "can_cast", 189 | "finfo", 190 | "iinfo", 191 | "isdtype", 192 | "result_type", 193 | # _spec_elementwise_functions 194 | "abs", 195 | "acos", 196 | "acosh", 197 | "add", 198 | "asin", 199 | "asinh", 200 | "atan", 201 | "atan2", 202 | "atanh", 203 | "bitwise_and", 204 | "bitwise_left_shift", 205 | "bitwise_invert", 206 | "bitwise_or", 207 | "bitwise_right_shift", 208 | "bitwise_xor", 209 | "ceil", 210 | "conj", 211 | "cos", 212 | "cosh", 213 | "divide", 214 | "equal", 215 | "exp", 216 | "expm1", 217 | "floor", 218 | "floor_divide", 219 | "greater", 220 | "greater_equal", 221 | "imag", 222 | "isfinite", 223 | "isinf", 224 | "isnan", 225 | "less", 226 | "less_equal", 227 | "log", 228 | "log1p", 229 | "log2", 230 | "log10", 231 | "logaddexp", 232 | "logical_and", 233 | "logical_not", 234 | "logical_or", 235 | "logical_xor", 236 | "multiply", 237 | "negative", 238 | "not_equal", 239 | "positive", 240 | "pow", 241 | "real", 242 | "remainder", 243 | "round", 244 | "sign", 245 | "sin", 246 | "sinh", 247 | "square", 248 | "sqrt", 249 | "subtract", 250 | "tan", 251 | "tanh", 252 | "trunc", 253 | # _spec_indexing_functions 254 | "take", 255 | # _spec_linear_algebra_functions 256 | "matmul", 257 | "matrix_transpose", 258 | "tensordot", 259 | "vecdot", 260 | # _spec_manipulation_functions 261 | "broadcast_arrays", 262 | "broadcast_to", 263 | "concat", 264 | "expand_dims", 265 | "flip", 266 | "permute_dims", 267 | "reshape", 268 | "roll", 269 | "squeeze", 270 | "stack", 271 | # _spec_searching_functions 272 | "argmax", 273 | "argmin", 274 | "nonzero", 275 | "where", 276 | # _spec_set_functions 277 | "unique_all", 278 | "unique_counts", 279 | "unique_inverse", 280 | "unique_values", 281 | # _spec_sorting_functions 282 | "argsort", 283 | "sort", 284 | # _spec_statistical_functions 285 | "max", 286 | "mean", 287 | "min", 288 | "prod", 289 | "std", 290 | "sum", 291 | "var", 292 | # _spec_utility_functions 293 | "all", 294 | "any", 295 | ] 296 | -------------------------------------------------------------------------------- /src/ragged/_helper_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | from __future__ import annotations 3 | 4 | import numpy as np 5 | 6 | 7 | def regularise_to_float(t: np.dtype, /) -> np.dtype: 8 | # Ensure compatibility with numpy 2.0.0 9 | if np.__version__ >= "2.1": 10 | # Just pass and return the input type if the numpy version is not 2.0.0 11 | return t 12 | 13 | if t in [np.int8, np.uint8, np.bool_, bool]: 14 | return np.float16 15 | elif t in [np.int16, np.uint16]: 16 | return np.float32 17 | elif t in [np.int32, np.uint32, np.int64, np.uint64]: 18 | return np.float64 19 | else: 20 | return t 21 | -------------------------------------------------------------------------------- /src/ragged/_import.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | import numpy as np 8 | 9 | from ._typing import Device 10 | 11 | 12 | def device_namespace(device: None | Device = None) -> tuple[Device, Any]: 13 | if device is None or device == "cpu": 14 | return "cpu", np 15 | elif device == "cuda": 16 | return "cuda", cupy() 17 | 18 | msg = f"unrecognized device: {device!r}" # type: ignore[unreachable] 19 | raise ValueError(msg) 20 | 21 | 22 | def cupy() -> Any: 23 | try: 24 | import cupy as cp # pylint: disable=C0415 25 | 26 | return cp 27 | except ModuleNotFoundError as err: 28 | error_message = """to use the "cuda" backend, you must install cupy: 29 | 30 | pip install cupy 31 | 32 | or 33 | 34 | conda install -c conda-forge cupy 35 | """ 36 | raise ModuleNotFoundError(error_message) from err 37 | -------------------------------------------------------------------------------- /src/ragged/_spec_constants.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/creation_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from numpy import e, inf, nan, newaxis, pi 10 | 11 | __all__ = ["e", "inf", "nan", "newaxis", "pi"] 12 | -------------------------------------------------------------------------------- /src/ragged/_spec_creation_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/creation_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import enum 10 | 11 | import awkward as ak 12 | import numpy as np 13 | 14 | from . import _import 15 | from ._import import device_namespace 16 | from ._spec_array_object import _box, _unbox, array 17 | from ._typing import ( 18 | Device, 19 | Dtype, 20 | NestedSequence, 21 | SupportsBufferProtocol, 22 | SupportsDLPack, 23 | ) 24 | 25 | 26 | def arange( 27 | start: int | float, 28 | /, 29 | stop: None | int | float = None, 30 | step: int | float = 1, 31 | *, 32 | dtype: None | Dtype = None, 33 | device: None | Device = None, 34 | ) -> array: 35 | """ 36 | Returns evenly spaced values within the half-open interval `[start, stop)` 37 | as a one-dimensional array. 38 | 39 | Args: 40 | start: If `stop` is specified, the start of interval (inclusive); 41 | otherwise, the end of the interval (exclusive). If `stop` is not 42 | specified, the default starting value is 0. 43 | stop: The end of the interval. 44 | step: The distance between two adjacent elements `(out[i+1] - out[i])`. 45 | Must not be 0; may be negative, this results in an empty array if 46 | `stop >= start`. 47 | dtype: Output array data type. If dtype is `None`, the output array 48 | data type is inferred from `start`, `stop` and `step`. If those are 49 | all integers, the output array dtype is `np.int64`; if one or more 50 | have type `float`, then the output array dtype is `np.float64`. 51 | device: Device on which to place the created array. 52 | 53 | Returns: 54 | A one-dimensional array containing evenly spaced values. The length of 55 | the output array is `ceil((stop-start)/step)` if `stop - start` and 56 | `step` have the same sign, and length 0 otherwise. 57 | 58 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html 59 | """ 60 | 61 | device, ns = device_namespace(device) 62 | return _box(array, ns.arange(start, stop, step, dtype=dtype)) 63 | 64 | 65 | def asarray( 66 | obj: ( 67 | array 68 | | ak.Array 69 | | bool 70 | | int 71 | | float 72 | | complex 73 | | NestedSequence[bool | int | float | complex] 74 | | SupportsBufferProtocol 75 | | SupportsDLPack 76 | ), 77 | dtype: None | Dtype | type | str = None, 78 | device: None | Device = None, 79 | copy: None | bool = None, 80 | ) -> array: 81 | """ 82 | Convert the input to an array. 83 | 84 | Args: 85 | obj: Object to be converted to an array. May be a Python scalar, a 86 | (possibly nested) sequence of Python scalars, or an object 87 | supporting the Python buffer protocol or DLPack. 88 | dtype: Output array data type. If `dtype` is `None`, the output array 89 | data type is inferred from the data type(s) in `obj`. If all input 90 | values are Python scalars, then, in order of precedence, 91 | - if all values are of type `bool`, the output data type is 92 | `bool`. 93 | - if all values are of type `int` or are a mixture of `bool` 94 | and `int`, the output data type is `np.int64`. 95 | - if one or more values are `complex` numbers, the output data 96 | type is `np.complex128`. 97 | - if one or more values are `float`s, the output data type is 98 | `np.float64`. 99 | device: Device on which to place the created array. If device is `None` 100 | and `obj` is an array, the output array device is inferred from 101 | `obj`. If `"cpu"`, the array is backed by NumPy and resides in main 102 | memory; if `"cuda"`, the array is backed by CuPy and resides in 103 | CUDA global memory. 104 | copy: Boolean indicating whether or not to copy the input. If `True`, 105 | this function always copies. If `False`, the function never copies 106 | for input which supports the buffer protocol and raises a 107 | ValueError in case a copy would be necessary. If `None`, the 108 | function reuses the existing memory buffer if possible and copies 109 | otherwise. 110 | 111 | Returns: 112 | An array containing the data from `obj`. 113 | 114 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.asarray.html 115 | """ 116 | 117 | return array(obj, dtype=dtype, device=device, copy=copy) 118 | 119 | 120 | def empty( 121 | shape: int | tuple[int, ...], 122 | *, 123 | dtype: None | Dtype = None, 124 | device: None | Device = None, 125 | ) -> array: 126 | """ 127 | Returns an uninitialized array having a specified shape. 128 | 129 | Args: 130 | shape: Output array shape. 131 | dtype: Output array data type. If `dtype` is `None`, the output array 132 | data type is `np.float64`. 133 | device: Device on which to place the created array. 134 | 135 | Returns: 136 | An array containing uninitialized data. 137 | 138 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html 139 | """ 140 | 141 | device, ns = device_namespace(device) 142 | return _box(array, ns.empty(shape, dtype=dtype)) 143 | 144 | 145 | def empty_like( 146 | x: array, /, *, dtype: None | Dtype = None, device: None | Device = None 147 | ) -> array: 148 | """ 149 | Returns an uninitialized array with the same shape as an input array x. 150 | 151 | Args: 152 | x: Input array from which to derive the output array shape. 153 | dtype: Output array data type. If `dtype` is `None`, the output array 154 | data type is inferred from `x`. 155 | device: Device on which to place the created array. If `device` is 156 | `None`, output array device is inferred from `x`. 157 | 158 | Returns: 159 | An array having the same shape as `x` and containing uninitialized data. 160 | 161 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty_like.html 162 | """ 163 | 164 | (impl,) = _unbox(x) 165 | if isinstance(impl, ak.Array): 166 | return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device) 167 | else: 168 | _, ns = device_namespace(x.device if device is None else device) 169 | return _box(type(x), ns.empty_like(impl), dtype=dtype, device=device) 170 | 171 | 172 | def eye( 173 | n_rows: int, 174 | n_cols: None | int = None, 175 | /, 176 | *, 177 | k: int = 0, 178 | dtype: None | Dtype = None, 179 | device: None | Device = None, 180 | ) -> array: 181 | """ 182 | Returns a two-dimensional array with ones on the kth diagonal and zeros elsewhere. 183 | 184 | Args: 185 | n_rows: Number of rows in the output array. 186 | n_cols: Number of columns in the output array. If `None`, the default 187 | number of columns in the output array is equal to `n_rows`. 188 | k: Index of the diagonal. A positive value refers to an upper diagonal, 189 | a negative value to a lower diagonal, and 0 to the main diagonal. 190 | dtype: Output array data type. If `dtype` is `None`, the output array 191 | data type is `np.float64`. 192 | device: Device on which to place the created array. 193 | 194 | Returns: 195 | An array where all elements are equal to zero, except for the kth 196 | diagonal, whose values are equal to one. 197 | 198 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html 199 | """ 200 | 201 | device, ns = device_namespace(device) 202 | return _box(array, ns.eye(n_rows, n_cols, k, dtype=dtype)) 203 | 204 | 205 | def from_dlpack(x: object, /) -> array: 206 | """ 207 | Returns a new array containing the data from another (array) object with a `__dlpack__` method. 208 | 209 | Args: 210 | x: Input (array) object. 211 | 212 | Returns: 213 | An array containing the data in `x`. 214 | 215 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html 216 | """ 217 | 218 | device_type, _ = x.__dlpack_device__() # type: ignore[attr-defined] 219 | if ( 220 | isinstance(device_type, enum.Enum) and device_type.value == 1 221 | ) or device_type == 1: 222 | y = np.from_dlpack(x) 223 | elif ( 224 | isinstance(device_type, enum.Enum) and device_type.value == 2 225 | ) or device_type == 2: 226 | cp = _import.cupy() 227 | y = cp.from_dlpack(x) 228 | else: 229 | msg = f"unsupported __dlpack_device__ type: {device_type}" 230 | raise TypeError(msg) 231 | 232 | return _box(array, y) 233 | 234 | 235 | def full( 236 | shape: int | tuple[int, ...], 237 | fill_value: bool | int | float | complex, 238 | *, 239 | dtype: None | Dtype = None, 240 | device: None | Device = None, 241 | ) -> array: 242 | """ 243 | Returns a new array having a specified shape and filled with fill_value. 244 | 245 | Args: 246 | shape: Output array shape. 247 | fill_value: Fill value. 248 | dtype: Output array data type. If `dtype` is `None`, the output array 249 | data type is inferred from `fill_value` according to the following 250 | rules: 251 | - if the fill value is an `int`, the output array data type is 252 | `np.int64`. 253 | - if the fill value is a `float`, the output array data type 254 | is `np.float64`. 255 | - if the fill value is a `complex` number, the output array 256 | data type is `np.complex128`. 257 | - if the fill value is a `bool`, the output array is 258 | `np.bool_`. 259 | device: Device on which to place the created array. 260 | 261 | Returns: 262 | An array where every element is equal to fill_value. 263 | 264 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html 265 | """ 266 | 267 | device, ns = device_namespace(device) 268 | return _box(array, ns.full(shape, fill_value, dtype=dtype)) 269 | 270 | 271 | def full_like( 272 | x: array, 273 | /, 274 | fill_value: bool | int | float | complex, 275 | *, 276 | dtype: None | Dtype = None, 277 | device: None | Device = None, 278 | ) -> array: 279 | """ 280 | Returns a new array filled with fill_value and having the same shape as an input array x. 281 | 282 | Args: 283 | x: Input array from which to derive the output array shape. 284 | fill_value: Fill value. 285 | dtype: Output array data type. If `dtype` is `None`, the output array 286 | data type is inferred from `x`. 287 | device: Device on which to place the created array. If `device` is 288 | `None`, the output array device is inferred from `x`. 289 | 290 | Returns: 291 | An array having the same shape as `x` and where every element is equal 292 | to `fill_value`. 293 | 294 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.full_like.html 295 | """ 296 | 297 | (impl,) = _unbox(x) 298 | if isinstance(impl, ak.Array): 299 | return _box(type(x), ak.full_like(impl, fill_value), dtype=dtype, device=device) 300 | else: 301 | _, ns = device_namespace(x.device if device is None else device) 302 | return _box(type(x), ns.full_like(impl, fill_value), dtype=dtype, device=device) 303 | 304 | 305 | def linspace( 306 | start: int | float | complex, 307 | stop: int | float | complex, 308 | /, 309 | num: int, 310 | *, 311 | dtype: None | Dtype = None, 312 | device: None | Device = None, 313 | endpoint: bool = True, 314 | ) -> array: 315 | r""" 316 | Returns evenly spaced numbers over a specified interval. 317 | 318 | Let `N` be the number of generated values (which is either `num` or `num+1` 319 | depending on whether `endpoint` is `True` or `False`, respectively). For 320 | real-valued output arrays, the spacing between values is given by 321 | 322 | $$\Delta_{\textrm{real}} = \frac{\textrm{stop} - \textrm{start}}{N - 1}$$ 323 | 324 | For complex output arrays, let `a = real(start)`, `b = imag(start)`, 325 | `c = real(stop)`, and `d = imag(stop)`. The spacing between complex values 326 | is given by 327 | 328 | $$\Delta_{\textrm{complex}} = \frac{c-a}{N-1} + \frac{d-b}{N-1} j$$ 329 | 330 | Args: 331 | start: The start of the interval. 332 | stop: The end of the interval. If `endpoint` is `False`, the function 333 | generates a sequence of `num+1` evenly spaced numbers starting with 334 | `start` and ending with `stop` and exclude the `stop` from the 335 | returned array such that the returned array consists of evenly 336 | spaced numbers over the half-open interval `[start, stop)`. If 337 | endpoint is `True`, the output array consists of evenly spaced 338 | numbers over the closed interval `[start, stop]`. 339 | num: Number of samples. Must be a nonnegative integer value. 340 | dtype: Output array data type. Should be a floating-point data type. 341 | If `dtype` is `None`, 342 | - if either `start` or `stop` is a `complex` number, the 343 | output data type is `np.complex128`. 344 | - if both `start` and `stop` are real-valued, the output data 345 | type is `np.float64`. 346 | device: Device on which to place the created array. 347 | endpoint: Boolean indicating whether to include `stop` in the interval. 348 | 349 | Returns: 350 | A one-dimensional array containing evenly spaced values. 351 | 352 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html 353 | """ 354 | 355 | device, ns = device_namespace(device) 356 | return _box( 357 | array, ns.linspace(start, stop, num=num, endpoint=endpoint, dtype=dtype) 358 | ) 359 | 360 | 361 | def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: 362 | """ 363 | Returns coordinate matrices from coordinate vectors. 364 | 365 | Args: 366 | arrays: An arbitrary number of one-dimensional arrays representing 367 | grid coordinates. Each array should have the same numeric data type. 368 | indexing: Cartesian `"xy"` or matrix `"ij"` indexing of output. If 369 | provided zero or one one-dimensional vector(s) (i.e., the zero- and 370 | one-dimensional cases, respectively), the `indexing` keyword has no 371 | effect and should be ignored. 372 | 373 | Returns: 374 | List of `N` arrays, where `N` is the number of provided one-dimensional 375 | input arrays. Each returned array must have rank `N`. For `N` 376 | one-dimensional arrays having lengths `Ni = len(xi)`, 377 | - if matrix indexing `"ij"`, then each returned array must have the 378 | shape `(N1, N2, N3, ..., Nn)`. 379 | - if Cartesian indexing `"xy"`, then each returned array must have 380 | shape `(N2, N1, N3, ..., Nn)`. 381 | 382 | Accordingly, for the two-dimensional case with input one-dimensional 383 | arrays of length `M` and `N`, if matrix indexing `"ij"`, then each 384 | returned array must have shape `(M, N)`, and, if Cartesian indexing 385 | `"xy"`, then each returned array must have shape `(N, M)`. 386 | 387 | Similarly, for the three-dimensional case with input one-dimensional 388 | arrays of length `M`, `N`, and `P`, if matrix indexing `"ij"`, then 389 | each returned array must have shape `(M, N, P)`, and, if Cartesian 390 | indexing `"xy"`, then each returned array must have shape `(N, M, P)`. 391 | 392 | Each returned array should have the same data type as the input arrays. 393 | 394 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.meshgrid.html 395 | """ 396 | 397 | arrays # noqa: B018, pylint: disable=W0104 398 | indexing # noqa: B018, pylint: disable=W0104 399 | raise NotImplementedError("TODO 43") # noqa: EM101 400 | 401 | 402 | def ones( 403 | shape: int | tuple[int, ...], 404 | *, 405 | dtype: None | Dtype = None, 406 | device: None | Device = None, 407 | ) -> array: 408 | """ 409 | Returns a new array having a specified `shape` and filled with ones. 410 | 411 | Args: 412 | shape: Output array shape. 413 | dtype: Output array data type. If `dtype` is `None`, the output array 414 | data type is `np.float64`. 415 | device: Device on which to place the created array. 416 | 417 | Returns: 418 | An array containing ones. 419 | 420 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html 421 | """ 422 | 423 | device, ns = device_namespace(device) 424 | return _box(array, ns.ones(shape, dtype=dtype)) 425 | 426 | 427 | def ones_like( 428 | x: array, /, *, dtype: None | Dtype = None, device: None | Device = None 429 | ) -> array: 430 | """ 431 | Returns a new array filled with ones and having the same `shape` as an 432 | input array `x`. 433 | 434 | Args: 435 | x: Input array from which to derive the output array shape. 436 | dtype: Output array data type. If `dtype` is `None`, the output array 437 | data type is inferred from `x`. 438 | device: Device on which to place the created array. If `device` is 439 | `None`, the output array device is inferred from `x`. 440 | 441 | Returns: 442 | An array having the same shape as x and filled with ones. 443 | 444 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html 445 | """ 446 | 447 | (impl,) = _unbox(x) 448 | if isinstance(impl, ak.Array): 449 | return _box(type(x), ak.ones_like(impl), dtype=dtype, device=device) 450 | else: 451 | _, ns = device_namespace(x.device if device is None else device) 452 | return _box(type(x), ns.ones_like(impl), dtype=dtype, device=device) 453 | 454 | 455 | def tril(x: array, /, *, k: int = 0) -> array: 456 | """ 457 | Returns the lower triangular part of a matrix (or a stack of matrices) `x`. 458 | 459 | Args: 460 | x: Input array having shape `(..., M, N)` and whose innermost two 461 | dimensions form `M` by `N` matrices. 462 | `k`: Diagonal above which to zero elements. If `k = 0`, the diagonal is 463 | the main diagonal. If `k < 0`, the diagonal is below the main 464 | diagonal. If `k > 0`, the diagonal is above the main diagonal. 465 | 466 | Returns: 467 | An array containing the lower triangular part(s). The returned array 468 | has the same shape and data type as `x`. All elements above the 469 | specified diagonal `k` are zero. The returned array is allocated on the 470 | same device as `x`. 471 | 472 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.tril.html 473 | """ 474 | 475 | x # noqa: B018, pylint: disable=W0104 476 | k # noqa: B018, pylint: disable=W0104 477 | raise NotImplementedError("TODO 46") # noqa: EM101 478 | 479 | 480 | def triu(x: array, /, *, k: int = 0) -> array: 481 | """ 482 | Returns the upper triangular part of a matrix (or a stack of matrices) `x`. 483 | 484 | Args: 485 | x: Input array having shape `(..., M, N)` and whose innermost two 486 | dimensions form `M` by `N` matrices. 487 | k: Diagonal below which to zero elements. If `k = 0`, the diagonal is 488 | the main diagonal. If `k < 0`, the diagonal is below the main 489 | diagonal. If `k > 0`, the diagonal is above the main diagonal. 490 | 491 | Returns: 492 | An array containing the upper triangular part(s). The returned array 493 | has the same shape and data type as `x`. All elements below the 494 | specified diagonal `k` are zero. The returned array is allocated on the 495 | same device as `x`. 496 | 497 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.triu.html 498 | """ 499 | 500 | x # noqa: B018, pylint: disable=W0104 501 | k # noqa: B018, pylint: disable=W0104 502 | raise NotImplementedError("TODO 47") # noqa: EM101 503 | 504 | 505 | def zeros( 506 | shape: int | tuple[int, ...], 507 | *, 508 | dtype: None | Dtype = None, 509 | device: None | Device = None, 510 | ) -> array: 511 | """ 512 | Returns a new array having a specified shape and filled with zeros. 513 | 514 | Args: 515 | shape: Output array shape. 516 | dtype: Output array data type. If `dtype` is `None`, the output array 517 | data type is `np.float64`. 518 | device: Device on which to place the created array. 519 | 520 | Returns: 521 | An array containing zeros. 522 | 523 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html 524 | """ 525 | 526 | device, ns = device_namespace(device) 527 | return _box(array, ns.zeros(shape, dtype=dtype)) 528 | 529 | 530 | def zeros_like( 531 | x: array, /, *, dtype: None | Dtype = None, device: None | Device = None 532 | ) -> array: 533 | """ 534 | Returns a new array filled with zeros and having the same `shape` as an 535 | input array `x`. 536 | 537 | Args: 538 | x: Input array from which to derive the output array shape. 539 | dtype: Output array data type. If `dtype` is `None`, the output array 540 | data type is inferred from `x`. 541 | device: Device on which to place the created array. If `device` is 542 | `None`, the output array device is inferred from `x`. 543 | 544 | Returns: 545 | An array having the same shape as `x` and filled with zeros. 546 | 547 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html 548 | """ 549 | 550 | (impl,) = _unbox(x) 551 | if isinstance(impl, ak.Array): 552 | return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device) 553 | else: 554 | _, ns = device_namespace(x.device if device is None else device) 555 | return _box(type(x), ns.zeros_like(impl), dtype=dtype, device=device) 556 | -------------------------------------------------------------------------------- /src/ragged/_spec_data_type_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/data_type_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from dataclasses import dataclass 10 | 11 | import numpy as np 12 | 13 | from ._spec_array_object import _box, _unbox, array 14 | from ._typing import Dtype 15 | 16 | _type = type 17 | 18 | 19 | def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: 20 | """ 21 | Copies an array to a specified data type irrespective of type promotion rules. 22 | 23 | Args: 24 | x: Array to cast. 25 | dtype: Desired data type. 26 | copy: Ignored because `ragged.array` data buffers are immutable. 27 | 28 | Returns: 29 | An array having the specified data type. The returned array has the 30 | same `shape` as `x`. 31 | 32 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html 33 | """ 34 | 35 | copy # noqa: B018, argument is ignored, pylint: disable=W0104 36 | 37 | return _box(type(x), *_unbox(x), dtype=dtype) 38 | 39 | 40 | def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: 41 | """ 42 | Determines if one data type can be cast to another data type according type 43 | promotion rules. 44 | 45 | Args: 46 | from: Input data type or array from which to cast. 47 | to: Desired data type. 48 | 49 | Returns: 50 | `True` if the cast can occur according to type promotion rules; 51 | otherwise, `False`. 52 | 53 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html 54 | """ 55 | 56 | return bool(np.can_cast(from_, to)) 57 | 58 | 59 | @dataclass 60 | class finfo_object: # pylint: disable=C0103 61 | """ 62 | Output of `ragged.finfo` with the following attributes. 63 | 64 | - bits (int): number of bits occupied by the real-valued floating-point 65 | data type. 66 | - eps (float): difference between 1.0 and the next smallest representable 67 | real-valued floating-point number larger than 1.0 according to the 68 | IEEE-754 standard. 69 | - max (float): largest representable real-valued number. 70 | - min (float): smallest representable real-valued number. 71 | - smallest_normal (float): smallest positive real-valued floating-point 72 | number with full precision. 73 | - dtype (np.dtype): real-valued floating-point data type. 74 | 75 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html 76 | """ 77 | 78 | bits: int 79 | eps: float 80 | max: float 81 | min: float 82 | smallest_normal: float 83 | dtype: np.dtype 84 | 85 | 86 | def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 87 | """ 88 | Machine limits for floating-point data types. 89 | 90 | Args: 91 | type: the kind of floating-point data-type about which to get 92 | information. If complex, the information is about its component 93 | data type. 94 | 95 | Returns: 96 | An object having the following attributes: 97 | 98 | - bits (int): number of bits occupied by the real-valued floating-point 99 | data type. 100 | - eps (float): difference between 1.0 and the next smallest 101 | representable real-valued floating-point number larger than 1.0 102 | according to the IEEE-754 standard. 103 | - max (float): largest representable real-valued number. 104 | - min (float): smallest representable real-valued number. 105 | - smallest_normal (float): smallest positive real-valued floating-point 106 | number with full precision. 107 | - dtype (np.dtype): real-valued floating-point data type. 108 | 109 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html 110 | """ 111 | 112 | if not isinstance(type, np.dtype): 113 | if not isinstance(type, _type) and hasattr(type, "dtype"): 114 | out = np.finfo(type.dtype) 115 | else: 116 | out = np.finfo(np.dtype(type)) 117 | else: 118 | out = np.finfo(type) 119 | return finfo_object( 120 | out.bits, out.eps, out.max, out.min, out.smallest_normal, out.dtype 121 | ) 122 | 123 | 124 | @dataclass 125 | class iinfo_object: # pylint: disable=C0103 126 | """ 127 | Output of `ragged.iinfo` with the following attributes. 128 | 129 | - bits (int): number of bits occupied by the type. 130 | - max (int): largest representable number. 131 | - min (int): smallest representable number. 132 | - dtype (np.dtype): integer data type. 133 | 134 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html 135 | """ 136 | 137 | bits: int 138 | max: int 139 | min: int 140 | dtype: np.dtype 141 | 142 | 143 | def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 144 | """ 145 | Machine limits for integer data types. 146 | 147 | Args: 148 | type: The kind of integer data-type about which to get information. 149 | 150 | Returns: 151 | An object having the following attributes: 152 | 153 | - bits (int): number of bits occupied by the type. 154 | - max (int): largest representable number. 155 | - min (int): smallest representable number. 156 | - dtype (np.dtype): integer data type. 157 | 158 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html 159 | """ 160 | 161 | if not isinstance(type, np.dtype): 162 | if not isinstance(type, _type) and hasattr(type, "dtype"): 163 | out = np.iinfo(type.dtype) 164 | else: 165 | out = np.iinfo(np.dtype(type)) 166 | else: 167 | out = np.iinfo(type) 168 | return iinfo_object(out.bits, out.max, out.min, out.dtype) 169 | 170 | 171 | def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: 172 | """ 173 | Returns a boolean indicating whether a provided dtype is of a specified data type "kind". 174 | 175 | Args: 176 | dtype: The input dtype. 177 | kind: Data type kind. 178 | If `kind` is a `dtype`, the function returns a boolean indicating 179 | whether the input `dtype` is equal to the dtype specified by `kind`. 180 | 181 | If `kind` is a string, the function returns a boolean indicating 182 | whether the input `dtype` is of a specified data type kind. The 183 | following dtype kinds must be supported: 184 | 185 | - `"bool"`: boolean data types (e.g., bool). 186 | - `"signed integer"`: signed integer data types (e.g., `int8`, 187 | `int16`, `int32`, `int64`). 188 | - `"unsigned integer"`: unsigned integer data types (e.g., 189 | `uint8`, `uint16`, `uint32`, `uint64`). 190 | - `"integral"`: integer data types. Shorthand for 191 | (`"signed integer"`, `"unsigned integer"`). 192 | - `"real floating"`: real-valued floating-point data types 193 | (e.g., `float32`, `float64`). 194 | - `"complex floating"`: complex floating-point data types 195 | (e.g., `complex64`, `complex128`). 196 | - `"numeric"`: numeric data types. Shorthand for (`"integral"`, 197 | `"real floating"`, `"complex floating"`). 198 | 199 | If `kind` is a tuple, the tuple specifies a union of dtypes and/or 200 | kinds, and the function returns a boolean indicating whether the 201 | input `dtype` is either equal to a specified dtype or belongs to at 202 | least one specified data type kind. 203 | 204 | Returns: 205 | Boolean indicating whether a provided dtype is of a specified data type 206 | kind. 207 | 208 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html 209 | """ 210 | 211 | dtype # noqa: B018, pylint: disable=W0104 212 | kind # noqa: B018, pylint: disable=W0104 213 | raise NotImplementedError("TODO 54") # noqa: EM101 214 | 215 | 216 | def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: 217 | """ 218 | Returns the dtype that results from applying the type promotion rules to 219 | the arguments. 220 | 221 | Args: 222 | arrays_and_dtypes: An arbitrary number of input arrays and/or dtypes. 223 | 224 | Returns: 225 | The dtype resulting from an operation involving the input arrays and dtypes. 226 | 227 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html 228 | """ 229 | 230 | return np.result_type(*arrays_and_dtypes) 231 | -------------------------------------------------------------------------------- /src/ragged/_spec_indexing_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/indexing_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | import numpy as np 11 | 12 | from ._spec_array_object import _box, array 13 | 14 | 15 | def take(x: array, indices: array, /, *, axis: None | int = None) -> array: 16 | """ 17 | Returns elements of an array along an axis. 18 | 19 | Conceptually, `take(x, indices, axis=3)` is equivalent to 20 | `x[:,:,:,indices,...]`. 21 | 22 | Args: 23 | x: Input array. 24 | indices: Array indices. The array must be one-dimensional and have an 25 | integer data type. 26 | axis: Axis over which to select values. If `axis` is negative, the 27 | function determines the axis along which to select values by 28 | counting from the last dimension. 29 | 30 | If `x` is a one-dimensional array, providing an axis is optional; 31 | however, if `x` has more than one dimension, providing an `axis` is 32 | required. 33 | 34 | Returns: 35 | An array having the same data type as `x`. The output array has the 36 | same rank (i.e., number of dimensions) as `x` and has the same shape as 37 | `x`, except for the axis specified by `axis` whose size must equal the 38 | number of elements in indices. 39 | 40 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html 41 | """ 42 | 43 | if axis is None: 44 | if x.ndim <= 1: 45 | axis = 0 46 | else: 47 | msg = f"for an {x.ndim}-dimensional array (greater than 1-dimensional), the 'axis' argument is required" 48 | raise TypeError(msg) 49 | 50 | original_axis = axis 51 | if axis < 0: 52 | axis += x.ndim + 1 53 | if not 0 <= axis < x.ndim: 54 | msg = f"axis {original_axis} is out of bounds for array of dimension {x.ndim}" 55 | raise ak.errors.AxisError(msg) 56 | 57 | toslice = x._impl # pylint: disable=W0212 58 | if not isinstance(toslice, ak.Array): 59 | toslice = ak.Array(toslice[np.newaxis]) # type: ignore[index] 60 | 61 | if not isinstance(indices, array): 62 | indices = array(indices) # type: ignore[unreachable] 63 | indexarray = indices._impl # pylint: disable=W0212 64 | 65 | slicer = (slice(None),) * axis + (indexarray,) 66 | return _box(type(x), toslice[slicer]) 67 | -------------------------------------------------------------------------------- /src/ragged/_spec_linear_algebra_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from collections.abc import Sequence 10 | 11 | from ._spec_array_object import array 12 | 13 | 14 | def matmul(x1: array, x2: array, /) -> array: 15 | """ 16 | Computes the matrix product. 17 | 18 | Args: 19 | x1: First input array. Must have at least one dimension. If `x1` is 20 | one-dimensional having shape `(M,)` and `x2` has more than one 21 | dimension, `x1` is promoted to a two-dimensional array by prepending 1 22 | to its dimensions (i.e., has shape `(1, M)`). After matrix 23 | multiplication, the prepended dimensions in the returned array are 24 | removed. If `x1` has more than one dimension (including after 25 | vector-to-matrix promotion), `shape(x1)[:-2]` is compatible with 26 | `shape(x2)[:-2]` (after vector-to-matrix promotion). If `x1` has shape 27 | `(..., M, K)`, the innermost two dimensions form matrices on which to 28 | perform matrix multiplication. 29 | x2: Second input array. Must have at least one dimension. If `x2` is 30 | one-dimensional having shape `(N,)` and `x1` has more than one 31 | dimension, `x2` is promoted to a two-dimensional array by appending 1 32 | to its dimensions (i.e., has shape `(N, 1)`). After matrix 33 | multiplication, the appended dimensions in the returned array are 34 | removed. If `x2` has more than one dimension (including after 35 | vector-to-matrix promotion), `shape(x2)[:-2]` is compatible with 36 | `shape(x1)[:-2]` (after vector-to-matrix promotion). If `x2` has shape 37 | `(..., K, N)`, the innermost two dimensions form matrices on which to 38 | perform matrix multiplication. 39 | 40 | Returns: 41 | If both `x1` and `x2` are one-dimensional arrays having shape `(N,)`, a 42 | zero-dimensional array containing the inner product as its only 43 | element. 44 | 45 | If `x1` is a two-dimensional array having shape `(M, K)` and `x2` is a 46 | two-dimensional array having shape `(K, N)`, a two-dimensional array 47 | containing the conventional matrix product and having shape `(M, N)`. 48 | 49 | If `x1` is a one-dimensional array having shape `(K,)` and `x2` is an 50 | array having shape `(..., K, N)`, an array having shape `(..., N)` 51 | (i.e., prepended dimensions during vector-to-matrix promotion are 52 | removed) and containing the conventional matrix product. 53 | 54 | If `x1` is an array having shape `(..., M, K)` and `x2` is a 55 | one-dimensional array having shape `(K,)`, an array having shape 56 | `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion 57 | are removed) and containing the conventional matrix product. 58 | 59 | If `x1` is a two-dimensional array having shape `(M, K)` and `x2` is an 60 | array having shape `(..., K, N)`, an array having shape `(..., M, N)` 61 | and containing the conventional matrix product for each stacked matrix. 62 | 63 | If `x1` is an array having shape `(..., M, K)` and `x2` is a 64 | two-dimensional array having shape `(K, N)`, an array having shape 65 | `(..., M, N)` and containing the conventional matrix product for each 66 | stacked matrix. 67 | 68 | If either `x1` or `x2` has more than two dimensions, an array having a 69 | shape determined by broadcasting `shape(x1)[:-2]` against 70 | `Shape(x2)[:-2]` and containing the conventional matrix product for 71 | each stacked matrix. 72 | 73 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html 74 | """ 75 | 76 | x1 # noqa: B018, pylint: disable=W0104 77 | x2 # noqa: B018, pylint: disable=W0104 78 | raise NotImplementedError("TODO 110") # noqa: EM101 79 | 80 | 81 | def matrix_transpose(x: array, /) -> array: 82 | """ 83 | Transposes a matrix (or a stack of matrices) x. 84 | 85 | Args: 86 | x: Input array having shape `(..., M, N)` and whose innermost two 87 | dimensions form `M` by `N` matrices. 88 | 89 | Returns: 90 | An array containing the transpose for each matrix and having shape 91 | `(..., N, M)`. The returned array has the same data type as `x`. 92 | 93 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.matrix_transpose.html 94 | """ 95 | 96 | x # noqa: B018, pylint: disable=W0104 97 | raise NotImplementedError("TODO 111") # noqa: EM101 98 | 99 | 100 | def tensordot( 101 | x1: array, x2: array, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2 102 | ) -> array: 103 | """ 104 | Returns a tensor contraction of `x1` and `x2` over specific axes. 105 | 106 | The tensordot function corresponds to the generalized matrix product. 107 | 108 | Args: 109 | x1: First input array. 110 | x2: Second input array. Corresponding contracted axes of `x1` and `x2` 111 | must be equal. 112 | axes: Number of axes (dimensions) to contract or explicit sequences of 113 | axes (dimensions) for `x1` and `x2`, respectively. 114 | 115 | If `axes` is an `int` equal to `N`, then contraction is performed 116 | over the last `N` axes of `x1` and the first `N` axes of `x2` in 117 | order. The size of each corresponding axis (dimension) match. Must 118 | be nonnegative. 119 | 120 | If `N` equals 0, the result is the tensor (outer) product. 121 | 122 | If `N` equals 1, the result is the tensor dot product. 123 | 124 | If `N` equals 2, the result is the tensor double contraction. 125 | 126 | If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the 127 | first sequence applies to `x1` and the second sequence to `x2`. 128 | Both sequences must have the same length. Each axis (dimension) 129 | `x1_axes[i]` for `x1` must have the same size as the respective 130 | axis (dimension) `x2_axes[i]` for `x2`. Each sequence must consist 131 | of unique (nonnegative) integers that specify valid axes for each 132 | respective array. 133 | 134 | Returns: 135 | An array containing the tensor contraction whose shape consists of the 136 | non-contracted axes (dimensions) of the first array `x1`, followed by 137 | the non-contracted axes (dimensions) of the second array `x2`. The 138 | returned array has a data type determined by type promotion rules. 139 | 140 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.tensordot.html 141 | """ 142 | 143 | x1 # noqa: B018, pylint: disable=W0104 144 | x2 # noqa: B018, pylint: disable=W0104 145 | axes # noqa: B018, pylint: disable=W0104 146 | raise NotImplementedError("TODO 112") # noqa: EM101 147 | 148 | 149 | def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: 150 | r""" 151 | Computes the (vector) dot product of two arrays. 152 | 153 | Let $\mathbf{a}$ be a vector in `x1` and $\mathbf{b}$ be a corresponding 154 | vector in `x2`. The dot product is defined as 155 | 156 | $$\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i$$ 157 | 158 | over the dimension specified by `axis` and where $n$ is the dimension size 159 | and $\overline{a_i}$ denotes the complex conjugate if $a_i$ is complex and 160 | the identity if $a_i$ is real-valued. 161 | 162 | Args: 163 | x1: First input array. 164 | x2: Second input array. Must be broadcastable with `x1` for all 165 | non-contracted axes. The size of the axis over which to compute the 166 | dot product is the same size as the respective axis in `x1`. 167 | 168 | The contracted axis (dimension) is not broadcasted. 169 | axis: Axis over which to compute the dot product. Must be an integer on 170 | the interval `[-N, N)`, where `N` is the rank (number of dimensions) of 171 | the shape determined by broadcasting. If specified as a negative 172 | integer, the function determines the axis along which to compute the 173 | dot product by counting backward from the last dimension (where `-1` 174 | refers to the last dimension). 175 | 176 | Returns: 177 | If `x1` and `x2` are both one-dimensional arrays, a zero-dimensional 178 | containing the dot product; otherwise, a non-zero-dimensional array 179 | containing the dot products and having rank `N - 1`, where `N` is the 180 | rank (number of dimensions) of the shape determined by broadcasting 181 | along the non-contracted axes. The returned array has a data type 182 | determined by type promotion. 183 | 184 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.vecdot.html 185 | """ 186 | 187 | x1 # noqa: B018, pylint: disable=W0104 188 | x2 # noqa: B018, pylint: disable=W0104 189 | axis # noqa: B018, pylint: disable=W0104 190 | raise NotImplementedError("TODO 113") # noqa: EM101 191 | -------------------------------------------------------------------------------- /src/ragged/_spec_manipulation_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import numbers 10 | 11 | import awkward as ak 12 | import numpy as np 13 | 14 | from ._spec_array_object import _box, _unbox, array 15 | 16 | 17 | def broadcast_arrays(*arrays: array) -> list[array]: 18 | """ 19 | Broadcasts one or more arrays against one another. 20 | 21 | Args: 22 | arrays: An arbitrary number of to-be broadcasted arrays. 23 | 24 | Returns: 25 | A list of broadcasted arrays. Each array has the same shape. Each array 26 | has the same dtype as its corresponding input array. 27 | 28 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_arrays.html 29 | """ 30 | 31 | impls = _unbox(*arrays) 32 | if all(not isinstance(x, ak.Array) for x in impls): 33 | return [_box(type(arrays[i]), x) for i, x in enumerate(impls)] 34 | else: 35 | out = [x if isinstance(x, ak.Array) else x.reshape((1,)) for x in impls] # type: ignore[union-attr] 36 | return [ 37 | _box(type(arrays[i]), x) for i, x in enumerate(ak.broadcast_arrays(*out)) 38 | ] 39 | 40 | 41 | def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array: 42 | """ 43 | Broadcasts an array to a specified shape. 44 | 45 | Args: 46 | x: Array to broadcast. 47 | shape: Array shape. Must be compatible with `x`. If the array is 48 | incompatible with the specified shape, the function raises an 49 | exception. 50 | 51 | Returns: 52 | An array having a specified shape. Must have the same data type as x. 53 | 54 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html 55 | """ 56 | 57 | x # noqa: B018, pylint: disable=W0104 58 | shape # noqa: B018, pylint: disable=W0104 59 | raise NotImplementedError("TODO 115") # noqa: EM101 60 | 61 | 62 | def concat( 63 | arrays: tuple[array, ...] | list[array], /, *, axis: None | int = 0 64 | ) -> array: 65 | """ 66 | Joins a sequence of arrays along an existing axis. 67 | 68 | Args: 69 | arrays: Input arrays to join. The arrays must have the same shape, 70 | except in the dimension specified by `axis`. 71 | axis: Axis along which the arrays will be joined. If `axis` is `None`, 72 | arrays are flattened before concatenation. If `axis` is negative, 73 | the function determines the axis along which to join by counting 74 | from the last dimension. 75 | 76 | Returns: 77 | An output array containing the concatenated values. If the input arrays 78 | have different data types, normal type promotion rules apply. If the 79 | input arrays have the same data type, the output array has the same 80 | data type as the input arrays. 81 | 82 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.concat.html 83 | """ 84 | 85 | if len(arrays) == 0: 86 | msg = "need at least one array to concatenate" 87 | raise ValueError(msg) 88 | 89 | first = arrays[0] 90 | if not all(first.ndim == x.ndim for x in arrays[1:]): 91 | msg = "all the input arrays must have the same number of dimensions" 92 | raise ValueError(msg) 93 | 94 | if first.ndim == 0: 95 | msg = "zero-dimensional arrays cannot be concatenated" 96 | raise ValueError(msg) 97 | 98 | impls = _unbox(*arrays) 99 | assert all(isinstance(x, ak.Array) for x in impls) 100 | 101 | if axis is None: 102 | impls = [ak.ravel(x) for x in impls] # type: ignore[assignment] 103 | axis = 0 104 | 105 | return _box(type(first), ak.concatenate(impls, axis=axis)) 106 | 107 | 108 | def expand_dims(x: array, /, *, axis: int = 0) -> array: 109 | """ 110 | Expands the shape of an array by inserting a new axis (dimension) of size 111 | one at the position specified by `axis`. 112 | 113 | Args: 114 | x: Input array. 115 | axis: Axis position (zero-based). If `x` has rank (i.e, number of 116 | dimensions) `N`, a valid `axis` must reside on the closed-interval 117 | `[-N-1, N]`. If provided a negative axis, the axis position at 118 | which to insert a singleton dimension is computed as 119 | `N + axis + 1`. Hence, if provided -1, the resolved axis position 120 | is `N` (i.e., a singleton dimension is appended to the input array 121 | `x`). If provided `-N - 1`, the resolved axis position is 0 (i.e., 122 | a singleton dimension is prepended to the input array x). An 123 | `IndexError` exception is raised if provided an invalid axis 124 | position. 125 | 126 | Returns: 127 | An expanded output array having the same data type as `x`. 128 | 129 | This is the opposite of `ragged.squeeze`. 130 | 131 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html 132 | """ 133 | 134 | original_axis = axis 135 | if axis < 0: 136 | axis += x.ndim + 1 137 | if not 0 <= axis <= x.ndim: 138 | msg = ( 139 | f"axis {original_axis} is out of bounds for array of dimension {x.ndim + 1}" 140 | ) 141 | raise ak.errors.AxisError(msg) 142 | 143 | slicer = (slice(None),) * axis + (np.newaxis,) 144 | shape = x.shape[:axis] + (1,) + x.shape[axis:] 145 | 146 | out = x._impl[slicer] # type: ignore[index] # pylint: disable=W0212 147 | if not isinstance(out, ak.Array): 148 | out = ak.Array(out) 149 | 150 | return x._new(out, shape, x.dtype, x.device) # pylint: disable=W0212 151 | 152 | 153 | def flip(x: array, /, *, axis: None | int | tuple[int, ...] = None) -> array: 154 | """ 155 | Reverses the order of elements in an array along the given axis. The shape 156 | of the array is preserved. 157 | 158 | Args: 159 | x: Input array. 160 | axis: Axis (or axes) along which to flip. If `axis` is `None`, the 161 | function flips all input array axes. If `axis` is negative, the 162 | function counts from the last dimension. If provided more than one 163 | axis, the function flips only the specified axes. 164 | 165 | Returns: 166 | An output array having the same data type and shape as `x` and whose 167 | elements, relative to `x`, are reordered. 168 | 169 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.flip.html 170 | """ 171 | 172 | x # noqa: B018, pylint: disable=W0104 173 | axis # noqa: B018, pylint: disable=W0104 174 | raise NotImplementedError("TODO 118") # noqa: EM101 175 | 176 | 177 | def permute_dims(x: array, /, axes: tuple[int, ...]) -> array: 178 | """ 179 | Permutes the axes (dimensions) of an array `x`. 180 | 181 | Args: 182 | x: Input array. 183 | axes: Tuple containing a permutation of `(0, 1, ..., N-1)` where `N` is 184 | the number of axes (dimensions) of `x`. 185 | 186 | Returns: 187 | An array containing the axes permutation. The returned array has the 188 | same data type as `x`. 189 | 190 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.permute_dims.html 191 | """ 192 | 193 | x # noqa: B018, pylint: disable=W0104 194 | axes # noqa: B018, pylint: disable=W0104 195 | raise NotImplementedError("TODO 119") # noqa: EM101 196 | 197 | 198 | def reshape(x: array, /, shape: tuple[int, ...], *, copy: None | bool = None) -> array: 199 | """ 200 | Reshapes an array without changing its data. 201 | 202 | Args: 203 | x: Input array to reshape. 204 | shape: A new shape compatible with the original shape. One shape 205 | dimension is allowed to be -1. When a shape dimension is -1, the 206 | corresponding output array shape dimension is inferred from the 207 | length of the array and the remaining dimensions. 208 | copy: Boolean indicating whether or not to copy the input array. If 209 | `True`, the function always copies. If `False`, the function never 210 | copies and raises a `ValueError` in case a copy would be necessary. 211 | If `None`, the function reuses the existing memory buffer if 212 | possible and copies otherwise. 213 | 214 | Returns: 215 | An output array having the same data type and elements as `x`. 216 | 217 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html 218 | """ 219 | 220 | x # noqa: B018, pylint: disable=W0104 221 | shape # noqa: B018, pylint: disable=W0104 222 | copy # noqa: B018, pylint: disable=W0104 223 | raise NotImplementedError("TODO 120") # noqa: EM101 224 | 225 | 226 | def roll( 227 | x: array, 228 | /, 229 | shift: int | tuple[int, ...], 230 | *, 231 | axis: None | int | tuple[int, ...] = None, 232 | ) -> array: 233 | """ 234 | Rolls array elements along a specified axis. Array elements that roll 235 | beyond the last position are re-introduced at the first position. Array 236 | elements that roll beyond the first position are re-introduced at the last 237 | position. 238 | 239 | Args: 240 | x: Input array. 241 | shift: Number of places by which the elements are shifted. If `shift` 242 | is a tuple, then `axis` must be a tuple of the same size, and each 243 | of the given axes must be shifted by the corresponding element in 244 | `shift`. If `shift` is an `int` and `axis` a tuple, then the same 245 | shift is used for all specified axes. If a shift is positive, then 246 | array elements are shifted positively (toward larger indices) along 247 | the dimension of `axis`. If a `shift` is negative, then array 248 | elements are shifted negatively (toward smaller indices) along the 249 | dimension of `axis`. 250 | axis: Axis (or axes) along which elements to shift. If `axis` is 251 | `None`, the array is flattened, shifted, and then restored to its 252 | original shape. 253 | 254 | Returns: 255 | An output array having the same data type as `x` and whose elements, 256 | relative to `x`, are shifted. 257 | 258 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.roll.html 259 | """ 260 | 261 | x # noqa: B018, pylint: disable=W0104 262 | shift # noqa: B018, pylint: disable=W0104 263 | axis # noqa: B018, pylint: disable=W0104 264 | raise NotImplementedError("TODO 121") # noqa: EM101 265 | 266 | 267 | def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array: 268 | """ 269 | Removes singleton dimensions (axes) from `x`. 270 | 271 | Args: 272 | x: Input array. 273 | axis: Axis (or axes) to squeeze. If a specified axis has a size 274 | greater than one, a `ValueError` is raised. 275 | 276 | Returns: 277 | An output array having the same data type and elements as `x`. 278 | 279 | This is the opposite of `ragged.expand_dims`. 280 | 281 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html 282 | """ 283 | 284 | if isinstance(axis, numbers.Integral): 285 | axis = (axis,) # type: ignore[assignment] 286 | 287 | posaxis = [] 288 | for axisitem in axis: # type: ignore[union-attr] 289 | posaxisitem = axisitem + x.ndim if axisitem < 0 else axisitem 290 | if not 0 <= posaxisitem < x.ndim and not posaxisitem == x.ndim == 0: 291 | msg = f"axis {axisitem} is out of bounds for array of dimension {x.ndim}" 292 | raise ak.errors.AxisError(msg) 293 | posaxis.append(posaxisitem) 294 | 295 | if not isinstance(x._impl, ak.Array): # pylint: disable=W0212 296 | return x._new(x._impl, x._shape, x._dtype, x._device) # pylint: disable=W0212 297 | 298 | out = x._impl # pylint: disable=W0212 299 | shape = list(x.shape) 300 | for i, shapeitem in reversed(list(enumerate(x.shape))): 301 | if i in posaxis: 302 | if shapeitem is None: 303 | if not np.all(ak.num(out, axis=i) == 1): 304 | msg = "cannot select an axis to squeeze out which has size not equal to one" 305 | raise ValueError(msg) 306 | else: 307 | out = out[(slice(None),) * i + (0,)] 308 | del shape[i] 309 | 310 | elif shapeitem == 1: 311 | out = out[(slice(None),) * i + (0,)] 312 | del shape[i] 313 | 314 | else: 315 | msg = "cannot select an axis to squeeze out which has size not equal to one" 316 | raise ValueError(msg) 317 | 318 | return x._new(out, tuple(shape), x.dtype, x.device) # pylint: disable=W0212 319 | 320 | 321 | def stack(arrays: tuple[array, ...] | list[array], /, *, axis: int = 0) -> array: 322 | """ 323 | Joins a sequence of arrays along a new axis. 324 | 325 | Args: 326 | arrays: Input arrays to join. Each array must have the same shape. 327 | axis: Axis along which the arrays will be joined. Providing an `axis` 328 | specifies the index of the new axis in the dimensions of the 329 | result. For example, if `axis` is 0, the new axis will be the first 330 | dimension and the output array will have shape `(N, A, B, C)`; if 331 | `axis` is 1, the new axis will be the second dimension and the 332 | output array will have shape `(A, N, B, C)`; and, if `axis` is -1, 333 | the new axis will be the last dimension and the output array will 334 | have shape `(A, B, C, N)`. A valid axis must be on the interval 335 | `[-N, N)`, where `N` is the rank (number of dimensions) of `x`. 336 | If provided an `axis` outside of the required interval, the 337 | function raises an exception. 338 | 339 | Returns: 340 | An output array having rank `N + 1`, where `N` is the rank (number of 341 | dimensions) of `x`. If the input arrays have different data types, 342 | normal type promotion rules apply. If the input arrays have the same 343 | data type, the output array has the same data type as the input arrays. 344 | 345 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.stack.html 346 | """ 347 | 348 | arrays # noqa: B018, pylint: disable=W0104 349 | axis # noqa: B018, pylint: disable=W0104 350 | raise NotImplementedError("TODO 123") # noqa: EM101 351 | -------------------------------------------------------------------------------- /src/ragged/_spec_searching_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/searching_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | import numpy as np 11 | 12 | from ._import import device_namespace 13 | from ._spec_array_object import _box, _unbox, array 14 | 15 | 16 | def _remove_optiontype(x: ak.contents.Content) -> ak.contents.Content: 17 | if x.is_list: 18 | return x.copy(content=_remove_optiontype(x.content)) 19 | elif x.is_option: 20 | return x.content 21 | else: 22 | return x 23 | 24 | 25 | def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: 26 | """ 27 | Returns the indices of the maximum values along a specified axis. 28 | 29 | When the maximum value occurs multiple times, only the indices 30 | corresponding to the first occurrence are returned. 31 | 32 | Args: 33 | x: Input array. 34 | axis: Axis along which to search. If `None`, the function returns the 35 | index of the maximum value of the flattened array. 36 | keepdims: If `True`, the reduced axes (dimensions) are included in the 37 | result as singleton dimensions, and, accordingly, the result is 38 | broadcastable with the input array. Otherwise, if `False`, the 39 | reduced axes (dimensions) are not included in the result. 40 | 41 | Returns: 42 | If `axis` is `None`, a zero-dimensional array containing the index of 43 | the first occurrence of the maximum value; otherwise, a 44 | non-zero-dimensional array containing the indices of the maximum 45 | values. The returned array has data type `np.int64`. 46 | 47 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmax.html 48 | """ 49 | 50 | out = np.argmax(*_unbox(x), axis=axis, keepdims=keepdims) 51 | 52 | if out is None: 53 | msg = "cannot compute argmax of an array with no data" 54 | raise ValueError(msg) 55 | 56 | if isinstance(out, ak.Array): 57 | if ak.any(ak.is_none(out, axis=-1)): 58 | msg = f"cannot compute argmax at axis={axis} because some lists at this depth have zero length" 59 | raise ValueError(msg) 60 | out = ak.Array( 61 | _remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs 62 | ) 63 | 64 | return _box(type(x), out) 65 | 66 | 67 | def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: 68 | """ 69 | Returns the indices of the minimum values along a specified axis. 70 | 71 | When the minimum value occurs multiple times, only the indices 72 | corresponding to the first occurrence are returned. 73 | 74 | Args: 75 | x: Input array. 76 | axis: Axis along which to search. If `None`, the function returns the 77 | index of the minimum value of the flattened array. 78 | keepdims: If `True`, the reduced axes (dimensions) are included in the 79 | result as singleton dimensions, and, accordingly, the result is 80 | broadcastable with the input array. Otherwise, if `False`, the 81 | reduced axes (dimensions) are not included in the result. 82 | 83 | Returns: 84 | If `axis` is `None`, a zero-dimensional array containing the index of 85 | the first occurrence of the minimum value; otherwise, a 86 | non-zero-dimensional array containing the indices of the minimum 87 | values. The returned array has data type `np.int64`. 88 | 89 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmin.html 90 | """ 91 | 92 | out = np.argmin(*_unbox(x), axis=axis, keepdims=keepdims) 93 | 94 | if out is None: 95 | msg = "cannot compute argmin of an array with no data" 96 | raise ValueError(msg) 97 | 98 | if isinstance(out, ak.Array): 99 | if ak.any(ak.is_none(out, axis=-1)): 100 | msg = f"cannot compute argmin at axis={axis} because some lists at this depth have zero length" 101 | raise ValueError(msg) 102 | out = ak.Array( 103 | _remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs 104 | ) 105 | 106 | return _box(type(x), out) 107 | 108 | 109 | def nonzero(x: array, /) -> tuple[array, ...]: 110 | """ 111 | Returns the indices of the array elements which are non-zero. 112 | 113 | Args: 114 | x: Input array. Must have a positive rank. If `x` is zero-dimensional, 115 | the function raises an exception. 116 | 117 | Returns: 118 | A tuple of `k` arrays, one for each dimension of `x` and each of size 119 | `n` (where `n` is the total number of non-zero elements), containing 120 | the indices of the non-zero elements in that dimension. The indices 121 | are returned in row-major, C-style order. The returned array has data 122 | type `np.int64`. 123 | 124 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html 125 | """ 126 | 127 | (impl,) = _unbox(x) 128 | if not isinstance(impl, ak.Array): 129 | impl = ak.Array(impl.reshape((1,))) # type: ignore[union-attr] 130 | 131 | return tuple(_box(type(x), item) for item in ak.where(impl)) 132 | 133 | 134 | def where(condition: array, x1: array, x2: array, /) -> array: 135 | """ 136 | Returns elements chosen from `x1` or `x2` depending on `condition`. 137 | 138 | Args: 139 | condition: When `True`, yield `x1_i`; otherwise, yield `x2_i`. Must be 140 | broadcastable with `x1` and `x2`. 141 | x1: First input array. Must be broadcastable with `condition` and `x2`. 142 | x2: Second input array. Must be broadcastable with `condition` and 143 | `x1`. 144 | 145 | Returns: 146 | An array with elements from `x1` where condition is `True`, and 147 | elements from `x2` elsewhere. The returned array has a data type 148 | determined by type promotion rules with the arrays `x1` and `x2`. 149 | 150 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html 151 | """ 152 | 153 | if condition.ndim == x1.ndim == x2.ndim == 0: 154 | cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2) 155 | _, ns = device_namespace(condition.device) 156 | return _box(type(condition), ns.where(cond_impl, x1_impl, x2_impl)) 157 | 158 | else: 159 | cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2) 160 | if not isinstance(cond_impl, ak.Array): 161 | cond_impl = ak.Array(cond_impl.reshape((1,))) # type: ignore[union-attr] 162 | if not isinstance(x1_impl, ak.Array): 163 | x1_impl = ak.Array(x1_impl.reshape((1,))) # type: ignore[union-attr] 164 | if not isinstance(x2_impl, ak.Array): 165 | x2_impl = ak.Array(x2_impl.reshape((1,))) # type: ignore[union-attr] 166 | 167 | cond_impl, x1_impl, x2_impl = ak.broadcast_arrays(cond_impl, x1_impl, x2_impl) 168 | 169 | return _box(type(condition), ak.where(cond_impl, x1_impl, x2_impl)) 170 | -------------------------------------------------------------------------------- /src/ragged/_spec_set_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/set_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from collections import namedtuple 10 | 11 | import awkward as ak 12 | import numpy as np 13 | 14 | import ragged 15 | 16 | from ._spec_array_object import array 17 | 18 | unique_all_result = namedtuple( # pylint: disable=C0103 19 | "unique_all_result", ["values", "indices", "inverse_indices", "counts"] 20 | ) 21 | 22 | 23 | def unique_all(x: array, /) -> tuple[array, array, array, array]: 24 | """ 25 | Returns the unique elements of an input array `x`, the first occurring 26 | indices for each unique element in `x`, the indices from the set of unique 27 | elements that reconstruct `x`, and the corresponding `counts` for each 28 | unique element in `x`. 29 | 30 | Args: 31 | x: Input array. If `x` has more than one dimension, the function 32 | flattens `x` and returns the unique elements of the flattened 33 | array. 34 | 35 | Returns: 36 | A namedtuple `(values, indices, inverse_indices, counts)` whose 37 | 38 | - first element has the field name `values` and must be an array 39 | containing the unique elements of `x`. The array has the same data 40 | type as `x`. 41 | - second element has the field name `indices` and is an array containing 42 | the indices (first occurrences) of `x` that result in values. The 43 | array has the same shape as `values` and has the default array index 44 | data type. 45 | - third element has the field name `inverse_indices` and is an array 46 | containing the indices of values that reconstruct `x`. The array has 47 | the same shape as `x` and has data type `np.int64`. 48 | - fourth element has the field name `counts` and is an array containing 49 | the number of times each unique element occurs in `x`. The returned 50 | array has same shape as `values` and has data type `np.int64`. 51 | 52 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html 53 | """ 54 | 55 | if isinstance(x, ragged.array): 56 | if x.ndim == 0: 57 | return unique_all_result( 58 | values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212 59 | indices=ragged.array([0]), 60 | inverse_indices=ragged.array([0]), 61 | counts=ragged.array([1]), 62 | ) 63 | else: 64 | x_flat = ak.ravel(x._impl) # pylint: disable=W0212 65 | if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 66 | return unique_all_result( 67 | values=ragged.array(np.empty(0, x.dtype)), 68 | indices=ragged.array(np.empty(0, np.int64)), 69 | inverse_indices=ragged.array(np.empty(0, np.int64)), 70 | counts=ragged.array(np.empty(0, np.int64)), 71 | ) 72 | values, indices, inverse_indices, counts = np.unique( 73 | x_flat.layout.data, # pylint: disable=E1101 74 | return_index=True, 75 | return_inverse=True, 76 | return_counts=True, 77 | equal_nan=False, 78 | ) 79 | return unique_all_result( 80 | values=ragged.array(values), 81 | indices=ragged.array(indices), 82 | inverse_indices=ragged.array(inverse_indices), 83 | counts=ragged.array(counts), 84 | ) 85 | else: 86 | msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable] 87 | raise TypeError(msg) 88 | 89 | 90 | unique_counts_result = namedtuple( # pylint: disable=C0103 91 | "unique_counts_result", ["values", "counts"] 92 | ) 93 | 94 | 95 | def unique_counts(x: array, /) -> tuple[array, array]: 96 | """ 97 | Returns the unique elements of an input array `x` and the corresponding 98 | counts for each unique element in `x`. 99 | 100 | Args: 101 | x: Input array. If `x` has more than one dimension, the function 102 | flattens `x` and returns the unique elements of the flattened 103 | array. 104 | 105 | Returns: 106 | A namedtuple `(values, counts)` whose 107 | 108 | - first element has the field name `values` and is an array containing 109 | the unique elements of `x`. The array has the same data type as `x`. 110 | - second element has the field name `counts` and is an array containing 111 | the number of times each unique element occurs in `x`. The returned 112 | array has same shape as `values` and has data type `np.int64`. 113 | 114 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_counts.html 115 | """ 116 | if isinstance(x, ragged.array): 117 | if x.ndim == 0: 118 | return unique_counts_result( 119 | values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212 120 | counts=ragged.array([1]), # pylint: disable=W0212 121 | ) 122 | else: 123 | x_flat = ak.ravel(x._impl) # pylint: disable=W0212 124 | if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 125 | return unique_counts_result( 126 | values=ragged.array(np.empty(0, x.dtype)), 127 | counts=ragged.array(np.empty(0, np.int64)), 128 | ) 129 | values, counts = np.unique( 130 | x_flat.layout.data, # pylint: disable=E1101 131 | return_counts=True, 132 | equal_nan=False, 133 | ) 134 | return unique_counts_result( 135 | values=ragged.array(values), counts=ragged.array(counts) 136 | ) 137 | else: 138 | msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable] 139 | raise TypeError(msg) 140 | 141 | 142 | unique_inverse_result = namedtuple( # pylint: disable=C0103 143 | "unique_inverse_result", ["values", "inverse_indices"] 144 | ) 145 | 146 | 147 | def unique_inverse(x: array, /) -> tuple[array, array]: 148 | """ 149 | Returns the unique elements of an input array `x` and the indices from the 150 | set of unique elements that reconstruct `x`. 151 | 152 | Args: 153 | x: Input array. If `x` has more than one dimension, the function 154 | flattens `x` and returns the unique elements of the flattened 155 | array. 156 | 157 | Returns: 158 | A namedtuple `(values, inverse_indices)` whose 159 | 160 | - first element has the field name `values` and is an array containing 161 | the unique elements of `x`. The array has the same data type as `x`. 162 | - second element has the field name `inverse_indices` and is an array 163 | containing the indices of `values` that reconstruct `x`. The array 164 | has the same shape as `x` and data type `np.int64`. 165 | 166 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_inverse.html 167 | """ 168 | if isinstance(x, ragged.array): 169 | if x.ndim == 0: 170 | return unique_inverse_result( 171 | values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212 172 | inverse_indices=ragged.array([0]), 173 | ) 174 | else: 175 | x_flat = ak.ravel(x._impl) # pylint: disable=W0212 176 | if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 177 | return unique_inverse_result( 178 | values=ragged.array(np.empty(0, x.dtype)), 179 | inverse_indices=ragged.array(np.empty(0, np.int64)), 180 | ) 181 | values, inverse_indices = np.unique( 182 | x_flat.layout.data, # pylint: disable=E1101 183 | return_inverse=True, 184 | equal_nan=False, 185 | ) 186 | 187 | return unique_inverse_result( 188 | values=ragged.array(values), 189 | inverse_indices=ragged.array(inverse_indices), 190 | ) 191 | else: 192 | msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable] 193 | raise TypeError(msg) 194 | 195 | 196 | def unique_values(x: array, /) -> array: 197 | """ 198 | Returns the unique elements of an input array `x`. 199 | 200 | Args: 201 | x: Input array. If `x` has more than one dimension, the function 202 | flattens `x` and returns the unique elements of the flattened 203 | array. 204 | 205 | Returns: 206 | An array containing the set of unique elements in `x`. The returned 207 | array has the same data type as `x`. 208 | 209 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html 210 | """ 211 | if isinstance(x, ragged.array): 212 | if x.ndim == 0: 213 | return ragged.array(np.unique(x._impl, equal_nan=False)) # pylint: disable=W0212 214 | 215 | else: 216 | x_flat = ak.ravel(x._impl) # pylint: disable=W0212 217 | if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101 218 | return ragged.array(np.empty(0, x.dtype)) 219 | return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101 220 | else: 221 | err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable] 222 | raise TypeError(err) 223 | -------------------------------------------------------------------------------- /src/ragged/_spec_sorting_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/sorting_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | 11 | from ._spec_array_object import _box, _unbox, array 12 | 13 | 14 | def argsort( 15 | x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True 16 | ) -> array: 17 | """ 18 | Returns the indices that sort an array `x` along a specified axis. 19 | 20 | Args: 21 | x: Input array. 22 | axis: Axis along which to sort. If set to -1, the function sorts along 23 | the last axis. 24 | descending: Sort order. If `True`, the returned indices sort `x` in 25 | descending order (by value). If `False`, the returned indices sort 26 | `x` in ascending order (by value). 27 | stable: Sort stability. If `True`, the returned indices will maintain 28 | the relative order of `x` values which compare as equal. If 29 | `False`, the returned indices may or may not maintain the relative 30 | order of `x` values which compare as equal. 31 | 32 | Returns: 33 | An array of indices. The returned array has the same shape as `x`. 34 | The returned array has data type `np.int64`. 35 | 36 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.argsort.html 37 | """ 38 | 39 | (impl,) = _unbox(x) 40 | if not isinstance(impl, ak.Array): 41 | msg = f"axis {axis} is out of bounds for array of dimension 0" 42 | raise ak.errors.AxisError(msg) 43 | out = ak.argsort(impl, axis=axis, ascending=not descending, stable=stable) 44 | return _box(type(x), out) 45 | 46 | 47 | def sort( 48 | x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True 49 | ) -> array: 50 | """ 51 | Returns a sorted copy of an input array `x`. 52 | 53 | Args: 54 | x: Input array. 55 | axis: Axis along which to sort. If set to -1, the function sorts along 56 | the last axis. 57 | descending: Sort order. If `True`, the array is sorted in descending 58 | order (by value). If `False`, the array is sorted in ascending 59 | order (by value). 60 | stable: Sort stability. If `True`, the returned array will maintain the 61 | relative order of `x` values which compare as equal. If `False`, 62 | the returned array may or may not maintain the relative order of 63 | `x` values which compare as equal. 64 | 65 | Returns: 66 | A sorted array. The returned array has the same data type and shape as 67 | `x`. 68 | 69 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.sort.html 70 | """ 71 | 72 | (impl,) = _unbox(x) 73 | if not isinstance(impl, ak.Array): 74 | msg = f"axis {axis} is out of bounds for array of dimension 0" 75 | raise ak.errors.AxisError(msg) 76 | out = ak.sort(impl, axis=axis, ascending=not descending, stable=stable) 77 | return _box(type(x), out) 78 | -------------------------------------------------------------------------------- /src/ragged/_spec_statistical_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/statistical_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import numbers 10 | 11 | import awkward as ak 12 | import numpy as np 13 | 14 | from ._spec_array_object import _box, _unbox, array 15 | from ._typing import Dtype 16 | 17 | 18 | def _regularize_axis( 19 | axis: None | int | tuple[int, ...], ndim: int 20 | ) -> None | tuple[int, ...]: 21 | if axis is None: 22 | return axis 23 | elif isinstance(axis, numbers.Integral): 24 | out = axis + ndim if axis < 0 else axis # type: ignore[operator] 25 | if not 0 <= out < ndim: 26 | msg = f"axis {axis} is out of bounds for an array with {ndim} dimensions" 27 | raise ak.errors.AxisError(msg) 28 | return out # type: ignore[no-any-return] 29 | else: 30 | out = [] 31 | for x in axis: # type: ignore[union-attr] 32 | out.append(x + ndim if x < 0 else x) 33 | if not 0 < out[-1] < ndim: 34 | msg = f"axis {x} is out of bounds for an array with {ndim} dimensions" 35 | if len(out) == 0: 36 | msg = "at least one axis must be specified" 37 | raise ak.errors.AxisError(msg) 38 | return tuple(sorted(out)) 39 | 40 | 41 | def _regularize_dtype(dtype: None | Dtype, array_dtype: Dtype) -> Dtype: 42 | if dtype is None: 43 | if array_dtype.kind in ("b", "i"): 44 | return np.dtype(np.int64) 45 | elif array_dtype.kind == "u": 46 | return np.dtype(np.uint64) 47 | elif array_dtype.kind == "f": 48 | return np.dtype(np.float64) 49 | elif array_dtype.kind == "c": 50 | return np.dtype(np.complex128) 51 | else: 52 | msg = f"unrecognized dtype.kind: {array_dtype.kind}" 53 | raise AssertionError(msg) 54 | else: 55 | return dtype 56 | 57 | 58 | def _ensure_dtype(data: array, dtype: Dtype) -> array: 59 | if data.dtype == dtype: 60 | return data 61 | else: 62 | (tmp,) = _unbox(data) 63 | if isinstance(tmp, ak.Array): 64 | return _box(type(data), ak.values_astype(tmp, dtype)) 65 | else: 66 | return _box(type(data), tmp.astype(dtype)) # type: ignore[union-attr] 67 | 68 | 69 | def max( # pylint: disable=W0622 70 | x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False 71 | ) -> array: 72 | """ 73 | Calculates the maximum value of the input array `x`. 74 | 75 | Args: 76 | x: Input array. 77 | axis: Axis or axes along which maximum values are computed. By default, 78 | the maximum value is computed over the entire array. If a tuple of 79 | integers, maximum values must be computed over multiple axes. 80 | keepdims: If `True`, the reduced axes (dimensions) are included in the 81 | result as singleton dimensions, and, accordingly, the result is 82 | broadcastable with the input array. Otherwise, if `False`, the 83 | reduced axes (dimensions) are not included in the result. 84 | 85 | Returns: 86 | If the maximum value was computed over the entire array, a 87 | zero-dimensional array containing the maximum value; otherwise, a 88 | non-zero-dimensional array containing the maximum values. The returned 89 | array has the same data type as `x`. 90 | 91 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.max.html 92 | """ 93 | 94 | axis = _regularize_axis(axis, x.ndim) 95 | 96 | if isinstance(axis, tuple): 97 | (out,) = _unbox(x) 98 | for axis_item in axis[::-1]: 99 | if isinstance(out, ak.Array): 100 | out = ak.max( 101 | out, axis=axis_item, keepdims=keepdims, mask_identity=False 102 | ) 103 | else: 104 | out = np.max(out, axis=axis_item, keepdims=keepdims) 105 | return _box(type(x), out) 106 | else: 107 | (tmp,) = _unbox(x) 108 | if isinstance(tmp, ak.Array): 109 | out = ak.max(tmp, axis=axis, keepdims=keepdims, mask_identity=False) 110 | else: 111 | out = np.max(tmp, axis=axis, keepdims=keepdims) 112 | return _box(type(x), out) 113 | 114 | 115 | def mean( 116 | x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False 117 | ) -> array: 118 | """ 119 | Calculates the arithmetic mean of the input array `x`. 120 | 121 | Args: 122 | x: Input array. 123 | axis: Axis or axes along which arithmetic means are computed. By 124 | default, the mean is computed over the entire array. If a tuple of 125 | integers, arithmetic means are computed over multiple axes. 126 | keepdims: If `True`, the reduced axes (dimensions) are included in the 127 | result as singleton dimensions, and, accordingly, the result is 128 | broadcastable with the input array. Otherwise, if `False`, the 129 | reduced axes (dimensions) are not included in the result. 130 | 131 | Returns: 132 | If the arithmetic mean was computed over the entire array, a 133 | zero-dimensional array containing the arithmetic mean; otherwise, a 134 | non-zero-dimensional array containing the arithmetic means. The 135 | returned array has the same data type as `x`. 136 | 137 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.mean.html 138 | """ 139 | 140 | axis = _regularize_axis(axis, x.ndim) 141 | 142 | if isinstance(axis, tuple): 143 | sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) 144 | sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims) 145 | for axis_item in axis[-2::-1]: 146 | sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims) 147 | sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims) 148 | else: 149 | sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims) 150 | sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims) 151 | 152 | with np.errstate(invalid="ignore", divide="ignore"): 153 | return _ensure_dtype(_box(type(x), sumwx / sumw), x.dtype) 154 | 155 | 156 | def min( # pylint: disable=W0622 157 | x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False 158 | ) -> array: 159 | """ 160 | Calculates the minimum value of the input array `x`. 161 | 162 | Args: 163 | x: Input array. 164 | axis: Axis or axes along which minimum values are computed. By default, 165 | the minimum value are computed over the entire array. If a tuple of 166 | integers, minimum values are computed over multiple axes. 167 | keepdims: If `True`, the reduced axes (dimensions) are included in the 168 | result as singleton dimensions, and, accordingly, the result is 169 | broadcastable with the input array. Otherwise, if `False`, the 170 | reduced axes (dimensions) are not included in the result. 171 | 172 | Returns: 173 | If the minimum value was computed over the entire array, a 174 | zero-dimensional array containing the minimum value; otherwise, a 175 | non-zero-dimensional array containing the minimum values. The returned 176 | array has the same data type as `x`. 177 | 178 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.min.html 179 | """ 180 | 181 | axis = _regularize_axis(axis, x.ndim) 182 | 183 | if isinstance(axis, tuple): 184 | (out,) = _unbox(x) 185 | for axis_item in axis[::-1]: 186 | if isinstance(out, ak.Array): 187 | out = ak.min( 188 | out, axis=axis_item, keepdims=keepdims, mask_identity=False 189 | ) 190 | else: 191 | out = np.min(out, axis=axis_item, keepdims=keepdims) 192 | return _box(type(x), out) 193 | else: 194 | (tmp,) = _unbox(x) 195 | if isinstance(tmp, ak.Array): 196 | out = ak.min(tmp, axis=axis, keepdims=keepdims, mask_identity=False) 197 | else: 198 | out = np.min(tmp, axis=axis, keepdims=keepdims) 199 | return _box(type(x), out) 200 | 201 | 202 | def prod( 203 | x: array, 204 | /, 205 | *, 206 | axis: None | int | tuple[int, ...] = None, 207 | dtype: None | Dtype = None, 208 | keepdims: bool = False, 209 | ) -> array: 210 | """ 211 | Calculates the product of input array `x` elements. 212 | 213 | Args: 214 | x: Input array. 215 | axis: Axis or axes along which products are computed. By default, the 216 | product is computed over the entire array. If a tuple of integers, 217 | products are computed over multiple axes. 218 | dtype: Data type of the returned array. If `None`, 219 | 220 | - if the default data type corresponding to the data type "kind" 221 | a (integer, real-valued floating-point, or complex floating-point) 222 | of `x` has a smaller range of values than the data type of `x` 223 | (e.g., `x` has data type `int64` and the default data type is 224 | `int32`, or `x` has data type `uint64` and the default data type 225 | is `int64`), the returned array has the same data type as `x`. 226 | - if `x` has a real-valued floating-point data type, the returned 227 | array has the default real-valued floating-point data type. 228 | - if `x` has a complex floating-point data type, the returned array 229 | has data type `np.complex128`. 230 | - if `x` has a signed integer data type (e.g., `int16`), the 231 | returned array has data type `np.int64`. 232 | - if `x` has an unsigned integer data type (e.g., `uint16`), the 233 | returned array has data type `np.uint64`. 234 | 235 | If the data type (either specified or resolved) differs from the 236 | data type of `x`, the input array will be cast to the specified 237 | data type before computing the product. 238 | 239 | keepdims: If `True`, the reduced axes (dimensions) are included in the 240 | result as singleton dimensions, and, accordingly, the result is 241 | broadcastable with the input array. Otherwise, if `False`, the 242 | reduced axes (dimensions) are not included in the result. 243 | 244 | Returns: 245 | If the product was computed over the entire array, a zero-dimensional 246 | array containing the product; otherwise, a non-zero-dimensional array 247 | containing the products. The returned array has a data type as 248 | described by the `dtype` parameter above. 249 | 250 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html 251 | """ 252 | 253 | axis = _regularize_axis(axis, x.ndim) 254 | dtype = _regularize_dtype(dtype, x.dtype) 255 | arr = _box(type(x), ak.values_astype(*_unbox(x), dtype)) if x.dtype == dtype else x 256 | 257 | if isinstance(axis, tuple): 258 | (out,) = _unbox(arr) 259 | for axis_item in axis[::-1]: 260 | out = np.prod(out, axis=axis_item, keepdims=keepdims) 261 | return _box(type(x), out) 262 | else: 263 | return _box(type(x), np.prod(*_unbox(arr), axis=axis, keepdims=keepdims)) 264 | 265 | 266 | def std( 267 | x: array, 268 | /, 269 | *, 270 | axis: None | int | tuple[int, ...] = None, 271 | correction: None | int | float = 0.0, 272 | keepdims: bool = False, 273 | ) -> array: 274 | """ 275 | Calculates the standard deviation of the input array `x`. 276 | 277 | Args: 278 | x: Input array. 279 | axis: Axis or axes along which standard deviations are computed. By 280 | default, the standard deviation is computed over the entire array. 281 | If a tuple of integers, standard deviations is computed over 282 | multiple axes. 283 | correction: Degrees of freedom adjustment. Setting this parameter to a 284 | value other than 0 has the effect of adjusting the divisor during 285 | the calculation of the standard deviation according to `N - c` 286 | where `N` corresponds to the total number of elements over which 287 | the standard deviation is computed and `c` corresponds to the 288 | provided degrees of freedom adjustment. When computing the standard 289 | deviation of a population, setting this parameter to 0 is the 290 | standard choice (i.e., the provided array contains data 291 | constituting an entire population). When computing the corrected 292 | sample standard deviation, setting this parameter to 1 is the 293 | standard choice (i.e., the provided array contains data sampled 294 | from a larger population; this is commonly referred to as Bessel's 295 | correction). 296 | keepdims: If `True`, the reduced axes (dimensions) are included in the 297 | result as singleton dimensions, and, accordingly, the result is 298 | broadcastable with the input array. Otherwise, if `False`, the 299 | reduced axes (dimensions) are not included in the result. 300 | 301 | Returns: 302 | If the standard deviation was computed over the entire array, a 303 | zero-dimensional array containing the standard deviation; otherwise, a 304 | non-zero-dimensional array containing the standard deviations. 305 | The returned array has the same data type as `x`. 306 | 307 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html 308 | """ 309 | 310 | return _box( 311 | type(x), 312 | np.sqrt(*_unbox(var(x, axis=axis, correction=correction, keepdims=keepdims))), 313 | ) 314 | 315 | 316 | def sum( # pylint: disable=W0622 317 | x: array, 318 | /, 319 | *, 320 | axis: None | int | tuple[int, ...] = None, 321 | dtype: None | Dtype = None, 322 | keepdims: bool = False, 323 | ) -> array: 324 | """ 325 | Calculates the sum of the input array `x`. 326 | 327 | Args: 328 | x: Input array. 329 | axis: Axis or axes along which sums are computed. By default, the sum 330 | is computed over the entire array. If a tuple of integers, sums 331 | are computed over multiple axes. 332 | dtype: Data type of the returned array. If `None`, 333 | 334 | - if the default data type corresponding to the data type "kind" 335 | (integer, real-valued floating-point, or complex floating-point) 336 | of `x` has a smaller range of values than the data type of `x` 337 | (e.g., `x` has data type `int64` and the default data type is 338 | `int32`, or `x` has data type `uint64` and the default data type 339 | is `int64`), the returned array has the same data type as `x`. 340 | - if `x` has a real-valued floating-point data type, the returned 341 | array has the default real-valued floating-point data type. 342 | - if `x` has a complex floating-point data type, the returned array 343 | has data type `np.complex128`. 344 | - if `x` has a signed integer data type (e.g., `int16`), the 345 | returned array has data type `np.int64`. 346 | - if `x` has an unsigned integer data type (e.g., `uint16`), the 347 | returned array has data type `np.uint64`. 348 | 349 | If the data type (either specified or resolved) differs from the 350 | data type of `x`, the input array is cast to the specified data 351 | type before computing the sum. 352 | 353 | keepdims: If `True`, the reduced axes (dimensions) are included in the 354 | result as singleton dimensions, and, accordingly, the result is 355 | broadcastable with the input array. Otherwise, if `False`, the 356 | reduced axes (dimensions) are not included in the result. 357 | 358 | Returns: 359 | If the sum was computed over the entire array, a zero-dimensional array 360 | containing the sum; otherwise, an array containing the sums. The 361 | returned array must have a data type as described by the `dtype` 362 | parameter above. 363 | 364 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html 365 | """ 366 | 367 | axis = _regularize_axis(axis, x.ndim) 368 | dtype = _regularize_dtype(dtype, x.dtype) 369 | arr = _box(type(x), ak.values_astype(*_unbox(x), dtype)) if x.dtype == dtype else x 370 | 371 | if isinstance(axis, tuple): 372 | (out,) = _unbox(arr) 373 | for axis_item in axis[::-1]: 374 | out = np.sum(out, axis=axis_item, keepdims=keepdims) 375 | return _box(type(x), out) 376 | else: 377 | return _box(type(x), np.sum(*_unbox(arr), axis=axis, keepdims=keepdims)) 378 | 379 | 380 | def var( 381 | x: array, 382 | /, 383 | *, 384 | axis: None | int | tuple[int, ...] = None, 385 | correction: None | int | float = 0.0, 386 | keepdims: bool = False, 387 | ) -> array: 388 | """ 389 | Calculates the variance of the input array `x`. 390 | 391 | Args: 392 | x: Input array. 393 | axis: Axis or axes along which variances are computed. By default, the 394 | variance is computed over the entire array. If a tuple of integers, 395 | variances are computed over multiple axes. 396 | correction: Degrees of freedom adjustment. Setting this parameter to a 397 | value other than 0 has the effect of adjusting the divisor during 398 | the calculation of the variance according to `N - c` where `N` 399 | corresponds to the total number of elements over which the variance 400 | is computed and `c` corresponds to the provided degrees of freedom 401 | adjustment. When computing the variance of a population, setting 402 | this parameter to 0 is the standard choice (i.e., the provided 403 | array contains data constituting an entire population). When 404 | computing the unbiased sample variance, setting this parameter to 1 405 | is the standard choice (i.e., the provided array contains data 406 | sampled from a larger population; this is commonly referred to as 407 | Bessel's correction). 408 | keepdims: If `True`, the reduced axes (dimensions) are included in the 409 | result as singleton dimensions, and, accordingly, the result is 410 | broadcastable with the input array. Otherwise, if `False`, the 411 | reduced axes (dimensions) are not included in the result. 412 | 413 | Returns: 414 | If the variance was computed over the entire array, a zero-dimensional 415 | array containing the variance; otherwise, a non-zero-dimensional array 416 | containing the variances. The returned array has the same data type as 417 | `x`. 418 | 419 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html 420 | """ 421 | 422 | axis = _regularize_axis(axis, x.ndim) 423 | 424 | if isinstance(axis, tuple): 425 | sumwxx = np.sum(np.square(*_unbox(x)), axis=axis[-1], keepdims=keepdims) 426 | sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) 427 | sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims) 428 | for axis_item in axis[-2::-1]: 429 | sumwxx = np.sum(sumwxx, axis=axis_item, keepdims=keepdims) 430 | sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims) 431 | sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims) 432 | else: 433 | sumwxx = np.sum(np.square(*_unbox(x)), axis=axis, keepdims=keepdims) 434 | sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims) 435 | sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims) 436 | 437 | with np.errstate(invalid="ignore", divide="ignore"): 438 | out = sumwxx / sumw - np.square(sumwx / sumw) 439 | if correction is not None and correction != 0: 440 | out *= sumw / (sumw - correction) 441 | return _ensure_dtype(_box(type(x), out), x.dtype) 442 | -------------------------------------------------------------------------------- /src/ragged/_spec_utility_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/utility_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import numpy as np 10 | 11 | from ._spec_array_object import _box, _unbox, array 12 | from ._spec_statistical_functions import _regularize_axis 13 | 14 | 15 | def all( # pylint: disable=W0622 16 | x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False 17 | ) -> array: 18 | """ 19 | Tests whether all input array elements evaluate to `True` along a specified 20 | axis. 21 | 22 | Args: 23 | x: Input array. 24 | axis: Axis or axes along which to perform a logical AND reduction. By 25 | default, a logical AND reduction is performed over the entire 26 | array. If a tuple of integers, logical AND reductions are performed 27 | over multiple axes. A valid `axis` must be an integer on the 28 | interval `[-N, N)`, where `N` is the rank (number of dimensions) of 29 | `x`. If an `axis` is specified as a negative integer, the function 30 | must determine the axis along which to perform a reduction by 31 | counting backward from the last dimension (where -1 refers to the 32 | last dimension). If provided an invalid `axis`, the function raises 33 | an exception. 34 | keepdims: If `True`, the reduced axes (dimensions) are included in the 35 | result as singleton dimensions, and, accordingly, the result is 36 | broadcastable with the input array. Otherwise, if `False`, the 37 | reduced axes (dimensions) are not included in the result. 38 | 39 | Returns: 40 | If a logical AND reduction was performed over the entire array, the 41 | returned array is a zero-dimensional array containing the test result; 42 | otherwise, the returned array is a non-zero-dimensional array 43 | containing the test results. The returned array has data type 44 | `np.bool_`. 45 | 46 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.all.html 47 | """ 48 | 49 | axis = _regularize_axis(axis, x.ndim) 50 | 51 | if isinstance(axis, tuple): 52 | (out,) = _unbox(x) 53 | for axis_item in axis[::-1]: 54 | out = np.all(out, axis=axis_item, keepdims=keepdims) 55 | return _box(type(x), out) 56 | else: 57 | return _box(type(x), np.all(*_unbox(x), axis=axis, keepdims=keepdims)) 58 | 59 | 60 | def any( # pylint: disable=W0622 61 | x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False 62 | ) -> array: 63 | """ 64 | Tests whether any input array element evaluates to True along a specified 65 | axis. 66 | 67 | Args: 68 | x: Input array. 69 | axis: Axis or axes along which to perform a logical OR reduction. By 70 | default, a logical OR reduction is performed over the entire array. 71 | If a tuple of integers, logical OR reductions aer performed over 72 | multiple axes. A valid `axis` must be an integer on the interval 73 | `[-N, N)`, where `N` is the rank (number of dimensions) of `x`. If 74 | an `axis` is specified as a negative integer, the function 75 | determines the axis along which to perform a reduction by counting 76 | backward from the last dimension (where -1 refers to the last 77 | dimension). If provided an invalid `axis`, the function raises an 78 | exception. 79 | keepdims: If `True`, the reduced axes (dimensions) aer included in the 80 | result as singleton dimensions, and, accordingly, the result is 81 | broadcastable with the input array. Otherwise, if `False`, the 82 | reduced axes (dimensions) are not included in the result. 83 | 84 | Returns: 85 | If a logical OR reduction was performed over the entire array, the 86 | returned array is a zero-dimensional array containing the test result; 87 | otherwise, the returned array is a non-zero-dimensional array 88 | containing the test results. The returned array has data type 89 | `np.bool_`. 90 | 91 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.any.html 92 | """ 93 | 94 | axis = _regularize_axis(axis, x.ndim) 95 | 96 | if isinstance(axis, tuple): 97 | (out,) = _unbox(x) 98 | for axis_item in axis[::-1]: 99 | out = np.any(out, axis=axis_item, keepdims=keepdims) 100 | return _box(type(x), out) 101 | else: 102 | return _box(type(x), np.any(*_unbox(x), axis=axis, keepdims=keepdims)) 103 | -------------------------------------------------------------------------------- /src/ragged/_typing.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | Borrows liberally from https://github.com/numpy/numpy/blob/main/numpy/array_api/_typing.py 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import enum 10 | import sys 11 | from typing import Any, Literal, Optional, Protocol, TypeVar, Union 12 | 13 | import numpy as np 14 | 15 | T_co = TypeVar("T_co", covariant=True) 16 | 17 | 18 | if sys.version_info >= (3, 12): 19 | from collections.abc import ( # pylint: disable=W0611 20 | Buffer as SupportsBufferProtocol, 21 | ) 22 | else: 23 | SupportsBufferProtocol = Any 24 | 25 | 26 | # not actually checked because of https://github.com/python/typing/discussions/1145 27 | class NestedSequence(Protocol[T_co]): 28 | def __getitem__(self, key: int, /) -> T_co | NestedSequence[T_co]: 29 | ... 30 | 31 | def __len__(self, /) -> int: 32 | ... 33 | 34 | 35 | PyCapsule = Any 36 | 37 | 38 | class SupportsDLPack(Protocol): 39 | def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: 40 | ... 41 | 42 | def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: 43 | ... 44 | 45 | 46 | Shape = tuple[Optional[int], ...] 47 | 48 | Dtype = np.dtype[ 49 | Union[ 50 | np.bool_, 51 | np.int8, 52 | np.int16, 53 | np.int32, 54 | np.int64, 55 | np.uint8, 56 | np.uint16, 57 | np.uint32, 58 | np.uint64, 59 | np.float32, 60 | np.float64, 61 | np.complex64, 62 | np.complex128, 63 | ] 64 | ] 65 | 66 | numeric_types = ( 67 | np.bool_, 68 | np.int8, 69 | np.int16, 70 | np.int32, 71 | np.int64, 72 | np.uint8, 73 | np.uint16, 74 | np.uint32, 75 | np.uint64, 76 | np.float32, 77 | np.float64, 78 | np.complex64, 79 | np.complex128, 80 | ) 81 | 82 | Device = Literal["cpu", "cuda"] 83 | -------------------------------------------------------------------------------- /src/ragged/_version.pyi: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | version: str 4 | version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] 5 | -------------------------------------------------------------------------------- /src/ragged/io/__init__.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | from __future__ import annotations 4 | 5 | from .cf import from_cf_contiguous, from_cf_indexed, to_cf_contiguous, to_cf_indexed 6 | 7 | __all__ = ["to_cf_contiguous", "from_cf_contiguous", "to_cf_indexed", "from_cf_indexed"] 8 | -------------------------------------------------------------------------------- /src/ragged/io/cf.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | from __future__ import annotations 4 | 5 | import awkward as ak 6 | 7 | from .._import import device_namespace 8 | from .._spec_array_object import _box, _unbox, array 9 | 10 | 11 | def to_cf_contiguous(x: array) -> tuple[array, array]: 12 | if x.ndim != 2: 13 | raise NotImplementedError 14 | 15 | (y,) = _unbox(x) 16 | 17 | return _box(type(x), ak.flatten(y)), _box(type(x), ak.num(y)) 18 | 19 | 20 | def from_cf_contiguous(content: array, counts: array) -> array: 21 | if content.ndim != 1 or counts.ndim != 1: 22 | raise NotImplementedError 23 | 24 | cont, cnts = _unbox(content, counts) 25 | 26 | return _box(type(content), ak.unflatten(cont, cnts)) 27 | 28 | 29 | def to_cf_indexed(x: array) -> tuple[array, array]: 30 | if x.ndim != 2: 31 | raise NotImplementedError 32 | 33 | _, ns = device_namespace(x.device) 34 | (y,) = _unbox(x) 35 | 36 | index, _ = ak.broadcast_arrays(ns.arange(len(x), dtype=ns.int64), y) 37 | 38 | return _box(type(x), ak.flatten(y)), _box(type(x), ak.flatten(index)) 39 | 40 | 41 | def from_cf_indexed(content: array, index: array) -> array: 42 | if content.ndim != 1 or index.ndim != 1: 43 | raise NotImplementedError 44 | 45 | _, ns = device_namespace(content.device) 46 | cont, ind = _unbox(content, index) 47 | 48 | counts = ns.zeros(ak.max(ind) + 1, dtype=ns.int64) 49 | ns.add.at(counts, ns.asarray(ind), 1) 50 | 51 | return _box(type(content), ak.unflatten(cont[ns.argsort(ind)], counts)) # type: ignore[index] 52 | 53 | 54 | __all__ = ["to_cf_contiguous", "from_cf_contiguous", "to_cf_indexed", "from_cf_indexed"] 55 | -------------------------------------------------------------------------------- /src/ragged/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-hep/ragged/98709142fe1d0b51973e8cc82c5cfb717b96ee96/src/ragged/py.typed -------------------------------------------------------------------------------- /tests-cuda/test_cuda_spec_set_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/set_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | import cupy as cp 11 | 12 | import ragged 13 | 14 | 15 | def test_existence(): 16 | assert ragged.unique_all is not None 17 | assert ragged.unique_counts is not None 18 | assert ragged.unique_inverse is not None 19 | assert ragged.unique_values is not None 20 | 21 | 22 | # unique_values tests 23 | def test_can_take_list(): 24 | arr = ragged.array(cp.array([1, 2, 4, 3, 4, 5, 6, 20])) 25 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 26 | unique_values = ragged.unique_values(arr) 27 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 28 | 29 | 30 | def test_can_take_empty_arr(): 31 | arr = ragged.array(cp.array([])) 32 | expected_unique_values = ragged.array([]) 33 | unique_values = ragged.unique_values(arr) 34 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 35 | 36 | 37 | def test_can_take_moredimensions(): 38 | arr = ragged.array(ak.Array([[1, 2, 2, 3, 4], [5, 6]], backend="cuda")) 39 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6]) 40 | unique_values = ragged.unique_values(arr) 41 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 42 | 43 | 44 | def test_can_take_1d_array(): 45 | arr = ragged.array(cp.array([5, 6, 7, 8, 8, 9, 1, 2, 3, 4, 10, 0, 15, 2])) 46 | expected_unique_values = ragged.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15]) 47 | assert ak.to_list(ragged.unique_values(arr)) == ak.to_list(expected_unique_values) 48 | 49 | 50 | # unique_counts tests 51 | def test_can_count_list(): 52 | arr = ragged.array(cp.array([1, 2, 4, 3, 4, 5, 6, 20])) 53 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 54 | expected_unique_counts = ragged.array([1, 1, 1, 2, 1, 1, 1]) 55 | unique_values, unique_counts = ragged.unique_counts(arr) 56 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 57 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 58 | 59 | 60 | def test_can_count_empty_arr(): 61 | arr = ragged.array(cp.array([])) 62 | expected_unique_values = ragged.array([]) 63 | expected_counts = ragged.array([]) 64 | unique_values, unique_counts = ragged.unique_counts(arr) 65 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 66 | assert ak.to_list(expected_counts) == ak.to_list(unique_counts) 67 | 68 | 69 | def test_can_count_simple_array(): 70 | arr = ragged.array(cp.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4])) 71 | expected_unique_values = ragged.array([1, 2, 3, 4]) 72 | expected_counts = ragged.array([1, 2, 3, 4]) 73 | unique_values, unique_counts = ragged.unique_counts(arr) 74 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 75 | assert ak.to_list(unique_counts) == ak.to_list(expected_counts) 76 | 77 | 78 | def test_can_count_normal_array(): 79 | arr = ragged.array( 80 | ak.Array([[1, 2, 2], [3], [3, 3], [4, 4, 4], [4]], backend="cuda") 81 | ) 82 | expected_unique_values = ragged.array([1, 2, 3, 4]) 83 | expected_counts = ragged.array([1, 2, 3, 4]) 84 | unique_values, unique_counts = ragged.unique_counts(arr) 85 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 86 | assert ak.to_list(unique_counts) == ak.to_list(expected_counts) 87 | 88 | 89 | # unique_inverse tests 90 | def test_can_inverse_list(): 91 | arr = ragged.array(cp.array([1, 2, 4, 3, 4, 5, 6, 20])) 92 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 93 | expected_inverse_indices = ragged.array([0, 1, 3, 2, 3, 4, 5, 6]) 94 | unique_values, inverse_indices = ragged.unique_inverse(arr) 95 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 96 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 97 | 98 | 99 | def test_can_inverse_empty_arr(): 100 | arr = ragged.array(cp.array([])) 101 | expected_unique_values = ragged.array([]) 102 | expected_inverse_indices = ragged.array([]) 103 | unique_values, inverse_indices = ragged.unique_inverse(arr) 104 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 105 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 106 | 107 | 108 | def test_can_inverse_simple_array(): 109 | arr = ragged.array(ak.Array([[1, 2, 2], [3, 3, 3], [4, 4, 4, 4]], backend="cuda")) 110 | expected_unique_values = ragged.array([1, 2, 3, 4]) 111 | expected_inverse_indices = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 112 | unique_values, inverse_indices = ragged.unique_inverse(arr) 113 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 114 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 115 | 116 | 117 | def test_can_inverse_normal_array(): 118 | arr = ragged.array( 119 | ak.Array([[1, 2, 2], [3], [3, 3], [4, 4, 4], [4]], backend="cuda") 120 | ) 121 | expected_unique_values = ragged.array([1, 2, 3, 4]) 122 | expected_inverse_indices = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 123 | unique_values, inverse_indices = ragged.unique_inverse(arr) 124 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 125 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 126 | 127 | 128 | # unique_all tests 129 | def test_can_all_list(): 130 | arr = ragged.array(cp.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4])) 131 | expected_unique_values = ragged.array([1, 2, 3, 4]) 132 | expected_unique_indices = ragged.array([0, 1, 3, 6]) 133 | expected_unique_inverse = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 134 | expected_unique_counts = ragged.array([1, 2, 3, 4]) 135 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 136 | arr 137 | ) 138 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 139 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 140 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 141 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 142 | 143 | 144 | def test_can_all_empty_arr(): 145 | arr = ragged.array(cp.array([])) 146 | expected_unique_values = ragged.array([]) 147 | expected_unique_indices = ragged.array([]) 148 | expected_unique_inverse = ragged.array([]) 149 | expected_unique_counts = ragged.array([]) 150 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 151 | arr 152 | ) 153 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 154 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 155 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 156 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 157 | 158 | 159 | def test_can_all_normal_array(): 160 | arr = ragged.array(ak.Array([[2, 2, 2], [3], [3, 5], [4, 4, 4], [4]])) 161 | expected_unique_values = ragged.array([2, 3, 4, 5]) 162 | expected_unique_indices = ragged.array([0, 3, 6, 5]) 163 | expected_unique_inverse = ragged.array([0, 0, 0, 1, 1, 3, 2, 2, 2, 2]) 164 | expected_unique_counts = ragged.array([3, 2, 4, 1]) 165 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 166 | arr 167 | ) 168 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 169 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 170 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 171 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 172 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | from __future__ import annotations 4 | 5 | import reprlib 6 | 7 | import awkward as ak 8 | import numpy as np 9 | import pytest 10 | 11 | import ragged 12 | 13 | 14 | @pytest.fixture(scope="session", autouse=True) 15 | def _patch_reprlib(): 16 | if not hasattr(reprlib.Repr, "repr1_original"): 17 | 18 | def repr1(self, x, level): 19 | if isinstance(x, ragged.array): 20 | return self.repr_instance(x, level) 21 | return self.repr1_original(x, level) 22 | 23 | reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined] 24 | reprlib.Repr.repr1 = repr1 # type: ignore[method-assign] 25 | 26 | 27 | @pytest.fixture(params=["regular", "irregular", "scalar"]) 28 | def x(request): 29 | if request.param == "regular": 30 | return ragged.array(np.array([1.0, 2.0, 3.0])) 31 | elif request.param == "irregular": 32 | return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) 33 | else: # request.param == "scalar" 34 | return ragged.array(np.array(10.0)) 35 | 36 | 37 | @pytest.fixture(params=["regular", "irregular", "scalar"]) 38 | def x_lt1(request): 39 | if request.param == "regular": 40 | return ragged.array(np.array([0.1, 0.2, 0.3])) 41 | elif request.param == "irregular": 42 | return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) 43 | else: # request.param == "scalar" 44 | return ragged.array(np.array(0.5)) 45 | 46 | 47 | @pytest.fixture(params=["regular", "irregular", "scalar"]) 48 | def x_bool(request): 49 | if request.param == "regular": 50 | return ragged.array(np.array([False, True, False])) 51 | elif request.param == "irregular": 52 | return ragged.array(ak.Array([[True, True, False], [], [False, False]])) 53 | else: # request.param == "scalar" 54 | return ragged.array(np.array(True)) 55 | 56 | 57 | @pytest.fixture(params=["regular", "irregular", "scalar"]) 58 | def x_int(request): 59 | if request.param == "regular": 60 | return ragged.array(np.array([0, 1, 2], dtype=np.int64)) 61 | elif request.param == "irregular": 62 | return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) 63 | else: # request.param == "scalar" 64 | return ragged.array(np.array(10, dtype=np.int64)) 65 | 66 | 67 | @pytest.fixture(params=["regular", "irregular", "scalar"]) 68 | def x_complex(request): 69 | if request.param == "regular": 70 | return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) 71 | elif request.param == "irregular": 72 | return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) 73 | else: # request.param == "scalar" 74 | return ragged.array(np.array(10 + 1j)) 75 | 76 | 77 | y = x 78 | y_lt1 = x_lt1 79 | y_bool = x_bool 80 | y_int = x_int 81 | y_complex = x_complex 82 | -------------------------------------------------------------------------------- /tests/test_spec_array_object.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/array_object.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import numpy as np 10 | import pytest 11 | 12 | import ragged 13 | 14 | devices = ["cpu"] 15 | try: 16 | import cupy as cp 17 | 18 | devices.append("cuda") 19 | except ModuleNotFoundError: 20 | cp = None 21 | 22 | 23 | def test_existence(): 24 | assert ragged.array is not None 25 | 26 | 27 | def test_item(): 28 | a = ragged.array(np.asarray(123)).item() 29 | assert isinstance(a, int) 30 | assert a == 123 31 | 32 | a = ragged.array(np.asarray([123])).item() 33 | assert isinstance(a, int) 34 | assert a == 123 35 | 36 | a = ragged.array(np.asarray([[123]])).item() 37 | assert isinstance(a, int) 38 | assert a == 123 39 | 40 | 41 | def test_contains(): 42 | a = ragged.array([[1, 2, 3], [], [4, 5]]) 43 | assert 4 in a 44 | assert 6 not in a 45 | 46 | b = a[0, 0] 47 | assert 1 in b 48 | assert 2 not in b 49 | 50 | 51 | def test_len(): 52 | assert len(ragged.array([1, 2, 3])) == 3 53 | with pytest.raises(TypeError, match="unsized object"): 54 | len(ragged.array(123)) 55 | 56 | 57 | def test_iter(): 58 | a = list(ragged.array([1, 2, 3])) 59 | assert isinstance(a[0], ragged.array) 60 | assert isinstance(a[1], ragged.array) 61 | assert isinstance(a[2], ragged.array) 62 | assert a[0] == 1 63 | assert a[1] == 2 64 | assert a[2] == 3 65 | 66 | b = list(ragged.array([[1], [2, 3]])) 67 | assert isinstance(b[0], ragged.array) 68 | assert isinstance(b[1], ragged.array) 69 | assert b[0].tolist() == [1] 70 | assert b[1].tolist() == [2, 3] 71 | 72 | with pytest.raises(TypeError, match="0-d array"): 73 | list(ragged.array(123)) 74 | 75 | 76 | def test_namespace(): 77 | assert ragged.array(123).__array_namespace__() is ragged 78 | assert ( 79 | ragged.array(123).__array_namespace__(api_version=ragged.__array_api_version__) 80 | is ragged 81 | ) 82 | with pytest.raises(NotImplementedError): 83 | ragged.array(123).__array_namespace__(api_version="does not exist") 84 | 85 | 86 | def test_bool(): 87 | assert bool(ragged.array(True)) is True 88 | assert bool(ragged.array(False)) is False 89 | 90 | 91 | def test_complex(): 92 | assert isinstance(complex(ragged.array(1.1 + 0.1j)), complex) 93 | assert complex(ragged.array(1.1 + 0.1j)) == 1.1 + 0.1j 94 | 95 | 96 | @pytest.mark.parametrize("device", devices) 97 | def test_dlpack(device): 98 | lib = np if device == "cpu" else cp 99 | 100 | if not hasattr(lib, "from_dlpack"): 101 | return 102 | 103 | a = ragged.array(lib.arange(2 * 3 * 5).reshape(2, 3, 5), device=device) 104 | assert a.device == device 105 | assert isinstance(a._impl.layout.data, lib.ndarray) # type: ignore[union-attr] 106 | 107 | b = lib.from_dlpack(a) 108 | assert isinstance(b, lib.ndarray) 109 | assert b.shape == a.shape 110 | assert b.dtype == a.dtype 111 | assert b.tolist() == a.tolist() 112 | 113 | a = ragged.array(lib.asarray(123), device=device) 114 | assert a.device == device 115 | assert isinstance(a._impl, lib.ndarray) 116 | 117 | b = lib.from_dlpack(a) 118 | assert isinstance(b, lib.ndarray) 119 | assert b.shape == a.shape 120 | assert b.dtype == a.dtype 121 | assert b.item() == a.item() == 123 122 | 123 | 124 | def test_float(): 125 | assert isinstance(float(ragged.array(1.1)), float) 126 | assert float(ragged.array(1.1)) == 1.1 127 | 128 | 129 | def test_getitem(): 130 | # slices are extensively tested in Awkward Array 131 | a = ragged.array([[1, 2, 3], [4], [5, 6, 7, 8]]) 132 | assert a[..., 1:].tolist() == [[2, 3], [], [6, 7, 8]] # type: ignore[comparison-overlap,index] 133 | 134 | 135 | def test_index(): 136 | assert isinstance(ragged.array(10).__index__(), int) 137 | assert ragged.array(10).__index__() == 10 138 | 139 | 140 | def test_int(): 141 | assert isinstance(int(ragged.array(10)), int) 142 | assert int(ragged.array(10)) == 10 143 | -------------------------------------------------------------------------------- /tests/test_spec_broadcasting.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/broadcasting.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | 10 | def test(): 11 | pass 12 | -------------------------------------------------------------------------------- /tests/test_spec_constants.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/constants.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import math 10 | 11 | import ragged 12 | 13 | 14 | def test_values(): 15 | assert ragged.e == math.e 16 | assert not math.isfinite(ragged.inf) 17 | assert ragged.inf > 0 18 | assert math.isnan(ragged.nan) 19 | assert ragged.newaxis is None 20 | assert ragged.pi == math.pi 21 | -------------------------------------------------------------------------------- /tests/test_spec_creation_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/creation_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import numpy as np 10 | import pytest 11 | 12 | import ragged 13 | 14 | devices = ["cpu"] 15 | ns = {"cpu": np} 16 | try: 17 | import cupy as cp 18 | 19 | devices.append("cuda") 20 | ns["cuda"] = cp 21 | except ModuleNotFoundError: 22 | cp = None 23 | 24 | 25 | def test_existence(): 26 | assert ragged.arange is not None 27 | assert ragged.asarray is not None 28 | assert ragged.empty is not None 29 | assert ragged.empty_like is not None 30 | assert ragged.eye is not None 31 | assert ragged.from_dlpack is not None 32 | assert ragged.full is not None 33 | assert ragged.full_like is not None 34 | assert ragged.linspace is not None 35 | assert ragged.meshgrid is not None 36 | assert ragged.ones is not None 37 | assert ragged.ones_like is not None 38 | assert ragged.tril is not None 39 | assert ragged.triu is not None 40 | assert ragged.zeros is not None 41 | assert ragged.zeros_like is not None 42 | 43 | 44 | @pytest.mark.parametrize("device", devices) 45 | def test_arange(device): 46 | a = ragged.arange(5, 10, 2, device=device) 47 | assert a.tolist() == [5, 7, 9] 48 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 49 | 50 | 51 | @pytest.mark.parametrize("device", devices) 52 | def test_empty(device): 53 | a = ragged.empty((2, 3, 5), device=device) 54 | assert a.shape == (2, 3, 5) 55 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 56 | 57 | 58 | @pytest.mark.parametrize("device", devices) 59 | def test_empty_ndim0(device): 60 | a = ragged.empty((), device=device) 61 | assert a.ndim == 0 62 | assert a.shape == () 63 | assert isinstance(a._impl, ns[device].ndarray) 64 | 65 | 66 | @pytest.mark.parametrize("device", devices) 67 | def test_empty_like(device): 68 | a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) 69 | b = ragged.empty_like(a) 70 | assert (b * 0).tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] 71 | assert a.dtype == b.dtype 72 | assert a.device == b.device == device 73 | 74 | 75 | @pytest.mark.parametrize("device", devices) 76 | def test_eye(device): 77 | a = ragged.eye(3, 5, k=1, device=device) 78 | assert a.tolist() == [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]] 79 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 80 | 81 | 82 | @pytest.mark.skipif( 83 | not hasattr(np, "from_dlpack"), reason=f"np.from_dlpack not in {np.__version__}" 84 | ) 85 | @pytest.mark.parametrize("device", devices) 86 | def test_from_dlpack(device): 87 | a = ns[device].array([1, 2, 3, 4, 5]) 88 | b = ragged.from_dlpack(a) 89 | assert b.tolist() == [1, 2, 3, 4, 5] 90 | assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 91 | 92 | 93 | @pytest.mark.parametrize("device", devices) 94 | def test_full(device): 95 | a = ragged.full(5, 3, device=device) 96 | assert a.tolist() == [3, 3, 3, 3, 3] 97 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 98 | 99 | 100 | @pytest.mark.parametrize("device", devices) 101 | def test_full_ndim0(device): 102 | a = ragged.full((), 3, device=device) 103 | assert a.ndim == 0 104 | assert a.shape == () 105 | assert a == 3 106 | assert isinstance(a._impl, ns[device].ndarray) 107 | 108 | 109 | @pytest.mark.parametrize("device", devices) 110 | def test_full_like(device): 111 | a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) 112 | b = ragged.full_like(a, 5) 113 | assert b.tolist() == [[5, 5, 5], [], [5, 5]] # type: ignore[comparison-overlap] 114 | assert a.dtype == b.dtype 115 | assert a.device == b.device == device 116 | 117 | 118 | @pytest.mark.parametrize("device", devices) 119 | def test_linspace(device): 120 | a = ragged.linspace(5, 8, 5, device=device) 121 | assert a.tolist() == [5, 5.75, 6.5, 7.25, 8] 122 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 123 | 124 | 125 | @pytest.mark.parametrize("device", devices) 126 | def test_ones(device): 127 | a = ragged.ones(5, device=device) 128 | assert a.tolist() == [1, 1, 1, 1, 1] 129 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 130 | 131 | 132 | @pytest.mark.parametrize("device", devices) 133 | def test_ones_ndim0(device): 134 | a = ragged.ones((), device=device) 135 | assert a.ndim == 0 136 | assert a.shape == () 137 | assert a == 1 138 | assert isinstance(a._impl, ns[device].ndarray) 139 | 140 | 141 | @pytest.mark.parametrize("device", devices) 142 | def test_ones_like(device): 143 | a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) 144 | b = ragged.ones_like(a) 145 | assert b.tolist() == [[1, 1, 1], [], [1, 1]] # type: ignore[comparison-overlap] 146 | assert a.dtype == b.dtype 147 | assert a.device == b.device == device 148 | 149 | 150 | @pytest.mark.parametrize("device", devices) 151 | def test_zeros(device): 152 | a = ragged.zeros(5, device=device) 153 | assert a.tolist() == [0, 0, 0, 0, 0] 154 | assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] 155 | 156 | 157 | @pytest.mark.parametrize("device", devices) 158 | def test_zeros_ndim0(device): 159 | a = ragged.zeros((), device=device) 160 | assert a.ndim == 0 161 | assert a.shape == () 162 | assert a == 0 163 | assert isinstance(a._impl, ns[device].ndarray) 164 | 165 | 166 | @pytest.mark.parametrize("device", devices) 167 | def test_zeros_like(device): 168 | a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) 169 | b = ragged.zeros_like(a) 170 | assert b.tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] 171 | assert a.dtype == b.dtype 172 | assert a.device == b.device == device 173 | -------------------------------------------------------------------------------- /tests/test_spec_data_type_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/data_type_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Any 10 | 11 | import awkward as ak 12 | import numpy as np 13 | import pytest 14 | 15 | import ragged 16 | 17 | devices = ["cpu"] 18 | try: 19 | import cupy as cp 20 | 21 | devices.append("cuda") 22 | except ModuleNotFoundError: 23 | cp = None 24 | 25 | 26 | def first(x: ragged.array) -> Any: 27 | out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl 28 | return np.asarray(out.item(), dtype=x.dtype) 29 | 30 | 31 | def test_existence(): 32 | assert ragged.astype is not None 33 | assert ragged.can_cast is not None 34 | assert ragged.finfo is not None 35 | assert ragged.iinfo is not None 36 | assert ragged.isdtype is not None 37 | assert ragged.result_type is not None 38 | 39 | 40 | @pytest.mark.parametrize("device", devices) 41 | @pytest.mark.parametrize("dt", ["float64", np.float64, np.dtype(np.float64)]) 42 | def test_astype(device, x_int, dt): 43 | x = x_int.to_device(device) 44 | y = ragged.astype(x, dt) 45 | assert first(y) == first(x) 46 | assert y.dtype == np.dtype(np.float64) 47 | assert y.device == x.device 48 | 49 | 50 | def test_can_cast(): 51 | assert ragged.can_cast(np.float32, np.complex128) 52 | assert not ragged.can_cast(np.complex128, np.float32) 53 | 54 | 55 | def test_finfo(): 56 | f = ragged.finfo(np.float64) 57 | assert f.bits == 64 58 | assert f.eps == 2.220446049250313e-16 59 | assert f.max == 1.7976931348623157e308 60 | assert f.min == -1.7976931348623157e308 61 | assert f.smallest_normal == 2.2250738585072014e-308 62 | assert f.dtype == np.dtype(np.float64) 63 | 64 | 65 | def test_finfo_array(): 66 | f = ragged.finfo(np.array([1.1, 2.2, 3.3])) 67 | assert f.bits == 64 68 | assert f.dtype == np.dtype(np.float64) 69 | 70 | 71 | def test_finfo_array2(): 72 | f = ragged.finfo(ragged.array([1.1, 2.2, 3.3])) 73 | assert f.bits == 64 74 | assert f.dtype == np.dtype(np.float64) 75 | 76 | 77 | def test_iinfo(): 78 | f = ragged.iinfo(np.int16) 79 | assert f.bits == 16 80 | assert f.max == 32767 81 | assert f.min == -32768 82 | assert f.dtype == np.dtype(np.int16) 83 | 84 | 85 | def test_iinfo_array(): 86 | f = ragged.iinfo(np.array([1, 2, 3], np.int16)) 87 | assert f.bits == 16 88 | assert f.dtype == np.dtype(np.int16) 89 | 90 | 91 | def test_iinfo_array2(): 92 | f = ragged.iinfo(ragged.array([1, 2, 3], np.int16)) 93 | assert f.bits == 16 94 | assert f.dtype == np.dtype(np.int16) 95 | 96 | 97 | def test_result_type(): 98 | dt = ragged.result_type(ragged.array([1, 2, 3]), ragged.array([1.1, 2.2, 3.3])) 99 | assert dt == np.dtype(np.float64) 100 | -------------------------------------------------------------------------------- /tests/test_spec_indexing.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/indexing.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import ragged 10 | 11 | 12 | def test(): 13 | # slices are extensively tested in Awkward Array, just check 'axis' argument 14 | 15 | a = ragged.array([0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]) 16 | assert ragged.take(a, ragged.array([5, 3, 3, 9, 0, 1]), axis=0).tolist() == [ 17 | 5.5, 18 | 3.3, 19 | 3.3, 20 | 9.9, 21 | 0, 22 | 1.1, 23 | ] 24 | 25 | b = ragged.array([[0.0, 1.1, 2.2], [3.3, 4.4], [5.5, 6.6, 7.7, 8.8, 9.9]]) 26 | assert ragged.take(b, ragged.array([0, 1, 1, 0]), axis=1).tolist() == [ 27 | [0, 1.1, 1.1, 0], 28 | [3.3, 4.4, 4.4, 3.3], 29 | [5.5, 6.6, 6.6, 5.5], 30 | ] 31 | -------------------------------------------------------------------------------- /tests/test_spec_indexing_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/indexing_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import ragged 10 | 11 | 12 | def test_existence(): 13 | assert ragged.take is not None 14 | -------------------------------------------------------------------------------- /tests/test_spec_linear_algebra_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import ragged 10 | 11 | 12 | def test_existence(): 13 | assert ragged.matmul is not None 14 | assert ragged.matrix_transpose is not None 15 | assert ragged.tensordot is not None 16 | assert ragged.vecdot is not None 17 | -------------------------------------------------------------------------------- /tests/test_spec_manipulation_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | import pytest 11 | 12 | import ragged 13 | 14 | devices = ["cpu"] 15 | try: 16 | import cupy as cp 17 | 18 | devices.append("cuda") 19 | except ModuleNotFoundError: 20 | cp = None 21 | 22 | 23 | def test_existence(): 24 | assert ragged.broadcast_arrays is not None 25 | assert ragged.broadcast_to is not None 26 | assert ragged.concat is not None 27 | assert ragged.expand_dims is not None 28 | assert ragged.flip is not None 29 | assert ragged.permute_dims is not None 30 | assert ragged.reshape is not None 31 | assert ragged.roll is not None 32 | assert ragged.squeeze is not None 33 | assert ragged.stack is not None 34 | 35 | 36 | @pytest.mark.parametrize("device", devices) 37 | def test_broadcast_arrays(device, x, y): 38 | x_bc, y_bc = ragged.broadcast_arrays(x.to_device(device), y.to_device(device)) 39 | if x.shape == () and y.shape == (): 40 | assert x_bc.shape == () 41 | assert y_bc.shape == () 42 | else: 43 | assert x_bc.shape == y_bc.shape 44 | if x_bc.shape == (3,): 45 | assert (x_bc * 0).tolist() == (y_bc * 0).tolist() == [0, 0, 0] 46 | if x_bc.shape == (3, None): 47 | assert (x_bc * 0).tolist() == (y_bc * 0).tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] 48 | 49 | 50 | def test_concat(x, y): 51 | if x.ndim != y.ndim: 52 | with pytest.raises(ValueError, match="same number of dimensions"): 53 | ragged.concat([x, y]) 54 | 55 | elif x.ndim == 0: 56 | with pytest.raises(ValueError, match="zero-dimensional"): 57 | ragged.concat([x, y]) 58 | 59 | elif x.ndim == 1: 60 | assert ragged.concat([x, y], axis=None).tolist() == x.tolist() + y.tolist() 61 | assert ragged.concat([x, y], axis=0).tolist() == x.tolist() + y.tolist() 62 | 63 | else: 64 | assert ragged.concat([x, y], axis=None).tolist() == [ 65 | 1.1, 66 | 2.2, 67 | 3.3, 68 | 4.4, 69 | 5.5, 70 | 1.1, 71 | 2.2, 72 | 3.3, 73 | 4.4, 74 | 5.5, 75 | ] 76 | assert ragged.concat([x, y], axis=0).tolist() == [ # type: ignore[comparison-overlap] 77 | [1.1, 2.2, 3.3], 78 | [], 79 | [4.4, 5.5], 80 | [1.1, 2.2, 3.3], 81 | [], 82 | [4.4, 5.5], 83 | ] 84 | assert ragged.concat([x, y], axis=1).tolist() == [ # type: ignore[comparison-overlap] 85 | [1.1, 2.2, 3.3, 1.1, 2.2, 3.3], 86 | [], 87 | [4.4, 5.5, 4.4, 5.5], 88 | ] 89 | 90 | 91 | @pytest.mark.parametrize("axis", [0, 1, 2]) 92 | def test_expand_dims(x, axis): 93 | if 0 <= axis <= x.ndim: 94 | a = ragged.expand_dims(x, axis=axis) 95 | assert a.shape == x.shape[:axis] + (1,) + x.shape[axis:] 96 | assert str(a._impl.type) == " * ".join( # type: ignore[union-attr] 97 | ["var" if ai is None else str(ai) for ai in a.shape] + [str(a.dtype)] 98 | ) 99 | 100 | else: 101 | with pytest.raises(ak.errors.AxisError): 102 | ragged.expand_dims(x, axis=axis) 103 | 104 | 105 | @pytest.mark.parametrize("axis", [0, 1, 2]) 106 | def test_squeeze(x, axis): 107 | if 0 <= axis <= x.ndim: 108 | a = ragged.expand_dims(x, axis=axis) 109 | b = ragged.squeeze(a, axis=axis) 110 | assert b.shape == x.shape 111 | assert b.tolist() == x.tolist() 112 | -------------------------------------------------------------------------------- /tests/test_spec_searching_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/searching_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Any 10 | 11 | import awkward as ak 12 | import numpy as np 13 | import pytest 14 | 15 | import ragged 16 | 17 | 18 | def first(x: ragged.array) -> Any: 19 | out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl 20 | return np.asarray(out.item(), dtype=x.dtype) 21 | 22 | 23 | def test_existence(): 24 | assert ragged.argmax is not None 25 | assert ragged.argmin is not None 26 | assert ragged.nonzero is not None 27 | assert ragged.where is not None 28 | 29 | 30 | def test_argmax(): 31 | data = ragged.array( 32 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 33 | ) 34 | assert ragged.argmax(data, axis=None).tolist() == 9 35 | assert ( 36 | ragged.argmax(data, axis=0).tolist() 37 | == ragged.argmax(data, axis=-3).tolist() 38 | == [[1, 1, 0], [1], [0, 0, 0, 0]] 39 | ) 40 | assert ( 41 | ragged.argmax(data, axis=1).tolist() # type: ignore[comparison-overlap] 42 | == ragged.argmax(data, axis=-2).tolist() 43 | == [[0, 0, 0], [], [2, 2, 2, 2]] 44 | ) 45 | with pytest.raises(ValueError, match=".*axis.*"): 46 | ragged.argmax(data, axis=2) 47 | with pytest.raises(ValueError, match=".*axis.*"): 48 | ragged.argmax(data, axis=-1) 49 | 50 | 51 | def test_argmin(): 52 | data = ragged.array( 53 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 54 | ) 55 | assert ragged.argmin(data, axis=None).tolist() == 0 56 | assert ( 57 | ragged.argmin(data, axis=0).tolist() 58 | == ragged.argmin(data, axis=-3).tolist() 59 | == [[0, 0, 0], [1], [0, 0, 0, 0]] 60 | ) 61 | assert ( 62 | ragged.argmin(data, axis=1).tolist() # type: ignore[comparison-overlap] 63 | == ragged.argmin(data, axis=-2).tolist() 64 | == [[0, 0, 0], [], [0, 0, 2, 2]] 65 | ) 66 | with pytest.raises(ValueError, match=".*axis.*"): 67 | ragged.argmin(data, axis=2) 68 | with pytest.raises(ValueError, match=".*axis.*"): 69 | ragged.argmin(data, axis=-1) 70 | 71 | 72 | def test_nonzero(): 73 | (result,) = ragged.nonzero(ragged.array(0)) 74 | assert result.tolist() == [] 75 | 76 | (result,) = ragged.nonzero(ragged.array(123)) 77 | assert result.tolist() == [0] 78 | 79 | (result,) = ragged.nonzero(ragged.array([0])) 80 | assert result.tolist() == [] 81 | 82 | (result,) = ragged.nonzero(ragged.array([123])) 83 | assert result.tolist() == [0] 84 | 85 | (result,) = ragged.nonzero(ragged.array([111, 222, 0, 333])) 86 | assert result.tolist() == [0, 1, 3] 87 | 88 | result1, result2 = ragged.nonzero(ragged.array([[111, 222, 0], [333, 0, 444]])) 89 | assert result1.tolist() == [0, 0, 1, 1] 90 | assert result2.tolist() == [0, 1, 0, 2] 91 | 92 | 93 | def test_where(x_bool, x, y): 94 | z = ragged.where(x_bool, x, y) 95 | if x_bool.ndim == x.ndim == y.ndim == 0: 96 | assert z.ndim == 0 97 | assert z == x if x_bool else y 98 | else: 99 | assert z.ndim == max(x_bool.ndim, x.ndim, y.ndim) 100 | assert first(z) == first(x) if first(x_bool) else first(y) 101 | -------------------------------------------------------------------------------- /tests/test_spec_set_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/set_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import awkward as ak 10 | 11 | import ragged 12 | 13 | 14 | def test_existence(): 15 | assert ragged.unique_all is not None 16 | assert ragged.unique_counts is not None 17 | assert ragged.unique_inverse is not None 18 | assert ragged.unique_values is not None 19 | 20 | 21 | # unique_values tests 22 | def test_can_take_list(): 23 | arr = ragged.array([1, 2, 4, 3, 4, 5, 6, 20]) 24 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 25 | unique_values = ragged.unique_values(arr) 26 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 27 | 28 | 29 | def test_can_take_empty_arr(): 30 | arr = ragged.array([]) 31 | expected_unique_values = ragged.array([]) 32 | unique_values = ragged.unique_values(arr) 33 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 34 | 35 | 36 | def test_can_take_moredimensions(): 37 | arr = ragged.array([[1, 2, 2, 3, 4], [5, 6]]) 38 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6]) 39 | unique_values = ragged.unique_values(arr) 40 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 41 | 42 | 43 | def test_can_take_1d_array(): 44 | arr = ragged.array([5, 6, 7, 8, 8, 9, 1, 2, 3, 4, 10, 0, 15, 2]) 45 | expected_unique_values = ragged.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15]) 46 | assert ak.to_list(ragged.unique_values(arr)) == ak.to_list(expected_unique_values) 47 | 48 | 49 | def test_can_take_scalar_int(): 50 | arr = ragged.array(5) 51 | expected_unique_values = ragged.array(5) 52 | unique_values = ragged.unique_values(arr) 53 | assert unique_values == expected_unique_values 54 | 55 | 56 | def test_can_take_scalar_float(): 57 | arr = ragged.array(4.326) 58 | expected_unique_values = ragged.array(4.326) 59 | unique_values = ragged.unique_values(arr) 60 | assert unique_values == expected_unique_values 61 | 62 | 63 | # unique_counts tests 64 | def test_can_count_list(): 65 | arr = ragged.array([1, 2, 4, 3, 4, 5, 6, 20]) 66 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 67 | expected_unique_counts = ragged.array([1, 1, 1, 2, 1, 1, 1]) 68 | unique_values, unique_counts = ragged.unique_counts(arr) 69 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 70 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 71 | 72 | 73 | def test_can_count_empty_arr(): 74 | arr = ragged.array([]) 75 | expected_unique_values = ragged.array([]) 76 | expected_counts = ragged.array([]) 77 | unique_values, unique_counts = ragged.unique_counts(arr) 78 | assert ak.to_list(expected_unique_values) == ak.to_list(unique_values) 79 | assert ak.to_list(expected_counts) == ak.to_list(unique_counts) 80 | 81 | 82 | def test_can_count_simple_array(): 83 | arr = ragged.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) 84 | expected_unique_values = ragged.array([1, 2, 3, 4]) 85 | expected_counts = ragged.array([1, 2, 3, 4]) 86 | unique_values, unique_counts = ragged.unique_counts(arr) 87 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 88 | assert ak.to_list(unique_counts) == ak.to_list(expected_counts) 89 | 90 | 91 | def test_can_count_normal_array(): 92 | arr = ragged.array([[1, 2, 2], [3], [3, 3], [4, 4, 4], [4]]) 93 | expected_unique_values = ragged.array([1, 2, 3, 4]) 94 | expected_counts = ragged.array([1, 2, 3, 4]) 95 | unique_values, unique_counts = ragged.unique_counts(arr) 96 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 97 | assert ak.to_list(unique_counts) == ak.to_list(expected_counts) 98 | 99 | 100 | def test_can_count_scalar_int(): 101 | arr = ragged.array(5) 102 | expected_unique_values = ragged.array(5) 103 | expected_counts = ragged.array([1]) 104 | unique_values, unique_counts = ragged.unique_counts(arr) 105 | assert unique_values == expected_unique_values 106 | assert unique_counts == expected_counts 107 | 108 | 109 | def test_can_count_scalar_float(): 110 | arr = ragged.array(4.326) 111 | expected_unique_values = ragged.array(4.326) 112 | expected_counts = ragged.array([1]) 113 | unique_values, unique_counts = ragged.unique_counts(arr) 114 | assert unique_values == expected_unique_values 115 | assert unique_counts == expected_counts 116 | 117 | 118 | # unique_inverse tests 119 | def test_can_inverse_list(): 120 | arr = ragged.array([1, 2, 4, 3, 4, 5, 6, 20]) 121 | expected_unique_values = ragged.array([1, 2, 3, 4, 5, 6, 20]) 122 | expected_inverse_indices = ragged.array([0, 1, 3, 2, 3, 4, 5, 6]) 123 | unique_values, inverse_indices = ragged.unique_inverse(arr) 124 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 125 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 126 | 127 | 128 | def test_can_inverse_empty_arr(): 129 | arr = ragged.array([]) 130 | expected_unique_values = ragged.array([]) 131 | expected_inverse_indices = ragged.array([]) 132 | unique_values, inverse_indices = ragged.unique_inverse(arr) 133 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 134 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 135 | 136 | 137 | def test_can_inverse_simple_array(): 138 | arr = ragged.array([[1, 2, 2], [3, 3, 3], [4, 4, 4, 4]]) 139 | expected_unique_values = ragged.array([1, 2, 3, 4]) 140 | expected_inverse_indices = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 141 | unique_values, inverse_indices = ragged.unique_inverse(arr) 142 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 143 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 144 | 145 | 146 | def test_can_inverse_normal_array(): 147 | arr = ragged.array([[1, 2, 2], [3], [3, 3], [4, 4, 4], [4]]) 148 | expected_unique_values = ragged.array([1, 2, 3, 4]) 149 | expected_inverse_indices = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 150 | unique_values, inverse_indices = ragged.unique_inverse(arr) 151 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 152 | assert ak.to_list(inverse_indices) == ak.to_list(expected_inverse_indices) 153 | 154 | 155 | def test_can_inverse_scalar_int(): 156 | arr = ragged.array(5) 157 | expected_unique_values = ragged.array(5) 158 | expected_inverse_indices = ragged.array([0]) 159 | unique_values, inverse_indices = ragged.unique_inverse(arr) 160 | assert unique_values == expected_unique_values 161 | assert inverse_indices == expected_inverse_indices 162 | 163 | 164 | def test_can_inverse_scalar_float(): 165 | arr = ragged.array(4.326) 166 | expected_unique_values = ragged.array(4.326) 167 | expected_inverse_indices = ragged.array([0]) 168 | unique_values, inverse_indices = ragged.unique_inverse(arr) 169 | assert unique_values == expected_unique_values 170 | assert inverse_indices == expected_inverse_indices 171 | 172 | 173 | # unique_all tests 174 | def test_can_all_list(): 175 | arr = ragged.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) 176 | expected_unique_values = ragged.array([1, 2, 3, 4]) 177 | expected_unique_indices = ragged.array([0, 1, 3, 6]) 178 | expected_unique_inverse = ragged.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]) 179 | expected_unique_counts = ragged.array([1, 2, 3, 4]) 180 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 181 | arr 182 | ) 183 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 184 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 185 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 186 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 187 | 188 | 189 | def test_can_all_empty_arr(): 190 | arr = ragged.array([]) 191 | expected_unique_values = ragged.array([]) 192 | expected_unique_indices = ragged.array([]) 193 | expected_unique_inverse = ragged.array([]) 194 | expected_unique_counts = ragged.array([]) 195 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 196 | arr 197 | ) 198 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 199 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 200 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 201 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 202 | 203 | 204 | def test_can_all_normal_array(): 205 | arr = ragged.array([[2, 2, 2], [3], [3, 5], [4, 4, 4], [4]]) 206 | expected_unique_values = ragged.array([2, 3, 4, 5]) 207 | expected_unique_indices = ragged.array([0, 3, 6, 5]) 208 | expected_unique_inverse = ragged.array([0, 0, 0, 1, 1, 3, 2, 2, 2, 2]) 209 | expected_unique_counts = ragged.array([3, 2, 4, 1]) 210 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 211 | arr 212 | ) 213 | assert ak.to_list(unique_values) == ak.to_list(expected_unique_values) 214 | assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices) 215 | assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse) 216 | assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts) 217 | 218 | 219 | def test_can_all_scalar_int(): 220 | arr = ragged.array(5) 221 | expected_unique_values = ragged.array(5) 222 | expected_unique_indices = ragged.array([0]) 223 | expected_unique_inverse = ragged.array([0]) 224 | expected_unique_counts = ragged.array([1]) 225 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 226 | arr 227 | ) 228 | assert unique_values == expected_unique_values 229 | assert unique_indices == expected_unique_indices 230 | assert unique_inverse == expected_unique_inverse 231 | assert unique_counts == expected_unique_counts 232 | 233 | 234 | def test_can_all_scalar_float(): 235 | arr = ragged.array(4.326) 236 | expected_unique_values = ragged.array(4.326) 237 | expected_unique_indices = ragged.array([0]) 238 | expected_unique_inverse = ragged.array([0]) 239 | expected_unique_counts = ragged.array([1]) 240 | unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all( 241 | arr 242 | ) 243 | assert unique_values == expected_unique_values 244 | assert unique_indices == expected_unique_indices 245 | assert unique_inverse == expected_unique_inverse 246 | assert unique_counts == expected_unique_counts 247 | -------------------------------------------------------------------------------- /tests/test_spec_sorting_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/sorting_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import pytest 10 | 11 | import ragged 12 | 13 | devices = ["cpu"] 14 | try: 15 | import cupy as cp 16 | 17 | # FIXME! 18 | # devices.append("cuda") 19 | except ModuleNotFoundError: 20 | cp = None 21 | 22 | 23 | def test_existence(): 24 | assert ragged.argsort is not None 25 | assert ragged.sort is not None 26 | 27 | 28 | @pytest.mark.parametrize("device", devices) 29 | def test_argsort(device): 30 | x = ragged.array( 31 | [[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device 32 | ) 33 | assert ragged.argsort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] 34 | [1, 0, 2], 35 | [], 36 | [0, 1], 37 | [0], 38 | [3, 1, 2, 0], 39 | ] 40 | assert ragged.argsort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] 41 | [2, 0, 1], 42 | [], 43 | [1, 0], 44 | [0], 45 | [0, 2, 1, 3], 46 | ] 47 | assert ragged.argsort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] 48 | [0, 0, 0], 49 | [], 50 | [2, 2], 51 | [3], 52 | [4, 4, 4, 4], 53 | ] 54 | assert ragged.argsort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] 55 | [4, 4, 4], 56 | [], 57 | [3, 2], 58 | [2], 59 | [0, 0, 0, 4], 60 | ] 61 | 62 | 63 | @pytest.mark.parametrize("device", devices) 64 | def test_sort(device): 65 | x = ragged.array( 66 | [[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device 67 | ) 68 | assert ragged.sort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] 69 | [0, 1.1, 2.2], 70 | [], 71 | [3.3, 4.4], 72 | [5.5], 73 | [6.6, 7.7, 8.8, 9.9], 74 | ] 75 | assert ragged.sort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] 76 | [2.2, 1.1, 0], 77 | [], 78 | [4.4, 3.3], 79 | [5.5], 80 | [9.9, 8.8, 7.7, 6.6], 81 | ] 82 | assert ragged.sort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] 83 | [1.1, 0.0, 2.2], 84 | [], 85 | [3.3, 4.4], 86 | [5.5], 87 | [9.9, 7.7, 8.8, 6.6], 88 | ] 89 | assert ragged.sort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] 90 | [9.9, 7.7, 8.8], 91 | [], 92 | [5.5, 4.4], 93 | [3.3], 94 | [1.1, 0.0, 2.2, 6.6], 95 | ] 96 | -------------------------------------------------------------------------------- /tests/test_spec_statistical_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/statistical_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import pytest 10 | 11 | import ragged 12 | 13 | 14 | def test_existence(): 15 | assert ragged.max is not None 16 | assert ragged.mean is not None 17 | assert ragged.min is not None 18 | assert ragged.prod is not None 19 | assert ragged.std is not None 20 | assert ragged.sum is not None 21 | assert ragged.var is not None 22 | 23 | 24 | def test_max(): 25 | data = ragged.array( 26 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 27 | ) 28 | assert ragged.max(data, axis=None).tolist() == 9.9 29 | assert ( 30 | ragged.max(data, axis=0).tolist() 31 | == ragged.max(data, axis=-3).tolist() 32 | == [[3.3, 4.4, 2.2], [5.5], [6.6, 7.7, 8.8, 9.9]] 33 | ) 34 | assert ( 35 | ragged.max(data, axis=1).tolist() # type: ignore[comparison-overlap] 36 | == ragged.max(data, axis=-2).tolist() 37 | == [[0, 1.1, 2.2], [], [6.6, 7.7, 8.8, 9.9]] 38 | ) 39 | assert ( 40 | ragged.max(data, axis=2).tolist() 41 | == ragged.max(data, axis=-1).tolist() 42 | == [[2.2, -ragged.inf], [], [4.4, 5.5, 9.9]] 43 | ) 44 | assert ( 45 | ragged.max(data, axis=(0, 1)).tolist() 46 | == ragged.max(data, axis=(1, 0)).tolist() 47 | == [6.6, 7.7, 8.8, 9.9] 48 | ) 49 | assert ( 50 | ragged.max(data, axis=(0, 2)).tolist() 51 | == ragged.max(data, axis=(2, 0)).tolist() 52 | == [4.4, 5.5, 9.9] 53 | ) 54 | assert ( 55 | ragged.max(data, axis=(1, 2)).tolist() 56 | == ragged.max(data, axis=(2, 1)).tolist() 57 | == [2.2, -ragged.inf, 9.9] 58 | ) 59 | assert ( 60 | ragged.max(data, axis=(0, 1, 2)).tolist() 61 | == ragged.max(data, axis=(-1, 0, 1)).tolist() 62 | == 9.9 63 | ) 64 | 65 | 66 | def test_mean(): 67 | data = ragged.array( 68 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 69 | ) 70 | assert ragged.mean(data, axis=None).tolist() == pytest.approx(4.95) 71 | assert ( 72 | ragged.mean(data, axis=0).tolist() # type: ignore[comparison-overlap] 73 | == ragged.mean(data, axis=-3).tolist() 74 | == [ 75 | pytest.approx([1.65, 2.75, 2.2]), 76 | pytest.approx([5.5]), 77 | pytest.approx([6.6, 7.7, 8.8, 9.9]), 78 | ] 79 | ) 80 | assert ( 81 | ragged.mean(data, axis=1).tolist() # type: ignore[comparison-overlap] 82 | == ragged.mean(data, axis=-2).tolist() 83 | == [ 84 | pytest.approx([0, 1.1, 2.2]), 85 | pytest.approx([]), 86 | pytest.approx([5.13333, 6.05, 8.8, 9.9]), 87 | ] 88 | ) 89 | assert ( 90 | ragged.mean(data, axis=2).tolist() # type: ignore[comparison-overlap] 91 | == [ 92 | pytest.approx([1.1, ragged.nan], nan_ok=True), 93 | pytest.approx([]), 94 | pytest.approx([3.85, 5.5, 8.25]), 95 | ] 96 | ) 97 | assert ( 98 | ragged.mean(data, axis=-1).tolist() # type: ignore[comparison-overlap] 99 | == [ 100 | pytest.approx([1.1, ragged.nan], nan_ok=True), 101 | pytest.approx([]), 102 | pytest.approx([3.85, 5.5, 8.25]), 103 | ] 104 | ) 105 | assert ( 106 | ragged.mean(data, axis=(0, 1)).tolist() 107 | == ragged.mean(data, axis=(1, 0)).tolist() 108 | == pytest.approx([3.85, 4.4, 5.5, 9.9]) 109 | ) 110 | assert ( 111 | ragged.mean(data, axis=(0, 2)).tolist() 112 | == ragged.mean(data, axis=(2, 0)).tolist() 113 | == pytest.approx([2.2, 5.5, 8.25]) 114 | ) 115 | assert ragged.mean(data, axis=(1, 2)).tolist() == pytest.approx( 116 | [1.1, ragged.nan, 6.6], nan_ok=True 117 | ) 118 | assert ragged.mean(data, axis=(2, 1)).tolist() == pytest.approx( 119 | [1.1, ragged.nan, 6.6], nan_ok=True 120 | ) 121 | assert ( 122 | ragged.mean(data, axis=(0, 1, 2)).tolist() 123 | == ragged.mean(data, axis=(-1, 0, 1)).tolist() 124 | == pytest.approx(4.95) 125 | ) 126 | 127 | 128 | def test_min(): 129 | data = ragged.array( 130 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 131 | ) 132 | assert ragged.min(data, axis=None).tolist() == 0 133 | assert ( 134 | ragged.min(data, axis=0).tolist() 135 | == ragged.min(data, axis=-3).tolist() 136 | == [[0, 1.1, 2.2], [5.5], [6.6, 7.7, 8.8, 9.9]] 137 | ) 138 | assert ( 139 | ragged.min(data, axis=1).tolist() # type: ignore[comparison-overlap] 140 | == ragged.min(data, axis=-2).tolist() 141 | == [[0, 1.1, 2.2], [], [3.3, 4.4, 8.8, 9.9]] 142 | ) 143 | assert ( 144 | ragged.min(data, axis=2).tolist() 145 | == ragged.min(data, axis=-1).tolist() 146 | == [[0, ragged.inf], [], [3.3, 5.5, 6.6]] 147 | ) 148 | assert ( 149 | ragged.min(data, axis=(0, 1)).tolist() 150 | == ragged.min(data, axis=(1, 0)).tolist() 151 | == [0, 1.1, 2.2, 9.9] 152 | ) 153 | assert ( 154 | ragged.min(data, axis=(0, 2)).tolist() 155 | == ragged.min(data, axis=(2, 0)).tolist() 156 | == [0, 5.5, 6.6] 157 | ) 158 | assert ( 159 | ragged.min(data, axis=(1, 2)).tolist() 160 | == ragged.min(data, axis=(2, 1)).tolist() 161 | == [0, ragged.inf, 3.3] 162 | ) 163 | assert ( 164 | ragged.min(data, axis=(0, 1, 2)).tolist() 165 | == ragged.min(data, axis=(-1, 0, 1)).tolist() 166 | == 0 167 | ) 168 | 169 | 170 | def test_prod(): 171 | data = ragged.array([[[2, 3, 5], []], [], [[7, 11], [13], [17, 19, 23, 27]]]) 172 | assert ragged.prod(data, axis=None).tolist() == 6023507490 173 | assert ( 174 | ragged.prod(data, axis=0).tolist() 175 | == ragged.prod(data, axis=-3).tolist() 176 | == [[14, 33, 5], [13], [17, 19, 23, 27]] 177 | ) 178 | assert ( 179 | ragged.prod(data, axis=1).tolist() # type: ignore[comparison-overlap] 180 | == ragged.prod(data, axis=-2).tolist() 181 | == [[2, 3, 5], [], [1547, 209, 23, 27]] 182 | ) 183 | assert ( 184 | ragged.prod(data, axis=2).tolist() # type: ignore[comparison-overlap] 185 | == ragged.prod(data, axis=-1).tolist() 186 | == [[30, 1], [], [77, 13, 200583]] 187 | ) 188 | assert ( 189 | ragged.prod(data, axis=(0, 1)).tolist() 190 | == ragged.prod(data, axis=(1, 0)).tolist() 191 | == [3094, 627, 115, 27] 192 | ) 193 | assert ( 194 | ragged.prod(data, axis=(0, 2)).tolist() 195 | == ragged.prod(data, axis=(2, 0)).tolist() 196 | == [2310, 13, 200583] 197 | ) 198 | assert ( 199 | ragged.prod(data, axis=(1, 2)).tolist() 200 | == ragged.prod(data, axis=(2, 1)).tolist() 201 | == [30, 1, 200783583] 202 | ) 203 | assert ( 204 | ragged.prod(data, axis=(0, 1, 2)).tolist() 205 | == ragged.prod(data, axis=(-1, 0, 1)).tolist() 206 | == 6023507490 207 | ) 208 | 209 | 210 | def test_std(): 211 | data = ragged.array( 212 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 213 | ) 214 | assert ragged.std(data, axis=None).tolist() == pytest.approx(3.159509) 215 | assert ( 216 | ragged.std(data, axis=0).tolist() # type: ignore[comparison-overlap] 217 | == ragged.std(data, axis=-3).tolist() 218 | == [ 219 | pytest.approx([1.65, 1.65, 0]), 220 | pytest.approx([0]), 221 | pytest.approx([0, 0, 0, 0]), 222 | ] 223 | ) 224 | assert ( 225 | ragged.std(data, axis=1).tolist() # type: ignore[comparison-overlap] 226 | == ragged.std(data, axis=-2).tolist() 227 | == [ 228 | pytest.approx([0, 0, 0]), 229 | pytest.approx([]), 230 | pytest.approx([1.37194, 1.65, 0, 0]), 231 | ] 232 | ) 233 | assert ( 234 | ragged.std(data, axis=2).tolist() # type: ignore[comparison-overlap] 235 | == [ 236 | pytest.approx([0.898146, ragged.nan], nan_ok=True), 237 | pytest.approx([]), 238 | pytest.approx([0.55, 0, 1.229837]), 239 | ] 240 | ) 241 | assert ( 242 | ragged.std(data, axis=-1).tolist() # type: ignore[comparison-overlap] 243 | == [ 244 | pytest.approx([0.898146, ragged.nan], nan_ok=True), 245 | pytest.approx([]), 246 | pytest.approx([0.55, 0, 1.229837]), 247 | ] 248 | ) 249 | assert ( 250 | ragged.std(data, axis=(0, 1, 2)).tolist() 251 | == ragged.std(data, axis=(-1, 0, 1)).tolist() 252 | == pytest.approx(3.159509) 253 | ) 254 | 255 | 256 | def test_sum(): 257 | data = ragged.array( 258 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 259 | ) 260 | assert ragged.sum(data, axis=None).tolist() == pytest.approx(49.5) 261 | assert ( 262 | ragged.sum(data, axis=0).tolist() # type: ignore[comparison-overlap] 263 | == ragged.sum(data, axis=-3).tolist() 264 | == [ 265 | pytest.approx([3.3, 5.5, 2.2]), 266 | pytest.approx([5.5]), 267 | pytest.approx([6.6, 7.7, 8.8, 9.9]), 268 | ] 269 | ) 270 | assert ( 271 | ragged.sum(data, axis=1).tolist() # type: ignore[comparison-overlap] 272 | == ragged.sum(data, axis=-2).tolist() 273 | == [ 274 | pytest.approx([0, 1.1, 2.2]), 275 | pytest.approx([]), 276 | pytest.approx([15.4, 12.1, 8.8, 9.9]), 277 | ] 278 | ) 279 | assert ( 280 | ragged.sum(data, axis=2).tolist() # type: ignore[comparison-overlap] 281 | == ragged.sum(data, axis=-1).tolist() 282 | == [ 283 | pytest.approx([3.3, 0]), 284 | pytest.approx([]), 285 | pytest.approx([7.7, 5.5, 33.0]), 286 | ] 287 | ) 288 | assert ( 289 | ragged.sum(data, axis=(0, 1)).tolist() 290 | == ragged.sum(data, axis=(1, 0)).tolist() 291 | == pytest.approx([15.4, 13.2, 11.0, 9.9]) 292 | ) 293 | assert ( 294 | ragged.sum(data, axis=(0, 2)).tolist() 295 | == ragged.sum(data, axis=(2, 0)).tolist() 296 | == pytest.approx([11.0, 5.5, 33.0]) 297 | ) 298 | assert ( 299 | ragged.sum(data, axis=(1, 2)).tolist() 300 | == ragged.sum(data, axis=(2, 1)).tolist() 301 | == pytest.approx([3.3, 0, 46.2]) 302 | ) 303 | assert ( 304 | ragged.sum(data, axis=(0, 1, 2)).tolist() 305 | == ragged.sum(data, axis=(-1, 0, 1)).tolist() 306 | == pytest.approx(49.5) 307 | ) 308 | 309 | 310 | def test_var(): 311 | data = ragged.array( 312 | [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] 313 | ) 314 | assert ragged.var(data, axis=None).tolist() == pytest.approx(9.9825) 315 | assert ( 316 | ragged.var(data, axis=0).tolist() # type: ignore[comparison-overlap] 317 | == ragged.var(data, axis=-3).tolist() 318 | == [ 319 | pytest.approx([2.7225, 2.7225, 0]), 320 | pytest.approx([0]), 321 | pytest.approx([0, 0, 0, 0]), 322 | ] 323 | ) 324 | assert ( 325 | ragged.var(data, axis=1).tolist() # type: ignore[comparison-overlap] 326 | == ragged.var(data, axis=-2).tolist() 327 | == [ 328 | pytest.approx([0, 0, 0]), 329 | pytest.approx([]), 330 | pytest.approx([1.88222222, 2.7225, 0, 0]), 331 | ] 332 | ) 333 | assert ( 334 | ragged.var(data, axis=2).tolist() # type: ignore[comparison-overlap] 335 | == [ 336 | pytest.approx([0.80666667, ragged.nan], nan_ok=True), 337 | pytest.approx([]), 338 | pytest.approx([0.3025, 0, 1.5125]), 339 | ] 340 | ) 341 | assert ( 342 | ragged.var(data, axis=-1).tolist() # type: ignore[comparison-overlap] 343 | == [ 344 | pytest.approx([0.80666667, ragged.nan], nan_ok=True), 345 | pytest.approx([]), 346 | pytest.approx([0.3025, 0, 1.5125]), 347 | ] 348 | ) 349 | assert ( 350 | ragged.var(data, axis=(0, 1, 2)).tolist() 351 | == ragged.var(data, axis=(-1, 0, 1)).tolist() 352 | == pytest.approx(9.9825) 353 | ) 354 | -------------------------------------------------------------------------------- /tests/test_spec_utility_functions.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/utility_functions.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import ragged 10 | 11 | 12 | def test_existence(): 13 | assert ragged.all is not None 14 | assert ragged.any is not None 15 | 16 | 17 | def test_all(): 18 | data = ragged.array([[[0, 1, 2], []], [], [[0, 0], [1], [1, 1, 1, 1]]]) 19 | assert ragged.all(data, axis=None).tolist() is False 20 | assert ( 21 | ragged.all(data, axis=0).tolist() 22 | == ragged.all(data, axis=-3).tolist() 23 | == [[False, False, True], [True], [True, True, True, True]] 24 | ) 25 | assert ( 26 | ragged.all(data, axis=1).tolist() # type: ignore[comparison-overlap] 27 | == ragged.all(data, axis=-2).tolist() 28 | == [[False, True, True], [], [False, False, True, True]] 29 | ) 30 | assert ( 31 | ragged.all(data, axis=2).tolist() # type: ignore[comparison-overlap] 32 | == ragged.all(data, axis=-1).tolist() 33 | == [[False, True], [], [False, True, True]] 34 | ) 35 | assert ( 36 | ragged.all(data, axis=(0, 1)).tolist() 37 | == ragged.all(data, axis=(1, 0)).tolist() 38 | == [False, False, True, True] 39 | ) 40 | assert ( 41 | ragged.all(data, axis=(0, 2)).tolist() 42 | == ragged.all(data, axis=(2, 0)).tolist() 43 | == [False, True, True] 44 | ) 45 | assert ( 46 | ragged.all(data, axis=(1, 2)).tolist() 47 | == ragged.all(data, axis=(2, 1)).tolist() 48 | == [False, True, False] 49 | ) 50 | assert ( 51 | ragged.all(data, axis=(0, 1, 2)).tolist() 52 | is ragged.all(data, axis=(-1, 0, 1)).tolist() 53 | is False 54 | ) 55 | 56 | 57 | def test_any(): 58 | data = ragged.array([[[0, 1, 2], []], [], [[0, 0], [1], [1, 1, 1, 1]]]) 59 | assert ragged.any(data, axis=None).tolist() is True 60 | assert ( 61 | ragged.any(data, axis=0).tolist() 62 | == ragged.any(data, axis=-3).tolist() 63 | == [[False, True, True], [True], [True, True, True, True]] 64 | ) 65 | assert ( 66 | ragged.any(data, axis=1).tolist() # type: ignore[comparison-overlap] 67 | == ragged.any(data, axis=-2).tolist() 68 | == [[False, True, True], [], [True, True, True, True]] 69 | ) 70 | assert ( 71 | ragged.any(data, axis=2).tolist() # type: ignore[comparison-overlap] 72 | == ragged.any(data, axis=-1).tolist() 73 | == [[True, False], [], [False, True, True]] 74 | ) 75 | assert ( 76 | ragged.any(data, axis=(0, 1)).tolist() 77 | == ragged.any(data, axis=(1, 0)).tolist() 78 | == [True, True, True, True] 79 | ) 80 | assert ( 81 | ragged.any(data, axis=(0, 2)).tolist() 82 | == ragged.any(data, axis=(2, 0)).tolist() 83 | == [True, True, True] 84 | ) 85 | assert ( 86 | ragged.any(data, axis=(1, 2)).tolist() 87 | == ragged.any(data, axis=(2, 1)).tolist() 88 | == [True, False, True] 89 | ) 90 | assert ( 91 | ragged.any(data, axis=(0, 1, 2)).tolist() 92 | is ragged.any(data, axis=(-1, 0, 1)).tolist() 93 | is True 94 | ) 95 | -------------------------------------------------------------------------------- /tests/test_spec_version.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/version.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import ragged 10 | 11 | 12 | def test_values(): 13 | assert ragged.__array_api_version__ == "2022.12" 14 | -------------------------------------------------------------------------------- /tests/test_type_promotion.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE 2 | 3 | """ 4 | https://data-apis.org/array-api/latest/API_specification/type_promotion.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | 10 | def test(): 11 | pass 12 | --------------------------------------------------------------------------------