├── pytest.ini ├── array_api_tests ├── meta │ ├── __init__.py │ ├── test_special_cases.py │ ├── test_partial_adopters.py │ ├── test_array_helpers.py │ ├── test_equality_mapping.py │ ├── test_broadcasting.py │ ├── test_pytest_helpers.py │ ├── test_signatures.py │ ├── test_utils.py │ └── test_hypothesis_helpers.py ├── typing.py ├── algos.py ├── _array_module.py ├── test_has_names.py ├── test_constants.py ├── test_indexing_functions.py ├── __init__.py ├── test_utility_functions.py ├── stubs.py ├── test_sorting_functions.py ├── shape_helpers.py ├── test_searching_functions.py ├── test_data_type_functions.py ├── test_fft.py ├── test_set_functions.py ├── test_array_object.py ├── array_helpers.py ├── test_signatures.py ├── test_statistical_functions.py ├── test_manipulation_functions.py ├── pytest_helpers.py └── dtype_helpers.py ├── _config.yml ├── .gitattributes ├── requirements.txt ├── MANIFEST.in ├── .gitmodules ├── .pre-commit-config.yaml ├── setup.cfg ├── .github └── workflows │ ├── lint.yml │ └── numpy.yml ├── LICENSE ├── .gitignore ├── numpy-skips.txt ├── reporting.py ├── conftest.py └── README.md /pytest.ini: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /array_api_tests/meta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | array_api_tests/_version.py} export-subst 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-json-report 3 | hypothesis>=6.68.0 4 | ndindex>=1.6 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include array_api_tests/_version.py} 3 | include array_api_tests/_version.py 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "array_api_tests/array-api"] 2 | path = array-api 3 | url = https://github.com/data-apis/array-api/ 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/flake8 3 | rev: '4.0.1' 4 | hooks: 5 | - id: flake8 6 | args: [--select, F] 7 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_special_cases.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from ..test_special_cases import parse_result 4 | 5 | 6 | def test_parse_result(): 7 | check_result, _ = parse_result( 8 | "an implementation-dependent approximation to ``+3π/4``" 9 | ) 10 | assert check_result(3 * math.pi / 4) 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | # See the docstring in versioneer.py for instructions. Note that you must 3 | # re-run 'versioneer.py setup' after changing this section, and commit the 4 | # resulting files. 5 | 6 | [versioneer] 7 | VCS = git 8 | style = pep440 9 | versionfile_source = array_api_tests/_version.py 10 | versionfile_build = array_api_tests/_version.py 11 | tag_prefix = 12 | parentdir_prefix = 13 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.8, 3.9] 12 | 13 | steps: 14 | - uses: actions/checkout@v1 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Run pre-commit hook 20 | uses: pre-commit/action@v2.0.3 21 | -------------------------------------------------------------------------------- /array_api_tests/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Type, Union 2 | 3 | __all__ = [ 4 | "DataType", 5 | "Scalar", 6 | "ScalarType", 7 | "Array", 8 | "Shape", 9 | "AtomicIndex", 10 | "Index", 11 | "Param", 12 | ] 13 | 14 | DataType = Type[Any] 15 | Scalar = Union[bool, int, float, complex] 16 | ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]] 17 | Array = Any 18 | Shape = Tuple[int, ...] 19 | AtomicIndex = Union[int, "ellipsis", slice, None] # noqa 20 | Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]] 21 | Param = Tuple 22 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_partial_adopters.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hypothesis import given 3 | 4 | from .. import dtype_helpers as dh 5 | from .. import hypothesis_helpers as hh 6 | from .. import _array_module as xp 7 | from .._array_module import _UndefinedStub 8 | 9 | 10 | # e.g. PyTorch only supports uint8 currently 11 | @pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined") 12 | @pytest.mark.skipif( 13 | not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]), 14 | reason="uints defined", 15 | ) 16 | @given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes)) 17 | def test_mutually_promotable_dtypes(pair): 18 | assert pair == (xp.uint8, xp.uint8) 19 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_array_helpers.py: -------------------------------------------------------------------------------- 1 | from .. import _array_module as xp 2 | from ..array_helpers import exactly_equal, notequal 3 | 4 | # TODO: These meta-tests currently only work with NumPy 5 | 6 | def test_exactly_equal(): 7 | a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) 8 | b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) 9 | 10 | res = xp.asarray([True, False, True, False, True, False, True, False]) 11 | assert xp.all(xp.equal(exactly_equal(a, b), res)) 12 | 13 | def test_notequal(): 14 | a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) 15 | b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) 16 | 17 | res = xp.asarray([False, True, False, False, False, True, False, True]) 18 | assert xp.all(xp.equal(notequal(a, b), res)) 19 | 20 | -------------------------------------------------------------------------------- /.github/workflows/numpy.yml: -------------------------------------------------------------------------------- 1 | name: NumPy Array API 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.10", "3.11"] 12 | 13 | steps: 14 | - name: Checkout array-api-tests 15 | uses: actions/checkout@v1 16 | with: 17 | submodules: 'true' 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | python -m pip install numpy==1.26.2 26 | python -m pip install -r requirements.txt 27 | - name: Run the test suite 28 | env: 29 | ARRAY_API_TESTS_MODULE: numpy.array_api 30 | run: | 31 | pytest -v -rxXfE --ci --skips-file numpy-skips.txt 32 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_equality_mapping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ..dtype_helpers import EqualityMapping 4 | 5 | 6 | def test_raises_on_distinct_eq_key(): 7 | with pytest.raises(ValueError): 8 | EqualityMapping([(float("nan"), "value")]) 9 | 10 | 11 | def test_raises_on_indistinct_eq_keys(): 12 | class AlwaysEq: 13 | def __init__(self, hash): 14 | self._hash = hash 15 | 16 | def __eq__(self, other): 17 | return True 18 | 19 | def __hash__(self): 20 | return self._hash 21 | 22 | with pytest.raises(ValueError): 23 | EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")]) 24 | 25 | 26 | def test_key_error(): 27 | mapping = EqualityMapping([("key", "value")]) 28 | with pytest.raises(KeyError): 29 | mapping["nonexistent key"] 30 | 31 | 32 | def test_iter(): 33 | mapping = EqualityMapping([("key", "value")]) 34 | it = iter(mapping) 35 | assert next(it) == "key" 36 | with pytest.raises(StopIteration): 37 | next(it) 38 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_broadcasting.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md 3 | """ 4 | 5 | import pytest 6 | 7 | from .. import shape_helpers as sh 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "shape1, shape2, expected", 12 | [ 13 | [(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)], 14 | [(5, 4), (1,), (5, 4)], 15 | [(5, 4), (4,), (5, 4)], 16 | [(15, 3, 5), (15, 1, 5), (15, 3, 5)], 17 | [(15, 3, 5), (3, 5), (15, 3, 5)], 18 | [(15, 3, 5), (3, 1), (15, 3, 5)], 19 | ], 20 | ) 21 | def test_broadcast_shapes(shape1, shape2, expected): 22 | assert sh._broadcast_shapes(shape1, shape2) == expected 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "shape1, shape2", 27 | [ 28 | [(3,), (4,)], # dimension does not match 29 | [(2, 1), (8, 4, 3)], # second dimension does not match 30 | [(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended 31 | ], 32 | ) 33 | def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2): 34 | with pytest.raises(sh.BroadcastError): 35 | sh._broadcast_shapes(shape1, shape2) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Consortium for Python Data API Standards contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_pytest_helpers.py: -------------------------------------------------------------------------------- 1 | from pytest import raises 2 | 3 | from .. import _array_module as xp 4 | from .. import pytest_helpers as ph 5 | 6 | 7 | def test_assert_dtype(): 8 | ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16) 9 | with raises(AssertionError): 10 | ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32) 11 | ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool) 12 | ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8) 13 | ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool) 14 | 15 | 16 | def test_assert_array_elements(): 17 | ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0)) 18 | ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0)) 19 | with raises(AssertionError): 20 | ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0)) 21 | with raises(AssertionError): 22 | ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0)) 23 | -------------------------------------------------------------------------------- /array_api_tests/algos.py: -------------------------------------------------------------------------------- 1 | __all__ = ["broadcast_shapes"] 2 | 3 | 4 | from .typing import Shape 5 | 6 | 7 | # We use a custom exception to differentiate from potential bugs 8 | class BroadcastError(ValueError): 9 | pass 10 | 11 | 12 | def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape: 13 | """Broadcasts `shape1` and `shape2`""" 14 | N1 = len(shape1) 15 | N2 = len(shape2) 16 | N = max(N1, N2) 17 | shape = [None for _ in range(N)] 18 | i = N - 1 19 | while i >= 0: 20 | n1 = N1 - N + i 21 | if N1 - N + i >= 0: 22 | d1 = shape1[n1] 23 | else: 24 | d1 = 1 25 | n2 = N2 - N + i 26 | if N2 - N + i >= 0: 27 | d2 = shape2[n2] 28 | else: 29 | d2 = 1 30 | 31 | if d1 == 1: 32 | shape[i] = d2 33 | elif d2 == 1: 34 | shape[i] = d1 35 | elif d1 == d2: 36 | shape[i] = d1 37 | else: 38 | raise BroadcastError 39 | 40 | i = i - 1 41 | 42 | return tuple(shape) 43 | 44 | 45 | def broadcast_shapes(*shapes: Shape): 46 | if len(shapes) == 0: 47 | raise ValueError("shapes=[] must be non-empty") 48 | elif len(shapes) == 1: 49 | return shapes[0] 50 | result = _broadcast_shapes(shapes[0], shapes[1]) 51 | for i in range(2, len(shapes)): 52 | result = _broadcast_shapes(result, shapes[i]) 53 | return result 54 | -------------------------------------------------------------------------------- /array_api_tests/_array_module.py: -------------------------------------------------------------------------------- 1 | from . import stubs, xp 2 | 3 | 4 | class _UndefinedStub: 5 | """ 6 | Standing for undefined names, so the tests can be imported even if they 7 | fail 8 | 9 | If this object appears in a test failure, it means a name is not defined 10 | in a function. This typically happens for things like dtype literals not 11 | being defined. 12 | 13 | """ 14 | def __init__(self, name): 15 | self.name = name 16 | 17 | def _raise(self, *args, **kwargs): 18 | raise AssertionError(f"{self.name} is not defined in {xp.__name__}") 19 | 20 | def __repr__(self): 21 | return f"" 22 | 23 | __call__ = _raise 24 | __getattr__ = _raise 25 | 26 | _dtypes = [ 27 | "bool", 28 | "uint8", "uint16", "uint32", "uint64", 29 | "int8", "int16", "int32", "int64", 30 | "float32", "float64", 31 | "complex64", "complex128", 32 | ] 33 | _constants = ["e", "inf", "nan", "pi"] 34 | _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] 35 | _funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout 36 | _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"] 37 | 38 | for attr in _top_level_attrs: 39 | try: 40 | globals()[attr] = getattr(xp, attr) 41 | except AttributeError: 42 | globals()[attr] = _UndefinedStub(attr) 43 | -------------------------------------------------------------------------------- /array_api_tests/test_has_names.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a very basic test to see what names are defined in a library. It 3 | does not even require functioning hypothesis array_api support. 4 | """ 5 | 6 | import pytest 7 | 8 | from . import xp 9 | from .stubs import (array_attributes, array_methods, category_to_funcs, 10 | extension_to_funcs, EXTENSIONS) 11 | 12 | pytestmark = pytest.mark.ci 13 | 14 | has_name_params = [] 15 | for ext, stubs in extension_to_funcs.items(): 16 | for stub in stubs: 17 | has_name_params.append(pytest.param(ext, stub.__name__)) 18 | for cat, stubs in category_to_funcs.items(): 19 | for stub in stubs: 20 | has_name_params.append(pytest.param(cat, stub.__name__)) 21 | for meth in array_methods: 22 | has_name_params.append(pytest.param('array_method', meth.__name__)) 23 | for attr in array_attributes: 24 | has_name_params.append(pytest.param('array_attribute', attr)) 25 | 26 | @pytest.mark.parametrize("category, name", has_name_params) 27 | def test_has_names(category, name): 28 | if category in EXTENSIONS: 29 | ext_mod = getattr(xp, category) 30 | assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()" 31 | elif category.startswith('array_'): 32 | # TODO: This would fail if ones() is missing. 33 | arr = xp.ones((1, 1)) 34 | if category == 'array_attribute': 35 | assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}" 36 | else: 37 | assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()" 38 | else: 39 | assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()" 40 | -------------------------------------------------------------------------------- /array_api_tests/test_constants.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, SupportsFloat 3 | 4 | import pytest 5 | 6 | from . import dtype_helpers as dh 7 | from . import xp 8 | from .typing import Array 9 | 10 | pytestmark = pytest.mark.ci 11 | 12 | 13 | def assert_scalar_float(name: str, c: Any): 14 | assert isinstance(c, SupportsFloat), f"{name}={c!r} does not look like a float" 15 | 16 | 17 | def assert_0d_float(name: str, x: Array): 18 | assert dh.is_float_dtype( 19 | x.dtype 20 | ), f"xp.asarray(xp.{name})={x!r}, but should have float dtype" 21 | 22 | 23 | @pytest.mark.parametrize("name, n", [("e", math.e), ("pi", math.pi)]) 24 | def test_irrational_numbers(name, n): 25 | assert hasattr(xp, name) 26 | c = getattr(xp, name) 27 | assert_scalar_float(name, c) 28 | floor = math.floor(n) 29 | assert c > floor, f"xp.{name}={c!r} <= {floor}" 30 | ceil = math.ceil(n) 31 | assert c < ceil, f"xp.{name}={c!r} >= {ceil}" 32 | x = xp.asarray(c) 33 | assert_0d_float("name", x) 34 | 35 | 36 | def test_inf(): 37 | assert hasattr(xp, "inf") 38 | assert_scalar_float("inf", xp.inf) 39 | assert math.isinf(xp.inf) 40 | assert xp.inf > 0, "xp.inf not greater than 0" 41 | x = xp.asarray(xp.inf) 42 | assert_0d_float("inf", x) 43 | assert xp.isinf(x), "xp.isinf(xp.asarray(xp.inf))=False" 44 | 45 | 46 | def test_nan(): 47 | assert hasattr(xp, "nan") 48 | assert_scalar_float("nan", xp.nan) 49 | assert math.isnan(xp.nan) 50 | assert xp.nan != xp.nan, "xp.nan should not have equality with itself" 51 | x = xp.asarray(xp.nan) 52 | assert_0d_float("nan", x) 53 | assert xp.isnan(x), "xp.isnan(xp.asarray(xp.nan))=False" 54 | 55 | 56 | def test_newaxis(): 57 | assert hasattr(xp, "newaxis") 58 | assert xp.newaxis is None 59 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_signatures.py: -------------------------------------------------------------------------------- 1 | from inspect import Parameter, Signature, signature 2 | 3 | import pytest 4 | 5 | from ..test_signatures import _test_inspectable_func 6 | 7 | 8 | def stub(foo, /, bar=None, *, baz=None): 9 | pass 10 | 11 | 12 | stub_sig = signature(stub) 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "sig", 17 | [ 18 | Signature( 19 | [ 20 | Parameter("foo", Parameter.POSITIONAL_ONLY), 21 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), 22 | Parameter("baz", Parameter.KEYWORD_ONLY), 23 | ] 24 | ), 25 | Signature( 26 | [ 27 | Parameter("foo", Parameter.POSITIONAL_ONLY), 28 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), 29 | Parameter("baz", Parameter.POSITIONAL_OR_KEYWORD), 30 | ] 31 | ), 32 | Signature( 33 | [ 34 | Parameter("foo", Parameter.POSITIONAL_ONLY), 35 | Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), 36 | Parameter("qux", Parameter.KEYWORD_ONLY), 37 | Parameter("baz", Parameter.KEYWORD_ONLY), 38 | ] 39 | ), 40 | ], 41 | ) 42 | def test_good_sig_passes(sig): 43 | _test_inspectable_func(sig, stub_sig) 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "sig", 48 | [ 49 | Signature( 50 | [ 51 | Parameter("foo", Parameter.POSITIONAL_ONLY), 52 | Parameter("bar", Parameter.POSITIONAL_ONLY), 53 | Parameter("baz", Parameter.KEYWORD_ONLY), 54 | ] 55 | ), 56 | Signature( 57 | [ 58 | Parameter("foo", Parameter.POSITIONAL_ONLY), 59 | Parameter("bar", Parameter.KEYWORD_ONLY), 60 | Parameter("baz", Parameter.KEYWORD_ONLY), 61 | ] 62 | ), 63 | ], 64 | ) 65 | def test_raises_on_bad_sig(sig): 66 | with pytest.raises(AssertionError): 67 | _test_inspectable_func(sig, stub_sig) 68 | -------------------------------------------------------------------------------- /array_api_tests/test_indexing_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hypothesis import given, note 3 | from hypothesis import strategies as st 4 | 5 | from . import _array_module as xp 6 | from . import dtype_helpers as dh 7 | from . import hypothesis_helpers as hh 8 | from . import pytest_helpers as ph 9 | from . import shape_helpers as sh 10 | from . import xps 11 | 12 | pytestmark = pytest.mark.ci 13 | 14 | 15 | @pytest.mark.min_version("2022.12") 16 | @given( 17 | x=xps.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)), 18 | data=st.data(), 19 | ) 20 | def test_take(x, data): 21 | # TODO: 22 | # * negative axis 23 | # * negative indices 24 | # * different dtypes for indices 25 | axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") 26 | _indices = data.draw( 27 | st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), 28 | label="_indices", 29 | ) 30 | indices = xp.asarray(_indices, dtype=dh.default_int) 31 | note(f"{indices=}") 32 | 33 | out = xp.take(x, indices, axis=axis) 34 | 35 | ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) 36 | ph.assert_shape( 37 | "take", 38 | out_shape=out.shape, 39 | expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :], 40 | kw=dict( 41 | x=x, 42 | indices=indices, 43 | axis=axis, 44 | ), 45 | ) 46 | out_indices = sh.ndindex(out.shape) 47 | axis_indices = list(sh.axis_ndindex(x.shape, axis)) 48 | for axis_idx in axis_indices: 49 | f_axis_idx = sh.fmt_idx("x", axis_idx) 50 | for i in _indices: 51 | f_take_idx = sh.fmt_idx(f_axis_idx, i) 52 | indexed_x = x[axis_idx][i, ...] 53 | for at_idx in sh.ndindex(indexed_x.shape): 54 | out_idx = next(out_indices) 55 | ph.assert_0d_equals( 56 | "take", 57 | x_repr=sh.fmt_idx(f_take_idx, at_idx), 58 | x_val=indexed_x[at_idx], 59 | out_repr=sh.fmt_idx("out", out_idx), 60 | out_val=out[out_idx], 61 | ) 62 | # sanity check 63 | with pytest.raises(StopIteration): 64 | next(out_indices) 65 | -------------------------------------------------------------------------------- /array_api_tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import wraps 3 | from importlib import import_module 4 | 5 | from hypothesis import strategies as st 6 | from hypothesis.extra import array_api 7 | 8 | from . import _version 9 | 10 | __all__ = ["xp", "api_version", "xps"] 11 | 12 | 13 | # You can comment the following out and instead import the specific array module 14 | # you want to test, e.g. `import numpy.array_api as xp`. 15 | if "ARRAY_API_TESTS_MODULE" in os.environ: 16 | xp_name = os.environ["ARRAY_API_TESTS_MODULE"] 17 | _module, _sub = xp_name, None 18 | if "." in xp_name: 19 | _module, _sub = xp_name.split(".", 1) 20 | xp = import_module(_module) 21 | if _sub: 22 | try: 23 | xp = getattr(xp, _sub) 24 | except AttributeError: 25 | # _sub may be a submodule that needs to be imported. WE can't 26 | # do this in every case because some array modules are not 27 | # submodules that can be imported (like mxnet.nd). 28 | xp = import_module(xp_name) 29 | else: 30 | raise RuntimeError( 31 | "No array module specified - either edit __init__.py or set the " 32 | "ARRAY_API_TESTS_MODULE environment variable." 33 | ) 34 | 35 | 36 | # We monkey patch floats() to always disable subnormals as they are out-of-scope 37 | 38 | _floats = st.floats 39 | 40 | 41 | @wraps(_floats) 42 | def floats(*a, **kw): 43 | kw["allow_subnormal"] = False 44 | return _floats(*a, **kw) 45 | 46 | 47 | st.floats = floats 48 | 49 | 50 | # We do the same with xps.from_dtype() - this is not strictly necessary, as 51 | # the underlying floats() will never generate subnormals. We only do this 52 | # because internal logic in xps.from_dtype() assumes xp.finfo() has its 53 | # attributes as scalar floats, which is expected behaviour but disrupts many 54 | # unrelated tests. 55 | try: 56 | __from_dtype = array_api._from_dtype 57 | 58 | @wraps(__from_dtype) 59 | def _from_dtype(*a, **kw): 60 | kw["allow_subnormal"] = False 61 | return __from_dtype(*a, **kw) 62 | 63 | array_api._from_dtype = _from_dtype 64 | except AttributeError: 65 | # Ignore monkey patching if Hypothesis changes the private API 66 | pass 67 | 68 | 69 | api_version = os.getenv( 70 | "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12") 71 | ) 72 | xps = array_api.make_strategies_namespace(xp, api_version=api_version) 73 | 74 | __version__ = _version.get_versions()["version"] 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pytest-json-report 132 | .report.json 133 | -------------------------------------------------------------------------------- /array_api_tests/test_utility_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hypothesis import given 3 | from hypothesis import strategies as st 4 | 5 | from . import _array_module as xp 6 | from . import dtype_helpers as dh 7 | from . import hypothesis_helpers as hh 8 | from . import pytest_helpers as ph 9 | from . import shape_helpers as sh 10 | from . import xps 11 | 12 | pytestmark = pytest.mark.ci 13 | 14 | 15 | @given( 16 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)), 17 | data=st.data(), 18 | ) 19 | def test_all(x, data): 20 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") 21 | keepdims = kw.get("keepdims", False) 22 | 23 | out = xp.all(x, **kw) 24 | 25 | ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) 26 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 27 | ph.assert_keepdimable_shape( 28 | "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 29 | ) 30 | scalar_type = dh.get_scalar_type(x.dtype) 31 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 32 | result = bool(out[out_idx]) 33 | elements = [] 34 | for idx in indices: 35 | s = scalar_type(x[idx]) 36 | elements.append(s) 37 | expected = all(elements) 38 | ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, 39 | out=result, expected=expected, kw=kw) 40 | 41 | 42 | @given( 43 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), 44 | data=st.data(), 45 | ) 46 | def test_any(x, data): 47 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") 48 | keepdims = kw.get("keepdims", False) 49 | 50 | out = xp.any(x, **kw) 51 | 52 | ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) 53 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 54 | ph.assert_keepdimable_shape( 55 | "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw, 56 | ) 57 | scalar_type = dh.get_scalar_type(x.dtype) 58 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 59 | result = bool(out[out_idx]) 60 | elements = [] 61 | for idx in indices: 62 | s = scalar_type(x[idx]) 63 | elements.append(s) 64 | expected = any(elements) 65 | ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, 66 | out=result, expected=expected, kw=kw) 67 | -------------------------------------------------------------------------------- /array_api_tests/stubs.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import sys 3 | from importlib import import_module 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from types import FunctionType, ModuleType 7 | from typing import Dict, List 8 | 9 | from . import api_version 10 | 11 | __all__ = [ 12 | "name_to_func", 13 | "array_methods", 14 | "array_attributes", 15 | "category_to_funcs", 16 | "EXTENSIONS", 17 | "extension_to_funcs", 18 | ] 19 | 20 | spec_module = "_" + api_version.replace('.', '_') 21 | 22 | spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification" 23 | assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`" 24 | sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module 25 | assert sigs_dir.exists() 26 | 27 | sigs_abs_path: str = str(sigs_dir.parent.parent.resolve()) 28 | sys.path.append(sigs_abs_path) 29 | assert find_spec(f"array_api_stubs.{spec_module}") is not None 30 | 31 | name_to_mod: Dict[str, ModuleType] = {} 32 | for path in sigs_dir.glob("*.py"): 33 | name = path.name.replace(".py", "") 34 | name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}") 35 | 36 | array = name_to_mod["array_object"].array 37 | array_methods = [ 38 | f for n, f in inspect.getmembers(array, predicate=inspect.isfunction) 39 | if n != "__init__" # probably exists for Sphinx 40 | ] 41 | array_attributes = [ 42 | n for n, f in inspect.getmembers(array, predicate=lambda x: isinstance(x, property)) 43 | ] 44 | 45 | category_to_funcs: Dict[str, List[FunctionType]] = {} 46 | for name, mod in name_to_mod.items(): 47 | if name.endswith("_functions"): 48 | category = name.replace("_functions", "") 49 | objects = [getattr(mod, name) for name in mod.__all__] 50 | assert all(isinstance(o, FunctionType) for o in objects) # sanity check 51 | category_to_funcs[category] = objects 52 | 53 | all_funcs = [] 54 | for funcs in [array_methods, *category_to_funcs.values()]: 55 | all_funcs.extend(funcs) 56 | name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} 57 | 58 | EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available 59 | extension_to_funcs: Dict[str, List[FunctionType]] = {} 60 | for ext in EXTENSIONS: 61 | mod = name_to_mod[ext] 62 | objects = [getattr(mod, name) for name in mod.__all__] 63 | assert all(isinstance(o, FunctionType) for o in objects) # sanity check 64 | funcs = [] 65 | for func in objects: 66 | if "Alias" in func.__doc__: 67 | funcs.append(name_to_func[func.__name__]) 68 | else: 69 | funcs.append(func) 70 | extension_to_funcs[ext] = funcs 71 | 72 | for funcs in extension_to_funcs.values(): 73 | for func in funcs: 74 | if func.__name__ not in name_to_func.keys(): 75 | name_to_func[func.__name__] = func 76 | 77 | # sanity check public attributes are not empty 78 | for attr in __all__: 79 | assert len(locals()[attr]) != 0, f"{attr} is empty" 80 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hypothesis import given 3 | from hypothesis import strategies as st 4 | 5 | from .. import _array_module as xp 6 | from .. import dtype_helpers as dh 7 | from .. import hypothesis_helpers as hh 8 | from .. import shape_helpers as sh 9 | from .. import xps 10 | from ..test_creation_functions import frange 11 | from ..test_manipulation_functions import roll_ndindex 12 | from ..test_operators_and_elementwise_functions import mock_int_dtype 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "r, size, elements", 17 | [ 18 | (frange(0, 1, 1), 1, [0]), 19 | (frange(1, 0, -1), 1, [1]), 20 | (frange(0, 1, -1), 0, []), 21 | (frange(0, 1, 2), 1, [0]), 22 | ], 23 | ) 24 | def test_frange(r, size, elements): 25 | assert len(r) == size 26 | assert list(r) == elements 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "shape, expected", 31 | [((), [()])], 32 | ) 33 | def test_ndindex(shape, expected): 34 | assert list(sh.ndindex(shape)) == expected 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "shape, axis, expected", 39 | [ 40 | ((1,), 0, [(slice(None, None),)]), 41 | ((1, 2), 0, [(slice(None, None), slice(None, None))]), 42 | ( 43 | (2, 4), 44 | 1, 45 | [(0, slice(None, None)), (1, slice(None, None))], 46 | ), 47 | ], 48 | ) 49 | def test_axis_ndindex(shape, axis, expected): 50 | assert list(sh.axis_ndindex(shape, axis)) == expected 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "shape, axes, expected", 55 | [ 56 | ((), (), [[()]]), 57 | ((1,), (0,), [[(0,)]]), 58 | ( 59 | (2, 2), 60 | (0,), 61 | [ 62 | [(0, 0), (1, 0)], 63 | [(0, 1), (1, 1)], 64 | ], 65 | ), 66 | ], 67 | ) 68 | def test_axes_ndindex(shape, axes, expected): 69 | assert list(sh.axes_ndindex(shape, axes)) == expected 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "shape, shifts, axes, expected", 74 | [ 75 | ((1, 1), (0,), (0,), [(0, 0)]), 76 | ((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]), 77 | ((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), 78 | ((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), 79 | ], 80 | ) 81 | def test_roll_ndindex(shape, shifts, axes, expected): 82 | assert list(roll_ndindex(shape, shifts, axes)) == expected 83 | 84 | 85 | @pytest.mark.parametrize( 86 | "idx, expected", 87 | [ 88 | ((), "x"), 89 | (42, "x[42]"), 90 | ((42,), "x[42]"), 91 | ((42, 7), "x[42, 7]"), 92 | (slice(None, 2), "x[:2]"), 93 | (slice(2, None), "x[2:]"), 94 | (slice(0, 2), "x[0:2]"), 95 | (slice(0, 2, -1), "x[0:2:-1]"), 96 | (slice(None, None, -1), "x[::-1]"), 97 | (slice(None, None), "x[:]"), 98 | (..., "x[...]"), 99 | ((None, 42), "x[None, 42]"), 100 | ], 101 | ) 102 | def test_fmt_idx(idx, expected): 103 | assert sh.fmt_idx("x", idx) == expected 104 | 105 | 106 | @given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes()) 107 | def test_int_to_dtype(x, dtype): 108 | with hh.reject_overflow(): 109 | d = xp.asarray(x, dtype=dtype) 110 | assert mock_int_dtype(x, dtype) == d 111 | 112 | 113 | @given(hh.oneway_promotable_dtypes(dh.all_dtypes)) 114 | def test_oneway_promotable_dtypes(D): 115 | assert D.result_dtype == dh.result_type(*D) 116 | 117 | 118 | @given(hh.oneway_broadcastable_shapes()) 119 | def test_oneway_broadcastable_shapes(S): 120 | assert S.result_shape == sh.broadcast_shapes(*S) 121 | -------------------------------------------------------------------------------- /numpy-skips.txt: -------------------------------------------------------------------------------- 1 | # copy not implemented 2 | array_api_tests/test_creation_functions.py::test_asarray_arrays 3 | # https://github.com/numpy/numpy/issues/20870 4 | array_api_tests/test_data_type_functions.py::test_can_cast 5 | # The return dtype for trace is not consistent in the spec 6 | # https://github.com/data-apis/array-api/issues/202#issuecomment-952529197 7 | array_api_tests/test_linalg.py::test_trace 8 | # waiting on NumPy to allow/revert distinct NaNs for np.unique 9 | # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 10 | array_api_tests/test_set_functions.py 11 | 12 | # newaxis not included in numpy namespace as of v1.26.2 13 | array_api_tests/test_constants.py::test_newaxis 14 | 15 | # linalg.solve issue in numpy.array_api as of v1.26.2 (see numpy#25146) 16 | array_api_tests/test_linalg.py::test_solve 17 | 18 | # https://github.com/numpy/numpy/issues/21373 19 | array_api_tests/test_array_object.py::test_getitem 20 | 21 | # missing copy arg 22 | array_api_tests/test_signatures.py::test_func_signature[reshape] 23 | 24 | # https://github.com/numpy/numpy/issues/21211 25 | array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] 26 | # https://github.com/numpy/numpy/issues/21213 27 | array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] 28 | array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] 29 | # noted diversions from spec 30 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 31 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 32 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 33 | array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 34 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 35 | array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 36 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 37 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 38 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 39 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 40 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 41 | array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 42 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] 43 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] 44 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] 45 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] 46 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] 47 | array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] 48 | -------------------------------------------------------------------------------- /reporting.py: -------------------------------------------------------------------------------- 1 | from array_api_tests.dtype_helpers import dtype_to_name 2 | from array_api_tests import _array_module as xp 3 | from array_api_tests import __version__ 4 | 5 | from collections import Counter 6 | from types import BuiltinFunctionType, FunctionType 7 | import dataclasses 8 | import json 9 | import warnings 10 | 11 | from hypothesis.strategies import SearchStrategy 12 | 13 | from pytest import hookimpl, fixture 14 | try: 15 | import pytest_jsonreport # noqa 16 | except ImportError: 17 | raise ImportError("pytest-json-report is required to run the array API tests") 18 | 19 | def to_json_serializable(o): 20 | if o in dtype_to_name: 21 | return dtype_to_name[o] 22 | if isinstance(o, (BuiltinFunctionType, FunctionType, type)): 23 | return o.__name__ 24 | if dataclasses.is_dataclass(o): 25 | return to_json_serializable(dataclasses.asdict(o)) 26 | if isinstance(o, SearchStrategy): 27 | return repr(o) 28 | if isinstance(o, dict): 29 | return {to_json_serializable(k): to_json_serializable(v) for k, v in o.items()} 30 | if isinstance(o, tuple): 31 | if hasattr(o, '_asdict'): # namedtuple 32 | return to_json_serializable(o._asdict()) 33 | return tuple(to_json_serializable(i) for i in o) 34 | if isinstance(o, list): 35 | return [to_json_serializable(i) for i in o] 36 | 37 | # Ensure everything is JSON serializable. If this warning is issued, it 38 | # means the given type needs to be added above if possible. 39 | try: 40 | json.dumps(o) 41 | except TypeError: 42 | warnings.warn(f"{o!r} (of type {type(o)}) is not JSON-serializable. Using the repr instead.") 43 | return repr(o) 44 | 45 | return o 46 | 47 | @hookimpl(optionalhook=True) 48 | def pytest_metadata(metadata): 49 | """ 50 | Additional global metadata for --json-report. 51 | """ 52 | metadata['array_api_tests_module'] = xp.__name__ 53 | metadata['array_api_tests_version'] = __version__ 54 | 55 | @fixture(autouse=True) 56 | def add_extra_json_metadata(request, json_metadata): 57 | """ 58 | Additional per-test metadata for --json-report 59 | """ 60 | def add_metadata(name, obj): 61 | obj = to_json_serializable(obj) 62 | json_metadata[name] = obj 63 | 64 | test_module = request.module.__name__ 65 | if test_module.startswith('array_api_tests.meta'): 66 | return 67 | 68 | test_function = request.function.__name__ 69 | assert test_function.startswith('test_'), 'unexpected test function name' 70 | 71 | if test_module == 'array_api_tests.test_has_names': 72 | array_api_function_name = None 73 | else: 74 | array_api_function_name = test_function[len('test_'):] 75 | 76 | add_metadata('test_module', test_module) 77 | add_metadata('test_function', test_function) 78 | add_metadata('array_api_function_name', array_api_function_name) 79 | 80 | if hasattr(request.node, 'callspec'): 81 | params = request.node.callspec.params 82 | add_metadata('params', params) 83 | 84 | def finalizer(): 85 | # TODO: This metadata is all in the form of error strings. It might be 86 | # nice to extract the hypothesis failing inputs directly somehow. 87 | if hasattr(request.node, 'hypothesis_report_information'): 88 | add_metadata('hypothesis_report_information', request.node.hypothesis_report_information) 89 | if hasattr(request.node, 'hypothesis_statistics'): 90 | add_metadata('hypothesis_statistics', request.node.hypothesis_statistics) 91 | 92 | request.addfinalizer(finalizer) 93 | 94 | @hookimpl(optionalhook=True) 95 | def pytest_json_modifyreport(json_report): 96 | # Deduplicate warnings. These duplicate warnings can cause the file size 97 | # to become huge. For instance, a warning from np.bool which is emitted 98 | # every time hypothesis runs (over a million times) causes the warnings 99 | # JSON for a plain numpy namespace run to be over 500MB. 100 | 101 | # This will lose information about what order the warnings were issued in, 102 | # but that isn't particularly helpful anyway since the warning metadata 103 | # doesn't store a full stack of where it was issued from. The resulting 104 | # warnings will be in order of the first time each warning is issued since 105 | # collections.Counter is ordered just like dict(). 106 | counted_warnings = Counter([frozenset(i.items()) for i in json_report['warnings']]) 107 | deduped_warnings = [{**dict(i), 'count': counted_warnings[i]} for i in counted_warnings] 108 | 109 | json_report['warnings'] = deduped_warnings 110 | -------------------------------------------------------------------------------- /array_api_tests/test_sorting_functions.py: -------------------------------------------------------------------------------- 1 | import cmath 2 | from typing import Set 3 | 4 | import pytest 5 | from hypothesis import given 6 | from hypothesis import strategies as st 7 | from hypothesis.control import assume 8 | 9 | from . import _array_module as xp 10 | from . import dtype_helpers as dh 11 | from . import hypothesis_helpers as hh 12 | from . import pytest_helpers as ph 13 | from . import shape_helpers as sh 14 | from . import xps 15 | from .typing import Scalar, Shape 16 | 17 | pytestmark = pytest.mark.ci 18 | 19 | 20 | def assert_scalar_in_set( 21 | func_name: str, 22 | idx: Shape, 23 | out: Scalar, 24 | set_: Set[Scalar], 25 | kw={}, 26 | ): 27 | out_repr = "out" if idx == () else f"out[{idx}]" 28 | if cmath.isnan(out): 29 | raise NotImplementedError() 30 | msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]" 31 | assert out in set_, msg 32 | 33 | 34 | # TODO: Test with signed zeros and NaNs (and ignore them somehow) 35 | @given( 36 | x=xps.arrays( 37 | dtype=xps.real_dtypes(), 38 | shape=hh.shapes(min_dims=1, min_side=1), 39 | elements={"allow_nan": False}, 40 | ), 41 | data=st.data(), 42 | ) 43 | def test_argsort(x, data): 44 | if dh.is_float_dtype(x.dtype): 45 | assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) 46 | 47 | kw = data.draw( 48 | hh.kwargs( 49 | axis=st.integers(-x.ndim, x.ndim - 1), 50 | descending=st.booleans(), 51 | stable=st.booleans(), 52 | ), 53 | label="kw", 54 | ) 55 | 56 | out = xp.argsort(x, **kw) 57 | 58 | ph.assert_default_index("argsort", out.dtype) 59 | ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw) 60 | axis = kw.get("axis", -1) 61 | axes = sh.normalise_axis(axis, x.ndim) 62 | scalar_type = dh.get_scalar_type(x.dtype) 63 | for indices in sh.axes_ndindex(x.shape, axes): 64 | elements = [scalar_type(x[idx]) for idx in indices] 65 | orders = list(range(len(elements))) 66 | sorders = sorted( 67 | orders, key=elements.__getitem__, reverse=kw.get("descending", False) 68 | ) 69 | if kw.get("stable", True): 70 | for idx, o in zip(indices, sorders): 71 | ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw) 72 | else: 73 | idx_elements = dict(zip(indices, elements)) 74 | idx_orders = dict(zip(indices, orders)) 75 | element_orders = {} 76 | for e in set(elements): 77 | element_orders[e] = [ 78 | idx_orders[idx] for idx in indices if idx_elements[idx] == e 79 | ] 80 | selements = [elements[o] for o in sorders] 81 | for idx, e in zip(indices, selements): 82 | expected_orders = element_orders[e] 83 | out_o = int(out[idx]) 84 | if len(expected_orders) == 1: 85 | ph.assert_scalar_equals( 86 | "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw 87 | ) 88 | else: 89 | assert_scalar_in_set( 90 | "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw 91 | ) 92 | 93 | 94 | # TODO: Test with signed zeros and NaNs (and ignore them somehow) 95 | @given( 96 | x=xps.arrays( 97 | dtype=xps.real_dtypes(), 98 | shape=hh.shapes(min_dims=1, min_side=1), 99 | elements={"allow_nan": False}, 100 | ), 101 | data=st.data(), 102 | ) 103 | def test_sort(x, data): 104 | if dh.is_float_dtype(x.dtype): 105 | assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) 106 | 107 | kw = data.draw( 108 | hh.kwargs( 109 | axis=st.integers(-x.ndim, x.ndim - 1), 110 | descending=st.booleans(), 111 | stable=st.booleans(), 112 | ), 113 | label="kw", 114 | ) 115 | 116 | out = xp.sort(x, **kw) 117 | 118 | ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype) 119 | ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw) 120 | axis = kw.get("axis", -1) 121 | axes = sh.normalise_axis(axis, x.ndim) 122 | scalar_type = dh.get_scalar_type(x.dtype) 123 | for indices in sh.axes_ndindex(x.shape, axes): 124 | elements = [scalar_type(x[idx]) for idx in indices] 125 | size = len(elements) 126 | orders = sorted( 127 | range(size), key=elements.__getitem__, reverse=kw.get("descending", False) 128 | ) 129 | for out_idx, o in zip(indices, orders): 130 | x_idx = indices[o] 131 | # TODO: error message when unstable should not imply just one idx 132 | ph.assert_0d_equals( 133 | "sort", 134 | x_repr=f"x[{x_idx}]", 135 | x_val=x[x_idx], 136 | out_repr=f"out[{out_idx}]", 137 | out_val=out[out_idx], 138 | kw=kw, 139 | ) 140 | -------------------------------------------------------------------------------- /array_api_tests/meta/test_hypothesis_helpers.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | from typing import Type 3 | 4 | import pytest 5 | from hypothesis import given, settings 6 | from hypothesis import strategies as st 7 | from hypothesis.errors import Unsatisfiable 8 | 9 | from .. import _array_module as xp 10 | from .. import array_helpers as ah 11 | from .. import dtype_helpers as dh 12 | from .. import hypothesis_helpers as hh 13 | from .. import shape_helpers as sh 14 | from .. import xps 15 | from .._array_module import _UndefinedStub 16 | 17 | UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) 18 | pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] 19 | 20 | @given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes)) 21 | def test_mutually_promotable_dtypes(pair): 22 | assert pair in ( 23 | (xp.float32, xp.float32), 24 | (xp.float32, xp.float64), 25 | (xp.float64, xp.float32), 26 | (xp.float64, xp.float64), 27 | ) 28 | 29 | 30 | @given( 31 | hh.mutually_promotable_dtypes( 32 | dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32] 33 | ) 34 | ) 35 | def test_partial_mutually_promotable_dtypes(pair): 36 | assert pair in ( 37 | (xp.uint8, xp.uint8), 38 | (xp.uint8, xp.uint32), 39 | (xp.uint32, xp.uint8), 40 | (xp.uint32, xp.uint32), 41 | ) 42 | 43 | 44 | def valid_shape(shape) -> bool: 45 | return ( 46 | all(isinstance(side, int) for side in shape) 47 | and all(side >= 0 for side in shape) 48 | and prod(shape) < hh.MAX_ARRAY_SIZE 49 | ) 50 | 51 | 52 | @given(hh.shapes()) 53 | def test_shapes(shape): 54 | assert valid_shape(shape) 55 | 56 | 57 | @given(hh.two_mutually_broadcastable_shapes) 58 | def test_two_mutually_broadcastable_shapes(pair): 59 | for shape in pair: 60 | assert valid_shape(shape) 61 | 62 | 63 | @given(hh.two_broadcastable_shapes()) 64 | def test_two_broadcastable_shapes(pair): 65 | for shape in pair: 66 | assert valid_shape(shape) 67 | assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0] 68 | 69 | 70 | @given(*hh.two_mutual_arrays()) 71 | def test_two_mutual_arrays(x1, x2): 72 | assert (x1.dtype, x2.dtype) in dh.promotion_table.keys() 73 | 74 | 75 | def test_two_mutual_arrays_raises_on_bad_dtypes(): 76 | with pytest.raises(TypeError): 77 | hh.two_mutual_arrays(dtypes=xps.scalar_dtypes()) 78 | 79 | 80 | def test_kwargs(): 81 | results = [] 82 | 83 | @given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]"))) 84 | @settings(max_examples=100) 85 | def run(kw): 86 | results.append(kw) 87 | run() 88 | 89 | assert all(isinstance(kw, dict) for kw in results) 90 | for size in [0, 1, 2]: 91 | assert any(len(kw) == size for kw in results) 92 | 93 | n_results = [kw for kw in results if "n" in kw] 94 | assert len(n_results) > 0 95 | assert all(isinstance(kw["n"], int) for kw in n_results) 96 | 97 | c_results = [kw for kw in results if "c" in kw] 98 | assert len(c_results) > 0 99 | assert all(isinstance(kw["c"], str) for kw in c_results) 100 | 101 | 102 | def test_specified_kwargs(): 103 | results = [] 104 | 105 | @given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data()) 106 | @settings(max_examples=100) 107 | def run(n, d, data): 108 | kw = data.draw( 109 | hh.specified_kwargs( 110 | hh.KVD("n", n, 0), 111 | hh.KVD("d", d, None), 112 | ), 113 | label="kw", 114 | ) 115 | results.append(kw) 116 | run() 117 | 118 | assert all(isinstance(kw, dict) for kw in results) 119 | 120 | assert any(len(kw) == 0 for kw in results) 121 | 122 | assert any("n" not in kw.keys() for kw in results) 123 | assert any("n" in kw.keys() and kw["n"] == 0 for kw in results) 124 | assert any("n" in kw.keys() and kw["n"] != 0 for kw in results) 125 | 126 | assert any("d" not in kw.keys() for kw in results) 127 | assert any("d" in kw.keys() and kw["d"] is None for kw in results) 128 | assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results) 129 | 130 | 131 | 132 | @given(m=hh.symmetric_matrices(hh.shared_floating_dtypes, 133 | finite=st.shared(st.booleans(), key='finite')), 134 | dtype=hh.shared_floating_dtypes, 135 | finite=st.shared(st.booleans(), key='finite')) 136 | def test_symmetric_matrices(m, dtype, finite): 137 | assert m.dtype == dtype 138 | # TODO: This part of this test should be part of the .mT test 139 | ah.assert_exactly_equal(m, m.mT) 140 | 141 | if finite: 142 | ah.assert_finite(m) 143 | 144 | @given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes), 145 | dtype=hh.shared_floating_dtypes) 146 | def test_positive_definite_matrices(m, dtype): 147 | assert m.dtype == dtype 148 | # TODO: Test that it actually is positive definite 149 | 150 | 151 | def make_raising_func(cls: Type[Exception], msg: str): 152 | def raises(): 153 | raise cls(msg) 154 | 155 | return raises 156 | 157 | @pytest.mark.parametrize( 158 | "func", 159 | [ 160 | make_raising_func(OverflowError, "foo"), 161 | make_raising_func(RuntimeError, "Overflow when unpacking long"), 162 | make_raising_func(Exception, "Got an overflow"), 163 | ] 164 | ) 165 | def test_reject_overflow(func): 166 | @given(data=st.data()) 167 | def test_case(data): 168 | with hh.reject_overflow(): 169 | func() 170 | 171 | with pytest.raises(Unsatisfiable): 172 | test_case() 173 | -------------------------------------------------------------------------------- /array_api_tests/shape_helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from itertools import product 3 | from typing import Iterator, List, Optional, Sequence, Tuple, Union 4 | 5 | from ndindex import iter_indices as _iter_indices 6 | 7 | from .typing import AtomicIndex, Index, Scalar, Shape 8 | 9 | __all__ = [ 10 | "broadcast_shapes", 11 | "normalise_axis", 12 | "ndindex", 13 | "axis_ndindex", 14 | "axes_ndindex", 15 | "reshape", 16 | "fmt_idx", 17 | ] 18 | 19 | 20 | class BroadcastError(ValueError): 21 | """Shapes do not broadcast with eachother""" 22 | 23 | 24 | def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape: 25 | """Broadcasts `shape1` and `shape2`""" 26 | N1 = len(shape1) 27 | N2 = len(shape2) 28 | N = max(N1, N2) 29 | shape = [None for _ in range(N)] 30 | i = N - 1 31 | while i >= 0: 32 | n1 = N1 - N + i 33 | if N1 - N + i >= 0: 34 | d1 = shape1[n1] 35 | else: 36 | d1 = 1 37 | n2 = N2 - N + i 38 | if N2 - N + i >= 0: 39 | d2 = shape2[n2] 40 | else: 41 | d2 = 1 42 | 43 | if d1 == 1: 44 | shape[i] = d2 45 | elif d2 == 1: 46 | shape[i] = d1 47 | elif d1 == d2: 48 | shape[i] = d1 49 | else: 50 | raise BroadcastError() 51 | 52 | i = i - 1 53 | 54 | return tuple(shape) 55 | 56 | 57 | def broadcast_shapes(*shapes: Shape): 58 | if len(shapes) == 0: 59 | raise ValueError("shapes=[] must be non-empty") 60 | elif len(shapes) == 1: 61 | return shapes[0] 62 | result = _broadcast_shapes(shapes[0], shapes[1]) 63 | for i in range(2, len(shapes)): 64 | result = _broadcast_shapes(result, shapes[i]) 65 | return result 66 | 67 | 68 | def normalise_axis( 69 | axis: Optional[Union[int, Sequence[int]]], ndim: int 70 | ) -> Tuple[int, ...]: 71 | if axis is None: 72 | return tuple(range(ndim)) 73 | elif isinstance(axis, Sequence) and not isinstance(axis, tuple): 74 | axis = tuple(axis) 75 | axes = axis if isinstance(axis, tuple) else (axis,) 76 | axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) 77 | return axes 78 | 79 | 80 | def ndindex(shape: Shape) -> Iterator[Index]: 81 | """Yield every index of a shape""" 82 | return (indices[0] for indices in iter_indices(shape)) 83 | 84 | 85 | def iter_indices( 86 | *shapes: Shape, skip_axes: Tuple[int, ...] = () 87 | ) -> Iterator[Tuple[Index, ...]]: 88 | """Wrapper for ndindex.iter_indices()""" 89 | # Prevent iterations if any shape has 0-sides 90 | for shape in shapes: 91 | if 0 in shape: 92 | return 93 | for indices in _iter_indices(*shapes, skip_axes=skip_axes): 94 | yield tuple(i.raw for i in indices) # type: ignore 95 | 96 | 97 | def axis_ndindex( 98 | shape: Shape, axis: int 99 | ) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: 100 | """Generate indices that index all elements in dimensions beyond `axis`""" 101 | assert axis >= 0 # sanity check 102 | axis_indices = [range(side) for side in shape[:axis]] 103 | for _ in range(axis, len(shape)): 104 | axis_indices.append([slice(None, None)]) 105 | yield from product(*axis_indices) 106 | 107 | 108 | def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: 109 | """Generate indices that index all elements except in `axes` dimensions""" 110 | base_indices = [] 111 | axes_indices = [] 112 | for axis, side in enumerate(shape): 113 | if axis in axes: 114 | base_indices.append([None]) 115 | axes_indices.append(range(side)) 116 | else: 117 | base_indices.append(range(side)) 118 | axes_indices.append([None]) 119 | for base_idx in product(*base_indices): 120 | indices = [] 121 | for idx in product(*axes_indices): 122 | idx = list(idx) 123 | for axis, side in enumerate(idx): 124 | if axis not in axes: 125 | idx[axis] = base_idx[axis] 126 | idx = tuple(idx) 127 | indices.append(idx) 128 | yield list(indices) 129 | 130 | 131 | def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]: 132 | """Reshape a flat sequence""" 133 | if any(s == 0 for s in shape): 134 | raise ValueError( 135 | f"{shape=} contains 0-sided dimensions, " 136 | f"but that's not representable in lists" 137 | ) 138 | if len(shape) == 0: 139 | assert len(flat_seq) == 1 # sanity check 140 | return flat_seq[0] 141 | elif len(shape) == 1: 142 | return flat_seq 143 | size = len(flat_seq) 144 | n = math.prod(shape[1:]) 145 | return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)] 146 | 147 | 148 | def fmt_i(i: AtomicIndex) -> str: 149 | if isinstance(i, int): 150 | return str(i) 151 | elif isinstance(i, slice): 152 | res = "" 153 | if i.start is not None: 154 | res += str(i.start) 155 | res += ":" 156 | if i.stop is not None: 157 | res += str(i.stop) 158 | if i.step is not None: 159 | res += f":{i.step}" 160 | return res 161 | elif i is None: 162 | return "None" 163 | else: 164 | return "..." 165 | 166 | 167 | def fmt_idx(sym: str, idx: Index) -> str: 168 | if idx == (): 169 | return sym 170 | res = f"{sym}[" 171 | _idx = idx if isinstance(idx, tuple) else (idx,) 172 | if len(_idx) == 1: 173 | res += fmt_i(_idx[0]) 174 | else: 175 | res += ", ".join(fmt_i(i) for i in _idx) 176 | res += "]" 177 | return res 178 | -------------------------------------------------------------------------------- /array_api_tests/test_searching_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | from hypothesis import given 5 | from hypothesis import strategies as st 6 | 7 | from . import _array_module as xp 8 | from . import dtype_helpers as dh 9 | from . import hypothesis_helpers as hh 10 | from . import pytest_helpers as ph 11 | from . import shape_helpers as sh 12 | from . import xps 13 | 14 | pytestmark = pytest.mark.ci 15 | 16 | 17 | @given( 18 | x=xps.arrays( 19 | dtype=xps.real_dtypes(), 20 | shape=hh.shapes(min_dims=1, min_side=1), 21 | elements={"allow_nan": False}, 22 | ), 23 | data=st.data(), 24 | ) 25 | def test_argmax(x, data): 26 | kw = data.draw( 27 | hh.kwargs( 28 | axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), 29 | keepdims=st.booleans(), 30 | ), 31 | label="kw", 32 | ) 33 | keepdims = kw.get("keepdims", False) 34 | 35 | out = xp.argmax(x, **kw) 36 | 37 | ph.assert_default_index("argmax", out.dtype) 38 | axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 39 | ph.assert_keepdimable_shape( 40 | "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw 41 | ) 42 | scalar_type = dh.get_scalar_type(x.dtype) 43 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): 44 | max_i = int(out[out_idx]) 45 | elements = [] 46 | for idx in indices: 47 | s = scalar_type(x[idx]) 48 | elements.append(s) 49 | expected = max(range(len(elements)), key=elements.__getitem__) 50 | ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i, 51 | expected=expected, kw=kw) 52 | 53 | 54 | @given( 55 | x=xps.arrays( 56 | dtype=xps.real_dtypes(), 57 | shape=hh.shapes(min_dims=1, min_side=1), 58 | elements={"allow_nan": False}, 59 | ), 60 | data=st.data(), 61 | ) 62 | def test_argmin(x, data): 63 | kw = data.draw( 64 | hh.kwargs( 65 | axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), 66 | keepdims=st.booleans(), 67 | ), 68 | label="kw", 69 | ) 70 | keepdims = kw.get("keepdims", False) 71 | 72 | out = xp.argmin(x, **kw) 73 | 74 | ph.assert_default_index("argmin", out.dtype) 75 | axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 76 | ph.assert_keepdimable_shape( 77 | "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw 78 | ) 79 | scalar_type = dh.get_scalar_type(x.dtype) 80 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): 81 | min_i = int(out[out_idx]) 82 | elements = [] 83 | for idx in indices: 84 | s = scalar_type(x[idx]) 85 | elements.append(s) 86 | expected = min(range(len(elements)), key=elements.__getitem__) 87 | ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) 88 | 89 | 90 | @pytest.mark.data_dependent_shapes 91 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) 92 | def test_nonzero(x): 93 | out = xp.nonzero(x) 94 | if x.ndim == 0: 95 | assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" 96 | else: 97 | assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" 98 | out_size = math.prod(out[0].shape) 99 | for i in range(len(out)): 100 | assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" 101 | size_at = math.prod(out[i].shape) 102 | assert size_at == out_size, ( 103 | f"prod(out[{i}].shape)={size_at}, " 104 | f"but should be prod(out[0].shape)={out_size}" 105 | ) 106 | ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") 107 | indices = [] 108 | if x.dtype == xp.bool: 109 | for idx in sh.ndindex(x.shape): 110 | if x[idx]: 111 | indices.append(idx) 112 | else: 113 | for idx in sh.ndindex(x.shape): 114 | if x[idx] != 0: 115 | indices.append(idx) 116 | if x.ndim == 0: 117 | assert out_size == len( 118 | indices 119 | ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" 120 | else: 121 | for i in range(out_size): 122 | idx = tuple(int(x[i]) for x in out) 123 | f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" 124 | f_element = f"x[{idx}]={x[idx]}" 125 | assert idx in indices, f"{f_idx} results in {f_element}, a zero element" 126 | assert ( 127 | idx == indices[i] 128 | ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" 129 | 130 | 131 | @given( 132 | shapes=hh.mutually_broadcastable_shapes(3), 133 | dtypes=hh.mutually_promotable_dtypes(), 134 | data=st.data(), 135 | ) 136 | def test_where(shapes, dtypes, data): 137 | cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") 138 | x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") 139 | x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") 140 | 141 | out = xp.where(cond, x1, x2) 142 | 143 | shape = sh.broadcast_shapes(*shapes) 144 | ph.assert_shape("where", out_shape=out.shape, expected=shape) 145 | # TODO: generate indices without broadcasting arrays 146 | _cond = xp.broadcast_to(cond, shape) 147 | _x1 = xp.broadcast_to(x1, shape) 148 | _x2 = xp.broadcast_to(x2, shape) 149 | for idx in sh.ndindex(shape): 150 | if _cond[idx]: 151 | ph.assert_0d_equals( 152 | "where", 153 | x_repr=f"_x1[{idx}]", 154 | x_val=_x1[idx], 155 | out_repr=f"out[{idx}]", 156 | out_val=out[idx] 157 | ) 158 | else: 159 | ph.assert_0d_equals( 160 | "where", 161 | x_repr=f"_x2[{idx}]", 162 | x_val=_x2[idx], 163 | out_repr=f"out[{idx}]", 164 | out_val=out[idx] 165 | ) 166 | -------------------------------------------------------------------------------- /array_api_tests/test_data_type_functions.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from typing import Union 3 | 4 | import pytest 5 | from hypothesis import given 6 | from hypothesis import strategies as st 7 | 8 | from . import _array_module as xp 9 | from . import dtype_helpers as dh 10 | from . import hypothesis_helpers as hh 11 | from . import pytest_helpers as ph 12 | from . import shape_helpers as sh 13 | from . import xps 14 | from . import xp as _xp 15 | from .typing import DataType 16 | 17 | pytestmark = pytest.mark.ci 18 | 19 | 20 | # TODO: test with complex dtypes 21 | def non_complex_dtypes(): 22 | return xps.boolean_dtypes() | xps.real_dtypes() 23 | 24 | 25 | def float32(n: Union[int, float]) -> float: 26 | return struct.unpack("!f", struct.pack("!f", float(n)))[0] 27 | 28 | 29 | @given( 30 | x_dtype=non_complex_dtypes(), 31 | dtype=non_complex_dtypes(), 32 | kw=hh.kwargs(copy=st.booleans()), 33 | data=st.data(), 34 | ) 35 | def test_astype(x_dtype, dtype, kw, data): 36 | if xp.bool in (x_dtype, dtype): 37 | elements_strat = xps.from_dtype(x_dtype) 38 | else: 39 | m1, M1 = dh.dtype_ranges[x_dtype] 40 | m2, M2 = dh.dtype_ranges[dtype] 41 | if dh.is_int_dtype(x_dtype): 42 | cast = int 43 | elif x_dtype == xp.float32: 44 | cast = float32 45 | else: 46 | cast = float 47 | min_value = cast(max(m1, m2)) 48 | max_value = cast(min(M1, M2)) 49 | elements_strat = xps.from_dtype( 50 | x_dtype, 51 | min_value=min_value, 52 | max_value=max_value, 53 | allow_nan=False, 54 | allow_infinity=False, 55 | ) 56 | x = data.draw( 57 | xps.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x" 58 | ) 59 | 60 | out = xp.astype(x, dtype, **kw) 61 | 62 | ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) 63 | ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) 64 | # TODO: test values 65 | # TODO: test copy 66 | 67 | 68 | @given( 69 | shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes), data=st.data() 70 | ) 71 | def test_broadcast_arrays(shapes, data): 72 | arrays = [] 73 | for c, shape in enumerate(shapes, 1): 74 | x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}") 75 | arrays.append(x) 76 | 77 | out = xp.broadcast_arrays(*arrays) 78 | 79 | expected_shape = sh.broadcast_shapes(*shapes) 80 | for i, x in enumerate(arrays): 81 | ph.assert_dtype( 82 | "broadcast_arrays", 83 | in_dtype=x.dtype, 84 | out_dtype=out[i].dtype, 85 | repr_name=f"out[{i}].dtype" 86 | ) 87 | ph.assert_result_shape( 88 | "broadcast_arrays", 89 | in_shapes=shapes, 90 | out_shape=out[i].shape, 91 | expected=expected_shape, 92 | repr_name=f"out[{i}].shape", 93 | ) 94 | # TODO: test values 95 | 96 | 97 | @given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data()) 98 | def test_broadcast_to(x, data): 99 | shape = data.draw( 100 | hh.mutually_broadcastable_shapes(1, base_shape=x.shape) 101 | .map(lambda S: S[0]) 102 | .filter(lambda s: sh.broadcast_shapes(x.shape, s) == s), 103 | label="shape", 104 | ) 105 | 106 | out = xp.broadcast_to(x, shape) 107 | 108 | ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype) 109 | ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape) 110 | # TODO: test values 111 | 112 | 113 | @given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data()) 114 | def test_can_cast(_from, to, data): 115 | from_ = data.draw( 116 | st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_" 117 | ) 118 | 119 | out = xp.can_cast(from_, to) 120 | 121 | f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" 122 | assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}" 123 | if _from == xp.bool: 124 | expected = to == xp.bool 125 | else: 126 | same_family = None 127 | for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]: 128 | if _from in dtypes: 129 | same_family = to in dtypes 130 | break 131 | assert same_family is not None # sanity check 132 | if same_family: 133 | from_min, from_max = dh.dtype_ranges[_from] 134 | to_min, to_max = dh.dtype_ranges[to] 135 | expected = from_min >= to_min and from_max <= to_max 136 | else: 137 | expected = False 138 | if expected: 139 | # cross-kind casting is not explicitly disallowed. We can only test 140 | # the cases where it should return True. TODO: if expected=False, 141 | # check that the array library actually allows such casts. 142 | assert out == expected, f"{out=}, but should be {expected} {f_func}" 143 | 144 | 145 | @pytest.mark.parametrize("dtype_name", dh.real_float_names) 146 | def test_finfo(dtype_name): 147 | try: 148 | dtype = getattr(_xp, dtype_name) 149 | except AttributeError as e: 150 | pytest.skip(str(e)) 151 | out = xp.finfo(dtype) 152 | f_func = f"[finfo({dh.dtype_to_name[dtype]})]" 153 | for attr, stype in [ 154 | ("bits", int), 155 | ("eps", float), 156 | ("max", float), 157 | ("min", float), 158 | ("smallest_normal", float), 159 | ]: 160 | assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}" 161 | value = getattr(out, attr) 162 | assert isinstance( 163 | value, stype 164 | ), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}" 165 | assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" 166 | # TODO: test values 167 | 168 | 169 | @pytest.mark.parametrize("dtype_name", dh.all_int_names) 170 | def test_iinfo(dtype_name): 171 | try: 172 | dtype = getattr(_xp, dtype_name) 173 | except AttributeError as e: 174 | pytest.skip(str(e)) 175 | out = xp.iinfo(dtype) 176 | f_func = f"[iinfo({dh.dtype_to_name[dtype]})]" 177 | for attr in ["bits", "max", "min"]: 178 | assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}" 179 | value = getattr(out, attr) 180 | assert isinstance( 181 | value, int 182 | ), f"type(out.{attr})={type(value)!r}, but should be int {f_func}" 183 | assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" 184 | # TODO: test values 185 | 186 | 187 | def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]: 188 | return xps.scalar_dtypes() | st.sampled_from(list(dh.kind_to_dtypes.keys())) 189 | 190 | 191 | @pytest.mark.min_version("2022.12") 192 | @given( 193 | dtype=xps.scalar_dtypes(), 194 | kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple), 195 | ) 196 | def test_isdtype(dtype, kind): 197 | out = xp.isdtype(dtype, kind) 198 | 199 | assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]" 200 | _kinds = kind if isinstance(kind, tuple) else (kind,) 201 | expected = False 202 | for _kind in _kinds: 203 | if isinstance(_kind, str): 204 | if dtype in dh.kind_to_dtypes[_kind]: 205 | expected = True 206 | break 207 | else: 208 | if dtype == _kind: 209 | expected = True 210 | break 211 | assert out == expected, f"{out=}, but should be {expected} [isdtype()]" 212 | 213 | 214 | @given(hh.mutually_promotable_dtypes(None)) 215 | def test_result_type(dtypes): 216 | out = xp.result_type(*dtypes) 217 | ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") 218 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from pathlib import Path 3 | import warnings 4 | import os 5 | 6 | from hypothesis import settings 7 | from pytest import mark 8 | 9 | from array_api_tests import _array_module as xp 10 | from array_api_tests import api_version 11 | from array_api_tests._array_module import _UndefinedStub 12 | 13 | from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa 14 | 15 | settings.register_profile("xp_default", deadline=800) 16 | 17 | def pytest_addoption(parser): 18 | # Hypothesis max examples 19 | # See https://github.com/HypothesisWorks/hypothesis/issues/2434 20 | parser.addoption( 21 | "--hypothesis-max-examples", 22 | "--max-examples", 23 | action="store", 24 | default=None, 25 | help="set the Hypothesis max_examples setting", 26 | ) 27 | # Hypothesis deadline 28 | parser.addoption( 29 | "--hypothesis-disable-deadline", 30 | "--disable-deadline", 31 | action="store_true", 32 | help="disable the Hypothesis deadline", 33 | ) 34 | # Hypothesis derandomize 35 | parser.addoption( 36 | "--hypothesis-derandomize", 37 | "--derandomize", 38 | action="store_true", 39 | help="set the Hypothesis derandomize parameter", 40 | ) 41 | # disable extensions 42 | parser.addoption( 43 | "--disable-extension", 44 | metavar="ext", 45 | nargs="+", 46 | default=[], 47 | help="disable testing for Array API extension(s)", 48 | ) 49 | # data-dependent shape 50 | parser.addoption( 51 | "--disable-data-dependent-shapes", 52 | "--disable-dds", 53 | action="store_true", 54 | help="disable testing functions with output shapes dependent on input", 55 | ) 56 | # CI 57 | parser.addoption( 58 | "--ci", 59 | action="store_true", 60 | help="run just the tests appropriate for CI", 61 | ) 62 | parser.addoption( 63 | "--skips-file", 64 | action="store", 65 | help="file with tests to skip. Defaults to skips.txt" 66 | ) 67 | parser.addoption( 68 | "--xfails-file", 69 | action="store", 70 | help="file with tests to skip. Defaults to xfails.txt" 71 | ) 72 | 73 | 74 | def pytest_configure(config): 75 | config.addinivalue_line( 76 | "markers", "xp_extension(ext): tests an Array API extension" 77 | ) 78 | config.addinivalue_line( 79 | "markers", "data_dependent_shapes: output shapes are dependent on inputs" 80 | ) 81 | config.addinivalue_line("markers", "ci: primary test") 82 | config.addinivalue_line( 83 | "markers", 84 | "min_version(api_version): run when greater or equal to api_version", 85 | ) 86 | # Hypothesis 87 | hypothesis_max_examples = config.getoption("--hypothesis-max-examples") 88 | disable_deadline = config.getoption("--hypothesis-disable-deadline") 89 | derandomize = config.getoption("--hypothesis-derandomize") 90 | profile_settings = {} 91 | if hypothesis_max_examples is not None: 92 | profile_settings["max_examples"] = int(hypothesis_max_examples) 93 | if disable_deadline: 94 | profile_settings["deadline"] = None 95 | if derandomize: 96 | profile_settings["derandomize"] = True 97 | if profile_settings: 98 | settings.register_profile("xp_override", **profile_settings) 99 | settings.load_profile("xp_override") 100 | else: 101 | settings.load_profile("xp_default") 102 | 103 | 104 | @lru_cache 105 | def xp_has_ext(ext: str) -> bool: 106 | try: 107 | return not isinstance(getattr(xp, ext), _UndefinedStub) 108 | except AttributeError: 109 | return False 110 | 111 | 112 | def pytest_collection_modifyitems(config, items): 113 | skips_file = skips_path = config.getoption('--skips-file') 114 | if skips_file is None: 115 | skips_file = Path(__file__).parent / "skips.txt" 116 | if skips_file.exists(): 117 | skips_path = skips_file 118 | 119 | skip_ids = [] 120 | if skips_path: 121 | with open(os.path.expanduser(skips_path)) as f: 122 | for line in f: 123 | if line.startswith("array_api_tests"): 124 | id_ = line.strip("\n") 125 | skip_ids.append(id_) 126 | 127 | xfails_file = xfails_path = config.getoption('--xfails-file') 128 | if xfails_file is None: 129 | xfails_file = Path(__file__).parent / "xfails.txt" 130 | if xfails_file.exists(): 131 | xfails_path = xfails_file 132 | 133 | xfail_ids = [] 134 | if xfails_path: 135 | with open(os.path.expanduser(xfails_path)) as f: 136 | for line in f: 137 | if not line.strip() or line.startswith('#'): 138 | continue 139 | id_ = line.strip("\n") 140 | xfail_ids.append(id_) 141 | 142 | skip_id_matched = {id_: False for id_ in skip_ids} 143 | xfail_id_matched = {id_: False for id_ in xfail_ids} 144 | 145 | disabled_exts = config.getoption("--disable-extension") 146 | disabled_dds = config.getoption("--disable-data-dependent-shapes") 147 | ci = config.getoption("--ci") 148 | 149 | for item in items: 150 | markers = list(item.iter_markers()) 151 | # skip if specified in skips file 152 | for id_ in skip_ids: 153 | if id_ in item.nodeid: 154 | item.add_marker(mark.skip(reason=f"--skips-file ({skips_file})")) 155 | skip_id_matched[id_] = True 156 | break 157 | # xfail if specified in xfails file 158 | for id_ in xfail_ids: 159 | if id_ in item.nodeid: 160 | item.add_marker(mark.xfail(reason=f"--xfails-file ({xfails_file})")) 161 | xfail_id_matched[id_] = True 162 | break 163 | # skip if disabled or non-existent extension 164 | ext_mark = next((m for m in markers if m.name == "xp_extension"), None) 165 | if ext_mark is not None: 166 | ext = ext_mark.args[0] 167 | if ext in disabled_exts: 168 | item.add_marker( 169 | mark.skip(reason=f"{ext} disabled in --disable-extensions") 170 | ) 171 | elif not xp_has_ext(ext): 172 | item.add_marker(mark.skip(reason=f"{ext} not found in array module")) 173 | # skip if disabled by dds flag 174 | if disabled_dds: 175 | for m in markers: 176 | if m.name == "data_dependent_shapes": 177 | item.add_marker( 178 | mark.skip(reason="disabled via --disable-data-dependent-shapes") 179 | ) 180 | break 181 | # skip if test not appropriate for CI 182 | if ci: 183 | ci_mark = next((m for m in markers if m.name == "ci"), None) 184 | if ci_mark is None: 185 | item.add_marker(mark.skip(reason="disabled via --ci")) 186 | # skip if test is for greater api_version 187 | ver_mark = next((m for m in markers if m.name == "min_version"), None) 188 | if ver_mark is not None: 189 | min_version = ver_mark.args[0] 190 | if api_version < min_version: 191 | item.add_marker( 192 | mark.skip( 193 | reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" 194 | ) 195 | ) 196 | 197 | bad_ids_end_msg = ( 198 | "Note the relevant tests might not of been collected by pytest, or " 199 | "another specified id might have already matched a test." 200 | ) 201 | bad_skip_ids = [id_ for id_, matched in skip_id_matched.items() if not matched] 202 | if bad_skip_ids: 203 | f_bad_ids = "\n".join(f" {id_}" for id_ in bad_skip_ids) 204 | warnings.warn( 205 | f"{len(bad_skip_ids)} ids in skips file don't match any collected tests: \n" 206 | f"{f_bad_ids}\n" 207 | f"(skips file: {skips_file})\n" 208 | f"{bad_ids_end_msg}" 209 | ) 210 | bad_xfail_ids = [id_ for id_, matched in xfail_id_matched.items() if not matched] 211 | if bad_xfail_ids: 212 | f_bad_ids = "\n".join(f" {id_}" for id_ in bad_xfail_ids) 213 | warnings.warn( 214 | f"{len(bad_xfail_ids)} ids in xfails file don't match any collected tests: \n" 215 | f"{f_bad_ids}\n" 216 | f"(xfails file: {xfails_file})\n" 217 | f"{bad_ids_end_msg}" 218 | ) 219 | -------------------------------------------------------------------------------- /array_api_tests/test_fft.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | from unittest.mock import MagicMock 4 | 5 | import pytest 6 | from hypothesis import assume, given 7 | from hypothesis import strategies as st 8 | 9 | from array_api_tests.typing import Array, DataType 10 | 11 | from . import api_version 12 | from . import dtype_helpers as dh 13 | from . import hypothesis_helpers as hh 14 | from . import pytest_helpers as ph 15 | from . import shape_helpers as sh 16 | from . import xps 17 | from . import xp 18 | 19 | pytestmark = [ 20 | pytest.mark.ci, 21 | pytest.mark.xp_extension("fft"), 22 | pytest.mark.min_version("2022.12"), 23 | ] 24 | 25 | 26 | # Using xps.complex_dtypes() raises an AttributeError for 2021.12 instances of 27 | # xps, hence this hack. TODO: figure out a better way to manage this! 28 | if api_version < "2022.12": 29 | xps = MagicMock(xps) 30 | 31 | fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) 32 | 33 | 34 | def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: 35 | size = math.prod(x.shape) 36 | n = data.draw( 37 | st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n" 38 | ) 39 | axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") 40 | if size_gt_1: 41 | _axis = x.ndim - 1 if axis == -1 else axis 42 | assume(x.shape[_axis] > 1) 43 | norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") 44 | kwargs = data.draw( 45 | hh.specified_kwargs( 46 | ("n", n, None), 47 | ("axis", axis, -1), 48 | ("norm", norm, "backward"), 49 | ), 50 | label="kwargs", 51 | ) 52 | return n, axis, norm, kwargs 53 | 54 | 55 | def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: 56 | all_axes = list(range(x.ndim)) 57 | axes = data.draw( 58 | st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True), 59 | label="axes", 60 | ) 61 | _axes = all_axes if axes is None else axes 62 | axes_sides = [x.shape[axis] for axis in _axes] 63 | s_strat = st.tuples( 64 | *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides] 65 | ) 66 | if axes is None: 67 | s_strat = st.none() | s_strat 68 | s = data.draw(s_strat, label="s") 69 | if size_gt_1: 70 | _s = x.shape if s is None else s 71 | for i in range(x.ndim): 72 | if i in _axes: 73 | side = _s[_axes.index(i)] 74 | else: 75 | side = x.shape[i] 76 | assume(side > 1) 77 | norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") 78 | kwargs = data.draw( 79 | hh.specified_kwargs( 80 | ("s", s, None), 81 | ("axes", axes, None), 82 | ("norm", norm, "backward"), 83 | ), 84 | label="kwargs", 85 | ) 86 | return s, axes, norm, kwargs 87 | 88 | 89 | def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): 90 | if in_dtype == xp.float32: 91 | expected = xp.complex64 92 | elif in_dtype == xp.float64: 93 | expected = xp.complex128 94 | else: 95 | assert dh.is_float_dtype(in_dtype) # sanity check 96 | expected = in_dtype 97 | ph.assert_dtype( 98 | func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected 99 | ) 100 | 101 | 102 | def assert_n_axis_shape( 103 | func_name: str, 104 | *, 105 | x: Array, 106 | n: Optional[int], 107 | axis: int, 108 | out: Array, 109 | size_gt_1: bool = False, 110 | ): 111 | _axis = len(x.shape) - 1 if axis == -1 else axis 112 | if n is None: 113 | if size_gt_1: 114 | axis_side = 2 * (x.shape[_axis] - 1) 115 | else: 116 | axis_side = x.shape[_axis] 117 | else: 118 | axis_side = n 119 | expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] 120 | ph.assert_shape(func_name, out_shape=out.shape, expected=expected) 121 | 122 | 123 | def assert_s_axes_shape( 124 | func_name: str, 125 | *, 126 | x: Array, 127 | s: Optional[List[int]], 128 | axes: Optional[List[int]], 129 | out: Array, 130 | size_gt_1: bool = False, 131 | ): 132 | _axes = sh.normalise_axis(axes, x.ndim) 133 | _s = x.shape if s is None else s 134 | expected = [] 135 | for i in range(x.ndim): 136 | if i in _axes: 137 | side = _s[_axes.index(i)] 138 | else: 139 | side = x.shape[i] 140 | expected.append(side) 141 | if size_gt_1: 142 | last_axis = _axes[-1] 143 | expected[last_axis] = 2 * (expected[last_axis] - 1) 144 | assume(expected[last_axis] > 0) # TODO: generate valid examples 145 | ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected)) 146 | 147 | 148 | @given( 149 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 150 | data=st.data(), 151 | ) 152 | def test_fft(x, data): 153 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) 154 | 155 | out = xp.fft.fft(x, **kwargs) 156 | 157 | assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) 158 | assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out) 159 | 160 | 161 | @given( 162 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 163 | data=st.data(), 164 | ) 165 | def test_ifft(x, data): 166 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) 167 | 168 | out = xp.fft.ifft(x, **kwargs) 169 | 170 | assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) 171 | assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out) 172 | 173 | 174 | @given( 175 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 176 | data=st.data(), 177 | ) 178 | def test_fftn(x, data): 179 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) 180 | 181 | out = xp.fft.fftn(x, **kwargs) 182 | 183 | assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) 184 | assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out) 185 | 186 | 187 | @given( 188 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 189 | data=st.data(), 190 | ) 191 | def test_ifftn(x, data): 192 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) 193 | 194 | out = xp.fft.ifftn(x, **kwargs) 195 | 196 | assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) 197 | assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out) 198 | 199 | 200 | @given( 201 | x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), 202 | data=st.data(), 203 | ) 204 | def test_rfft(x, data): 205 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) 206 | 207 | out = xp.fft.rfft(x, **kwargs) 208 | 209 | assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype) 210 | assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out) 211 | 212 | 213 | @given( 214 | x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), 215 | data=st.data(), 216 | ) 217 | def test_irfft(x, data): 218 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) 219 | 220 | out = xp.fft.irfft(x, **kwargs) 221 | 222 | assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype) 223 | 224 | _axis = x.ndim - 1 if axis == -1 else axis 225 | if n is None: 226 | axis_side = 2 * (x.shape[_axis] - 1) 227 | else: 228 | axis_side = n 229 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] 230 | ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape) 231 | 232 | 233 | @given( 234 | x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), 235 | data=st.data(), 236 | ) 237 | def test_rfftn(x, data): 238 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) 239 | 240 | out = xp.fft.rfftn(x, **kwargs) 241 | 242 | assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype) 243 | assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out) 244 | 245 | 246 | @given( 247 | x=xps.arrays( 248 | dtype=xps.complex_dtypes(), shape=fft_shapes_strat.filter(lambda s: s[-1] > 1) 249 | ), 250 | data=st.data(), 251 | ) 252 | def test_irfftn(x, data): 253 | s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True) 254 | 255 | out = xp.fft.irfftn(x, **kwargs) 256 | 257 | assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype) 258 | assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True) 259 | 260 | 261 | @given( 262 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 263 | data=st.data(), 264 | ) 265 | def test_hfft(x, data): 266 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) 267 | 268 | out = xp.fft.hfft(x, **kwargs) 269 | 270 | assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype) 271 | 272 | _axis = x.ndim - 1 if axis == -1 else axis 273 | if n is None: 274 | axis_side = 2 * (x.shape[_axis] - 1) 275 | else: 276 | axis_side = n 277 | expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] 278 | ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape) 279 | 280 | 281 | @given( 282 | x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), 283 | data=st.data(), 284 | ) 285 | def test_ihfft(x, data): 286 | n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) 287 | 288 | out = xp.fft.ihfft(x, **kwargs) 289 | 290 | assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype) 291 | assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True) 292 | 293 | 294 | # TODO: 295 | # fftfreq 296 | # rfftfreq 297 | # fftshift 298 | # ifftshift 299 | -------------------------------------------------------------------------------- /array_api_tests/test_set_functions.py: -------------------------------------------------------------------------------- 1 | # TODO: disable if opted out, refactor things 2 | import cmath 3 | import math 4 | from collections import Counter, defaultdict 5 | 6 | import pytest 7 | from hypothesis import assume, given 8 | 9 | from . import _array_module as xp 10 | from . import dtype_helpers as dh 11 | from . import hypothesis_helpers as hh 12 | from . import pytest_helpers as ph 13 | from . import shape_helpers as sh 14 | from . import xps 15 | 16 | pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes] 17 | 18 | 19 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) 20 | def test_unique_all(x): 21 | out = xp.unique_all(x) 22 | 23 | assert hasattr(out, "values") 24 | assert hasattr(out, "indices") 25 | assert hasattr(out, "inverse_indices") 26 | assert hasattr(out, "counts") 27 | 28 | ph.assert_dtype( 29 | "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" 30 | ) 31 | ph.assert_default_index( 32 | "unique_all", out.indices.dtype, repr_name="out.indices.dtype" 33 | ) 34 | ph.assert_default_index( 35 | "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" 36 | ) 37 | ph.assert_default_index( 38 | "unique_all", out.counts.dtype, repr_name="out.counts.dtype" 39 | ) 40 | 41 | assert ( 42 | out.indices.shape == out.values.shape 43 | ), f"{out.indices.shape=}, but should be {out.values.shape=}" 44 | ph.assert_shape( 45 | "unique_all", 46 | out_shape=out.inverse_indices.shape, 47 | expected=x.shape, 48 | repr_name="out.inverse_indices.shape", 49 | ) 50 | assert ( 51 | out.counts.shape == out.values.shape 52 | ), f"{out.counts.shape=}, but should be {out.values.shape=}" 53 | 54 | scalar_type = dh.get_scalar_type(out.values.dtype) 55 | counts = defaultdict(int) 56 | firsts = {} 57 | for i, idx in enumerate(sh.ndindex(x.shape)): 58 | val = scalar_type(x[idx]) 59 | if counts[val] == 0: 60 | firsts[val] = i 61 | counts[val] += 1 62 | 63 | for idx in sh.ndindex(out.indices.shape): 64 | val = scalar_type(out.values[idx]) 65 | if cmath.isnan(val): 66 | break 67 | i = int(out.indices[idx]) 68 | expected = firsts[val] 69 | assert i == expected, ( 70 | f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " 71 | f"but first occurence of {val} is at {expected}" 72 | ) 73 | 74 | for idx in sh.ndindex(out.inverse_indices.shape): 75 | ridx = int(out.inverse_indices[idx]) 76 | val = out.values[ridx] 77 | expected = x[idx] 78 | msg = ( 79 | f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " 80 | f"but should result in x[{idx}]={expected}" 81 | ) 82 | if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): 83 | assert xp.isnan(val), msg 84 | else: 85 | assert val == expected, msg 86 | 87 | vals_idx = {} 88 | nans = 0 89 | for idx in sh.ndindex(out.values.shape): 90 | val = scalar_type(out.values[idx]) 91 | count = int(out.counts[idx]) 92 | if cmath.isnan(val): 93 | nans += 1 94 | assert count == 1, ( 95 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " 96 | "but count should be 1 as NaNs are distinct" 97 | ) 98 | else: 99 | expected = counts[val] 100 | assert ( 101 | expected > 0 102 | ), f"out.values[{idx}]={val}, but {val} not in input array" 103 | count = int(out.counts[idx]) 104 | assert count == expected, ( 105 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " 106 | f"but should be {expected}" 107 | ) 108 | assert ( 109 | val not in vals_idx.keys() 110 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" 111 | vals_idx[val] = idx 112 | 113 | if dh.is_float_dtype(out.values.dtype): 114 | assume(math.prod(x.shape) <= 128) # may not be representable 115 | expected = sum(v for k, v in counts.items() if cmath.isnan(k)) 116 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}" 117 | 118 | 119 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) 120 | def test_unique_counts(x): 121 | out = xp.unique_counts(x) 122 | assert hasattr(out, "values") 123 | assert hasattr(out, "counts") 124 | ph.assert_dtype( 125 | "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" 126 | ) 127 | ph.assert_default_index( 128 | "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" 129 | ) 130 | assert ( 131 | out.counts.shape == out.values.shape 132 | ), f"{out.counts.shape=}, but should be {out.values.shape=}" 133 | scalar_type = dh.get_scalar_type(out.values.dtype) 134 | counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) 135 | vals_idx = {} 136 | nans = 0 137 | for idx in sh.ndindex(out.values.shape): 138 | val = scalar_type(out.values[idx]) 139 | count = int(out.counts[idx]) 140 | if cmath.isnan(val): 141 | nans += 1 142 | assert count == 1, ( 143 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " 144 | "but count should be 1 as NaNs are distinct" 145 | ) 146 | else: 147 | expected = counts[val] 148 | assert ( 149 | expected > 0 150 | ), f"out.values[{idx}]={val}, but {val} not in input array" 151 | count = int(out.counts[idx]) 152 | assert count == expected, ( 153 | f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " 154 | f"but should be {expected}" 155 | ) 156 | assert ( 157 | val not in vals_idx.keys() 158 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" 159 | vals_idx[val] = idx 160 | if dh.is_float_dtype(out.values.dtype): 161 | assume(math.prod(x.shape) <= 128) # may not be representable 162 | expected = sum(v for k, v in counts.items() if cmath.isnan(k)) 163 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}" 164 | 165 | 166 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) 167 | def test_unique_inverse(x): 168 | out = xp.unique_inverse(x) 169 | assert hasattr(out, "values") 170 | assert hasattr(out, "inverse_indices") 171 | ph.assert_dtype( 172 | "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" 173 | ) 174 | ph.assert_default_index( 175 | "unique_inverse", 176 | out.inverse_indices.dtype, 177 | repr_name="out.inverse_indices.dtype", 178 | ) 179 | ph.assert_shape( 180 | "unique_inverse", 181 | out_shape=out.inverse_indices.shape, 182 | expected=x.shape, 183 | repr_name="out.inverse_indices.shape", 184 | ) 185 | scalar_type = dh.get_scalar_type(out.values.dtype) 186 | distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) 187 | vals_idx = {} 188 | nans = 0 189 | for idx in sh.ndindex(out.values.shape): 190 | val = scalar_type(out.values[idx]) 191 | if cmath.isnan(val): 192 | nans += 1 193 | else: 194 | assert ( 195 | val in distinct 196 | ), f"out.values[{idx}]={val}, but {val} not in input array" 197 | assert ( 198 | val not in vals_idx.keys() 199 | ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" 200 | vals_idx[val] = idx 201 | for idx in sh.ndindex(out.inverse_indices.shape): 202 | ridx = int(out.inverse_indices[idx]) 203 | val = out.values[ridx] 204 | expected = x[idx] 205 | msg = ( 206 | f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " 207 | f"but should result in x[{idx}]={expected}" 208 | ) 209 | if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): 210 | assert xp.isnan(val), msg 211 | else: 212 | assert val == expected, msg 213 | if dh.is_float_dtype(out.values.dtype): 214 | assume(math.prod(x.shape) <= 128) # may not be representable 215 | expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) 216 | assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" 217 | 218 | 219 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) 220 | def test_unique_values(x): 221 | out = xp.unique_values(x) 222 | ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype) 223 | scalar_type = dh.get_scalar_type(x.dtype) 224 | distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) 225 | vals_idx = {} 226 | nans = 0 227 | for idx in sh.ndindex(out.shape): 228 | val = scalar_type(out[idx]) 229 | if cmath.isnan(val): 230 | nans += 1 231 | else: 232 | assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" 233 | assert ( 234 | val not in vals_idx.keys() 235 | ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" 236 | vals_idx[val] = idx 237 | if dh.is_float_dtype(out.dtype): 238 | assume(math.prod(x.shape) <= 128) # may not be representable 239 | expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) 240 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}" 241 | -------------------------------------------------------------------------------- /array_api_tests/test_array_object.py: -------------------------------------------------------------------------------- 1 | import cmath 2 | import math 3 | from itertools import product 4 | from typing import List, Sequence, Tuple, Union, get_args 5 | 6 | import pytest 7 | from hypothesis import assume, given, note 8 | from hypothesis import strategies as st 9 | 10 | from . import _array_module as xp 11 | from . import dtype_helpers as dh 12 | from . import hypothesis_helpers as hh 13 | from . import pytest_helpers as ph 14 | from . import shape_helpers as sh 15 | from . import xps 16 | from . import xp as _xp 17 | from .typing import DataType, Index, Param, Scalar, ScalarType, Shape 18 | 19 | pytestmark = pytest.mark.ci 20 | 21 | 22 | def scalar_objects( 23 | dtype: DataType, shape: Shape 24 | ) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]: 25 | """Generates scalars or nested sequences which are valid for xp.asarray()""" 26 | size = math.prod(shape) 27 | return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( 28 | lambda l: sh.reshape(l, shape) 29 | ) 30 | 31 | 32 | def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]: 33 | """ 34 | Normalise an indexing key. 35 | 36 | * If a non-tuple index, wrap as a tuple. 37 | * Represent ellipsis as equivalent slices. 38 | """ 39 | _key = tuple(key) if isinstance(key, tuple) else (key,) 40 | if Ellipsis in _key: 41 | nonexpanding_key = tuple(i for i in _key if i is not None) 42 | start_a = nonexpanding_key.index(Ellipsis) 43 | stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) 44 | slices = tuple(slice(None) for _ in range(start_a, stop_a)) 45 | start_pos = _key.index(Ellipsis) 46 | _key = _key[:start_pos] + slices + _key[start_pos + 1 :] 47 | return _key 48 | 49 | 50 | def get_indexed_axes_and_out_shape( 51 | key: Tuple[Union[int, slice, None], ...], shape: Shape 52 | ) -> Tuple[Tuple[Sequence[int], ...], Shape]: 53 | """ 54 | From the (normalised) key and input shape, calculates: 55 | 56 | * indexed_axes: For each dimension, the axes which the key indexes. 57 | * out_shape: The resulting shape of indexing an array (of the input shape) 58 | with the key. 59 | """ 60 | axes_indices = [] 61 | out_shape = [] 62 | a = 0 63 | for i in key: 64 | if i is None: 65 | out_shape.append(1) 66 | else: 67 | side = shape[a] 68 | if isinstance(i, int): 69 | if i < 0: 70 | i += side 71 | axes_indices.append((i,)) 72 | else: 73 | indices = range(side)[i] 74 | axes_indices.append(indices) 75 | out_shape.append(len(indices)) 76 | a += 1 77 | return tuple(axes_indices), tuple(out_shape) 78 | 79 | 80 | @given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) 81 | def test_getitem(shape, dtype, data): 82 | zero_sided = any(side == 0 for side in shape) 83 | if zero_sided: 84 | x = xp.zeros(shape, dtype=dtype) 85 | else: 86 | obj = data.draw(scalar_objects(dtype, shape), label="obj") 87 | x = xp.asarray(obj, dtype=dtype) 88 | note(f"{x=}") 89 | key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") 90 | 91 | out = x[key] 92 | 93 | ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) 94 | _key = normalise_key(key, shape) 95 | axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape) 96 | ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) 97 | out_zero_sided = any(side == 0 for side in expected_shape) 98 | if not zero_sided and not out_zero_sided: 99 | out_obj = [] 100 | for idx in product(*axes_indices): 101 | val = obj 102 | for i in idx: 103 | val = val[i] 104 | out_obj.append(val) 105 | out_obj = sh.reshape(out_obj, expected_shape) 106 | expected = xp.asarray(out_obj, dtype=dtype) 107 | ph.assert_array_elements("__getitem__", out=out, expected=expected) 108 | 109 | 110 | @given( 111 | shape=hh.shapes(), 112 | dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), 113 | data=st.data(), 114 | ) 115 | def test_setitem(shape, dtypes, data): 116 | zero_sided = any(side == 0 for side in shape) 117 | if zero_sided: 118 | x = xp.zeros(shape, dtype=dtypes.result_dtype) 119 | else: 120 | obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj") 121 | x = xp.asarray(obj, dtype=dtypes.result_dtype) 122 | note(f"{x=}") 123 | key = data.draw(xps.indices(shape=shape), label="key") 124 | _key = normalise_key(key, shape) 125 | axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) 126 | value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape) 127 | if out_shape == (): 128 | # We can pass scalars if we're only indexing one element 129 | value_strat |= xps.from_dtype(dtypes.result_dtype) 130 | value = data.draw(value_strat, label="value") 131 | 132 | res = xp.asarray(x, copy=True) 133 | res[key] = value 134 | 135 | ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") 136 | ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape") 137 | f_res = sh.fmt_idx("x", key) 138 | if isinstance(value, get_args(Scalar)): 139 | msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" 140 | if cmath.isnan(value): 141 | assert xp.isnan(res[key]), msg 142 | else: 143 | assert res[key] == value, msg 144 | else: 145 | ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res) 146 | unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) 147 | for idx in unaffected_indices: 148 | ph.assert_0d_equals( 149 | "__setitem__", 150 | x_repr=f"old {f_res}", 151 | x_val=x[idx], 152 | out_repr=f"modified {f_res}", 153 | out_val=res[idx], 154 | ) 155 | 156 | 157 | @pytest.mark.data_dependent_shapes 158 | @given(hh.shapes(), st.data()) 159 | def test_getitem_masking(shape, data): 160 | x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") 161 | mask_shapes = st.one_of( 162 | st.sampled_from([x.shape, ()]), 163 | st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map( 164 | lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l)) 165 | ), 166 | hh.shapes(), 167 | ) 168 | key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key") 169 | 170 | if key.ndim > x.ndim or not all( 171 | ks in (xs, 0) for xs, ks in zip(x.shape, key.shape) 172 | ): 173 | with pytest.raises(IndexError): 174 | x[key] 175 | return 176 | 177 | out = x[key] 178 | 179 | ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) 180 | if key.ndim == 0: 181 | expected_shape = (1,) if key else (0,) 182 | expected_shape += x.shape 183 | else: 184 | size = int(xp.sum(xp.astype(key, xp.uint8))) 185 | expected_shape = (size,) + x.shape[key.ndim :] 186 | ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) 187 | if not any(s == 0 for s in key.shape): 188 | assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios 189 | out_indices = sh.ndindex(out.shape) 190 | for x_idx in sh.ndindex(x.shape): 191 | if key[x_idx]: 192 | out_idx = next(out_indices) 193 | ph.assert_0d_equals( 194 | "__getitem__", 195 | x_repr=f"x[{x_idx}]", 196 | x_val=x[x_idx], 197 | out_repr=f"out[{out_idx}]", 198 | out_val=out[out_idx], 199 | ) 200 | 201 | 202 | @given(hh.shapes(), st.data()) 203 | def test_setitem_masking(shape, data): 204 | x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") 205 | key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") 206 | value = data.draw( 207 | xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value" 208 | ) 209 | 210 | res = xp.asarray(x, copy=True) 211 | res[key] = value 212 | 213 | ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") 214 | ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype") 215 | scalar_type = dh.get_scalar_type(x.dtype) 216 | for idx in sh.ndindex(x.shape): 217 | if key[idx]: 218 | if isinstance(value, get_args(Scalar)): 219 | ph.assert_scalar_equals( 220 | "__setitem__", 221 | type_=scalar_type, 222 | idx=idx, 223 | out=scalar_type(res[idx]), 224 | expected=value, 225 | repr_name="modified x", 226 | ) 227 | else: 228 | ph.assert_0d_equals( 229 | "__setitem__", 230 | x_repr="value", 231 | x_val=value, 232 | out_repr=f"modified x[{idx}]", 233 | out_val=res[idx] 234 | ) 235 | else: 236 | ph.assert_0d_equals( 237 | "__setitem__", 238 | x_repr=f"old x[{idx}]", 239 | x_val=x[idx], 240 | out_repr=f"modified x[{idx}]", 241 | out_val=res[idx] 242 | ) 243 | 244 | 245 | def make_scalar_casting_param( 246 | method_name: str, dtype_name: DataType, stype: ScalarType 247 | ) -> Param: 248 | return pytest.param( 249 | method_name, dtype_name, stype, id=f"{method_name}({dtype_name})" 250 | ) 251 | 252 | 253 | @pytest.mark.parametrize( 254 | "method_name, dtype_name, stype", 255 | [make_scalar_casting_param("__bool__", "bool", bool)] 256 | + [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names] 257 | + [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names] 258 | + [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names], 259 | ) 260 | @given(data=st.data()) 261 | def test_scalar_casting(method_name, dtype_name, stype, data): 262 | try: 263 | dtype = getattr(_xp, dtype_name) 264 | except AttributeError as e: 265 | pytest.skip(str(e)) 266 | x = data.draw(xps.arrays(dtype, shape=()), label="x") 267 | method = getattr(x, method_name) 268 | out = method() 269 | assert isinstance( 270 | out, stype 271 | ), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar" 272 | -------------------------------------------------------------------------------- /array_api_tests/array_helpers.py: -------------------------------------------------------------------------------- 1 | from ._array_module import (isnan, all, any, equal, not_equal, logical_and, 2 | logical_or, isfinite, greater, less, less_equal, 3 | zeros, ones, full, bool, int8, int16, int32, 4 | int64, uint8, uint16, uint32, uint64, float32, 5 | float64, nan, inf, pi, remainder, divide, isinf, 6 | negative, asarray) 7 | # These are exported here so that they can be included in the special cases 8 | # tests from this file. 9 | from ._array_module import logical_not, subtract, floor, ceil, where 10 | from . import dtype_helpers as dh 11 | 12 | 13 | __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', 14 | 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil', 15 | 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN', 16 | 'infinity', 'π', 'isnegzero', 'non_zero', 'isposzero', 17 | 'exactly_equal', 'assert_exactly_equal', 'notequal', 18 | 'assert_finite', 'assert_non_zero', 'ispositive', 19 | 'assert_positive', 'isnegative', 'assert_negative', 'isintegral', 20 | 'assert_integral', 'isodd', 'iseven', "assert_iseven", 21 | 'assert_isinf', 'positive_mathematical_sign', 22 | 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 23 | 'assert_negative_mathematical_sign', 'same_sign', 24 | 'assert_same_sign', 'float64', 25 | 'asarray', 'full', 'true', 'false', 'isnan'] 26 | 27 | def zero(shape, dtype): 28 | """ 29 | Returns a full 0 array of the given dtype. 30 | 31 | This should be used in place of the literal "0" in the test suite, as the 32 | spec does not require any behavior with Python literals (and in 33 | particular, it does not specify how the integer 0 and the float 0.0 work 34 | with type promotion). 35 | 36 | To get -0, use -zero(dtype) (note that -0 is only defined for floating 37 | point dtypes). 38 | """ 39 | return zeros(shape, dtype=dtype) 40 | 41 | def one(shape, dtype): 42 | """ 43 | Returns a full 1 array of the given dtype. 44 | 45 | This should be used in place of the literal "1" in the test suite, as the 46 | spec does not require any behavior with Python literals (and in 47 | particular, it does not specify how the integer 1 and the float 1.0 work 48 | with type promotion). 49 | 50 | To get -1, use -one(dtype). 51 | """ 52 | return ones(shape, dtype=dtype) 53 | 54 | def NaN(shape, dtype): 55 | """ 56 | Returns a full nan array of the given dtype. 57 | 58 | Note that this is only defined for floating point dtypes. 59 | """ 60 | if dtype not in [float32, float64]: 61 | raise RuntimeError(f"Unexpected dtype {dtype} in NaN().") 62 | return full(shape, nan, dtype=dtype) 63 | 64 | def infinity(shape, dtype): 65 | """ 66 | Returns a full positive infinity array of the given dtype. 67 | 68 | Note that this is only defined for floating point dtypes. 69 | 70 | To get negative infinity, use -infinity(dtype). 71 | 72 | """ 73 | if dtype not in [float32, float64]: 74 | raise RuntimeError(f"Unexpected dtype {dtype} in infinity().") 75 | return full(shape, inf, dtype=dtype) 76 | 77 | def π(shape, dtype): 78 | """ 79 | Returns a full π array of the given dtype. 80 | 81 | Note that this function is only defined for floating point dtype. 82 | 83 | To get rational multiples of π, use, e.g., 3*π(dtype)/2. 84 | 85 | """ 86 | if dtype not in [float32, float64]: 87 | raise RuntimeError(f"Unexpected dtype {dtype} in π().") 88 | return full(shape, pi, dtype=dtype) 89 | 90 | def true(shape): 91 | """ 92 | Returns a full True array with dtype=bool. 93 | """ 94 | return full(shape, True, dtype=bool) 95 | 96 | def false(shape): 97 | """ 98 | Returns a full False array with dtype=bool. 99 | """ 100 | return full(shape, False, dtype=bool) 101 | 102 | def isnegzero(x): 103 | """ 104 | Returns a mask where x is -0. Is all False if x has integer dtype. 105 | """ 106 | # TODO: If copysign or signbit are added to the spec, use those instead. 107 | shape = x.shape 108 | dtype = x.dtype 109 | if dh.is_int_dtype(dtype): 110 | return false(shape) 111 | return equal(divide(one(shape, dtype), x), -infinity(shape, dtype)) 112 | 113 | def isposzero(x): 114 | """ 115 | Returns a mask where x is +0 (but not -0). Is all True if x has integer dtype. 116 | """ 117 | # TODO: If copysign or signbit are added to the spec, use those instead. 118 | shape = x.shape 119 | dtype = x.dtype 120 | if dh.is_int_dtype(dtype): 121 | return true(shape) 122 | return equal(divide(one(shape, dtype), x), infinity(shape, dtype)) 123 | 124 | def exactly_equal(x, y): 125 | """ 126 | Same as equal(x, y) except it gives True where both values are nan, and 127 | distinguishes +0 and -0. 128 | 129 | This function implicitly assumes x and y have the same shape and dtype. 130 | """ 131 | if x.dtype in [float32, float64]: 132 | xnegzero = isnegzero(x) 133 | ynegzero = isnegzero(y) 134 | 135 | xposzero = isposzero(x) 136 | yposzero = isposzero(y) 137 | 138 | xnan = isnan(x) 139 | ynan = isnan(y) 140 | 141 | # (x == y OR x == y == NaN) AND xnegzero == ynegzero AND xposzero == y poszero 142 | return logical_and(logical_and( 143 | logical_or(equal(x, y), logical_and(xnan, ynan)), 144 | equal(xnegzero, ynegzero)), 145 | equal(xposzero, yposzero)) 146 | 147 | return equal(x, y) 148 | 149 | def notequal(x, y): 150 | """ 151 | Same as not_equal(x, y) except it gives False when both values are nan. 152 | 153 | Note: this function does NOT distinguish +0 and -0. 154 | 155 | This function implicitly assumes x and y have the same shape and dtype. 156 | """ 157 | if x.dtype in [float32, float64]: 158 | xnan = isnan(x) 159 | ynan = isnan(y) 160 | 161 | both_nan = logical_and(xnan, ynan) 162 | # NOT both nan AND (both nan OR x != y) 163 | return logical_and(logical_not(both_nan), not_equal(x, y)) 164 | 165 | return not_equal(x, y) 166 | 167 | def assert_exactly_equal(x, y): 168 | """ 169 | Test that the arrays x and y are exactly equal. 170 | 171 | If x and y do not have the same shape and dtype, they are not considered 172 | equal. 173 | 174 | """ 175 | assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})" 176 | 177 | assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})" 178 | 179 | assert all(exactly_equal(x, y)), "The input arrays have different values" 180 | 181 | def assert_finite(x): 182 | """ 183 | Test that the array x is finite 184 | """ 185 | assert all(isfinite(x)), "The input array is not finite" 186 | 187 | def non_zero(x): 188 | return not_equal(x, zero(x.shape, x.dtype)) 189 | 190 | def assert_non_zero(x): 191 | assert all(non_zero(x)), "The input array is not nonzero" 192 | 193 | def ispositive(x): 194 | return greater(x, zero(x.shape, x.dtype)) 195 | 196 | def assert_positive(x): 197 | assert all(ispositive(x)), "The input array is not positive" 198 | 199 | def isnegative(x): 200 | return less(x, zero(x.shape, x.dtype)) 201 | 202 | def assert_negative(x): 203 | assert all(isnegative(x)), "The input array is not negative" 204 | 205 | def inrange(x, a, b, epsilon=0, open=False): 206 | """ 207 | Returns a mask for values of x in the range [a-epsilon, a+epsilon], inclusive 208 | 209 | If open=True, the range is (a-epsilon, a+epsilon) (i.e., not inclusive). 210 | """ 211 | eps = full(x.shape, epsilon, dtype=x.dtype) 212 | l = less if open else less_equal 213 | return logical_and(l(a-eps, x), l(x, b+eps)) 214 | 215 | def isintegral(x): 216 | """ 217 | Returns a mask on x where the values are integral 218 | 219 | x is integral if its dtype is an integer dtype, or if it is a floating 220 | point value that can be exactly represented as an integer. 221 | """ 222 | if x.dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64]: 223 | return full(x.shape, True, dtype=bool) 224 | elif x.dtype in [float32, float64]: 225 | return equal(remainder(x, one(x.shape, x.dtype)), zero(x.shape, x.dtype)) 226 | else: 227 | return full(x.shape, False, dtype=bool) 228 | 229 | def assert_integral(x): 230 | """ 231 | Check that x has only integer values 232 | """ 233 | assert all(isintegral(x)), "The input array has nonintegral values" 234 | 235 | def isodd(x): 236 | return logical_and( 237 | isintegral(x), 238 | equal( 239 | remainder(x, 2*one(x.shape, x.dtype)), 240 | one(x.shape, x.dtype))) 241 | 242 | def iseven(x): 243 | return logical_and( 244 | isintegral(x), 245 | equal( 246 | remainder(x, 2*one(x.shape, x.dtype)), 247 | zero(x.shape, x.dtype))) 248 | 249 | def assert_iseven(x): 250 | """ 251 | Check that x is an even integer 252 | """ 253 | assert all(iseven(x)), "The input array is not even" 254 | 255 | def assert_isinf(x): 256 | """ 257 | Check that x is an infinity 258 | """ 259 | assert all(isinf(x)), "The input array is not infinite" 260 | 261 | def positive_mathematical_sign(x): 262 | """ 263 | Check if x has a positive "mathematical sign" 264 | 265 | The "mathematical sign" here means the sign bit is 0. This includes 0, 266 | positive finite numbers, and positive infinity. It does not include any 267 | nans, as signed nans are not required by the spec. 268 | 269 | """ 270 | z = zero(x.shape, x.dtype) 271 | return logical_or(greater(x, z), isposzero(x)) 272 | 273 | def assert_positive_mathematical_sign(x): 274 | assert all(positive_mathematical_sign(x)), "The input arrays do not have a positive mathematical sign" 275 | 276 | def negative_mathematical_sign(x): 277 | """ 278 | Check if x has a negative "mathematical sign" 279 | 280 | The "mathematical sign" here means the sign bit is 1. This includes -0, 281 | negative finite numbers, and negative infinity. It does not include any 282 | nans, as signed nans are not required by the spec. 283 | 284 | """ 285 | z = zero(x.shape, x.dtype) 286 | if x.dtype in [float32, float64]: 287 | return logical_or(less(x, z), isnegzero(x)) 288 | return less(x, z) 289 | 290 | def assert_negative_mathematical_sign(x): 291 | assert all(negative_mathematical_sign(x)), "The input arrays do not have a negative mathematical sign" 292 | 293 | def same_sign(x, y): 294 | """ 295 | Check if x and y have the "same sign" 296 | 297 | x and y have the same sign if they are both nonnegative or both negative. 298 | For the purposes of this function 0 and 1 have the same sign and -0 and -1 299 | have the same sign. The value of this function is False if either x or y 300 | is nan, as signed nans are not required by the spec. 301 | """ 302 | return logical_or( 303 | logical_and(positive_mathematical_sign(x), positive_mathematical_sign(y)), 304 | logical_and(negative_mathematical_sign(x), negative_mathematical_sign(y))) 305 | 306 | def assert_same_sign(x, y): 307 | assert all(same_sign(x, y)), "The input arrays do not have the same sign" 308 | 309 | -------------------------------------------------------------------------------- /array_api_tests/test_signatures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for function/method signatures compliance 3 | 4 | We're not interested in being 100% strict - instead we focus on areas which 5 | could affect interop, e.g. with 6 | 7 | def add(x1, x2, /): 8 | ... 9 | 10 | x1 and x2 don't need to be pos-only for the purposes of interoperability, but with 11 | 12 | def squeeze(x, /, axis): 13 | ... 14 | 15 | axis has to be pos-or-keyword to support both styles 16 | 17 | >>> squeeze(x, 0) 18 | ... 19 | >>> squeeze(x, axis=0) 20 | ... 21 | 22 | """ 23 | from collections import defaultdict 24 | from copy import copy 25 | from inspect import Parameter, Signature, signature 26 | from types import FunctionType 27 | from typing import Any, Callable, Dict, Literal, get_args 28 | from warnings import warn 29 | 30 | import pytest 31 | 32 | from . import dtype_helpers as dh 33 | from . import xp 34 | from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func 35 | 36 | pytestmark = pytest.mark.ci 37 | 38 | ParameterKind = Literal[ 39 | Parameter.POSITIONAL_ONLY, 40 | Parameter.VAR_POSITIONAL, 41 | Parameter.POSITIONAL_OR_KEYWORD, 42 | Parameter.KEYWORD_ONLY, 43 | Parameter.VAR_KEYWORD, 44 | ] 45 | ALL_KINDS = get_args(ParameterKind) 46 | VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) 47 | kind_to_str: Dict[ParameterKind, str] = { 48 | Parameter.POSITIONAL_OR_KEYWORD: "pos or kw argument", 49 | Parameter.POSITIONAL_ONLY: "pos-only argument", 50 | Parameter.KEYWORD_ONLY: "keyword-only argument", 51 | Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument", 52 | Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument", 53 | } 54 | 55 | 56 | def _test_inspectable_func(sig: Signature, stub_sig: Signature): 57 | params = list(sig.parameters.values()) 58 | stub_params = list(stub_sig.parameters.values()) 59 | 60 | non_kwonly_stub_params = [ 61 | p for p in stub_params if p.kind != Parameter.KEYWORD_ONLY 62 | ] 63 | # sanity check 64 | assert non_kwonly_stub_params == stub_params[: len(non_kwonly_stub_params)] 65 | # We're not interested if the array module has additional arguments, so we 66 | # only iterate through the arguments listed in the spec. 67 | for i, stub_param in enumerate(non_kwonly_stub_params): 68 | assert ( 69 | len(params) >= i + 1 70 | ), f"Argument '{stub_param.name}' missing from signature" 71 | param = params[i] 72 | 73 | # We're not interested in the name if it isn't actually used 74 | if stub_param.kind not in [Parameter.POSITIONAL_ONLY, *VAR_KINDS]: 75 | assert ( 76 | param.name == stub_param.name 77 | ), f"Expected argument '{param.name}' to be named '{stub_param.name}'" 78 | 79 | if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]: 80 | f_stub_kind = kind_to_str[stub_param.kind] 81 | assert param.kind == stub_param.kind, ( 82 | f"{param.name} is a {kind_to_str[param.kind]}, " 83 | f"but should be a {f_stub_kind}" 84 | ) 85 | 86 | kwonly_stub_params = stub_params[len(non_kwonly_stub_params) :] 87 | for stub_param in kwonly_stub_params: 88 | assert ( 89 | stub_param.name in sig.parameters.keys() 90 | ), f"Argument '{stub_param.name}' missing from signature" 91 | param = next(p for p in params if p.name == stub_param.name) 92 | f_stub_kind = kind_to_str[stub_param.kind] 93 | assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], ( 94 | f"{param.name} is a {kind_to_str[param.kind]}, " 95 | f"but should be a {f_stub_kind} " 96 | f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})" 97 | ) 98 | 99 | 100 | def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str: 101 | f_sig = f"{func_name}(" 102 | f_sig += ", ".join(str(a) for a in args) 103 | if len(kwargs) != 0: 104 | if len(args) != 0: 105 | f_sig += ", " 106 | f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items()) 107 | f_sig += ")" 108 | return f_sig 109 | 110 | 111 | # We test uninspectable signatures by passing valid, manually-defined arguments 112 | # to the signature's function/method. 113 | # 114 | # Arguments which require use of the array module are specified as string 115 | # expressions to be eval()'d on runtime. This is as opposed to just using the 116 | # array module whilst setting up the tests, which is prone to halt the entire 117 | # test suite if an array module doesn't support a given expression. 118 | func_to_specified_args = defaultdict( 119 | dict, 120 | { 121 | "permute_dims": {"axes": 0}, 122 | "reshape": {"shape": (1, 5)}, 123 | "broadcast_to": {"shape": (1, 5)}, 124 | "asarray": {"obj": [0, 1, 2, 3, 4]}, 125 | "full_like": {"fill_value": 42}, 126 | "matrix_power": {"n": 2}, 127 | }, 128 | ) 129 | func_to_specified_arg_exprs = defaultdict( 130 | dict, 131 | { 132 | "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"}, 133 | "iinfo": {"type": "xp.int64"}, 134 | "finfo": {"type": "xp.float64"}, 135 | "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"}, 136 | "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"}, 137 | "solve": { 138 | a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"] 139 | }, 140 | }, 141 | ) 142 | # We default most array arguments heuristically. As functions/methods work only 143 | # with arrays of certain dtypes and shapes, we specify only supported arrays 144 | # respective to the function. 145 | casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"] 146 | matrixy_names = [ 147 | f.__name__ 148 | for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"] 149 | ] 150 | matrixy_names += ["__matmul__", "triu", "tril"] 151 | for func_name, func in name_to_func.items(): 152 | stub_sig = signature(func) 153 | array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"} 154 | if func in array_methods: 155 | array_argnames.add("self") 156 | array_argnames -= set(func_to_specified_arg_exprs[func_name].keys()) 157 | if len(array_argnames) > 0: 158 | in_dtypes = dh.func_in_dtypes[func_name] 159 | for dtype_name in ["float64", "bool", "int64", "complex128"]: 160 | # We try float64 first because uninspectable numerical functions 161 | # tend to support float inputs first-and-foremost (i.e. PyTorch) 162 | try: 163 | dtype = getattr(xp, dtype_name) 164 | except AttributeError: 165 | pass 166 | else: 167 | if dtype in in_dtypes: 168 | if func_name in casty_names: 169 | shape = () 170 | elif func_name in matrixy_names: 171 | shape = (3, 3) 172 | else: 173 | shape = (5,) 174 | fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})" 175 | break 176 | else: 177 | warn( 178 | f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does " 179 | "not contain any assumed dtypes, so skipping specifying fallback array." 180 | ) 181 | continue 182 | for argname in array_argnames: 183 | func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr 184 | 185 | 186 | def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature): 187 | params = list(stub_sig.parameters.values()) 188 | 189 | if len(params) == 0: 190 | func() 191 | return 192 | 193 | uninspectable_msg = ( 194 | f"Note {func_name}() is not inspectable so arguments are passed " 195 | "manually to test the signature." 196 | ) 197 | 198 | argname_to_arg = copy(func_to_specified_args[func_name]) 199 | argname_to_expr = func_to_specified_arg_exprs[func_name] 200 | for argname, expr in argname_to_expr.items(): 201 | assert argname not in argname_to_arg.keys() # sanity check 202 | try: 203 | argname_to_arg[argname] = eval(expr, {"xp": xp}) 204 | except Exception as e: 205 | pytest.skip( 206 | f"Exception occured when evaluating {argname}={expr}: {e}\n" 207 | f"{uninspectable_msg}" 208 | ) 209 | 210 | posargs = [] 211 | posorkw_args = {} 212 | kwargs = {} 213 | no_arg_msg = ( 214 | "We have no argument specified for '{}'. Please ensure you're using " 215 | "the latest version of array-api-tests, then open an issue if one " 216 | f"doesn't already exist. {uninspectable_msg}" 217 | ) 218 | for param in params: 219 | if param.kind == Parameter.POSITIONAL_ONLY: 220 | try: 221 | posargs.append(argname_to_arg[param.name]) 222 | except KeyError: 223 | pytest.skip(no_arg_msg.format(param.name)) 224 | elif param.kind == Parameter.POSITIONAL_OR_KEYWORD: 225 | if param.default == Parameter.empty: 226 | try: 227 | posorkw_args[param.name] = argname_to_arg[param.name] 228 | except KeyError: 229 | pytest.skip(no_arg_msg.format(param.name)) 230 | else: 231 | assert argname_to_arg[param.name] 232 | posorkw_args[param.name] = param.default 233 | elif param.kind == Parameter.KEYWORD_ONLY: 234 | assert param.default != Parameter.empty # sanity check 235 | kwargs[param.name] = param.default 236 | else: 237 | assert param.kind in VAR_KINDS # sanity check 238 | pytest.skip(no_arg_msg.format(param.name)) 239 | if len(posorkw_args) == 0: 240 | func(*posargs, **kwargs) 241 | else: 242 | posorkw_name_to_arg_pairs = list(posorkw_args.items()) 243 | for i in range(len(posorkw_name_to_arg_pairs), -1, -1): 244 | extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]] 245 | extra_kwargs = dict(posorkw_name_to_arg_pairs[i:]) 246 | func(*posargs, *extra_posargs, **kwargs, **extra_kwargs) 247 | 248 | 249 | def _test_func_signature(func: Callable, stub: FunctionType, is_method=False): 250 | stub_sig = signature(stub) 251 | # If testing against array, ignore 'self' arg in stub as it won't be present 252 | # in func (which should be a method). 253 | if is_method: 254 | stub_params = list(stub_sig.parameters.values()) 255 | if stub_params[0].name == "self": 256 | del stub_params[0] 257 | stub_sig = Signature( 258 | parameters=stub_params, return_annotation=stub_sig.return_annotation 259 | ) 260 | 261 | try: 262 | sig = signature(func) 263 | except ValueError: 264 | try: 265 | _test_uninspectable_func(stub.__name__, func, stub_sig) 266 | except Exception as e: 267 | raise e from None # suppress parent exception for cleaner pytest output 268 | else: 269 | _test_inspectable_func(sig, stub_sig) 270 | 271 | 272 | @pytest.mark.parametrize( 273 | "stub", 274 | [s for stubs in category_to_funcs.values() for s in stubs], 275 | ids=lambda f: f.__name__, 276 | ) 277 | def test_func_signature(stub: FunctionType): 278 | assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module" 279 | func = getattr(xp, stub.__name__) 280 | _test_func_signature(func, stub) 281 | 282 | 283 | extension_and_stub_params = [] 284 | for ext, stubs in extension_to_funcs.items(): 285 | for stub in stubs: 286 | p = pytest.param( 287 | ext, stub, id=f"{ext}.{stub.__name__}", marks=pytest.mark.xp_extension(ext) 288 | ) 289 | extension_and_stub_params.append(p) 290 | 291 | 292 | @pytest.mark.parametrize("extension, stub", extension_and_stub_params) 293 | def test_extension_func_signature(extension: str, stub: FunctionType): 294 | mod = getattr(xp, extension) 295 | assert hasattr( 296 | mod, stub.__name__ 297 | ), f"{stub.__name__} not found in {extension} extension" 298 | func = getattr(mod, stub.__name__) 299 | _test_func_signature(func, stub) 300 | 301 | 302 | @pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__) 303 | def test_array_method_signature(stub: FunctionType): 304 | x_expr = func_to_specified_arg_exprs[stub.__name__]["self"] 305 | try: 306 | x = eval(x_expr, {"xp": xp}) 307 | except Exception as e: 308 | pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}") 309 | assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}" 310 | method = getattr(x, stub.__name__) 311 | _test_func_signature(method, stub, is_method=True) 312 | -------------------------------------------------------------------------------- /array_api_tests/test_statistical_functions.py: -------------------------------------------------------------------------------- 1 | import cmath 2 | import math 3 | from typing import Optional 4 | 5 | import pytest 6 | from hypothesis import assume, given 7 | from hypothesis import strategies as st 8 | 9 | from . import _array_module as xp 10 | from . import dtype_helpers as dh 11 | from . import hypothesis_helpers as hh 12 | from . import pytest_helpers as ph 13 | from . import shape_helpers as sh 14 | from . import xps, api_version 15 | from ._array_module import _UndefinedStub 16 | from .typing import DataType 17 | 18 | pytestmark = pytest.mark.ci 19 | 20 | 21 | def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: 22 | dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] 23 | if hh.FILTER_UNDEFINED_DTYPES: 24 | dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] 25 | return st.none() | st.sampled_from(dtypes) 26 | 27 | 28 | @given( 29 | x=xps.arrays( 30 | dtype=xps.real_dtypes(), 31 | shape=hh.shapes(min_side=1), 32 | elements={"allow_nan": False}, 33 | ), 34 | data=st.data(), 35 | ) 36 | def test_max(x, data): 37 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") 38 | keepdims = kw.get("keepdims", False) 39 | 40 | out = xp.max(x, **kw) 41 | 42 | ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype) 43 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 44 | ph.assert_keepdimable_shape( 45 | "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 46 | ) 47 | scalar_type = dh.get_scalar_type(out.dtype) 48 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 49 | max_ = scalar_type(out[out_idx]) 50 | elements = [] 51 | for idx in indices: 52 | s = scalar_type(x[idx]) 53 | elements.append(s) 54 | expected = max(elements) 55 | ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected) 56 | 57 | 58 | @given( 59 | x=xps.arrays( 60 | dtype=xps.floating_dtypes(), 61 | shape=hh.shapes(min_side=1), 62 | elements={"allow_nan": False}, 63 | ), 64 | data=st.data(), 65 | ) 66 | def test_mean(x, data): 67 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") 68 | keepdims = kw.get("keepdims", False) 69 | 70 | out = xp.mean(x, **kw) 71 | 72 | ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype) 73 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 74 | ph.assert_keepdimable_shape( 75 | "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 76 | ) 77 | # Values testing mean is too finicky 78 | 79 | 80 | @given( 81 | x=xps.arrays( 82 | dtype=xps.real_dtypes(), 83 | shape=hh.shapes(min_side=1), 84 | elements={"allow_nan": False}, 85 | ), 86 | data=st.data(), 87 | ) 88 | def test_min(x, data): 89 | kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") 90 | keepdims = kw.get("keepdims", False) 91 | 92 | out = xp.min(x, **kw) 93 | 94 | ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype) 95 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 96 | ph.assert_keepdimable_shape( 97 | "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 98 | ) 99 | scalar_type = dh.get_scalar_type(out.dtype) 100 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 101 | min_ = scalar_type(out[out_idx]) 102 | elements = [] 103 | for idx in indices: 104 | s = scalar_type(x[idx]) 105 | elements.append(s) 106 | expected = min(elements) 107 | ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) 108 | 109 | 110 | @given( 111 | x=xps.arrays( 112 | dtype=xps.numeric_dtypes(), 113 | shape=hh.shapes(min_side=1), 114 | elements={"allow_nan": False}, 115 | ), 116 | data=st.data(), 117 | ) 118 | def test_prod(x, data): 119 | kw = data.draw( 120 | hh.kwargs( 121 | axis=hh.axes(x.ndim), 122 | dtype=kwarg_dtypes(x.dtype), 123 | keepdims=st.booleans(), 124 | ), 125 | label="kw", 126 | ) 127 | keepdims = kw.get("keepdims", False) 128 | 129 | with hh.reject_overflow(): 130 | out = xp.prod(x, **kw) 131 | 132 | dtype = kw.get("dtype", None) 133 | if dtype is None: 134 | if dh.is_int_dtype(x.dtype): 135 | if x.dtype in dh.uint_dtypes: 136 | default_dtype = dh.default_uint 137 | else: 138 | default_dtype = dh.default_int 139 | if default_dtype is None: 140 | _dtype = None 141 | else: 142 | m, M = dh.dtype_ranges[x.dtype] 143 | d_m, d_M = dh.dtype_ranges[default_dtype] 144 | if m < d_m or M > d_M: 145 | _dtype = x.dtype 146 | else: 147 | _dtype = default_dtype 148 | elif dh.is_float_dtype(x.dtype, include_complex=False): 149 | if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: 150 | _dtype = x.dtype 151 | else: 152 | _dtype = dh.default_float 153 | elif api_version > "2021.12": 154 | # Complex dtype 155 | if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: 156 | _dtype = x.dtype 157 | else: 158 | _dtype = dh.default_complex 159 | else: 160 | raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") 161 | else: 162 | _dtype = dtype 163 | if _dtype is None: 164 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support 165 | # uint32 or uint64), we skip testing the output dtype. 166 | # See https://github.com/data-apis/array-api-tests/issues/106 167 | if x.dtype in dh.uint_dtypes: 168 | assert dh.is_int_dtype(out.dtype) # sanity check 169 | else: 170 | ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) 171 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 172 | ph.assert_keepdimable_shape( 173 | "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 174 | ) 175 | scalar_type = dh.get_scalar_type(out.dtype) 176 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 177 | prod = scalar_type(out[out_idx]) 178 | assume(cmath.isfinite(prod)) 179 | elements = [] 180 | for idx in indices: 181 | s = scalar_type(x[idx]) 182 | elements.append(s) 183 | expected = math.prod(elements) 184 | if dh.is_int_dtype(out.dtype): 185 | m, M = dh.dtype_ranges[out.dtype] 186 | assume(m <= expected <= M) 187 | ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected) 188 | 189 | 190 | @given( 191 | x=xps.arrays( 192 | dtype=xps.floating_dtypes(), 193 | shape=hh.shapes(min_side=1), 194 | elements={"allow_nan": False}, 195 | ).filter(lambda x: math.prod(x.shape) >= 2), 196 | data=st.data(), 197 | ) 198 | def test_std(x, data): 199 | axis = data.draw(hh.axes(x.ndim), label="axis") 200 | _axes = sh.normalise_axis(axis, x.ndim) 201 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) 202 | correction = data.draw( 203 | st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), 204 | label="correction", 205 | ) 206 | _keepdims = data.draw(st.booleans(), label="keepdims") 207 | kw = data.draw( 208 | hh.specified_kwargs( 209 | ("axis", axis, None), 210 | ("correction", correction, 0.0), 211 | ("keepdims", _keepdims, False), 212 | ), 213 | label="kw", 214 | ) 215 | keepdims = kw.get("keepdims", False) 216 | 217 | out = xp.std(x, **kw) 218 | 219 | ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype) 220 | ph.assert_keepdimable_shape( 221 | "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 222 | ) 223 | # We can't easily test the result(s) as standard deviation methods vary a lot 224 | 225 | 226 | @given( 227 | x=xps.arrays( 228 | dtype=xps.numeric_dtypes(), 229 | shape=hh.shapes(min_side=1), 230 | elements={"allow_nan": False}, 231 | ), 232 | data=st.data(), 233 | ) 234 | def test_sum(x, data): 235 | kw = data.draw( 236 | hh.kwargs( 237 | axis=hh.axes(x.ndim), 238 | dtype=kwarg_dtypes(x.dtype), 239 | keepdims=st.booleans(), 240 | ), 241 | label="kw", 242 | ) 243 | keepdims = kw.get("keepdims", False) 244 | 245 | with hh.reject_overflow(): 246 | out = xp.sum(x, **kw) 247 | 248 | dtype = kw.get("dtype", None) 249 | if dtype is None: 250 | if dh.is_int_dtype(x.dtype): 251 | if x.dtype in dh.uint_dtypes: 252 | default_dtype = dh.default_uint 253 | else: 254 | default_dtype = dh.default_int 255 | if default_dtype is None: 256 | _dtype = None 257 | else: 258 | m, M = dh.dtype_ranges[x.dtype] 259 | d_m, d_M = dh.dtype_ranges[default_dtype] 260 | if m < d_m or M > d_M: 261 | _dtype = x.dtype 262 | else: 263 | _dtype = default_dtype 264 | elif dh.is_float_dtype(x.dtype, include_complex=False): 265 | if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: 266 | _dtype = x.dtype 267 | else: 268 | _dtype = dh.default_float 269 | elif api_version > "2021.12": 270 | # Complex dtype 271 | if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: 272 | _dtype = x.dtype 273 | else: 274 | _dtype = dh.default_complex 275 | else: 276 | raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") 277 | else: 278 | _dtype = dtype 279 | if _dtype is None: 280 | # If a default uint cannot exist (i.e. in PyTorch which doesn't support 281 | # uint32 or uint64), we skip testing the output dtype. 282 | # See https://github.com/data-apis/array-api-tests/issues/160 283 | if x.dtype in dh.uint_dtypes: 284 | assert dh.is_int_dtype(out.dtype) # sanity check 285 | else: 286 | ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) 287 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 288 | ph.assert_keepdimable_shape( 289 | "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 290 | ) 291 | scalar_type = dh.get_scalar_type(out.dtype) 292 | for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): 293 | sum_ = scalar_type(out[out_idx]) 294 | assume(cmath.isfinite(sum_)) 295 | elements = [] 296 | for idx in indices: 297 | s = scalar_type(x[idx]) 298 | elements.append(s) 299 | expected = sum(elements) 300 | if dh.is_int_dtype(out.dtype): 301 | m, M = dh.dtype_ranges[out.dtype] 302 | assume(m <= expected <= M) 303 | ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) 304 | 305 | 306 | @given( 307 | x=xps.arrays( 308 | dtype=xps.floating_dtypes(), 309 | shape=hh.shapes(min_side=1), 310 | elements={"allow_nan": False}, 311 | ).filter(lambda x: math.prod(x.shape) >= 2), 312 | data=st.data(), 313 | ) 314 | def test_var(x, data): 315 | axis = data.draw(hh.axes(x.ndim), label="axis") 316 | _axes = sh.normalise_axis(axis, x.ndim) 317 | N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) 318 | correction = data.draw( 319 | st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), 320 | label="correction", 321 | ) 322 | _keepdims = data.draw(st.booleans(), label="keepdims") 323 | kw = data.draw( 324 | hh.specified_kwargs( 325 | ("axis", axis, None), 326 | ("correction", correction, 0.0), 327 | ("keepdims", _keepdims, False), 328 | ), 329 | label="kw", 330 | ) 331 | keepdims = kw.get("keepdims", False) 332 | 333 | out = xp.var(x, **kw) 334 | 335 | ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype) 336 | ph.assert_keepdimable_shape( 337 | "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw 338 | ) 339 | # We can't easily test the result(s) as variance methods vary a lot 340 | -------------------------------------------------------------------------------- /array_api_tests/test_manipulation_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import deque 3 | from typing import Iterable, Iterator, Tuple, Union 4 | 5 | import pytest 6 | from hypothesis import assume, given 7 | from hypothesis import strategies as st 8 | 9 | from . import _array_module as xp 10 | from . import dtype_helpers as dh 11 | from . import hypothesis_helpers as hh 12 | from . import pytest_helpers as ph 13 | from . import shape_helpers as sh 14 | from . import xps 15 | from .typing import Array, Shape 16 | 17 | pytestmark = pytest.mark.ci 18 | 19 | MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 20 | MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims 21 | 22 | 23 | def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: 24 | key = "shape" 25 | if args: 26 | key += " " + " ".join(args) 27 | if kwargs: 28 | key += " " + ph.fmt_kw(kwargs) 29 | return st.shared(hh.shapes(*args, **kwargs), key="shape") 30 | 31 | 32 | def assert_array_ndindex( 33 | func_name: str, 34 | x: Array, 35 | *, 36 | x_indices: Iterable[Union[int, Shape]], 37 | out: Array, 38 | out_indices: Iterable[Union[int, Shape]], 39 | kw: dict = {}, 40 | ): 41 | msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}" 42 | for x_idx, out_idx in zip(x_indices, out_indices): 43 | msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" 44 | msg += msg_suffix 45 | if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): 46 | assert xp.isnan(out[out_idx]), msg 47 | else: 48 | assert out[out_idx] == x[x_idx], msg 49 | 50 | 51 | @given( 52 | dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), 53 | base_shape=hh.shapes(), 54 | data=st.data(), 55 | ) 56 | def test_concat(dtypes, base_shape, data): 57 | axis_strat = st.none() 58 | ndim = len(base_shape) 59 | if ndim > 0: 60 | axis_strat |= st.integers(-ndim, ndim - 1) 61 | kw = data.draw( 62 | axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw" 63 | ) 64 | axis = kw.get("axis", 0) 65 | if axis is None: 66 | _axis = None 67 | shape_strat = hh.shapes() 68 | else: 69 | _axis = axis if axis >= 0 else len(base_shape) + axis 70 | shape_strat = st.integers(0, MAX_SIDE).map( 71 | lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :] 72 | ) 73 | arrays = [] 74 | for i, dtype in enumerate(dtypes, 1): 75 | x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") 76 | arrays.append(x) 77 | 78 | out = xp.concat(arrays, **kw) 79 | 80 | ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype) 81 | 82 | shapes = tuple(x.shape for x in arrays) 83 | if _axis is None: 84 | size = sum(math.prod(s) for s in shapes) 85 | shape = (size,) 86 | else: 87 | shape = list(shapes[0]) 88 | for other_shape in shapes[1:]: 89 | shape[_axis] += other_shape[_axis] 90 | shape = tuple(shape) 91 | ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw) 92 | 93 | if _axis is None: 94 | out_indices = (i for i in range(math.prod(out.shape))) 95 | for x_num, x in enumerate(arrays, 1): 96 | for x_idx in sh.ndindex(x.shape): 97 | out_i = next(out_indices) 98 | ph.assert_0d_equals( 99 | "concat", 100 | x_repr=f"x{x_num}[{x_idx}]", 101 | x_val=x[x_idx], 102 | out_repr=f"out[{out_i}]", 103 | out_val=out[out_i], 104 | kw=kw, 105 | ) 106 | else: 107 | out_indices = sh.ndindex(out.shape) 108 | for idx in sh.axis_ndindex(shapes[0], _axis): 109 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) 110 | for x_num, x in enumerate(arrays, 1): 111 | indexed_x = x[idx] 112 | for x_idx in sh.ndindex(indexed_x.shape): 113 | out_idx = next(out_indices) 114 | ph.assert_0d_equals( 115 | "concat", 116 | x_repr=f"x{x_num}[{f_idx}][{x_idx}]", 117 | x_val=indexed_x[x_idx], 118 | out_repr=f"out[{out_idx}]", 119 | out_val=out[out_idx], 120 | kw=kw, 121 | ) 122 | 123 | 124 | @given( 125 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), 126 | axis=shared_shapes().flatmap( 127 | # Generate both valid and invalid axis 128 | lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) 129 | ), 130 | ) 131 | def test_expand_dims(x, axis): 132 | if axis < -x.ndim - 1 or axis > x.ndim: 133 | with pytest.raises(IndexError): 134 | xp.expand_dims(x, axis=axis) 135 | return 136 | 137 | out = xp.expand_dims(x, axis=axis) 138 | 139 | ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) 140 | 141 | shape = [side for side in x.shape] 142 | index = axis if axis >= 0 else x.ndim + axis + 1 143 | shape.insert(index, 1) 144 | shape = tuple(shape) 145 | ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) 146 | 147 | assert_array_ndindex( 148 | "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) 149 | ) 150 | 151 | 152 | @given( 153 | x=xps.arrays( 154 | dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) 155 | ), 156 | data=st.data(), 157 | ) 158 | def test_squeeze(x, data): 159 | axes = st.integers(-x.ndim, x.ndim - 1) 160 | axis = data.draw( 161 | axes 162 | | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), 163 | label="axis", 164 | ) 165 | 166 | axes = (axis,) if isinstance(axis, int) else axis 167 | axes = sh.normalise_axis(axes, x.ndim) 168 | 169 | squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] 170 | if any(i not in squeezable_axes for i in axes): 171 | with pytest.raises(ValueError): 172 | xp.squeeze(x, axis) 173 | return 174 | 175 | out = xp.squeeze(x, axis) 176 | 177 | ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype) 178 | 179 | shape = [] 180 | for i, side in enumerate(x.shape): 181 | if i not in axes: 182 | shape.append(side) 183 | shape = tuple(shape) 184 | ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis)) 185 | 186 | assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) 187 | 188 | 189 | @given( 190 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), 191 | data=st.data(), 192 | ) 193 | def test_flip(x, data): 194 | if x.ndim == 0: 195 | axis_strat = st.none() 196 | else: 197 | axis_strat = ( 198 | st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) 199 | ) 200 | kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") 201 | 202 | out = xp.flip(x, **kw) 203 | 204 | ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype) 205 | 206 | _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) 207 | for indices in sh.axes_ndindex(x.shape, _axes): 208 | reverse_indices = indices[::-1] 209 | assert_array_ndindex("flip", x, x_indices=indices, out=out, 210 | out_indices=reverse_indices, kw=kw) 211 | 212 | 213 | @given( 214 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), 215 | axes=shared_shapes(min_dims=1).flatmap( 216 | lambda s: st.lists( 217 | st.integers(0, len(s) - 1), 218 | min_size=len(s), 219 | max_size=len(s), 220 | unique=True, 221 | ).map(tuple) 222 | ), 223 | ) 224 | def test_permute_dims(x, axes): 225 | out = xp.permute_dims(x, axes) 226 | 227 | ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype) 228 | 229 | shape = [None for _ in range(len(axes))] 230 | for i, dim in enumerate(axes): 231 | side = x.shape[dim] 232 | shape[i] = side 233 | shape = tuple(shape) 234 | ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes)) 235 | 236 | indices = list(sh.ndindex(x.shape)) 237 | permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] 238 | assert_array_ndindex("permute_dims", x, x_indices=indices, out=out, 239 | out_indices=permuted_indices) 240 | 241 | 242 | @st.composite 243 | def reshape_shapes(draw, shape): 244 | size = 1 if len(shape) == 0 else math.prod(shape) 245 | rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) 246 | assume(all(side <= MAX_SIDE for side in rshape)) 247 | if len(rshape) != 0 and size > 0 and draw(st.booleans()): 248 | index = draw(st.integers(0, len(rshape) - 1)) 249 | rshape[index] = -1 250 | return tuple(rshape) 251 | 252 | 253 | @given( 254 | x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), 255 | data=st.data(), 256 | ) 257 | def test_reshape(x, data): 258 | shape = data.draw(reshape_shapes(x.shape)) 259 | 260 | out = xp.reshape(x, shape) 261 | 262 | ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) 263 | 264 | _shape = list(shape) 265 | if any(side == -1 for side in shape): 266 | size = math.prod(x.shape) 267 | rsize = math.prod(shape) * -1 268 | _shape[shape.index(-1)] = size / rsize 269 | _shape = tuple(_shape) 270 | ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape)) 271 | 272 | assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) 273 | 274 | 275 | def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: 276 | assert len(shifts) == len(axes) # sanity check 277 | all_shifts = [0 for _ in shape] 278 | for s, a in zip(shifts, axes): 279 | all_shifts[a] = s 280 | for idx in sh.ndindex(shape): 281 | yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape)) 282 | 283 | 284 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) 285 | def test_roll(x, data): 286 | shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) 287 | if x.ndim > 0: 288 | shift_strat = shift_strat | st.lists( 289 | shift_strat, min_size=1, max_size=x.ndim 290 | ).map(tuple) 291 | shift = data.draw(shift_strat, label="shift") 292 | if isinstance(shift, tuple): 293 | axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift)) 294 | kw_strat = axis_strat.map(lambda t: {"axis": t}) 295 | else: 296 | axis_strat = st.none() 297 | if x.ndim != 0: 298 | axis_strat |= st.integers(-x.ndim, x.ndim - 1) 299 | kw_strat = hh.kwargs(axis=axis_strat) 300 | kw = data.draw(kw_strat, label="kw") 301 | 302 | out = xp.roll(x, shift, **kw) 303 | 304 | kw = {"shift": shift, **kw} # for error messages 305 | 306 | ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype) 307 | 308 | ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw) 309 | 310 | if kw.get("axis", None) is None: 311 | assert isinstance(shift, int) # sanity check 312 | indices = list(sh.ndindex(x.shape)) 313 | shifted_indices = deque(indices) 314 | shifted_indices.rotate(-shift) 315 | assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw) 316 | else: 317 | shifts = (shift,) if isinstance(shift, int) else shift 318 | axes = sh.normalise_axis(kw["axis"], x.ndim) 319 | shifted_indices = roll_ndindex(x.shape, shifts, axes) 320 | assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw) 321 | 322 | 323 | @given( 324 | shape=shared_shapes(min_dims=1), 325 | dtypes=hh.mutually_promotable_dtypes(None), 326 | kw=hh.kwargs( 327 | axis=shared_shapes(min_dims=1).flatmap( 328 | lambda s: st.integers(-len(s), len(s) - 1) 329 | ) 330 | ), 331 | data=st.data(), 332 | ) 333 | def test_stack(shape, dtypes, kw, data): 334 | arrays = [] 335 | for i, dtype in enumerate(dtypes, 1): 336 | x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") 337 | arrays.append(x) 338 | 339 | out = xp.stack(arrays, **kw) 340 | 341 | ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype) 342 | 343 | axis = kw.get("axis", 0) 344 | _axis = axis if axis >= 0 else len(shape) + axis + 1 345 | _shape = list(shape) 346 | _shape.insert(_axis, len(arrays)) 347 | _shape = tuple(_shape) 348 | ph.assert_result_shape( 349 | "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw 350 | ) 351 | 352 | out_indices = sh.ndindex(out.shape) 353 | for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): 354 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) 355 | for x_num, x in enumerate(arrays, 1): 356 | indexed_x = x[idx] 357 | for x_idx in sh.ndindex(indexed_x.shape): 358 | out_idx = next(out_indices) 359 | ph.assert_0d_equals( 360 | "stack", 361 | x_repr=f"x{x_num}[{f_idx}][{x_idx}]", 362 | x_val=indexed_x[x_idx], 363 | out_repr=f"out[{out_idx}]", 364 | out_val=out[out_idx], 365 | kw=kw, 366 | ) 367 | -------------------------------------------------------------------------------- /array_api_tests/pytest_helpers.py: -------------------------------------------------------------------------------- 1 | import cmath 2 | import math 3 | from inspect import getfullargspec 4 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 5 | 6 | from . import _array_module as xp 7 | from . import dtype_helpers as dh 8 | from . import shape_helpers as sh 9 | from . import stubs 10 | from .typing import Array, DataType, Scalar, ScalarType, Shape 11 | 12 | __all__ = [ 13 | "raises", 14 | "doesnt_raise", 15 | "nargs", 16 | "fmt_kw", 17 | "is_pos_zero", 18 | "is_neg_zero", 19 | "assert_dtype", 20 | "assert_kw_dtype", 21 | "assert_default_float", 22 | "assert_default_int", 23 | "assert_default_index", 24 | "assert_shape", 25 | "assert_result_shape", 26 | "assert_keepdimable_shape", 27 | "assert_0d_equals", 28 | "assert_fill", 29 | "assert_array_elements", 30 | ] 31 | 32 | 33 | def raises(exceptions, function, message=""): 34 | """ 35 | Like pytest.raises() except it allows custom error messages 36 | """ 37 | try: 38 | function() 39 | except exceptions: 40 | return 41 | except Exception as e: 42 | if message: 43 | raise AssertionError( 44 | f"Unexpected exception {e!r} (expected {exceptions}): {message}" 45 | ) 46 | raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})") 47 | raise AssertionError(message) 48 | 49 | 50 | def doesnt_raise(function, message=""): 51 | """ 52 | The inverse of raises(). 53 | 54 | Use doesnt_raise(function) to test that function() doesn't raise any 55 | exceptions. Returns the result of calling function. 56 | """ 57 | if not callable(function): 58 | raise ValueError("doesnt_raise should take a lambda") 59 | try: 60 | return function() 61 | except Exception as e: 62 | if message: 63 | raise AssertionError(f"Unexpected exception {e!r}: {message}") 64 | raise AssertionError(f"Unexpected exception {e!r}") 65 | 66 | 67 | def nargs(func_name): 68 | return len(getfullargspec(stubs.name_to_func[func_name]).args) 69 | 70 | 71 | def fmt_kw(kw: Dict[str, Any]) -> str: 72 | return ", ".join(f"{k}={v}" for k, v in kw.items()) 73 | 74 | 75 | def is_pos_zero(n: float) -> bool: 76 | return n == 0 and math.copysign(1, n) == 1 77 | 78 | 79 | def is_neg_zero(n: float) -> bool: 80 | return n == 0 and math.copysign(1, n) == -1 81 | 82 | 83 | def assert_dtype( 84 | func_name: str, 85 | *, 86 | in_dtype: Union[DataType, Sequence[DataType]], 87 | out_dtype: DataType, 88 | expected: Optional[DataType] = None, 89 | repr_name: str = "out.dtype", 90 | ): 91 | """ 92 | Assert the output dtype is as expected. 93 | 94 | If expected=None, we infer the expected dtype as in_dtype, to test 95 | out_dtype, e.g. 96 | 97 | >>> x = xp.arange(5, dtype=xp.uint8) 98 | >>> out = xp.abs(x) 99 | >>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype) 100 | 101 | is equivalent to 102 | 103 | >>> assert out.dtype == xp.uint8 104 | 105 | Or for multiple input dtypes, the expected dtype is inferred from their 106 | resulting type promotion, e.g. 107 | 108 | >>> x1 = xp.arange(5, dtype=xp.uint8) 109 | >>> x2 = xp.arange(5, dtype=xp.uint16) 110 | >>> out = xp.add(x1, x2) 111 | >>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) 112 | 113 | is equivalent to 114 | 115 | >>> assert out.dtype == xp.uint16 116 | 117 | We can also specify the expected dtype ourselves, e.g. 118 | 119 | >>> x = xp.arange(5, dtype=xp.int8) 120 | >>> out = xp.sum(x) 121 | >>> default_int = xp.asarray(0).dtype 122 | >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int) 123 | 124 | """ 125 | __tracebackhide__ = True 126 | in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] 127 | f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) 128 | f_out_dtype = dh.dtype_to_name[out_dtype] 129 | if expected is None: 130 | expected = dh.result_type(*in_dtypes) 131 | f_expected = dh.dtype_to_name[expected] 132 | msg = ( 133 | f"{repr_name}={f_out_dtype}, but should be {f_expected} " 134 | f"[{func_name}({f_in_dtypes})]" 135 | ) 136 | assert out_dtype == expected, msg 137 | 138 | 139 | def assert_kw_dtype( 140 | func_name: str, 141 | *, 142 | kw_dtype: DataType, 143 | out_dtype: DataType, 144 | ): 145 | """ 146 | Assert the output dtype is the passed keyword dtype, e.g. 147 | 148 | >>> kw = {'dtype': xp.uint8} 149 | >>> out = xp.ones(5, kw=kw) 150 | >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype) 151 | 152 | """ 153 | __tracebackhide__ = True 154 | f_kw_dtype = dh.dtype_to_name[kw_dtype] 155 | f_out_dtype = dh.dtype_to_name[out_dtype] 156 | msg = ( 157 | f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} " 158 | f"[{func_name}(dtype={f_kw_dtype})]" 159 | ) 160 | assert out_dtype == kw_dtype, msg 161 | 162 | 163 | def assert_default_float(func_name: str, out_dtype: DataType): 164 | """ 165 | Assert the output dtype is the default float, e.g. 166 | 167 | >>> out = xp.ones(5) 168 | >>> assert_default_float('ones', out.dtype) 169 | 170 | """ 171 | __tracebackhide__ = True 172 | f_dtype = dh.dtype_to_name[out_dtype] 173 | f_default = dh.dtype_to_name[dh.default_float] 174 | msg = ( 175 | f"out.dtype={f_dtype}, should be default " 176 | f"floating-point dtype {f_default} [{func_name}()]" 177 | ) 178 | assert out_dtype == dh.default_float, msg 179 | 180 | 181 | def assert_default_complex(func_name: str, out_dtype: DataType): 182 | """ 183 | Assert the output dtype is the default complex, e.g. 184 | 185 | >>> out = xp.asarray(4+2j) 186 | >>> assert_default_complex('asarray', out.dtype) 187 | 188 | """ 189 | __tracebackhide__ = True 190 | f_dtype = dh.dtype_to_name[out_dtype] 191 | f_default = dh.dtype_to_name[dh.default_complex] 192 | msg = ( 193 | f"out.dtype={f_dtype}, should be default " 194 | f"complex dtype {f_default} [{func_name}()]" 195 | ) 196 | assert out_dtype == dh.default_complex, msg 197 | 198 | 199 | def assert_default_int(func_name: str, out_dtype: DataType): 200 | """ 201 | Assert the output dtype is the default int, e.g. 202 | 203 | >>> out = xp.full(5, 42) 204 | >>> assert_default_int('full', out.dtype) 205 | 206 | """ 207 | __tracebackhide__ = True 208 | f_dtype = dh.dtype_to_name[out_dtype] 209 | f_default = dh.dtype_to_name[dh.default_int] 210 | msg = ( 211 | f"out.dtype={f_dtype}, should be default " 212 | f"integer dtype {f_default} [{func_name}()]" 213 | ) 214 | assert out_dtype == dh.default_int, msg 215 | 216 | 217 | def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dtype"): 218 | """ 219 | Assert the output dtype is the default index dtype, e.g. 220 | 221 | >>> out = xp.argmax(xp.arange(5)) 222 | >>> assert_default_int('argmax', out.dtype) 223 | 224 | """ 225 | __tracebackhide__ = True 226 | f_dtype = dh.dtype_to_name[out_dtype] 227 | msg = ( 228 | f"{repr_name}={f_dtype}, should be the default index dtype, " 229 | f"which is either int32 or int64 [{func_name}()]" 230 | ) 231 | assert out_dtype in (xp.int32, xp.int64), msg 232 | 233 | 234 | def assert_shape( 235 | func_name: str, 236 | *, 237 | out_shape: Union[int, Shape], 238 | expected: Union[int, Shape], 239 | repr_name="out.shape", 240 | kw: dict = {}, 241 | ): 242 | """ 243 | Assert the output shape is as expected, e.g. 244 | 245 | >>> out = xp.ones((3, 3, 3)) 246 | >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3)) 247 | 248 | """ 249 | __tracebackhide__ = True 250 | if isinstance(out_shape, int): 251 | out_shape = (out_shape,) 252 | if isinstance(expected, int): 253 | expected = (expected,) 254 | msg = ( 255 | f"{repr_name}={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]" 256 | ) 257 | assert out_shape == expected, msg 258 | 259 | 260 | def assert_result_shape( 261 | func_name: str, 262 | in_shapes: Sequence[Shape], 263 | out_shape: Shape, 264 | expected: Optional[Shape] = None, 265 | *, 266 | repr_name="out.shape", 267 | kw: dict = {}, 268 | ): 269 | """ 270 | Assert the output shape is as expected. 271 | 272 | If expected=None, we infer the expected shape as the result of broadcasting 273 | in_shapes, to test against out_shape, e.g. 274 | 275 | >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3))) 276 | >>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape) 277 | 278 | is equivalent to 279 | 280 | >>> assert out.shape == (3, 3) 281 | 282 | """ 283 | __tracebackhide__ = True 284 | if expected is None: 285 | expected = sh.broadcast_shapes(*in_shapes) 286 | f_in_shapes = " . ".join(str(s) for s in in_shapes) 287 | f_sig = f" {f_in_shapes} " 288 | if kw: 289 | f_sig += f", {fmt_kw(kw)}" 290 | msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" 291 | assert out_shape == expected, msg 292 | 293 | 294 | def assert_keepdimable_shape( 295 | func_name: str, 296 | *, 297 | in_shape: Shape, 298 | out_shape: Shape, 299 | axes: Tuple[int, ...], 300 | keepdims: bool, 301 | kw: dict = {}, 302 | ): 303 | """ 304 | Assert the output shape from a keepdimable function is as expected, e.g. 305 | 306 | >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 307 | >>> out1 = xp.max(x, keepdims=False) 308 | >>> out2 = xp.max(x, keepdims=True) 309 | >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False) 310 | >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True) 311 | 312 | is equivalent to 313 | 314 | >>> assert out1.shape == () 315 | >>> assert out2.shape == (1, 1) 316 | 317 | """ 318 | __tracebackhide__ = True 319 | if keepdims: 320 | shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) 321 | else: 322 | shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) 323 | assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw) 324 | 325 | 326 | def assert_0d_equals( 327 | func_name: str, 328 | *, 329 | x_repr: str, 330 | x_val: Array, 331 | out_repr: str, 332 | out_val: Array, 333 | kw: dict = {}, 334 | ): 335 | """ 336 | Assert a 0d array is as expected, e.g. 337 | 338 | >>> x = xp.asarray([0, 1, 2]) 339 | >>> kw = {'copy': True} 340 | >>> res = xp.asarray(x, **kw) 341 | >>> res[0] = 42 342 | >>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw) 343 | 344 | is equivalent to 345 | 346 | >>> assert res[0] == x[0] 347 | 348 | """ 349 | __tracebackhide__ = True 350 | msg = ( 351 | f"{out_repr}={out_val}, but should be {x_repr}={x_val} " 352 | f"[{func_name}({fmt_kw(kw)})]" 353 | ) 354 | if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): 355 | assert xp.isnan(x_val), msg 356 | else: 357 | assert x_val == out_val, msg 358 | 359 | 360 | def assert_scalar_equals( 361 | func_name: str, 362 | *, 363 | type_: ScalarType, 364 | idx: Shape, 365 | out: Scalar, 366 | expected: Scalar, 367 | repr_name: str = "out", 368 | kw: dict = {}, 369 | ): 370 | """ 371 | Assert a 0d array, convered to a scalar, is as expected, e.g. 372 | 373 | >>> x = xp.ones(5, dtype=xp.uint8) 374 | >>> out = xp.sum(x) 375 | >>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5) 376 | 377 | is equivalent to 378 | 379 | >>> assert int(out) == 5 380 | 381 | """ 382 | __tracebackhide__ = True 383 | repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" 384 | f_func = f"{func_name}({fmt_kw(kw)})" 385 | if type_ in [bool, int]: 386 | msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" 387 | assert out == expected, msg 388 | elif cmath.isnan(expected): 389 | msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" 390 | assert cmath.isnan(out), msg 391 | else: 392 | msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" 393 | assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg 394 | 395 | 396 | def assert_fill( 397 | func_name: str, 398 | *, 399 | fill_value: Scalar, 400 | dtype: DataType, 401 | out: Array, 402 | kw: dict = {}, 403 | ): 404 | """ 405 | Assert all elements of an array is as expected, e.g. 406 | 407 | >>> out = xp.full(5, 42, dtype=xp.uint8) 408 | >>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5)) 409 | 410 | is equivalent to 411 | 412 | >>> assert xp.all(out == 42) 413 | 414 | """ 415 | __tracebackhide__ = True 416 | msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" 417 | if cmath.isnan(fill_value): 418 | assert xp.all(xp.isnan(out)), msg 419 | else: 420 | assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg 421 | 422 | 423 | def _assert_float_element(at_out: Array, at_expected: Array, msg: str): 424 | if xp.isnan(at_expected): 425 | assert xp.isnan(at_out), msg 426 | elif at_expected == 0.0 or at_expected == -0.0: 427 | scalar_at_expected = float(at_expected) 428 | scalar_at_out = float(at_out) 429 | if is_pos_zero(scalar_at_expected): 430 | assert is_pos_zero(scalar_at_out), msg 431 | else: 432 | assert is_neg_zero(scalar_at_expected) # sanity check 433 | assert is_neg_zero(scalar_at_out), msg 434 | else: 435 | assert at_out == at_expected, msg 436 | 437 | 438 | def assert_array_elements( 439 | func_name: str, 440 | *, 441 | out: Array, 442 | expected: Array, 443 | out_repr: str = "out", 444 | kw: dict = {}, 445 | ): 446 | """ 447 | Assert array elements are (strictly) as expected, e.g. 448 | 449 | >>> x = xp.arange(5) 450 | >>> out = xp.asarray(x) 451 | >>> assert_array_elements('asarray', out=out, expected=x) 452 | 453 | is equivalent to 454 | 455 | >>> assert xp.all(out == x) 456 | 457 | """ 458 | __tracebackhide__ = True 459 | dh.result_type(out.dtype, expected.dtype) # sanity check 460 | assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check 461 | f_func = f"[{func_name}({fmt_kw(kw)})]" 462 | if out.dtype in dh.real_float_dtypes: 463 | for idx in sh.ndindex(out.shape): 464 | at_out = out[idx] 465 | at_expected = expected[idx] 466 | msg = ( 467 | f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " 468 | f"{f_func}" 469 | ) 470 | _assert_float_element(at_out, at_expected, msg) 471 | elif out.dtype in dh.complex_dtypes: 472 | assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) 473 | for idx in sh.ndindex(out.shape): 474 | at_out = out[idx] 475 | at_expected = expected[idx] 476 | msg = ( 477 | f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " 478 | f"{f_func}" 479 | ) 480 | _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) 481 | _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) 482 | else: 483 | assert xp.all( 484 | out == expected 485 | ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" 486 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Test Suite for Array API Compliance 2 | 3 | This is the test suite for array libraries adopting the [Python Array API 4 | standard](https://data-apis.org/array-api/latest). 5 | 6 | Note the suite is still a **work in progress**. Feedback and contributions are 7 | welcome! 8 | 9 | ## Quickstart 10 | 11 | ### Setup 12 | 13 | Currently we pin the Array API specification repo [`array-api`](https://github.com/data-apis/array-api/) 14 | as a git submodule. This might change in the future to better support vendoring 15 | use cases (see [#107](https://github.com/data-apis/array-api-tests/issues/107)), 16 | but for now be sure submodules are pulled too, e.g. 17 | 18 | ```bash 19 | $ git submodule update --init 20 | ``` 21 | 22 | To run the tests, install the testing dependencies. 23 | 24 | ```bash 25 | $ pip install -r requirements.txt 26 | ``` 27 | 28 | Ensure you have the array library that you want to test installed. 29 | 30 | ### Specifying the array module 31 | 32 | You need to specify the array library to test. It can be specified via the 33 | `ARRAY_API_TESTS_MODULE` environment variable, e.g. 34 | 35 | ```bash 36 | $ export ARRAY_API_TESTS_MODULE=numpy.array_api 37 | ``` 38 | 39 | Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`. 40 | 41 | ### Run the suite 42 | 43 | Simply run `pytest` against the `array_api_tests/` folder to run the full suite. 44 | 45 | ```bash 46 | $ pytest array_api_tests/ 47 | ``` 48 | 49 | The suite tries to logically organise its tests. `pytest` allows you to only run 50 | a specific test case, which is useful when developing functions. 51 | 52 | ```bash 53 | $ pytest array_api_tests/test_creation_functions.py::test_zeros 54 | ``` 55 | 56 | ## What the test suite covers 57 | 58 | We are interested in array libraries conforming to the 59 | [spec](https://data-apis.org/array-api/latest/API_specification/index.html). 60 | Ideally this means that if a library has fully adopted the Array API, the test 61 | suite passes. We take great care to _not_ test things which are out-of-scope, 62 | so as to not unexpectedly fail the suite. 63 | 64 | ### Primary tests 65 | 66 | Every function—including array object methods—has a respective test 67 | method1. We use 68 | [Hypothesis](https://hypothesis.readthedocs.io/en/latest/) 69 | to generate a diverse set of valid inputs. This means array inputs will cover 70 | different dtypes and shapes, as well as contain interesting elements. These 71 | examples generate with interesting arrangements of non-array positional 72 | arguments and keyword arguments. 73 | 74 | Each test case will cover the following areas if relevant: 75 | 76 | * **Smoking**: We pass our generated examples to all functions. As these 77 | examples solely consist of *valid* inputs, we are testing that functions can 78 | be called using their documented inputs without raising errors. 79 | 80 | * **Data type**: For functions returning/modifying arrays, we assert that output 81 | arrays have the correct data types. Most functions 82 | [type-promote](https://data-apis.org/array-api/latest/API_specification/type_promotion.html) 83 | input arrays and some functions have bespoke rules—in both cases we simulate 84 | the correct behaviour to find the expected data types. 85 | 86 | * **Shape**: For functions returning/modifying arrays, we assert that output 87 | arrays have the correct shape. Most functions 88 | [broadcast](https://data-apis.org/array-api/latest/API_specification/broadcasting.html) 89 | input arrays and some functions have bespoke rules—in both cases we simulate 90 | the correct behaviour to find the expected shapes. 91 | 92 | * **Values**: We assert output values (including the elements of 93 | returned/modified arrays) are as expected. Except for manipulation functions 94 | or special cases, the spec allows floating-point inputs to have inexact 95 | outputs, so with such examples we only assert values are roughly as expected. 96 | 97 | ### Additional tests 98 | 99 | In addition to having one test case for each function, we test other properties 100 | of the functions and some miscellaneous things. 101 | 102 | * **Special cases**: For functions with special case behaviour, we assert that 103 | these functions return the correct values. 104 | 105 | * **Signatures**: We assert functions have the correct signatures. 106 | 107 | * **Constants**: We assert that 108 | [constants](https://data-apis.org/array-api/latest/API_specification/constants.html) 109 | behave expectedly, are roughly the expected value, and that any related 110 | functions interact with them correctly. 111 | 112 | Be aware that some aspects of the spec are impractical or impossible to actually 113 | test, so they are not covered in the suite. 114 | 115 | ## Interpreting errors 116 | 117 | First and foremost, note that most tests have to assume that certain aspects of 118 | the Array API have been correctly adopted, as fundamental APIs such as array 119 | creation and equalities are hard requirements for many assertions. This means a 120 | test case for one function might fail because another function has bugs or even 121 | no implementation. 122 | 123 | This means adopting libraries at first will result in a vast number of errors 124 | due to cascading errors. Generally the nature of the spec means many granular 125 | details such as type promotion is likely going to also fail nearly-conforming 126 | functions. 127 | 128 | We hope to improve user experience in regards to "noisy" errors in 129 | [#51](https://github.com/data-apis/array-api-tests/issues/51). For now, if an 130 | error message involves `_UndefinedStub`, it means an attribute of the array 131 | library (including functions) and it's objects (e.g. the array) is missing. 132 | 133 | The spec is the suite's source of truth. If the suite appears to assume 134 | behaviour different from the spec, or test something that is not documented, 135 | this is a bug—please [report such 136 | issues](https://github.com/data-apis/array-api-tests/issues/) to us. 137 | 138 | 139 | ## Running on CI 140 | 141 | See our existing [GitHub Actions workflow for 142 | Numpy](https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/numpy.yml) 143 | for an example of using the test suite on CI. 144 | 145 | ### Releases 146 | 147 | We recommend pinning against a [release tag](https://github.com/data-apis/array-api-tests/releases) 148 | when running on CI. 149 | 150 | We use [calender versioning](https://calver.org/) for the releases. You should 151 | expect that any version may be "breaking" compared to the previous one, in that 152 | new tests (or improvements to existing tests) may cause a previously passing 153 | library to fail. 154 | 155 | ### Configuration 156 | 157 | #### API version 158 | 159 | You can specify the API version to use when testing via the 160 | `ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the 161 | array module's `__array_api_version__` value, and if that attribute doesn't 162 | exist then we fallback to `"2021.12"`. 163 | 164 | #### CI flag 165 | 166 | Use the `--ci` flag to run only the primary and special cases tests. You can 167 | ignore the other test cases as they are redundant for the purposes of checking 168 | compliance. 169 | 170 | #### Data-dependent shapes 171 | 172 | Use the `--disable-data-dependent-shapes` flag to skip testing functions which have 173 | [data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). 174 | 175 | #### Extensions 176 | 177 | By default, tests for the optional Array API extensions such as 178 | [`linalg`](https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html) 179 | will be skipped if not present in the specified array module. You can purposely 180 | skip testing extension(s) via the `--disable-extension` option. 181 | 182 | #### Skip or XFAIL test cases 183 | 184 | Test cases you want to skip can be specified in a skips or XFAILS file. The 185 | difference between skip and XFAIL is that XFAIL tests are still run and 186 | reported as XPASS if they pass. 187 | 188 | By default, the skips and xfails files are `skips.txt` and `fails.txt` in the root 189 | of this repository, but any file can be specified with the `--skips-file` and 190 | `--xfails-file` command line flags. 191 | 192 | The files should list the test ids to be skipped/xfailed. Empty lines and 193 | lines starting with `#` are ignored. The test id can be any substring of the 194 | test ids to skip/xfail. 195 | 196 | ``` 197 | # skips.txt or xfails.txt 198 | # Line comments can be denoted with the hash symbol (#) 199 | 200 | # Skip specific test case, e.g. when argsort() does not respect relative order 201 | # https://github.com/numpy/numpy/issues/20778 202 | array_api_tests/test_sorting_functions.py::test_argsort 203 | 204 | # Skip specific test case parameter, e.g. you forgot to implement in-place adds 205 | array_api_tests/test_add[__iadd__(x1, x2)] 206 | array_api_tests/test_add[__iadd__(x, s)] 207 | 208 | # Skip module, e.g. when your set functions treat NaNs as non-distinct 209 | # https://github.com/numpy/numpy/issues/20326 210 | array_api_tests/test_set_functions.py 211 | ``` 212 | 213 | Here is an example GitHub Actions workflow file, where the xfails are stored 214 | in `array-api-tests.xfails.txt` in the base of the `your-array-library` repo. 215 | 216 | If you want, you can use `-o xfail_strict=True`, which causes XPASS tests (XFAIL 217 | tests that actually pass) to fail the test suite. However, be aware that 218 | XFAILures can be flaky (see below, so this may not be a good idea unless you 219 | use some other mitigation of such flakyness). 220 | 221 | If you don't want this behavior, you can remove it, or use `--skips-file` 222 | instead of `--xfails-file`. 223 | 224 | ```yaml 225 | # ./.github/workflows/array_api.yml 226 | jobs: 227 | tests: 228 | runs-on: ubuntu-latest 229 | strategy: 230 | matrix: 231 | python-version: ['3.8', '3.9', '3.10', '3.11'] 232 | 233 | steps: 234 | - name: Checkout 235 | uses: actions/checkout@v3 236 | with: 237 | path: your-array-library 238 | 239 | - name: Checkout array-api-tests 240 | uses: actions/checkout@v3 241 | with: 242 | repository: data-apis/array-api-tests 243 | submodules: 'true' 244 | path: array-api-tests 245 | 246 | - name: Run the array API test suite 247 | env: 248 | ARRAY_API_TESTS_MODULE: your.array.api.namespace 249 | run: | 250 | export PYTHONPATH="${GITHUB_WORKSPACE}/your-array-library" 251 | cd ${GITHUB_WORKSPACE}/array-api-tests 252 | pytest -v -rxXfE --ci --xfails-file ${GITHUB_WORKSPACE}/your-array-library/array-api-tests-xfails.txt array_api_tests/ 253 | ``` 254 | 255 | > **Warning** 256 | > 257 | > XFAIL tests that use Hypothesis (basically every test in the test suite except 258 | > those in test_has_names.py) can be flaky, due to the fact that Hypothesis 259 | > might not always run the test with an input that causes the test to fail. 260 | > There are several ways to avoid this problem: 261 | > 262 | > - Increase the maximum number of examples, e.g., by adding `--max-examples 263 | > 200` to the test command (the default is `100`, see below). This will 264 | > make it more likely that the failing case will be found, but it will also 265 | > make the tests take longer to run. 266 | > - Don't use `-o xfail_strict=True`. This will make it so that if an XFAIL 267 | > test passes, it will alert you in the test summary but will not cause the 268 | > test run to register as failed. 269 | > - Use skips instead of XFAILS. The difference between XFAIL and skip is that 270 | > a skipped test is never run at all, whereas an XFAIL test is always run 271 | > but ignored if it fails. 272 | > - Save the [Hypothesis examples 273 | > database](https://hypothesis.readthedocs.io/en/latest/database.html) 274 | > persistently on CI. That way as soon as a run finds one failing example, 275 | > it will always re-run future runs with that example. But note that the 276 | > Hypothesis examples database may be cleared when a new version of 277 | > Hypothesis or the test suite is released. 278 | 279 | #### Max examples 280 | 281 | The tests make heavy use 282 | [Hypothesis](https://hypothesis.readthedocs.io/en/latest/). You can configure 283 | how many examples are generated using the `--max-examples` flag, which 284 | defaults to `100`. Lower values can be useful for quick checks, and larger 285 | values should result in more rigorous runs. For example, `--max-examples 286 | 10_000` may find bugs where default runs don't but will take much longer to 287 | run. 288 | 289 | 290 | ## Contributing 291 | 292 | ### Remain in-scope 293 | 294 | It is important that every test only uses APIs that are part of the standard. 295 | For instance, when creating input arrays you should only use the [array creation 296 | functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html) 297 | that are documented in the spec. The same goes for testing arrays—you'll find 298 | many utilities that parralel NumPy's own test utils in the `*_helpers.py` files. 299 | 300 | ### Tools 301 | 302 | Hypothesis should almost always be used for the primary tests, and can be useful 303 | elsewhere. Effort should be made so drawn arguments are labeled with their 304 | respective names. For 305 | [`st.data()`](https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.data), 306 | draws should be accompanied with the `label` kwarg i.e. `data.draw(, 307 | label=