├── .flake8
├── .gitattributes
├── .github
└── workflows
│ ├── lint.yml
│ └── test.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── _config.yml
├── array-api-strict-skips.txt
├── array_api_tests
├── __init__.py
├── _array_module.py
├── _version.py
├── algos.py
├── array_helpers.py
├── dtype_helpers.py
├── hypothesis_helpers.py
├── pytest_helpers.py
├── shape_helpers.py
├── stubs.py
├── test_array_object.py
├── test_constants.py
├── test_creation_functions.py
├── test_data_type_functions.py
├── test_fft.py
├── test_has_names.py
├── test_indexing_functions.py
├── test_inspection_functions.py
├── test_linalg.py
├── test_manipulation_functions.py
├── test_operators_and_elementwise_functions.py
├── test_searching_functions.py
├── test_set_functions.py
├── test_signatures.py
├── test_sorting_functions.py
├── test_special_cases.py
├── test_statistical_functions.py
├── test_utility_functions.py
└── typing.py
├── conftest.py
├── meta_tests
├── README.md
├── __init__.py
├── test_array_helpers.py
├── test_broadcasting.py
├── test_equality_mapping.py
├── test_hypothesis_helpers.py
├── test_linalg.py
├── test_partial_adopters.py
├── test_pytest_helpers.py
├── test_signatures.py
├── test_special_cases.py
└── test_utils.py
├── pytest.ini
├── reporting.py
├── requirements.txt
└── setup.cfg
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = F
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | array_api_tests/_version.py} export-subst
2 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v4
11 | - name: Set up Python ${{ matrix.python-version }}
12 | uses: actions/setup-python@v5
13 | with:
14 | python-version: "3.10"
15 | - name: Run pre-commit hook
16 | uses: pre-commit/action@v3.0.1
17 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test Array API Strict
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.10", "3.11"]
12 |
13 | steps:
14 | - name: Checkout array-api-tests
15 | uses: actions/checkout@v1
16 | with:
17 | submodules: 'true'
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v1
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | python -m pip install array-api-strict
26 | python -m pip install -r requirements.txt
27 | - name: Run the test suite
28 | env:
29 | ARRAY_API_TESTS_MODULE: array_api_strict
30 | ARRAY_API_STRICT_API_VERSION: 2024.12
31 | run: |
32 | pytest -v -rxXfE --skips-file array-api-strict-skips.txt array_api_tests/
33 | # We also have internal tests that isn't really necessary for adopters
34 | pytest -v -rxXfE meta_tests/
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # IDE
121 | .idea/
122 | .vscode/
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # pytest-json-report
136 | .report.json
137 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "array_api_tests/array-api"]
2 | path = array-api
3 | url = https://github.com/data-apis/array-api/
4 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pycqa/flake8
3 | rev: '4.0.1'
4 | hooks:
5 | - id: flake8
6 | args: [--select, F]
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Consortium for Python Data API Standards contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include versioneer.py
2 | include array_api_tests/_version.py}
3 | include array_api_tests/_version.py
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Test Suite for Array API Compliance
2 |
3 | This is the test suite for array libraries adopting the [Python Array API
4 | standard](https://data-apis.org/array-api/latest).
5 |
6 | Keeping full coverage of the spec is an on-going priority as the Array API evolves.
7 | Feedback and contributions are welcome!
8 |
9 | ## Quickstart
10 |
11 | ### Setup
12 |
13 | Currently we pin the Array API specification repo [`array-api`](https://github.com/data-apis/array-api/)
14 | as a git submodule. This might change in the future to better support vendoring
15 | use cases (see [#107](https://github.com/data-apis/array-api-tests/issues/107)),
16 | but for now be sure submodules are pulled too, e.g.
17 |
18 | ```bash
19 | $ git submodule update --init
20 | ```
21 |
22 | To run the tests, install the testing dependencies.
23 |
24 | ```bash
25 | $ pip install -r requirements.txt
26 | ```
27 |
28 | Ensure you have the array library that you want to test installed.
29 |
30 | ### Specifying the array module
31 |
32 | You need to specify the array library to test. It can be specified via the
33 | `ARRAY_API_TESTS_MODULE` environment variable, e.g.
34 |
35 | ```bash
36 | $ export ARRAY_API_TESTS_MODULE=array_api_strict
37 | ```
38 |
39 | To specify a runtime-defined module, define `xp` using the `exec('...')` syntax:
40 |
41 | ```bash
42 | $ export ARRAY_API_TESTS_MODULE="exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)')"
43 | ```
44 |
45 | Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
46 |
47 | ### Specifying the API version
48 |
49 | You can specify the API version to use when testing via the
50 | `ARRAY_API_TESTS_VERSION` environment variable, e.g.
51 |
52 | ```bash
53 | $ export ARRAY_API_TESTS_VERSION="2023.12"
54 | ```
55 |
56 | Currently this defaults to the array module's `__array_api_version__` value, and
57 | if that attribute doesn't exist then we fallback to `"2021.12"`.
58 |
59 | ### Run the suite
60 |
61 | Simply run `pytest` against the `array_api_tests/` folder to run the full suite.
62 |
63 | ```bash
64 | $ pytest array_api_tests/
65 | ```
66 |
67 | The suite tries to logically organise its tests. `pytest` allows you to only run
68 | a specific test case, which is useful when developing functions.
69 |
70 | ```bash
71 | $ pytest array_api_tests/test_creation_functions.py::test_zeros
72 | ```
73 |
74 | ## What the test suite covers
75 |
76 | We are interested in array libraries conforming to the
77 | [spec](https://data-apis.org/array-api/latest/API_specification/index.html).
78 | Ideally this means that if a library has fully adopted the Array API, the test
79 | suite passes. We take great care to _not_ test things which are out-of-scope,
80 | so as to not unexpectedly fail the suite.
81 |
82 | ### Primary tests
83 |
84 | Every function—including array object methods—has a respective test
85 | method1 . We use
86 | [Hypothesis](https://hypothesis.readthedocs.io/en/latest/)
87 | to generate a diverse set of valid inputs. This means array inputs will cover
88 | different dtypes and shapes, as well as contain interesting elements. These
89 | examples generate with interesting arrangements of non-array positional
90 | arguments and keyword arguments.
91 |
92 | Each test case will cover the following areas if relevant:
93 |
94 | * **Smoking**: We pass our generated examples to all functions. As these
95 | examples solely consist of *valid* inputs, we are testing that functions can
96 | be called using their documented inputs without raising errors.
97 |
98 | * **Data type**: For functions returning/modifying arrays, we assert that output
99 | arrays have the correct data types. Most functions
100 | [type-promote](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)
101 | input arrays and some functions have bespoke rules—in both cases we simulate
102 | the correct behaviour to find the expected data types.
103 |
104 | * **Shape**: For functions returning/modifying arrays, we assert that output
105 | arrays have the correct shape. Most functions
106 | [broadcast](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
107 | input arrays and some functions have bespoke rules—in both cases we simulate
108 | the correct behaviour to find the expected shapes.
109 |
110 | * **Values**: We assert output values (including the elements of
111 | returned/modified arrays) are as expected. Except for manipulation functions
112 | or special cases, the spec allows floating-point inputs to have inexact
113 | outputs, so with such examples we only assert values are roughly as expected.
114 |
115 | ### Additional tests
116 |
117 | In addition to having one test case for each function, we test other properties
118 | of the functions and some miscellaneous things.
119 |
120 | * **Special cases**: For functions with special case behaviour, we assert that
121 | these functions return the correct values.
122 |
123 | * **Signatures**: We assert functions have the correct signatures.
124 |
125 | * **Constants**: We assert that
126 | [constants](https://data-apis.org/array-api/latest/API_specification/constants.html)
127 | behave expectedly, are roughly the expected value, and that any related
128 | functions interact with them correctly.
129 |
130 | Be aware that some aspects of the spec are impractical or impossible to actually
131 | test, so they are not covered in the suite.
132 |
133 | ## Interpreting errors
134 |
135 | First and foremost, note that most tests have to assume that certain aspects of
136 | the Array API have been correctly adopted, as fundamental APIs such as array
137 | creation and equalities are hard requirements for many assertions. This means a
138 | test case for one function might fail because another function has bugs or even
139 | no implementation.
140 |
141 | This means adopting libraries at first will result in a vast number of errors
142 | due to cascading errors. Generally the nature of the spec means many granular
143 | details such as type promotion is likely going to also fail nearly-conforming
144 | functions.
145 |
146 | We hope to improve user experience in regards to "noisy" errors in
147 | [#51](https://github.com/data-apis/array-api-tests/issues/51). For now, if an
148 | error message involves `_UndefinedStub`, it means an attribute of the array
149 | library (including functions) and it's objects (e.g. the array) is missing.
150 |
151 | The spec is the suite's source of truth. If the suite appears to assume
152 | behaviour different from the spec, or test something that is not documented,
153 | this is a bug—please [report such
154 | issues](https://github.com/data-apis/array-api-tests/issues/) to us.
155 |
156 |
157 | ## Running on CI
158 |
159 | See our existing [GitHub Actions workflow for `array-api-strict`](https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/test.yml)
160 | for an example of using the test suite on CI. Note [`array-api-strict`](https://github.com/data-apis/array-api-strict)
161 | is an implementation of the array API that uses NumPy under the hood.
162 |
163 | ### Releases
164 |
165 | We recommend pinning against a [release tag](https://github.com/data-apis/array-api-tests/releases)
166 | when running on CI.
167 |
168 | We use [calender versioning](https://calver.org/) for the releases. You should
169 | expect that any version may be "breaking" compared to the previous one, in that
170 | new tests (or improvements to existing tests) may cause a previously passing
171 | library to fail.
172 |
173 | ### Configuration
174 |
175 | #### Data-dependent shapes
176 |
177 | Use the `--disable-data-dependent-shapes` flag to skip testing functions which have
178 | [data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html).
179 |
180 | #### Extensions
181 |
182 | By default, tests for the optional Array API extensions such as
183 | [`linalg`](https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html)
184 | will be skipped if not present in the specified array module. You can purposely
185 | skip testing extension(s) via the `--disable-extension` option.
186 |
187 | #### Skip or XFAIL test cases
188 |
189 | Test cases you want to skip can be specified in a skips or XFAILS file. The
190 | difference between skip and XFAIL is that XFAIL tests are still run and
191 | reported as XPASS if they pass.
192 |
193 | By default, the skips and xfails files are `skips.txt` and `fails.txt` in the root
194 | of this repository, but any file can be specified with the `--skips-file` and
195 | `--xfails-file` command line flags.
196 |
197 | The files should list the test ids to be skipped/xfailed. Empty lines and
198 | lines starting with `#` are ignored. The test id can be any substring of the
199 | test ids to skip/xfail.
200 |
201 | ```
202 | # skips.txt or xfails.txt
203 | # Line comments can be denoted with the hash symbol (#)
204 |
205 | # Skip specific test case, e.g. when argsort() does not respect relative order
206 | # https://github.com/numpy/numpy/issues/20778
207 | array_api_tests/test_sorting_functions.py::test_argsort
208 |
209 | # Skip specific test case parameter, e.g. you forgot to implement in-place adds
210 | array_api_tests/test_add[__iadd__(x1, x2)]
211 | array_api_tests/test_add[__iadd__(x, s)]
212 |
213 | # Skip module, e.g. when your set functions treat NaNs as non-distinct
214 | # https://github.com/numpy/numpy/issues/20326
215 | array_api_tests/test_set_functions.py
216 | ```
217 |
218 | Here is an example GitHub Actions workflow file, where the xfails are stored
219 | in `array-api-tests.xfails.txt` in the base of the `your-array-library` repo.
220 |
221 | If you want, you can use `-o xfail_strict=True`, which causes XPASS tests (XFAIL
222 | tests that actually pass) to fail the test suite. However, be aware that
223 | XFAILures can be flaky (see below, so this may not be a good idea unless you
224 | use some other mitigation of such flakyness).
225 |
226 | If you don't want this behavior, you can remove it, or use `--skips-file`
227 | instead of `--xfails-file`.
228 |
229 | ```yaml
230 | # ./.github/workflows/array_api.yml
231 | jobs:
232 | tests:
233 | runs-on: ubuntu-latest
234 | strategy:
235 | matrix:
236 | python-version: ['3.8', '3.9', '3.10', '3.11']
237 |
238 | steps:
239 | - name: Checkout
240 | uses: actions/checkout@v3
241 | with:
242 | path: your-array-library
243 |
244 | - name: Checkout array-api-tests
245 | uses: actions/checkout@v3
246 | with:
247 | repository: data-apis/array-api-tests
248 | submodules: 'true'
249 | path: array-api-tests
250 |
251 | - name: Run the array API test suite
252 | env:
253 | ARRAY_API_TESTS_MODULE: your.array.api.namespace
254 | run: |
255 | export PYTHONPATH="${GITHUB_WORKSPACE}/your-array-library"
256 | cd ${GITHUB_WORKSPACE}/array-api-tests
257 | pytest -v -rxXfE --ci --xfails-file ${GITHUB_WORKSPACE}/your-array-library/array-api-tests-xfails.txt array_api_tests/
258 | ```
259 |
260 | > **Warning**
261 | >
262 | > XFAIL tests that use Hypothesis (basically every test in the test suite except
263 | > those in test_has_names.py) can be flaky, due to the fact that Hypothesis
264 | > might not always run the test with an input that causes the test to fail.
265 | > There are several ways to avoid this problem:
266 | >
267 | > - Increase the maximum number of examples, e.g., by adding `--max-examples
268 | > 200` to the test command (the default is `20`, see below). This will
269 | > make it more likely that the failing case will be found, but it will also
270 | > make the tests take longer to run.
271 | > - Don't use `-o xfail_strict=True`. This will make it so that if an XFAIL
272 | > test passes, it will alert you in the test summary but will not cause the
273 | > test run to register as failed.
274 | > - Use skips instead of XFAILS. The difference between XFAIL and skip is that
275 | > a skipped test is never run at all, whereas an XFAIL test is always run
276 | > but ignored if it fails.
277 | > - Save the [Hypothesis examples
278 | > database](https://hypothesis.readthedocs.io/en/latest/database.html)
279 | > persistently on CI. That way as soon as a run finds one failing example,
280 | > it will always re-run future runs with that example. But note that the
281 | > Hypothesis examples database may be cleared when a new version of
282 | > Hypothesis or the test suite is released.
283 |
284 | #### Max examples
285 |
286 | The tests make heavy use
287 | [Hypothesis](https://hypothesis.readthedocs.io/en/latest/). You can configure
288 | how many examples are generated using the `--max-examples` flag, which
289 | defaults to `20`. Lower values can be useful for quick checks, and larger
290 | values should result in more rigorous runs. For example, `--max-examples
291 | 10_000` may find bugs where default runs don't but will take much longer to
292 | run.
293 |
294 | #### Skipping Dtypes
295 |
296 | The test suite will automatically skip testing of inessential dtypes if they
297 | are not present on the array module namespace, but dtypes can also be skipped
298 | manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to
299 | a comma separated list of dtypes to skip. For example
300 |
301 | ```
302 | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
303 | ```
304 |
305 | Note that skipping certain essential dtypes such as `bool` and the default
306 | floating-point dtype is not supported.
307 |
308 | #### Turning xfails into skips
309 |
310 | Keeping a large number of ``xfails`` can have drastic effects on the run time. This is due
311 | to the way `hypothesis` works: when it detects a failure, it does a large amount
312 | of work to simplify the failing example.
313 | If the run time of the test suite becomes a problem, you can use the
314 | ``ARRAY_API_TESTS_XFAIL_MARK`` environment variable: setting it to ``skip`` skips the
315 | entries from the ``xfail.txt`` file instead of xfailing them. Anecdotally, we saw
316 | speed-ups by a factor of 4-5---which allowed us to use 4-5 larger values of
317 | ``--max-examples`` within the same time budget.
318 |
319 | #### Limiting the array sizes
320 |
321 | The test suite generates random arrays as inputs to functions it tests. "unvectorized"
322 | tests iterate over elements of arrays, which might be slow. If the run time becomes
323 | a problem, you can limit the maximum number of elements in generated arrays by
324 | setting the environment variable ``ARRAY_API_TESTS_MAX_ARRAY_SIZE`` to the
325 | desired value. By default, it is set to 1024.
326 |
327 |
328 | ## Contributing
329 |
330 | ### Remain in-scope
331 |
332 | It is important that every test only uses APIs that are part of the standard.
333 | For instance, when creating input arrays you should only use the [array creation
334 | functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html)
335 | that are documented in the spec. The same goes for testing arrays—you'll find
336 | many utilities that parralel NumPy's own test utils in the `*_helpers.py` files.
337 |
338 | ### Tools
339 |
340 | Hypothesis should almost always be used for the primary tests, and can be useful
341 | elsewhere. Effort should be made so drawn arguments are labeled with their
342 | respective names. For
343 | [`st.data()`](https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.data),
344 | draws should be accompanied with the `label` kwarg i.e. `data.draw(,
345 | label=)`.
346 |
347 | [`pytest.mark.parametrize`](https://docs.pytest.org/en/latest/how-to/parametrize.html)
348 | should be used to run tests over multiple arguments. Parameterization should be
349 | preferred over using Hypothesis when there are a small number of possible
350 | inputs, as this allows better failure reporting. Note using both parametrize and
351 | Hypothesis for a single test method is possible and can be quite useful.
352 |
353 | ### Error messages
354 |
355 | Any assertion should be accompanied with a descriptive error message, including
356 | the relevant values. Error messages should be self-explanatory as to why a given
357 | test fails, as one should not need prior knowledge of how the test is
358 | implemented.
359 |
360 | ### Generated files
361 |
362 | Some files in the suite are automatically generated from the spec, and should
363 | not be edited directly. To regenerate these files, run the script
364 |
365 | ./generate_stubs.py path/to/array-api
366 |
367 | where `path/to/array-api` is the path to a local clone of the [`array-api`
368 | repo](https://github.com/data-apis/array-api/). Edit `generate_stubs.py` to make
369 | changes to the generated files.
370 |
371 |
372 | ### Release
373 |
374 | To make a release, first make an annotated tag with the version, e.g.:
375 |
376 | ```
377 | git tag -a 2022.01.01
378 | ```
379 |
380 | Be sure to use the calver version number for the tag name. Don't worry too much
381 | on the tag message, e.g. just write "2022.01.01".
382 |
383 | Versioneer will automatically set the version number of the `array_api_tests`
384 | package based on the git tag. Push the tag to GitHub:
385 |
386 | ```
387 | git push --tags upstream 2022.1
388 | ```
389 |
390 | Then go to the [tags page on
391 | GitHub](https://github.com/data-apis/array-api-tests/tags) and convert the tag
392 | into a release. If you want, you can add release notes, which GitHub can
393 | generate for you.
394 |
395 |
396 | ---
397 |
398 | 1 The only exceptions to having just one primary test per function are:
399 |
400 | * [`asarray()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.creation_functions.asarray.html),
401 | which is tested by `test_asarray_scalars` and `test_asarray_arrays` in
402 | `test_creation_functions.py`. Testing `asarray()` works with scalars (and
403 | nested sequences of scalars) is fundamental to testing that it works with
404 | arrays, as said arrays can only be generated by passing scalar sequences to
405 | `asarray()`.
406 |
407 | * Indexing methods
408 | ([`__getitem__()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.__getitem__.html)
409 | and
410 | [`__setitem__()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.__setitem__.html)),
411 | which respectively have both a test for non-array indices and a test for
412 | boolean array indices. This is because [masking is
413 | opt-in](https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing)
414 | (and boolean arrays need to be generated by indexing arrays anyway).
415 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
--------------------------------------------------------------------------------
/array-api-strict-skips.txt:
--------------------------------------------------------------------------------
1 | # Known special case issue in NumPy. Not worth working around here
2 | # https://github.com/numpy/numpy/issues/21213
3 | array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4 | array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
5 |
6 | # The test suite is incorrectly checking sums that have loss of significance
7 | # (https://github.com/data-apis/array-api-tests/issues/168)
8 | array_api_tests/test_statistical_functions.py::test_sum
9 |
10 | # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, all libraries do just that
11 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
12 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
13 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
14 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
15 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
16 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
17 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
18 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
19 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
20 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
21 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
22 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
23 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
24 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
25 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
26 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
27 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
28 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
29 |
30 | # FIXME needs array-api-strict >=2.3.2
31 | array_api_tests/test_data_type_functions.py::test_finfo
32 | array_api_tests/test_data_type_functions.py::test_finfo_dtype
33 | array_api_tests/test_data_type_functions.py::test_iinfo
34 | array_api_tests/test_data_type_functions.py::test_iinfo_dtype
35 |
--------------------------------------------------------------------------------
/array_api_tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import wraps
3 | from importlib import import_module
4 |
5 | from hypothesis import strategies as st
6 | from hypothesis.extra import array_api
7 |
8 | from . import _version
9 |
10 | __all__ = ["xp", "api_version", "xps"]
11 |
12 |
13 | # You can comment the following out and instead import the specific array module
14 | # you want to test, e.g. `import array_api_strict as xp`.
15 | if "ARRAY_API_TESTS_MODULE" in os.environ:
16 | env_var = os.environ["ARRAY_API_TESTS_MODULE"]
17 | if env_var.startswith("exec('") and env_var.endswith("')"):
18 | script = env_var[6:][:-2]
19 | namespace = {}
20 | exec(script, namespace)
21 | xp = namespace["xp"]
22 | xp_name = xp.__name__
23 | else:
24 | xp_name = env_var
25 | _module, _sub = xp_name, None
26 | if "." in xp_name:
27 | _module, _sub = xp_name.split(".", 1)
28 | xp = import_module(_module)
29 | if _sub:
30 | try:
31 | xp = getattr(xp, _sub)
32 | except AttributeError:
33 | # _sub may be a submodule that needs to be imported. We can't
34 | # do this in every case because some array modules are not
35 | # submodules that can be imported (like mxnet.nd).
36 | xp = import_module(xp_name)
37 | else:
38 | raise RuntimeError(
39 | "No array module specified - either edit __init__.py or set the "
40 | "ARRAY_API_TESTS_MODULE environment variable."
41 | )
42 |
43 |
44 | # If xp.bool is not available, like in some versions of NumPy and CuPy, try
45 | # patching in xp.bool_.
46 | try:
47 | xp.bool
48 | except AttributeError as e:
49 | if hasattr(xp, "bool_"):
50 | xp.bool = xp.bool_
51 | else:
52 | raise e
53 |
54 |
55 | # We monkey patch floats() to always disable subnormals as they are out-of-scope
56 |
57 | _floats = st.floats
58 |
59 |
60 | @wraps(_floats)
61 | def floats(*a, **kw):
62 | kw["allow_subnormal"] = False
63 | return _floats(*a, **kw)
64 |
65 |
66 | st.floats = floats
67 |
68 |
69 | # We do the same with xps.from_dtype() - this is not strictly necessary, as
70 | # the underlying floats() will never generate subnormals. We only do this
71 | # because internal logic in xps.from_dtype() assumes xp.finfo() has its
72 | # attributes as scalar floats, which is expected behaviour but disrupts many
73 | # unrelated tests.
74 | try:
75 | __from_dtype = array_api._from_dtype
76 |
77 | @wraps(__from_dtype)
78 | def _from_dtype(*a, **kw):
79 | kw["allow_subnormal"] = False
80 | return __from_dtype(*a, **kw)
81 |
82 | array_api._from_dtype = _from_dtype
83 | except AttributeError:
84 | # Ignore monkey patching if Hypothesis changes the private API
85 | pass
86 |
87 |
88 | api_version = os.getenv(
89 | "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2024.12")
90 | )
91 | xps = array_api.make_strategies_namespace(xp, api_version=api_version)
92 |
93 | __version__ = _version.get_versions()["version"]
94 |
--------------------------------------------------------------------------------
/array_api_tests/_array_module.py:
--------------------------------------------------------------------------------
1 | from . import stubs, xp
2 |
3 |
4 | class _UndefinedStub:
5 | """
6 | Standing for undefined names, so the tests can be imported even if they
7 | fail
8 |
9 | If this object appears in a test failure, it means a name is not defined
10 | in a function. This typically happens for things like dtype literals not
11 | being defined.
12 |
13 | """
14 | def __init__(self, name):
15 | self.name = name
16 |
17 | def _raise(self, *args, **kwargs):
18 | raise AssertionError(f"{self.name} is not defined in {xp.__name__}")
19 |
20 | def __repr__(self):
21 | return f""
22 |
23 | __call__ = _raise
24 | __getattr__ = _raise
25 |
26 | _dtypes = [
27 | "bool",
28 | "uint8", "uint16", "uint32", "uint64",
29 | "int8", "int16", "int32", "int64",
30 | "float32", "float64",
31 | "complex64", "complex128",
32 | ]
33 | _constants = ["e", "inf", "nan", "pi"]
34 | _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
35 | _funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
36 | _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]
37 |
38 | for attr in _top_level_attrs:
39 | try:
40 | globals()[attr] = getattr(xp, attr)
41 | except AttributeError:
42 | globals()[attr] = _UndefinedStub(attr)
43 |
--------------------------------------------------------------------------------
/array_api_tests/algos.py:
--------------------------------------------------------------------------------
1 | __all__ = ["broadcast_shapes"]
2 |
3 |
4 | from .typing import Shape
5 |
6 |
7 | # We use a custom exception to differentiate from potential bugs
8 | class BroadcastError(ValueError):
9 | pass
10 |
11 |
12 | def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
13 | """Broadcasts `shape1` and `shape2`"""
14 | N1 = len(shape1)
15 | N2 = len(shape2)
16 | N = max(N1, N2)
17 | shape = [None for _ in range(N)]
18 | i = N - 1
19 | while i >= 0:
20 | n1 = N1 - N + i
21 | if N1 - N + i >= 0:
22 | d1 = shape1[n1]
23 | else:
24 | d1 = 1
25 | n2 = N2 - N + i
26 | if N2 - N + i >= 0:
27 | d2 = shape2[n2]
28 | else:
29 | d2 = 1
30 |
31 | if d1 == 1:
32 | shape[i] = d2
33 | elif d2 == 1:
34 | shape[i] = d1
35 | elif d1 == d2:
36 | shape[i] = d1
37 | else:
38 | raise BroadcastError
39 |
40 | i = i - 1
41 |
42 | return tuple(shape)
43 |
44 |
45 | def broadcast_shapes(*shapes: Shape):
46 | if len(shapes) == 0:
47 | raise ValueError("shapes=[] must be non-empty")
48 | elif len(shapes) == 1:
49 | return shapes[0]
50 | result = _broadcast_shapes(shapes[0], shapes[1])
51 | for i in range(2, len(shapes)):
52 | result = _broadcast_shapes(result, shapes[i])
53 | return result
54 |
--------------------------------------------------------------------------------
/array_api_tests/array_helpers.py:
--------------------------------------------------------------------------------
1 | from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
2 | logical_or, isfinite, greater, less_equal,
3 | zeros, ones, full, bool, int8, int16, int32,
4 | int64, uint8, uint16, uint32, uint64, float32,
5 | float64, nan, inf, pi, remainder, divide, isinf,
6 | negative, asarray)
7 | # These are exported here so that they can be included in the special cases
8 | # tests from this file.
9 | from ._array_module import logical_not, subtract, floor, ceil, where
10 | from . import _array_module as xp
11 | from . import dtype_helpers as dh
12 |
13 | __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
14 | 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
15 | 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
16 | 'infinity', 'π', 'isnegzero', 'non_zero', 'isposzero',
17 | 'exactly_equal', 'assert_exactly_equal', 'notequal',
18 | 'assert_finite', 'assert_non_zero', 'ispositive',
19 | 'assert_positive', 'isnegative', 'assert_negative', 'isintegral',
20 | 'assert_integral', 'isodd', 'iseven', "assert_iseven",
21 | 'assert_isinf', 'positive_mathematical_sign',
22 | 'assert_positive_mathematical_sign', 'negative_mathematical_sign',
23 | 'assert_negative_mathematical_sign', 'same_sign',
24 | 'assert_same_sign', 'float64',
25 | 'asarray', 'full', 'true', 'false', 'isnan']
26 |
27 | def zero(shape, dtype):
28 | """
29 | Returns a full 0 array of the given dtype.
30 |
31 | This should be used in place of the literal "0" in the test suite, as the
32 | spec does not require any behavior with Python literals (and in
33 | particular, it does not specify how the integer 0 and the float 0.0 work
34 | with type promotion).
35 |
36 | To get -0, use -zero(dtype) (note that -0 is only defined for floating
37 | point dtypes).
38 | """
39 | return zeros(shape, dtype=dtype)
40 |
41 | def one(shape, dtype):
42 | """
43 | Returns a full 1 array of the given dtype.
44 |
45 | This should be used in place of the literal "1" in the test suite, as the
46 | spec does not require any behavior with Python literals (and in
47 | particular, it does not specify how the integer 1 and the float 1.0 work
48 | with type promotion).
49 |
50 | To get -1, use -one(dtype).
51 | """
52 | return ones(shape, dtype=dtype)
53 |
54 | def NaN(shape, dtype):
55 | """
56 | Returns a full nan array of the given dtype.
57 |
58 | Note that this is only defined for floating point dtypes.
59 | """
60 | if dtype not in [float32, float64]:
61 | raise RuntimeError(f"Unexpected dtype {dtype} in NaN().")
62 | return full(shape, nan, dtype=dtype)
63 |
64 | def infinity(shape, dtype):
65 | """
66 | Returns a full positive infinity array of the given dtype.
67 |
68 | Note that this is only defined for floating point dtypes.
69 |
70 | To get negative infinity, use -infinity(dtype).
71 |
72 | """
73 | if dtype not in [float32, float64]:
74 | raise RuntimeError(f"Unexpected dtype {dtype} in infinity().")
75 | return full(shape, inf, dtype=dtype)
76 |
77 | def π(shape, dtype):
78 | """
79 | Returns a full π array of the given dtype.
80 |
81 | Note that this function is only defined for floating point dtype.
82 |
83 | To get rational multiples of π, use, e.g., 3*π(dtype)/2.
84 |
85 | """
86 | if dtype not in [float32, float64]:
87 | raise RuntimeError(f"Unexpected dtype {dtype} in π().")
88 | return full(shape, pi, dtype=dtype)
89 |
90 | def true(shape):
91 | """
92 | Returns a full True array with dtype=bool.
93 | """
94 | return full(shape, True, dtype=bool)
95 |
96 | def false(shape):
97 | """
98 | Returns a full False array with dtype=bool.
99 | """
100 | return full(shape, False, dtype=bool)
101 |
102 | def isnegzero(x):
103 | """
104 | Returns a mask where x is -0. Is all False if x has integer dtype.
105 | """
106 | # TODO: If copysign or signbit are added to the spec, use those instead.
107 | shape = x.shape
108 | dtype = x.dtype
109 | if dh.is_int_dtype(dtype):
110 | return false(shape)
111 | return equal(divide(one(shape, dtype), x), -infinity(shape, dtype))
112 |
113 | def isposzero(x):
114 | """
115 | Returns a mask where x is +0 (but not -0). Is all True if x has integer dtype.
116 | """
117 | # TODO: If copysign or signbit are added to the spec, use those instead.
118 | shape = x.shape
119 | dtype = x.dtype
120 | if dh.is_int_dtype(dtype):
121 | return true(shape)
122 | return equal(divide(one(shape, dtype), x), infinity(shape, dtype))
123 |
124 | def exactly_equal(x, y):
125 | """
126 | Same as equal(x, y) except it gives True where both values are nan, and
127 | distinguishes +0 and -0.
128 |
129 | This function implicitly assumes x and y have the same shape and dtype.
130 | """
131 | if x.dtype in [float32, float64]:
132 | xnegzero = isnegzero(x)
133 | ynegzero = isnegzero(y)
134 |
135 | xposzero = isposzero(x)
136 | yposzero = isposzero(y)
137 |
138 | xnan = isnan(x)
139 | ynan = isnan(y)
140 |
141 | # (x == y OR x == y == NaN) AND xnegzero == ynegzero AND xposzero == y poszero
142 | return logical_and(logical_and(
143 | logical_or(equal(x, y), logical_and(xnan, ynan)),
144 | equal(xnegzero, ynegzero)),
145 | equal(xposzero, yposzero))
146 |
147 | return equal(x, y)
148 |
149 | def notequal(x, y):
150 | """
151 | Same as not_equal(x, y) except it gives False when both values are nan.
152 |
153 | Note: this function does NOT distinguish +0 and -0.
154 |
155 | This function implicitly assumes x and y have the same shape and dtype.
156 | """
157 | if x.dtype in [float32, float64]:
158 | xnan = isnan(x)
159 | ynan = isnan(y)
160 |
161 | both_nan = logical_and(xnan, ynan)
162 | # NOT both nan AND (both nan OR x != y)
163 | return logical_and(logical_not(both_nan), not_equal(x, y))
164 |
165 | return not_equal(x, y)
166 |
167 | def less(x, y):
168 | """
169 | Same as less(x, y) except it allows comparing uint64 with signed int dtypes
170 | """
171 | if x.dtype == uint64 and dh.dtype_signed[y.dtype]:
172 | return xp.where(y < 0, xp.asarray(False), xp.less(x, xp.astype(y, uint64)))
173 | if y.dtype == uint64 and dh.dtype_signed[x.dtype]:
174 | return xp.where(x < 0, xp.asarray(True), xp.less(xp.astype(x, uint64), y))
175 | return xp.less(x, y)
176 |
177 | def assert_exactly_equal(x, y, msg_extra=None):
178 | """
179 | Test that the arrays x and y are exactly equal.
180 |
181 | If x and y do not have the same shape and dtype, they are not considered
182 | equal.
183 |
184 | """
185 | extra = '' if not msg_extra else f' ({msg_extra})'
186 |
187 | assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
188 |
189 | assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
190 |
191 | assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"
192 |
193 | def assert_finite(x):
194 | """
195 | Test that the array x is finite
196 | """
197 | assert all(isfinite(x)), "The input array is not finite"
198 |
199 | def non_zero(x):
200 | return not_equal(x, zero(x.shape, x.dtype))
201 |
202 | def assert_non_zero(x):
203 | assert all(non_zero(x)), "The input array is not nonzero"
204 |
205 | def ispositive(x):
206 | return greater(x, zero(x.shape, x.dtype))
207 |
208 | def assert_positive(x):
209 | assert all(ispositive(x)), "The input array is not positive"
210 |
211 | def isnegative(x):
212 | return less(x, zero(x.shape, x.dtype))
213 |
214 | def assert_negative(x):
215 | assert all(isnegative(x)), "The input array is not negative"
216 |
217 | def inrange(x, a, b, epsilon=0, open=False):
218 | """
219 | Returns a mask for values of x in the range [a-epsilon, a+epsilon], inclusive
220 |
221 | If open=True, the range is (a-epsilon, a+epsilon) (i.e., not inclusive).
222 | """
223 | eps = full(x.shape, epsilon, dtype=x.dtype)
224 | l = less if open else less_equal
225 | return logical_and(l(a-eps, x), l(x, b+eps))
226 |
227 | def isintegral(x):
228 | """
229 | Returns a mask on x where the values are integral
230 |
231 | x is integral if its dtype is an integer dtype, or if it is a floating
232 | point value that can be exactly represented as an integer.
233 | """
234 | if x.dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64]:
235 | return full(x.shape, True, dtype=bool)
236 | elif x.dtype in [float32, float64]:
237 | return equal(remainder(x, one(x.shape, x.dtype)), zero(x.shape, x.dtype))
238 | else:
239 | return full(x.shape, False, dtype=bool)
240 |
241 | def assert_integral(x):
242 | """
243 | Check that x has only integer values
244 | """
245 | assert all(isintegral(x)), "The input array has nonintegral values"
246 |
247 | def isodd(x):
248 | return logical_and(
249 | isintegral(x),
250 | equal(
251 | remainder(x, 2*one(x.shape, x.dtype)),
252 | one(x.shape, x.dtype)))
253 |
254 | def iseven(x):
255 | return logical_and(
256 | isintegral(x),
257 | equal(
258 | remainder(x, 2*one(x.shape, x.dtype)),
259 | zero(x.shape, x.dtype)))
260 |
261 | def assert_iseven(x):
262 | """
263 | Check that x is an even integer
264 | """
265 | assert all(iseven(x)), "The input array is not even"
266 |
267 | def assert_isinf(x):
268 | """
269 | Check that x is an infinity
270 | """
271 | assert all(isinf(x)), "The input array is not infinite"
272 |
273 | def positive_mathematical_sign(x):
274 | """
275 | Check if x has a positive "mathematical sign"
276 |
277 | The "mathematical sign" here means the sign bit is 0. This includes 0,
278 | positive finite numbers, and positive infinity. It does not include any
279 | nans, as signed nans are not required by the spec.
280 |
281 | """
282 | z = zero(x.shape, x.dtype)
283 | return logical_or(greater(x, z), isposzero(x))
284 |
285 | def assert_positive_mathematical_sign(x):
286 | assert all(positive_mathematical_sign(x)), "The input arrays do not have a positive mathematical sign"
287 |
288 | def negative_mathematical_sign(x):
289 | """
290 | Check if x has a negative "mathematical sign"
291 |
292 | The "mathematical sign" here means the sign bit is 1. This includes -0,
293 | negative finite numbers, and negative infinity. It does not include any
294 | nans, as signed nans are not required by the spec.
295 |
296 | """
297 | z = zero(x.shape, x.dtype)
298 | if x.dtype in [float32, float64]:
299 | return logical_or(less(x, z), isnegzero(x))
300 | return less(x, z)
301 |
302 | def assert_negative_mathematical_sign(x):
303 | assert all(negative_mathematical_sign(x)), "The input arrays do not have a negative mathematical sign"
304 |
305 | def same_sign(x, y):
306 | """
307 | Check if x and y have the "same sign"
308 |
309 | x and y have the same sign if they are both nonnegative or both negative.
310 | For the purposes of this function 0 and 1 have the same sign and -0 and -1
311 | have the same sign. The value of this function is False if either x or y
312 | is nan, as signed nans are not required by the spec.
313 | """
314 | return logical_or(
315 | logical_and(positive_mathematical_sign(x), positive_mathematical_sign(y)),
316 | logical_and(negative_mathematical_sign(x), negative_mathematical_sign(y)))
317 |
318 | def assert_same_sign(x, y):
319 | assert all(same_sign(x, y)), "The input arrays do not have the same sign"
320 |
321 | def _matrix_transpose(x):
322 | if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
323 | return xp.matrix_transpose(x)
324 | if hasattr(x, 'mT'):
325 | return x.mT
326 | if not isinstance(xp.permute_dims, xp._UndefinedStub):
327 | perm = list(range(x.ndim))
328 | perm[-1], perm[-2] = perm[-2], perm[-1]
329 | return xp.permute_dims(x, axes=tuple(perm))
330 | raise NotImplementedError("No way to compute matrix transpose")
331 |
--------------------------------------------------------------------------------
/array_api_tests/shape_helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from itertools import product
3 | from typing import Iterator, List, Optional, Sequence, Tuple, Union
4 |
5 | from ndindex import iter_indices as _iter_indices
6 |
7 | from .typing import AtomicIndex, Index, Scalar, Shape
8 |
9 | __all__ = [
10 | "broadcast_shapes",
11 | "normalize_axis",
12 | "ndindex",
13 | "axis_ndindex",
14 | "axes_ndindex",
15 | "reshape",
16 | "fmt_idx",
17 | ]
18 |
19 |
20 | class BroadcastError(ValueError):
21 | """Shapes do not broadcast with eachother"""
22 |
23 |
24 | def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
25 | """Broadcasts `shape1` and `shape2`"""
26 | N1 = len(shape1)
27 | N2 = len(shape2)
28 | N = max(N1, N2)
29 | shape = [None for _ in range(N)]
30 | i = N - 1
31 | while i >= 0:
32 | n1 = N1 - N + i
33 | if N1 - N + i >= 0:
34 | d1 = shape1[n1]
35 | else:
36 | d1 = 1
37 | n2 = N2 - N + i
38 | if N2 - N + i >= 0:
39 | d2 = shape2[n2]
40 | else:
41 | d2 = 1
42 |
43 | if d1 == 1:
44 | shape[i] = d2
45 | elif d2 == 1:
46 | shape[i] = d1
47 | elif d1 == d2:
48 | shape[i] = d1
49 | else:
50 | raise BroadcastError()
51 |
52 | i = i - 1
53 |
54 | return tuple(shape)
55 |
56 |
57 | def broadcast_shapes(*shapes: Shape):
58 | if len(shapes) == 0:
59 | raise ValueError("shapes=[] must be non-empty")
60 | elif len(shapes) == 1:
61 | return shapes[0]
62 | result = _broadcast_shapes(shapes[0], shapes[1])
63 | for i in range(2, len(shapes)):
64 | result = _broadcast_shapes(result, shapes[i])
65 | return result
66 |
67 |
68 | def normalize_axis(
69 | axis: Optional[Union[int, Sequence[int]]], ndim: int
70 | ) -> Tuple[int, ...]:
71 | if axis is None:
72 | return tuple(range(ndim))
73 | elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
74 | axis = tuple(axis)
75 | axes = axis if isinstance(axis, tuple) else (axis,)
76 | axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
77 | return axes
78 |
79 |
80 | def ndindex(shape: Shape) -> Iterator[Index]:
81 | """Yield every index of a shape"""
82 | return (indices[0] for indices in iter_indices(shape))
83 |
84 |
85 | def iter_indices(
86 | *shapes: Shape, skip_axes: Tuple[int, ...] = ()
87 | ) -> Iterator[Tuple[Index, ...]]:
88 | """Wrapper for ndindex.iter_indices()"""
89 | # Prevent iterations if any shape has 0-sides
90 | for shape in shapes:
91 | if 0 in shape:
92 | return
93 | for indices in _iter_indices(*shapes, skip_axes=skip_axes):
94 | yield tuple(i.raw for i in indices) # type: ignore
95 |
96 |
97 | def axis_ndindex(
98 | shape: Shape, axis: int
99 | ) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
100 | """Generate indices that index all elements in dimensions beyond `axis`"""
101 | assert axis >= 0 # sanity check
102 | axis_indices = [range(side) for side in shape[:axis]]
103 | for _ in range(axis, len(shape)):
104 | axis_indices.append([slice(None, None)])
105 | yield from product(*axis_indices)
106 |
107 |
108 | def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
109 | """Generate indices that index all elements except in `axes` dimensions"""
110 | base_indices = []
111 | axes_indices = []
112 | for axis, side in enumerate(shape):
113 | if axis in axes:
114 | base_indices.append([None])
115 | axes_indices.append(range(side))
116 | else:
117 | base_indices.append(range(side))
118 | axes_indices.append([None])
119 | for base_idx in product(*base_indices):
120 | indices = []
121 | for idx in product(*axes_indices):
122 | idx = list(idx)
123 | for axis, side in enumerate(idx):
124 | if axis not in axes:
125 | idx[axis] = base_idx[axis]
126 | idx = tuple(idx)
127 | indices.append(idx)
128 | yield list(indices)
129 |
130 |
131 | def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
132 | """Reshape a flat sequence"""
133 | if any(s == 0 for s in shape):
134 | raise ValueError(
135 | f"{shape=} contains 0-sided dimensions, "
136 | f"but that's not representable in lists"
137 | )
138 | if len(shape) == 0:
139 | assert len(flat_seq) == 1 # sanity check
140 | return flat_seq[0]
141 | elif len(shape) == 1:
142 | return flat_seq
143 | size = len(flat_seq)
144 | n = math.prod(shape[1:])
145 | return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
146 |
147 |
148 | def fmt_i(i: AtomicIndex) -> str:
149 | if isinstance(i, int):
150 | return str(i)
151 | elif isinstance(i, slice):
152 | res = ""
153 | if i.start is not None:
154 | res += str(i.start)
155 | res += ":"
156 | if i.stop is not None:
157 | res += str(i.stop)
158 | if i.step is not None:
159 | res += f":{i.step}"
160 | return res
161 | elif i is None:
162 | return "None"
163 | else:
164 | return "..."
165 |
166 |
167 | def fmt_idx(sym: str, idx: Index) -> str:
168 | if idx == ():
169 | return sym
170 | res = f"{sym}["
171 | _idx = idx if isinstance(idx, tuple) else (idx,)
172 | if len(_idx) == 1:
173 | res += fmt_i(_idx[0])
174 | else:
175 | res += ", ".join(fmt_i(i) for i in _idx)
176 | res += "]"
177 | return res
178 |
--------------------------------------------------------------------------------
/array_api_tests/stubs.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import sys
3 | from importlib import import_module
4 | from importlib.util import find_spec
5 | from pathlib import Path
6 | from types import FunctionType, ModuleType
7 | from typing import Dict, List
8 |
9 | from . import api_version
10 |
11 | __all__ = [
12 | "name_to_func",
13 | "array_methods",
14 | "array_attributes",
15 | "category_to_funcs",
16 | "EXTENSIONS",
17 | "extension_to_funcs",
18 | ]
19 |
20 | spec_module = "_" + api_version.replace('.', '_')
21 |
22 | spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification"
23 | assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
24 | sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module
25 | assert sigs_dir.exists()
26 |
27 | sigs_abs_path: str = str(sigs_dir.parent.parent.resolve())
28 | sys.path.append(sigs_abs_path)
29 | assert find_spec(f"array_api_stubs.{spec_module}") is not None
30 |
31 | name_to_mod: Dict[str, ModuleType] = {}
32 | for path in sigs_dir.glob("*.py"):
33 | name = path.name.replace(".py", "")
34 | name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}")
35 |
36 | array = name_to_mod["array_object"].array
37 | array_methods = [
38 | f for n, f in inspect.getmembers(array, predicate=inspect.isfunction)
39 | if n != "__init__" # probably exists for Sphinx
40 | ]
41 | array_attributes = [
42 | n for n, f in inspect.getmembers(array, predicate=lambda x: isinstance(x, property))
43 | ]
44 |
45 | category_to_funcs: Dict[str, List[FunctionType]] = {}
46 | for name, mod in name_to_mod.items():
47 | if name.endswith("_functions"):
48 | category = name.replace("_functions", "")
49 | objects = [getattr(mod, name) for name in mod.__all__]
50 | assert all(isinstance(o, FunctionType) for o in objects) # sanity check
51 | category_to_funcs[category] = objects
52 |
53 | all_funcs = []
54 | for funcs in [array_methods, *category_to_funcs.values()]:
55 | all_funcs.extend(funcs)
56 | name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
57 |
58 | info_funcs = []
59 | if api_version >= "2023.12":
60 | # The info functions in the stubs are in info.py, but this is not a name
61 | # in the standard.
62 | info_mod = name_to_mod["info"]
63 |
64 | # Note that __array_namespace_info__ is in info.__all__ but it is in the
65 | # top-level namespace, not the info namespace.
66 | info_funcs = [getattr(info_mod, name) for name in info_mod.__all__
67 | if name != '__array_namespace_info__']
68 | assert all(isinstance(f, FunctionType) for f in info_funcs)
69 | name_to_func.update({f.__name__: f for f in info_funcs})
70 |
71 | all_funcs.append(info_mod.__array_namespace_info__)
72 | name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__
73 | category_to_funcs['info'] = [info_mod.__array_namespace_info__]
74 |
75 | EXTENSIONS: List[str] = ["linalg"]
76 | if api_version >= "2022.12":
77 | EXTENSIONS.append("fft")
78 | extension_to_funcs: Dict[str, List[FunctionType]] = {}
79 | for ext in EXTENSIONS:
80 | mod = name_to_mod[ext]
81 | objects = [getattr(mod, name) for name in mod.__all__]
82 | assert all(isinstance(o, FunctionType) for o in objects) # sanity check
83 | funcs = []
84 | for func in objects:
85 | if "Alias" in func.__doc__:
86 | funcs.append(name_to_func[func.__name__])
87 | else:
88 | funcs.append(func)
89 | extension_to_funcs[ext] = funcs
90 |
91 | for funcs in extension_to_funcs.values():
92 | for func in funcs:
93 | if func.__name__ not in name_to_func.keys():
94 | name_to_func[func.__name__] = func
95 |
96 | # sanity check public attributes are not empty
97 | for attr in __all__:
98 | assert len(locals()[attr]) != 0, f"{attr} is empty"
99 |
--------------------------------------------------------------------------------
/array_api_tests/test_array_object.py:
--------------------------------------------------------------------------------
1 | import cmath
2 | import math
3 | from itertools import product
4 | from typing import List, Sequence, Tuple, Union, get_args
5 |
6 | import pytest
7 | from hypothesis import assume, given, note
8 | from hypothesis import strategies as st
9 |
10 | from . import _array_module as xp
11 | from . import dtype_helpers as dh
12 | from . import hypothesis_helpers as hh
13 | from . import pytest_helpers as ph
14 | from . import shape_helpers as sh
15 | from . import xps
16 | from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
17 |
18 |
19 | def scalar_objects(
20 | dtype: DataType, shape: Shape
21 | ) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
22 | """Generates scalars or nested sequences which are valid for xp.asarray()"""
23 | size = math.prod(shape)
24 | return st.lists(hh.from_dtype(dtype), min_size=size, max_size=size).map(
25 | lambda l: sh.reshape(l, shape)
26 | )
27 |
28 |
29 | def normalize_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
30 | """
31 | Normalize an indexing key.
32 |
33 | * If a non-tuple index, wrap as a tuple.
34 | * Represent ellipsis as equivalent slices.
35 | """
36 | _key = tuple(key) if isinstance(key, tuple) else (key,)
37 | if Ellipsis in _key:
38 | nonexpanding_key = tuple(i for i in _key if i is not None)
39 | start_a = nonexpanding_key.index(Ellipsis)
40 | stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
41 | slices = tuple(slice(None) for _ in range(start_a, stop_a))
42 | start_pos = _key.index(Ellipsis)
43 | _key = _key[:start_pos] + slices + _key[start_pos + 1 :]
44 | return _key
45 |
46 |
47 | def get_indexed_axes_and_out_shape(
48 | key: Tuple[Union[int, slice, None], ...], shape: Shape
49 | ) -> Tuple[Tuple[Sequence[int], ...], Shape]:
50 | """
51 | From the (normalized) key and input shape, calculates:
52 |
53 | * indexed_axes: For each dimension, the axes which the key indexes.
54 | * out_shape: The resulting shape of indexing an array (of the input shape)
55 | with the key.
56 | """
57 | axes_indices = []
58 | out_shape = []
59 | a = 0
60 | for i in key:
61 | if i is None:
62 | out_shape.append(1)
63 | else:
64 | side = shape[a]
65 | if isinstance(i, int):
66 | if i < 0:
67 | i += side
68 | axes_indices.append((i,))
69 | else:
70 | indices = range(side)[i]
71 | axes_indices.append(indices)
72 | out_shape.append(len(indices))
73 | a += 1
74 | return tuple(axes_indices), tuple(out_shape)
75 |
76 |
77 | @given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data())
78 | def test_getitem(shape, dtype, data):
79 | zero_sided = any(side == 0 for side in shape)
80 | if zero_sided:
81 | x = xp.zeros(shape, dtype=dtype)
82 | else:
83 | obj = data.draw(scalar_objects(dtype, shape), label="obj")
84 | x = xp.asarray(obj, dtype=dtype)
85 | note(f"{x=}")
86 | key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
87 |
88 | out = x[key]
89 |
90 | ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
91 | _key = normalize_key(key, shape)
92 | axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape)
93 | ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
94 | out_zero_sided = any(side == 0 for side in expected_shape)
95 | if not zero_sided and not out_zero_sided:
96 | out_obj = []
97 | for idx in product(*axes_indices):
98 | val = obj
99 | for i in idx:
100 | val = val[i]
101 | out_obj.append(val)
102 | out_obj = sh.reshape(out_obj, expected_shape)
103 | expected = xp.asarray(out_obj, dtype=dtype)
104 | ph.assert_array_elements("__getitem__", out=out, expected=expected)
105 |
106 |
107 | @pytest.mark.unvectorized
108 | @given(
109 | shape=hh.shapes(),
110 | dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
111 | data=st.data(),
112 | )
113 | def test_setitem(shape, dtypes, data):
114 | zero_sided = any(side == 0 for side in shape)
115 | if zero_sided:
116 | x = xp.zeros(shape, dtype=dtypes.result_dtype)
117 | else:
118 | obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj")
119 | x = xp.asarray(obj, dtype=dtypes.result_dtype)
120 | note(f"{x=}")
121 | key = data.draw(xps.indices(shape=shape), label="key")
122 | _key = normalize_key(key, shape)
123 | axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
124 | value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape)
125 | if out_shape == ():
126 | # We can pass scalars if we're only indexing one element
127 | value_strat |= hh.from_dtype(dtypes.result_dtype)
128 | value = data.draw(value_strat, label="value")
129 |
130 | res = xp.asarray(x, copy=True)
131 | res[key] = value
132 |
133 | ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
134 | ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
135 | f_res = sh.fmt_idx("x", key)
136 | if isinstance(value, get_args(Scalar)):
137 | msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
138 | if cmath.isnan(value):
139 | assert xp.isnan(res[key]), msg
140 | else:
141 | assert res[key] == value, msg
142 | else:
143 | ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
144 | unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
145 | for idx in unaffected_indices:
146 | ph.assert_0d_equals(
147 | "__setitem__",
148 | x_repr=f"old {f_res}",
149 | x_val=x[idx],
150 | out_repr=f"modified {f_res}",
151 | out_val=res[idx],
152 | )
153 |
154 |
155 | @pytest.mark.unvectorized
156 | @pytest.mark.data_dependent_shapes
157 | @given(hh.shapes(), st.data())
158 | def test_getitem_masking(shape, data):
159 | x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
160 | mask_shapes = st.one_of(
161 | st.sampled_from([x.shape, ()]),
162 | st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
163 | lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
164 | ),
165 | hh.shapes(),
166 | )
167 | key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
168 |
169 | if key.ndim > x.ndim or not all(
170 | ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
171 | ):
172 | with pytest.raises(IndexError):
173 | x[key]
174 | return
175 |
176 | out = x[key]
177 |
178 | ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
179 | if key.ndim == 0:
180 | expected_shape = (1,) if key else (0,)
181 | expected_shape += x.shape
182 | else:
183 | size = int(xp.sum(xp.astype(key, xp.uint8)))
184 | expected_shape = (size,) + x.shape[key.ndim :]
185 | ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
186 | if not any(s == 0 for s in key.shape):
187 | assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
188 | out_indices = sh.ndindex(out.shape)
189 | for x_idx in sh.ndindex(x.shape):
190 | if key[x_idx]:
191 | out_idx = next(out_indices)
192 | ph.assert_0d_equals(
193 | "__getitem__",
194 | x_repr=f"x[{x_idx}]",
195 | x_val=x[x_idx],
196 | out_repr=f"out[{out_idx}]",
197 | out_val=out[out_idx],
198 | )
199 |
200 |
201 | @pytest.mark.unvectorized
202 | @given(hh.shapes(), st.data())
203 | def test_setitem_masking(shape, data):
204 | x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
205 | key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
206 | value = data.draw(
207 | hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
208 | )
209 |
210 | res = xp.asarray(x, copy=True)
211 | res[key] = value
212 |
213 | ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
214 | ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
215 | scalar_type = dh.get_scalar_type(x.dtype)
216 | for idx in sh.ndindex(x.shape):
217 | if key[idx]:
218 | if isinstance(value, get_args(Scalar)):
219 | ph.assert_scalar_equals(
220 | "__setitem__",
221 | type_=scalar_type,
222 | idx=idx,
223 | out=scalar_type(res[idx]),
224 | expected=value,
225 | repr_name="modified x",
226 | )
227 | else:
228 | ph.assert_0d_equals(
229 | "__setitem__",
230 | x_repr="value",
231 | x_val=value,
232 | out_repr=f"modified x[{idx}]",
233 | out_val=res[idx]
234 | )
235 | else:
236 | ph.assert_0d_equals(
237 | "__setitem__",
238 | x_repr=f"old x[{idx}]",
239 | x_val=x[idx],
240 | out_repr=f"modified x[{idx}]",
241 | out_val=res[idx]
242 | )
243 |
244 |
245 | # ### Fancy indexing ###
246 |
247 | @pytest.mark.min_version("2024.12")
248 | @pytest.mark.unvectorized
249 | @pytest.mark.parametrize("idx_max_dims", [1, None])
250 | @given(shape=hh.shapes(min_dims=2), data=st.data())
251 | def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims):
252 | # min_dims=2 : test multidim `x` arrays
253 | # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254 | _test_getitem_arrays_and_ints(shape, data, idx_max_dims)
255 |
256 |
257 | @pytest.mark.min_version("2024.12")
258 | @pytest.mark.unvectorized
259 | @pytest.mark.parametrize("idx_max_dims", [1, None])
260 | @given(shape=hh.shapes(min_dims=1), data=st.data())
261 | def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims):
262 | # min_dims=1 : favor 1D `x` arrays
263 | # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264 | _test_getitem_arrays_and_ints(shape, data, idx_max_dims)
265 |
266 |
267 | def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
268 | assume((len(shape) > 0) and all(sh > 0 for sh in shape))
269 |
270 | dtype = xp.int32
271 | obj = data.draw(scalar_objects(dtype, shape), label="obj")
272 | x = xp.asarray(obj, dtype=dtype)
273 |
274 | # draw a mix of ints and index arrays
275 | arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
276 | assume(sum(arr_index) > 0)
277 |
278 | # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279 | # max_dims=None ==> multidim indexing arrays
280 | if sum(arr_index) > 0:
281 | index_shapes = data.draw(
282 | hh.mutually_broadcastable_shapes(
283 | sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1
284 | )
285 | )
286 | index_shapes = list(index_shapes)
287 |
288 | # prepare the indexing tuple, a mix of integer indices and index arrays
289 | key = []
290 | for i,typ in enumerate(arr_index):
291 | if typ:
292 | # draw an array index
293 | this_idx = data.draw(
294 | xps.arrays(
295 | dtype,
296 | shape=index_shapes.pop(),
297 | elements=st.integers(0, shape[i]-1)
298 | )
299 | )
300 | key.append(this_idx)
301 |
302 | else:
303 | # draw an integer
304 | key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
305 |
306 | key = tuple(key)
307 | out = x[key]
308 |
309 | arrays = [xp.asarray(k) for k in key]
310 | bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
311 | bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
312 |
313 | for idx in sh.ndindex(bcast_shape):
314 | tpl = tuple(k[idx] for k in bcast_key)
315 | assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
316 |
317 |
318 | def make_scalar_casting_param(
319 | method_name: str, dtype: DataType, stype: ScalarType
320 | ) -> Param:
321 | dtype_name = dh.dtype_to_name[dtype]
322 | return pytest.param(
323 | method_name, dtype, stype, id=f"{method_name}({dtype_name})"
324 | )
325 |
326 |
327 | @pytest.mark.parametrize(
328 | "method_name, dtype, stype",
329 | [make_scalar_casting_param("__bool__", xp.bool, bool)]
330 | + [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes]
331 | + [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes]
332 | + [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes],
333 | )
334 | @given(data=st.data())
335 | def test_scalar_casting(method_name, dtype, stype, data):
336 | x = data.draw(hh.arrays(dtype, shape=()), label="x")
337 | method = getattr(x, method_name)
338 | out = method()
339 | assert isinstance(
340 | out, stype
341 | ), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar"
342 |
--------------------------------------------------------------------------------
/array_api_tests/test_constants.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Any, SupportsFloat
3 |
4 | import pytest
5 |
6 | from . import dtype_helpers as dh
7 | from . import xp
8 | from .typing import Array
9 |
10 |
11 | def assert_scalar_float(name: str, c: Any):
12 | assert isinstance(c, SupportsFloat), f"{name}={c!r} does not look like a float"
13 |
14 |
15 | def assert_0d_float(name: str, x: Array):
16 | assert dh.is_float_dtype(
17 | x.dtype
18 | ), f"xp.asarray(xp.{name})={x!r}, but should have float dtype"
19 |
20 |
21 | @pytest.mark.parametrize("name, n", [("e", math.e), ("pi", math.pi)])
22 | def test_irrational_numbers(name, n):
23 | assert hasattr(xp, name)
24 | c = getattr(xp, name)
25 | assert_scalar_float(name, c)
26 | floor = math.floor(n)
27 | assert c > floor, f"xp.{name}={c!r} <= {floor}"
28 | ceil = math.ceil(n)
29 | assert c < ceil, f"xp.{name}={c!r} >= {ceil}"
30 | x = xp.asarray(c)
31 | assert_0d_float("name", x)
32 |
33 |
34 | def test_inf():
35 | assert hasattr(xp, "inf")
36 | assert_scalar_float("inf", xp.inf)
37 | assert math.isinf(xp.inf)
38 | assert xp.inf > 0, "xp.inf not greater than 0"
39 | x = xp.asarray(xp.inf)
40 | assert_0d_float("inf", x)
41 | assert xp.isinf(x), "xp.isinf(xp.asarray(xp.inf))=False"
42 |
43 |
44 | def test_nan():
45 | assert hasattr(xp, "nan")
46 | assert_scalar_float("nan", xp.nan)
47 | assert math.isnan(xp.nan)
48 | assert xp.nan != xp.nan, "xp.nan should not have equality with itself"
49 | x = xp.asarray(xp.nan)
50 | assert_0d_float("nan", x)
51 | assert xp.isnan(x), "xp.isnan(xp.asarray(xp.nan))=False"
52 |
53 |
54 | def test_newaxis():
55 | assert hasattr(xp, "newaxis")
56 | assert xp.newaxis is None
57 |
--------------------------------------------------------------------------------
/array_api_tests/test_data_type_functions.py:
--------------------------------------------------------------------------------
1 | import struct
2 | from typing import Union
3 |
4 | import pytest
5 | from hypothesis import given, assume
6 | from hypothesis import strategies as st
7 |
8 | from . import _array_module as xp
9 | from . import dtype_helpers as dh
10 | from . import hypothesis_helpers as hh
11 | from . import pytest_helpers as ph
12 | from . import shape_helpers as sh
13 | from . import xps
14 | from .typing import DataType
15 |
16 |
17 | # TODO: test with complex dtypes
18 | def non_complex_dtypes():
19 | return xps.boolean_dtypes() | hh.real_dtypes
20 |
21 |
22 | def float32(n: Union[int, float]) -> float:
23 | return struct.unpack("!f", struct.pack("!f", float(n)))[0]
24 |
25 |
26 | def _float_match_complex(complex_dtype):
27 | if complex_dtype == xp.complex64:
28 | return xp.float32
29 | elif complex_dtype == xp.complex128:
30 | return xp.float64
31 | else:
32 | return dh.default_float
33 |
34 |
35 | @given(
36 | x_dtype=hh.all_dtypes,
37 | dtype=hh.all_dtypes,
38 | kw=hh.kwargs(copy=st.booleans()),
39 | data=st.data(),
40 | )
41 | def test_astype(x_dtype, dtype, kw, data):
42 | _complex_dtypes = (xp.complex64, xp.complex128)
43 |
44 | if xp.bool in (x_dtype, dtype):
45 | elements_strat = hh.from_dtype(x_dtype)
46 | else:
47 |
48 | if dh.is_int_dtype(x_dtype):
49 | cast = int
50 | elif x_dtype in (xp.float32, xp.complex64):
51 | cast = float32
52 | else:
53 | cast = float
54 |
55 | real_dtype = x_dtype
56 | if x_dtype in _complex_dtypes:
57 | real_dtype = _float_match_complex(x_dtype)
58 | m1, M1 = dh.dtype_ranges[real_dtype]
59 |
60 | real_dtype = dtype
61 | if dtype in _complex_dtypes:
62 | real_dtype = _float_match_complex(x_dtype)
63 | m2, M2 = dh.dtype_ranges[real_dtype]
64 |
65 | min_value = cast(max(m1, m2))
66 | max_value = cast(min(M1, M2))
67 |
68 | elements_strat = hh.from_dtype(
69 | x_dtype,
70 | min_value=min_value,
71 | max_value=max_value,
72 | allow_nan=False,
73 | allow_infinity=False,
74 | )
75 | x = data.draw(
76 | hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
77 | )
78 |
79 | # according to the spec, "Casting a complex floating-point array to a real-valued
80 | # data type should not be permitted."
81 | # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
82 | assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
83 |
84 | out = xp.astype(x, dtype, **kw)
85 |
86 | ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
87 | ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw)
88 | # TODO: test values
89 | # TODO: test copy
90 |
91 |
92 | @given(
93 | shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes), data=st.data()
94 | )
95 | def test_broadcast_arrays(shapes, data):
96 | arrays = []
97 | for c, shape in enumerate(shapes, 1):
98 | x = data.draw(hh.arrays(dtype=hh.all_dtypes, shape=shape), label=f"x{c}")
99 | arrays.append(x)
100 |
101 | out = xp.broadcast_arrays(*arrays)
102 |
103 | expected_shape = sh.broadcast_shapes(*shapes)
104 | for i, x in enumerate(arrays):
105 | ph.assert_dtype(
106 | "broadcast_arrays",
107 | in_dtype=x.dtype,
108 | out_dtype=out[i].dtype,
109 | repr_name=f"out[{i}].dtype"
110 | )
111 | ph.assert_result_shape(
112 | "broadcast_arrays",
113 | in_shapes=shapes,
114 | out_shape=out[i].shape,
115 | expected=expected_shape,
116 | repr_name=f"out[{i}].shape",
117 | )
118 | # TODO: test values
119 |
120 |
121 | @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
122 | def test_broadcast_to(x, data):
123 | shape = data.draw(
124 | hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
125 | .map(lambda S: S[0])
126 | .filter(lambda s: sh.broadcast_shapes(x.shape, s) == s),
127 | label="shape",
128 | )
129 |
130 | out = xp.broadcast_to(x, shape)
131 |
132 | ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype)
133 | ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape)
134 | # TODO: test values
135 |
136 |
137 | @given(_from=hh.all_dtypes, to=hh.all_dtypes)
138 | def test_can_cast(_from, to):
139 | out = xp.can_cast(_from, to)
140 |
141 | expected = False
142 | for other in dh.all_dtypes:
143 | if dh.promotion_table.get((_from, other)) == to:
144 | expected = True
145 | break
146 |
147 | f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
148 | if expected:
149 | # cross-kind casting is not explicitly disallowed. We can only test
150 | # the cases where it should return True. TODO: if expected=False,
151 | # check that the array library actually allows such casts.
152 | assert out == expected, f"{out=}, but should be {expected} {f_func}"
153 |
154 |
155 | @pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
156 | def test_finfo(dtype):
157 | for arg in (
158 | dtype,
159 | xp.asarray(1, dtype=dtype),
160 | # np.float64 and np.asarray(1, dtype=np.float64).dtype are different
161 | xp.asarray(1, dtype=dtype).dtype,
162 | ):
163 | out = xp.finfo(arg)
164 | assert isinstance(out.bits, int)
165 | assert isinstance(out.eps, float)
166 | assert isinstance(out.max, float)
167 | assert isinstance(out.min, float)
168 | assert isinstance(out.smallest_normal, float)
169 |
170 |
171 | @pytest.mark.min_version("2022.12")
172 | @pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
173 | def test_finfo_dtype(dtype):
174 | out = xp.finfo(dtype)
175 |
176 | if dtype == xp.complex64:
177 | assert out.dtype == xp.float32
178 | elif dtype == xp.complex128:
179 | assert out.dtype == xp.float64
180 | else:
181 | assert out.dtype == dtype
182 |
183 | # Guard vs. numpy.dtype.__eq__ lax comparison
184 | assert not isinstance(out.dtype, str)
185 | assert out.dtype is not float
186 | assert out.dtype is not complex
187 |
188 |
189 | @pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
190 | def test_iinfo(dtype):
191 | for arg in (
192 | dtype,
193 | xp.asarray(1, dtype=dtype),
194 | # np.int64 and np.asarray(1, dtype=np.int64).dtype are different
195 | xp.asarray(1, dtype=dtype).dtype,
196 | ):
197 | out = xp.iinfo(arg)
198 | assert isinstance(out.bits, int)
199 | assert isinstance(out.max, int)
200 | assert isinstance(out.min, int)
201 |
202 |
203 | @pytest.mark.min_version("2022.12")
204 | @pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
205 | def test_iinfo_dtype(dtype):
206 | out = xp.iinfo(dtype)
207 | assert out.dtype == dtype
208 | # Guard vs. numpy.dtype.__eq__ lax comparison
209 | assert not isinstance(out.dtype, str)
210 | assert out.dtype is not int
211 |
212 |
213 | def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
214 | return hh.all_dtypes | st.sampled_from(list(dh.kind_to_dtypes.keys()))
215 |
216 |
217 | @pytest.mark.min_version("2022.12")
218 | @given(
219 | dtype=hh.all_dtypes,
220 | kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
221 | )
222 | def test_isdtype(dtype, kind):
223 | out = xp.isdtype(dtype, kind)
224 |
225 | assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
226 | _kinds = kind if isinstance(kind, tuple) else (kind,)
227 | expected = False
228 | for _kind in _kinds:
229 | if isinstance(_kind, str):
230 | if dtype in dh.kind_to_dtypes[_kind]:
231 | expected = True
232 | break
233 | else:
234 | if dtype == _kind:
235 | expected = True
236 | break
237 | assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
238 |
239 |
240 | @pytest.mark.min_version("2024.12")
241 | class TestResultType:
242 | @given(dtypes=hh.mutually_promotable_dtypes(None))
243 | def test_result_type(self, dtypes):
244 | out = xp.result_type(*dtypes)
245 | ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
246 |
247 | @given(pair=hh.pair_of_mutually_promotable_dtypes(None))
248 | def test_shuffled(self, pair):
249 | """Test that result_type is insensitive to the order of arguments."""
250 | s1, s2 = pair
251 | out1 = xp.result_type(*s1)
252 | out2 = xp.result_type(*s2)
253 | assert out1 == out2
254 |
255 | @given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data())
256 | def test_arrays_and_dtypes(self, pair, data):
257 | s1, s2 = pair
258 | a2 = tuple(xp.empty(1, dtype=dt) for dt in s2)
259 | a_and_dt = data.draw(st.permutations(s1 + a2))
260 | out = xp.result_type(*a_and_dt)
261 | ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
262 |
263 | @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
264 | def test_with_scalars(self, dtypes, data):
265 | out = xp.result_type(*dtypes)
266 |
267 | if out == xp.bool:
268 | scalars = [True]
269 | elif out in dh.all_int_dtypes:
270 | scalars = [1]
271 | elif out in dh.real_dtypes:
272 | scalars = [1, 1.0]
273 | elif out in dh.numeric_dtypes:
274 | scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
275 | else:
276 | raise ValueError(f"unknown dtype {out = }.")
277 |
278 | scalar = data.draw(st.sampled_from(scalars))
279 | inputs = data.draw(st.permutations(dtypes + (scalar,)))
280 |
281 | out_scalar = xp.result_type(*inputs)
282 | assert out_scalar == out
283 |
284 | # retry with arrays
285 | arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
286 | inputs = data.draw(st.permutations(arrays + (scalar,)))
287 | out_scalar = xp.result_type(*inputs)
288 | assert out_scalar == out
289 |
290 |
--------------------------------------------------------------------------------
/array_api_tests/test_fft.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Optional
3 |
4 | import pytest
5 | from hypothesis import assume, given
6 | from hypothesis import strategies as st
7 |
8 | from array_api_tests.typing import Array
9 |
10 | from . import dtype_helpers as dh
11 | from . import hypothesis_helpers as hh
12 | from . import pytest_helpers as ph
13 | from . import shape_helpers as sh
14 | from . import xp
15 |
16 | pytestmark = [
17 | pytest.mark.xp_extension("fft"),
18 | pytest.mark.min_version("2022.12"),
19 | ]
20 |
21 | fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
22 |
23 |
24 | def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
25 | size = math.prod(x.shape)
26 | n = data.draw(
27 | st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n"
28 | )
29 | axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
30 | if size_gt_1:
31 | _axis = x.ndim - 1 if axis == -1 else axis
32 | assume(x.shape[_axis] > 1)
33 | norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
34 | kwargs = data.draw(
35 | hh.specified_kwargs(
36 | ("n", n, None),
37 | ("axis", axis, -1),
38 | ("norm", norm, "backward"),
39 | ),
40 | label="kwargs",
41 | )
42 | return n, axis, norm, kwargs
43 |
44 |
45 | def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
46 | all_axes = list(range(x.ndim))
47 | axes = data.draw(
48 | st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True),
49 | label="axes",
50 | )
51 | _axes = all_axes if axes is None else axes
52 | axes_sides = [x.shape[axis] for axis in _axes]
53 | s_strat = st.tuples(
54 | *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides]
55 | )
56 | if axes is None:
57 | s_strat = st.none() | s_strat
58 | s = data.draw(s_strat, label="s")
59 |
60 | # Using `axes is None and s is not None` is disallowed by the spec
61 | assume(axes is not None or s is None)
62 |
63 | norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
64 | kwargs = data.draw(
65 | hh.specified_kwargs(
66 | ("s", s, None),
67 | ("axes", axes, None),
68 | ("norm", norm, "backward"),
69 | ),
70 | label="kwargs",
71 | )
72 | return s, axes, norm, kwargs
73 |
74 |
75 | def assert_n_axis_shape(
76 | func_name: str,
77 | *,
78 | x: Array,
79 | n: Optional[int],
80 | axis: int,
81 | out: Array,
82 | ):
83 | _axis = len(x.shape) - 1 if axis == -1 else axis
84 | if n is None:
85 | axis_side = x.shape[_axis]
86 | else:
87 | axis_side = n
88 | expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
89 | ph.assert_shape(func_name, out_shape=out.shape, expected=expected)
90 |
91 |
92 | def assert_s_axes_shape(
93 | func_name: str,
94 | *,
95 | x: Array,
96 | s: Optional[List[int]],
97 | axes: Optional[List[int]],
98 | out: Array,
99 | ):
100 | _axes = sh.normalize_axis(axes, x.ndim)
101 | _s = x.shape if s is None else s
102 | expected = []
103 | for i in range(x.ndim):
104 | if i in _axes:
105 | side = _s[_axes.index(i)]
106 | else:
107 | side = x.shape[i]
108 | expected.append(side)
109 | ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))
110 |
111 |
112 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
113 | def test_fft(x, data):
114 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
115 |
116 | out = xp.fft.fft(x, **kwargs)
117 |
118 | ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
119 | assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
120 |
121 |
122 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
123 | def test_ifft(x, data):
124 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
125 |
126 | out = xp.fft.ifft(x, **kwargs)
127 |
128 | ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
129 | assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
130 |
131 |
132 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
133 | def test_fftn(x, data):
134 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
135 |
136 | out = xp.fft.fftn(x, **kwargs)
137 |
138 | ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
139 | assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
140 |
141 |
142 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
143 | def test_ifftn(x, data):
144 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
145 |
146 | out = xp.fft.ifftn(x, **kwargs)
147 |
148 | ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
149 | assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)
150 |
151 |
152 | @given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
153 | def test_rfft(x, data):
154 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
155 |
156 | out = xp.fft.rfft(x, **kwargs)
157 |
158 | ph.assert_float_to_complex_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)
159 |
160 | _axis = x.ndim - 1 if axis == -1 else axis
161 | if n is None:
162 | axis_side = x.shape[_axis] // 2 + 1
163 | else:
164 | axis_side = n // 2 + 1
165 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
166 | ph.assert_shape("rfft", out_shape=out.shape, expected=expected_shape)
167 |
168 |
169 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
170 | def test_irfft(x, data):
171 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
172 |
173 | out = xp.fft.irfft(x, **kwargs)
174 |
175 | ph.assert_dtype(
176 | "irfft",
177 | in_dtype=x.dtype,
178 | out_dtype=out.dtype,
179 | expected=dh.dtype_components[x.dtype],
180 | )
181 |
182 | _axis = x.ndim - 1 if axis == -1 else axis
183 | if n is None:
184 | axis_side = 2 * (x.shape[_axis] - 1)
185 | else:
186 | axis_side = n
187 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
188 | ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)
189 |
190 |
191 | @given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
192 | def test_rfftn(x, data):
193 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
194 |
195 | out = xp.fft.rfftn(x, **kwargs)
196 |
197 | ph.assert_float_to_complex_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
198 |
199 | _axes = sh.normalize_axis(axes, x.ndim)
200 | _s = x.shape if s is None else s
201 | expected = []
202 | for i in range(x.ndim):
203 | if i in _axes:
204 | side = _s[_axes.index(i)]
205 | else:
206 | side = x.shape[i]
207 | expected.append(side)
208 | expected[_axes[-1]] = _s[-1] // 2 + 1
209 | ph.assert_shape("rfftn", out_shape=out.shape, expected=tuple(expected))
210 |
211 |
212 | @given(
213 | x=hh.arrays(
214 | dtype=hh.complex_dtypes, shape=fft_shapes_strat.filter(lambda s: s[-1] > 1)
215 | ),
216 | data=st.data(),
217 | )
218 | def test_irfftn(x, data):
219 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
220 |
221 | out = xp.fft.irfftn(x, **kwargs)
222 |
223 | ph.assert_dtype(
224 | "irfftn",
225 | in_dtype=x.dtype,
226 | out_dtype=out.dtype,
227 | expected=dh.dtype_components[x.dtype],
228 | )
229 |
230 | _axes = sh.normalize_axis(axes, x.ndim)
231 | _s = x.shape if s is None else s
232 | expected = []
233 | for i in range(x.ndim):
234 | if i in _axes:
235 | side = _s[_axes.index(i)]
236 | else:
237 | side = x.shape[i]
238 | expected.append(side)
239 | expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1]
240 | ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
241 |
242 |
243 | @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
244 | def test_hfft(x, data):
245 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
246 |
247 | out = xp.fft.hfft(x, **kwargs)
248 |
249 | ph.assert_dtype(
250 | "hfft",
251 | in_dtype=x.dtype,
252 | out_dtype=out.dtype,
253 | expected=dh.dtype_components[x.dtype],
254 | )
255 |
256 | _axis = x.ndim - 1 if axis == -1 else axis
257 | if n is None:
258 | axis_side = 2 * (x.shape[_axis] - 1)
259 | else:
260 | axis_side = n
261 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
262 | ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)
263 |
264 |
265 | @given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
266 | def test_ihfft(x, data):
267 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
268 |
269 | out = xp.fft.ihfft(x, **kwargs)
270 |
271 | ph.assert_float_to_complex_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
272 |
273 | _axis = x.ndim - 1 if axis == -1 else axis
274 | if n is None:
275 | axis_side = x.shape[_axis] // 2 + 1
276 | else:
277 | axis_side = n // 2 + 1
278 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
279 | ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape)
280 |
281 |
282 | @given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
283 | def test_fftfreq(n, kw):
284 | out = xp.fft.fftfreq(n, **kw)
285 | ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
286 |
287 |
288 | @given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
289 | def test_rfftfreq(n, kw):
290 | out = xp.fft.rfftfreq(n, **kw)
291 | ph.assert_shape(
292 | "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n}
293 | )
294 |
295 |
296 | @pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
297 | @given(x=hh.arrays(hh.floating_dtypes, fft_shapes_strat), data=st.data())
298 | def test_shift_func(func_name, x, data):
299 | func = getattr(xp.fft, func_name)
300 | axes = data.draw(
301 | st.none()
302 | | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
303 | label="axes",
304 | )
305 | out = func(x, axes=axes)
306 | ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
307 | ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)
308 |
--------------------------------------------------------------------------------
/array_api_tests/test_has_names.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a very basic test to see what names are defined in a library. It
3 | does not even require functioning hypothesis array_api support.
4 | """
5 |
6 | import pytest
7 |
8 | from . import xp
9 | from .stubs import (array_attributes, array_methods, category_to_funcs,
10 | extension_to_funcs, EXTENSIONS)
11 |
12 | has_name_params = []
13 | for ext, stubs in extension_to_funcs.items():
14 | for stub in stubs:
15 | has_name_params.append(pytest.param(ext, stub.__name__))
16 | for cat, stubs in category_to_funcs.items():
17 | for stub in stubs:
18 | has_name_params.append(pytest.param(cat, stub.__name__))
19 | for meth in array_methods:
20 | has_name_params.append(pytest.param('array_method', meth.__name__))
21 | for attr in array_attributes:
22 | has_name_params.append(pytest.param('array_attribute', attr))
23 |
24 | @pytest.mark.parametrize("category, name", has_name_params)
25 | def test_has_names(category, name):
26 | if category in EXTENSIONS:
27 | ext_mod = getattr(xp, category)
28 | assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()"
29 | elif category.startswith('array_'):
30 | # TODO: This would fail if ones() is missing.
31 | arr = xp.ones((1, 1))
32 | if category == 'array_attribute':
33 | assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}"
34 | else:
35 | assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()"
36 | else:
37 | assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()"
38 |
--------------------------------------------------------------------------------
/array_api_tests/test_indexing_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from hypothesis import given, note
3 | from hypothesis import strategies as st
4 |
5 | from . import _array_module as xp
6 | from . import dtype_helpers as dh
7 | from . import hypothesis_helpers as hh
8 | from . import pytest_helpers as ph
9 | from . import shape_helpers as sh
10 |
11 |
12 | @pytest.mark.unvectorized
13 | @pytest.mark.min_version("2022.12")
14 | @given(
15 | x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
16 | data=st.data(),
17 | )
18 | def test_take(x, data):
19 | # TODO:
20 | # * negative axis
21 | # * negative indices
22 | # * different dtypes for indices
23 | axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
24 | _indices = data.draw(
25 | st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
26 | label="_indices",
27 | )
28 | indices = xp.asarray(_indices, dtype=dh.default_int)
29 | note(f"{indices=}")
30 |
31 | out = xp.take(x, indices, axis=axis)
32 |
33 | ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
34 | ph.assert_shape(
35 | "take",
36 | out_shape=out.shape,
37 | expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
38 | kw=dict(
39 | x=x,
40 | indices=indices,
41 | axis=axis,
42 | ),
43 | )
44 | out_indices = sh.ndindex(out.shape)
45 | axis_indices = list(sh.axis_ndindex(x.shape, axis))
46 | for axis_idx in axis_indices:
47 | f_axis_idx = sh.fmt_idx("x", axis_idx)
48 | for i in _indices:
49 | f_take_idx = sh.fmt_idx(f_axis_idx, i)
50 | indexed_x = x[axis_idx][i, ...]
51 | for at_idx in sh.ndindex(indexed_x.shape):
52 | out_idx = next(out_indices)
53 | ph.assert_0d_equals(
54 | "take",
55 | x_repr=sh.fmt_idx(f_take_idx, at_idx),
56 | x_val=indexed_x[at_idx],
57 | out_repr=sh.fmt_idx("out", out_idx),
58 | out_val=out[out_idx],
59 | )
60 | # sanity check
61 | with pytest.raises(StopIteration):
62 | next(out_indices)
63 |
64 |
65 |
66 | @pytest.mark.unvectorized
67 | @pytest.mark.min_version("2024.12")
68 | @given(
69 | x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
70 | data=st.data(),
71 | )
72 | def test_take_along_axis(x, data):
73 | # TODO
74 | # 2. negative indices
75 | # 3. different dtypes for indices
76 | # 4. "broadcast-compatible" indices
77 | axis = data.draw(
78 | st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
79 | label="axis"
80 | )
81 | if axis is None:
82 | axis_kw = {}
83 | n_axis = x.ndim - 1
84 | else:
85 | axis_kw = {"axis": axis}
86 | n_axis = axis + x.ndim if axis < 0 else axis
87 |
88 | new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len")
89 | idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:]
90 | indices = data.draw(
91 | hh.arrays(
92 | shape=idx_shape,
93 | dtype=dh.default_int,
94 | elements={"min_value": 0, "max_value": x.shape[n_axis]-1}
95 | ),
96 | label="indices"
97 | )
98 | note(f"{indices=} {idx_shape=}")
99 |
100 | out = xp.take_along_axis(x, indices, **axis_kw)
101 |
102 | ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
103 | ph.assert_shape(
104 | "take_along_axis",
105 | out_shape=out.shape,
106 | expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
107 | kw=dict(
108 | x=x,
109 | indices=indices,
110 | axis=axis,
111 | ),
112 | )
113 |
114 | # value test: notation is from `np.take_along_axis` docstring
115 | Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
116 | for ii in sh.ndindex(Ni):
117 | for kk in sh.ndindex(Nk):
118 | a_1d = x[ii + (slice(None),) + kk]
119 | i_1d = indices[ii + (slice(None),) + kk]
120 | o_1d = out[ii + (slice(None),) + kk]
121 | for j in range(new_len):
122 | assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
123 |
--------------------------------------------------------------------------------
/array_api_tests/test_inspection_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from hypothesis import given, strategies as st
3 | from array_api_tests.dtype_helpers import available_kinds, dtype_names
4 |
5 | from . import xp
6 |
7 | pytestmark = pytest.mark.min_version("2023.12")
8 |
9 |
10 | class TestInspection:
11 | def test_capabilities(self):
12 | out = xp.__array_namespace_info__()
13 |
14 | capabilities = out.capabilities()
15 | assert isinstance(capabilities, dict)
16 |
17 | expected_attr = {"boolean indexing": bool, "data-dependent shapes": bool}
18 | if xp.__array_api_version__ >= "2024.12":
19 | expected_attr.update(**{"max dimensions": type(None) | int})
20 |
21 | for attr, typ in expected_attr.items():
22 | assert attr in capabilities, f'capabilites is missing "{attr}".'
23 | assert isinstance(capabilities[attr], typ)
24 |
25 | max_dims = capabilities.get("max dimensions", 100500)
26 | assert (max_dims is None) or (max_dims > 0)
27 |
28 | def test_devices(self):
29 | out = xp.__array_namespace_info__()
30 |
31 | assert hasattr(out, "devices")
32 | assert hasattr(out, "default_device")
33 |
34 | assert isinstance(out.devices(), list)
35 | if out.default_device() is not None:
36 | # Per https://github.com/data-apis/array-api/issues/923
37 | # default_device() can return None. Otherwise, it must be a valid device.
38 | assert out.default_device() in out.devices()
39 |
40 | def test_default_dtypes(self):
41 | out = xp.__array_namespace_info__()
42 |
43 | for device in xp.__array_namespace_info__().devices():
44 | default_dtypes = out.default_dtypes(device=device)
45 | assert isinstance(default_dtypes, dict)
46 | expected_subset = (
47 | {"real floating", "complex floating", "integral"}
48 | & available_kinds()
49 | | {"indexing"}
50 | )
51 | assert expected_subset.issubset(set(default_dtypes.keys()))
52 |
53 |
54 | atomic_kinds = [
55 | "bool",
56 | "signed integer",
57 | "unsigned integer",
58 | "real floating",
59 | "complex floating",
60 | ]
61 |
62 |
63 | @given(
64 | kind=st.one_of(
65 | st.none(),
66 | st.sampled_from(atomic_kinds + ["integral", "numeric"]),
67 | st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple),
68 | ),
69 | device=st.one_of(
70 | st.none(),
71 | st.sampled_from(xp.__array_namespace_info__().devices())
72 | )
73 | )
74 | def test_array_namespace_info_dtypes(kind, device):
75 | out = xp.__array_namespace_info__().dtypes(kind=kind, device=device)
76 | assert isinstance(out, dict)
77 |
78 | for name, dtyp in out.items():
79 | assert name in dtype_names
80 | xp.empty(1, dtype=dtyp, device=device) # check `dtyp` is a valid dtype
81 |
82 |
--------------------------------------------------------------------------------
/array_api_tests/test_manipulation_functions.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import deque
3 | from typing import Iterable, Iterator, Tuple, Union
4 |
5 | import pytest
6 | from hypothesis import assume, given
7 | from hypothesis import strategies as st
8 |
9 | from . import _array_module as xp
10 | from . import dtype_helpers as dh
11 | from . import hypothesis_helpers as hh
12 | from . import pytest_helpers as ph
13 | from . import shape_helpers as sh
14 | from . import xps
15 | from .typing import Array, Shape
16 |
17 |
18 | def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
19 | key = "shape"
20 | if args:
21 | key += " " + " ".join(args)
22 | if kwargs:
23 | key += " " + ph.fmt_kw(kwargs)
24 | return st.shared(hh.shapes(*args, **kwargs), key="shape")
25 |
26 |
27 | def assert_array_ndindex(
28 | func_name: str,
29 | x: Array,
30 | *,
31 | x_indices: Iterable[Union[int, Shape]],
32 | out: Array,
33 | out_indices: Iterable[Union[int, Shape]],
34 | kw: dict = {},
35 | ):
36 | msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}"
37 | for x_idx, out_idx in zip(x_indices, out_indices):
38 | msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}"
39 | msg += msg_suffix
40 | if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
41 | assert xp.isnan(out[out_idx]), msg
42 | else:
43 | assert out[out_idx] == x[x_idx], msg
44 |
45 |
46 | @pytest.mark.unvectorized
47 | @given(
48 | dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
49 | base_shape=hh.shapes(),
50 | data=st.data(),
51 | )
52 | def test_concat(dtypes, base_shape, data):
53 | axis_strat = st.none()
54 | ndim = len(base_shape)
55 | if ndim > 0:
56 | axis_strat |= st.integers(-ndim, ndim - 1)
57 | kw = data.draw(
58 | axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
59 | )
60 | axis = kw.get("axis", 0)
61 | if axis is None:
62 | _axis = None
63 | shape_strat = hh.shapes()
64 | else:
65 | _axis = axis if axis >= 0 else len(base_shape) + axis
66 | shape_strat = st.integers(0, hh.MAX_SIDE).map(
67 | lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
68 | )
69 | arrays = []
70 | for i, dtype in enumerate(dtypes, 1):
71 | x = data.draw(hh.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}")
72 | arrays.append(x)
73 |
74 | out = xp.concat(arrays, **kw)
75 |
76 | ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype)
77 |
78 | shapes = tuple(x.shape for x in arrays)
79 | if _axis is None:
80 | size = sum(math.prod(s) for s in shapes)
81 | shape = (size,)
82 | else:
83 | shape = list(shapes[0])
84 | for other_shape in shapes[1:]:
85 | shape[_axis] += other_shape[_axis]
86 | shape = tuple(shape)
87 | ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw)
88 |
89 | if _axis is None:
90 | out_indices = (i for i in range(math.prod(out.shape)))
91 | for x_num, x in enumerate(arrays, 1):
92 | for x_idx in sh.ndindex(x.shape):
93 | out_i = next(out_indices)
94 | ph.assert_0d_equals(
95 | "concat",
96 | x_repr=f"x{x_num}[{x_idx}]",
97 | x_val=x[x_idx],
98 | out_repr=f"out[{out_i}]",
99 | out_val=out[out_i],
100 | kw=kw,
101 | )
102 | else:
103 | out_indices = sh.ndindex(out.shape)
104 | for idx in sh.axis_ndindex(shapes[0], _axis):
105 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
106 | for x_num, x in enumerate(arrays, 1):
107 | indexed_x = x[idx]
108 | for x_idx in sh.ndindex(indexed_x.shape):
109 | out_idx = next(out_indices)
110 | ph.assert_0d_equals(
111 | "concat",
112 | x_repr=f"x{x_num}[{f_idx}][{x_idx}]",
113 | x_val=indexed_x[x_idx],
114 | out_repr=f"out[{out_idx}]",
115 | out_val=out[out_idx],
116 | kw=kw,
117 | )
118 |
119 |
120 | @pytest.mark.unvectorized
121 | @given(
122 | x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()),
123 | axis=shared_shapes().flatmap(
124 | # Generate both valid and invalid axis
125 | lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s))
126 | ),
127 | )
128 | def test_expand_dims(x, axis):
129 | if axis < -x.ndim - 1 or axis > x.ndim:
130 | with pytest.raises(IndexError):
131 | xp.expand_dims(x, axis=axis)
132 | return
133 |
134 | out = xp.expand_dims(x, axis=axis)
135 |
136 | ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype)
137 |
138 | shape = [side for side in x.shape]
139 | index = axis if axis >= 0 else x.ndim + axis + 1
140 | shape.insert(index, 1)
141 | shape = tuple(shape)
142 | ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape)
143 |
144 | assert_array_ndindex(
145 | "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)
146 | )
147 |
148 |
149 | @pytest.mark.min_version("2023.12")
150 | @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data())
151 | def test_moveaxis(x, data):
152 | source = data.draw(
153 | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source"
154 | )
155 | if isinstance(source, int):
156 | destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination")
157 | else:
158 | assert isinstance(source, tuple) # sanity check
159 | destination = data.draw(
160 | st.lists(
161 | st.integers(-x.ndim, x.ndim - 1),
162 | min_size=len(source),
163 | max_size=len(source),
164 | unique_by=lambda n: n if n >= 0 else x.ndim + n,
165 | ).map(tuple),
166 | label="destination"
167 | )
168 |
169 | out = xp.moveaxis(x, source, destination)
170 |
171 | ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
172 |
173 |
174 | _source = sh.normalize_axis(source, x.ndim)
175 | _destination = sh.normalize_axis(destination, x.ndim)
176 |
177 | new_axes = [n for n in range(x.ndim) if n not in _source]
178 |
179 | for dest, src in sorted(zip(_destination, _source)):
180 | new_axes.insert(dest, src)
181 |
182 | expected_shape = tuple(x.shape[i] for i in new_axes)
183 |
184 | ph.assert_result_shape("moveaxis", in_shapes=[x.shape],
185 | out_shape=out.shape, expected=expected_shape,
186 | kw={"source": source, "destination": destination})
187 |
188 | indices = list(sh.ndindex(x.shape))
189 | permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices]
190 | assert_array_ndindex(
191 | "moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices
192 | )
193 |
194 | @pytest.mark.unvectorized
195 | @given(
196 | x=hh.arrays(
197 | dtype=hh.all_dtypes, shape=hh.shapes(min_side=1).filter(lambda s: 1 in s)
198 | ),
199 | data=st.data(),
200 | )
201 | def test_squeeze(x, data):
202 | axes = st.integers(-x.ndim, x.ndim - 1)
203 | axis = data.draw(
204 | axes
205 | | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple),
206 | label="axis",
207 | )
208 |
209 | axes = (axis,) if isinstance(axis, int) else axis
210 | axes = sh.normalize_axis(axes, x.ndim)
211 |
212 | squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1]
213 | if any(i not in squeezable_axes for i in axes):
214 | with pytest.raises(ValueError):
215 | xp.squeeze(x, axis)
216 | return
217 |
218 | out = xp.squeeze(x, axis)
219 |
220 | ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype)
221 |
222 | shape = []
223 | for i, side in enumerate(x.shape):
224 | if i not in axes:
225 | shape.append(side)
226 | shape = tuple(shape)
227 | ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis))
228 |
229 | assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape))
230 |
231 |
232 | @pytest.mark.unvectorized
233 | @given(
234 | x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
235 | data=st.data(),
236 | )
237 | def test_flip(x, data):
238 | if x.ndim == 0:
239 | axis_strat = st.none()
240 | else:
241 | axis_strat = (
242 | st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
243 | )
244 | kw = data.draw(hh.kwargs(axis=axis_strat), label="kw")
245 |
246 | out = xp.flip(x, **kw)
247 |
248 | ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype)
249 |
250 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
251 | for indices in sh.axes_ndindex(x.shape, _axes):
252 | reverse_indices = indices[::-1]
253 | assert_array_ndindex("flip", x, x_indices=indices, out=out,
254 | out_indices=reverse_indices, kw=kw)
255 |
256 |
257 | @pytest.mark.unvectorized
258 | @given(
259 | x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
260 | axes=shared_shapes(min_dims=1).flatmap(
261 | lambda s: st.lists(
262 | st.integers(0, len(s) - 1),
263 | min_size=len(s),
264 | max_size=len(s),
265 | unique=True,
266 | ).map(tuple)
267 | ),
268 | )
269 | def test_permute_dims(x, axes):
270 | out = xp.permute_dims(x, axes)
271 |
272 | ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype)
273 |
274 | shape = [None for _ in range(len(axes))]
275 | for i, dim in enumerate(axes):
276 | side = x.shape[dim]
277 | shape[i] = side
278 | shape = tuple(shape)
279 | ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes))
280 |
281 | indices = list(sh.ndindex(x.shape))
282 | permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices]
283 | assert_array_ndindex("permute_dims", x, x_indices=indices, out=out,
284 | out_indices=permuted_indices)
285 |
286 |
287 | @pytest.mark.min_version("2023.12")
288 | @given(
289 | x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
290 | kw=hh.kwargs(
291 | axis=st.none() | shared_shapes(min_dims=1).flatmap(
292 | lambda s: st.integers(-len(s), len(s) - 1)
293 | )
294 | ),
295 | data=st.data(),
296 | )
297 | def test_repeat(x, kw, data):
298 | shape = x.shape
299 | axis = kw.get("axis", None)
300 | size = math.prod(shape) if axis is None else shape[axis]
301 | repeat_strat = st.integers(1, 10)
302 | repeats = data.draw(repeat_strat
303 | | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
304 | shape=st.sampled_from([(1,), (size,)])),
305 | label="repeats")
306 | if isinstance(repeats, int):
307 | n_repititions = size*repeats
308 | else:
309 | if repeats.shape == (1,):
310 | n_repititions = size*int(repeats[0])
311 | else:
312 | n_repititions = int(xp.sum(repeats))
313 |
314 | assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE)
315 |
316 | out = xp.repeat(x, repeats, **kw)
317 | ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
318 | if axis is None:
319 | expected_shape = (n_repititions,)
320 | else:
321 | expected_shape = list(shape)
322 | expected_shape[axis] = n_repititions
323 | expected_shape = tuple(expected_shape)
324 | ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
325 |
326 | # Test values
327 |
328 | if isinstance(repeats, int):
329 | repeats_array = xp.full(size, repeats, dtype=xp.int32)
330 | else:
331 | repeats_array = repeats
332 |
333 | if kw.get("axis") is None:
334 | x = xp.reshape(x, (-1,))
335 | axis = 0
336 |
337 | for idx, in sh.iter_indices(x.shape, skip_axes=axis):
338 | x_slice = x[idx]
339 | out_slice = out[idx]
340 | start = 0
341 | for i, count in enumerate(repeats_array):
342 | end = start + count
343 | ph.assert_array_elements("repeat", out=out_slice[start:end],
344 | expected=xp.full((count,), x_slice[i], dtype=x.dtype),
345 | kw=kw)
346 | start = end
347 |
348 | reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
349 |
350 | @pytest.mark.unvectorized
351 | @given(
352 | x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
353 | shape=hh.reshape_shapes(reshape_shape),
354 | )
355 | def test_reshape(x, shape):
356 | out = xp.reshape(x, shape)
357 |
358 | ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)
359 |
360 | _shape = list(shape)
361 | if any(side == -1 for side in shape):
362 | size = math.prod(x.shape)
363 | rsize = math.prod(shape) * -1
364 | _shape[shape.index(-1)] = size / rsize
365 | _shape = tuple(_shape)
366 | ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape))
367 |
368 | assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape))
369 |
370 |
371 | def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]:
372 | assert len(shifts) == len(axes) # sanity check
373 | all_shifts = [0 for _ in shape]
374 | for s, a in zip(shifts, axes):
375 | all_shifts[a] = s
376 | for idx in sh.ndindex(shape):
377 | yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape))
378 |
379 |
380 | @pytest.mark.unvectorized
381 | @given(hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), st.data())
382 | def test_roll(x, data):
383 | shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)
384 | if x.ndim > 0:
385 | shift_strat = shift_strat | st.lists(
386 | shift_strat, min_size=1, max_size=x.ndim
387 | ).map(tuple)
388 | shift = data.draw(shift_strat, label="shift")
389 | if isinstance(shift, tuple):
390 | axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift))
391 | kw_strat = axis_strat.map(lambda t: {"axis": t})
392 | else:
393 | axis_strat = st.none()
394 | if x.ndim != 0:
395 | axis_strat |= st.integers(-x.ndim, x.ndim - 1)
396 | kw_strat = hh.kwargs(axis=axis_strat)
397 | kw = data.draw(kw_strat, label="kw")
398 |
399 | out = xp.roll(x, shift, **kw)
400 |
401 | kw = {"shift": shift, **kw} # for error messages
402 |
403 | ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype)
404 |
405 | ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw)
406 |
407 | if kw.get("axis", None) is None:
408 | assert isinstance(shift, int) # sanity check
409 | indices = list(sh.ndindex(x.shape))
410 | shifted_indices = deque(indices)
411 | shifted_indices.rotate(-shift)
412 | assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw)
413 | else:
414 | shifts = (shift,) if isinstance(shift, int) else shift
415 | axes = sh.normalize_axis(kw["axis"], x.ndim)
416 | shifted_indices = roll_ndindex(x.shape, shifts, axes)
417 | assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw)
418 |
419 |
420 | @pytest.mark.unvectorized
421 | @given(
422 | shape=shared_shapes(min_dims=1),
423 | dtypes=hh.mutually_promotable_dtypes(None),
424 | kw=hh.kwargs(
425 | axis=shared_shapes(min_dims=1).flatmap(
426 | lambda s: st.integers(-len(s), len(s) - 1)
427 | )
428 | ),
429 | data=st.data(),
430 | )
431 | def test_stack(shape, dtypes, kw, data):
432 | arrays = []
433 | for i, dtype in enumerate(dtypes, 1):
434 | x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
435 | arrays.append(x)
436 |
437 | out = xp.stack(arrays, **kw)
438 |
439 | ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype)
440 |
441 | axis = kw.get("axis", 0)
442 | _axis = axis if axis >= 0 else len(shape) + axis + 1
443 | _shape = list(shape)
444 | _shape.insert(_axis, len(arrays))
445 | _shape = tuple(_shape)
446 | ph.assert_result_shape(
447 | "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw
448 | )
449 |
450 | out_indices = sh.ndindex(out.shape)
451 | for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis):
452 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
453 | for x_num, x in enumerate(arrays, 1):
454 | indexed_x = x[idx]
455 | for x_idx in sh.ndindex(indexed_x.shape):
456 | out_idx = next(out_indices)
457 | ph.assert_0d_equals(
458 | "stack",
459 | x_repr=f"x{x_num}[{f_idx}][{x_idx}]",
460 | x_val=indexed_x[x_idx],
461 | out_repr=f"out[{out_idx}]",
462 | out_val=out[out_idx],
463 | kw=kw,
464 | )
465 |
466 |
467 | @pytest.mark.min_version("2023.12")
468 | @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
469 | def test_tile(x, data):
470 | repetitions = data.draw(
471 | st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple),
472 | label="repetitions"
473 | )
474 | out = xp.tile(x, repetitions)
475 | ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype)
476 | # TODO: values testing
477 |
478 | # shape check; the notation is from the Array API docs
479 | N, M = len(x.shape), len(repetitions)
480 | if N > M:
481 | S = x.shape
482 | R = (1,)*(N - M) + repetitions
483 | else:
484 | S = (1,)*(M - N) + x.shape
485 | R = repetitions
486 |
487 | assert out.shape == tuple(r*s for r, s in zip(R, S))
488 |
489 |
490 | @pytest.mark.min_version("2023.12")
491 | @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data())
492 | def test_unstack(x, data):
493 | axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
494 | kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
495 | out = xp.unstack(x, **kw)
496 |
497 | assert isinstance(out, tuple)
498 | assert len(out) == x.shape[axis]
499 | expected_shape = list(x.shape)
500 | expected_shape.pop(axis)
501 | expected_shape = tuple(expected_shape)
502 | for i in range(x.shape[axis]):
503 | arr = out[i]
504 | ph.assert_result_shape("unstack", in_shapes=[x.shape],
505 | out_shape=arr.shape, expected=expected_shape,
506 | kw=kw, repr_name=f"out[{i}].shape")
507 |
508 | ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
509 | repr_name=f"out[{i}].dtype")
510 |
511 | idx = [slice(None)] * x.ndim
512 | idx[axis] = i
513 | ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")
514 |
--------------------------------------------------------------------------------
/array_api_tests/test_searching_functions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pytest
4 | from hypothesis import given, note, assume
5 | from hypothesis import strategies as st
6 |
7 | from . import _array_module as xp
8 | from . import dtype_helpers as dh
9 | from . import hypothesis_helpers as hh
10 | from . import pytest_helpers as ph
11 | from . import shape_helpers as sh
12 | from . import xps
13 |
14 |
15 | pytestmark = pytest.mark.unvectorized
16 |
17 |
18 | @given(
19 | x=hh.arrays(
20 | dtype=hh.real_dtypes,
21 | shape=hh.shapes(min_dims=1, min_side=1),
22 | elements={"allow_nan": False},
23 | ),
24 | data=st.data(),
25 | )
26 | def test_argmax(x, data):
27 | kw = data.draw(
28 | hh.kwargs(
29 | axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
30 | keepdims=st.booleans(),
31 | ),
32 | label="kw",
33 | )
34 | keepdims = kw.get("keepdims", False)
35 |
36 | out = xp.argmax(x, **kw)
37 |
38 | ph.assert_default_index("argmax", out.dtype)
39 | axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
40 | ph.assert_keepdimable_shape(
41 | "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
42 | )
43 | scalar_type = dh.get_scalar_type(x.dtype)
44 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
45 | max_i = int(out[out_idx])
46 | elements = []
47 | for idx in indices:
48 | s = scalar_type(x[idx])
49 | elements.append(s)
50 | expected = max(range(len(elements)), key=elements.__getitem__)
51 | ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i,
52 | expected=expected, kw=kw)
53 |
54 |
55 | @given(
56 | x=hh.arrays(
57 | dtype=hh.real_dtypes,
58 | shape=hh.shapes(min_dims=1, min_side=1),
59 | elements={"allow_nan": False},
60 | ),
61 | data=st.data(),
62 | )
63 | def test_argmin(x, data):
64 | kw = data.draw(
65 | hh.kwargs(
66 | axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
67 | keepdims=st.booleans(),
68 | ),
69 | label="kw",
70 | )
71 | keepdims = kw.get("keepdims", False)
72 |
73 | out = xp.argmin(x, **kw)
74 |
75 | ph.assert_default_index("argmin", out.dtype)
76 | axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
77 | ph.assert_keepdimable_shape(
78 | "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
79 | )
80 | scalar_type = dh.get_scalar_type(x.dtype)
81 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
82 | min_i = int(out[out_idx])
83 | elements = []
84 | for idx in indices:
85 | s = scalar_type(x[idx])
86 | elements.append(s)
87 | expected = min(range(len(elements)), key=elements.__getitem__)
88 | ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
89 |
90 |
91 | # XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
92 | # the problem is tha for ints >iinfo(int32) it runs into essentially this:
93 | # >>> jnp.asarray[2147483648], dtype=jnp.int64)
94 | # .... https://github.com/jax-ml/jax/pull/6047 ...
95 | # Explicitly limiting the range in elements(...) runs into problems with
96 | # hypothesis where floating-point numbers are not exactly representable.
97 | @pytest.mark.min_version("2024.12")
98 | @given(
99 | x=hh.arrays(
100 | dtype=hh.all_dtypes,
101 | shape=hh.shapes(min_dims=1, min_side=1),
102 | elements={"allow_nan": False},
103 | ),
104 | data=st.data(),
105 | )
106 | def test_count_nonzero(x, data):
107 | kw = data.draw(
108 | hh.kwargs(
109 | axis=hh.axes(x.ndim),
110 | keepdims=st.booleans(),
111 | ),
112 | label="kw",
113 | )
114 | keepdims = kw.get("keepdims", False)
115 |
116 | assume(kw.get("axis", None) != ()) # TODO clarify in the spec
117 |
118 | out = xp.count_nonzero(x, **kw)
119 |
120 | ph.assert_default_index("count_nonzero", out.dtype)
121 | axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
122 | ph.assert_keepdimable_shape(
123 | "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
124 | )
125 | scalar_type = dh.get_scalar_type(x.dtype)
126 |
127 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
128 | count = int(out[out_idx])
129 | elements = []
130 | for idx in indices:
131 | s = scalar_type(x[idx])
132 | elements.append(s)
133 | expected = sum(el != 0 for el in elements)
134 | ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
135 |
136 |
137 | @given(hh.arrays(dtype=hh.all_dtypes, shape=()))
138 | def test_nonzero_zerodim_error(x):
139 | with pytest.raises(Exception):
140 | xp.nonzero(x)
141 |
142 |
143 | @pytest.mark.data_dependent_shapes
144 | @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1)))
145 | def test_nonzero(x):
146 | out = xp.nonzero(x)
147 | assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
148 | out_size = math.prod(out[0].shape)
149 | for i in range(len(out)):
150 | assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
151 | size_at = math.prod(out[i].shape)
152 | assert size_at == out_size, (
153 | f"prod(out[{i}].shape)={size_at}, "
154 | f"but should be prod(out[0].shape)={out_size}"
155 | )
156 | ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
157 | indices = []
158 | if x.dtype == xp.bool:
159 | for idx in sh.ndindex(x.shape):
160 | if x[idx]:
161 | indices.append(idx)
162 | else:
163 | for idx in sh.ndindex(x.shape):
164 | if x[idx] != 0:
165 | indices.append(idx)
166 | if x.ndim == 0:
167 | assert out_size == len(
168 | indices
169 | ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
170 | else:
171 | for i in range(out_size):
172 | idx = tuple(int(x[i]) for x in out)
173 | f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
174 | f_element = f"x[{idx}]={x[idx]}"
175 | assert idx in indices, f"{f_idx} results in {f_element}, a zero element"
176 | assert (
177 | idx == indices[i]
178 | ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
179 |
180 |
181 | @given(
182 | shapes=hh.mutually_broadcastable_shapes(3),
183 | dtypes=hh.mutually_promotable_dtypes(),
184 | data=st.data(),
185 | )
186 | def test_where(shapes, dtypes, data):
187 | cond = data.draw(hh.arrays(dtype=xp.bool, shape=shapes[0]), label="condition")
188 | x1 = data.draw(hh.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1")
189 | x2 = data.draw(hh.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2")
190 |
191 | out = xp.where(cond, x1, x2)
192 |
193 | shape = sh.broadcast_shapes(*shapes)
194 | ph.assert_shape("where", out_shape=out.shape, expected=shape)
195 | # TODO: generate indices without broadcasting arrays
196 | _cond = xp.broadcast_to(cond, shape)
197 | _x1 = xp.broadcast_to(x1, shape)
198 | _x2 = xp.broadcast_to(x2, shape)
199 | for idx in sh.ndindex(shape):
200 | if _cond[idx]:
201 | ph.assert_0d_equals(
202 | "where",
203 | x_repr=f"_x1[{idx}]",
204 | x_val=_x1[idx],
205 | out_repr=f"out[{idx}]",
206 | out_val=out[idx]
207 | )
208 | else:
209 | ph.assert_0d_equals(
210 | "where",
211 | x_repr=f"_x2[{idx}]",
212 | x_val=_x2[idx],
213 | out_repr=f"out[{idx}]",
214 | out_val=out[idx]
215 | )
216 |
217 |
218 | @pytest.mark.min_version("2023.12")
219 | @given(data=st.data())
220 | def test_searchsorted(data):
221 | # TODO: test side="right"
222 | # TODO: Allow different dtypes for x1 and x2
223 | _x1 = data.draw(
224 | st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
225 | label="_x1",
226 | )
227 | x1 = xp.asarray(_x1, dtype=dh.default_float)
228 | if data.draw(st.booleans(), label="use sorter?"):
229 | sorter = xp.argsort(x1)
230 | else:
231 | sorter = None
232 | x1 = xp.sort(x1)
233 | note(f"{x1=}")
234 | x2 = data.draw(
235 | st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
236 | lambda o: xp.asarray(o, dtype=dh.default_float)
237 | ),
238 | label="x2",
239 | )
240 |
241 | out = xp.searchsorted(x1, x2, sorter=sorter)
242 |
243 | ph.assert_dtype(
244 | "searchsorted",
245 | in_dtype=[x1.dtype, x2.dtype],
246 | out_dtype=out.dtype,
247 | expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
248 | )
249 | # TODO: shapes and values testing
250 |
--------------------------------------------------------------------------------
/array_api_tests/test_set_functions.py:
--------------------------------------------------------------------------------
1 | # TODO: disable if opted out, refactor things
2 | import cmath
3 | import math
4 | from collections import Counter, defaultdict
5 |
6 | import pytest
7 | from hypothesis import assume, given
8 |
9 | from . import _array_module as xp
10 | from . import dtype_helpers as dh
11 | from . import hypothesis_helpers as hh
12 | from . import pytest_helpers as ph
13 | from . import shape_helpers as sh
14 |
15 | pytestmark = [pytest.mark.data_dependent_shapes, pytest.mark.unvectorized]
16 |
17 |
18 | @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
19 | def test_unique_all(x):
20 | out = xp.unique_all(x)
21 |
22 | assert hasattr(out, "values")
23 | assert hasattr(out, "indices")
24 | assert hasattr(out, "inverse_indices")
25 | assert hasattr(out, "counts")
26 |
27 | ph.assert_dtype(
28 | "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
29 | )
30 | ph.assert_default_index(
31 | "unique_all", out.indices.dtype, repr_name="out.indices.dtype"
32 | )
33 | ph.assert_default_index(
34 | "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype"
35 | )
36 | ph.assert_default_index(
37 | "unique_all", out.counts.dtype, repr_name="out.counts.dtype"
38 | )
39 |
40 | assert (
41 | out.indices.shape == out.values.shape
42 | ), f"{out.indices.shape=}, but should be {out.values.shape=}"
43 | ph.assert_shape(
44 | "unique_all",
45 | out_shape=out.inverse_indices.shape,
46 | expected=x.shape,
47 | repr_name="out.inverse_indices.shape",
48 | )
49 | assert (
50 | out.counts.shape == out.values.shape
51 | ), f"{out.counts.shape=}, but should be {out.values.shape=}"
52 |
53 | scalar_type = dh.get_scalar_type(out.values.dtype)
54 | counts = defaultdict(int)
55 | firsts = {}
56 | for i, idx in enumerate(sh.ndindex(x.shape)):
57 | val = scalar_type(x[idx])
58 | if counts[val] == 0:
59 | firsts[val] = i
60 | counts[val] += 1
61 |
62 | for idx in sh.ndindex(out.indices.shape):
63 | val = scalar_type(out.values[idx])
64 | if cmath.isnan(val):
65 | break
66 | i = int(out.indices[idx])
67 | expected = firsts[val]
68 | assert i == expected, (
69 | f"out.values[{idx}]={val} and out.indices[{idx}]={i}, "
70 | f"but first occurence of {val} is at {expected}"
71 | )
72 |
73 | for idx in sh.ndindex(out.inverse_indices.shape):
74 | ridx = int(out.inverse_indices[idx])
75 | val = out.values[ridx]
76 | expected = x[idx]
77 | msg = (
78 | f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
79 | f"but should result in x[{idx}]={expected}"
80 | )
81 | if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
82 | assert xp.isnan(val), msg
83 | else:
84 | assert val == expected, msg
85 |
86 | vals_idx = {}
87 | nans = 0
88 | for idx in sh.ndindex(out.values.shape):
89 | val = scalar_type(out.values[idx])
90 | count = int(out.counts[idx])
91 | if cmath.isnan(val):
92 | nans += 1
93 | assert count == 1, (
94 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
95 | "but count should be 1 as NaNs are distinct"
96 | )
97 | else:
98 | expected = counts[val]
99 | assert (
100 | expected > 0
101 | ), f"out.values[{idx}]={val}, but {val} not in input array"
102 | count = int(out.counts[idx])
103 | assert count == expected, (
104 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
105 | f"but should be {expected}"
106 | )
107 | assert (
108 | val not in vals_idx.keys()
109 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
110 | vals_idx[val] = idx
111 |
112 | if dh.is_float_dtype(out.values.dtype):
113 | assume(math.prod(x.shape) <= 128) # may not be representable
114 | expected = sum(v for k, v in counts.items() if cmath.isnan(k))
115 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
116 |
117 |
118 | @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
119 | def test_unique_counts(x):
120 | out = xp.unique_counts(x)
121 | assert hasattr(out, "values")
122 | assert hasattr(out, "counts")
123 | ph.assert_dtype(
124 | "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
125 | )
126 | ph.assert_default_index(
127 | "unique_counts", out.counts.dtype, repr_name="out.counts.dtype"
128 | )
129 | assert (
130 | out.counts.shape == out.values.shape
131 | ), f"{out.counts.shape=}, but should be {out.values.shape=}"
132 | scalar_type = dh.get_scalar_type(out.values.dtype)
133 | counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
134 | vals_idx = {}
135 | nans = 0
136 | for idx in sh.ndindex(out.values.shape):
137 | val = scalar_type(out.values[idx])
138 | count = int(out.counts[idx])
139 | if cmath.isnan(val):
140 | nans += 1
141 | assert count == 1, (
142 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
143 | "but count should be 1 as NaNs are distinct"
144 | )
145 | else:
146 | expected = counts[val]
147 | assert (
148 | expected > 0
149 | ), f"out.values[{idx}]={val}, but {val} not in input array"
150 | count = int(out.counts[idx])
151 | assert count == expected, (
152 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
153 | f"but should be {expected}"
154 | )
155 | assert (
156 | val not in vals_idx.keys()
157 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
158 | vals_idx[val] = idx
159 | if dh.is_float_dtype(out.values.dtype):
160 | assume(math.prod(x.shape) <= 128) # may not be representable
161 | expected = sum(v for k, v in counts.items() if cmath.isnan(k))
162 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
163 |
164 |
165 | @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
166 | def test_unique_inverse(x):
167 | out = xp.unique_inverse(x)
168 | assert hasattr(out, "values")
169 | assert hasattr(out, "inverse_indices")
170 | ph.assert_dtype(
171 | "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
172 | )
173 | ph.assert_default_index(
174 | "unique_inverse",
175 | out.inverse_indices.dtype,
176 | repr_name="out.inverse_indices.dtype",
177 | )
178 | ph.assert_shape(
179 | "unique_inverse",
180 | out_shape=out.inverse_indices.shape,
181 | expected=x.shape,
182 | repr_name="out.inverse_indices.shape",
183 | )
184 | scalar_type = dh.get_scalar_type(out.values.dtype)
185 | distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
186 | vals_idx = {}
187 | nans = 0
188 | for idx in sh.ndindex(out.values.shape):
189 | val = scalar_type(out.values[idx])
190 | if cmath.isnan(val):
191 | nans += 1
192 | else:
193 | assert (
194 | val in distinct
195 | ), f"out.values[{idx}]={val}, but {val} not in input array"
196 | assert (
197 | val not in vals_idx.keys()
198 | ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
199 | vals_idx[val] = idx
200 | for idx in sh.ndindex(out.inverse_indices.shape):
201 | ridx = int(out.inverse_indices[idx])
202 | val = out.values[ridx]
203 | expected = x[idx]
204 | msg = (
205 | f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
206 | f"but should result in x[{idx}]={expected}"
207 | )
208 | if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
209 | assert xp.isnan(val), msg
210 | else:
211 | assert val == expected, msg
212 | if dh.is_float_dtype(out.values.dtype):
213 | assume(math.prod(x.shape) <= 128) # may not be representable
214 | expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
215 | assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}"
216 |
217 |
218 | @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
219 | def test_unique_values(x):
220 | out = xp.unique_values(x)
221 | ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype)
222 | scalar_type = dh.get_scalar_type(x.dtype)
223 | distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
224 | vals_idx = {}
225 | nans = 0
226 | for idx in sh.ndindex(out.shape):
227 | val = scalar_type(out[idx])
228 | if cmath.isnan(val):
229 | nans += 1
230 | else:
231 | assert val in distinct, f"out[{idx}]={val}, but {val} not in input array"
232 | assert (
233 | val not in vals_idx.keys()
234 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
235 | vals_idx[val] = idx
236 | if dh.is_float_dtype(out.dtype):
237 | assume(math.prod(x.shape) <= 128) # may not be representable
238 | expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
239 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
240 |
--------------------------------------------------------------------------------
/array_api_tests/test_signatures.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for function/method signatures compliance
3 |
4 | We're not interested in being 100% strict - instead we focus on areas which
5 | could affect interop, e.g. with
6 |
7 | def add(x1, x2, /):
8 | ...
9 |
10 | x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
11 |
12 | def squeeze(x, /, axis):
13 | ...
14 |
15 | axis has to be pos-or-keyword to support both styles
16 |
17 | >>> squeeze(x, 0)
18 | ...
19 | >>> squeeze(x, axis=0)
20 | ...
21 |
22 | """
23 | from collections import defaultdict
24 | from copy import copy
25 | from inspect import Parameter, Signature, signature
26 | from types import FunctionType
27 | from typing import Any, Callable, Dict, Literal, get_args
28 | from warnings import warn
29 |
30 | import pytest
31 |
32 | from . import dtype_helpers as dh
33 | from . import xp
34 | from .stubs import (array_methods, category_to_funcs, extension_to_funcs,
35 | name_to_func, info_funcs)
36 |
37 | ParameterKind = Literal[
38 | Parameter.POSITIONAL_ONLY,
39 | Parameter.VAR_POSITIONAL,
40 | Parameter.POSITIONAL_OR_KEYWORD,
41 | Parameter.KEYWORD_ONLY,
42 | Parameter.VAR_KEYWORD,
43 | ]
44 | ALL_KINDS = get_args(ParameterKind)
45 | VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
46 | kind_to_str: Dict[ParameterKind, str] = {
47 | Parameter.POSITIONAL_OR_KEYWORD: "pos or kw argument",
48 | Parameter.POSITIONAL_ONLY: "pos-only argument",
49 | Parameter.KEYWORD_ONLY: "keyword-only argument",
50 | Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument",
51 | Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
52 | }
53 |
54 |
55 | def _test_inspectable_func(sig: Signature, stub_sig: Signature):
56 | params = list(sig.parameters.values())
57 | stub_params = list(stub_sig.parameters.values())
58 |
59 | non_kwonly_stub_params = [
60 | p for p in stub_params if p.kind != Parameter.KEYWORD_ONLY
61 | ]
62 | # sanity check
63 | assert non_kwonly_stub_params == stub_params[: len(non_kwonly_stub_params)]
64 | # We're not interested if the array module has additional arguments, so we
65 | # only iterate through the arguments listed in the spec.
66 | for i, stub_param in enumerate(non_kwonly_stub_params):
67 | assert (
68 | len(params) >= i + 1
69 | ), f"Argument '{stub_param.name}' missing from signature"
70 | param = params[i]
71 |
72 | # We're not interested in the name if it isn't actually used
73 | if stub_param.kind not in [Parameter.POSITIONAL_ONLY, *VAR_KINDS]:
74 | assert (
75 | param.name == stub_param.name
76 | ), f"Expected argument '{param.name}' to be named '{stub_param.name}'"
77 |
78 | if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
79 | f_stub_kind = kind_to_str[stub_param.kind]
80 | assert param.kind == stub_param.kind, (
81 | f"{param.name} is a {kind_to_str[param.kind]}, "
82 | f"but should be a {f_stub_kind}"
83 | )
84 |
85 | kwonly_stub_params = stub_params[len(non_kwonly_stub_params) :]
86 | for stub_param in kwonly_stub_params:
87 | assert (
88 | stub_param.name in sig.parameters.keys()
89 | ), f"Argument '{stub_param.name}' missing from signature"
90 | param = next(p for p in params if p.name == stub_param.name)
91 | f_stub_kind = kind_to_str[stub_param.kind]
92 | assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], (
93 | f"{param.name} is a {kind_to_str[param.kind]}, "
94 | f"but should be a {f_stub_kind} "
95 | f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
96 | )
97 |
98 |
99 | def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
100 | f_sig = f"{func_name}("
101 | f_sig += ", ".join(str(a) for a in args)
102 | if len(kwargs) != 0:
103 | if len(args) != 0:
104 | f_sig += ", "
105 | f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items())
106 | f_sig += ")"
107 | return f_sig
108 |
109 |
110 | # We test uninspectable signatures by passing valid, manually-defined arguments
111 | # to the signature's function/method.
112 | #
113 | # Arguments which require use of the array module are specified as string
114 | # expressions to be eval()'d on runtime. This is as opposed to just using the
115 | # array module whilst setting up the tests, which is prone to halt the entire
116 | # test suite if an array module doesn't support a given expression.
117 | func_to_specified_args = defaultdict(
118 | dict,
119 | {
120 | "permute_dims": {"axes": 0},
121 | "reshape": {"shape": (1, 5)},
122 | "broadcast_to": {"shape": (1, 5)},
123 | "asarray": {"obj": [0, 1, 2, 3, 4]},
124 | "full_like": {"fill_value": 42},
125 | "matrix_power": {"n": 2},
126 | },
127 | )
128 | func_to_specified_arg_exprs = defaultdict(
129 | dict,
130 | {
131 | "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
132 | "iinfo": {"type": "xp.int64"},
133 | "finfo": {"type": "xp.float64"},
134 | "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
135 | "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
136 | "solve": {
137 | a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
138 | },
139 | "outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"},
140 | },
141 | )
142 | # We default most array arguments heuristically. As functions/methods work only
143 | # with arrays of certain dtypes and shapes, we specify only supported arrays
144 | # respective to the function.
145 | casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"]
146 | matrixy_names = [
147 | f.__name__
148 | for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"]
149 | ]
150 | matrixy_names += ["__matmul__", "triu", "tril"]
151 | for func_name, func in name_to_func.items():
152 | stub_sig = signature(func)
153 | array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"}
154 | if func in array_methods:
155 | array_argnames.add("self")
156 | array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
157 | if len(array_argnames) > 0:
158 | in_dtypes = dh.func_in_dtypes[func_name]
159 | for dtype_name in ["float64", "bool", "int64", "complex128"]:
160 | # We try float64 first because uninspectable numerical functions
161 | # tend to support float inputs first-and-foremost (i.e. PyTorch)
162 | try:
163 | dtype = getattr(xp, dtype_name)
164 | except AttributeError:
165 | pass
166 | else:
167 | if dtype in in_dtypes:
168 | if func_name in casty_names:
169 | shape = ()
170 | elif func_name in matrixy_names:
171 | shape = (3, 3)
172 | else:
173 | shape = (5,)
174 | fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
175 | break
176 | else:
177 | warn(
178 | f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does "
179 | "not contain any assumed dtypes, so skipping specifying fallback array."
180 | )
181 | continue
182 | for argname in array_argnames:
183 | func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr
184 |
185 |
186 | def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
187 | params = list(stub_sig.parameters.values())
188 |
189 | if len(params) == 0:
190 | func()
191 | return
192 |
193 | uninspectable_msg = (
194 | f"Note {func_name}() is not inspectable so arguments are passed "
195 | "manually to test the signature."
196 | )
197 |
198 | argname_to_arg = copy(func_to_specified_args[func_name])
199 | argname_to_expr = func_to_specified_arg_exprs[func_name]
200 | for argname, expr in argname_to_expr.items():
201 | assert argname not in argname_to_arg.keys() # sanity check
202 | try:
203 | argname_to_arg[argname] = eval(expr, {"xp": xp})
204 | except Exception as e:
205 | pytest.skip(
206 | f"Exception occured when evaluating {argname}={expr}: {e}\n"
207 | f"{uninspectable_msg}"
208 | )
209 |
210 | posargs = []
211 | posorkw_args = {}
212 | kwargs = {}
213 | no_arg_msg = (
214 | "We have no argument specified for '{}'. Please ensure you're using "
215 | "the latest version of array-api-tests, then open an issue if one "
216 | f"doesn't already exist. {uninspectable_msg}"
217 | )
218 | for param in params:
219 | if param.kind == Parameter.POSITIONAL_ONLY:
220 | try:
221 | posargs.append(argname_to_arg[param.name])
222 | except KeyError:
223 | pytest.skip(no_arg_msg.format(param.name))
224 | elif param.kind == Parameter.POSITIONAL_OR_KEYWORD:
225 | if param.default == Parameter.empty:
226 | try:
227 | posorkw_args[param.name] = argname_to_arg[param.name]
228 | except KeyError:
229 | pytest.skip(no_arg_msg.format(param.name))
230 | else:
231 | assert argname_to_arg[param.name]
232 | posorkw_args[param.name] = param.default
233 | elif param.kind == Parameter.KEYWORD_ONLY:
234 | assert param.default != Parameter.empty # sanity check
235 | kwargs[param.name] = param.default
236 | else:
237 | assert param.kind in VAR_KINDS # sanity check
238 | pytest.skip(no_arg_msg.format(param.name))
239 | if len(posorkw_args) == 0:
240 | func(*posargs, **kwargs)
241 | else:
242 | posorkw_name_to_arg_pairs = list(posorkw_args.items())
243 | for i in range(len(posorkw_name_to_arg_pairs), -1, -1):
244 | extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]]
245 | extra_kwargs = dict(posorkw_name_to_arg_pairs[i:])
246 | func(*posargs, *extra_posargs, **kwargs, **extra_kwargs)
247 |
248 |
249 | def _test_func_signature(func: Callable, stub: FunctionType, is_method=False):
250 | stub_sig = signature(stub)
251 | # If testing against array, ignore 'self' arg in stub as it won't be present
252 | # in func (which should be a method).
253 | if is_method:
254 | stub_params = list(stub_sig.parameters.values())
255 | if stub_params[0].name == "self":
256 | del stub_params[0]
257 | stub_sig = Signature(
258 | parameters=stub_params, return_annotation=stub_sig.return_annotation
259 | )
260 |
261 | try:
262 | sig = signature(func)
263 | except ValueError:
264 | try:
265 | _test_uninspectable_func(stub.__name__, func, stub_sig)
266 | except Exception as e:
267 | raise e from None # suppress parent exception for cleaner pytest output
268 | else:
269 | _test_inspectable_func(sig, stub_sig)
270 |
271 |
272 | @pytest.mark.parametrize(
273 | "stub",
274 | [s for stubs in category_to_funcs.values() for s in stubs],
275 | ids=lambda f: f.__name__,
276 | )
277 | def test_func_signature(stub: FunctionType):
278 | assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
279 | func = getattr(xp, stub.__name__)
280 | _test_func_signature(func, stub)
281 |
282 |
283 | extension_and_stub_params = []
284 | for ext, stubs in extension_to_funcs.items():
285 | for stub in stubs:
286 | p = pytest.param(
287 | ext, stub, id=f"{ext}.{stub.__name__}", marks=pytest.mark.xp_extension(ext)
288 | )
289 | extension_and_stub_params.append(p)
290 |
291 |
292 | @pytest.mark.parametrize("extension, stub", extension_and_stub_params)
293 | def test_extension_func_signature(extension: str, stub: FunctionType):
294 | mod = getattr(xp, extension)
295 | assert hasattr(
296 | mod, stub.__name__
297 | ), f"{stub.__name__} not found in {extension} extension"
298 | func = getattr(mod, stub.__name__)
299 | _test_func_signature(func, stub)
300 |
301 |
302 | @pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
303 | def test_array_method_signature(stub: FunctionType):
304 | x_expr = func_to_specified_arg_exprs[stub.__name__]["self"]
305 | try:
306 | x = eval(x_expr, {"xp": xp})
307 | except Exception as e:
308 | pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}")
309 | assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
310 | method = getattr(x, stub.__name__)
311 | _test_func_signature(method, stub, is_method=True)
312 |
313 | if info_funcs: # pytest fails collecting if info_funcs is empty
314 | @pytest.mark.min_version("2023.12")
315 | @pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__)
316 | def test_info_func_signature(stub: FunctionType):
317 | try:
318 | info_namespace = xp.__array_namespace_info__()
319 | except Exception as e:
320 | raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}")
321 |
322 | func = getattr(info_namespace, stub.__name__)
323 | _test_func_signature(func, stub)
324 |
--------------------------------------------------------------------------------
/array_api_tests/test_sorting_functions.py:
--------------------------------------------------------------------------------
1 | import cmath
2 | from typing import Set
3 |
4 | import pytest
5 | from hypothesis import given
6 | from hypothesis import strategies as st
7 | from hypothesis.control import assume
8 |
9 | from . import _array_module as xp
10 | from . import dtype_helpers as dh
11 | from . import hypothesis_helpers as hh
12 | from . import pytest_helpers as ph
13 | from . import shape_helpers as sh
14 | from .typing import Scalar, Shape
15 |
16 |
17 | def assert_scalar_in_set(
18 | func_name: str,
19 | idx: Shape,
20 | out: Scalar,
21 | set_: Set[Scalar],
22 | kw={},
23 | ):
24 | out_repr = "out" if idx == () else f"out[{idx}]"
25 | if cmath.isnan(out):
26 | raise NotImplementedError()
27 | msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]"
28 | assert out in set_, msg
29 |
30 |
31 | # TODO: Test with signed zeros and NaNs (and ignore them somehow)
32 | @pytest.mark.unvectorized
33 | @given(
34 | x=hh.arrays(
35 | dtype=hh.real_dtypes,
36 | shape=hh.shapes(min_dims=1, min_side=1),
37 | elements={"allow_nan": False},
38 | ),
39 | data=st.data(),
40 | )
41 | def test_argsort(x, data):
42 | if dh.is_float_dtype(x.dtype):
43 | assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
44 |
45 | kw = data.draw(
46 | hh.kwargs(
47 | axis=st.integers(-x.ndim, x.ndim - 1),
48 | descending=st.booleans(),
49 | stable=st.booleans(),
50 | ),
51 | label="kw",
52 | )
53 |
54 | out = xp.argsort(x, **kw)
55 |
56 | ph.assert_default_index("argsort", out.dtype)
57 | ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw)
58 | axis = kw.get("axis", -1)
59 | axes = sh.normalize_axis(axis, x.ndim)
60 | scalar_type = dh.get_scalar_type(x.dtype)
61 | for indices in sh.axes_ndindex(x.shape, axes):
62 | elements = [scalar_type(x[idx]) for idx in indices]
63 | orders = list(range(len(elements)))
64 | sorders = sorted(
65 | orders, key=elements.__getitem__, reverse=kw.get("descending", False)
66 | )
67 | if kw.get("stable", True):
68 | for idx, o in zip(indices, sorders):
69 | ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw)
70 | else:
71 | idx_elements = dict(zip(indices, elements))
72 | idx_orders = dict(zip(indices, orders))
73 | element_orders = {}
74 | for e in set(elements):
75 | element_orders[e] = [
76 | idx_orders[idx] for idx in indices if idx_elements[idx] == e
77 | ]
78 | selements = [elements[o] for o in sorders]
79 | for idx, e in zip(indices, selements):
80 | expected_orders = element_orders[e]
81 | out_o = int(out[idx])
82 | if len(expected_orders) == 1:
83 | ph.assert_scalar_equals(
84 | "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw
85 | )
86 | else:
87 | assert_scalar_in_set(
88 | "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw
89 | )
90 |
91 |
92 | @pytest.mark.unvectorized
93 | # TODO: Test with signed zeros and NaNs (and ignore them somehow)
94 | @given(
95 | x=hh.arrays(
96 | dtype=hh.real_dtypes,
97 | shape=hh.shapes(min_dims=1, min_side=1),
98 | elements={"allow_nan": False},
99 | ),
100 | data=st.data(),
101 | )
102 | def test_sort(x, data):
103 | if dh.is_float_dtype(x.dtype):
104 | assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
105 |
106 | kw = data.draw(
107 | hh.kwargs(
108 | axis=st.integers(-x.ndim, x.ndim - 1),
109 | descending=st.booleans(),
110 | stable=st.booleans(),
111 | ),
112 | label="kw",
113 | )
114 |
115 | out = xp.sort(x, **kw)
116 |
117 | ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype)
118 | ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw)
119 | axis = kw.get("axis", -1)
120 | axes = sh.normalize_axis(axis, x.ndim)
121 | scalar_type = dh.get_scalar_type(x.dtype)
122 | for indices in sh.axes_ndindex(x.shape, axes):
123 | elements = [scalar_type(x[idx]) for idx in indices]
124 | size = len(elements)
125 | orders = sorted(
126 | range(size), key=elements.__getitem__, reverse=kw.get("descending", False)
127 | )
128 | for out_idx, o in zip(indices, orders):
129 | x_idx = indices[o]
130 | # TODO: error message when unstable should not imply just one idx
131 | ph.assert_0d_equals(
132 | "sort",
133 | x_repr=f"x[{x_idx}]",
134 | x_val=x[x_idx],
135 | out_repr=f"out[{out_idx}]",
136 | out_val=out[out_idx],
137 | kw=kw,
138 | )
139 |
--------------------------------------------------------------------------------
/array_api_tests/test_statistical_functions.py:
--------------------------------------------------------------------------------
1 | import cmath
2 | import math
3 | from typing import Optional
4 |
5 | import pytest
6 | from hypothesis import assume, given
7 | from hypothesis import strategies as st
8 | from ndindex import iter_indices
9 |
10 | from . import _array_module as xp
11 | from . import dtype_helpers as dh
12 | from . import hypothesis_helpers as hh
13 | from . import pytest_helpers as ph
14 | from . import shape_helpers as sh
15 | from ._array_module import _UndefinedStub
16 | from .typing import DataType
17 |
18 |
19 | @pytest.mark.min_version("2023.12")
20 | @pytest.mark.unvectorized
21 | @given(
22 | x=hh.arrays(
23 | dtype=hh.numeric_dtypes,
24 | shape=hh.shapes(min_dims=1)),
25 | data=st.data(),
26 | )
27 | def test_cumulative_sum(x, data):
28 | axes = st.integers(-x.ndim, x.ndim - 1)
29 | if x.ndim == 1:
30 | axes = axes | st.none()
31 | axis = data.draw(axes, label='axis')
32 | _axis, = sh.normalize_axis(axis, x.ndim)
33 | dtype = data.draw(kwarg_dtypes(x.dtype))
34 | include_initial = data.draw(st.booleans(), label="include_initial")
35 |
36 | kw = data.draw(
37 | hh.specified_kwargs(
38 | ("axis", axis, None),
39 | ("dtype", dtype, None),
40 | ("include_initial", include_initial, False),
41 | ),
42 | label="kw",
43 | )
44 |
45 | out = xp.cumulative_sum(x, **kw)
46 |
47 | expected_shape = list(x.shape)
48 | if include_initial:
49 | expected_shape[_axis] += 1
50 | expected_shape = tuple(expected_shape)
51 | ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape)
52 |
53 | expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
54 | if expected_dtype is None:
55 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support
56 | # uint32 or uint64), we skip testing the output dtype.
57 | # See https://github.com/data-apis/array-api-tests/issues/106
58 | if x.dtype in dh.uint_dtypes:
59 | assert dh.is_int_dtype(out.dtype) # sanity check
60 | else:
61 | ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
62 |
63 | scalar_type = dh.get_scalar_type(out.dtype)
64 |
65 | for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
66 | x_arr = x[x_idx.raw]
67 | out_arr = out[out_idx.raw]
68 |
69 | if include_initial:
70 | ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0)
71 |
72 | for n in range(x.shape[_axis]):
73 | start = 1 if include_initial else 0
74 | out_val = out_arr[n + start]
75 | assume(cmath.isfinite(out_val))
76 | elements = []
77 | for idx in range(n + 1):
78 | s = scalar_type(x_arr[idx])
79 | elements.append(s)
80 | expected = sum(elements)
81 | if dh.is_int_dtype(out.dtype):
82 | m, M = dh.dtype_ranges[out.dtype]
83 | assume(m <= expected <= M)
84 | ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
85 | idx=out_idx.raw, out=out_val,
86 | expected=expected)
87 | else:
88 | condition_number = _sum_condition_number(elements)
89 | assume(condition_number < 1e6)
90 | ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type,
91 | idx=out_idx.raw, out=out_val,
92 | expected=expected)
93 |
94 |
95 |
96 | @pytest.mark.min_version("2024.12")
97 | @pytest.mark.unvectorized
98 | @given(
99 | x=hh.arrays(
100 | dtype=hh.numeric_dtypes,
101 | shape=hh.shapes(min_dims=1)),
102 | data=st.data(),
103 | )
104 | def test_cumulative_prod(x, data):
105 | axes = st.integers(-x.ndim, x.ndim - 1)
106 | if x.ndim == 1:
107 | axes = axes | st.none()
108 | axis = data.draw(axes, label='axis')
109 | _axis, = sh.normalize_axis(axis, x.ndim)
110 | dtype = data.draw(kwarg_dtypes(x.dtype))
111 | include_initial = data.draw(st.booleans(), label="include_initial")
112 |
113 | kw = data.draw(
114 | hh.specified_kwargs(
115 | ("axis", axis, None),
116 | ("dtype", dtype, None),
117 | ("include_initial", include_initial, False),
118 | ),
119 | label="kw",
120 | )
121 |
122 | out = xp.cumulative_prod(x, **kw)
123 |
124 | expected_shape = list(x.shape)
125 | if include_initial:
126 | expected_shape[_axis] += 1
127 | expected_shape = tuple(expected_shape)
128 | ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape)
129 |
130 | expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
131 | if expected_dtype is None:
132 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support
133 | # uint32 or uint64), we skip testing the output dtype.
134 | # See https://github.com/data-apis/array-api-tests/issues/106
135 | if x.dtype in dh.uint_dtypes:
136 | assert dh.is_int_dtype(out.dtype) # sanity check
137 | else:
138 | ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
139 |
140 | scalar_type = dh.get_scalar_type(out.dtype)
141 |
142 | for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
143 | #x_arr = x[x_idx.raw]
144 | out_arr = out[out_idx.raw]
145 |
146 | if include_initial:
147 | ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1)
148 |
149 | #TODO: add value testing of cumulative_prod
150 |
151 |
152 | def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
153 | dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
154 | dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
155 | assert len(dtypes) > 0 # sanity check
156 | return st.none() | st.sampled_from(dtypes)
157 |
158 |
159 | @pytest.mark.unvectorized
160 | @given(
161 | x=hh.arrays(
162 | dtype=hh.real_dtypes,
163 | shape=hh.shapes(min_side=1),
164 | elements={"allow_nan": False},
165 | ),
166 | data=st.data(),
167 | )
168 | def test_max(x, data):
169 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
170 | keepdims = kw.get("keepdims", False)
171 |
172 | out = xp.max(x, **kw)
173 |
174 | ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype)
175 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
176 | ph.assert_keepdimable_shape(
177 | "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
178 | )
179 | scalar_type = dh.get_scalar_type(out.dtype)
180 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
181 | max_ = scalar_type(out[out_idx])
182 | elements = []
183 | for idx in indices:
184 | s = scalar_type(x[idx])
185 | elements.append(s)
186 | expected = max(elements)
187 | ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected)
188 |
189 |
190 | @given(
191 | x=hh.arrays(
192 | dtype=hh.real_floating_dtypes,
193 | shape=hh.shapes(min_side=1),
194 | elements={"allow_nan": False},
195 | ),
196 | data=st.data(),
197 | )
198 | def test_mean(x, data):
199 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
200 | keepdims = kw.get("keepdims", False)
201 |
202 | out = xp.mean(x, **kw)
203 |
204 | ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype)
205 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
206 | ph.assert_keepdimable_shape(
207 | "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
208 | )
209 | # Values testing mean is too finicky
210 |
211 |
212 | @pytest.mark.unvectorized
213 | @given(
214 | x=hh.arrays(
215 | dtype=hh.real_dtypes,
216 | shape=hh.shapes(min_side=1),
217 | elements={"allow_nan": False},
218 | ),
219 | data=st.data(),
220 | )
221 | def test_min(x, data):
222 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
223 | keepdims = kw.get("keepdims", False)
224 |
225 | out = xp.min(x, **kw)
226 |
227 | ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype)
228 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
229 | ph.assert_keepdimable_shape(
230 | "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
231 | )
232 | scalar_type = dh.get_scalar_type(out.dtype)
233 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
234 | min_ = scalar_type(out[out_idx])
235 | elements = []
236 | for idx in indices:
237 | s = scalar_type(x[idx])
238 | elements.append(s)
239 | expected = min(elements)
240 | ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
241 |
242 |
243 | def _prod_condition_number(elements):
244 | # Relative condition number using the infinity norm
245 | abs_max = max([abs(i) for i in elements])
246 | abs_min = min([abs(i) for i in elements])
247 |
248 | if abs_min == 0:
249 | return float('inf')
250 |
251 | return abs_max / abs_min
252 |
253 | @pytest.mark.unvectorized
254 | @given(
255 | x=hh.arrays(
256 | dtype=hh.numeric_dtypes,
257 | shape=hh.shapes(min_side=1),
258 | elements={"allow_nan": False},
259 | ),
260 | data=st.data(),
261 | )
262 | def test_prod(x, data):
263 | kw = data.draw(
264 | hh.kwargs(
265 | axis=hh.axes(x.ndim),
266 | dtype=kwarg_dtypes(x.dtype),
267 | keepdims=st.booleans(),
268 | ),
269 | label="kw",
270 | )
271 | keepdims = kw.get("keepdims", False)
272 |
273 | with hh.reject_overflow():
274 | out = xp.prod(x, **kw)
275 |
276 | dtype = kw.get("dtype", None)
277 | expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
278 | if expected_dtype is None:
279 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support
280 | # uint32 or uint64), we skip testing the output dtype.
281 | # See https://github.com/data-apis/array-api-tests/issues/106
282 | if x.dtype in dh.uint_dtypes:
283 | assert dh.is_int_dtype(out.dtype) # sanity check
284 | else:
285 | ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
286 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
287 | ph.assert_keepdimable_shape(
288 | "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
289 | )
290 | scalar_type = dh.get_scalar_type(out.dtype)
291 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
292 | prod = scalar_type(out[out_idx])
293 | assume(cmath.isfinite(prod))
294 | elements = []
295 | for idx in indices:
296 | s = scalar_type(x[idx])
297 | elements.append(s)
298 | expected = math.prod(elements)
299 | if dh.is_int_dtype(out.dtype):
300 | m, M = dh.dtype_ranges[out.dtype]
301 | assume(m <= expected <= M)
302 | ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx,
303 | out=prod, expected=expected)
304 | else:
305 | condition_number = _prod_condition_number(elements)
306 | assume(condition_number < 1e15)
307 | ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx,
308 | out=prod, expected=expected)
309 |
310 |
311 | @pytest.mark.skip(reason="flaky") # TODO: fix!
312 | @given(
313 | x=hh.arrays(
314 | dtype=hh.real_floating_dtypes,
315 | shape=hh.shapes(min_side=1),
316 | elements={"allow_nan": False},
317 | ).filter(lambda x: math.prod(x.shape) >= 2),
318 | data=st.data(),
319 | )
320 | def test_std(x, data):
321 | axis = data.draw(hh.axes(x.ndim), label="axis")
322 | _axes = sh.normalize_axis(axis, x.ndim)
323 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
324 | correction = data.draw(
325 | st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
326 | label="correction",
327 | )
328 | _keepdims = data.draw(st.booleans(), label="keepdims")
329 | kw = data.draw(
330 | hh.specified_kwargs(
331 | ("axis", axis, None),
332 | ("correction", correction, 0.0),
333 | ("keepdims", _keepdims, False),
334 | ),
335 | label="kw",
336 | )
337 | keepdims = kw.get("keepdims", False)
338 |
339 | out = xp.std(x, **kw)
340 |
341 | ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype)
342 | ph.assert_keepdimable_shape(
343 | "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
344 | )
345 | # We can't easily test the result(s) as standard deviation methods vary a lot
346 |
347 |
348 | def _sum_condition_number(elements):
349 | sum_abs = sum([abs(i) for i in elements])
350 | abs_sum = abs(sum(elements))
351 |
352 | if abs_sum == 0:
353 | return float('inf')
354 |
355 | return sum_abs / abs_sum
356 |
357 | # @pytest.mark.unvectorized
358 | @given(
359 | x=hh.arrays(
360 | dtype=hh.numeric_dtypes,
361 | shape=hh.shapes(min_side=1),
362 | elements={"allow_nan": False},
363 | ),
364 | data=st.data(),
365 | )
366 | def test_sum(x, data):
367 | kw = data.draw(
368 | hh.kwargs(
369 | axis=hh.axes(x.ndim),
370 | dtype=kwarg_dtypes(x.dtype),
371 | keepdims=st.booleans(),
372 | ),
373 | label="kw",
374 | )
375 | keepdims = kw.get("keepdims", False)
376 |
377 | with hh.reject_overflow():
378 | out = xp.sum(x, **kw)
379 |
380 | dtype = kw.get("dtype", None)
381 | expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
382 | if expected_dtype is None:
383 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support
384 | # uint32 or uint64), we skip testing the output dtype.
385 | # See https://github.com/data-apis/array-api-tests/issues/160
386 | if x.dtype in dh.uint_dtypes:
387 | assert dh.is_int_dtype(out.dtype) # sanity check
388 | else:
389 | ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
390 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
391 | ph.assert_keepdimable_shape(
392 | "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
393 | )
394 | scalar_type = dh.get_scalar_type(out.dtype)
395 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
396 | sum_ = scalar_type(out[out_idx])
397 | assume(cmath.isfinite(sum_))
398 | elements = []
399 | for idx in indices:
400 | s = scalar_type(x[idx])
401 | elements.append(s)
402 | expected = sum(elements)
403 | if dh.is_int_dtype(out.dtype):
404 | m, M = dh.dtype_ranges[out.dtype]
405 | assume(m <= expected <= M)
406 | ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx,
407 | out=sum_, expected=expected)
408 | else:
409 | # Avoid value testing for ill conditioned summations. See
410 | # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and
411 | # https://en.wikipedia.org/wiki/Condition_number.
412 | condition_number = _sum_condition_number(elements)
413 | assume(condition_number < 1e6)
414 | ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
415 |
416 |
417 | @pytest.mark.unvectorized
418 | @pytest.mark.skip(reason="flaky") # TODO: fix!
419 | @given(
420 | x=hh.arrays(
421 | dtype=hh.real_floating_dtypes,
422 | shape=hh.shapes(min_side=1),
423 | elements={"allow_nan": False},
424 | ).filter(lambda x: math.prod(x.shape) >= 2),
425 | data=st.data(),
426 | )
427 | def test_var(x, data):
428 | axis = data.draw(hh.axes(x.ndim), label="axis")
429 | _axes = sh.normalize_axis(axis, x.ndim)
430 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
431 | correction = data.draw(
432 | st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
433 | label="correction",
434 | )
435 | _keepdims = data.draw(st.booleans(), label="keepdims")
436 | kw = data.draw(
437 | hh.specified_kwargs(
438 | ("axis", axis, None),
439 | ("correction", correction, 0.0),
440 | ("keepdims", _keepdims, False),
441 | ),
442 | label="kw",
443 | )
444 | keepdims = kw.get("keepdims", False)
445 |
446 | out = xp.var(x, **kw)
447 |
448 | ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype)
449 | ph.assert_keepdimable_shape(
450 | "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
451 | )
452 | # We can't easily test the result(s) as variance methods vary a lot
453 |
--------------------------------------------------------------------------------
/array_api_tests/test_utility_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from hypothesis import given
3 | from hypothesis import strategies as st
4 |
5 | from . import _array_module as xp
6 | from . import dtype_helpers as dh
7 | from . import hypothesis_helpers as hh
8 | from . import pytest_helpers as ph
9 | from . import shape_helpers as sh
10 |
11 |
12 | @pytest.mark.unvectorized
13 | @given(
14 | x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)),
15 | data=st.data(),
16 | )
17 | def test_all(x, data):
18 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
19 | keepdims = kw.get("keepdims", False)
20 |
21 | out = xp.all(x, **kw)
22 |
23 | ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
24 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
25 | ph.assert_keepdimable_shape(
26 | "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
27 | )
28 | scalar_type = dh.get_scalar_type(x.dtype)
29 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
30 | result = bool(out[out_idx])
31 | elements = []
32 | for idx in indices:
33 | s = scalar_type(x[idx])
34 | elements.append(s)
35 | expected = all(elements)
36 | ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx,
37 | out=result, expected=expected, kw=kw)
38 |
39 |
40 | @pytest.mark.unvectorized
41 | @given(
42 | x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
43 | data=st.data(),
44 | )
45 | def test_any(x, data):
46 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
47 | keepdims = kw.get("keepdims", False)
48 |
49 | out = xp.any(x, **kw)
50 |
51 | ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
52 | _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
53 | ph.assert_keepdimable_shape(
54 | "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw,
55 | )
56 | scalar_type = dh.get_scalar_type(x.dtype)
57 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
58 | result = bool(out[out_idx])
59 | elements = []
60 | for idx in indices:
61 | s = scalar_type(x[idx])
62 | elements.append(s)
63 | expected = any(elements)
64 | ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
65 | out=result, expected=expected, kw=kw)
66 |
67 |
68 | @pytest.mark.unvectorized
69 | @pytest.mark.min_version("2024.12")
70 | @given(
71 | x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
72 | data=st.data(),
73 | )
74 | def test_diff(x, data):
75 | axis = data.draw(
76 | st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
77 | label="axis"
78 | )
79 | if axis is None:
80 | axis_kw = {"axis": -1}
81 | n_axis = x.ndim - 1
82 | else:
83 | axis_kw = {"axis": axis}
84 | n_axis = axis + x.ndim if axis < 0 else axis
85 |
86 | n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
87 |
88 | out = xp.diff(x, **axis_kw, n=n)
89 |
90 | expected_shape = list(x.shape)
91 | expected_shape[n_axis] -= n
92 |
93 | assert out.shape == tuple(expected_shape)
94 |
95 | # value test
96 | if n == 1:
97 | for idx in sh.ndindex(out.shape):
98 | l = list(idx)
99 | l[n_axis] += 1
100 | assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
101 |
102 |
103 | @pytest.mark.min_version("2024.12")
104 | @pytest.mark.unvectorized
105 | @given(
106 | x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
107 | data=st.data(),
108 | )
109 | def test_diff_append_prepend(x, data):
110 | axis = data.draw(
111 | st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
112 | label="axis"
113 | )
114 | if axis is None:
115 | axis_kw = {"axis": -1}
116 | n_axis = x.ndim - 1
117 | else:
118 | axis_kw = {"axis": axis}
119 | n_axis = axis + x.ndim if axis < 0 else axis
120 |
121 | n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
122 |
123 | append_shape = list(x.shape)
124 | append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis")
125 | append_shape[n_axis] = append_axis_len
126 | append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append")
127 |
128 | prepend_shape = list(x.shape)
129 | prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis")
130 | prepend_shape[n_axis] = prepend_axis_len
131 | prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend")
132 |
133 | out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
134 |
135 | in_1 = xp.concat((prepend, x, append), **axis_kw)
136 | out_1 = xp.diff(in_1, **axis_kw, n=n)
137 |
138 | assert out.shape == out_1.shape
139 | for idx in sh.ndindex(out.shape):
140 | assert out[idx] == out_1[idx], f"{idx = }"
141 |
142 |
--------------------------------------------------------------------------------
/array_api_tests/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, Type, Union
2 |
3 | __all__ = [
4 | "DataType",
5 | "Scalar",
6 | "ScalarType",
7 | "Array",
8 | "Shape",
9 | "AtomicIndex",
10 | "Index",
11 | "Param",
12 | ]
13 |
14 | DataType = Type[Any]
15 | Scalar = Union[bool, int, float, complex]
16 | ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]]
17 | Array = Any
18 | Shape = Tuple[int, ...]
19 | AtomicIndex = Union[int, "ellipsis", slice, None] # noqa
20 | Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]]
21 | Param = Tuple
22 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | from pathlib import Path
3 | import argparse
4 | import warnings
5 | import os
6 |
7 | from hypothesis import settings
8 | from hypothesis.errors import InvalidArgument
9 | from pytest import mark
10 |
11 | from array_api_tests import _array_module as xp
12 | from array_api_tests import api_version
13 | from array_api_tests._array_module import _UndefinedStub
14 | from array_api_tests.stubs import EXTENSIONS
15 | from array_api_tests import xp_name, xp as array_module
16 |
17 | from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa
18 |
19 | def pytest_report_header(config):
20 | disabled_extensions = config.getoption("--disable-extension")
21 | enabled_extensions = sorted({
22 | ext for ext in EXTENSIONS + ['fft'] if ext not in disabled_extensions and xp_has_ext(ext)
23 | })
24 |
25 | try:
26 | array_module_version = array_module.__version__
27 | except AttributeError:
28 | array_module_version = "version unknown"
29 |
30 | # make it easier to catch typos in environment variables (ARRAY_API_*** instead of ARRAY_API_TESTS_*** etc)
31 | env_vars = "\n".join([f"{k} = {v}" for k, v in os.environ.items() if 'ARRAY_API' in k])
32 | env_vars = f"Environment variables:\n{'-'*22}\n{env_vars}\n\n"
33 |
34 | header1 = f"Array API Tests Module: {xp_name} ({array_module_version}). API Version: {api_version}. Enabled Extensions: {', '.join(enabled_extensions)}"
35 | return env_vars + header1
36 |
37 | def pytest_addoption(parser):
38 | # Hypothesis max examples
39 | # See https://github.com/HypothesisWorks/hypothesis/issues/2434
40 | parser.addoption(
41 | "--hypothesis-max-examples",
42 | "--max-examples",
43 | action="store",
44 | default=100,
45 | type=int,
46 | help="set the Hypothesis max_examples setting",
47 | )
48 | # Hypothesis deadline
49 | parser.addoption(
50 | "--hypothesis-disable-deadline",
51 | "--disable-deadline",
52 | action="store_true",
53 | help="disable the Hypothesis deadline",
54 | )
55 | # Hypothesis derandomize
56 | parser.addoption(
57 | "--hypothesis-derandomize",
58 | "--derandomize",
59 | action="store_true",
60 | help="set the Hypothesis derandomize parameter",
61 | )
62 | # disable extensions
63 | parser.addoption(
64 | "--disable-extension",
65 | metavar="ext",
66 | nargs="+",
67 | default=[],
68 | help="disable testing for Array API extension(s)",
69 | )
70 | # data-dependent shape
71 | parser.addoption(
72 | "--disable-data-dependent-shapes",
73 | "--disable-dds",
74 | action="store_true",
75 | help="disable testing functions with output shapes dependent on input",
76 | )
77 | # CI
78 | parser.addoption("--ci", action="store_true", help=argparse.SUPPRESS ) # deprecated
79 | parser.addoption(
80 | "--skips-file",
81 | action="store",
82 | help="file with tests to skip. Defaults to skips.txt"
83 | )
84 | parser.addoption(
85 | "--xfails-file",
86 | action="store",
87 | help="file with tests to skip. Defaults to xfails.txt"
88 | )
89 |
90 |
91 | def pytest_configure(config):
92 | config.addinivalue_line(
93 | "markers", "xp_extension(ext): tests an Array API extension"
94 | )
95 | config.addinivalue_line(
96 | "markers", "data_dependent_shapes: output shapes are dependent on inputs"
97 | )
98 | config.addinivalue_line(
99 | "markers",
100 | "min_version(api_version): run when greater or equal to api_version",
101 | )
102 | config.addinivalue_line(
103 | "markers",
104 | "unvectorized: asserts against values via element-wise iteration (not performative!)",
105 | )
106 | # Hypothesis
107 | deadline = None if config.getoption("--hypothesis-disable-deadline") else 800
108 | settings.register_profile(
109 | "array-api-tests",
110 | max_examples=config.getoption("--hypothesis-max-examples"),
111 | derandomize=config.getoption("--hypothesis-derandomize"),
112 | deadline=deadline,
113 | )
114 | settings.load_profile("array-api-tests")
115 | # CI
116 | if config.getoption("--ci"):
117 | warnings.warn(
118 | "Custom pytest option --ci is deprecated as any tests not for CI "
119 | "are now located in meta_tests/"
120 | )
121 |
122 |
123 | @lru_cache
124 | def xp_has_ext(ext: str) -> bool:
125 | try:
126 | return not isinstance(getattr(xp, ext), _UndefinedStub)
127 | except AttributeError:
128 | return False
129 |
130 |
131 | def check_id_match(id_, pattern):
132 | id_ = id_.removeprefix('array-api-tests/')
133 |
134 | if id_ == pattern:
135 | return True
136 |
137 | if id_.startswith(pattern.removesuffix("/") + "/"):
138 | return True
139 |
140 | if pattern.endswith(".py") and id_.startswith(pattern):
141 | return True
142 |
143 | if id_.split("::", maxsplit=2)[0] == pattern:
144 | return True
145 |
146 | if id_.split("[", maxsplit=2)[0] == pattern:
147 | return True
148 |
149 | return False
150 |
151 |
152 | def get_xfail_mark():
153 | """Skip or xfail tests from the xfails-file.txt."""
154 | m = os.environ.get("ARRAY_API_TESTS_XFAIL_MARK", "xfail")
155 | if m == "xfail":
156 | return mark.xfail
157 | elif m == "skip":
158 | return mark.skip
159 | else:
160 | raise ValueError(
161 | f'ARRAY_API_TESTS_XFAIL_MARK value should be one of "skip" or "xfail" '
162 | f'got {m} instead.'
163 | )
164 |
165 |
166 | def pytest_collection_modifyitems(config, items):
167 | # 1. Prepare for iterating over items
168 | # -----------------------------------
169 |
170 | skips_file = skips_path = config.getoption('--skips-file')
171 | if skips_file is None:
172 | skips_file = Path(__file__).parent / "skips.txt"
173 | if skips_file.exists():
174 | skips_path = skips_file
175 |
176 | skip_ids = []
177 | if skips_path:
178 | with open(os.path.expanduser(skips_path)) as f:
179 | for line in f:
180 | if line.startswith("array_api_tests"):
181 | id_ = line.strip("\n")
182 | skip_ids.append(id_)
183 |
184 | xfails_file = xfails_path = config.getoption('--xfails-file')
185 | if xfails_file is None:
186 | xfails_file = Path(__file__).parent / "xfails.txt"
187 | if xfails_file.exists():
188 | xfails_path = xfails_file
189 |
190 | xfail_ids = []
191 | if xfails_path:
192 | with open(os.path.expanduser(xfails_path)) as f:
193 | for line in f:
194 | if not line.strip() or line.startswith('#'):
195 | continue
196 | id_ = line.strip("\n")
197 | xfail_ids.append(id_)
198 |
199 | skip_id_matched = {id_: False for id_ in skip_ids}
200 | xfail_id_matched = {id_: False for id_ in xfail_ids}
201 |
202 | disabled_exts = config.getoption("--disable-extension")
203 | disabled_dds = config.getoption("--disable-data-dependent-shapes")
204 | unvectorized_max_examples = max(1, config.getoption("--hypothesis-max-examples")//10)
205 |
206 | # 2. Iterate through items and apply markers accordingly
207 | # ------------------------------------------------------
208 |
209 | xfail_mark = get_xfail_mark()
210 |
211 | for item in items:
212 | markers = list(item.iter_markers())
213 | # skip if specified in skips file
214 | for id_ in skip_ids:
215 | if check_id_match(item.nodeid, id_):
216 | item.add_marker(mark.skip(reason=f"--skips-file ({skips_file})"))
217 | skip_id_matched[id_] = True
218 | break
219 | # xfail if specified in xfails file
220 | for id_ in xfail_ids:
221 | if check_id_match(item.nodeid, id_):
222 | item.add_marker(xfail_mark(reason=f"--xfails-file ({xfails_file})"))
223 | xfail_id_matched[id_] = True
224 | break
225 | # skip if disabled or non-existent extension
226 | ext_mark = next((m for m in markers if m.name == "xp_extension"), None)
227 | if ext_mark is not None:
228 | ext = ext_mark.args[0]
229 | if ext in disabled_exts:
230 | item.add_marker(
231 | mark.skip(reason=f"{ext} disabled in --disable-extensions")
232 | )
233 | elif not xp_has_ext(ext):
234 | item.add_marker(mark.skip(reason=f"{ext} not found in array module"))
235 | # skip if disabled by dds flag
236 | if disabled_dds:
237 | for m in markers:
238 | if m.name == "data_dependent_shapes":
239 | item.add_marker(
240 | mark.skip(reason="disabled via --disable-data-dependent-shapes")
241 | )
242 | break
243 | # skip if test is for greater api_version
244 | ver_mark = next((m for m in markers if m.name == "min_version"), None)
245 | if ver_mark is not None:
246 | min_version = ver_mark.args[0]
247 | if api_version < min_version:
248 | item.add_marker(
249 | mark.skip(
250 | reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
251 | )
252 | )
253 | # reduce max generated Hypothesis example for unvectorized tests
254 | if any(m.name == "unvectorized" for m in markers):
255 | # TODO: limit generated examples when settings already applied
256 | if not hasattr(item.obj, "_hypothesis_internal_settings_applied"):
257 | try:
258 | item.obj = settings(max_examples=unvectorized_max_examples)(item.obj)
259 | except InvalidArgument as e:
260 | warnings.warn(
261 | f"Tried decorating {item.name} with settings() but got "
262 | f"hypothesis.errors.InvalidArgument: {e}"
263 | )
264 |
265 |
266 | # 3. Warn on bad skipped/xfailed ids
267 | # ----------------------------------
268 |
269 | bad_ids_end_msg = (
270 | "Note the relevant tests might not have been collected by pytest, or "
271 | "another specified id might have already matched a test."
272 | )
273 | bad_skip_ids = [id_ for id_, matched in skip_id_matched.items() if not matched]
274 | if bad_skip_ids:
275 | f_bad_ids = "\n".join(f" {id_}" for id_ in bad_skip_ids)
276 | warnings.warn(
277 | f"{len(bad_skip_ids)} ids in skips file don't match any collected tests: \n"
278 | f"{f_bad_ids}\n"
279 | f"(skips file: {skips_file})\n"
280 | f"{bad_ids_end_msg}"
281 | )
282 | bad_xfail_ids = [id_ for id_, matched in xfail_id_matched.items() if not matched]
283 | if bad_xfail_ids:
284 | f_bad_ids = "\n".join(f" {id_}" for id_ in bad_xfail_ids)
285 | warnings.warn(
286 | f"{len(bad_xfail_ids)} ids in xfails file don't match any collected tests: \n"
287 | f"{f_bad_ids}\n"
288 | f"(xfails file: {xfails_file})\n"
289 | f"{bad_ids_end_msg}"
290 | )
291 |
--------------------------------------------------------------------------------
/meta_tests/README.md:
--------------------------------------------------------------------------------
1 | Testing the utilities used in `array_api_tests/`
--------------------------------------------------------------------------------
/meta_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/data-apis/array-api-tests/8c8cb69913d53ca9c2c54be7910d7b0ab0c5f46e/meta_tests/__init__.py
--------------------------------------------------------------------------------
/meta_tests/test_array_helpers.py:
--------------------------------------------------------------------------------
1 | from hypothesis import given
2 | from hypothesis import strategies as st
3 |
4 | from array_api_tests import _array_module as xp
5 | from array_api_tests.hypothesis_helpers import (int_dtypes, arrays,
6 | two_mutually_broadcastable_shapes)
7 | from array_api_tests.shape_helpers import iter_indices, broadcast_shapes
8 | from array_api_tests .array_helpers import exactly_equal, notequal, less
9 |
10 | # TODO: These meta-tests currently only work with NumPy
11 |
12 | def test_exactly_equal():
13 | a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
14 | b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
15 |
16 | res = xp.asarray([True, False, True, False, True, False, True, False])
17 | assert xp.all(xp.equal(exactly_equal(a, b), res))
18 |
19 | def test_notequal():
20 | a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
21 | b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
22 |
23 | res = xp.asarray([False, True, False, False, False, True, False, True])
24 | assert xp.all(xp.equal(notequal(a, b), res))
25 |
26 |
27 | @given(two_mutually_broadcastable_shapes, int_dtypes, int_dtypes, st.data())
28 | def test_less(shapes, dtype1, dtype2, data):
29 | x = data.draw(arrays(shape=shapes[0], dtype=dtype1))
30 | y = data.draw(arrays(shape=shapes[1], dtype=dtype2))
31 |
32 | res = less(x, y)
33 |
34 | for i, j, k in iter_indices(x.shape, y.shape, broadcast_shapes(x.shape, y.shape)):
35 | assert res[k] == (int(x[i]) < int(y[j]))
36 |
--------------------------------------------------------------------------------
/meta_tests/test_broadcasting.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
3 | """
4 |
5 | import pytest
6 |
7 | from array_api_tests import shape_helpers as sh
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "shape1, shape2, expected",
12 | [
13 | [(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)],
14 | [(5, 4), (1,), (5, 4)],
15 | [(5, 4), (4,), (5, 4)],
16 | [(15, 3, 5), (15, 1, 5), (15, 3, 5)],
17 | [(15, 3, 5), (3, 5), (15, 3, 5)],
18 | [(15, 3, 5), (3, 1), (15, 3, 5)],
19 | ],
20 | )
21 | def test_broadcast_shapes(shape1, shape2, expected):
22 | assert sh._broadcast_shapes(shape1, shape2) == expected
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "shape1, shape2",
27 | [
28 | [(3,), (4,)], # dimension does not match
29 | [(2, 1), (8, 4, 3)], # second dimension does not match
30 | [(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended
31 | ],
32 | )
33 | def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
34 | with pytest.raises(sh.BroadcastError):
35 | sh._broadcast_shapes(shape1, shape2)
36 |
--------------------------------------------------------------------------------
/meta_tests/test_equality_mapping.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from array_api_tests .dtype_helpers import EqualityMapping
4 |
5 |
6 | def test_raises_on_distinct_eq_key():
7 | with pytest.raises(ValueError):
8 | EqualityMapping([(float("nan"), "value")])
9 |
10 |
11 | def test_raises_on_indistinct_eq_keys():
12 | class AlwaysEq:
13 | def __init__(self, hash):
14 | self._hash = hash
15 |
16 | def __eq__(self, other):
17 | return True
18 |
19 | def __hash__(self):
20 | return self._hash
21 |
22 | with pytest.raises(ValueError):
23 | EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])
24 |
25 |
26 | def test_key_error():
27 | mapping = EqualityMapping([("key", "value")])
28 | with pytest.raises(KeyError):
29 | mapping["nonexistent key"]
30 |
31 |
32 | def test_iter():
33 | mapping = EqualityMapping([("key", "value")])
34 | it = iter(mapping)
35 | assert next(it) == "key"
36 | with pytest.raises(StopIteration):
37 | next(it)
38 |
--------------------------------------------------------------------------------
/meta_tests/test_hypothesis_helpers.py:
--------------------------------------------------------------------------------
1 | from math import prod
2 | from typing import Type
3 |
4 | import pytest
5 | from hypothesis import given, settings
6 | from hypothesis import strategies as st
7 | from hypothesis.errors import Unsatisfiable
8 |
9 | from array_api_tests import _array_module as xp
10 | from array_api_tests import array_helpers as ah
11 | from array_api_tests import dtype_helpers as dh
12 | from array_api_tests import hypothesis_helpers as hh
13 | from array_api_tests import shape_helpers as sh
14 | from array_api_tests import xps
15 | from array_api_tests ._array_module import _UndefinedStub
16 |
17 | UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
18 | pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
19 |
20 | @given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes))
21 | def test_mutually_promotable_dtypes(pair):
22 | assert pair in (
23 | (xp.float32, xp.float32),
24 | (xp.float32, xp.float64),
25 | (xp.float64, xp.float32),
26 | (xp.float64, xp.float64),
27 | )
28 |
29 |
30 | @given(
31 | hh.mutually_promotable_dtypes(
32 | dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
33 | )
34 | )
35 | def test_partial_mutually_promotable_dtypes(pair):
36 | assert pair in (
37 | (xp.uint8, xp.uint8),
38 | (xp.uint8, xp.uint32),
39 | (xp.uint32, xp.uint8),
40 | (xp.uint32, xp.uint32),
41 | )
42 |
43 |
44 | def valid_shape(shape) -> bool:
45 | return (
46 | all(isinstance(side, int) for side in shape)
47 | and all(side >= 0 for side in shape)
48 | and prod(shape) < hh.MAX_ARRAY_SIZE
49 | )
50 |
51 |
52 | @given(hh.shapes())
53 | def test_shapes(shape):
54 | assert valid_shape(shape)
55 |
56 |
57 | @given(hh.two_mutually_broadcastable_shapes)
58 | def test_two_mutually_broadcastable_shapes(pair):
59 | for shape in pair:
60 | assert valid_shape(shape)
61 |
62 |
63 | @given(hh.two_broadcastable_shapes())
64 | def test_two_broadcastable_shapes(pair):
65 | for shape in pair:
66 | assert valid_shape(shape)
67 | assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0]
68 |
69 |
70 | @given(*hh.two_mutual_arrays())
71 | def test_two_mutual_arrays(x1, x2):
72 | assert (x1.dtype, x2.dtype) in dh.promotion_table.keys()
73 |
74 |
75 | def test_two_mutual_arrays_raises_on_bad_dtypes():
76 | with pytest.raises(TypeError):
77 | hh.two_mutual_arrays(dtypes=xps.scalar_dtypes())
78 |
79 |
80 | def test_kwargs():
81 | results = []
82 |
83 | @given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
84 | @settings(max_examples=100)
85 | def run(kw):
86 | results.append(kw)
87 | run()
88 |
89 | assert all(isinstance(kw, dict) for kw in results)
90 | for size in [0, 1, 2]:
91 | assert any(len(kw) == size for kw in results)
92 |
93 | n_results = [kw for kw in results if "n" in kw]
94 | assert len(n_results) > 0
95 | assert all(isinstance(kw["n"], int) for kw in n_results)
96 |
97 | c_results = [kw for kw in results if "c" in kw]
98 | assert len(c_results) > 0
99 | assert all(isinstance(kw["c"], str) for kw in c_results)
100 |
101 |
102 | def test_specified_kwargs():
103 | results = []
104 |
105 | @given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data())
106 | @settings(max_examples=100)
107 | def run(n, d, data):
108 | kw = data.draw(
109 | hh.specified_kwargs(
110 | hh.KVD("n", n, 0),
111 | hh.KVD("d", d, None),
112 | ),
113 | label="kw",
114 | )
115 | results.append(kw)
116 | run()
117 |
118 | assert all(isinstance(kw, dict) for kw in results)
119 |
120 | assert any(len(kw) == 0 for kw in results)
121 |
122 | assert any("n" not in kw.keys() for kw in results)
123 | assert any("n" in kw.keys() and kw["n"] == 0 for kw in results)
124 | assert any("n" in kw.keys() and kw["n"] != 0 for kw in results)
125 |
126 | assert any("d" not in kw.keys() for kw in results)
127 | assert any("d" in kw.keys() and kw["d"] is None for kw in results)
128 | assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
129 |
130 |
131 | @given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data())
132 | def test_symmetric_matrices(finite, dtype, data):
133 | m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite), label="m")
134 | assert m.dtype == dtype
135 | # TODO: This part of this test should be part of the .mT test
136 | ah.assert_exactly_equal(m, m.mT)
137 |
138 | if finite:
139 | ah.assert_finite(m)
140 |
141 |
142 | @given(dtype=xps.floating_dtypes(), data=st.data())
143 | def test_positive_definite_matrices(dtype, data):
144 | m = data.draw(hh.positive_definite_matrices(st.just(dtype)), label="m")
145 | assert m.dtype == dtype
146 | # TODO: Test that it actually is positive definite
147 |
148 |
149 | def make_raising_func(cls: Type[Exception], msg: str):
150 | def raises():
151 | raise cls(msg)
152 |
153 | return raises
154 |
155 | @pytest.mark.parametrize(
156 | "func",
157 | [
158 | make_raising_func(OverflowError, "foo"),
159 | make_raising_func(RuntimeError, "Overflow when unpacking long"),
160 | make_raising_func(Exception, "Got an overflow"),
161 | ]
162 | )
163 | def test_reject_overflow(func):
164 | @given(data=st.data())
165 | def test_case(data):
166 | with hh.reject_overflow():
167 | func()
168 |
169 | with pytest.raises(Unsatisfiable):
170 | test_case()
171 |
--------------------------------------------------------------------------------
/meta_tests/test_linalg.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from hypothesis import given
4 |
5 | from array_api_tests .hypothesis_helpers import symmetric_matrices
6 | from array_api_tests import array_helpers as ah
7 | from array_api_tests import _array_module as xp
8 |
9 | @pytest.mark.xp_extension('linalg')
10 | @given(x=symmetric_matrices(finite=True))
11 | def test_symmetric_matrices(x):
12 | upper = xp.triu(x)
13 | lower = xp.tril(x)
14 | lowerT = ah._matrix_transpose(lower)
15 |
16 | ah.assert_exactly_equal(upper, lowerT)
17 |
--------------------------------------------------------------------------------
/meta_tests/test_partial_adopters.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from hypothesis import given
3 |
4 | from array_api_tests import dtype_helpers as dh
5 | from array_api_tests import hypothesis_helpers as hh
6 | from array_api_tests import _array_module as xp
7 | from array_api_tests ._array_module import _UndefinedStub
8 |
9 |
10 | # e.g. PyTorch only supports uint8 currently
11 | @pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined")
12 | @pytest.mark.skipif(
13 | not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]),
14 | reason="uints defined",
15 | )
16 | @given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes))
17 | def test_mutually_promotable_dtypes(pair):
18 | assert pair == (xp.uint8, xp.uint8)
19 |
--------------------------------------------------------------------------------
/meta_tests/test_pytest_helpers.py:
--------------------------------------------------------------------------------
1 | from pytest import raises
2 |
3 | from array_api_tests import xp as _xp
4 | from array_api_tests import _array_module as xp
5 | from array_api_tests import pytest_helpers as ph
6 |
7 |
8 | def test_assert_dtype():
9 | ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16)
10 | with raises(AssertionError):
11 | ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32)
12 | ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool)
13 | ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8)
14 | ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool)
15 |
16 |
17 | def test_assert_array_elements():
18 | ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0))
19 | ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0))
20 | ph.assert_array_elements("neg zeros", out=xp.asarray(-0.0), expected=xp.asarray(-0.0))
21 | if hasattr(_xp, "signbit"):
22 | with raises(AssertionError):
23 | ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
24 | with raises(AssertionError):
25 | ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))
26 |
27 | ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan")))
28 | with raises(AssertionError):
29 | ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0))
30 |
--------------------------------------------------------------------------------
/meta_tests/test_signatures.py:
--------------------------------------------------------------------------------
1 | from inspect import Parameter, Signature, signature
2 |
3 | import pytest
4 |
5 | from array_api_tests .test_signatures import _test_inspectable_func
6 |
7 |
8 | def stub(foo, /, bar=None, *, baz=None):
9 | pass
10 |
11 |
12 | stub_sig = signature(stub)
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "sig",
17 | [
18 | Signature(
19 | [
20 | Parameter("foo", Parameter.POSITIONAL_ONLY),
21 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
22 | Parameter("baz", Parameter.KEYWORD_ONLY),
23 | ]
24 | ),
25 | Signature(
26 | [
27 | Parameter("foo", Parameter.POSITIONAL_ONLY),
28 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
29 | Parameter("baz", Parameter.POSITIONAL_OR_KEYWORD),
30 | ]
31 | ),
32 | Signature(
33 | [
34 | Parameter("foo", Parameter.POSITIONAL_ONLY),
35 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
36 | Parameter("qux", Parameter.KEYWORD_ONLY),
37 | Parameter("baz", Parameter.KEYWORD_ONLY),
38 | ]
39 | ),
40 | ],
41 | )
42 | def test_good_sig_passes(sig):
43 | _test_inspectable_func(sig, stub_sig)
44 |
45 |
46 | @pytest.mark.parametrize(
47 | "sig",
48 | [
49 | Signature(
50 | [
51 | Parameter("foo", Parameter.POSITIONAL_ONLY),
52 | Parameter("bar", Parameter.POSITIONAL_ONLY),
53 | Parameter("baz", Parameter.KEYWORD_ONLY),
54 | ]
55 | ),
56 | Signature(
57 | [
58 | Parameter("foo", Parameter.POSITIONAL_ONLY),
59 | Parameter("bar", Parameter.KEYWORD_ONLY),
60 | Parameter("baz", Parameter.KEYWORD_ONLY),
61 | ]
62 | ),
63 | ],
64 | )
65 | def test_raises_on_bad_sig(sig):
66 | with pytest.raises(AssertionError):
67 | _test_inspectable_func(sig, stub_sig)
68 |
--------------------------------------------------------------------------------
/meta_tests/test_special_cases.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from array_api_tests .test_special_cases import parse_result
4 |
5 |
6 | def test_parse_result():
7 | check_result, _ = parse_result(
8 | "an implementation-dependent approximation to ``+3π/4``"
9 | )
10 | assert check_result(3 * math.pi / 4)
11 |
--------------------------------------------------------------------------------
/meta_tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from hypothesis import given
3 | from hypothesis import strategies as st
4 |
5 | from array_api_tests import _array_module as xp
6 | from array_api_tests import dtype_helpers as dh
7 | from array_api_tests import hypothesis_helpers as hh
8 | from array_api_tests import shape_helpers as sh
9 | from array_api_tests import xps
10 | from array_api_tests .test_creation_functions import frange
11 | from array_api_tests .test_manipulation_functions import roll_ndindex
12 | from array_api_tests .test_operators_and_elementwise_functions import mock_int_dtype
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "r, size, elements",
17 | [
18 | (frange(0, 1, 1), 1, [0]),
19 | (frange(1, 0, -1), 1, [1]),
20 | (frange(0, 1, -1), 0, []),
21 | (frange(0, 1, 2), 1, [0]),
22 | ],
23 | )
24 | def test_frange(r, size, elements):
25 | assert len(r) == size
26 | assert list(r) == elements
27 |
28 |
29 | @pytest.mark.parametrize(
30 | "shape, expected",
31 | [((), [()])],
32 | )
33 | def test_ndindex(shape, expected):
34 | assert list(sh.ndindex(shape)) == expected
35 |
36 |
37 | @pytest.mark.parametrize(
38 | "shape, axis, expected",
39 | [
40 | ((1,), 0, [(slice(None, None),)]),
41 | ((1, 2), 0, [(slice(None, None), slice(None, None))]),
42 | (
43 | (2, 4),
44 | 1,
45 | [(0, slice(None, None)), (1, slice(None, None))],
46 | ),
47 | ],
48 | )
49 | def test_axis_ndindex(shape, axis, expected):
50 | assert list(sh.axis_ndindex(shape, axis)) == expected
51 |
52 |
53 | @pytest.mark.parametrize(
54 | "shape, axes, expected",
55 | [
56 | ((), (), [[()]]),
57 | ((1,), (0,), [[(0,)]]),
58 | (
59 | (2, 2),
60 | (0,),
61 | [
62 | [(0, 0), (1, 0)],
63 | [(0, 1), (1, 1)],
64 | ],
65 | ),
66 | ],
67 | )
68 | def test_axes_ndindex(shape, axes, expected):
69 | assert list(sh.axes_ndindex(shape, axes)) == expected
70 |
71 |
72 | @pytest.mark.parametrize(
73 | "shape, shifts, axes, expected",
74 | [
75 | ((1, 1), (0,), (0,), [(0, 0)]),
76 | ((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]),
77 | ((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
78 | ((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
79 | ],
80 | )
81 | def test_roll_ndindex(shape, shifts, axes, expected):
82 | assert list(roll_ndindex(shape, shifts, axes)) == expected
83 |
84 |
85 | @pytest.mark.parametrize(
86 | "idx, expected",
87 | [
88 | ((), "x"),
89 | (42, "x[42]"),
90 | ((42,), "x[42]"),
91 | ((42, 7), "x[42, 7]"),
92 | (slice(None, 2), "x[:2]"),
93 | (slice(2, None), "x[2:]"),
94 | (slice(0, 2), "x[0:2]"),
95 | (slice(0, 2, -1), "x[0:2:-1]"),
96 | (slice(None, None, -1), "x[::-1]"),
97 | (slice(None, None), "x[:]"),
98 | (..., "x[...]"),
99 | ((None, 42), "x[None, 42]"),
100 | ],
101 | )
102 | def test_fmt_idx(idx, expected):
103 | assert sh.fmt_idx("x", idx) == expected
104 |
105 |
106 | @given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
107 | def test_int_to_dtype(x, dtype):
108 | with hh.reject_overflow():
109 | d = xp.asarray(x, dtype=dtype)
110 | assert mock_int_dtype(x, dtype) == d
111 |
112 |
113 | @given(hh.oneway_promotable_dtypes(dh.all_dtypes))
114 | def test_oneway_promotable_dtypes(D):
115 | assert D.result_dtype == dh.result_type(*D)
116 |
117 |
118 | @given(hh.oneway_broadcastable_shapes())
119 | def test_oneway_broadcastable_shapes(S):
120 | assert S.result_shape == sh.broadcast_shapes(*S)
121 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | filterwarnings =
3 | # Ignore floating-point warnings from NumPy
4 | ignore:invalid value encountered in:RuntimeWarning
5 | ignore:overflow encountered in:RuntimeWarning
6 | ignore:divide by zero encountered in:RuntimeWarning
7 |
8 |
9 |
--------------------------------------------------------------------------------
/reporting.py:
--------------------------------------------------------------------------------
1 | from array_api_tests.dtype_helpers import dtype_to_name
2 | from array_api_tests import _array_module as xp
3 | from array_api_tests import __version__
4 |
5 | from collections import Counter
6 | from types import BuiltinFunctionType, FunctionType
7 | import dataclasses
8 | import json
9 | import warnings
10 |
11 | from hypothesis.strategies import SearchStrategy
12 |
13 | from pytest import hookimpl, fixture
14 | try:
15 | import pytest_jsonreport # noqa
16 | except ImportError:
17 | raise ImportError("pytest-json-report is required to run the array API tests")
18 |
19 | def to_json_serializable(o):
20 | if o in dtype_to_name:
21 | return dtype_to_name[o]
22 | if isinstance(o, (BuiltinFunctionType, FunctionType, type)):
23 | return o.__name__
24 | if dataclasses.is_dataclass(o):
25 | return to_json_serializable(dataclasses.asdict(o))
26 | if isinstance(o, SearchStrategy):
27 | return repr(o)
28 | if isinstance(o, dict):
29 | return {to_json_serializable(k): to_json_serializable(v) for k, v in o.items()}
30 | if isinstance(o, tuple):
31 | if hasattr(o, '_asdict'): # namedtuple
32 | return to_json_serializable(o._asdict())
33 | return tuple(to_json_serializable(i) for i in o)
34 | if isinstance(o, list):
35 | return [to_json_serializable(i) for i in o]
36 | if callable(o):
37 | return repr(o)
38 |
39 | # Ensure everything is JSON serializable. If this warning is issued, it
40 | # means the given type needs to be added above if possible.
41 | try:
42 | json.dumps(o)
43 | except TypeError:
44 | warnings.warn(f"{o!r} (of type {type(o)}) is not JSON-serializable. Using the repr instead.")
45 | return repr(o)
46 |
47 | return o
48 |
49 | @hookimpl(optionalhook=True)
50 | def pytest_metadata(metadata):
51 | """
52 | Additional global metadata for --json-report.
53 | """
54 | metadata['array_api_tests_module'] = xp.__name__
55 | metadata['array_api_tests_version'] = __version__
56 |
57 | @fixture(autouse=True)
58 | def add_extra_json_metadata(request, json_metadata):
59 | """
60 | Additional per-test metadata for --json-report
61 | """
62 | def add_metadata(name, obj):
63 | obj = to_json_serializable(obj)
64 | json_metadata[name] = obj
65 |
66 | test_module = request.module.__name__
67 | if test_module.startswith('array_api_tests.meta'):
68 | return
69 |
70 | test_function = request.function.__name__
71 | assert test_function.startswith('test_'), 'unexpected test function name'
72 |
73 | if test_module == 'array_api_tests.test_has_names':
74 | array_api_function_name = None
75 | else:
76 | array_api_function_name = test_function[len('test_'):]
77 |
78 | add_metadata('test_module', test_module)
79 | add_metadata('test_function', test_function)
80 | add_metadata('array_api_function_name', array_api_function_name)
81 |
82 | if hasattr(request.node, 'callspec'):
83 | params = request.node.callspec.params
84 | add_metadata('params', params)
85 |
86 | def finalizer():
87 | # TODO: This metadata is all in the form of error strings. It might be
88 | # nice to extract the hypothesis failing inputs directly somehow.
89 | if hasattr(request.node, 'hypothesis_report_information'):
90 | add_metadata('hypothesis_report_information', request.node.hypothesis_report_information)
91 | if hasattr(request.node, 'hypothesis_statistics'):
92 | add_metadata('hypothesis_statistics', request.node.hypothesis_statistics)
93 |
94 | request.addfinalizer(finalizer)
95 |
96 | @hookimpl(optionalhook=True)
97 | def pytest_json_modifyreport(json_report):
98 | # Deduplicate warnings. These duplicate warnings can cause the file size
99 | # to become huge. For instance, a warning from np.bool which is emitted
100 | # every time hypothesis runs (over a million times) causes the warnings
101 | # JSON for a plain numpy namespace run to be over 500MB.
102 |
103 | # This will lose information about what order the warnings were issued in,
104 | # but that isn't particularly helpful anyway since the warning metadata
105 | # doesn't store a full stack of where it was issued from. The resulting
106 | # warnings will be in order of the first time each warning is issued since
107 | # collections.Counter is ordered just like dict().
108 | counted_warnings = Counter([frozenset(i.items()) for i in json_report.get('warnings', dict())])
109 | deduped_warnings = [{**dict(i), 'count': counted_warnings[i]} for i in counted_warnings]
110 |
111 | json_report['warnings'] = deduped_warnings
112 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-json-report
3 | hypothesis>=6.130.5
4 | ndindex>=1.8
5 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | # See the docstring in versioneer.py for instructions. Note that you must
3 | # re-run 'versioneer.py setup' after changing this section, and commit the
4 | # resulting files.
5 |
6 | [versioneer]
7 | VCS = git
8 | style = pep440
9 | versionfile_source = array_api_tests/_version.py
10 | versionfile_build = array_api_tests/_version.py
11 | tag_prefix =
12 | parentdir_prefix =
13 |
--------------------------------------------------------------------------------