├── .github ├── dependabot.yml └── workflows │ ├── array-api-tests-dask.yml │ ├── array-api-tests-numpy-1-22.yml │ ├── array-api-tests-numpy-1-26.yml │ ├── array-api-tests-numpy-dev.yml │ ├── array-api-tests-numpy-latest.yml │ ├── array-api-tests-torch.yml │ ├── array-api-tests.yml │ ├── dependabot-auto-merge.yml │ ├── docs-build.yml │ ├── docs-deploy.yml │ ├── publish-package.yml │ ├── ruff.yml │ └── tests.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── array_api_compat ├── __init__.py ├── _internal.py ├── common │ ├── __init__.py │ ├── _aliases.py │ ├── _fft.py │ ├── _helpers.py │ ├── _linalg.py │ └── _typing.py ├── cupy │ ├── __init__.py │ ├── _aliases.py │ ├── _info.py │ ├── _typing.py │ ├── fft.py │ └── linalg.py ├── dask │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ ├── _aliases.py │ │ ├── _info.py │ │ ├── fft.py │ │ └── linalg.py ├── numpy │ ├── __init__.py │ ├── _aliases.py │ ├── _info.py │ ├── _typing.py │ ├── fft.py │ └── linalg.py ├── py.typed └── torch │ ├── __init__.py │ ├── _aliases.py │ ├── _info.py │ ├── _typing.py │ ├── fft.py │ └── linalg.py ├── cupy-xfails.txt ├── dask-skips.txt ├── dask-xfails.txt ├── docs ├── Makefile ├── _static │ ├── custom.css │ └── favicon.png ├── changelog.md ├── conf.py ├── dev │ ├── implementation-notes.md │ ├── index.md │ ├── releasing.md │ ├── special-considerations.md │ └── tests.md ├── helper-functions.rst ├── index.md ├── make.bat └── supported-array-libraries.md ├── numpy-1-22-xfails.txt ├── numpy-1-26-xfails.txt ├── numpy-dev-xfails.txt ├── numpy-skips.txt ├── numpy-xfails.txt ├── pyproject.toml ├── test_cupy.sh ├── tests ├── __init__.py ├── _helpers.py ├── test_all.py ├── test_array_namespace.py ├── test_common.py ├── test_cupy.py ├── test_dask.py ├── test_isdtype.py ├── test_jax.py ├── test_no_dependencies.py ├── test_torch.py └── test_vendoring.py ├── torch-skips.txt ├── torch-xfails.txt └── vendor_test ├── __init__.py ├── uses_cupy.py ├── uses_dask.py ├── uses_numpy.py ├── uses_torch.py └── vendored ├── __init__.py └── _compat /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | groups: 9 | actions: 10 | patterns: 11 | - "*" 12 | labels: 13 | - "github-actions" 14 | - "dependencies" 15 | reviewers: 16 | - "asmeurer" 17 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-dask.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (Dask) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-dask: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: dask 10 | module-name: dask.array 11 | extra-requires: numpy 12 | # Dask is substantially slower then other libraries on unit tests. 13 | # Reduce the number of examples to speed up CI, even though this means that this 14 | # workflow is barely more than a smoke test, and one should expect extreme 15 | # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run 16 | # the full test suite with at least 200 examples. 17 | pytest-extra-args: --max-examples=200 -n 4 18 | python-versions: '[''3.10'', ''3.13'']' 19 | extra-env-vars: | 20 | ARRAY_API_TESTS_XFAIL_MARK=skip 21 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-numpy-1-22.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (NumPy 1.22) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-numpy-1-22: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: numpy 10 | package-version: '== 1.22.*' 11 | xfails-file-extra: '-1-22' 12 | python-versions: '[''3.10'']' 13 | pytest-extra-args: -n 4 14 | extra-env-vars: | 15 | ARRAY_API_TESTS_XFAIL_MARK=skip 16 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-numpy-1-26.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (NumPy 1.26) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-numpy-latest: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: numpy 10 | package-version: '== 1.26.*' 11 | xfails-file-extra: '-1-26' 12 | python-versions: '[''3.10'', ''3.12'']' 13 | pytest-extra-args: -n 4 14 | extra-env-vars: | 15 | ARRAY_API_TESTS_XFAIL_MARK=skip 16 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-numpy-dev.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (NumPy dev) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-numpy-dev: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: numpy 10 | extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' 11 | xfails-file-extra: '-dev' 12 | python-versions: '[''3.11'', ''3.13'']' 13 | pytest-extra-args: -n 4 14 | extra-env-vars: | 15 | ARRAY_API_TESTS_XFAIL_MARK=skip 16 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-numpy-latest.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (NumPy latest) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-numpy-latest: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: numpy 10 | python-versions: '[''3.10'', ''3.13'']' 11 | pytest-extra-args: -n 4 12 | extra-env-vars: | 13 | ARRAY_API_TESTS_XFAIL_MARK=skip 14 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests-torch.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests (PyTorch CPU) 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | array-api-tests-torch: 7 | uses: ./.github/workflows/array-api-tests.yml 8 | with: 9 | package-name: torch 10 | extra-requires: '--index-url https://download.pytorch.org/whl/cpu' 11 | extra-env-vars: | 12 | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 13 | ARRAY_API_TESTS_XFAIL_MARK=skip 14 | python-versions: '[''3.10'', ''3.13'']' 15 | pytest-extra-args: -n 4 16 | -------------------------------------------------------------------------------- /.github/workflows/array-api-tests.yml: -------------------------------------------------------------------------------- 1 | name: Array API Tests 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | package-name: 7 | required: true 8 | type: string 9 | module-name: 10 | required: false 11 | type: string 12 | extra-requires: 13 | required: false 14 | type: string 15 | package-version: 16 | required: false 17 | type: string 18 | default: '>= 0' 19 | python-versions: 20 | required: true 21 | type: string 22 | description: JSON array of Python versions to test against. 23 | pytest-extra-args: 24 | required: false 25 | type: string 26 | # This is not how I would prefer to implement this but it's the only way 27 | # that seems possible with GitHub Actions' limited expressions syntax 28 | xfails-file-extra: 29 | required: false 30 | type: string 31 | skips-file-extra: 32 | required: false 33 | type: string 34 | extra-env-vars: 35 | required: false 36 | type: string 37 | description: Multiline string of environment variables to set for the test run. 38 | 39 | env: 40 | PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" 41 | 42 | jobs: 43 | tests: 44 | runs-on: ubuntu-latest 45 | strategy: 46 | fail-fast: false 47 | matrix: 48 | python-version: ${{ fromJson(inputs.python-versions) }} 49 | 50 | steps: 51 | - name: Checkout array-api-compat 52 | uses: actions/checkout@v4 53 | with: 54 | path: array-api-compat 55 | 56 | - name: Checkout array-api-tests 57 | uses: actions/checkout@v4 58 | with: 59 | repository: data-apis/array-api-tests 60 | submodules: 'true' 61 | path: array-api-tests 62 | 63 | - name: Set up Python ${{ matrix.python-version }} 64 | uses: actions/setup-python@v5 65 | with: 66 | python-version: ${{ matrix.python-version }} 67 | 68 | - name: Set Extra Environment Variables 69 | # Set additional environment variables if provided 70 | if: inputs.extra-env-vars 71 | run: | 72 | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV 73 | 74 | - name: Install dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} 78 | python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt 79 | python -m pip install pytest-xdist 80 | 81 | - name: Dump pip environment 82 | run: pip freeze 83 | 84 | - name: Run the array API testsuite (${{ inputs.package-name }}) 85 | env: 86 | ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} 87 | ARRAY_API_TESTS_VERSION: 2024.12 88 | # This enables the NEP 50 type promotion behavior (without it a lot of 89 | # tests fail on bad scalar type promotion behavior) 90 | NPY_PROMOTION_STATE: weak 91 | run: | 92 | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" 93 | cd ${GITHUB_WORKSPACE}/array-api-tests 94 | pytest array_api_tests/ --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.xfails-file-extra }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.skips-file-extra}}-skips.txt ${PYTEST_ARGS} 95 | -------------------------------------------------------------------------------- /.github/workflows/dependabot-auto-merge.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/automating-dependabot-with-github-actions#approve-a-pull-request 2 | name: Dependabot auto-merge 3 | on: pull_request 4 | 5 | permissions: 6 | contents: write 7 | pull-requests: write 8 | 9 | jobs: 10 | dependabot: 11 | runs-on: ubuntu-latest 12 | if: github.actor == 'dependabot[bot]' 13 | steps: 14 | - name: Dependabot metadata 15 | id: metadata 16 | uses: dependabot/fetch-metadata@v2 17 | with: 18 | github-token: "${{ secrets.GITHUB_TOKEN }}" 19 | - name: Enable auto-merge for Dependabot PRs 20 | run: gh pr merge --auto --merge "$PR_URL" 21 | env: 22 | PR_URL: ${{github.event.pull_request.html_url}} 23 | GH_TOKEN: ${{secrets.GITHUB_TOKEN}} 24 | -------------------------------------------------------------------------------- /.github/workflows/docs-build.yml: -------------------------------------------------------------------------------- 1 | name: Docs Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | docs-build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: actions/setup-python@v5 11 | - name: Install Dependencies 12 | run: | 13 | python -m pip install .[docs] 14 | - name: Build Docs 15 | run: | 16 | cd docs 17 | make html 18 | - name: Upload Artifact 19 | uses: actions/upload-artifact@v4 20 | with: 21 | name: docs-build 22 | path: docs/_build/html 23 | -------------------------------------------------------------------------------- /.github/workflows/docs-deploy.yml: -------------------------------------------------------------------------------- 1 | name: Docs Deploy 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | docs-deploy: 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: docs-deploy 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Download Artifact 16 | uses: dawidd6/action-download-artifact@v10 17 | with: 18 | workflow: docs-build.yml 19 | name: docs-build 20 | path: docs/_build/html 21 | 22 | # Note, the gh-pages deployment requires setting up a SSH deploy key. 23 | # See 24 | # https://github.com/JamesIves/github-pages-deploy-action/tree/dev#using-an-ssh-deploy-key- 25 | - name: Deploy 26 | uses: JamesIves/github-pages-deploy-action@v4 27 | with: 28 | folder: docs/_build/html 29 | ssh-key: ${{ secrets.DEPLOY_KEY }} 30 | force: no 31 | -------------------------------------------------------------------------------- /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- 1 | name: publish distributions 2 | on: 3 | push: 4 | branches: 5 | - main 6 | tags: 7 | - '[0-9]+.[0-9]+' 8 | - '[0-9]+.[0-9]+.[0-9]+' 9 | pull_request: 10 | branches: 11 | - main 12 | release: 13 | types: [published] 14 | workflow_dispatch: 15 | inputs: 16 | publish: 17 | type: choice 18 | description: 'Publish to TestPyPI?' 19 | options: 20 | - false 21 | - true 22 | 23 | concurrency: 24 | group: ${{ github.workflow }}-${{ github.ref }} 25 | cancel-in-progress: true 26 | 27 | jobs: 28 | build: 29 | name: Build Python distribution 30 | runs-on: ubuntu-latest 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | with: 35 | fetch-depth: 0 36 | 37 | - name: Set up Python 38 | uses: actions/setup-python@v5 39 | with: 40 | python-version: '3.x' 41 | 42 | - name: Install python-build and twine 43 | run: | 44 | python -m pip install --upgrade pip "setuptools<=67" 45 | python -m pip install build twine 46 | python -m pip list 47 | 48 | - name: Build a wheel and a sdist 49 | run: | 50 | #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . 51 | python -m build . 52 | 53 | - name: Verify the distribution 54 | run: twine check --strict dist/* 55 | 56 | - name: List contents of sdist 57 | run: python -m tarfile --list dist/array_api_compat-*.tar.gz 58 | 59 | - name: List contents of wheel 60 | run: python -m zipfile --list dist/array_api_compat-*.whl 61 | 62 | - name: Upload distribution artifact 63 | uses: actions/upload-artifact@v4 64 | with: 65 | name: dist-artifact 66 | path: dist 67 | 68 | publish: 69 | name: Publish Python distribution to (Test)PyPI 70 | if: github.event_name != 'pull_request' && github.repository == 'data-apis/array-api-compat' && github.ref_type == 'tag' 71 | needs: build 72 | runs-on: ubuntu-latest 73 | # Mandatory for publishing with a trusted publisher 74 | # c.f. https://docs.pypi.org/trusted-publishers/using-a-publisher/ 75 | permissions: 76 | id-token: write 77 | contents: write 78 | # Restrict to the environment set for the trusted publisher 79 | environment: 80 | name: publish-package 81 | 82 | steps: 83 | - name: Download distribution artifact 84 | uses: actions/download-artifact@v4 85 | with: 86 | name: dist-artifact 87 | path: dist 88 | 89 | - name: List all files 90 | run: ls -lh dist 91 | 92 | # - name: Publish distribution 📦 to Test PyPI 93 | # # Publish to TestPyPI on tag events of if manually triggered 94 | # # Compare to 'true' string as booleans get turned into strings in the console 95 | # if: >- 96 | # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) 97 | # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') 98 | # uses: pypa/gh-action-pypi-publish@v1.12.4 99 | # with: 100 | # repository-url: https://test.pypi.org/legacy/ 101 | # print-hash: true 102 | 103 | - name: Publish distribution 📦 to PyPI 104 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 105 | uses: pypa/gh-action-pypi-publish@v1.12.4 106 | with: 107 | print-hash: true 108 | 109 | - name: Create GitHub Release from a Tag 110 | uses: softprops/action-gh-release@v2 111 | if: startsWith(github.ref, 'refs/tags/') 112 | with: 113 | files: dist/* 114 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [push, pull_request] 3 | jobs: 4 | check-ruff: 5 | runs-on: ubuntu-latest 6 | continue-on-error: true 7 | steps: 8 | - uses: actions/checkout@v4 9 | - name: Install Python 10 | uses: actions/setup-python@v5 11 | with: 12 | python-version: "3.11" 13 | - name: Install dependencies 14 | run: | 15 | python -m pip install --upgrade pip 16 | pip install ruff 17 | # Update output format to enable automatic inline annotations. 18 | - name: Run Ruff 19 | run: ruff check --output-format=github . 20 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: [push, pull_request] 3 | jobs: 4 | tests: 5 | runs-on: ubuntu-latest 6 | strategy: 7 | fail-fast: false 8 | matrix: 9 | include: 10 | - numpy-version: '1.22' 11 | python-version: '3.10' 12 | - numpy-version: '1.26' 13 | python-version: '3.10' 14 | - numpy-version: '1.26' 15 | python-version: '3.12' 16 | - numpy-version: 'latest' 17 | python-version: '3.10' 18 | - numpy-version: 'latest' 19 | python-version: '3.13' 20 | - numpy-version: 'dev' 21 | python-version: '3.11' 22 | - numpy-version: 'dev' 23 | python-version: '3.13' 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install Dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | python -m pip install pytest 34 | 35 | # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack 36 | python -m pip install array-api-strict 37 | python -m pip install torch --index-url https://download.pytorch.org/whl/cpu 38 | 39 | if [ "${{ matrix.numpy-version }}" == "dev" ]; then 40 | python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 41 | python -m pip install dask[array] jax[cpu] sparse ndonnx 42 | elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then 43 | python -m pip install 'numpy==1.22.*' 44 | elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then 45 | python -m pip install 'numpy==1.26.*' 46 | else 47 | python -m pip install numpy 48 | python -m pip install dask[array] jax[cpu] sparse ndonnx 49 | fi 50 | 51 | - name: Dump pip environment 52 | run: pip freeze 53 | 54 | - name: Test it installs 55 | run: python -m pip install . 56 | 57 | - name: Run Tests 58 | run: pytest -v 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # macOS specific iles 132 | .DS_Store 133 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | docs/changelog.md -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributions to array-api-compat are welcome, so long as they are [in 2 | scope](https://data-apis.org/array-api-compat/index.html#scope). 3 | 4 | Contributors are encouraged to read through the [development 5 | notes](https://data-apis.org/array-api-compat/dev/index.html) for the package 6 | to get full context on some of the design decisions and implementation 7 | details used in the codebase. 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Consortium for Python Data API Standards 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Array API compatibility library 2 | 3 | This is a small wrapper around common array libraries that is compatible with 4 | the [Array API standard](https://data-apis.org/array-api/latest/). Currently, 5 | NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want 6 | support for other array libraries, or if you encounter any issues, please [open 7 | an issue](https://github.com/data-apis/array-api-compat/issues). 8 | 9 | See the documentation for more details https://data-apis.org/array-api-compat/ 10 | -------------------------------------------------------------------------------- /array_api_compat/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | NumPy Array API compatibility library 3 | 4 | This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are 5 | compatible with the Array API standard https://data-apis.org/array-api/latest/. 6 | See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. 7 | 8 | Unlike array_api_strict, this is not a strict minimal implementation of the 9 | Array API, but rather just an extension of the main NumPy namespace with 10 | changes needed to be compliant with the Array API. See 11 | https://numpy.org/doc/stable/reference/array_api.html for a full list of 12 | changes. In particular, unlike array_api_strict, this package does not use a 13 | separate Array object, but rather just uses numpy.ndarray directly. 14 | 15 | Library authors using the Array API may wish to test against array_api_strict 16 | to ensure they are not using functionality outside of the standard, but prefer 17 | this implementation for the default when working with NumPy arrays. 18 | 19 | """ 20 | __version__ = '1.13.0.dev0' 21 | 22 | from .common import * # noqa: F401, F403 23 | -------------------------------------------------------------------------------- /array_api_compat/_internal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Internal helpers 3 | """ 4 | 5 | import importlib 6 | from collections.abc import Callable 7 | from functools import wraps 8 | from inspect import signature 9 | from types import ModuleType 10 | from typing import TypeVar 11 | 12 | _T = TypeVar("_T") 13 | 14 | 15 | def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: 16 | """ 17 | Decorator to automatically replace xp with the corresponding array module. 18 | 19 | Use like 20 | 21 | import numpy as np 22 | 23 | @get_xp(np) 24 | def func(x, /, xp, kwarg=None): 25 | return xp.func(x, kwarg=kwarg) 26 | 27 | Note that xp must be a keyword argument and come after all non-keyword 28 | arguments. 29 | 30 | """ 31 | 32 | def inner(f: Callable[..., _T], /) -> Callable[..., _T]: 33 | @wraps(f) 34 | def wrapped_f(*args: object, **kwargs: object) -> object: 35 | return f(*args, xp=xp, **kwargs) 36 | 37 | sig = signature(f) 38 | new_sig = sig.replace( 39 | parameters=[par for i, par in sig.parameters.items() if i != "xp"] 40 | ) 41 | 42 | if wrapped_f.__doc__ is None: 43 | wrapped_f.__doc__ = f"""\ 44 | Array API compatibility wrapper for {f.__name__}. 45 | 46 | See the corresponding documentation in NumPy/CuPy and/or the array API 47 | specification for more details. 48 | 49 | """ 50 | wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 51 | return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] 52 | 53 | return inner 54 | 55 | 56 | def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: 57 | """Import everything from module, updating globals(). 58 | Returns __all__. 59 | """ 60 | mod = importlib.import_module(mod_name) 61 | # Neither of these two methods is sufficient by itself, 62 | # depending on various idiosyncrasies of the libraries we're wrapping. 63 | objs = {} 64 | exec(f"from {mod.__name__} import *", objs) 65 | 66 | for n in dir(mod): 67 | if not n.startswith("_") and hasattr(mod, n): 68 | objs[n] = getattr(mod, n) 69 | 70 | globals_.update(objs) 71 | return list(objs) 72 | 73 | 74 | __all__ = ["get_xp", "clone_module"] 75 | 76 | def __dir__() -> list[str]: 77 | return __all__ 78 | -------------------------------------------------------------------------------- /array_api_compat/common/__init__.py: -------------------------------------------------------------------------------- 1 | from ._helpers import * # noqa: F403 2 | -------------------------------------------------------------------------------- /array_api_compat/common/_fft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | from typing import Literal, TypeAlias 5 | 6 | from ._typing import Array, Device, DType, Namespace 7 | 8 | _Norm: TypeAlias = Literal["backward", "ortho", "forward"] 9 | 10 | # Note: NumPy fft functions improperly upcast float32 and complex64 to 11 | # complex128, which is why we require wrapping them all here. 12 | 13 | def fft( 14 | x: Array, 15 | /, 16 | xp: Namespace, 17 | *, 18 | n: int | None = None, 19 | axis: int = -1, 20 | norm: _Norm = "backward", 21 | ) -> Array: 22 | res = xp.fft.fft(x, n=n, axis=axis, norm=norm) 23 | if x.dtype in [xp.float32, xp.complex64]: 24 | return res.astype(xp.complex64) 25 | return res 26 | 27 | def ifft( 28 | x: Array, 29 | /, 30 | xp: Namespace, 31 | *, 32 | n: int | None = None, 33 | axis: int = -1, 34 | norm: _Norm = "backward", 35 | ) -> Array: 36 | res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) 37 | if x.dtype in [xp.float32, xp.complex64]: 38 | return res.astype(xp.complex64) 39 | return res 40 | 41 | def fftn( 42 | x: Array, 43 | /, 44 | xp: Namespace, 45 | *, 46 | s: Sequence[int] | None = None, 47 | axes: Sequence[int] | None = None, 48 | norm: _Norm = "backward", 49 | ) -> Array: 50 | res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) 51 | if x.dtype in [xp.float32, xp.complex64]: 52 | return res.astype(xp.complex64) 53 | return res 54 | 55 | def ifftn( 56 | x: Array, 57 | /, 58 | xp: Namespace, 59 | *, 60 | s: Sequence[int] | None = None, 61 | axes: Sequence[int] | None = None, 62 | norm: _Norm = "backward", 63 | ) -> Array: 64 | res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) 65 | if x.dtype in [xp.float32, xp.complex64]: 66 | return res.astype(xp.complex64) 67 | return res 68 | 69 | def rfft( 70 | x: Array, 71 | /, 72 | xp: Namespace, 73 | *, 74 | n: int | None = None, 75 | axis: int = -1, 76 | norm: _Norm = "backward", 77 | ) -> Array: 78 | res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) 79 | if x.dtype == xp.float32: 80 | return res.astype(xp.complex64) 81 | return res 82 | 83 | def irfft( 84 | x: Array, 85 | /, 86 | xp: Namespace, 87 | *, 88 | n: int | None = None, 89 | axis: int = -1, 90 | norm: _Norm = "backward", 91 | ) -> Array: 92 | res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) 93 | if x.dtype == xp.complex64: 94 | return res.astype(xp.float32) 95 | return res 96 | 97 | def rfftn( 98 | x: Array, 99 | /, 100 | xp: Namespace, 101 | *, 102 | s: Sequence[int] | None = None, 103 | axes: Sequence[int] | None = None, 104 | norm: _Norm = "backward", 105 | ) -> Array: 106 | res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) 107 | if x.dtype == xp.float32: 108 | return res.astype(xp.complex64) 109 | return res 110 | 111 | def irfftn( 112 | x: Array, 113 | /, 114 | xp: Namespace, 115 | *, 116 | s: Sequence[int] | None = None, 117 | axes: Sequence[int] | None = None, 118 | norm: _Norm = "backward", 119 | ) -> Array: 120 | res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) 121 | if x.dtype == xp.complex64: 122 | return res.astype(xp.float32) 123 | return res 124 | 125 | def hfft( 126 | x: Array, 127 | /, 128 | xp: Namespace, 129 | *, 130 | n: int | None = None, 131 | axis: int = -1, 132 | norm: _Norm = "backward", 133 | ) -> Array: 134 | res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) 135 | if x.dtype in [xp.float32, xp.complex64]: 136 | return res.astype(xp.float32) 137 | return res 138 | 139 | def ihfft( 140 | x: Array, 141 | /, 142 | xp: Namespace, 143 | *, 144 | n: int | None = None, 145 | axis: int = -1, 146 | norm: _Norm = "backward", 147 | ) -> Array: 148 | res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) 149 | if x.dtype in [xp.float32, xp.complex64]: 150 | return res.astype(xp.complex64) 151 | return res 152 | 153 | def fftfreq( 154 | n: int, 155 | /, 156 | xp: Namespace, 157 | *, 158 | d: float = 1.0, 159 | dtype: DType | None = None, 160 | device: Device | None = None, 161 | ) -> Array: 162 | if device not in ["cpu", None]: 163 | raise ValueError(f"Unsupported device {device!r}") 164 | res = xp.fft.fftfreq(n, d=d) 165 | if dtype is not None: 166 | return res.astype(dtype) 167 | return res 168 | 169 | def rfftfreq( 170 | n: int, 171 | /, 172 | xp: Namespace, 173 | *, 174 | d: float = 1.0, 175 | dtype: DType | None = None, 176 | device: Device | None = None, 177 | ) -> Array: 178 | if device not in ["cpu", None]: 179 | raise ValueError(f"Unsupported device {device!r}") 180 | res = xp.fft.rfftfreq(n, d=d) 181 | if dtype is not None: 182 | return res.astype(dtype) 183 | return res 184 | 185 | def fftshift( 186 | x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None 187 | ) -> Array: 188 | return xp.fft.fftshift(x, axes=axes) 189 | 190 | def ifftshift( 191 | x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None 192 | ) -> Array: 193 | return xp.fft.ifftshift(x, axes=axes) 194 | 195 | __all__ = [ 196 | "fft", 197 | "ifft", 198 | "fftn", 199 | "ifftn", 200 | "rfft", 201 | "irfft", 202 | "rfftn", 203 | "irfftn", 204 | "hfft", 205 | "ihfft", 206 | "fftfreq", 207 | "rfftfreq", 208 | "fftshift", 209 | "ifftshift", 210 | ] 211 | 212 | def __dir__() -> list[str]: 213 | return __all__ 214 | -------------------------------------------------------------------------------- /array_api_compat/common/_linalg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from typing import Literal, NamedTuple, cast 5 | 6 | import numpy as np 7 | 8 | if np.__version__[0] == "2": 9 | from numpy.lib.array_utils import normalize_axis_tuple 10 | else: 11 | from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] 12 | 13 | from .._internal import get_xp 14 | from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot 15 | from ._typing import Array, DType, JustFloat, JustInt, Namespace 16 | 17 | 18 | # These are in the main NumPy namespace but not in numpy.linalg 19 | def cross( 20 | x1: Array, 21 | x2: Array, 22 | /, 23 | xp: Namespace, 24 | *, 25 | axis: int = -1, 26 | **kwargs: object, 27 | ) -> Array: 28 | return xp.cross(x1, x2, axis=axis, **kwargs) 29 | 30 | def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: 31 | return xp.outer(x1, x2, **kwargs) 32 | 33 | class EighResult(NamedTuple): 34 | eigenvalues: Array 35 | eigenvectors: Array 36 | 37 | class QRResult(NamedTuple): 38 | Q: Array 39 | R: Array 40 | 41 | class SlogdetResult(NamedTuple): 42 | sign: Array 43 | logabsdet: Array 44 | 45 | class SVDResult(NamedTuple): 46 | U: Array 47 | S: Array 48 | Vh: Array 49 | 50 | # These functions are the same as their NumPy counterparts except they return 51 | # a namedtuple. 52 | def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: 53 | return EighResult(*xp.linalg.eigh(x, **kwargs)) 54 | 55 | def qr( 56 | x: Array, 57 | /, 58 | xp: Namespace, 59 | *, 60 | mode: Literal["reduced", "complete"] = "reduced", 61 | **kwargs: object, 62 | ) -> QRResult: 63 | return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) 64 | 65 | def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: 66 | return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) 67 | 68 | def svd( 69 | x: Array, 70 | /, 71 | xp: Namespace, 72 | *, 73 | full_matrices: bool = True, 74 | **kwargs: object, 75 | ) -> SVDResult: 76 | return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) 77 | 78 | # These functions have additional keyword arguments 79 | 80 | # The upper keyword argument is new from NumPy 81 | def cholesky( 82 | x: Array, 83 | /, 84 | xp: Namespace, 85 | *, 86 | upper: bool = False, 87 | **kwargs: object, 88 | ) -> Array: 89 | L = xp.linalg.cholesky(x, **kwargs) 90 | if upper: 91 | U = get_xp(xp)(matrix_transpose)(L) 92 | if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): 93 | U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] 94 | return U 95 | return L 96 | 97 | # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. 98 | # Note that it has a different semantic meaning from tol and rcond. 99 | def matrix_rank( 100 | x: Array, 101 | /, 102 | xp: Namespace, 103 | *, 104 | rtol: float | Array | None = None, 105 | **kwargs: object, 106 | ) -> Array: 107 | # this is different from xp.linalg.matrix_rank, which supports 1 108 | # dimensional arrays. 109 | if x.ndim < 2: 110 | raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") 111 | S: Array = get_xp(xp)(svdvals)(x, **kwargs) 112 | if rtol is None: 113 | tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps 114 | else: 115 | # this is different from xp.linalg.matrix_rank, which does not 116 | # multiply the tolerance by the largest singular value. 117 | tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] 118 | return xp.count_nonzero(S > tol, axis=-1) 119 | 120 | def pinv( 121 | x: Array, 122 | /, 123 | xp: Namespace, 124 | *, 125 | rtol: float | Array | None = None, 126 | **kwargs: object, 127 | ) -> Array: 128 | # this is different from xp.linalg.pinv, which does not multiply the 129 | # default tolerance by max(M, N). 130 | if rtol is None: 131 | rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps 132 | return xp.linalg.pinv(x, rcond=rtol, **kwargs) 133 | 134 | # These functions are new in the array API spec 135 | 136 | def matrix_norm( 137 | x: Array, 138 | /, 139 | xp: Namespace, 140 | *, 141 | keepdims: bool = False, 142 | ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", 143 | ) -> Array: 144 | return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) 145 | 146 | # svdvals is not in NumPy (but it is in SciPy). It is equivalent to 147 | # xp.linalg.svd(compute_uv=False). 148 | def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: 149 | return xp.linalg.svd(x, compute_uv=False) 150 | 151 | def vector_norm( 152 | x: Array, 153 | /, 154 | xp: Namespace, 155 | *, 156 | axis: int | tuple[int, ...] | None = None, 157 | keepdims: bool = False, 158 | ord: JustInt | JustFloat = 2, 159 | ) -> Array: 160 | # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or 161 | # when axis=None and the input is 2-D, so to force a vector norm, we make 162 | # it so the input is 1-D (for axis=None), or reshape so that norm is done 163 | # on a single dimension. 164 | if axis is None: 165 | # Note: xp.linalg.norm() doesn't handle 0-D arrays 166 | _x = x.ravel() 167 | _axis = 0 168 | elif isinstance(axis, tuple): 169 | # Note: The axis argument supports any number of axes, whereas 170 | # xp.linalg.norm() only supports a single axis for vector norm. 171 | normalized_axis = cast( 172 | "tuple[int, ...]", 173 | normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] 174 | ) 175 | rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) 176 | newshape = axis + rest 177 | _x = xp.transpose(x, newshape).reshape( 178 | (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest])) 179 | _axis = 0 180 | else: 181 | _x = x 182 | _axis = axis 183 | 184 | res = xp.linalg.norm(_x, axis=_axis, ord=ord) 185 | 186 | if keepdims: 187 | # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks 188 | # above to avoid matrix norm logic. 189 | shape = list(x.shape) 190 | axes = cast( 191 | "tuple[int, ...]", 192 | normalize_axis_tuple( # pyright: ignore[reportCallIssue] 193 | range(x.ndim) if axis is None else axis, 194 | x.ndim, 195 | ), 196 | ) 197 | for i in axes: 198 | shape[i] = 1 199 | res = xp.reshape(res, tuple(shape)) 200 | 201 | return res 202 | 203 | # xp.diagonal and xp.trace operate on the first two axes whereas these 204 | # operates on the last two 205 | 206 | def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: 207 | return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) 208 | 209 | def trace( 210 | x: Array, 211 | /, 212 | xp: Namespace, 213 | *, 214 | offset: int = 0, 215 | dtype: DType | None = None, 216 | **kwargs: object, 217 | ) -> Array: 218 | return xp.asarray( 219 | xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) 220 | ) 221 | 222 | __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 223 | 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', 224 | 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 225 | 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 226 | 'trace'] 227 | 228 | 229 | def __dir__() -> list[str]: 230 | return __all__ 231 | -------------------------------------------------------------------------------- /array_api_compat/common/_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from types import ModuleType as Namespace 5 | from typing import ( 6 | TYPE_CHECKING, 7 | Literal, 8 | Protocol, 9 | TypeAlias, 10 | TypedDict, 11 | TypeVar, 12 | final, 13 | ) 14 | 15 | if TYPE_CHECKING: 16 | from _typeshed import Incomplete 17 | 18 | SupportsBufferProtocol: TypeAlias = Incomplete 19 | Array: TypeAlias = Incomplete 20 | Device: TypeAlias = Incomplete 21 | DType: TypeAlias = Incomplete 22 | else: 23 | SupportsBufferProtocol = object 24 | Array = object 25 | Device = object 26 | DType = object 27 | 28 | 29 | _T_co = TypeVar("_T_co", covariant=True) 30 | 31 | 32 | # These "Just" types are equivalent to the `Just` type from the `optype` library, 33 | # apart from them not being `@runtime_checkable`. 34 | # - docs: https://github.com/jorenham/optype/blob/master/README.md#just 35 | # - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py 36 | @final 37 | class JustInt(Protocol): # type: ignore[misc] 38 | @property # type: ignore[override] 39 | def __class__(self, /) -> type[int]: ... 40 | @__class__.setter 41 | def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] 42 | 43 | 44 | @final 45 | class JustFloat(Protocol): # type: ignore[misc] 46 | @property # type: ignore[override] 47 | def __class__(self, /) -> type[float]: ... 48 | @__class__.setter 49 | def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] 50 | 51 | 52 | @final 53 | class JustComplex(Protocol): # type: ignore[misc] 54 | @property # type: ignore[override] 55 | def __class__(self, /) -> type[complex]: ... 56 | @__class__.setter 57 | def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] 58 | 59 | 60 | class NestedSequence(Protocol[_T_co]): 61 | def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... 62 | def __len__(self, /) -> int: ... 63 | 64 | 65 | class SupportsArrayNamespace(Protocol[_T_co]): 66 | def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... 67 | 68 | 69 | class HasShape(Protocol[_T_co]): 70 | @property 71 | def shape(self, /) -> _T_co: ... 72 | 73 | 74 | # Return type of `__array_namespace_info__.default_dtypes` 75 | Capabilities = TypedDict( 76 | "Capabilities", 77 | { 78 | "boolean indexing": bool, 79 | "data-dependent shapes": bool, 80 | "max dimensions": int, 81 | }, 82 | ) 83 | 84 | # Return type of `__array_namespace_info__.default_dtypes` 85 | DefaultDTypes = TypedDict( 86 | "DefaultDTypes", 87 | { 88 | "real floating": DType, 89 | "complex floating": DType, 90 | "integral": DType, 91 | "indexing": DType, 92 | }, 93 | ) 94 | 95 | 96 | _DTypeKind: TypeAlias = Literal[ 97 | "bool", 98 | "signed integer", 99 | "unsigned integer", 100 | "integral", 101 | "real floating", 102 | "complex floating", 103 | "numeric", 104 | ] 105 | # Type of the `kind` parameter in `__array_namespace_info__.dtypes` 106 | DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] 107 | 108 | 109 | # `__array_namespace_info__.dtypes(kind="bool")` 110 | class DTypesBool(TypedDict): 111 | bool: DType 112 | 113 | 114 | # `__array_namespace_info__.dtypes(kind="signed integer")` 115 | class DTypesSigned(TypedDict): 116 | int8: DType 117 | int16: DType 118 | int32: DType 119 | int64: DType 120 | 121 | 122 | # `__array_namespace_info__.dtypes(kind="unsigned integer")` 123 | class DTypesUnsigned(TypedDict): 124 | uint8: DType 125 | uint16: DType 126 | uint32: DType 127 | uint64: DType 128 | 129 | 130 | # `__array_namespace_info__.dtypes(kind="integral")` 131 | class DTypesIntegral(DTypesSigned, DTypesUnsigned): 132 | pass 133 | 134 | 135 | # `__array_namespace_info__.dtypes(kind="real floating")` 136 | class DTypesReal(TypedDict): 137 | float32: DType 138 | float64: DType 139 | 140 | 141 | # `__array_namespace_info__.dtypes(kind="complex floating")` 142 | class DTypesComplex(TypedDict): 143 | complex64: DType 144 | complex128: DType 145 | 146 | 147 | # `__array_namespace_info__.dtypes(kind="numeric")` 148 | class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): 149 | pass 150 | 151 | 152 | # `__array_namespace_info__.dtypes(kind=None)` (default) 153 | class DTypesAll(DTypesBool, DTypesNumeric): 154 | pass 155 | 156 | 157 | # `__array_namespace_info__.dtypes(kind=?)` (fallback) 158 | DTypesAny: TypeAlias = Mapping[str, DType] 159 | 160 | 161 | __all__ = [ 162 | "Array", 163 | "Capabilities", 164 | "DType", 165 | "DTypeKind", 166 | "DTypesAny", 167 | "DTypesAll", 168 | "DTypesBool", 169 | "DTypesNumeric", 170 | "DTypesIntegral", 171 | "DTypesSigned", 172 | "DTypesUnsigned", 173 | "DTypesReal", 174 | "DTypesComplex", 175 | "DefaultDTypes", 176 | "Device", 177 | "HasShape", 178 | "Namespace", 179 | "JustInt", 180 | "JustFloat", 181 | "JustComplex", 182 | "NestedSequence", 183 | "SupportsArrayNamespace", 184 | "SupportsBufferProtocol", 185 | ] 186 | 187 | 188 | def __dir__() -> list[str]: 189 | return __all__ 190 | -------------------------------------------------------------------------------- /array_api_compat/cupy/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Final 2 | from cupy import * # noqa: F403 3 | 4 | # from cupy import * doesn't overwrite these builtin names 5 | from cupy import abs, max, min, round # noqa: F401 6 | 7 | # These imports may overwrite names from the import * above. 8 | from ._aliases import * # noqa: F403 9 | from ._info import __array_namespace_info__ # noqa: F401 10 | 11 | # See the comment in the numpy __init__.py 12 | __import__(__package__ + '.linalg') 13 | __import__(__package__ + '.fft') 14 | 15 | __array_api_version__: Final = '2024.12' 16 | 17 | __all__ = sorted( 18 | {name for name in globals() if not name.startswith("__")} 19 | - {"Final", "_aliases", "_info", "_typing"} 20 | | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} 21 | ) 22 | 23 | def __dir__() -> list[str]: 24 | return __all__ 25 | -------------------------------------------------------------------------------- /array_api_compat/cupy/_aliases.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from builtins import bool as py_bool 4 | 5 | import cupy as cp 6 | 7 | from ..common import _aliases, _helpers 8 | from ..common._typing import NestedSequence, SupportsBufferProtocol 9 | from .._internal import get_xp 10 | from ._typing import Array, Device, DType 11 | 12 | bool = cp.bool_ 13 | 14 | # Basic renames 15 | acos = cp.arccos 16 | acosh = cp.arccosh 17 | asin = cp.arcsin 18 | asinh = cp.arcsinh 19 | atan = cp.arctan 20 | atan2 = cp.arctan2 21 | atanh = cp.arctanh 22 | bitwise_left_shift = cp.left_shift 23 | bitwise_invert = cp.invert 24 | bitwise_right_shift = cp.right_shift 25 | concat = cp.concatenate 26 | pow = cp.power 27 | 28 | arange = get_xp(cp)(_aliases.arange) 29 | empty = get_xp(cp)(_aliases.empty) 30 | empty_like = get_xp(cp)(_aliases.empty_like) 31 | eye = get_xp(cp)(_aliases.eye) 32 | full = get_xp(cp)(_aliases.full) 33 | full_like = get_xp(cp)(_aliases.full_like) 34 | linspace = get_xp(cp)(_aliases.linspace) 35 | ones = get_xp(cp)(_aliases.ones) 36 | ones_like = get_xp(cp)(_aliases.ones_like) 37 | zeros = get_xp(cp)(_aliases.zeros) 38 | zeros_like = get_xp(cp)(_aliases.zeros_like) 39 | UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult) 40 | UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult) 41 | UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult) 42 | unique_all = get_xp(cp)(_aliases.unique_all) 43 | unique_counts = get_xp(cp)(_aliases.unique_counts) 44 | unique_inverse = get_xp(cp)(_aliases.unique_inverse) 45 | unique_values = get_xp(cp)(_aliases.unique_values) 46 | std = get_xp(cp)(_aliases.std) 47 | var = get_xp(cp)(_aliases.var) 48 | cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) 49 | cumulative_prod = get_xp(cp)(_aliases.cumulative_prod) 50 | clip = get_xp(cp)(_aliases.clip) 51 | permute_dims = get_xp(cp)(_aliases.permute_dims) 52 | reshape = get_xp(cp)(_aliases.reshape) 53 | argsort = get_xp(cp)(_aliases.argsort) 54 | sort = get_xp(cp)(_aliases.sort) 55 | nonzero = get_xp(cp)(_aliases.nonzero) 56 | ceil = get_xp(cp)(_aliases.ceil) 57 | floor = get_xp(cp)(_aliases.floor) 58 | trunc = get_xp(cp)(_aliases.trunc) 59 | matmul = get_xp(cp)(_aliases.matmul) 60 | matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) 61 | tensordot = get_xp(cp)(_aliases.tensordot) 62 | sign = get_xp(cp)(_aliases.sign) 63 | finfo = get_xp(cp)(_aliases.finfo) 64 | iinfo = get_xp(cp)(_aliases.iinfo) 65 | 66 | 67 | # asarray also adds the copy keyword, which is not present in numpy 1.0. 68 | def asarray( 69 | obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, 70 | /, 71 | *, 72 | dtype: DType | None = None, 73 | device: Device | None = None, 74 | copy: py_bool | None = None, 75 | **kwargs: object, 76 | ) -> Array: 77 | """ 78 | Array API compatibility wrapper for asarray(). 79 | 80 | See the corresponding documentation in the array library and/or the array API 81 | specification for more details. 82 | """ 83 | with cp.cuda.Device(device): 84 | if copy is None: 85 | return cp.asarray(obj, dtype=dtype, **kwargs) 86 | else: 87 | res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) 88 | if not copy and res is not obj: 89 | raise ValueError("Unable to avoid copy while creating an array as requested") 90 | return res 91 | 92 | 93 | def astype( 94 | x: Array, 95 | dtype: DType, 96 | /, 97 | *, 98 | copy: py_bool = True, 99 | device: Device | None = None, 100 | ) -> Array: 101 | if device is None: 102 | return x.astype(dtype=dtype, copy=copy) 103 | out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) 104 | return out.copy() if copy and out is x else out 105 | 106 | 107 | # cupy.count_nonzero does not have keepdims 108 | def count_nonzero( 109 | x: Array, 110 | axis: int | tuple[int, ...] | None = None, 111 | keepdims: py_bool = False, 112 | ) -> Array: 113 | result = cp.count_nonzero(x, axis) 114 | if keepdims: 115 | if axis is None: 116 | return cp.reshape(result, [1]*x.ndim) 117 | return cp.expand_dims(result, axis) 118 | return result 119 | 120 | 121 | # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg 122 | def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 123 | return cp.take_along_axis(x, indices, axis=axis) 124 | 125 | 126 | # These functions are completely new here. If the library already has them 127 | # (i.e., numpy 2.0), use the library version instead of our wrapper. 128 | if hasattr(cp, 'vecdot'): 129 | vecdot = cp.vecdot 130 | else: 131 | vecdot = get_xp(cp)(_aliases.vecdot) 132 | 133 | if hasattr(cp, 'isdtype'): 134 | isdtype = cp.isdtype 135 | else: 136 | isdtype = get_xp(cp)(_aliases.isdtype) 137 | 138 | if hasattr(cp, 'unstack'): 139 | unstack = cp.unstack 140 | else: 141 | unstack = get_xp(cp)(_aliases.unstack) 142 | 143 | __all__ = _aliases.__all__ + ['asarray', 'astype', 144 | 'acos', 'acosh', 'asin', 'asinh', 'atan', 145 | 'atan2', 'atanh', 'bitwise_left_shift', 146 | 'bitwise_invert', 'bitwise_right_shift', 147 | 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 148 | 'take_along_axis'] 149 | 150 | 151 | def __dir__() -> list[str]: 152 | return __all__ 153 | -------------------------------------------------------------------------------- /array_api_compat/cupy/_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | Array API Inspection namespace 3 | 4 | This is the namespace for inspection functions as defined by the array API 5 | standard. See 6 | https://data-apis.org/array-api/latest/API_specification/inspection.html for 7 | more details. 8 | 9 | """ 10 | from cupy import ( 11 | dtype, 12 | cuda, 13 | bool_ as bool, 14 | intp, 15 | int8, 16 | int16, 17 | int32, 18 | int64, 19 | uint8, 20 | uint16, 21 | uint32, 22 | uint64, 23 | float32, 24 | float64, 25 | complex64, 26 | complex128, 27 | ) 28 | 29 | 30 | class __array_namespace_info__: 31 | """ 32 | Get the array API inspection namespace for CuPy. 33 | 34 | The array API inspection namespace defines the following functions: 35 | 36 | - capabilities() 37 | - default_device() 38 | - default_dtypes() 39 | - dtypes() 40 | - devices() 41 | 42 | See 43 | https://data-apis.org/array-api/latest/API_specification/inspection.html 44 | for more details. 45 | 46 | Returns 47 | ------- 48 | info : ModuleType 49 | The array API inspection namespace for CuPy. 50 | 51 | Examples 52 | -------- 53 | >>> info = xp.__array_namespace_info__() 54 | >>> info.default_dtypes() 55 | {'real floating': cupy.float64, 56 | 'complex floating': cupy.complex128, 57 | 'integral': cupy.int64, 58 | 'indexing': cupy.int64} 59 | 60 | """ 61 | 62 | __module__ = 'cupy' 63 | 64 | def capabilities(self): 65 | """ 66 | Return a dictionary of array API library capabilities. 67 | 68 | The resulting dictionary has the following keys: 69 | 70 | - **"boolean indexing"**: boolean indicating whether an array library 71 | supports boolean indexing. Always ``True`` for CuPy. 72 | 73 | - **"data-dependent shapes"**: boolean indicating whether an array 74 | library supports data-dependent output shapes. Always ``True`` for 75 | CuPy. 76 | 77 | See 78 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html 79 | for more details. 80 | 81 | See Also 82 | -------- 83 | __array_namespace_info__.default_device, 84 | __array_namespace_info__.default_dtypes, 85 | __array_namespace_info__.dtypes, 86 | __array_namespace_info__.devices 87 | 88 | Returns 89 | ------- 90 | capabilities : dict 91 | A dictionary of array API library capabilities. 92 | 93 | Examples 94 | -------- 95 | >>> info = xp.__array_namespace_info__() 96 | >>> info.capabilities() 97 | {'boolean indexing': True, 98 | 'data-dependent shapes': True, 99 | 'max dimensions': 64} 100 | 101 | """ 102 | return { 103 | "boolean indexing": True, 104 | "data-dependent shapes": True, 105 | "max dimensions": 64, 106 | } 107 | 108 | def default_device(self): 109 | """ 110 | The default device used for new CuPy arrays. 111 | 112 | See Also 113 | -------- 114 | __array_namespace_info__.capabilities, 115 | __array_namespace_info__.default_dtypes, 116 | __array_namespace_info__.dtypes, 117 | __array_namespace_info__.devices 118 | 119 | Returns 120 | ------- 121 | device : Device 122 | The default device used for new CuPy arrays. 123 | 124 | Examples 125 | -------- 126 | >>> info = xp.__array_namespace_info__() 127 | >>> info.default_device() 128 | Device(0) 129 | 130 | Notes 131 | ----- 132 | This method returns the static default device when CuPy is initialized. 133 | However, the *current* device used by creation functions (``empty`` etc.) 134 | can be changed globally or with a context manager. 135 | 136 | See Also 137 | -------- 138 | https://github.com/data-apis/array-api/issues/835 139 | """ 140 | return cuda.Device(0) 141 | 142 | def default_dtypes(self, *, device=None): 143 | """ 144 | The default data types used for new CuPy arrays. 145 | 146 | For CuPy, this always returns the following dictionary: 147 | 148 | - **"real floating"**: ``cupy.float64`` 149 | - **"complex floating"**: ``cupy.complex128`` 150 | - **"integral"**: ``cupy.intp`` 151 | - **"indexing"**: ``cupy.intp`` 152 | 153 | Parameters 154 | ---------- 155 | device : str, optional 156 | The device to get the default data types for. 157 | 158 | Returns 159 | ------- 160 | dtypes : dict 161 | A dictionary describing the default data types used for new CuPy 162 | arrays. 163 | 164 | See Also 165 | -------- 166 | __array_namespace_info__.capabilities, 167 | __array_namespace_info__.default_device, 168 | __array_namespace_info__.dtypes, 169 | __array_namespace_info__.devices 170 | 171 | Examples 172 | -------- 173 | >>> info = xp.__array_namespace_info__() 174 | >>> info.default_dtypes() 175 | {'real floating': cupy.float64, 176 | 'complex floating': cupy.complex128, 177 | 'integral': cupy.int64, 178 | 'indexing': cupy.int64} 179 | 180 | """ 181 | # TODO: Does this depend on device? 182 | return { 183 | "real floating": dtype(float64), 184 | "complex floating": dtype(complex128), 185 | "integral": dtype(intp), 186 | "indexing": dtype(intp), 187 | } 188 | 189 | def dtypes(self, *, device=None, kind=None): 190 | """ 191 | The array API data types supported by CuPy. 192 | 193 | Note that this function only returns data types that are defined by 194 | the array API. 195 | 196 | Parameters 197 | ---------- 198 | device : str, optional 199 | The device to get the data types for. 200 | kind : str or tuple of str, optional 201 | The kind of data types to return. If ``None``, all data types are 202 | returned. If a string, only data types of that kind are returned. 203 | If a tuple, a dictionary containing the union of the given kinds 204 | is returned. The following kinds are supported: 205 | 206 | - ``'bool'``: boolean data types (i.e., ``bool``). 207 | - ``'signed integer'``: signed integer data types (i.e., ``int8``, 208 | ``int16``, ``int32``, ``int64``). 209 | - ``'unsigned integer'``: unsigned integer data types (i.e., 210 | ``uint8``, ``uint16``, ``uint32``, ``uint64``). 211 | - ``'integral'``: integer data types. Shorthand for ``('signed 212 | integer', 'unsigned integer')``. 213 | - ``'real floating'``: real-valued floating-point data types 214 | (i.e., ``float32``, ``float64``). 215 | - ``'complex floating'``: complex floating-point data types (i.e., 216 | ``complex64``, ``complex128``). 217 | - ``'numeric'``: numeric data types. Shorthand for ``('integral', 218 | 'real floating', 'complex floating')``. 219 | 220 | Returns 221 | ------- 222 | dtypes : dict 223 | A dictionary mapping the names of data types to the corresponding 224 | CuPy data types. 225 | 226 | See Also 227 | -------- 228 | __array_namespace_info__.capabilities, 229 | __array_namespace_info__.default_device, 230 | __array_namespace_info__.default_dtypes, 231 | __array_namespace_info__.devices 232 | 233 | Examples 234 | -------- 235 | >>> info = xp.__array_namespace_info__() 236 | >>> info.dtypes(kind='signed integer') 237 | {'int8': cupy.int8, 238 | 'int16': cupy.int16, 239 | 'int32': cupy.int32, 240 | 'int64': cupy.int64} 241 | 242 | """ 243 | # TODO: Does this depend on device? 244 | if kind is None: 245 | return { 246 | "bool": dtype(bool), 247 | "int8": dtype(int8), 248 | "int16": dtype(int16), 249 | "int32": dtype(int32), 250 | "int64": dtype(int64), 251 | "uint8": dtype(uint8), 252 | "uint16": dtype(uint16), 253 | "uint32": dtype(uint32), 254 | "uint64": dtype(uint64), 255 | "float32": dtype(float32), 256 | "float64": dtype(float64), 257 | "complex64": dtype(complex64), 258 | "complex128": dtype(complex128), 259 | } 260 | if kind == "bool": 261 | return {"bool": bool} 262 | if kind == "signed integer": 263 | return { 264 | "int8": dtype(int8), 265 | "int16": dtype(int16), 266 | "int32": dtype(int32), 267 | "int64": dtype(int64), 268 | } 269 | if kind == "unsigned integer": 270 | return { 271 | "uint8": dtype(uint8), 272 | "uint16": dtype(uint16), 273 | "uint32": dtype(uint32), 274 | "uint64": dtype(uint64), 275 | } 276 | if kind == "integral": 277 | return { 278 | "int8": dtype(int8), 279 | "int16": dtype(int16), 280 | "int32": dtype(int32), 281 | "int64": dtype(int64), 282 | "uint8": dtype(uint8), 283 | "uint16": dtype(uint16), 284 | "uint32": dtype(uint32), 285 | "uint64": dtype(uint64), 286 | } 287 | if kind == "real floating": 288 | return { 289 | "float32": dtype(float32), 290 | "float64": dtype(float64), 291 | } 292 | if kind == "complex floating": 293 | return { 294 | "complex64": dtype(complex64), 295 | "complex128": dtype(complex128), 296 | } 297 | if kind == "numeric": 298 | return { 299 | "int8": dtype(int8), 300 | "int16": dtype(int16), 301 | "int32": dtype(int32), 302 | "int64": dtype(int64), 303 | "uint8": dtype(uint8), 304 | "uint16": dtype(uint16), 305 | "uint32": dtype(uint32), 306 | "uint64": dtype(uint64), 307 | "float32": dtype(float32), 308 | "float64": dtype(float64), 309 | "complex64": dtype(complex64), 310 | "complex128": dtype(complex128), 311 | } 312 | if isinstance(kind, tuple): 313 | res = {} 314 | for k in kind: 315 | res.update(self.dtypes(kind=k)) 316 | return res 317 | raise ValueError(f"unsupported kind: {kind!r}") 318 | 319 | def devices(self): 320 | """ 321 | The devices supported by CuPy. 322 | 323 | Returns 324 | ------- 325 | devices : list[Device] 326 | The devices supported by CuPy. 327 | 328 | See Also 329 | -------- 330 | __array_namespace_info__.capabilities, 331 | __array_namespace_info__.default_device, 332 | __array_namespace_info__.default_dtypes, 333 | __array_namespace_info__.dtypes 334 | 335 | """ 336 | return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] 337 | -------------------------------------------------------------------------------- /array_api_compat/cupy/_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["Array", "DType", "Device"] 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import cupy as cp 8 | from cupy import ndarray as Array 9 | from cupy.cuda.device import Device 10 | 11 | if TYPE_CHECKING: 12 | # NumPy 1.x on Python 3.10 fails to parse np.dtype[] 13 | DType = cp.dtype[ 14 | cp.intp 15 | | cp.int8 16 | | cp.int16 17 | | cp.int32 18 | | cp.int64 19 | | cp.uint8 20 | | cp.uint16 21 | | cp.uint32 22 | | cp.uint64 23 | | cp.float32 24 | | cp.float64 25 | | cp.complex64 26 | | cp.complex128 27 | | cp.bool_ 28 | ] 29 | else: 30 | DType = cp.dtype 31 | -------------------------------------------------------------------------------- /array_api_compat/cupy/fft.py: -------------------------------------------------------------------------------- 1 | from cupy.fft import * # noqa: F403 2 | 3 | # cupy.fft doesn't have __all__. If it is added, replace this with 4 | # 5 | # from cupy.fft import __all__ as linalg_all 6 | _n: dict[str, object] = {} 7 | exec("from cupy.fft import *", _n) 8 | del _n["__builtins__"] 9 | fft_all = list(_n) 10 | del _n 11 | 12 | from ..common import _fft 13 | from .._internal import get_xp 14 | 15 | import cupy as cp 16 | 17 | fft = get_xp(cp)(_fft.fft) 18 | ifft = get_xp(cp)(_fft.ifft) 19 | fftn = get_xp(cp)(_fft.fftn) 20 | ifftn = get_xp(cp)(_fft.ifftn) 21 | rfft = get_xp(cp)(_fft.rfft) 22 | irfft = get_xp(cp)(_fft.irfft) 23 | rfftn = get_xp(cp)(_fft.rfftn) 24 | irfftn = get_xp(cp)(_fft.irfftn) 25 | hfft = get_xp(cp)(_fft.hfft) 26 | ihfft = get_xp(cp)(_fft.ihfft) 27 | fftfreq = get_xp(cp)(_fft.fftfreq) 28 | rfftfreq = get_xp(cp)(_fft.rfftfreq) 29 | fftshift = get_xp(cp)(_fft.fftshift) 30 | ifftshift = get_xp(cp)(_fft.ifftshift) 31 | 32 | __all__ = fft_all + _fft.__all__ 33 | 34 | def __dir__() -> list[str]: 35 | return __all__ 36 | 37 | -------------------------------------------------------------------------------- /array_api_compat/cupy/linalg.py: -------------------------------------------------------------------------------- 1 | from cupy.linalg import * # noqa: F403 2 | # cupy.linalg doesn't have __all__. If it is added, replace this with 3 | # 4 | # from cupy.linalg import __all__ as linalg_all 5 | _n: dict[str, object] = {} 6 | exec('from cupy.linalg import *', _n) 7 | del _n['__builtins__'] 8 | linalg_all = list(_n) 9 | del _n 10 | 11 | from ..common import _linalg 12 | from .._internal import get_xp 13 | 14 | import cupy as cp 15 | 16 | # These functions are in both the main and linalg namespaces 17 | from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 18 | 19 | cross = get_xp(cp)(_linalg.cross) 20 | outer = get_xp(cp)(_linalg.outer) 21 | EighResult = _linalg.EighResult 22 | QRResult = _linalg.QRResult 23 | SlogdetResult = _linalg.SlogdetResult 24 | SVDResult = _linalg.SVDResult 25 | eigh = get_xp(cp)(_linalg.eigh) 26 | qr = get_xp(cp)(_linalg.qr) 27 | slogdet = get_xp(cp)(_linalg.slogdet) 28 | svd = get_xp(cp)(_linalg.svd) 29 | cholesky = get_xp(cp)(_linalg.cholesky) 30 | matrix_rank = get_xp(cp)(_linalg.matrix_rank) 31 | pinv = get_xp(cp)(_linalg.pinv) 32 | matrix_norm = get_xp(cp)(_linalg.matrix_norm) 33 | svdvals = get_xp(cp)(_linalg.svdvals) 34 | diagonal = get_xp(cp)(_linalg.diagonal) 35 | trace = get_xp(cp)(_linalg.trace) 36 | 37 | # These functions are completely new here. If the library already has them 38 | # (i.e., numpy 2.0), use the library version instead of our wrapper. 39 | if hasattr(cp.linalg, 'vector_norm'): 40 | vector_norm = cp.linalg.vector_norm 41 | else: 42 | vector_norm = get_xp(cp)(_linalg.vector_norm) 43 | 44 | __all__ = linalg_all + _linalg.__all__ 45 | 46 | def __dir__() -> list[str]: 47 | return __all__ 48 | -------------------------------------------------------------------------------- /array_api_compat/dask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/array_api_compat/dask/__init__.py -------------------------------------------------------------------------------- /array_api_compat/dask/array/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Final 2 | 3 | from ..._internal import clone_module 4 | 5 | __all__ = clone_module("dask.array", globals()) 6 | 7 | # These imports may overwrite names from the import * above. 8 | from . import _aliases 9 | from ._aliases import * # type: ignore[assignment] # noqa: F403 10 | from ._info import __array_namespace_info__ # noqa: F401 11 | 12 | __array_api_version__: Final = "2024.12" 13 | del Final 14 | 15 | # See the comment in the numpy __init__.py 16 | __import__(__package__ + '.linalg') 17 | __import__(__package__ + '.fft') 18 | 19 | __all__ = sorted( 20 | set(__all__) 21 | | set(_aliases.__all__) 22 | | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} 23 | ) 24 | 25 | def __dir__() -> list[str]: 26 | return __all__ 27 | -------------------------------------------------------------------------------- /array_api_compat/dask/array/_aliases.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPrivateUsage=false 2 | # pyright: reportUnknownArgumentType=false 3 | # pyright: reportUnknownMemberType=false 4 | # pyright: reportUnknownVariableType=false 5 | 6 | from __future__ import annotations 7 | 8 | from builtins import bool as py_bool 9 | from collections.abc import Callable 10 | from typing import TYPE_CHECKING, Any 11 | 12 | if TYPE_CHECKING: 13 | from typing_extensions import TypeIs 14 | 15 | import dask.array as da 16 | import numpy as np 17 | from numpy import bool_ as bool 18 | from numpy import ( 19 | can_cast, 20 | complex64, 21 | complex128, 22 | float32, 23 | float64, 24 | int8, 25 | int16, 26 | int32, 27 | int64, 28 | result_type, 29 | uint8, 30 | uint16, 31 | uint32, 32 | uint64, 33 | ) 34 | 35 | from ..._internal import get_xp 36 | from ...common import _aliases, _helpers, array_namespace 37 | from ...common._typing import ( 38 | Array, 39 | Device, 40 | DType, 41 | NestedSequence, 42 | SupportsBufferProtocol, 43 | ) 44 | 45 | isdtype = get_xp(np)(_aliases.isdtype) 46 | unstack = get_xp(da)(_aliases.unstack) 47 | 48 | 49 | # da.astype doesn't respect copy=True 50 | def astype( 51 | x: Array, 52 | dtype: DType, 53 | /, 54 | *, 55 | copy: py_bool = True, 56 | device: Device | None = None, 57 | ) -> Array: 58 | """ 59 | Array API compatibility wrapper for astype(). 60 | 61 | See the corresponding documentation in the array library and/or the array API 62 | specification for more details. 63 | """ 64 | # TODO: respect device keyword? 65 | _helpers._check_device(da, device) 66 | 67 | if not copy and dtype == x.dtype: 68 | return x 69 | x = x.astype(dtype) 70 | return x.copy() if copy else x 71 | 72 | 73 | # Common aliases 74 | 75 | 76 | # This arange func is modified from the common one to 77 | # not pass stop/step as keyword arguments, which will cause 78 | # an error with dask 79 | def arange( 80 | start: float, 81 | /, 82 | stop: float | None = None, 83 | step: float = 1, 84 | *, 85 | dtype: DType | None = None, 86 | device: Device | None = None, 87 | **kwargs: object, 88 | ) -> Array: 89 | """ 90 | Array API compatibility wrapper for arange(). 91 | 92 | See the corresponding documentation in the array library and/or the array API 93 | specification for more details. 94 | """ 95 | # TODO: respect device keyword? 96 | _helpers._check_device(da, device) 97 | 98 | args: list[Any] = [start] 99 | if stop is not None: 100 | args.append(stop) 101 | else: 102 | # stop is None, so start is actually stop 103 | # prepend the default value for start which is 0 104 | args.insert(0, 0) 105 | args.append(step) 106 | 107 | return da.arange(*args, dtype=dtype, **kwargs) 108 | 109 | 110 | eye = get_xp(da)(_aliases.eye) 111 | linspace = get_xp(da)(_aliases.linspace) 112 | UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) 113 | UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) 114 | UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) 115 | unique_all = get_xp(da)(_aliases.unique_all) 116 | unique_counts = get_xp(da)(_aliases.unique_counts) 117 | unique_inverse = get_xp(da)(_aliases.unique_inverse) 118 | unique_values = get_xp(da)(_aliases.unique_values) 119 | permute_dims = get_xp(da)(_aliases.permute_dims) 120 | std = get_xp(da)(_aliases.std) 121 | var = get_xp(da)(_aliases.var) 122 | cumulative_sum = get_xp(da)(_aliases.cumulative_sum) 123 | cumulative_prod = get_xp(da)(_aliases.cumulative_prod) 124 | empty = get_xp(da)(_aliases.empty) 125 | empty_like = get_xp(da)(_aliases.empty_like) 126 | full = get_xp(da)(_aliases.full) 127 | full_like = get_xp(da)(_aliases.full_like) 128 | ones = get_xp(da)(_aliases.ones) 129 | ones_like = get_xp(da)(_aliases.ones_like) 130 | zeros = get_xp(da)(_aliases.zeros) 131 | zeros_like = get_xp(da)(_aliases.zeros_like) 132 | reshape = get_xp(da)(_aliases.reshape) 133 | matrix_transpose = get_xp(da)(_aliases.matrix_transpose) 134 | vecdot = get_xp(da)(_aliases.vecdot) 135 | nonzero = get_xp(da)(_aliases.nonzero) 136 | ceil = get_xp(np)(_aliases.ceil) 137 | floor = get_xp(np)(_aliases.floor) 138 | trunc = get_xp(np)(_aliases.trunc) 139 | matmul = get_xp(np)(_aliases.matmul) 140 | tensordot = get_xp(np)(_aliases.tensordot) 141 | sign = get_xp(np)(_aliases.sign) 142 | finfo = get_xp(np)(_aliases.finfo) 143 | iinfo = get_xp(np)(_aliases.iinfo) 144 | 145 | 146 | # asarray also adds the copy keyword, which is not present in numpy 1.0. 147 | def asarray( 148 | obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, 149 | /, 150 | *, 151 | dtype: DType | None = None, 152 | device: Device | None = None, 153 | copy: py_bool | None = None, 154 | **kwargs: object, 155 | ) -> Array: 156 | """ 157 | Array API compatibility wrapper for asarray(). 158 | 159 | See the corresponding documentation in the array library and/or the array API 160 | specification for more details. 161 | """ 162 | # TODO: respect device keyword? 163 | _helpers._check_device(da, device) 164 | 165 | if isinstance(obj, da.Array): 166 | if dtype is not None and dtype != obj.dtype: 167 | if copy is False: 168 | raise ValueError("Unable to avoid copy when changing dtype") 169 | obj = obj.astype(dtype) 170 | return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] 171 | 172 | if copy is False: 173 | raise ValueError( 174 | "Unable to avoid copy when converting a non-dask object to dask" 175 | ) 176 | 177 | # copy=None to be uniform across dask < 2024.12 and >= 2024.12 178 | # see https://github.com/dask/dask/pull/11524/ 179 | obj = np.array(obj, dtype=dtype, copy=True) 180 | return da.from_array(obj) 181 | 182 | 183 | # Element wise aliases 184 | from dask.array import arccos as acos 185 | from dask.array import arccosh as acosh 186 | from dask.array import arcsin as asin 187 | from dask.array import arcsinh as asinh 188 | from dask.array import arctan as atan 189 | from dask.array import arctan2 as atan2 190 | from dask.array import arctanh as atanh 191 | 192 | # Other 193 | from dask.array import concatenate as concat 194 | from dask.array import invert as bitwise_invert 195 | from dask.array import left_shift as bitwise_left_shift 196 | from dask.array import power as pow 197 | from dask.array import right_shift as bitwise_right_shift 198 | 199 | 200 | # dask.array.clip does not work unless all three arguments are provided. 201 | # Furthermore, the masking workaround in common._aliases.clip cannot work with 202 | # dask (meaning uint64 promoting to float64 is going to just be unfixed for 203 | # now). 204 | def clip( 205 | x: Array, 206 | /, 207 | min: float | Array | None = None, 208 | max: float | Array | None = None, 209 | ) -> Array: 210 | """ 211 | Array API compatibility wrapper for clip(). 212 | 213 | See the corresponding documentation in the array library and/or the array API 214 | specification for more details. 215 | """ 216 | 217 | def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: 218 | return a is None or isinstance(a, (int, float)) 219 | 220 | min_shape = () if _isscalar(min) else min.shape 221 | max_shape = () if _isscalar(max) else max.shape 222 | 223 | # TODO: This won't handle dask unknown shapes 224 | result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) 225 | 226 | if min is not None: 227 | min = da.broadcast_to(da.asarray(min), result_shape) 228 | if max is not None: 229 | max = da.broadcast_to(da.asarray(max), result_shape) 230 | 231 | if min is None and max is None: 232 | return da.positive(x) 233 | 234 | if min is None: 235 | return astype(da.minimum(x, max), x.dtype) 236 | if max is None: 237 | return astype(da.maximum(x, min), x.dtype) 238 | 239 | return astype(da.minimum(da.maximum(x, min), max), x.dtype) 240 | 241 | 242 | def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]: 243 | """ 244 | Make sure that Array is not broken into multiple chunks along axis. 245 | 246 | Returns 247 | ------- 248 | x : Array 249 | The input Array with a single chunk along axis. 250 | restore : Callable[Array, Array] 251 | function to apply to the output to rechunk it back into reasonable chunks 252 | """ 253 | if axis < 0: 254 | axis += x.ndim 255 | if x.numblocks[axis] < 2: 256 | return x, lambda x: x 257 | 258 | # Break chunks on other axes in an attempt to keep chunk size low 259 | x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)}) 260 | 261 | # Rather than reconstructing the original chunks, which can be a 262 | # very expensive affair, just break down oversized chunks without 263 | # incurring in any transfers over the network. 264 | # This has the downside of a risk of overchunking if the array is 265 | # then used in operations against other arrays that match the 266 | # original chunking pattern. 267 | return x, lambda x: x.rechunk() 268 | 269 | 270 | def sort( 271 | x: Array, 272 | /, 273 | *, 274 | axis: int = -1, 275 | descending: py_bool = False, 276 | stable: py_bool = True, 277 | ) -> Array: 278 | """ 279 | Array API compatibility layer around the lack of sort() in Dask. 280 | 281 | Warnings 282 | -------- 283 | This function temporarily rechunks the array along `axis` to a single chunk. 284 | This can be extremely inefficient and can lead to out-of-memory errors. 285 | 286 | See the corresponding documentation in the array library and/or the array API 287 | specification for more details. 288 | """ 289 | x, restore = _ensure_single_chunk(x, axis) 290 | 291 | meta_xp = array_namespace(x._meta) 292 | x = da.map_blocks( 293 | meta_xp.sort, 294 | x, 295 | axis=axis, 296 | meta=x._meta, 297 | dtype=x.dtype, 298 | descending=descending, 299 | stable=stable, 300 | ) 301 | 302 | return restore(x) 303 | 304 | 305 | def argsort( 306 | x: Array, 307 | /, 308 | *, 309 | axis: int = -1, 310 | descending: py_bool = False, 311 | stable: py_bool = True, 312 | ) -> Array: 313 | """ 314 | Array API compatibility layer around the lack of argsort() in Dask. 315 | 316 | See the corresponding documentation in the array library and/or the array API 317 | specification for more details. 318 | 319 | Warnings 320 | -------- 321 | This function temporarily rechunks the array along `axis` into a single chunk. 322 | This can be extremely inefficient and can lead to out-of-memory errors. 323 | """ 324 | x, restore = _ensure_single_chunk(x, axis) 325 | 326 | meta_xp = array_namespace(x._meta) 327 | dtype = meta_xp.argsort(x._meta).dtype 328 | meta = meta_xp.astype(x._meta, dtype) 329 | x = da.map_blocks( 330 | meta_xp.argsort, 331 | x, 332 | axis=axis, 333 | meta=meta, 334 | dtype=dtype, 335 | descending=descending, 336 | stable=stable, 337 | ) 338 | 339 | return restore(x) 340 | 341 | 342 | # dask.array.count_nonzero does not have keepdims 343 | def count_nonzero( 344 | x: Array, 345 | axis: int | None = None, 346 | keepdims: py_bool = False, 347 | ) -> Array: 348 | result = da.count_nonzero(x, axis) 349 | if keepdims: 350 | if axis is None: 351 | return da.reshape(result, [1] * x.ndim) 352 | return da.expand_dims(result, axis) 353 | return result 354 | 355 | 356 | __all__ = [ 357 | "count_nonzero", 358 | "bool", 359 | "int8", "int16", "int32", "int64", 360 | "uint8", "uint16", "uint32", "uint64", 361 | "float32", "float64", 362 | "complex64", "complex128", 363 | "asarray", "astype", "can_cast", "result_type", 364 | "pow", 365 | "concat", 366 | "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", 367 | "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", 368 | ] # fmt: skip 369 | __all__ += _aliases.__all__ 370 | 371 | def __dir__() -> list[str]: 372 | return __all__ 373 | -------------------------------------------------------------------------------- /array_api_compat/dask/array/fft.py: -------------------------------------------------------------------------------- 1 | from ..._internal import clone_module 2 | 3 | __all__ = clone_module("dask.array.fft", globals()) 4 | 5 | from ...common import _fft 6 | from ..._internal import get_xp 7 | 8 | import dask.array as da 9 | 10 | fftfreq = get_xp(da)(_fft.fftfreq) 11 | rfftfreq = get_xp(da)(_fft.rfftfreq) 12 | 13 | __all__ += ["fftfreq", "rfftfreq"] 14 | 15 | def __dir__() -> list[str]: 16 | return __all__ 17 | -------------------------------------------------------------------------------- /array_api_compat/dask/array/linalg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | import dask.array as da 6 | 7 | # The `matmul` and `tensordot` functions are in both the main and linalg namespaces 8 | from dask.array import matmul, outer, tensordot 9 | 10 | # Exports 11 | from ..._internal import clone_module, get_xp 12 | from ...common import _linalg 13 | from ...common._typing import Array 14 | 15 | __all__ = clone_module("dask.array.linalg", globals()) 16 | 17 | from ._aliases import matrix_transpose, vecdot 18 | 19 | EighResult = _linalg.EighResult 20 | QRResult = _linalg.QRResult 21 | SlogdetResult = _linalg.SlogdetResult 22 | SVDResult = _linalg.SVDResult 23 | # TODO: use the QR wrapper once dask 24 | # supports the mode keyword on QR 25 | # https://github.com/dask/dask/issues/10388 26 | #qr = get_xp(da)(_linalg.qr) 27 | def qr( # type: ignore[no-redef] 28 | x: Array, 29 | mode: Literal["reduced", "complete"] = "reduced", 30 | **kwargs: object, 31 | ) -> QRResult: 32 | if mode != "reduced": 33 | raise ValueError("dask arrays only support using mode='reduced'") 34 | return QRResult(*da.linalg.qr(x, **kwargs)) 35 | trace = get_xp(da)(_linalg.trace) 36 | cholesky = get_xp(da)(_linalg.cholesky) 37 | matrix_rank = get_xp(da)(_linalg.matrix_rank) 38 | matrix_norm = get_xp(da)(_linalg.matrix_norm) 39 | 40 | 41 | # Wrap the svd functions to not pass full_matrices to dask 42 | # when full_matrices=False (as that is the default behavior for dask), 43 | # and dask doesn't have the full_matrices keyword 44 | def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] 45 | if full_matrices: 46 | raise ValueError("full_matrics=True is not supported by dask.") 47 | return da.linalg.svd(x, coerce_signs=False, **kwargs) 48 | 49 | def svdvals(x: Array) -> Array: 50 | # TODO: can't avoid computing U or V for dask 51 | _, s, _ = svd(x) 52 | return s 53 | 54 | vector_norm = get_xp(da)(_linalg.vector_norm) 55 | diagonal = get_xp(da)(_linalg.diagonal) 56 | 57 | __all__ += ["trace", "outer", "matmul", "tensordot", 58 | "matrix_transpose", "vecdot", "EighResult", 59 | "QRResult", "SlogdetResult", "SVDResult", "qr", 60 | "cholesky", "matrix_rank", "matrix_norm", "svdvals", 61 | "vector_norm", "diagonal"] 62 | 63 | def __dir__() -> list[str]: 64 | return __all__ 65 | -------------------------------------------------------------------------------- /array_api_compat/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: PLC0414 2 | from typing import Final 3 | 4 | from .._internal import clone_module 5 | 6 | # This needs to be loaded explicitly before cloning 7 | import numpy.typing # noqa: F401 8 | 9 | __all__ = clone_module("numpy", globals()) 10 | 11 | # These imports may overwrite names from the import * above. 12 | from . import _aliases 13 | from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 14 | from ._info import __array_namespace_info__ # noqa: F401 15 | 16 | # Don't know why, but we have to do an absolute import to import linalg. If we 17 | # instead do 18 | # 19 | # from . import linalg 20 | # 21 | # It doesn't overwrite np.linalg from above. The import is generated 22 | # dynamically so that the library can be vendored. 23 | __import__(__package__ + ".linalg") 24 | 25 | __import__(__package__ + ".fft") 26 | 27 | from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 28 | 29 | __array_api_version__: Final = "2024.12" 30 | 31 | __all__ = sorted( 32 | set(__all__) 33 | | set(_aliases.__all__) 34 | | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} 35 | ) 36 | 37 | def __dir__() -> list[str]: 38 | return __all__ 39 | -------------------------------------------------------------------------------- /array_api_compat/numpy/_aliases.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPrivateUsage=false 2 | from __future__ import annotations 3 | 4 | from builtins import bool as py_bool 5 | from typing import Any, cast 6 | 7 | import numpy as np 8 | 9 | from .._internal import get_xp 10 | from ..common import _aliases, _helpers 11 | from ..common._typing import NestedSequence, SupportsBufferProtocol 12 | from ._typing import Array, Device, DType 13 | 14 | bool = np.bool_ 15 | 16 | # Basic renames 17 | acos = np.arccos 18 | acosh = np.arccosh 19 | asin = np.arcsin 20 | asinh = np.arcsinh 21 | atan = np.arctan 22 | atan2 = np.arctan2 23 | atanh = np.arctanh 24 | bitwise_left_shift = np.left_shift 25 | bitwise_invert = np.invert 26 | bitwise_right_shift = np.right_shift 27 | concat = np.concatenate 28 | pow = np.power 29 | 30 | arange = get_xp(np)(_aliases.arange) 31 | empty = get_xp(np)(_aliases.empty) 32 | empty_like = get_xp(np)(_aliases.empty_like) 33 | eye = get_xp(np)(_aliases.eye) 34 | full = get_xp(np)(_aliases.full) 35 | full_like = get_xp(np)(_aliases.full_like) 36 | linspace = get_xp(np)(_aliases.linspace) 37 | ones = get_xp(np)(_aliases.ones) 38 | ones_like = get_xp(np)(_aliases.ones_like) 39 | zeros = get_xp(np)(_aliases.zeros) 40 | zeros_like = get_xp(np)(_aliases.zeros_like) 41 | UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult) 42 | UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult) 43 | UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult) 44 | unique_all = get_xp(np)(_aliases.unique_all) 45 | unique_counts = get_xp(np)(_aliases.unique_counts) 46 | unique_inverse = get_xp(np)(_aliases.unique_inverse) 47 | unique_values = get_xp(np)(_aliases.unique_values) 48 | std = get_xp(np)(_aliases.std) 49 | var = get_xp(np)(_aliases.var) 50 | cumulative_sum = get_xp(np)(_aliases.cumulative_sum) 51 | cumulative_prod = get_xp(np)(_aliases.cumulative_prod) 52 | clip = get_xp(np)(_aliases.clip) 53 | permute_dims = get_xp(np)(_aliases.permute_dims) 54 | reshape = get_xp(np)(_aliases.reshape) 55 | argsort = get_xp(np)(_aliases.argsort) 56 | sort = get_xp(np)(_aliases.sort) 57 | nonzero = get_xp(np)(_aliases.nonzero) 58 | ceil = get_xp(np)(_aliases.ceil) 59 | floor = get_xp(np)(_aliases.floor) 60 | trunc = get_xp(np)(_aliases.trunc) 61 | matmul = get_xp(np)(_aliases.matmul) 62 | matrix_transpose = get_xp(np)(_aliases.matrix_transpose) 63 | tensordot = get_xp(np)(_aliases.tensordot) 64 | sign = get_xp(np)(_aliases.sign) 65 | finfo = get_xp(np)(_aliases.finfo) 66 | iinfo = get_xp(np)(_aliases.iinfo) 67 | 68 | 69 | # asarray also adds the copy keyword, which is not present in numpy 1.0. 70 | # asarray() is different enough between numpy, cupy, and dask, the logic 71 | # complicated enough that it's easier to define it separately for each module 72 | # rather than trying to combine everything into one function in common/ 73 | def asarray( 74 | obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, 75 | /, 76 | *, 77 | dtype: DType | None = None, 78 | device: Device | None = None, 79 | copy: py_bool | None = None, 80 | **kwargs: Any, 81 | ) -> Array: 82 | """ 83 | Array API compatibility wrapper for asarray(). 84 | 85 | See the corresponding documentation in the array library and/or the array API 86 | specification for more details. 87 | """ 88 | _helpers._check_device(np, device) 89 | 90 | # None is unsupported in NumPy 1.0, but we can use an internal enum 91 | # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API 92 | if copy is None: 93 | copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] 94 | elif copy is False: 95 | copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] 96 | 97 | return np.array(obj, copy=copy, dtype=dtype, **kwargs) 98 | 99 | 100 | def astype( 101 | x: Array, 102 | dtype: DType, 103 | /, 104 | *, 105 | copy: py_bool = True, 106 | device: Device | None = None, 107 | ) -> Array: 108 | _helpers._check_device(np, device) 109 | return x.astype(dtype=dtype, copy=copy) 110 | 111 | 112 | # count_nonzero returns a python int for axis=None and keepdims=False 113 | # https://github.com/numpy/numpy/issues/17562 114 | def count_nonzero( 115 | x: Array, 116 | axis: int | tuple[int, ...] | None = None, 117 | keepdims: py_bool = False, 118 | ) -> Array: 119 | # NOTE: this is currently incorrectly typed in numpy, but will be fixed in 120 | # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 121 | result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] 122 | if axis is None and not keepdims: 123 | return np.asarray(result) 124 | return result 125 | 126 | 127 | # take_along_axis: axis defaults to -1 but in numpy axis is a required arg 128 | def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 129 | return np.take_along_axis(x, indices, axis=axis) 130 | 131 | 132 | # These functions are completely new here. If the library already has them 133 | # (i.e., numpy 2.0), use the library version instead of our wrapper. 134 | if hasattr(np, "vecdot"): 135 | vecdot = np.vecdot 136 | else: 137 | vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] 138 | 139 | if hasattr(np, "isdtype"): 140 | isdtype = np.isdtype 141 | else: 142 | isdtype = get_xp(np)(_aliases.isdtype) 143 | 144 | if hasattr(np, "unstack"): 145 | unstack = np.unstack 146 | else: 147 | unstack = get_xp(np)(_aliases.unstack) 148 | 149 | __all__ = _aliases.__all__ + [ 150 | "asarray", 151 | "astype", 152 | "acos", 153 | "acosh", 154 | "asin", 155 | "asinh", 156 | "atan", 157 | "atan2", 158 | "atanh", 159 | "bitwise_left_shift", 160 | "bitwise_invert", 161 | "bitwise_right_shift", 162 | "bool", 163 | "concat", 164 | "count_nonzero", 165 | "pow", 166 | "take_along_axis" 167 | ] 168 | 169 | 170 | def __dir__() -> list[str]: 171 | return __all__ 172 | -------------------------------------------------------------------------------- /array_api_compat/numpy/_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | Array API Inspection namespace 3 | 4 | This is the namespace for inspection functions as defined by the array API 5 | standard. See 6 | https://data-apis.org/array-api/latest/API_specification/inspection.html for 7 | more details. 8 | 9 | """ 10 | from __future__ import annotations 11 | 12 | from numpy import bool_ as bool 13 | from numpy import ( 14 | complex64, 15 | complex128, 16 | dtype, 17 | float32, 18 | float64, 19 | int8, 20 | int16, 21 | int32, 22 | int64, 23 | intp, 24 | uint8, 25 | uint16, 26 | uint32, 27 | uint64, 28 | ) 29 | 30 | from ..common._typing import DefaultDTypes 31 | from ._typing import Device, DType 32 | 33 | 34 | class __array_namespace_info__: 35 | """ 36 | Get the array API inspection namespace for NumPy. 37 | 38 | The array API inspection namespace defines the following functions: 39 | 40 | - capabilities() 41 | - default_device() 42 | - default_dtypes() 43 | - dtypes() 44 | - devices() 45 | 46 | See 47 | https://data-apis.org/array-api/latest/API_specification/inspection.html 48 | for more details. 49 | 50 | Returns 51 | ------- 52 | info : ModuleType 53 | The array API inspection namespace for NumPy. 54 | 55 | Examples 56 | -------- 57 | >>> info = np.__array_namespace_info__() 58 | >>> info.default_dtypes() 59 | {'real floating': numpy.float64, 60 | 'complex floating': numpy.complex128, 61 | 'integral': numpy.int64, 62 | 'indexing': numpy.int64} 63 | 64 | """ 65 | 66 | __module__ = 'numpy' 67 | 68 | def capabilities(self): 69 | """ 70 | Return a dictionary of array API library capabilities. 71 | 72 | The resulting dictionary has the following keys: 73 | 74 | - **"boolean indexing"**: boolean indicating whether an array library 75 | supports boolean indexing. Always ``True`` for NumPy. 76 | 77 | - **"data-dependent shapes"**: boolean indicating whether an array 78 | library supports data-dependent output shapes. Always ``True`` for 79 | NumPy. 80 | 81 | See 82 | https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html 83 | for more details. 84 | 85 | See Also 86 | -------- 87 | __array_namespace_info__.default_device, 88 | __array_namespace_info__.default_dtypes, 89 | __array_namespace_info__.dtypes, 90 | __array_namespace_info__.devices 91 | 92 | Returns 93 | ------- 94 | capabilities : dict 95 | A dictionary of array API library capabilities. 96 | 97 | Examples 98 | -------- 99 | >>> info = np.__array_namespace_info__() 100 | >>> info.capabilities() 101 | {'boolean indexing': True, 102 | 'data-dependent shapes': True, 103 | 'max dimensions': 64} 104 | 105 | """ 106 | return { 107 | "boolean indexing": True, 108 | "data-dependent shapes": True, 109 | "max dimensions": 64, 110 | } 111 | 112 | def default_device(self): 113 | """ 114 | The default device used for new NumPy arrays. 115 | 116 | For NumPy, this always returns ``'cpu'``. 117 | 118 | See Also 119 | -------- 120 | __array_namespace_info__.capabilities, 121 | __array_namespace_info__.default_dtypes, 122 | __array_namespace_info__.dtypes, 123 | __array_namespace_info__.devices 124 | 125 | Returns 126 | ------- 127 | device : Device 128 | The default device used for new NumPy arrays. 129 | 130 | Examples 131 | -------- 132 | >>> info = np.__array_namespace_info__() 133 | >>> info.default_device() 134 | 'cpu' 135 | 136 | """ 137 | return "cpu" 138 | 139 | def default_dtypes( 140 | self, 141 | *, 142 | device: Device | None = None, 143 | ) -> DefaultDTypes: 144 | """ 145 | The default data types used for new NumPy arrays. 146 | 147 | For NumPy, this always returns the following dictionary: 148 | 149 | - **"real floating"**: ``numpy.float64`` 150 | - **"complex floating"**: ``numpy.complex128`` 151 | - **"integral"**: ``numpy.intp`` 152 | - **"indexing"**: ``numpy.intp`` 153 | 154 | Parameters 155 | ---------- 156 | device : str, optional 157 | The device to get the default data types for. For NumPy, only 158 | ``'cpu'`` is allowed. 159 | 160 | Returns 161 | ------- 162 | dtypes : dict 163 | A dictionary describing the default data types used for new NumPy 164 | arrays. 165 | 166 | See Also 167 | -------- 168 | __array_namespace_info__.capabilities, 169 | __array_namespace_info__.default_device, 170 | __array_namespace_info__.dtypes, 171 | __array_namespace_info__.devices 172 | 173 | Examples 174 | -------- 175 | >>> info = np.__array_namespace_info__() 176 | >>> info.default_dtypes() 177 | {'real floating': numpy.float64, 178 | 'complex floating': numpy.complex128, 179 | 'integral': numpy.int64, 180 | 'indexing': numpy.int64} 181 | 182 | """ 183 | if device not in ["cpu", None]: 184 | raise ValueError( 185 | 'Device not understood. Only "cpu" is allowed, but received:' 186 | f' {device}' 187 | ) 188 | return { 189 | "real floating": dtype(float64), 190 | "complex floating": dtype(complex128), 191 | "integral": dtype(intp), 192 | "indexing": dtype(intp), 193 | } 194 | 195 | def dtypes( 196 | self, 197 | *, 198 | device: Device | None = None, 199 | kind: str | tuple[str, ...] | None = None, 200 | ) -> dict[str, DType]: 201 | """ 202 | The array API data types supported by NumPy. 203 | 204 | Note that this function only returns data types that are defined by 205 | the array API. 206 | 207 | Parameters 208 | ---------- 209 | device : str, optional 210 | The device to get the data types for. For NumPy, only ``'cpu'`` is 211 | allowed. 212 | kind : str or tuple of str, optional 213 | The kind of data types to return. If ``None``, all data types are 214 | returned. If a string, only data types of that kind are returned. 215 | If a tuple, a dictionary containing the union of the given kinds 216 | is returned. The following kinds are supported: 217 | 218 | - ``'bool'``: boolean data types (i.e., ``bool``). 219 | - ``'signed integer'``: signed integer data types (i.e., ``int8``, 220 | ``int16``, ``int32``, ``int64``). 221 | - ``'unsigned integer'``: unsigned integer data types (i.e., 222 | ``uint8``, ``uint16``, ``uint32``, ``uint64``). 223 | - ``'integral'``: integer data types. Shorthand for ``('signed 224 | integer', 'unsigned integer')``. 225 | - ``'real floating'``: real-valued floating-point data types 226 | (i.e., ``float32``, ``float64``). 227 | - ``'complex floating'``: complex floating-point data types (i.e., 228 | ``complex64``, ``complex128``). 229 | - ``'numeric'``: numeric data types. Shorthand for ``('integral', 230 | 'real floating', 'complex floating')``. 231 | 232 | Returns 233 | ------- 234 | dtypes : dict 235 | A dictionary mapping the names of data types to the corresponding 236 | NumPy data types. 237 | 238 | See Also 239 | -------- 240 | __array_namespace_info__.capabilities, 241 | __array_namespace_info__.default_device, 242 | __array_namespace_info__.default_dtypes, 243 | __array_namespace_info__.devices 244 | 245 | Examples 246 | -------- 247 | >>> info = np.__array_namespace_info__() 248 | >>> info.dtypes(kind='signed integer') 249 | {'int8': numpy.int8, 250 | 'int16': numpy.int16, 251 | 'int32': numpy.int32, 252 | 'int64': numpy.int64} 253 | 254 | """ 255 | if device not in ["cpu", None]: 256 | raise ValueError( 257 | 'Device not understood. Only "cpu" is allowed, but received:' 258 | f' {device}' 259 | ) 260 | if kind is None: 261 | return { 262 | "bool": dtype(bool), 263 | "int8": dtype(int8), 264 | "int16": dtype(int16), 265 | "int32": dtype(int32), 266 | "int64": dtype(int64), 267 | "uint8": dtype(uint8), 268 | "uint16": dtype(uint16), 269 | "uint32": dtype(uint32), 270 | "uint64": dtype(uint64), 271 | "float32": dtype(float32), 272 | "float64": dtype(float64), 273 | "complex64": dtype(complex64), 274 | "complex128": dtype(complex128), 275 | } 276 | if kind == "bool": 277 | return {"bool": dtype(bool)} 278 | if kind == "signed integer": 279 | return { 280 | "int8": dtype(int8), 281 | "int16": dtype(int16), 282 | "int32": dtype(int32), 283 | "int64": dtype(int64), 284 | } 285 | if kind == "unsigned integer": 286 | return { 287 | "uint8": dtype(uint8), 288 | "uint16": dtype(uint16), 289 | "uint32": dtype(uint32), 290 | "uint64": dtype(uint64), 291 | } 292 | if kind == "integral": 293 | return { 294 | "int8": dtype(int8), 295 | "int16": dtype(int16), 296 | "int32": dtype(int32), 297 | "int64": dtype(int64), 298 | "uint8": dtype(uint8), 299 | "uint16": dtype(uint16), 300 | "uint32": dtype(uint32), 301 | "uint64": dtype(uint64), 302 | } 303 | if kind == "real floating": 304 | return { 305 | "float32": dtype(float32), 306 | "float64": dtype(float64), 307 | } 308 | if kind == "complex floating": 309 | return { 310 | "complex64": dtype(complex64), 311 | "complex128": dtype(complex128), 312 | } 313 | if kind == "numeric": 314 | return { 315 | "int8": dtype(int8), 316 | "int16": dtype(int16), 317 | "int32": dtype(int32), 318 | "int64": dtype(int64), 319 | "uint8": dtype(uint8), 320 | "uint16": dtype(uint16), 321 | "uint32": dtype(uint32), 322 | "uint64": dtype(uint64), 323 | "float32": dtype(float32), 324 | "float64": dtype(float64), 325 | "complex64": dtype(complex64), 326 | "complex128": dtype(complex128), 327 | } 328 | if isinstance(kind, tuple): 329 | res: dict[str, DType] = {} 330 | for k in kind: 331 | res.update(self.dtypes(kind=k)) 332 | return res 333 | raise ValueError(f"unsupported kind: {kind!r}") 334 | 335 | def devices(self) -> list[Device]: 336 | """ 337 | The devices supported by NumPy. 338 | 339 | For NumPy, this always returns ``['cpu']``. 340 | 341 | Returns 342 | ------- 343 | devices : list[Device] 344 | The devices supported by NumPy. 345 | 346 | See Also 347 | -------- 348 | __array_namespace_info__.capabilities, 349 | __array_namespace_info__.default_device, 350 | __array_namespace_info__.default_dtypes, 351 | __array_namespace_info__.dtypes 352 | 353 | Examples 354 | -------- 355 | >>> info = np.__array_namespace_info__() 356 | >>> info.devices() 357 | ['cpu'] 358 | 359 | """ 360 | return ["cpu"] 361 | 362 | 363 | __all__ = ["__array_namespace_info__"] 364 | 365 | 366 | def __dir__() -> list[str]: 367 | return __all__ 368 | -------------------------------------------------------------------------------- /array_api_compat/numpy/_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Literal, TypeAlias 4 | 5 | import numpy as np 6 | 7 | Device: TypeAlias = Literal["cpu"] 8 | 9 | if TYPE_CHECKING: 10 | 11 | # NumPy 1.x on Python 3.10 fails to parse np.dtype[] 12 | DType: TypeAlias = np.dtype[ 13 | np.bool_ 14 | | np.integer[Any] 15 | | np.float32 16 | | np.float64 17 | | np.complex64 18 | | np.complex128 19 | ] 20 | Array: TypeAlias = np.ndarray[Any, DType] 21 | else: 22 | DType: TypeAlias = np.dtype 23 | Array: TypeAlias = np.ndarray 24 | 25 | __all__ = ["Array", "DType", "Device"] 26 | 27 | 28 | def __dir__() -> list[str]: 29 | return __all__ 30 | -------------------------------------------------------------------------------- /array_api_compat/numpy/fft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .._internal import clone_module 4 | 5 | __all__ = clone_module("numpy.fft", globals()) 6 | 7 | from .._internal import get_xp 8 | from ..common import _fft 9 | 10 | fft = get_xp(np)(_fft.fft) 11 | ifft = get_xp(np)(_fft.ifft) 12 | fftn = get_xp(np)(_fft.fftn) 13 | ifftn = get_xp(np)(_fft.ifftn) 14 | rfft = get_xp(np)(_fft.rfft) 15 | irfft = get_xp(np)(_fft.irfft) 16 | rfftn = get_xp(np)(_fft.rfftn) 17 | irfftn = get_xp(np)(_fft.irfftn) 18 | hfft = get_xp(np)(_fft.hfft) 19 | ihfft = get_xp(np)(_fft.ihfft) 20 | fftfreq = get_xp(np)(_fft.fftfreq) 21 | rfftfreq = get_xp(np)(_fft.rfftfreq) 22 | fftshift = get_xp(np)(_fft.fftshift) 23 | ifftshift = get_xp(np)(_fft.ifftshift) 24 | 25 | 26 | __all__ = sorted(set(__all__) | set(_fft.__all__)) 27 | 28 | def __dir__() -> list[str]: 29 | return __all__ 30 | 31 | -------------------------------------------------------------------------------- /array_api_compat/numpy/linalg.py: -------------------------------------------------------------------------------- 1 | # pyright: reportAttributeAccessIssue=false 2 | # pyright: reportUnknownArgumentType=false 3 | # pyright: reportUnknownMemberType=false 4 | # pyright: reportUnknownVariableType=false 5 | 6 | from __future__ import annotations 7 | 8 | import numpy as np 9 | 10 | from .._internal import clone_module, get_xp 11 | from ..common import _linalg 12 | 13 | __all__ = clone_module("numpy.linalg", globals()) 14 | 15 | # These functions are in both the main and linalg namespaces 16 | from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 17 | from ._typing import Array 18 | 19 | cross = get_xp(np)(_linalg.cross) 20 | outer = get_xp(np)(_linalg.outer) 21 | EighResult = _linalg.EighResult 22 | QRResult = _linalg.QRResult 23 | SlogdetResult = _linalg.SlogdetResult 24 | SVDResult = _linalg.SVDResult 25 | eigh = get_xp(np)(_linalg.eigh) 26 | qr = get_xp(np)(_linalg.qr) 27 | slogdet = get_xp(np)(_linalg.slogdet) 28 | svd = get_xp(np)(_linalg.svd) 29 | cholesky = get_xp(np)(_linalg.cholesky) 30 | matrix_rank = get_xp(np)(_linalg.matrix_rank) 31 | pinv = get_xp(np)(_linalg.pinv) 32 | matrix_norm = get_xp(np)(_linalg.matrix_norm) 33 | svdvals = get_xp(np)(_linalg.svdvals) 34 | diagonal = get_xp(np)(_linalg.diagonal) 35 | trace = get_xp(np)(_linalg.trace) 36 | 37 | # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a 38 | # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack 39 | # of matrices. The np.linalg.solve behavior of allowing stacks of both 40 | # matrices and vectors is ambiguous c.f. 41 | # https://github.com/numpy/numpy/issues/15349 and 42 | # https://github.com/data-apis/array-api/issues/285. 43 | 44 | # To workaround this, the below is the code from np.linalg.solve except 45 | # only calling solve1 in the exactly 1D case. 46 | 47 | 48 | # This code is here instead of in common because it is numpy specific. Also 49 | # note that CuPy's solve() does not currently support broadcasting (see 50 | # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). 51 | def solve(x1: Array, x2: Array, /) -> Array: 52 | try: 53 | from numpy.linalg._linalg import ( # type: ignore[attr-defined] 54 | _assert_stacked_2d, 55 | _assert_stacked_square, 56 | _commonType, 57 | _makearray, 58 | _raise_linalgerror_singular, 59 | isComplexType, 60 | ) 61 | except ImportError: 62 | from numpy.linalg.linalg import ( # type: ignore[attr-defined] 63 | _assert_stacked_2d, 64 | _assert_stacked_square, 65 | _commonType, 66 | _makearray, 67 | _raise_linalgerror_singular, 68 | isComplexType, 69 | ) 70 | from numpy.linalg import _umath_linalg 71 | 72 | x1, _ = _makearray(x1) 73 | _assert_stacked_2d(x1) 74 | _assert_stacked_square(x1) 75 | x2, wrap = _makearray(x2) 76 | t, result_t = _commonType(x1, x2) 77 | 78 | # This part is different from np.linalg.solve 79 | gufunc: np.ufunc 80 | if x2.ndim == 1: 81 | gufunc = _umath_linalg.solve1 82 | else: 83 | gufunc = _umath_linalg.solve 84 | 85 | # This does nothing currently but is left in because it will be relevant 86 | # when complex dtype support is added to the spec in 2022. 87 | signature = "DD->D" if isComplexType(t) else "dd->d" 88 | with np.errstate( 89 | call=_raise_linalgerror_singular, 90 | invalid="call", 91 | over="ignore", 92 | divide="ignore", 93 | under="ignore", 94 | ): 95 | r: Array = gufunc(x1, x2, signature=signature) 96 | 97 | return wrap(r.astype(result_t, copy=False)) 98 | 99 | 100 | # These functions are completely new here. If the library already has them 101 | # (i.e., numpy 2.0), use the library version instead of our wrapper. 102 | if hasattr(np.linalg, "vector_norm"): 103 | vector_norm = np.linalg.vector_norm 104 | else: 105 | vector_norm = get_xp(np)(_linalg.vector_norm) 106 | 107 | 108 | _all = [ 109 | "LinAlgError", 110 | "cond", 111 | "det", 112 | "eig", 113 | "eigvals", 114 | "eigvalsh", 115 | "inv", 116 | "lstsq", 117 | "matrix_power", 118 | "multi_dot", 119 | "norm", 120 | "solve", 121 | "tensorinv", 122 | "tensorsolve", 123 | "vector_norm", 124 | ] 125 | __all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) 126 | 127 | def __dir__() -> list[str]: 128 | return __all__ 129 | -------------------------------------------------------------------------------- /array_api_compat/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/array_api_compat/py.typed -------------------------------------------------------------------------------- /array_api_compat/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Final 2 | 3 | from .._internal import clone_module 4 | 5 | __all__ = clone_module("torch", globals()) 6 | 7 | # These imports may overwrite names from the import * above. 8 | from . import _aliases 9 | from ._aliases import * # noqa: F403 10 | from ._info import __array_namespace_info__ # noqa: F401 11 | 12 | # See the comment in the numpy __init__.py 13 | __import__(__package__ + '.linalg') 14 | __import__(__package__ + '.fft') 15 | 16 | __array_api_version__: Final = '2024.12' 17 | 18 | __all__ = sorted( 19 | set(__all__) 20 | | set(_aliases.__all__) 21 | | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} 22 | ) 23 | 24 | def __dir__() -> list[str]: 25 | return __all__ 26 | -------------------------------------------------------------------------------- /array_api_compat/torch/_typing.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Array", "Device", "DType"] 2 | 3 | from torch import device as Device, dtype as DType, Tensor as Array 4 | -------------------------------------------------------------------------------- /array_api_compat/torch/fft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | from typing import Literal 5 | 6 | import torch 7 | import torch.fft 8 | 9 | from ._typing import Array 10 | from .._internal import clone_module 11 | 12 | __all__ = clone_module("torch.fft", globals()) 13 | 14 | # Several torch fft functions do not map axes to dim 15 | 16 | def fftn( 17 | x: Array, 18 | /, 19 | *, 20 | s: Sequence[int] = None, 21 | axes: Sequence[int] = None, 22 | norm: Literal["backward", "ortho", "forward"] = "backward", 23 | **kwargs: object, 24 | ) -> Array: 25 | return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) 26 | 27 | def ifftn( 28 | x: Array, 29 | /, 30 | *, 31 | s: Sequence[int] = None, 32 | axes: Sequence[int] = None, 33 | norm: Literal["backward", "ortho", "forward"] = "backward", 34 | **kwargs: object, 35 | ) -> Array: 36 | return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) 37 | 38 | def rfftn( 39 | x: Array, 40 | /, 41 | *, 42 | s: Sequence[int] = None, 43 | axes: Sequence[int] = None, 44 | norm: Literal["backward", "ortho", "forward"] = "backward", 45 | **kwargs: object, 46 | ) -> Array: 47 | return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) 48 | 49 | def irfftn( 50 | x: Array, 51 | /, 52 | *, 53 | s: Sequence[int] = None, 54 | axes: Sequence[int] = None, 55 | norm: Literal["backward", "ortho", "forward"] = "backward", 56 | **kwargs: object, 57 | ) -> Array: 58 | return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) 59 | 60 | def fftshift( 61 | x: Array, 62 | /, 63 | *, 64 | axes: int | Sequence[int] = None, 65 | **kwargs: object, 66 | ) -> Array: 67 | return torch.fft.fftshift(x, dim=axes, **kwargs) 68 | 69 | def ifftshift( 70 | x: Array, 71 | /, 72 | *, 73 | axes: int | Sequence[int] = None, 74 | **kwargs: object, 75 | ) -> Array: 76 | return torch.fft.ifftshift(x, dim=axes, **kwargs) 77 | 78 | 79 | __all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] 80 | 81 | def __dir__() -> list[str]: 82 | return __all__ 83 | -------------------------------------------------------------------------------- /array_api_compat/torch/linalg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.linalg 5 | 6 | from .._internal import clone_module 7 | 8 | __all__ = clone_module("torch.linalg", globals()) 9 | 10 | # outer is implemented in torch but aren't in the linalg namespace 11 | from torch import outer 12 | from ._aliases import _fix_promotion, sum 13 | # These functions are in both the main and linalg namespaces 14 | from ._aliases import matmul, matrix_transpose, tensordot 15 | from ._typing import Array, DType 16 | from ..common._typing import JustInt, JustFloat 17 | 18 | # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the 19 | # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 20 | 21 | # torch.cross also does not support broadcasting when it would add new 22 | # dimensions https://github.com/pytorch/pytorch/issues/39656 23 | def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: 24 | x1, x2 = _fix_promotion(x1, x2, only_scalar=False) 25 | if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): 26 | raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") 27 | if not (x1.shape[axis] == x2.shape[axis] == 3): 28 | raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") 29 | x1, x2 = torch.broadcast_tensors(x1, x2) 30 | return torch.linalg.cross(x1, x2, dim=axis) 31 | 32 | def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: 33 | from ._aliases import isdtype 34 | 35 | x1, x2 = _fix_promotion(x1, x2, only_scalar=False) 36 | 37 | # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension 38 | if x1.shape[axis] != x2.shape[axis]: 39 | raise ValueError("x1 and x2 must have the same size along the given axis") 40 | 41 | # torch.linalg.vecdot doesn't support integer dtypes 42 | if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): 43 | if kwargs: 44 | raise RuntimeError("vecdot kwargs not supported for integral dtypes") 45 | 46 | x1_ = torch.moveaxis(x1, axis, -1) 47 | x2_ = torch.moveaxis(x2, axis, -1) 48 | x1_, x2_ = torch.broadcast_tensors(x1_, x2_) 49 | 50 | res = x1_[..., None, :] @ x2_[..., None] 51 | return res[..., 0, 0] 52 | return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) 53 | 54 | def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: 55 | x1, x2 = _fix_promotion(x1, x2, only_scalar=False) 56 | # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve 57 | # whenever 58 | # 1. x1.ndim - 1 == x2.ndim 59 | # 2. x1.shape[:-1] == x2.shape 60 | # 61 | # See linalg_solve_is_vector_rhs in 62 | # aten/src/ATen/native/LinearAlgebraUtils.h and 63 | # TORCH_META_FUNC(_linalg_solve_ex) in 64 | # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code. 65 | # 66 | # The easiest way to work around this is to prepend a size 1 dimension to 67 | # x2, since x2 is already one dimension less than x1. 68 | # 69 | # See https://github.com/pytorch/pytorch/issues/52915 70 | if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape: 71 | x2 = x2[None] 72 | return torch.linalg.solve(x1, x2, **kwargs) 73 | 74 | # torch.trace doesn't support the offset argument and doesn't support stacking 75 | def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: 76 | # Use our wrapped sum to make sure it does upcasting correctly 77 | return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) 78 | 79 | def vector_norm( 80 | x: Array, 81 | /, 82 | *, 83 | axis: int | tuple[int, ...] | None = None, 84 | keepdims: bool = False, 85 | # JustFloat stands for inf | -inf, which are not valid for Literal 86 | ord: JustInt | JustFloat = 2, 87 | **kwargs: object, 88 | ) -> Array: 89 | # torch.vector_norm incorrectly treats axis=() the same as axis=None 90 | if axis == (): 91 | out = kwargs.get('out') 92 | if out is None: 93 | dtype = None 94 | if x.dtype == torch.complex64: 95 | dtype = torch.float32 96 | elif x.dtype == torch.complex128: 97 | dtype = torch.float64 98 | 99 | out = torch.zeros_like(x, dtype=dtype) 100 | 101 | # The norm of a single scalar works out to abs(x) in every case except 102 | # for ord=0, which is x != 0. 103 | if ord == 0: 104 | out[:] = (x != 0) 105 | else: 106 | out[:] = torch.abs(x) 107 | return out 108 | return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) 109 | 110 | __all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', 111 | 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] 112 | 113 | def __dir__() -> list[str]: 114 | return __all__ 115 | -------------------------------------------------------------------------------- /dask-skips.txt: -------------------------------------------------------------------------------- 1 | # NOTE: dask tests run on a very small number of examples in CI due to 2 | # slowness. This causes very high flakiness in the tests. 3 | # Before changing this file, please run with at least 200 examples. 4 | 5 | # Passes, but extremely slow 6 | array_api_tests/test_linalg.py::test_outer 7 | 8 | # Hangs 9 | array_api_tests/test_creation_functions.py::test_eye 10 | -------------------------------------------------------------------------------- /dask-xfails.txt: -------------------------------------------------------------------------------- 1 | # NOTE: dask tests run on a very small number of examples in CI due to 2 | # slowness. This causes very high flakiness in the tests. 3 | # Before changing this file, please run with at least 200 examples. 4 | 5 | # Broken edge case with shape 0 6 | # https://github.com/dask/dask/issues/11800 7 | array_api_tests/test_array_object.py::test_setitem 8 | 9 | # Various indexing errors 10 | array_api_tests/test_array_object.py::test_getitem_masking 11 | 12 | # zero division error, and typeerror: tuple indices must be integers or slices not tuple 13 | array_api_tests/test_creation_functions.py::test_eye 14 | 15 | # attributes are np.float32 instead of float 16 | # (see also https://github.com/data-apis/array-api/issues/405) 17 | array_api_tests/test_data_type_functions.py::test_finfo[float32] 18 | array_api_tests/test_data_type_functions.py::test_finfo[complex64] 19 | 20 | # out[-1]=dask.array but should be some floating number 21 | # (I think the test is not forcing the op to be computed?) 22 | array_api_tests/test_creation_functions.py::test_linspace 23 | 24 | # Shape mismatch 25 | array_api_tests/test_indexing_functions.py::test_take 26 | 27 | # missing `take_along_axis`, https://github.com/dask/dask/issues/3663 28 | array_api_tests/test_indexing_functions.py::test_take_along_axis 29 | 30 | # Array methods and attributes not already on da.Array cannot be wrapped 31 | array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] 32 | array_api_tests/test_has_names.py::test_has_names[array_method-to_device] 33 | array_api_tests/test_has_names.py::test_has_names[array_attribute-device] 34 | array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] 35 | 36 | # Data-dependent output shape 37 | # These tests fail as array-api-tests doesn't cope with unknown shapes 38 | # Also, output shape is (math.nan, ) instead of (None, ) 39 | # Also, da.unique() doesn't accept equals_nan which causes non-compliant 40 | # output when there are NaNs in the input. 41 | array_api_tests/test_searching_functions.py::test_nonzero 42 | array_api_tests/test_set_functions.py::test_unique_all 43 | array_api_tests/test_set_functions.py::test_unique_counts 44 | array_api_tests/test_set_functions.py::test_unique_inverse 45 | array_api_tests/test_set_functions.py::test_unique_values 46 | 47 | # Linalg failures (signature failures/missing methods) 48 | 49 | # fails for ndim > 2 50 | array_api_tests/test_linalg.py::test_svdvals 51 | 52 | # dtype mismatch got uint64, but should be uint8; NPY_PROMOTION_STATE=weak doesn't help 53 | array_api_tests/test_linalg.py::test_tensordot 54 | 55 | # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] 56 | array_api_tests/test_linalg.py::test_linalg_tensordot 57 | 58 | # ZeroDivisionError in dask's normalize_chunks/auto_chunks internals 59 | array_api_tests/test_linalg.py::test_inv 60 | array_api_tests/test_linalg.py::test_matrix_power 61 | 62 | # Linalg - these don't exist in dask 63 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] 64 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] 65 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] 66 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] 67 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] 68 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] 69 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] 70 | array_api_tests/test_linalg.py::test_cross 71 | array_api_tests/test_linalg.py::test_det 72 | array_api_tests/test_linalg.py::test_eigh 73 | array_api_tests/test_linalg.py::test_eigvalsh 74 | array_api_tests/test_linalg.py::test_matrix_rank 75 | array_api_tests/test_linalg.py::test_pinv 76 | array_api_tests/test_linalg.py::test_slogdet 77 | array_api_tests/test_has_names.py::test_has_names[linalg-cross] 78 | array_api_tests/test_has_names.py::test_has_names[linalg-det] 79 | array_api_tests/test_has_names.py::test_has_names[linalg-eigh] 80 | array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] 81 | array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] 82 | array_api_tests/test_has_names.py::test_has_names[linalg-pinv] 83 | array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] 84 | 85 | # Constructing the input arrays fails to a weird shape error... 86 | array_api_tests/test_linalg.py::test_solve 87 | 88 | # missing full_matrices kw 89 | # https://github.com/dask/dask/issues/10389 90 | # also only supports 2-d inputs 91 | array_api_tests/test_linalg.py::test_svd 92 | 93 | # Missing dlpack stuff 94 | array_api_tests/test_signatures.py::test_func_signature[from_dlpack] 95 | array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] 96 | array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] 97 | array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] 98 | array_api_tests/test_signatures.py::test_array_method_signature[to_device] 99 | array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] 100 | array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] 101 | array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] 102 | 103 | # No mT on dask array 104 | array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices 105 | 106 | # Edge case of args near 2**63 107 | # https://github.com/dask/dask/issues/11706 108 | array_api_tests/test_creation_functions.py::test_arange 109 | 110 | # da.searchsorted with a sorter argument is not supported 111 | array_api_tests/test_searching_functions.py::test_searchsorted 112 | 113 | # 2023.12 support 114 | array_api_tests/test_manipulation_functions.py::test_repeat 115 | 116 | # 2024.12 support 117 | array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[1] 118 | array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None] 119 | array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1] 120 | array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None] 121 | array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis] 122 | array_api_tests/test_signatures.py::test_func_signature[count_nonzero] 123 | array_api_tests/test_signatures.py::test_func_signature[take_along_axis] 124 | 125 | array_api_tests/test_linalg.py::test_cholesky 126 | array_api_tests/test_linalg.py::test_linalg_matmul 127 | array_api_tests/test_linalg.py::test_matmul 128 | array_api_tests/test_linalg.py::test_matrix_norm 129 | array_api_tests/test_linalg.py::test_qr 130 | array_api_tests/test_manipulation_functions.py::test_roll 131 | 132 | # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) 133 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 134 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 135 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 136 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 137 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 138 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 139 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 140 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 141 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 142 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 143 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 144 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 145 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 146 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 147 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 148 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 149 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 150 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 151 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | livehtml: 23 | sphinx-autobuild --open-browser --watch .. --port 0 -b html $(SOURCEDIR) $(ALLSPHINXOPTS) $(BUILDDIR)/html 24 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* Makes the text look better on Mac retina displays (the Furo CSS disables*/ 2 | /* subpixel antialiasing). */ 3 | body { 4 | -webkit-font-smoothing: auto; 5 | -moz-osx-font-smoothing: auto; 6 | } 7 | 8 | /* Disable the fancy scrolling behavior when jumping to headers (this is too 9 | slow for long pages) */ 10 | html { 11 | scroll-behavior: auto; 12 | } 13 | 14 | /* Make checkboxes from the tasklist extension ('- [ ]' in Markdown) not add bullet points to the checkboxes. 15 | 16 | This can be removed once https://github.com/executablebooks/mdit-py-plugins/issues/59 is addressed. 17 | */ 18 | 19 | .contains-task-list { 20 | list-style: none; 21 | } 22 | 23 | /* Make the checkboxes indented like they are bullets */ 24 | .task-list-item-checkbox { 25 | margin: 0 0.2em 0.25em -1.4em; 26 | } 27 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import sys 10 | import os 11 | sys.path.insert(0, os.path.abspath('..')) 12 | 13 | project = 'array-api-compat' 14 | copyright = '2024, Consortium for Python Data API Standards' 15 | author = 'Consortium for Python Data API Standards' 16 | 17 | import array_api_compat 18 | release = array_api_compat.__version__ 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = [ 24 | 'myst_parser', 25 | 'sphinx.ext.autodoc', 26 | 'sphinx.ext.napoleon', 27 | 'sphinx.ext.intersphinx', 28 | 'sphinx_copybutton', 29 | ] 30 | 31 | intersphinx_mapping = { 32 | 'cupy': ('https://docs.cupy.dev/en/stable', None), 33 | 'torch': ('https://pytorch.org/docs/stable/', None), 34 | } 35 | # Require :external: to reference intersphinx. 36 | intersphinx_disabled_reftypes = ['*'] 37 | 38 | templates_path = ['_templates'] 39 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 40 | 41 | myst_enable_extensions = ["dollarmath", "linkify", "tasklist"] 42 | myst_enable_checkboxes = True 43 | 44 | napoleon_use_rtype = False 45 | napoleon_use_param = False 46 | 47 | # Make sphinx give errors for bad cross-references 48 | nitpicky = True 49 | # autodoc wants to make cross-references for every type hint. But a lot of 50 | # them don't actually refer to anything that we have a document for. 51 | nitpick_ignore = [ 52 | ("py:class", "Array"), 53 | ("py:class", "Device"), 54 | ] 55 | 56 | # Lets us use single backticks for code in RST 57 | default_role = 'code' 58 | 59 | # -- Options for HTML output ------------------------------------------------- 60 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 61 | 62 | html_theme = 'furo' 63 | html_static_path = ['_static'] 64 | 65 | html_css_files = ['custom.css'] 66 | 67 | html_theme_options = { 68 | # See https://pradyunsg.me/furo/customisation/footer/ 69 | "footer_icons": [ 70 | { 71 | "name": "GitHub", 72 | "url": "https://github.com/data-apis/array-api-compat", 73 | "html": """ 74 | 75 | 76 | 77 | """, 78 | "class": "", 79 | }, 80 | ], 81 | } 82 | 83 | # Logo 84 | 85 | html_favicon = "_static/favicon.png" 86 | 87 | # html_logo = "_static/logo.svg" 88 | -------------------------------------------------------------------------------- /docs/dev/implementation-notes.md: -------------------------------------------------------------------------------- 1 | # Implementation Notes 2 | 3 | Since NumPy, CuPy, and to a degree, Dask, are nearly identical in behavior, 4 | most wrapping logic can be shared between them. Wrapped functions that have 5 | the same logic between multiple libraries are in `array_api_compat/common/`. 6 | These functions are defined like 7 | 8 | ```py 9 | # In array_api_compat/common/_aliases.py 10 | 11 | def acos(x, /, xp): 12 | return xp.arccos(x) 13 | ``` 14 | 15 | The `xp` argument refers to the original array namespace (e.g., `numpy` or 16 | `cupy`). Then in the specific `array_api_compat/numpy/` and 17 | `array_api_compat/cupy/` namespaces, the `@get_xp` decorator is applied to 18 | these functions, which automatically removes the `xp` argument from the 19 | function signature and replaces it with the corresponding array library, like 20 | 21 | ```py 22 | # In array_api_compat/numpy/_aliases.py 23 | 24 | from ..common import _aliases 25 | 26 | import numpy as np 27 | 28 | acos = get_xp(np)(_aliases.acos) 29 | ``` 30 | 31 | This `acos` now has the signature `acos(x, /)` and calls `numpy.arccos`. 32 | 33 | Similarly, for CuPy: 34 | 35 | ```py 36 | # In array_api_compat/cupy/_aliases.py 37 | 38 | from ..common import _aliases 39 | 40 | import cupy as cp 41 | 42 | acos = get_xp(cp)(_aliases.acos) 43 | ``` 44 | 45 | Most NumPy and CuPy are defined in this way, since their behaviors are nearly 46 | identical PyTorch uses a similar layout in `array_api_compat/torch/`, but it 47 | differs enough from NumPy/CuPy that very few common wrappers for those 48 | libraries are reused. Dask is close to NumPy in behavior and so most Dask 49 | functions also reuse the NumPy/CuPy common wrappers. 50 | 51 | Occasionally, a wrapper implementation will need to reference another wrapper 52 | implementation, rather than the base `xp` version. The easiest way to do this 53 | is to call `array_namespace`, like 54 | 55 | ```py 56 | wrapped_xp = array_namespace(x) 57 | wrapped_xp.wrapped_func(...) 58 | ``` 59 | 60 | Also, if there is a very minor difference required for wrapping, say, CuPy and 61 | NumPy, they can still use a common implementation in `common/_aliases.py` and 62 | use the `is_*_namespace()` or `is_*_function()` [helper 63 | functions](../helper-functions.rst) to branch as necessary. 64 | -------------------------------------------------------------------------------- /docs/dev/index.md: -------------------------------------------------------------------------------- 1 | # Development Notes 2 | 3 | This is internal documentation related to the development of array-api-compat. 4 | It is recommended that contributors read through this documentation. 5 | 6 | ```{toctree} 7 | :titlesonly: 8 | 9 | special-considerations.md 10 | implementation-notes.md 11 | tests.md 12 | releasing.md 13 | ``` 14 | -------------------------------------------------------------------------------- /docs/dev/releasing.md: -------------------------------------------------------------------------------- 1 | # Releasing 2 | 3 | - [ ] **Create a PR with a release branch** 4 | 5 | This makes it easy to verify that CI is passing, and also gives you a place 6 | to push up updates to the changelog and any last minute fixes for the 7 | release. 8 | 9 | - [ ] **Double check the release branch is fully merged with `main`.** 10 | 11 | (e.g., if the release branch is called `release`) 12 | 13 | ``` 14 | git checkout main 15 | git pull 16 | git checkout release 17 | git merge main 18 | ``` 19 | 20 | - [ ] **Make sure that all CI tests are passing.** 21 | 22 | Note that the GitHub action that publishes to PyPI does not check if CI is 23 | passing before publishing. So you need to check this manually. 24 | 25 | This does mean you can ignore CI failures, but ideally you should fix any 26 | failures or update the `*-xfails.txt` files before tagging, so that CI and 27 | the CuPy tests fully pass. Otherwise it will be hard to tell what things are 28 | breaking in the future. It's also a good idea to remove any xpasses from 29 | those files (but be aware that some xfails are from flaky failures, so 30 | unless you know the underlying issue has been fixed, an xpass test is 31 | probably still xfail). 32 | 33 | - [ ] **Test CuPy.** 34 | 35 | CuPy must be tested manually (it isn't tested on CI, see 36 | https://github.com/data-apis/array-api-compat/issues/197). Use the script 37 | 38 | ``` 39 | ./test_cupy.sh 40 | ``` 41 | 42 | on a machine with a CUDA GPU. 43 | 44 | 45 | - [ ] **Update the version.** 46 | 47 | You must edit 48 | 49 | ``` 50 | array_api_compat/__init__.py 51 | ``` 52 | 53 | and update the version (the version is not computed from the tag because 54 | that would break vendorability). 55 | 56 | - [ ] **Update the [changelog](../changelog.md).** 57 | 58 | Edit 59 | 60 | ``` 61 | docs/changelog.md 62 | ``` 63 | 64 | with the changes for the release. 65 | 66 | - [ ] **Create the release tag.** 67 | 68 | Once everything is ready, create a tag 69 | 70 | ``` 71 | git tag -a 72 | ``` 73 | 74 | (note the tag names are not prefixed, for instance, the tag for version 1.5 is 75 | just `1.5`) 76 | 77 | - [ ] **Push the tag to GitHub.** 78 | 79 | *This is the final step. Doing this will build and publish the release!* 80 | 81 | ``` 82 | git push origin 83 | ``` 84 | 85 | This will trigger the [`publish 86 | distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) 87 | GitHub Action that will build the release and push it to PyPI. 88 | 89 | - [ ] **Check that the [`publish 90 | distributions`](https://github.com/data-apis/array-api-compat/actions/workflows/publish-package.yml) 91 | action build on the tag worked.** Note that this action will run even if the 92 | other CI fails, so you must make sure that CI is passing *before* tagging. 93 | 94 | If it failed for some reason, you may need to delete the tag and try again. 95 | 96 | - [ ] **Merge the release branch.** 97 | 98 | This way any changes you made in the branch, such as updates to the 99 | changelog or xfails files, are updated in `main`. This will also make the 100 | docs update (the docs are published automatically from the sources on 101 | `main`). 102 | 103 | - [ ] **Update conda-forge.** 104 | 105 | After the PyPI package is published, the conda-forge bot should update the 106 | feedstock automatically after some time. The bot should automerge, so in 107 | most cases you don't need to do anything here, unless some metadata on the 108 | feedstock needs to be updated. 109 | -------------------------------------------------------------------------------- /docs/dev/special-considerations.md: -------------------------------------------------------------------------------- 1 | # Special Considerations 2 | 3 | array-api-compat requires some special development considerations that are 4 | different from most other Python libraries. The goal of array-api-compat is to 5 | be a small library that packages can either vendor or add as a dependency to 6 | implement array API support. Consequently, certain design considerations 7 | should be taken into account: 8 | 9 | (no-dependencies)= 10 | - **No Hard Dependencies.** Although array-api-compat "depends" on NumPy, CuPy, 11 | PyTorch, etc., it does not hard depend on them. These libraries are not 12 | imported unless either an array object is passed to 13 | {func}`~.array_namespace()`, or the specific `array_api_compat.` 14 | sub-namespace is explicitly imported. This is tested (as best as possible) 15 | in `tests/test_no_dependencies.py`. 16 | 17 | - **Vendorability.** array-api-compat should be [vendorable](vendoring). This 18 | means that, for instance, all imports in the library are relative imports. 19 | No code in the package specifically references the name `array_api_compat` 20 | (we also support renaming the package to something else). 21 | Vendorability support is tested in `tests/test_vendoring.py`. 22 | 23 | - **Pure Python.** To make array-api-compat as easy as possible to add as a 24 | dependency, the code is all pure Python. 25 | 26 | - **Minimal Wrapping Only.** The wrapping functionality is minimal. This means 27 | that if something is difficult to wrap using pure Python, or if trying to 28 | support some array API behavior would require a significant amount of code, 29 | we prefer to leave the behavior as an upstream issue for the array library, 30 | and [document it as a known difference](../supported-array-libraries.md). 31 | 32 | This also means that we do not at this point in time implement anything 33 | other than wrappers for functions in the standard, and basic [helper 34 | functions](../helper-functions.rst) that would be useful for most users of 35 | array-api-compat. The addition of functions that are not part of the array 36 | API standard is currently out-of-scope for this package (see the 37 | [Scope](scope) section of the documentation). 38 | 39 | - **No Side-Effects**. array-api-compat behavior should be localized to only the 40 | specific code that imports and uses it. It should be invisible to end-users 41 | or users of dependent codes. This in particular implies to the next two 42 | points. 43 | 44 | - **No Monkey Patching.** `array-api-compat` should not attempt to modify 45 | anything about the underlying library. It is a *wrapper* library only. 46 | 47 | - **No Modifying the Array Object.** The array (or tensor) object of the array 48 | library cannot be modified. This also precludes the creation of array 49 | subclasses or wrapper classes. 50 | 51 | Any non-standard behavior that is built-in to the array object, such as the 52 | behavior of [array 53 | methods](https://data-apis.org/array-api/latest/API_specification/array_object.html), 54 | is therefore left unwrapped. Users can workaround issues by using 55 | corresponding [elementwise 56 | functions](https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html) 57 | instead of 58 | [operators](https://data-apis.org/array-api/latest/API_specification/array_object.html#operators), 59 | and by using the [helper functions](../helper-functions.rst) provided by 60 | array-api-compat instead of attributes or methods like `x.to_device()`. 61 | 62 | - **Avoid Restricting Behavior that is Outside the Scope of the Standard.** All 63 | array libraries have functions and behaviors that are outside of the scope 64 | of what is specified by the standard. These behaviors should be left intact 65 | whenever possible, unless the standard explicitly disallows something. This 66 | means 67 | 68 | - All namespaces are *extended* with wrapper functions. You may notice the 69 | extensive use of `import *` in various files in `array_api_compat`. While 70 | this would normally be questionable, this is the [one actual legitimate 71 | use-case for `import *`](https://peps.python.org/pep-0008/#imports), to 72 | re-export names from an external namespace. 73 | 74 | - All wrapper functions pass `**kwargs` through to the wrapped function. 75 | 76 | - Input types not supported by the standard should work if they work in the 77 | underlying wrapped function (for instance, Python scalars or `np.ndarray` 78 | subclasses). 79 | 80 | By keeping underlying behaviors intact, it is easier for libraries to swap 81 | out NumPy or other array libraries for array-api-compat, and it is easier 82 | for libraries to write array library-specific code paths. 83 | 84 | The onus is on users of array-api-compat to ensure their array API code is 85 | portable, e.g., by testing against [array-api-strict](array-api-strict). 86 | -------------------------------------------------------------------------------- /docs/dev/tests.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | The majority of the behavior for array-api-compat is tested by the 4 | [array-api-tests](https://github.com/data-apis/array-api-tests) test suite for 5 | the array API standard. There are also array-api-compat specific tests in 6 | [`tests/`](https://github.com/data-apis/array-api-compat/tree/main/tests). 7 | These tests should be limited to things that are not tested by the test suite, 8 | e.g., tests for [helper functions](../helper-functions.rst) or for behavior 9 | that is not strictly required by the standard. To run these tests, install the 10 | dependencies from the `dev` optional group (array-api-compat has [no hard 11 | runtime dependencies](no-dependencies)). 12 | 13 | array-api-tests is run against all supported libraries are tested on CI 14 | ([except for JAX](jax-support) and [Sparse](sparse-support)). This is achieved 15 | by a [reusable GitHub Actions 16 | Workflow](https://github.com/data-apis/array-api-compat/blob/main/.github/workflows/array-api-tests.yml). 17 | Most libraries have tests that must be xfailed or skipped for various reasons. 18 | These are defined in specific `-xfails.txt` files and are 19 | automatically forwarded to array-api-tests. 20 | 21 | You may often need to update these xfail files, either to add new xfails 22 | (e.g., because of new test suite features, or because a test that was 23 | previously thought to be passing actually flaky fails). Try to keep the xfails 24 | files organized, with comments pointing to upstream issues whenever possible. 25 | 26 | From time to time, xpass tests should be removed from the xfail files, but be 27 | aware that many xfail tests are flaky, so an xpass should only be removed if 28 | you know that the underlying issue has been fixed. 29 | 30 | Array libraries that require a GPU to run (currently only CuPy) cannot be 31 | tested on CI. There is a helper script `test_cupy.sh` that can be used to 32 | manually test CuPy on a machine with a CUDA GPU. 33 | -------------------------------------------------------------------------------- /docs/helper-functions.rst: -------------------------------------------------------------------------------- 1 | Helper Functions 2 | ================ 3 | 4 | .. currentmodule:: array_api_compat 5 | 6 | In addition to the wrapped library namespaces and functions in the array API 7 | specification, there are several helper functions included here that aren't 8 | part of the specification but which are useful for using the array API: 9 | 10 | Entry-point Helpers 11 | ------------------- 12 | 13 | The `array_namespace()` function is the primary entry-point for array API 14 | consuming libraries. 15 | 16 | 17 | .. autofunction:: array_namespace 18 | .. autofunction:: is_array_api_obj 19 | 20 | Array Method Helpers 21 | -------------------- 22 | 23 | array-api-compat does not attempt to wrap or monkey patch the array object for 24 | any library. Consequently, any API differences for the `array object 25 | `__ 26 | cannot be directly wrapped. Some libraries do not define some of these methods 27 | or define them differently. For these, helper functions are provided which can 28 | be used instead. 29 | 30 | Note that if you have a compatibility issue with an operator method (like 31 | `__add__`, i.e., `+`) you can prefer to use the corresponding `elementwise 32 | function 33 | `__ 34 | instead, which would be wrapped. 35 | 36 | .. autofunction:: device 37 | .. autofunction:: to_device 38 | .. autofunction:: size 39 | 40 | Inspection Helpers 41 | ------------------ 42 | 43 | These convenience functions can be used to test if an array or namespace comes from a 44 | specific library without importing that library if it hasn't been imported 45 | yet. 46 | 47 | .. autofunction:: is_numpy_array 48 | .. autofunction:: is_cupy_array 49 | .. autofunction:: is_torch_array 50 | .. autofunction:: is_dask_array 51 | .. autofunction:: is_jax_array 52 | .. autofunction:: is_pydata_sparse_array 53 | .. autofunction:: is_ndonnx_array 54 | .. autofunction:: is_writeable_array 55 | .. autofunction:: is_lazy_array 56 | .. autofunction:: is_numpy_namespace 57 | .. autofunction:: is_cupy_namespace 58 | .. autofunction:: is_torch_namespace 59 | .. autofunction:: is_dask_namespace 60 | .. autofunction:: is_jax_namespace 61 | .. autofunction:: is_pydata_sparse_namespace 62 | .. autofunction:: is_ndonnx_namespace 63 | .. autofunction:: is_array_api_strict_namespace 64 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Array API compatibility library 2 | 3 | This is a small wrapper around common array libraries that is compatible with 4 | the [Array API standard](https://data-apis.org/array-api/latest/). Currently, 5 | NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, and Sparse are supported. If you want 6 | support for other array libraries, or if you encounter any issues, please 7 | [open an issue](https://github.com/data-apis/array-api-compat/issues). 8 | 9 | Note that some of the functionality in this library is backwards incompatible 10 | with the corresponding wrapped libraries. The end-goal is to eventually make 11 | each array library itself fully compatible with the array API, but this 12 | requires making backwards incompatible changes in many cases, so this will 13 | take some time. 14 | 15 | Currently all libraries here are implemented against the [2024.12 16 | version](https://data-apis.org/array-api/2024.12/) of the standard. 17 | 18 | ## Installation 19 | 20 | `array-api-compat` is available on both [PyPI](https://pypi.org/project/array-api-compat/) 21 | 22 | ``` 23 | python -m pip install array-api-compat 24 | ``` 25 | 26 | and [conda-forge](https://anaconda.org/conda-forge/array-api-compat) 27 | 28 | ``` 29 | conda install --channel conda-forge array-api-compat 30 | ``` 31 | 32 | ## Usage 33 | 34 | The typical usage of this library will be to get the corresponding array API 35 | compliant namespace from the input arrays using {func}`~.array_namespace()`, like 36 | 37 | ```py 38 | def your_function(x, y): 39 | xp = array_api_compat.array_namespace(x, y) 40 | # Now use xp as the array library namespace 41 | return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) 42 | ``` 43 | 44 | If you wish to have library-specific code-paths, you can import the 45 | corresponding wrapped namespace for each library, like 46 | 47 | ```py 48 | import array_api_compat.numpy as np 49 | ``` 50 | 51 | ```py 52 | import array_api_compat.cupy as cp 53 | ``` 54 | 55 | ```py 56 | import array_api_compat.torch as torch 57 | ``` 58 | 59 | ```py 60 | import array_api_compat.dask as da 61 | ``` 62 | 63 | ```{note} 64 | There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These 65 | support for these libraries is contained in the libraries themselves (JAX 66 | support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the 67 | `jax.experimental.array_api` module for older JAX versions). The 68 | array-api-compat support for these libraries consists of supporting them in 69 | the [helper functions](helper-functions). 70 | ``` 71 | 72 | Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array 73 | namespace, except that functions that are part of the array API are wrapped so 74 | that they have the correct array API behavior. In each case, the array object 75 | used will be the same array object from the wrapped library. 76 | 77 | (array-api-strict)= 78 | ## Difference between `array_api_compat` and `array_api_strict` 79 | 80 | [`array_api_strict`](https://data-apis.org/array-api-strict/) is a 81 | strict minimal implementation of the array API standard, formerly known as 82 | `numpy.array_api` (see [NEP 83 | 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For example, 84 | `array_api_strict` does not include any functions that are not part of the 85 | array API specification, and will explicitly disallow behaviors that are not 86 | required by the spec (e.g., [cross-kind type 87 | promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)). 88 | (`cupy.array_api` is similar to `array_api_strict`) 89 | 90 | `array_api_compat`, on the other hand, is just an extension of the 91 | corresponding array library namespaces with changes needed to be compliant 92 | with the array API. It includes all additional library functions not mentioned 93 | in the spec, and allows any library behaviors not explicitly disallowed by it, 94 | such as cross-kind casting. 95 | 96 | In particular, unlike `array_api_strict`, this package does not use a separate 97 | `Array` object, but rather just uses the corresponding array library array 98 | objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This 99 | is because those are the objects that are going to be passed as inputs to 100 | functions by end users. This does mean that a few behaviors cannot be wrapped 101 | (see below), but most of the array API functional, so this does not affect 102 | most things. 103 | 104 | Array consuming library authors coding against the array API may wish to test 105 | against `array_api_strict` to ensure they are not using functionality outside 106 | of the standard, but prefer this implementation for the default behavior for 107 | end-users. 108 | 109 | (vendoring)= 110 | ## Vendoring 111 | 112 | This library supports vendoring as an installation method. To vendor the 113 | library, simply copy `array_api_compat` into the appropriate place in the 114 | library, like 115 | 116 | ``` 117 | cp -R array_api_compat/ mylib/vendored/array_api_compat 118 | ``` 119 | 120 | You may also rename it to something else if you like (nowhere in the code 121 | references the name "array_api_compat"). 122 | 123 | Alternatively, the library may be installed as dependency from PyPI. 124 | 125 | (scope)= 126 | ## Scope 127 | 128 | At this time, the scope of array-api-compat is limited to wrapping array 129 | libraries so that they can comply with the [array API 130 | standard](https://data-apis.org/array-api/latest/API_specification/index.html). 131 | This includes a small set of [helper functions](helper-functions.rst) which may 132 | be useful to most users of array-api-compat, for instance, functions that 133 | provide meta-functionality to aid in supporting the array API, or functions 134 | that are necessary to work around wrapping limitations for certain libraries. 135 | 136 | Things that are out-of-scope include: 137 | 138 | - functions that have not yet been 139 | standardized (although note that functions that are in a draft version of the 140 | standard are *in scope*), 141 | 142 | - functions that are complicated to implement correctly/maintain, 143 | 144 | - anything that requires the use of non-Python code. 145 | 146 | If you want a function that is not in array-api-compat that isn't part of the 147 | standard, you should request it either for [inclusion in the 148 | standard](https://github.com/data-apis/array-api/issues) or in specific array 149 | libraries. 150 | 151 | Why is the scope limited in this way? Firstly, we want to keep 152 | array-api-compat as primarily a 153 | [polyfill](https://en.wikipedia.org/wiki/Polyfill_(programming)) compatibility 154 | shim. The goal is to let consuming libraries use the array API today, even 155 | with array libraries that do not yet fully support it. In an ideal world---one that we hope to eventually see in the future---array-api-compat would be 156 | unnecessary, because every array library would fully support the standard. 157 | 158 | The inclusion of non-standardized functions in array-api-compat would 159 | undermine this goal. But much more importantly, it would also undermine the 160 | goals of the [Data APIs Consortium](https://data-apis.org/). The Consortium 161 | creates the array API standard via the consensus of stakeholders from various 162 | array libraries and users. If a not-yet-standardized function were included in 163 | array-api-compat, it would become *de facto* standard, bypassing the decision 164 | making processes of the Consortium. 165 | 166 | Secondly, we want to keep array-api-compat as minimal as possible, so that it 167 | is easy for libraries to add as a (possibly vendored) dependency. 168 | 169 | Thirdly, array-api-compat has a relatively small development team. Pull 170 | requests to array-api-compat would not necessarily receive the same stringent 171 | level of scrutiny that changes to established array libraries like NumPy or 172 | PyTorch would. For wrapped standard functions, this is fine, since the 173 | wrappers typically just clean up a few small inconsistencies from the 174 | standard, leaving the complexity of the implementation to the base array 175 | library function. Furthermore, standard functions are tested by the rigorous 176 | [array-api-tests](https://github.com/data-apis/array-api-tests) test suite. 177 | For this reason, functions that require complex implementations are generally 178 | out-of-scope and should be preferred to be implemented in upstream array 179 | libraries. 180 | 181 | ```{toctree} 182 | :titlesonly: 183 | :hidden: 184 | 185 | helper-functions.rst 186 | supported-array-libraries.md 187 | changelog.md 188 | dev/index.md 189 | ``` 190 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/supported-array-libraries.md: -------------------------------------------------------------------------------- 1 | # Supported Array Libraries 2 | 3 | The following array libraries are supported. This page outlines the known 4 | differences between this library and the array API specification for the 5 | supported packages. 6 | 7 | Note that the {func}`~.array_namespace()` helper will also support any array 8 | library that explicitly supports the array API by defining 9 | [`__array_namespace__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html). 10 | 11 | Any reasonably popular array library is in-scope for array-api-compat, 12 | assuming it is possible to wrap it to support the array API without too much 13 | complexity. If your favorite library is not supported, feel free to open an 14 | [issue or pull request](https://github.com/data-apis/array-api-compat/issues). 15 | 16 | ## [NumPy](https://numpy.org/) and [CuPy](https://cupy.dev/) 17 | 18 | NumPy 2.0 has full array API compatibility. This package is not strictly 19 | necessary for NumPy 2.0 support, but may still be useful for the support of 20 | other libraries, as well as for the [helper functions](helper-functions.rst). 21 | 22 | For NumPy 1.26, as well as corresponding versions of CuPy, the following 23 | deviations from the standard should be noted: 24 | 25 | - The array methods `__array_namespace__`, `device` (for NumPy), `to_device`, 26 | and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we 27 | don't want to monkey patch or wrap it. The [helper 28 | functions](helper-functions.rst) {func}`~.device()` and {func}`~.to_device()` 29 | are provided to work around these missing methods. `x.mT` can be replaced 30 | with `xp.linalg.matrix_transpose(x)`. {func}`~.array_namespace()` should be 31 | used instead of `x.__array_namespace__`. 32 | 33 | - Value-based casting for scalars will be in effect unless explicitly disabled 34 | with the environment variable `NPY_PROMOTION_STATE=weak` or 35 | `np._set_promotion_state('weak')` (requires NumPy 1.24 or newer, see [NEP 36 | 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and 37 | https://github.com/numpy/numpy/issues/22341) 38 | 39 | - Functions which are not wrapped may not have the same type annotations 40 | as the spec. 41 | 42 | - Functions which are not wrapped may not use positional-only arguments. 43 | 44 | The minimum supported NumPy version is 1.22. However, this older version of 45 | NumPy has a few issues: 46 | 47 | - `unique_*` will not compare nans as unequal. 48 | - No `from_dlpack` or `__dlpack__`. 49 | - Type promotion behavior will be value based for 0-D arrays (and there is no 50 | `NPY_PROMOTION_STATE=weak` to disable this). 51 | 52 | If any of these are an issue, it is recommended to bump your minimum NumPy 53 | version. 54 | 55 | ## [PyTorch](https://pytorch.org/) 56 | 57 | - Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the 58 | `__array_namespace__` and `to_device` methods, so the corresponding helper 59 | functions {func}`~.array_namespace()` and {func}`~.to_device()` in this 60 | library should be used instead. 61 | 62 | - The {external+torch:meth}`x.size() ` attribute on 63 | `torch.Tensor` is a method that behaves differently from the 64 | [`x.size`](https://data-apis.org/array-api/draft/API_specification/generated/array_api.array.size.html) 65 | attribute in the spec. Use the {func}`~.size()` helper function as a 66 | portable workaround. 67 | 68 | - PyTorch has incomplete support for unsigned integer types other 69 | than `uint8`, and no attempt is made to implement them here. 70 | 71 | - PyTorch has type promotion semantics that differ from the array API 72 | specification for 0-D tensor objects. The array functions in this wrapper 73 | library do work around this, but the operators on the Tensor object do not, 74 | as no operators or methods on the Tensor object are modified. If this is a 75 | concern, use the functional form instead of the operator form, e.g., `add(x, 76 | y)` instead of `x + y`. 77 | 78 | - [`unique_all()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html#array_api.unique_all) 79 | is not implemented, due to the fact that `torch.unique` does not support 80 | returning the `indices` array. The other 81 | [`unique_*`](https://data-apis.org/array-api/latest/API_specification/set_functions.html) 82 | functions are implemented. 83 | 84 | - Slices do not support negative steps. 85 | 86 | - [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std) 87 | and 88 | [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var) 89 | do not support floating-point `correction`. 90 | 91 | - The `stream` argument of the {func}`~.to_device()` helper is not supported. 92 | 93 | - As with NumPy, type annotations and positional-only arguments may not 94 | exactly match the spec for functions that are not wrapped at all. 95 | 96 | (jax-support)= 97 | ## [JAX](https://jax.readthedocs.io/en/latest/) 98 | 99 | Unlike the other libraries supported here, JAX array API support is contained 100 | entirely in the JAX library. The JAX array API support is tracked at 101 | https://github.com/google/jax/issues/18353. 102 | 103 | ## [Dask](https://www.dask.org/) 104 | 105 | If you're using dask with numpy, many of the same limitations that apply to numpy 106 | will also apply to dask. Besides those differences, other limitations include missing 107 | sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg` 108 | and `fft` extensions. 109 | 110 | In particular, the `fft` namespace is not compliant with the array API spec. Any functions 111 | that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk! 112 | 113 | For `linalg`, several methods are missing, for example: 114 | - `cross` 115 | - `det` 116 | - `eigh` 117 | - `eigvalsh` 118 | - `matrix_power` 119 | - `pinv` 120 | - `slogdet` 121 | - `matrix_norm` 122 | - `matrix_rank` 123 | Other methods may only be partially implemented or return incorrect results at times. 124 | 125 | (sparse-support)= 126 | ## [Sparse](https://sparse.pydata.org/en/stable/) 127 | 128 | Similar to JAX, `sparse` Array API support is contained directly in `sparse`. 129 | 130 | (ndonnx-support)= 131 | ## [ndonnx](https://github.com/quantco/ndonnx) 132 | 133 | Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`. 134 | 135 | (array-api-strict-support)= 136 | ## [array-api-strict](https://data-apis.org/array-api-strict/) 137 | 138 | array-api-strict exists only to test support for the Array API, so it does not need any wrappers. 139 | -------------------------------------------------------------------------------- /numpy-1-26-xfails.txt: -------------------------------------------------------------------------------- 1 | # attributes are np.float32 instead of float 2 | # (see also https://github.com/data-apis/array-api/issues/405) 3 | array_api_tests/test_data_type_functions.py::test_finfo[float32] 4 | array_api_tests/test_data_type_functions.py::test_finfo[complex64] 5 | 6 | # Array methods and attributes not already on np.ndarray cannot be wrapped 7 | array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] 8 | array_api_tests/test_has_names.py::test_has_names[array_method-to_device] 9 | array_api_tests/test_has_names.py::test_has_names[array_attribute-device] 10 | array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] 11 | 12 | # Array methods and attributes not already on np.ndarray cannot be wrapped 13 | array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] 14 | array_api_tests/test_signatures.py::test_array_method_signature[to_device] 15 | 16 | # NumPy deviates in some special cases for floordiv 17 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 18 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 19 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 20 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 21 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 22 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 23 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 24 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 25 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 26 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 27 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 28 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 29 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 30 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 31 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 32 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 33 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 34 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 35 | 36 | # https://github.com/numpy/numpy/issues/21213 37 | array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices 38 | 39 | # 2023.12 support 40 | array_api_tests/test_signatures.py::test_func_signature[from_dlpack] 41 | array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] 42 | # uint64 repeats not supported 43 | array_api_tests/test_manipulation_functions.py::test_repeat 44 | 45 | # 2024.12 support 46 | array_api_tests/test_signatures.py::test_func_signature[bitwise_and] 47 | array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] 48 | array_api_tests/test_signatures.py::test_func_signature[bitwise_or] 49 | array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] 50 | array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] 51 | array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars 52 | 53 | array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars 54 | 55 | # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that 56 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 57 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 58 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 59 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 60 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 61 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 62 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 63 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 64 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 65 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 66 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 67 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 68 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 69 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 70 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 71 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 72 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 73 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 74 | -------------------------------------------------------------------------------- /numpy-dev-xfails.txt: -------------------------------------------------------------------------------- 1 | # attributes are np.float32 instead of float 2 | # (see also https://github.com/data-apis/array-api/issues/405) 3 | array_api_tests/test_data_type_functions.py::test_finfo[float32] 4 | array_api_tests/test_data_type_functions.py::test_finfo[complex64] 5 | 6 | # The test suite cannot properly get the signature for vecdot 7 | # https://github.com/numpy/numpy/pull/26237 8 | array_api_tests/test_signatures.py::test_func_signature[vecdot] 9 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] 10 | 11 | # 2023.12 support 12 | # uint64 repeats not supported 13 | array_api_tests/test_manipulation_functions.py::test_repeat 14 | 15 | # 2024.12 support 16 | array_api_tests/test_signatures.py::test_func_signature[bitwise_and] 17 | array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] 18 | array_api_tests/test_signatures.py::test_func_signature[bitwise_or] 19 | array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] 20 | array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] 21 | 22 | # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that 23 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 24 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 25 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 26 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 27 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 28 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 29 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 30 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 31 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 32 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 33 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 34 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 35 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 36 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 37 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 38 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 39 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 40 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 41 | -------------------------------------------------------------------------------- /numpy-skips.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/numpy-skips.txt -------------------------------------------------------------------------------- /numpy-xfails.txt: -------------------------------------------------------------------------------- 1 | # attributes are np.float32 instead of float 2 | # (see also https://github.com/data-apis/array-api/issues/405) 3 | array_api_tests/test_data_type_functions.py::test_finfo[float32] 4 | array_api_tests/test_data_type_functions.py::test_finfo[complex64] 5 | 6 | # The test suite cannot properly get the signature for vecdot 7 | # https://github.com/numpy/numpy/pull/26237 8 | array_api_tests/test_signatures.py::test_func_signature[vecdot] 9 | array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] 10 | 11 | # 2023.12 support 12 | # uint64 repeats not supported 13 | array_api_tests/test_manipulation_functions.py::test_repeat 14 | 15 | # 2024.12 support 16 | array_api_tests/test_signatures.py::test_func_signature[bitwise_and] 17 | array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] 18 | array_api_tests/test_signatures.py::test_func_signature[bitwise_or] 19 | array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] 20 | array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] 21 | 22 | # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that 23 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 24 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 25 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 26 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 27 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 28 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 29 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 30 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 31 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 32 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 33 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 34 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 35 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 36 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 37 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 38 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 39 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 40 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 41 | 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "array-api-compat" 7 | dynamic = ["version"] 8 | description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = "MIT" 12 | authors = [{name = "Consortium for Python Data API Standards"}] 13 | classifiers = [ 14 | "Operating System :: OS Independent", 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.10", 17 | "Programming Language :: Python :: 3.11", 18 | "Programming Language :: Python :: 3.12", 19 | "Programming Language :: Python :: 3.13", 20 | "Topic :: Software Development :: Libraries :: Python Modules", 21 | "Typing :: Typed", 22 | ] 23 | 24 | [project.optional-dependencies] 25 | cupy = ["cupy"] 26 | dask = ["dask>=2024.9.0"] 27 | jax = ["jax"] 28 | # Note: array-api-compat follows scikit-learn minimum dependencies, which support 29 | # much older versions of NumPy than what SPEC0 recommends. 30 | numpy = ["numpy>=1.22"] 31 | pytorch = ["torch"] 32 | sparse = ["sparse>=0.15.1"] 33 | ndonnx = ["ndonnx"] 34 | docs = [ 35 | "furo", 36 | "linkify-it-py", 37 | "myst-parser", 38 | "sphinx", 39 | "sphinx-copybutton", 40 | "sphinx-autobuild", 41 | ] 42 | dev = [ 43 | "array-api-strict", 44 | "dask[array]>=2024.9.0", 45 | "jax[cpu]", 46 | "ndonnx", 47 | "numpy>=1.22", 48 | "pytest", 49 | "torch", 50 | "sparse>=0.15.1", 51 | ] 52 | 53 | [project.urls] 54 | homepage = "https://data-apis.org/array-api-compat/" 55 | repository = "https://github.com/data-apis/array-api-compat/" 56 | 57 | [tool.setuptools.dynamic] 58 | version = {attr = "array_api_compat.__version__"} 59 | 60 | [tool.setuptools.packages.find] 61 | include = ["array_api_compat*"] 62 | namespaces = false 63 | 64 | [tool.ruff.lint] 65 | preview = true 66 | select = [ 67 | # Defaults 68 | "E4", "E7", "E9", "F", 69 | # Undefined export 70 | "F822", 71 | # Useless import alias 72 | "PLC0414" 73 | ] 74 | 75 | ignore = [ 76 | # Module import not at top of file 77 | "E402", 78 | # Do not use bare `except` 79 | "E722" 80 | ] 81 | 82 | 83 | [tool.mypy] 84 | files = ["array_api_compat"] 85 | disallow_incomplete_defs = true 86 | disallow_untyped_decorators = true 87 | disallow_untyped_defs = false # TODO 88 | ignore_missing_imports = false 89 | no_implicit_optional = true 90 | show_error_codes = true 91 | warn_redundant_casts = true 92 | warn_unused_ignores = true 93 | warn_unreachable = true 94 | 95 | [[tool.mypy.overrides]] 96 | module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] 97 | ignore_missing_imports = true 98 | 99 | 100 | [tool.pyright] 101 | include = ["src", "tests"] 102 | pythonPlatform = "All" 103 | 104 | reportAny = false 105 | reportExplicitAny = false 106 | # missing type stubs 107 | reportAttributeAccessIssue = false 108 | reportUnknownMemberType = false 109 | reportUnknownVariableType = false 110 | # Redundant with mypy checks 111 | reportMissingImports = false 112 | reportMissingTypeStubs = false 113 | # false positives for input validation 114 | reportUnreachable = false 115 | # ruff handles this 116 | reportUnusedParameter = false 117 | 118 | executionEnvironments = [ 119 | { root = "array_api_compat" }, 120 | ] 121 | -------------------------------------------------------------------------------- /test_cupy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # We cannot test cupy on CI so this script will test it manually. Assumes it 3 | # is being run in an environment that has cupy and the array-api-tests 4 | # dependencies installed 5 | set -x 6 | set -e 7 | 8 | # Run the vendoring tests in this repo 9 | pytest 10 | 11 | tmpdir=$(mktemp -d) 12 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 13 | export PYTHONPATH="$PYTHONPATH:$SCRIPT_DIR" 14 | 15 | PYTEST_ARGS="--max-examples 200 -v -rxXfE --ci --hypothesis-disable-deadline" 16 | 17 | cd $tmpdir 18 | git clone https://github.com/data-apis/array-api-tests 19 | cd array-api-tests 20 | 21 | git submodule update --init 22 | 23 | # store the hypothesis examples database in this directory, so that failures 24 | # will be remembered across runs 25 | mkdir -p $SCRIPT_DIR/.hypothesis 26 | ln -s $SCRIPT_DIR/.hypothesis .hypothesis 27 | 28 | export ARRAY_API_TESTS_MODULE=array_api_compat.cupy 29 | export ARRAY_API_TESTS_VERSION=2024.12 30 | pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" 31 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic tests for the compat library 3 | 4 | This only tests basic things like that vendoring works. The extensive tests 5 | are done by the array API test suite https://github.com/data-apis/array-api-tests 6 | 7 | """ 8 | -------------------------------------------------------------------------------- /tests/_helpers.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | import pytest 4 | 5 | wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] 6 | all_libraries = wrapped_libraries + [ 7 | "array_api_strict", "jax.numpy", "ndonnx", "sparse" 8 | ] 9 | 10 | def import_(library, wrapper=False): 11 | pytest.importorskip(library) 12 | if wrapper: 13 | if 'jax' in library: 14 | # JAX v0.4.32 implements the array API directly in jax.numpy 15 | # Older jax versions use jax.experimental.array_api 16 | jax_numpy = import_module("jax.numpy") 17 | if not hasattr(jax_numpy, "__array_api_version__"): 18 | library = 'jax.experimental.array_api' 19 | elif library in wrapped_libraries: 20 | library = 'array_api_compat.' + library 21 | 22 | return import_module(library) 23 | 24 | 25 | def xfail(request: pytest.FixtureRequest, reason: str) -> None: 26 | """ 27 | XFAIL the currently running test. 28 | 29 | Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately 30 | halting it, so that it may result in a XPASS. 31 | xref https://github.com/pandas-dev/pandas/issues/38902 32 | """ 33 | request.node.add_marker(pytest.mark.xfail(reason=reason)) 34 | -------------------------------------------------------------------------------- /tests/test_all.py: -------------------------------------------------------------------------------- 1 | """Test exported names""" 2 | 3 | import builtins 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from array_api_compat._internal import clone_module 9 | 10 | from ._helpers import wrapped_libraries 11 | 12 | NAMES = { 13 | "": [ 14 | # Inspection 15 | "__array_api_version__", 16 | "__array_namespace_info__", 17 | # Submodules 18 | "fft", 19 | "linalg", 20 | # Constants 21 | "e", 22 | "inf", 23 | "nan", 24 | "newaxis", 25 | "pi", 26 | # Creation Functions 27 | "arange", 28 | "asarray", 29 | "empty", 30 | "empty_like", 31 | "eye", 32 | "from_dlpack", 33 | "full", 34 | "full_like", 35 | "linspace", 36 | "meshgrid", 37 | "ones", 38 | "ones_like", 39 | "tril", 40 | "triu", 41 | "zeros", 42 | "zeros_like", 43 | # Data Type Functions 44 | "astype", 45 | "can_cast", 46 | "finfo", 47 | "iinfo", 48 | "isdtype", 49 | "result_type", 50 | # Data Types 51 | "bool", 52 | "int8", 53 | "int16", 54 | "int32", 55 | "int64", 56 | "uint8", 57 | "uint16", 58 | "uint32", 59 | "uint64", 60 | "float32", 61 | "float64", 62 | "complex64", 63 | "complex128", 64 | # Elementwise Functions 65 | "abs", 66 | "acos", 67 | "acosh", 68 | "add", 69 | "asin", 70 | "asinh", 71 | "atan", 72 | "atan2", 73 | "atanh", 74 | "bitwise_and", 75 | "bitwise_left_shift", 76 | "bitwise_invert", 77 | "bitwise_or", 78 | "bitwise_right_shift", 79 | "bitwise_xor", 80 | "ceil", 81 | "clip", 82 | "conj", 83 | "copysign", 84 | "cos", 85 | "cosh", 86 | "divide", 87 | "equal", 88 | "exp", 89 | "expm1", 90 | "floor", 91 | "floor_divide", 92 | "greater", 93 | "greater_equal", 94 | "hypot", 95 | "imag", 96 | "isfinite", 97 | "isinf", 98 | "isnan", 99 | "less", 100 | "less_equal", 101 | "log", 102 | "log1p", 103 | "log2", 104 | "log10", 105 | "logaddexp", 106 | "logical_and", 107 | "logical_not", 108 | "logical_or", 109 | "logical_xor", 110 | "maximum", 111 | "minimum", 112 | "multiply", 113 | "negative", 114 | "nextafter", 115 | "not_equal", 116 | "positive", 117 | "pow", 118 | "real", 119 | "reciprocal", 120 | "remainder", 121 | "round", 122 | "sign", 123 | "signbit", 124 | "sin", 125 | "sinh", 126 | "square", 127 | "sqrt", 128 | "subtract", 129 | "tan", 130 | "tanh", 131 | "trunc", 132 | # Indexing Functions 133 | "take", 134 | "take_along_axis", 135 | # Linear Algebra Functions 136 | "matmul", 137 | "matrix_transpose", 138 | "tensordot", 139 | "vecdot", 140 | # Manipulation Functions 141 | "broadcast_arrays", 142 | "broadcast_to", 143 | "concat", 144 | "expand_dims", 145 | "flip", 146 | "moveaxis", 147 | "permute_dims", 148 | "repeat", 149 | "reshape", 150 | "roll", 151 | "squeeze", 152 | "stack", 153 | "tile", 154 | "unstack", 155 | # Searching Functions 156 | "argmax", 157 | "argmin", 158 | "count_nonzero", 159 | "nonzero", 160 | "searchsorted", 161 | "where", 162 | # Set Functions 163 | "unique_all", 164 | "unique_counts", 165 | "unique_inverse", 166 | "unique_values", 167 | # Sorting Functions 168 | "argsort", 169 | "sort", 170 | # Statistical Functions 171 | "cumulative_prod", 172 | "cumulative_sum", 173 | "max", 174 | "mean", 175 | "min", 176 | "prod", 177 | "std", 178 | "sum", 179 | "var", 180 | # Utility Functions 181 | "all", 182 | "any", 183 | "diff", 184 | ], 185 | "fft": [ 186 | "fft", 187 | "ifft", 188 | "fftn", 189 | "ifftn", 190 | "rfft", 191 | "irfft", 192 | "rfftn", 193 | "irfftn", 194 | "hfft", 195 | "ihfft", 196 | "fftfreq", 197 | "rfftfreq", 198 | "fftshift", 199 | "ifftshift", 200 | ], 201 | "linalg": [ 202 | "cholesky", 203 | "cross", 204 | "det", 205 | "diagonal", 206 | "eigh", 207 | "eigvalsh", 208 | "inv", 209 | "matmul", 210 | "matrix_norm", 211 | "matrix_power", 212 | "matrix_rank", 213 | "matrix_transpose", 214 | "outer", 215 | "pinv", 216 | "qr", 217 | "slogdet", 218 | "solve", 219 | "svd", 220 | "svdvals", 221 | "tensordot", 222 | "trace", 223 | "vecdot", 224 | "vector_norm", 225 | ], 226 | } 227 | 228 | XFAILS = { 229 | ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], 230 | ("dask.array", ""): ["from_dlpack", "take_along_axis"], 231 | ("dask.array", "linalg"): [ 232 | "cross", 233 | "det", 234 | "eigh", 235 | "eigvalsh", 236 | "matrix_power", 237 | "pinv", 238 | "slogdet", 239 | ], 240 | } 241 | 242 | 243 | def all_names(mod): 244 | """Return all names available in a module.""" 245 | objs = {} 246 | clone_module(mod.__name__, objs) 247 | return set(objs) 248 | 249 | 250 | def get_mod(library, module, *, compat): 251 | if compat: 252 | library = f"array_api_compat.{library}" 253 | xp = pytest.importorskip(library) 254 | return getattr(xp, module) if module else xp 255 | 256 | 257 | @pytest.mark.parametrize("module", list(NAMES)) 258 | @pytest.mark.parametrize("library", wrapped_libraries) 259 | def test_array_api_names(library, module): 260 | """Test that __all__ isn't missing any exports 261 | dictated by the Standard. 262 | """ 263 | mod = get_mod(library, module, compat=True) 264 | missing = set(NAMES[module]) - all_names(mod) 265 | xfail = set(XFAILS.get((library, module), [])) 266 | xpass = xfail - missing 267 | fails = missing - xfail 268 | assert not xpass, f"Names in XFAILS are defined: {xpass}" 269 | assert not fails, f"Missing exports: {fails}" 270 | 271 | 272 | @pytest.mark.parametrize("module", list(NAMES)) 273 | @pytest.mark.parametrize("library", wrapped_libraries) 274 | def test_compat_doesnt_hide_names(library, module): 275 | """The base namespace can have more names than the ones explicitly exported 276 | by array-api-compat. Test that we're not suppressing them. 277 | """ 278 | bare_mod = get_mod(library, module, compat=False) 279 | compat_mod = get_mod(library, module, compat=True) 280 | 281 | missing = all_names(bare_mod) - all_names(compat_mod) 282 | missing = {name for name in missing if not name.startswith("_")} 283 | assert not missing, f"Non-Array API names have been hidden: {missing}" 284 | 285 | 286 | @pytest.mark.parametrize("module", list(NAMES)) 287 | @pytest.mark.parametrize("library", wrapped_libraries) 288 | def test_compat_doesnt_add_names(library, module): 289 | """Test that array-api-compat isn't adding names to the namespace 290 | besides those defined by the Array API Standard. 291 | """ 292 | bare_mod = get_mod(library, module, compat=False) 293 | compat_mod = get_mod(library, module, compat=True) 294 | 295 | aapi_names = set(NAMES[module]) 296 | spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names 297 | # Quietly ignore *Result dataclasses 298 | spurious = {name for name in spurious if not name.endswith("Result")} 299 | assert not spurious, ( 300 | f"array-api-compat is adding non-Array API names: {spurious}" 301 | ) 302 | 303 | 304 | @pytest.mark.parametrize( 305 | "name", [name for name in NAMES[""] if hasattr(builtins, name)] 306 | ) 307 | @pytest.mark.parametrize("library", wrapped_libraries) 308 | def test_builtins_collision(library, name): 309 | """Test that xp.bool is not accidentally builtins.bool, etc.""" 310 | xp = pytest.importorskip(f"array_api_compat.{library}") 311 | assert getattr(xp, name) is not getattr(builtins, name) 312 | -------------------------------------------------------------------------------- /tests/test_array_namespace.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import warnings 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | import array_api_compat 9 | from array_api_compat import array_namespace 10 | 11 | from ._helpers import all_libraries, wrapped_libraries, xfail 12 | 13 | 14 | @pytest.mark.parametrize("use_compat", [True, False, None]) 15 | @pytest.mark.parametrize( 16 | "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] 17 | ) 18 | @pytest.mark.parametrize("library", all_libraries) 19 | def test_array_namespace(request, library, api_version, use_compat): 20 | xp = pytest.importorskip(library) 21 | 22 | array = xp.asarray([1.0, 2.0, 3.0]) 23 | if use_compat and library not in wrapped_libraries: 24 | pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) 25 | return 26 | if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( 27 | library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") 28 | ): 29 | xfail(request, "Unsupported API version") 30 | 31 | with warnings.catch_warnings(): 32 | warnings.simplefilter('ignore', UserWarning) 33 | namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) 34 | 35 | if use_compat is False or use_compat is None and library not in wrapped_libraries: 36 | if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): 37 | # Backwards compatibility for JAX <0.4.32 38 | import jax.experimental.array_api 39 | assert namespace == jax.experimental.array_api 40 | else: 41 | assert namespace == xp 42 | elif library == "dask.array": 43 | assert namespace == array_api_compat.dask.array 44 | else: 45 | assert namespace == getattr(array_api_compat, library) 46 | 47 | if library == "numpy": 48 | # check that the same namespace is returned for NumPy scalars 49 | with warnings.catch_warnings(): 50 | warnings.simplefilter('ignore', UserWarning) 51 | 52 | scalar_namespace = array_namespace( 53 | xp.float64(0.0), api_version=api_version, use_compat=use_compat 54 | ) 55 | assert scalar_namespace == namespace 56 | 57 | 58 | def test_jax_backwards_compat(): 59 | """On JAX <0.4.32, test that array_namespace works even if 60 | jax.experimental.array_api has not been imported yet. 61 | """ 62 | pytest.importorskip("jax") 63 | code = """\ 64 | import sys 65 | import jax.numpy 66 | import array_api_compat 67 | 68 | array = jax.numpy.asarray([1.0, 2.0, 3.0]) 69 | assert 'jax.experimental.array_api' not in sys.modules 70 | namespace = array_api_compat.array_namespace(array) 71 | 72 | if hasattr(jax.numpy, '__array_api_version__'): 73 | assert namespace == jax.numpy 74 | else: 75 | import jax.experimental.array_api 76 | assert namespace == jax.experimental.array_api 77 | """ 78 | subprocess.check_call([sys.executable, "-c", code]) 79 | 80 | 81 | def test_jax_zero_gradient(): 82 | jax = pytest.importorskip("jax") 83 | jx = jax.numpy.arange(4) 84 | jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) 85 | assert array_namespace(jax_zero) is array_namespace(jx) 86 | 87 | 88 | def test_array_namespace_errors(): 89 | pytest.raises(TypeError, lambda: array_namespace([1])) 90 | pytest.raises(TypeError, lambda: array_namespace()) 91 | 92 | x = np.asarray([1, 2]) 93 | pytest.raises(TypeError, lambda: array_namespace((x, x))) 94 | pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) 95 | 96 | 97 | @pytest.mark.parametrize("library", all_libraries) 98 | def test_array_namespace_many_args(library): 99 | xp = pytest.importorskip(library) 100 | a = xp.asarray(1) 101 | b = xp.asarray(2) 102 | assert array_namespace(a, b) is array_namespace(a) 103 | 104 | 105 | def test_array_namespace_mismatch(): 106 | xp = pytest.importorskip("array_api_strict") 107 | with pytest.raises(TypeError, match="Multiple namespaces"): 108 | array_namespace(np.asarray(1), xp.asarray(1)) 109 | 110 | 111 | def test_get_namespace(): 112 | # Backwards compatible wrapper 113 | assert array_api_compat.get_namespace is array_namespace 114 | 115 | 116 | @pytest.mark.parametrize("library", all_libraries) 117 | def test_python_scalars(library): 118 | xp = pytest.importorskip(library) 119 | a = xp.asarray([1, 2]) 120 | xp = array_namespace(a) 121 | 122 | pytest.raises(TypeError, lambda: array_namespace(1)) 123 | pytest.raises(TypeError, lambda: array_namespace(1.0)) 124 | pytest.raises(TypeError, lambda: array_namespace(1j)) 125 | pytest.raises(TypeError, lambda: array_namespace(True)) 126 | pytest.raises(TypeError, lambda: array_namespace(None)) 127 | 128 | assert array_namespace(a, 1) is xp 129 | assert array_namespace(a, 1.0) is xp 130 | assert array_namespace(a, 1j) is xp 131 | assert array_namespace(a, True) is xp 132 | assert array_namespace(a, None) is xp 133 | -------------------------------------------------------------------------------- /tests/test_cupy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from array_api_compat import device, to_device 3 | 4 | xp = pytest.importorskip("array_api_compat.cupy") 5 | from cupy.cuda import Stream 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "make_stream", 10 | [ 11 | lambda: Stream(), 12 | lambda: Stream(non_blocking=True), 13 | lambda: Stream(null=True), 14 | lambda: Stream(ptds=True), 15 | ], 16 | ) 17 | def test_to_device_with_stream(make_stream): 18 | devices = xp.__array_namespace_info__().devices() 19 | 20 | a = xp.asarray([1, 2, 3]) 21 | for dev in devices: 22 | # Streams are device-specific and must be created within 23 | # the context of the device... 24 | with dev: 25 | stream = make_stream() 26 | # ... however, to_device() does not need to be inside the 27 | # device context. 28 | b = to_device(a, dev, stream=stream) 29 | assert device(b) == dev 30 | 31 | 32 | def test_to_device_with_dlpack_stream(): 33 | devices = xp.__array_namespace_info__().devices() 34 | 35 | a = xp.asarray([1, 2, 3]) 36 | for dev in devices: 37 | # Streams are device-specific and must be created within 38 | # the context of the device... 39 | with dev: 40 | s1 = Stream() 41 | 42 | # ... however, to_device() does not need to be inside the 43 | # device context. 44 | b = to_device(a, dev, stream=s1.ptr) 45 | assert device(b) == dev 46 | -------------------------------------------------------------------------------- /tests/test_dask.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | try: 7 | import dask 8 | import dask.array as da 9 | except ImportError: 10 | pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") 11 | 12 | from array_api_compat import array_namespace 13 | 14 | 15 | @pytest.fixture 16 | def xp(): 17 | """Fixture returning the wrapped dask namespace""" 18 | return array_namespace(da.empty(0)) 19 | 20 | 21 | @contextmanager 22 | def assert_no_compute(): 23 | """ 24 | Context manager that raises if at any point inside it anything calls compute() 25 | or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. 26 | """ 27 | 28 | def get(dsk, *args, **kwargs): 29 | raise AssertionError("Called compute() or persist()") 30 | 31 | with dask.config.set(scheduler=get): 32 | yield 33 | 34 | 35 | def test_assert_no_compute(): 36 | """Test the assert_no_compute context manager""" 37 | a = da.asarray(True) 38 | with pytest.raises(AssertionError, match="Called compute"): 39 | with assert_no_compute(): 40 | bool(a) 41 | 42 | # Exiting the context manager restores the original scheduler 43 | assert bool(a) is True 44 | 45 | 46 | # Test no_compute for functions that use generic _aliases with xp=np 47 | 48 | 49 | def test_unary_ops_no_compute(xp): 50 | with assert_no_compute(): 51 | a = xp.asarray([1.5, -1.5]) 52 | xp.ceil(a) 53 | xp.floor(a) 54 | xp.trunc(a) 55 | xp.sign(a) 56 | 57 | 58 | def test_matmul_tensordot_no_compute(xp): 59 | A = da.ones((4, 4), chunks=2) 60 | B = da.zeros((4, 4), chunks=2) 61 | with assert_no_compute(): 62 | xp.matmul(A, B) 63 | xp.tensordot(A, B) 64 | 65 | 66 | # Test no_compute for functions that are fully bespoke for dask 67 | 68 | 69 | def test_asarray_no_compute(xp): 70 | with assert_no_compute(): 71 | a = xp.arange(10) 72 | xp.asarray(a) 73 | xp.asarray(a, dtype=np.int16) 74 | xp.asarray(a, dtype=a.dtype) 75 | xp.asarray(a, copy=True) 76 | xp.asarray(a, copy=True, dtype=np.int16) 77 | xp.asarray(a, copy=True, dtype=a.dtype) 78 | xp.asarray(a, copy=False) 79 | xp.asarray(a, copy=False, dtype=a.dtype) 80 | 81 | 82 | @pytest.mark.parametrize("copy", [True, False]) 83 | def test_astype_no_compute(xp, copy): 84 | with assert_no_compute(): 85 | a = xp.arange(10) 86 | xp.astype(a, np.int16, copy=copy) 87 | xp.astype(a, a.dtype, copy=copy) 88 | 89 | 90 | def test_clip_no_compute(xp): 91 | with assert_no_compute(): 92 | a = xp.arange(10) 93 | xp.clip(a) 94 | xp.clip(a, 1) 95 | xp.clip(a, 1, 8) 96 | 97 | 98 | @pytest.mark.parametrize("chunks", (5, 10)) 99 | def test_sort_argsort_nocompute(xp, chunks): 100 | with assert_no_compute(): 101 | a = xp.arange(10, chunks=chunks) 102 | xp.sort(a) 103 | xp.argsort(a) 104 | 105 | 106 | def test_generators_are_lazy(xp): 107 | """ 108 | Test that generator functions are fully lazy, e.g. that 109 | da.ones(n) is not implemented as da.asarray(np.ones(n)) 110 | """ 111 | size = 100_000_000_000 # 800 GB 112 | chunks = size // 10 # 10x 80 GB chunks 113 | 114 | with assert_no_compute(): 115 | xp.zeros(size, chunks=chunks) 116 | xp.ones(size, chunks=chunks) 117 | xp.empty(size, chunks=chunks) 118 | xp.full(size, fill_value=123, chunks=chunks) 119 | a = xp.arange(size, chunks=chunks) 120 | xp.zeros_like(a) 121 | xp.ones_like(a) 122 | xp.empty_like(a) 123 | xp.full_like(a, fill_value=123) 124 | 125 | 126 | @pytest.mark.parametrize("axis", [0, 1]) 127 | @pytest.mark.parametrize("func", ["sort", "argsort"]) 128 | def test_sort_argsort_chunks(xp, func, axis): 129 | """Test that sort and argsort are functionally correct when 130 | the array is chunked along the sort axis, e.g. the sort is 131 | not just local to each chunk. 132 | """ 133 | a = da.random.random((10, 10), chunks=(5, 5)) 134 | actual = getattr(xp, func)(a, axis=axis) 135 | expect = getattr(np, func)(a.compute(), axis=axis) 136 | np.testing.assert_array_equal(actual, expect) 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "shape,chunks", 141 | [ 142 | # 3 GiB; 128 MiB per chunk; must rechunk before sorting. 143 | # Sort chunks can be 128 MiB each; no need for final rechunk. 144 | ((20_000, 20_000), "auto"), 145 | # 3 GiB; 128 MiB per chunk; must rechunk before sorting. 146 | # Must sort on two 1.5 GiB chunks; benefits from final rechunk. 147 | ((2, 2**30 * 3 // 16), "auto"), 148 | # 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting. 149 | # Surely the user must know what they're doing, so don't 150 | # perform the final rechunk. 151 | ((2, 2**30 * 3 // 16), (1, -1)), 152 | ], 153 | ) 154 | @pytest.mark.parametrize("func", ["sort", "argsort"]) 155 | def test_sort_argsort_chunk_size(xp, func, shape, chunks): 156 | """ 157 | Test that sort and argsort produce reasonably-sized chunks 158 | in the output array, even if they had to go through a singular 159 | huge one to perform the operation. 160 | """ 161 | a = da.random.random(shape, chunks=chunks) 162 | b = getattr(xp, func)(a) 163 | max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize 164 | assert ( 165 | max_chunk_size <= 128 * 1024 * 1024 # 128 MiB 166 | or b.chunks == a.chunks 167 | ) 168 | 169 | 170 | @pytest.mark.parametrize("func", ["sort", "argsort"]) 171 | def test_sort_argsort_meta(xp, func): 172 | """Test meta-namespace other than numpy""" 173 | mxp = pytest.importorskip("array_api_strict") 174 | typ = type(mxp.asarray(0)) 175 | a = da.random.random(10) 176 | b = a.map_blocks(mxp.asarray) 177 | assert isinstance(b._meta, typ) 178 | c = getattr(xp, func)(b) 179 | assert isinstance(c._meta, typ) 180 | d = c.compute() 181 | # Note: np.sort(array_api_strict.asarray(0)) would return a numpy array 182 | assert isinstance(d, typ) 183 | np.testing.assert_array_equal(d, getattr(np, func)(a.compute())) 184 | -------------------------------------------------------------------------------- /tests/test_isdtype.py: -------------------------------------------------------------------------------- 1 | """ 2 | isdtype is not yet tested in the test suite, and it should extend properly to 3 | non-spec dtypes 4 | """ 5 | 6 | import pytest 7 | 8 | from ._helpers import import_, wrapped_libraries 9 | 10 | # Check the known dtypes by their string names 11 | 12 | def _spec_dtypes(library): 13 | if library == 'torch': 14 | # torch does not have unsigned integer dtypes 15 | return { 16 | 'bool', 17 | 'complex64', 18 | 'complex128', 19 | 'uint8', 20 | 'int8', 21 | 'int16', 22 | 'int32', 23 | 'int64', 24 | 'float32', 25 | 'float64', 26 | } 27 | else: 28 | return { 29 | 'bool', 30 | 'complex64', 31 | 'complex128', 32 | 'float32', 33 | 'float64', 34 | 'int16', 35 | 'int32', 36 | 'int64', 37 | 'int8', 38 | 'uint16', 39 | 'uint32', 40 | 'uint64', 41 | 'uint8', 42 | } 43 | 44 | dtype_categories = { 45 | 'bool': lambda d: d == 'bool', 46 | 'signed integer': lambda d: d.startswith('int'), 47 | 'unsigned integer': lambda d: d.startswith('uint'), 48 | 'integral': lambda d: dtype_categories['signed integer'](d) or 49 | dtype_categories['unsigned integer'](d), 50 | 'real floating': lambda d: 'float' in d, 51 | 'complex floating': lambda d: d.startswith('complex'), 52 | 'numeric': lambda d: dtype_categories['integral'](d) or 53 | dtype_categories['real floating'](d) or 54 | dtype_categories['complex floating'](d), 55 | } 56 | 57 | def isdtype_(dtype_, kind): 58 | # Check a dtype_ string against kind. Note that 'bool' technically has two 59 | # meanings here but they are both the same. 60 | if kind in dtype_categories: 61 | res = dtype_categories[kind](dtype_) 62 | else: 63 | res = dtype_ == kind 64 | assert type(res) is bool # noqa: E721 65 | return res 66 | 67 | @pytest.mark.parametrize("library", wrapped_libraries) 68 | def test_isdtype_spec_dtypes(library): 69 | xp = import_(library, wrapper=True) 70 | 71 | isdtype = xp.isdtype 72 | 73 | for dtype_ in _spec_dtypes(library): 74 | for dtype2_ in _spec_dtypes(library): 75 | dtype = getattr(xp, dtype_) 76 | dtype2 = getattr(xp, dtype2_) 77 | res = isdtype_(dtype_, dtype2_) 78 | assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_) 79 | 80 | for cat in dtype_categories: 81 | res = isdtype_(dtype_, cat) 82 | assert isdtype(dtype, cat) == res, (dtype_, cat) 83 | 84 | # Basic tuple testing (the array-api testsuite will be more complete here) 85 | for kind1_ in [*_spec_dtypes(library), *dtype_categories]: 86 | for kind2_ in [*_spec_dtypes(library), *dtype_categories]: 87 | kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_) 88 | kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_) 89 | kind = (kind1, kind2) 90 | 91 | res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_) 92 | assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_)) 93 | 94 | additional_dtypes = [ 95 | 'float16', 96 | 'float128', 97 | 'complex256', 98 | 'bfloat16', 99 | ] 100 | 101 | @pytest.mark.parametrize("library", wrapped_libraries) 102 | @pytest.mark.parametrize("dtype_", additional_dtypes) 103 | def test_isdtype_additional_dtypes(library, dtype_): 104 | xp = import_(library, wrapper=True) 105 | 106 | isdtype = xp.isdtype 107 | 108 | if not hasattr(xp, dtype_): 109 | return 110 | # pytest.skip(f"{library} doesn't have dtype {dtype_}") 111 | 112 | dtype = getattr(xp, dtype_) 113 | for cat in dtype_categories: 114 | res = isdtype_(dtype_, cat) 115 | assert isdtype(dtype, cat) == res, (dtype_, cat) 116 | -------------------------------------------------------------------------------- /tests/test_jax.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_equal 2 | import pytest 3 | 4 | from array_api_compat import device, to_device 5 | 6 | try: 7 | import jax 8 | import jax.numpy as jnp 9 | except ImportError: 10 | pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") 11 | 12 | HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "func", 17 | [ 18 | lambda x: jnp.zeros(1, device=device(x)), 19 | lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), 20 | lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))), 21 | lambda x: jnp.full(1, fill_value=0, device=device(x)), 22 | pytest.param( 23 | lambda x: jnp.asarray([0], device=device(x)), 24 | marks=pytest.mark.skipif( 25 | not HAS_JAX_0_4_31, reason="asarray() has no device= parameter" 26 | ), 27 | ), 28 | lambda x: to_device(jnp.zeros(1), device(x)), 29 | ] 30 | ) 31 | def test_device_jit(func): 32 | # Test work around to https://github.com/jax-ml/jax/issues/26000 33 | # Also test missing to_device() method in JAX < 0.4.31 34 | # when inside jax.jit, even after importing jax.experimental.array_api 35 | 36 | x = jnp.ones(1) 37 | assert_equal(func(x), jnp.asarray([0])) 38 | assert_equal(jax.jit(func)(x), jnp.asarray([0])) 39 | -------------------------------------------------------------------------------- /tests/test_no_dependencies.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test that array_api_compat has no "hard" dependencies. 3 | 4 | Libraries like NumPy should only be imported if a numpy array is passed to 5 | array_namespace or if array_api_compat.numpy is explicitly imported. 6 | 7 | We have to test this in a subprocess because these libraries have already been 8 | imported from the other tests. 9 | """ 10 | 11 | import sys 12 | import subprocess 13 | 14 | import pytest 15 | 16 | class Array: 17 | # Dummy array namespace that doesn't depend on any array library 18 | def __array_namespace__(self, api_version=None): 19 | class Namespace: 20 | pass 21 | return Namespace() 22 | 23 | def _test_dependency(mod): 24 | assert mod not in sys.modules 25 | 26 | # Run various functions that shouldn't depend on mod and check that they 27 | # don't import it. 28 | 29 | import array_api_compat 30 | assert mod not in sys.modules 31 | 32 | a = Array() 33 | 34 | # array-api-strict is an example of an array API library that isn't 35 | # wrapped by array-api-compat. 36 | if "strict" not in mod and mod != "sparse": 37 | is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array") 38 | assert not is_mod_array(a) 39 | assert mod not in sys.modules 40 | 41 | is_array_api_obj = getattr(array_api_compat, "is_array_api_obj") 42 | assert is_array_api_obj(a) 43 | assert mod not in sys.modules 44 | 45 | array_namespace = getattr(array_api_compat, "array_namespace") 46 | array_namespace(Array()) 47 | assert mod not in sys.modules 48 | 49 | # TODO: Test that wrapper for library X doesn't depend on wrappers for library 50 | # Y (except most array libraries actually do themselves depend on numpy). 51 | 52 | @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", 53 | "jax.numpy", "sparse", "array_api_strict"]) 54 | def test_numpy_dependency(library): 55 | # This import is here because it imports numpy 56 | from ._helpers import import_ 57 | 58 | # This unfortunately won't go through any of the pytest machinery. We 59 | # reraise the exception as an AssertionError so that pytest will show it 60 | # in a semi-reasonable way 61 | 62 | # Import (in this process) to make sure 'library' is actually installed and 63 | # so that cupy can be skipped. 64 | import_(library) 65 | 66 | try: 67 | subprocess.run([sys.executable, '-c', f'''\ 68 | from tests.test_no_dependencies import _test_dependency 69 | 70 | _test_dependency({library!r})'''], check=True, capture_output=True, encoding='utf-8') 71 | except subprocess.CalledProcessError as e: 72 | print(e.stdout, end='') 73 | raise AssertionError(e.stderr) 74 | -------------------------------------------------------------------------------- /tests/test_torch.py: -------------------------------------------------------------------------------- 1 | """Test "unspecified" behavior which we cannot easily test in the Array API test suite. 2 | """ 3 | import itertools 4 | 5 | import pytest 6 | 7 | try: 8 | import torch 9 | except ImportError: 10 | pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") 11 | 12 | from array_api_compat import torch as xp 13 | 14 | 15 | class TestResultType: 16 | def test_empty(self): 17 | with pytest.raises(ValueError): 18 | xp.result_type() 19 | 20 | def test_one_arg(self): 21 | for x in [1, 1.0, 1j, '...', None]: 22 | with pytest.raises((ValueError, AttributeError)): 23 | xp.result_type(x) 24 | 25 | for x in [xp.float32, xp.int64, torch.complex64]: 26 | assert xp.result_type(x) == x 27 | 28 | for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: 29 | assert xp.result_type(x) == x.dtype 30 | 31 | def test_two_args(self): 32 | # Only include here things "unspecified" in the spec 33 | 34 | # scalar, tensor or tensor,tensor 35 | for x, y in [ 36 | (1., 1j), 37 | (1j, xp.arange(3)), 38 | (True, xp.asarray(3.)), 39 | (xp.ones(3) == 1, 1j*xp.ones(3)), 40 | ]: 41 | assert xp.result_type(x, y) == torch.result_type(x, y) 42 | 43 | # dtype, scalar 44 | for x, y in [ 45 | (1j, xp.int64), 46 | (True, xp.float64), 47 | ]: 48 | assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) 49 | 50 | # dtype, dtype 51 | for x, y in [ 52 | (xp.bool, xp.complex64) 53 | ]: 54 | xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) 55 | assert xp.result_type(x, y) == torch.result_type(xt, yt) 56 | 57 | def test_multi_arg(self): 58 | torch.set_default_dtype(torch.float32) 59 | 60 | args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] 61 | assert xp.result_type(*args) == torch.float16 62 | 63 | args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] 64 | assert xp.result_type(*args) == xp.complex64 65 | 66 | args = [1, 2, 3j, xp.float64, 4, 5, 6] 67 | assert xp.result_type(*args) == xp.complex128 68 | 69 | args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] 70 | assert xp.result_type(*args) == xp.complex128 71 | 72 | i64 = xp.ones(1, dtype=xp.int64) 73 | f16 = xp.ones(1, dtype=xp.float16) 74 | for i in itertools.permutations([i64, f16, 1.0, 1.0]): 75 | assert xp.result_type(*i) == xp.float16, f"{i}" 76 | 77 | with pytest.raises(ValueError): 78 | xp.result_type(1, 2, 3, 4) 79 | 80 | 81 | @pytest.mark.parametrize("default_dt", ['float32', 'float64']) 82 | @pytest.mark.parametrize("dtype_a", 83 | (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) 84 | ) 85 | @pytest.mark.parametrize("dtype_b", 86 | (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) 87 | ) 88 | def test_gh_273(self, default_dt, dtype_a, dtype_b): 89 | # Regression test for https://github.com/data-apis/array-api-compat/issues/273 90 | 91 | try: 92 | prev_default = torch.get_default_dtype() 93 | default_dtype = getattr(torch, default_dt) 94 | torch.set_default_dtype(default_dtype) 95 | 96 | a = xp.asarray([2, 1], dtype=dtype_a) 97 | b = xp.asarray([1, -1], dtype=dtype_b) 98 | dtype_1 = xp.result_type(a, b, 1.0) 99 | dtype_2 = xp.result_type(b, a, 1.0) 100 | assert dtype_1 == dtype_2 101 | finally: 102 | torch.set_default_dtype(prev_default) 103 | 104 | 105 | def test_meshgrid(): 106 | """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" 107 | 108 | x, y = xp.asarray([1, 2]), xp.asarray([4]) 109 | 110 | X, Y = xp.meshgrid(x, y) 111 | 112 | # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different 113 | X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) 114 | 115 | assert X.shape == X_xy.shape 116 | assert xp.all(X == X_xy) 117 | 118 | assert Y.shape == Y_xy.shape 119 | assert xp.all(Y == Y_xy) 120 | -------------------------------------------------------------------------------- /tests/test_vendoring.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_vendoring_numpy(): 5 | from vendor_test import uses_numpy 6 | 7 | uses_numpy._test_numpy() 8 | 9 | 10 | def test_vendoring_cupy(): 11 | pytest.importorskip("cupy") 12 | 13 | from vendor_test import uses_cupy 14 | 15 | uses_cupy._test_cupy() 16 | 17 | 18 | def test_vendoring_torch(): 19 | pytest.importorskip("torch") 20 | from vendor_test import uses_torch 21 | 22 | uses_torch._test_torch() 23 | 24 | 25 | def test_vendoring_dask(): 26 | pytest.importorskip("dask") 27 | from vendor_test import uses_dask 28 | uses_dask._test_dask() 29 | -------------------------------------------------------------------------------- /torch-skips.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/torch-skips.txt -------------------------------------------------------------------------------- /vendor_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/vendor_test/__init__.py -------------------------------------------------------------------------------- /vendor_test/uses_cupy.py: -------------------------------------------------------------------------------- 1 | # Basic test that vendoring works 2 | 3 | from .vendored._compat import ( 4 | cupy as cp_compat, 5 | is_cupy_array, 6 | is_cupy_namespace, 7 | ) 8 | 9 | import cupy as cp 10 | 11 | def _test_cupy(): 12 | a = cp_compat.asarray([1., 2., 3.]) 13 | b = cp_compat.arange(3, dtype=cp_compat.float32) 14 | 15 | # cp.pow does not exist. Update this to use something else if it is added 16 | res = cp_compat.pow(a, b) 17 | assert res.dtype == cp_compat.float64 == cp.float64 18 | assert isinstance(a, cp.ndarray) 19 | assert isinstance(b, cp.ndarray) 20 | assert isinstance(res, cp.ndarray) 21 | 22 | cp.testing.assert_allclose(res, [1., 2., 9.]) 23 | 24 | assert is_cupy_array(res) 25 | assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat) 26 | -------------------------------------------------------------------------------- /vendor_test/uses_dask.py: -------------------------------------------------------------------------------- 1 | # Basic test that vendoring works 2 | 3 | from .vendored._compat.dask import array as dask_compat 4 | from .vendored._compat import is_dask_array, is_dask_namespace 5 | 6 | import dask.array as da 7 | import numpy as np 8 | 9 | def _test_dask(): 10 | a = dask_compat.asarray([1., 2., 3.]) 11 | b = dask_compat.arange(3, dtype=dask_compat.float32) 12 | 13 | # np.pow does not exist. Update this to use something else if it is added 14 | res = dask_compat.pow(a, b) 15 | assert res.dtype == dask_compat.float64 == np.float64 16 | assert isinstance(a, da.Array) 17 | assert isinstance(b, da.Array) 18 | assert isinstance(res, da.Array) 19 | 20 | np.testing.assert_allclose(res, [1., 2., 9.]) 21 | 22 | assert is_dask_array(res) 23 | assert is_dask_namespace(da) and is_dask_namespace(dask_compat) 24 | -------------------------------------------------------------------------------- /vendor_test/uses_numpy.py: -------------------------------------------------------------------------------- 1 | # Basic test that vendoring works 2 | 3 | from .vendored._compat import ( 4 | is_numpy_array, 5 | is_numpy_namespace, 6 | numpy as np_compat, 7 | ) 8 | 9 | 10 | import numpy as np 11 | 12 | def _test_numpy(): 13 | a = np_compat.asarray([1., 2., 3.]) 14 | b = np_compat.arange(3, dtype=np_compat.float32) 15 | 16 | # np.pow does not exist. Update this to use something else if it is added 17 | res = np_compat.pow(a, b) 18 | assert res.dtype == np_compat.float64 == np.float64 19 | assert isinstance(a, np.ndarray) 20 | assert isinstance(b, np.ndarray) 21 | assert isinstance(res, np.ndarray) 22 | 23 | np.testing.assert_allclose(res, [1., 2., 9.]) 24 | 25 | assert is_numpy_array(res) 26 | assert is_numpy_namespace(np) and is_numpy_namespace(np_compat) 27 | -------------------------------------------------------------------------------- /vendor_test/uses_torch.py: -------------------------------------------------------------------------------- 1 | # Basic test that vendoring works 2 | 3 | from .vendored._compat import ( 4 | is_torch_array, 5 | is_torch_namespace, 6 | torch as torch_compat, 7 | ) 8 | 9 | import torch 10 | 11 | def _test_torch(): 12 | a = torch_compat.asarray([1., 2., 3.]) 13 | b = torch_compat.arange(3, dtype=torch_compat.float64) 14 | assert a.dtype == torch_compat.float32 == torch.float32 15 | assert b.dtype == torch_compat.float64 == torch.float64 16 | 17 | # torch.expand_dims does not exist. Update this to use something else if it is added 18 | res = torch_compat.expand_dims(a, axis=0) 19 | assert res.dtype == torch_compat.float32 == torch.float32 20 | assert res.shape == (1, 3) 21 | assert isinstance(res.shape, torch.Size) 22 | assert isinstance(a, torch.Tensor) 23 | assert isinstance(b, torch.Tensor) 24 | assert isinstance(res, torch.Tensor) 25 | 26 | torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]])) 27 | 28 | assert is_torch_array(res) 29 | assert is_torch_namespace(torch) and is_torch_namespace(torch_compat) 30 | 31 | -------------------------------------------------------------------------------- /vendor_test/vendored/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-apis/array-api-compat/cddc9ef8a19b453b09884987ca6a0626408a1478/vendor_test/vendored/__init__.py -------------------------------------------------------------------------------- /vendor_test/vendored/_compat: -------------------------------------------------------------------------------- 1 | ../../array_api_compat/ --------------------------------------------------------------------------------