├── .flake8 ├── stablehlo_coreml ├── __init__.py ├── passes │ ├── utils.py │ └── remove_noop_slice_update.py ├── padding.py ├── ops_register.py ├── translation_context.py ├── reductions.py ├── sort_utils.py ├── utils.py └── converter.py ├── .coveragerc ├── .github └── workflows │ ├── publish-to-pypi.yml │ └── run-tests.yml ├── LICENSE ├── README.md ├── pyproject.toml ├── tests ├── passes │ └── test_remove_noop_slice_update.py ├── flax_blocks.py ├── utils.py ├── pytorch │ └── test_pytorch.py ├── test_flax.py ├── test_equinox.py └── test_jax.py └── .gitignore /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = env 3 | max-line-length = 127 4 | count = true 5 | statistics = true 6 | show-source = true 7 | -------------------------------------------------------------------------------- /stablehlo_coreml/__init__.py: -------------------------------------------------------------------------------- 1 | from .converter import convert 2 | from .passes.utils import register_optimizations, DEFAULT_HLO_PIPELINE 3 | 4 | __version__ = "0.0.0" 5 | __all__ = ['convert', 'register_optimizations', 'DEFAULT_HLO_PIPELINE'] 6 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | # Have to re-enable the standard pragma 4 | pragma: no cover 5 | 6 | # Don't complain about missing debug-only code: 7 | def __repr__ 8 | 9 | # Don't complain if tests don't hit defensive assertion code: 10 | raise ValueError 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /stablehlo_coreml/passes/utils.py: -------------------------------------------------------------------------------- 1 | import coremltools as ct 2 | 3 | DEFAULT_HLO_PIPELINE: ct.PassPipeline = ct.PassPipeline.DEFAULT 4 | 5 | 6 | def register_optimizations(): 7 | from .remove_noop_slice_update import remove_noop_slice_update 8 | custom_passes = [remove_noop_slice_update] 9 | 10 | for custom_pass in custom_passes: 11 | DEFAULT_HLO_PIPELINE.append_pass(f"common::{custom_pass.__name__}") 12 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.13' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install build 25 | - name: Install Hatch 26 | uses: pypa/hatch@install 27 | - name: Update version 28 | run: | 29 | hatch version $(git describe --tags) 30 | - name: Build 31 | run: | 32 | hatch build 33 | - name: Publish to PyPi 34 | run: | 35 | hatch publish -n -u __token__ -a ${{ secrets.PYPI_API_TOKEN }} 36 | -------------------------------------------------------------------------------- /stablehlo_coreml/padding.py: -------------------------------------------------------------------------------- 1 | from coremltools.converters.mil.mil import Builder as mb 2 | from coremltools.converters.mil.mil import types 3 | from .utils import get_mil_type, dtype_str 4 | 5 | 6 | def pad_with_cast(x, pad, mode="constant", constant_val=None): 7 | """ 8 | Helper function to handle padding for integer tensors. 9 | mb.pad only supports fp16 and fp32 inputs, so we cast to float, pad, and cast back. 10 | """ 11 | mil_type = get_mil_type(x) 12 | is_int_input = types.is_int(mil_type) 13 | if is_int_input: 14 | x = mb.cast(x=x, dtype="fp32") 15 | if constant_val is not None: 16 | constant_val = mb.cast(x=constant_val, dtype="fp32") 17 | 18 | padded = mb.pad(x=x, pad=pad, mode=mode, constant_val=constant_val) 19 | 20 | if is_int_input: 21 | padded = mb.cast(x=padded, dtype=dtype_str(mil_type)) 22 | 23 | return padded 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kasper Nielsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert StableHLO models into Apple Core ML format 2 | 3 | **This repo is currently experimental!** 4 | 5 | Only a subset of the StableHLO operations have been implemented, and some of them may have restrictions. 6 | 7 | Due to the current _dot_general_ op implementation, it is only possible to target iOS >= 18. 8 | 9 | Look in the `tests` directory, to see what has currently been tested. 10 | 11 | The package is published to PyPi as `stablehlo-coreml-experimental`. 12 | 13 | ## Converting a model 14 | 15 | To convert a StableHLO module, do the following: 16 | 17 | ```python 18 | import coremltools as ct 19 | from stablehlo_coreml.converter import convert 20 | from stablehlo_coreml import DEFAULT_HLO_PIPELINE 21 | 22 | mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18) 23 | cml_model = ct.convert( 24 | mil_program, 25 | source="milinternal", 26 | minimum_deployment_target=ct.target.iOS18, 27 | pass_pipeline=DEFAULT_HLO_PIPELINE, 28 | ) 29 | ``` 30 | 31 | For a Jax project, the `hlo_module` can be obtained the following way: 32 | 33 | ```python 34 | import jax 35 | from jax._src.lib.mlir import ir 36 | from jax._src.interpreters import mlir as jax_mlir 37 | from jax.export import export 38 | 39 | import jax.numpy as jnp 40 | 41 | def jax_function(a, b): 42 | return jnp.einsum("ij,jk -> ik", a, b) 43 | 44 | context = jax_mlir.make_ir_context() 45 | input_shapes = (jnp.zeros((2, 4)), jnp.zeros((4, 3))) 46 | jax_exported = export(jax.jit(jax_function))(*input_shapes) 47 | hlo_module = ir.Module.parse(jax_exported.mlir_module(), context=context) 48 | ``` 49 | 50 | For the Jax example to work, you will additionally need to install `absl-py` and `flatbuffers` as dependencies. 51 | 52 | For additional examples see the `tests` directory. 53 | 54 | ## Notes 55 | * `coremltools` supports up to python 3.13. Do not run hatch with a newer version. 56 | Can be controlled using fx `export HATCH_PYTHON=python3.13` 57 | * Run tests using `hatch run test:pytest tests` 58 | -------------------------------------------------------------------------------- /stablehlo_coreml/ops_register.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from .translation_context import TranslationContext 4 | 5 | 6 | def register_stablehlo_op(func): 7 | # Check the signature 8 | sig = inspect.signature(func) 9 | params = list(sig.parameters.values()) 10 | 11 | # Exclude 'self' from the parameters 12 | params = params[1:] 13 | 14 | error_msg = "HLO op implementations should take parameters of exactly " \ 15 | "(context: TranscriptionContext, op: )" 16 | if len(params) != 2: 17 | raise ValueError(error_msg) 18 | 19 | if not issubclass(params[0].annotation, TranslationContext): 20 | raise ValueError(error_msg) 21 | 22 | # We identify the function by the type of operation it implements 23 | func._implements_hlo_op = params[1].annotation 24 | return func 25 | 26 | 27 | class StableHloOpsRegistry(type): 28 | def __init__(cls, name, bases, clsdict): 29 | super().__init__(name, bases, clsdict) 30 | 31 | cls._stablehlo_ops_registry = {} 32 | for name, method in clsdict.items(): 33 | op_type = getattr(method, '_implements_hlo_op', False) 34 | if callable(method) and op_type: 35 | if op_type in cls._stablehlo_ops_registry: 36 | raise ValueError(f"StableHLO op {op_type} has been registered more than once!") 37 | cls._stablehlo_ops_registry[op_type] = method 38 | 39 | def _dispatch_op(cls, self, context: TranslationContext, op): 40 | if type(op) not in self._stablehlo_ops_registry: 41 | raise ValueError(f"The StableHLO op {type(op)} has not been implemented!") 42 | 43 | op_method = self._stablehlo_ops_registry[type(op)] 44 | return op_method(self, context, op) 45 | 46 | def __call__(cls, *args, **kwargs): 47 | # Register the dispatch_op method 48 | instance = super().__call__(*args, **kwargs) 49 | setattr(instance, 'dispatch_op', cls._dispatch_op) 50 | return instance 51 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Run Tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | test-core: 14 | runs-on: macos-26 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.12", "3.13"] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install flake8 pytest 30 | - name: Install Hatch 31 | uses: pypa/hatch@install 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | # The GitHub editor is 127 chars wide 36 | flake8 . --count --show-source --statistics --max-line-length=127 37 | - name: Test with hatch 38 | run: | 39 | hatch run +py=${{ matrix.python-version }} test:test-with-cov 40 | 41 | test-pytorch: 42 | runs-on: macos-26 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | python-version: ["3.12", "3.13"] 47 | 48 | steps: 49 | - uses: actions/checkout@v4 50 | - name: Set up Python ${{ matrix.python-version }} 51 | uses: actions/setup-python@v3 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | - name: Cache Hugging Face and Torch models 55 | uses: actions/cache@v3 56 | with: 57 | path: | 58 | ~/.cache/huggingface 59 | ~/.cache/torch 60 | key: ${{ runner.os }}-models-${{ hashFiles('tests/pytorch/test_pytorch.py') }} 61 | restore-keys: | 62 | ${{ runner.os }}-models- 63 | - name: Install dependencies 64 | run: | 65 | python -m pip install --upgrade pip 66 | python -m pip install flake8 pytest 67 | - name: Install Hatch 68 | uses: pypa/hatch@install 69 | - name: Test pytorch export 70 | run: | 71 | hatch run +py=${{ matrix.python-version }} test-pytorch:pytest -vv tests/pytorch/ 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "stablehlo-coreml-experimental" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name="Kasper Nielsen", email="kasper0406@gmail.com" }, 10 | ] 11 | description = "Convert StableHLO models into Apple Core ML format" 12 | readme = "README.md" 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "License :: OSI Approved :: MIT License", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: End Users/Desktop", 18 | "Operating System :: MacOS :: MacOS X", 19 | "Programming Language :: Python", 20 | "Topic :: Scientific/Engineering", 21 | "Topic :: Software Development" 22 | ] 23 | keywords=[ "stablehlo", "hlo", "xla", "coreml", "machinelearning", "ml", "coremltools", "converter", "neural" ] 24 | requires-python = ">=3.9" 25 | 26 | dependencies = [ 27 | 'coremltools>=9.0; python_version >= "3.10" and python_version <= "3.13"', 28 | "numpy~=2.0", 29 | 30 | # Jax is not actually a strict requirement for the main library. 31 | # However, the code relies on the mlir StableHLO python bindings, and currently they are not published to pip 32 | # and the only pre-built stand-alone library is only built for linux. 33 | # Onces https://github.com/openxla/stablehlo/issues/2346 is resolved, this dependency can be switch to stablehlo instead. 34 | "jax>=0.8.1", 35 | ] 36 | 37 | [tool.hatch.version] 38 | path = "stablehlo_coreml/__init__.py" 39 | 40 | [tool.hatch.build.targets.wheel] 41 | # This setting is needed as long as we publish with the `_experimental` suffix 42 | packages = ["stablehlo_coreml"] 43 | 44 | [[tool.hatch.envs.test.matrix]] 45 | python = ["3.12", "3.13"] 46 | 47 | [tool.hatch.envs.test] 48 | randomize = true 49 | parallel = true 50 | 51 | extra-dependencies = [ 52 | "pytest", 53 | "flax>=0.12.0", 54 | "flatbuffers", 55 | "einops", 56 | "pillow", 57 | "equinox>=0.13.2", 58 | "pytest-cov", 59 | ] 60 | 61 | [tool.hatch.envs.test.scripts] 62 | test-with-cov = "pytest --cov=stablehlo_coreml --cov-report=term-missing" 63 | 64 | [[tool.hatch.envs.test-pytorch.matrix]] 65 | python = ["3.12", "3.13"] 66 | 67 | [tool.hatch.envs.test-pytorch] 68 | randomize = true 69 | parallel = true 70 | 71 | extra-dependencies = [ 72 | "pytest", 73 | "torch>=2.9.1", 74 | "torchvision", 75 | "torchax", 76 | "flax", # torchax wants flax to be installed 77 | "transformers", 78 | ] 79 | 80 | [project.urls] 81 | Homepage = "https://github.com/kasper0406/stablehlo-coreml" 82 | Issues = "https://github.com/kasper0406/stablehlo-coreml/issues" 83 | -------------------------------------------------------------------------------- /stablehlo_coreml/translation_context.py: -------------------------------------------------------------------------------- 1 | from coremltools.converters.mil import mil 2 | 3 | 4 | class TranslationContext: 5 | def __init__(self): 6 | self._path = [] 7 | self.seen_paths = set() 8 | self.variables = {} # Nested map: path -> variable -> mil var 9 | 10 | def push_function(self, name: str): 11 | counter = 0 12 | ctx_name = name 13 | while True: 14 | new_path = self._path + [ctx_name] 15 | if "/".join(new_path) in self.seen_paths: 16 | # Ensure that the new context name is in fact unique 17 | # A collision can happen if the same function is called twice 18 | ctx_name = f"{name}_{counter}" 19 | counter += 1 20 | else: 21 | self._path.append(ctx_name) 22 | self.seen_paths.add(self.path()) 23 | return ctx_name 24 | 25 | def pop_function(self): 26 | self.variables.pop(self.path()) 27 | self._path.pop() 28 | 29 | def add_variable(self, name: str, mil_var: mil.Var): 30 | path = self.path() 31 | if path not in self.variables: 32 | self.variables[path] = {} 33 | 34 | if name in self.variables[path]: 35 | raise ValueError(f"Variable {name} is already defined in path {path}") 36 | self.variables[path][name] = mil_var 37 | 38 | def add_result(self, hlo_result, result: mil.Var): 39 | result_name = hlo_result.get_name() 40 | self.add_variable(result_name, result) 41 | 42 | def validate_shapes(hlo_shape: tuple, mil_shape: tuple): 43 | if hlo_shape == tuple() and (mil_shape == tuple() or mil_shape == (1, )): 44 | return True 45 | if hlo_shape == mil_shape: 46 | return True 47 | 48 | raise ValueError(f"The HLO result shape `{hlo_shape}` is different from the actual MIL result shape `{mil_shape}`") 49 | 50 | hlo_shape = tuple(hlo_result.type.shape) 51 | mil_shape = tuple(result.shape) 52 | validate_shapes(hlo_shape=hlo_shape, mil_shape=mil_shape) 53 | 54 | def __getitem__(self, name: str): 55 | # Walk up along the path list to find the first correctly named variable in scope 56 | path = self._path.copy() 57 | while True: 58 | ctx = self.variables["/".join(path)] 59 | if name in ctx: 60 | return ctx[name] 61 | if len(path) == 0: 62 | raise ValueError(f"Variable with name {name} is not defined in path {path}") 63 | path.pop() 64 | 65 | def path(self) -> str: 66 | return "/".join(self._path) 67 | -------------------------------------------------------------------------------- /stablehlo_coreml/passes/remove_noop_slice_update.py: -------------------------------------------------------------------------------- 1 | from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass 2 | from coremltools.converters.mil.mil.passes.helper import block_context_manager 3 | from coremltools.converters.mil.mil.passes.pass_registry import register_pass 4 | 5 | import numpy as np 6 | 7 | 8 | def _match_pattern(op): 9 | if op.op_type == "slice_update": 10 | x_rank = len(op.x.shape) 11 | 12 | x_and_update_shape_matches = op.x.shape == op.update.shape 13 | 14 | all_zeros_start_indices_array = np.array([0] * x_rank, dtype=np.int32) 15 | start_values_all_zero = np.array_equal(op.begin.val, all_zeros_start_indices_array) 16 | 17 | end_values_matches_x_shape = np.array_equal(op.end.val, op.x.shape) 18 | 19 | all_one_strides_array = np.array([1] * x_rank, dtype=np.int32) 20 | strides_all_one = not op.stride or np.array_equal(op.stride.val, all_one_strides_array) 21 | no_extra_options = strides_all_one and not op.begin_mask and not op.end_mask 22 | 23 | return x_and_update_shape_matches and start_values_all_zero and end_values_matches_x_shape and no_extra_options 24 | 25 | return False 26 | 27 | 28 | def _try_to_transform(slice_update_op): 29 | # Replace occurences of the `slice_update_op` output with the `slice_update_op.update` variable 30 | slice_update_op.enclosing_block.replace_uses_of_var_after_op( 31 | anchor_op=slice_update_op, old_var=slice_update_op.outputs[0], new_var=slice_update_op.update 32 | ) 33 | slice_update_op.remove_from_block() 34 | return True 35 | 36 | 37 | @block_context_manager 38 | def _remove_noop_slice_update(block): 39 | did_optimize = False 40 | for op in list(block.operations): 41 | if op.enclosing_block is None: 42 | continue 43 | 44 | for b in op.blocks: 45 | block_changed = True 46 | while block_changed: 47 | block_changed = _remove_noop_slice_update(b) 48 | if len(op.blocks) > 0: 49 | continue 50 | 51 | if _match_pattern(op): 52 | if _try_to_transform(op): 53 | did_optimize = True 54 | return did_optimize 55 | 56 | 57 | @register_pass(namespace="common") 58 | class remove_noop_slice_update(AbstractGraphPass): 59 | """ 60 | If a slice_update is called on the full tensor with an update of the same shape, 61 | simply use the update tensor going forward. 62 | 63 | This optimization is very useful for the way the HLO DotGeneralOp is implemented, 64 | in case the DotGeneralOp reduces to a single matrix multiplication. 65 | 66 | Given: 67 | %1 = 68 | %2 = 69 | %2 = slice_update(x=%buffer, update=%2, begin=[0] * rank(%1), end=S, stride=[1] * rank(%1)) 70 | %3 = some_op(%2) 71 | 72 | Result: 73 | %1 = 74 | %3 = some_op(%1) 75 | ... 76 | """ 77 | def apply(self, prog): 78 | for f in prog.functions.values(): 79 | block_changed = True 80 | while block_changed: 81 | block_changed = _remove_noop_slice_update(f) 82 | -------------------------------------------------------------------------------- /tests/passes/test_remove_noop_slice_update.py: -------------------------------------------------------------------------------- 1 | from coremltools.converters.mil.mil import Builder as mb 2 | from coremltools.converters.mil.testing_utils import ( 3 | apply_pass_and_basic_check, 4 | assert_model_is_valid, 5 | get_op_types_in_program, 6 | ) 7 | import coremltools as ct 8 | 9 | import numpy as np 10 | 11 | from stablehlo_coreml import register_optimizations 12 | 13 | register_optimizations() 14 | 15 | 16 | class TestRemoveNoopSliceUpdate: 17 | 18 | def test_is_removed(self): 19 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 20 | def prog(x): 21 | buffer = np.zeros((10, 20)) 22 | # Because this function ends up being a complete no-op, we need to ensure the naming of inputs and outputs 23 | x = mb.slice_update(x=buffer, update=x, begin=[0, 0], end=buffer.shape, name="x") 24 | return x 25 | self.__test_program(prog, should_remove=True) 26 | 27 | def test_not_removed_if_non_zero_begin_shape(self): 28 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 29 | def prog(x): 30 | buffer = np.zeros((11, 20)) 31 | x = mb.slice_update(x=buffer, update=x, begin=[1, 0], end=buffer.shape) 32 | return x 33 | self.__test_program(prog, should_remove=False) 34 | 35 | def test_not_removed_if_end_not_matching(self): 36 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 37 | def prog(x): 38 | buffer = np.zeros((11, 20)) 39 | x = mb.slice_update(x=buffer, update=x, begin=[0, 0], end=[10, 20]) 40 | return x 41 | self.__test_program(prog, should_remove=False) 42 | 43 | def test_not_removed_if_strided(self): 44 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 45 | def prog(x): 46 | buffer = np.zeros((20, 20)) 47 | x = mb.slice_update(x=buffer, update=x, begin=[0, 0], end=buffer.shape, stride=[2, 1]) 48 | return x 49 | self.__test_program(prog, should_remove=False) 50 | 51 | def test_not_removed_if_begin_mask(self): 52 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 53 | def prog(x): 54 | buffer = np.zeros((10, 20)) 55 | x = mb.slice_update(x=buffer, update=x, begin=[0, 0], end=buffer.shape, begin_mask=[True, False]) 56 | return x 57 | self.__test_program(prog, should_remove=False) 58 | 59 | def test_not_removed_if_end_mask(self): 60 | @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) 61 | def prog(x): 62 | buffer = np.zeros((10, 20)) 63 | x = mb.slice_update(x=buffer, update=x, begin=[0, 0], end=buffer.shape, end_mask=[True, False]) 64 | return x 65 | self.__test_program(prog, should_remove=False) 66 | 67 | def __test_program(self, prog, should_remove: bool): 68 | assert get_op_types_in_program(prog) == ["slice_update"] 69 | 70 | apply_pass_and_basic_check( 71 | prog, "common::remove_noop_slice_update" 72 | ) 73 | _, _, _ = apply_pass_and_basic_check(prog, "common::dead_code_elimination") 74 | 75 | if should_remove: 76 | assert get_op_types_in_program(prog) == [] 77 | else: 78 | assert get_op_types_in_program(prog) == ["slice_update"] 79 | 80 | assert_model_is_valid( 81 | prog, 82 | {"x": (10, 20)}, 83 | minimum_deployment_target=ct.target.iOS18, 84 | backend=("mlprogram", "fp32") 85 | ) 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .vscode/ 165 | -------------------------------------------------------------------------------- /tests/flax_blocks.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | import jax.numpy as jnp 3 | 4 | from typing import List 5 | 6 | 7 | class ResidualConv(nnx.Module): 8 | scale_conv: nnx.Conv 9 | conv: nnx.Conv 10 | normalization_1: nnx.Module 11 | normalization_2: nnx.Module 12 | normalization_3: nnx.Module 13 | shortcut: nnx.Conv 14 | 15 | def __init__(self, in_channels: int, out_channels: int, rngs: nnx.Rngs, stride: int = 2): 16 | conv_type = nnx.Conv if in_channels <= out_channels else nnx.ConvTranspose 17 | 18 | kernel_size = 4 19 | self.scale_conv = conv_type( 20 | in_features=in_channels, 21 | out_features=out_channels, 22 | kernel_size=(kernel_size,), 23 | strides=(stride,), 24 | rngs=rngs 25 | ) 26 | self.conv = nnx.Conv( 27 | in_features=out_channels, 28 | out_features=out_channels, 29 | kernel_size=kernel_size, 30 | rngs=rngs, 31 | ) 32 | 33 | self.normalization_1 = nnx.BatchNorm(num_features=out_channels, rngs=rngs) 34 | self.normalization_2 = nnx.BatchNorm(num_features=out_channels, rngs=rngs) 35 | 36 | self.shortcut = conv_type( 37 | in_features=in_channels, 38 | out_features=out_channels, 39 | kernel_size=(stride,), 40 | strides=(stride,), 41 | rngs=rngs 42 | ) 43 | 44 | def __call__(self, x): 45 | out = self.scale_conv(x) 46 | out = self.normalization_1(out) 47 | out = nnx.silu(out) 48 | 49 | out = self.conv(out) 50 | out = nnx.silu(out) 51 | 52 | # Residual 53 | out = out + self.shortcut(x) 54 | out = self.normalization_2(out) 55 | 56 | return out 57 | 58 | 59 | class Encoder(nnx.Module): 60 | cnn_layers: List[ResidualConv] 61 | normalization: nnx.Module 62 | 63 | def __init__(self, num_layers: int, rngs: nnx.Rngs): 64 | self.cnn_layers = nnx.List() 65 | 66 | for i in range(num_layers): 67 | in_channels = (2 ** i) 68 | out_channels = 2 ** (i + 1) 69 | 70 | self.cnn_layers.append(ResidualConv( 71 | in_channels=in_channels, 72 | out_channels=out_channels, 73 | rngs=rngs, 74 | )) 75 | 76 | last_layer_features = 2 ** num_layers 77 | self.normalization = nnx.BatchNorm(num_features=last_layer_features, rngs=rngs) 78 | 79 | def __call__(self, x): 80 | out = x 81 | skip_connections = [] 82 | for layer in self.cnn_layers: 83 | out = layer(out) 84 | skip_connections.append(out) 85 | 86 | out = self.normalization(out) 87 | out = nnx.tanh(out) 88 | 89 | return out, skip_connections 90 | 91 | 92 | class Decoder(nnx.Module): 93 | cnn_layers: List[ResidualConv] 94 | residual_norm_layers: List[nnx.Module] 95 | output_polling: nnx.Conv 96 | 97 | def __init__(self, num_layers: int, rngs: nnx.Rngs): 98 | self.cnn_layers = nnx.List() 99 | self.residual_norm_layers = nnx.List() 100 | 101 | input_features = 2 ** num_layers 102 | for i in range(num_layers): 103 | # Times two to handle residual connections 104 | in_channels = 2 * (input_features // (2 ** i)) 105 | out_channels = input_features // (2 ** (i + 1)) 106 | 107 | self.residual_norm_layers.append(nnx.BatchNorm(in_channels, rngs=rngs)) 108 | self.cnn_layers.append(ResidualConv( 109 | in_channels=in_channels, 110 | out_channels=out_channels, 111 | rngs=rngs, 112 | )) 113 | 114 | last_layer_features = input_features // (2 ** num_layers) 115 | self.output_polling = nnx.Conv( 116 | in_features=last_layer_features, 117 | out_features=1, 118 | kernel_size=3, 119 | rngs=rngs, 120 | ) 121 | 122 | def __call__(self, x, skip_connections): 123 | skip_connections = list(reversed(skip_connections)) 124 | 125 | out = x 126 | for i, (cnn_layer, residual_norm) in enumerate(zip(self.cnn_layers, self.residual_norm_layers)): 127 | residual = skip_connections[i] 128 | out = residual_norm(jnp.concatenate([out, residual], axis=-1)) 129 | out = cnn_layer(out) 130 | 131 | out = self.output_polling(out) 132 | return out 133 | 134 | 135 | class UNet(nnx.Module): 136 | encoder: Encoder 137 | decoder: Decoder 138 | 139 | def __init__(self, num_layers: int, rngs: nnx.Rngs): 140 | self.audio_encoding = Encoder(num_layers=num_layers, rngs=rngs) 141 | self.audio_decoding = Decoder(num_layers=num_layers, rngs=rngs) 142 | 143 | def __call__(self, x): 144 | def compress_dynamic_range(samples): 145 | mu = jnp.array(255.0, dtype=jnp.float16) 146 | return jnp.sign(samples) * jnp.log1p(mu * jnp.abs(samples)) / jnp.log1p(mu) 147 | x = compress_dynamic_range(x) 148 | 149 | hidden, skip_connections = self.audio_encoding(x) 150 | out = self.audio_decoding(hidden, skip_connections) 151 | 152 | return out 153 | -------------------------------------------------------------------------------- /stablehlo_coreml/reductions.py: -------------------------------------------------------------------------------- 1 | from coremltools import _logger as logger 2 | from coremltools.converters.mil.mil import Builder as mb 3 | import numpy as np 4 | from jaxlib.mlir.dialects.stablehlo import ( 5 | AddOp, MulOp, MinOp, MaxOp, ReturnOp, SubtractOp, DivOp 6 | ) 7 | 8 | from .utils import ( 9 | index_by_slices, update_tensor_by_slice, iterate_indexes_in_shapes, 10 | get_numpy_type 11 | ) 12 | from .translation_context import TranslationContext 13 | 14 | 15 | def match_computation(hlo_body): 16 | if len(hlo_body.blocks) != 1: 17 | return None, None, None 18 | args = list(hlo_body.blocks[0].arguments) 19 | ops = list(hlo_body.blocks[0].operations) 20 | 21 | # Check for the special "update" mode (overwrite) 22 | # This corresponds to returning the second argument (the update value) 23 | if len(ops) == 1 and isinstance(ops[0], ReturnOp) and ops[0].operands[0] == args[1]: 24 | # This is the "update" mode: return args[1] (the update value) 25 | # We define a lambda that just returns the update value 26 | def mil_binary_op(x, y): 27 | return y 28 | mode = "update" 29 | return None, mil_binary_op, mode 30 | 31 | # Simple matches are where the `hlo_body` is on the form 32 | # return _generic_reduction_op_type_(`args`) 33 | # In that case, if MIL has an equivalent of `_generic_reduction_op_`, we simply delegate to that 34 | simple_matches = { 35 | MaxOp: (mb.reduce_max, mb.maximum, "max"), 36 | MinOp: (mb.reduce_min, mb.minimum, "min"), 37 | AddOp: (mb.reduce_sum, mb.add, "add"), 38 | MulOp: (mb.reduce_prod, mb.mul, "mul"), 39 | SubtractOp: (None, mb.sub, "sub"), 40 | DivOp: (None, mb.real_div, "div"), 41 | } 42 | 43 | for generic_reduce_op_type, mil_equivalents in simple_matches.items(): 44 | if len(ops) == 2 and isinstance(ops[0], generic_reduce_op_type) and isinstance(ops[1], ReturnOp): 45 | if list(ops[0].operands) == args and list(ops[1].operands) == list(ops[0].results): 46 | return mil_equivalents 47 | 48 | return None, None, None 49 | 50 | 51 | def compute_reduction(converter, context: TranslationContext, inputs, dimensions, body, init_values, result_types): 52 | mil_reduction, mil_single_reduction, _ = match_computation(body) 53 | if mil_reduction and mil_single_reduction and len(inputs) == 1: 54 | res = mil_reduction(x=inputs[0], axes=np.array(dimensions, dtype=np.int32)) 55 | # Handle initial value 56 | res = mil_single_reduction(x=res, y=init_values[0]) 57 | return [res] 58 | 59 | # Fall back to loop implementation 60 | logger.warning("Falling back to while-loop implementation for reduction. This may be slower than expected!") 61 | 62 | input_rank = len(inputs[0].shape) 63 | # Notice for the loops we treat both `reduce_shape` and `result_shape` as being 64 | # of the input rank. This is to make computing element indexes easier. 65 | # When updating the result, we later pick out just the result indices 66 | # we care about in the actual result. 67 | reduce_shape = [inputs[0].shape[dim] if dim in dimensions else 1 for dim in range(input_rank)] 68 | result_shape = [inputs[0].shape[dim] if dim not in dimensions else 1 for dim in range(input_rank)] 69 | 70 | def compute_reduction_loop(result_idx, *partial_results): 71 | def compute_inner(element_idx, *acc): 72 | element_idx = mb.add(x=result_idx, y=element_idx) 73 | elements = [mb.reshape(x=index_by_slices(input, [element_idx]), shape=(1,)) for input in inputs] 74 | 75 | args = list(acc) + elements 76 | hlo_params = list(body.blocks[0].arguments) 77 | outputs = converter.invoke_hlo_function(context, "reduce_body", hlo_params, body, args) 78 | 79 | return outputs 80 | 81 | reduction_results = iterate_indexes_in_shapes(compute_inner, [reduce_shape], init_values) 82 | 83 | # The result rank is likely less than the input shape. 84 | # We need to pick the indexes in the result shape we want to update 85 | result_indices = [dim for dim in range(input_rank) if dim not in dimensions] 86 | if len(result_indices) != 0: 87 | result_idx = [mb.gather(x=result_idx, indices=result_indices)] 88 | else: 89 | result_idx = [] 90 | 91 | return [ 92 | update_tensor_by_slice(acc, result_idx, result) 93 | for acc, result in zip(partial_results, reduction_results) 94 | ] 95 | 96 | mil_results = [ 97 | np.zeros(result_type.shape, dtype=get_numpy_type(result_type)) 98 | if len(result_type.shape) == 0 else 99 | mb.tile(x=np.zeros((1,) * len(result_type.shape), dtype=get_numpy_type(result_type)), reps=result_type.shape) 100 | for result_type in result_types 101 | ] 102 | mil_results = iterate_indexes_in_shapes(compute_reduction_loop, [result_shape], mil_results, unroll_limit=5) 103 | return mil_results 104 | 105 | 106 | def compute_windowed_reduction( 107 | converter, 108 | context: TranslationContext, 109 | inputs, 110 | window_dimensions, 111 | window_strides, 112 | body, 113 | init_values, 114 | result_types 115 | ): 116 | def move_axis_last(arr, axis): 117 | permutation = list(range(len(arr.shape))) 118 | permutation.append(permutation.pop(axis)) 119 | return mb.transpose(x=arr, perm=permutation) 120 | 121 | # First group all the dimensions being reduced over in a group at the end 122 | inputs_rank = len(window_dimensions) 123 | partitioned_inputs = [] 124 | for input in inputs: 125 | transformed = mb.sliding_windows( 126 | x=input, 127 | axis=0, 128 | size=window_dimensions[0], 129 | stride=window_strides[0] 130 | ) 131 | transformed = move_axis_last(transformed, 1) 132 | for axis in range(1, inputs_rank): 133 | transformed = mb.sliding_windows( 134 | x=transformed, axis=axis, size=window_dimensions[axis], stride=window_strides[axis]) 135 | transformed = move_axis_last(transformed, axis + 1) 136 | # Contract the two last dimensions into one 137 | transformed_rank = len(transformed.shape) 138 | new_shape = mb.concat(values=[ 139 | mb.slice_by_size(x=mb.shape(x=transformed), begin=[0], size=[transformed_rank - 2]), 140 | np.array([-1], dtype=np.int32) 141 | ], axis=0) 142 | transformed = mb.reshape(x=transformed, shape=new_shape) 143 | partitioned_inputs.append(transformed) 144 | 145 | # Then use the normal reduce implementation to compute the result 146 | reduction_dimension = len(partitioned_inputs[0].shape) - 1 147 | reduction_results = compute_reduction( 148 | converter=converter, 149 | context=context, 150 | inputs=partitioned_inputs, 151 | dimensions=[reduction_dimension], 152 | body=body, 153 | init_values=init_values, 154 | result_types=result_types, 155 | ) 156 | return reduction_results 157 | -------------------------------------------------------------------------------- /stablehlo_coreml/sort_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from jaxlib.mlir import ir 3 | from jaxlib.mlir.dialects.stablehlo import ( 4 | CompareOp, SelectOp, ConstantOp 5 | ) 6 | from jax._src.lib.mlir.dialects import hlo 7 | 8 | 9 | def match_sort(comparator_root, args, inputs): 10 | """ 11 | Analyzes the comparator region of a SortOp to determine if it implements a supported sorting pattern. 12 | We try to analyze the comparator logic for multi-key sorts. 13 | 14 | Multi-key sort compares multiple keys in sequence. If the primary keys are equal, 15 | it moves to the secondary keys, and so on. 16 | 17 | The comparator logic is expected to look like a chain of SelectOps: 18 | if (k1 == k2) then compare(next_keys) else compare(k1 < k2) 19 | 20 | Args: 21 | comparator_root: The current operation in the comparator region being analyzed (starts at the return value). 22 | args: The arguments of the comparator block (representing the two elements being compared). 23 | inputs: The list of input tensors to the SortOp. 24 | 25 | Returns: 26 | A list of (tensor, ascending) tuples representing the sort keys, or None if the pattern doesn't match. 27 | """ 28 | def get_op(val): 29 | if isinstance(val.owner, ir.Block): 30 | return None 31 | return val.owner.opview 32 | 33 | def get_arg_index(value): 34 | if value in args: 35 | return args.index(value) 36 | op = get_op(value) 37 | if op is None: 38 | return None 39 | return match_nan_and_zero_handling(op, args) 40 | 41 | def identify_comparison_args(compare_op: CompareOp) -> tuple[int | None, bool | None]: 42 | lhs = get_arg_index(compare_op.lhs) 43 | rhs = get_arg_index(compare_op.rhs) 44 | if lhs is None or rhs is None: 45 | return None, None 46 | 47 | # According to StableHLO sort spec, arguments are guaranteed to be interleaved: 48 | # (lhs_0, rhs_0, lhs_1, rhs_1, ...) 49 | # Therefore the input pair being compared should be adjacent, and the corresponding 50 | # key index can be derived by integer division by 2. 51 | # Reference: https://openxla.org/stablehlo/spec#sort 52 | if (lhs // 2) != (rhs // 2): 53 | return None, None 54 | 55 | direction = hlo.ComparisonDirectionAttr(compare_op.comparison_direction).value 56 | is_ascending = lhs < rhs and direction == "LT" or lhs > rhs and direction == "GT" 57 | 58 | return lhs // 2, is_ascending 59 | 60 | def match_comparison(op, expected_direction=None): 61 | if not isinstance(op, CompareOp): 62 | return None 63 | 64 | direction = hlo.ComparisonDirectionAttr(op.comparison_direction).value 65 | if expected_direction and direction != expected_direction: 66 | return None 67 | 68 | if expected_direction is None and direction not in ("LT", "GT"): 69 | return None 70 | 71 | return op 72 | 73 | def match_select_chain(op): 74 | # Matches: select(pred, on_true, on_false) 75 | # where pred is (k1 == k2) and on_false is (k1 < k2) 76 | if not isinstance(op, SelectOp): 77 | return None 78 | 79 | pred = get_op(op.pred) 80 | on_false = get_op(op.on_false) 81 | 82 | # 1. Check pred: k1 == k2 83 | pred_cmp = match_comparison(pred, "EQ") 84 | if not pred_cmp: 85 | return None 86 | 87 | # 2. Check on_false: k1 < k2 (or >) 88 | false_cmp = match_comparison(on_false) 89 | if not false_cmp: 90 | return None 91 | 92 | # 3. Verify operands match between pred and on_false 93 | if {pred_cmp.lhs, pred_cmp.rhs} != {false_cmp.lhs, false_cmp.rhs}: 94 | return None 95 | 96 | # 4. Identify key 97 | key_info = identify_comparison_args(pred_cmp) 98 | if key_info[0] is None: 99 | return None 100 | 101 | return key_info, get_op(op.on_true) 102 | 103 | def match_leaf(op): 104 | # Matches: k1 < k2 105 | cmp = match_comparison(op) 106 | if not cmp: 107 | return None 108 | return identify_comparison_args(cmp) 109 | 110 | # Walk through the operations graph to match the sort pattern 111 | sort_keys = [] 112 | current_op = comparator_root 113 | while current_op: 114 | # Try to match a chain node (SelectOp) 115 | chain_result = match_select_chain(current_op) 116 | if chain_result: 117 | (key_idx, is_asc), next_op = chain_result 118 | sort_keys.append((inputs[key_idx], is_asc)) 119 | current_op = next_op 120 | continue 121 | 122 | # Try to match a leaf node (CompareOp) 123 | leaf_result = match_leaf(current_op) 124 | if leaf_result: 125 | key_idx, is_asc = leaf_result 126 | if key_idx is None: 127 | return None 128 | sort_keys.append((inputs[key_idx], is_asc)) 129 | return sort_keys 130 | 131 | # If neither matched, it's not a valid sort pattern 132 | return None 133 | 134 | return sort_keys 135 | 136 | 137 | def match_nan_and_zero_handling(op, args): 138 | """ 139 | Jax generates nan-checks and +0/-0 merges in the comparator. This function matches this pattern: 140 | select( != , NaN, select( == 0, 0, )) 141 | 142 | It will extract the argument index of if the pattern matches. 143 | 144 | Notice that this is technically non correct, as MIL does not handle NaN's in the same way as StableHLO. 145 | +0/-0 looks to be correctly handled. 146 | """ 147 | def get_op(val): 148 | if isinstance(val.owner, ir.Block): 149 | return None 150 | return val.owner.opview 151 | 152 | def match_constant(val, check_fn): 153 | const_op = get_op(val) 154 | if not isinstance(const_op, ConstantOp): 155 | return False 156 | return check_fn(const_op.value) 157 | 158 | def match_isnan(val): 159 | # Matches: != 160 | compare_op = get_op(val) 161 | if not isinstance(compare_op, CompareOp): 162 | return None 163 | if hlo.ComparisonDirectionAttr(compare_op.comparison_direction).value != "NE": 164 | return None 165 | if compare_op.lhs != compare_op.rhs: 166 | return None 167 | return compare_op.lhs 168 | 169 | def match_is_zero(val): 170 | # Matches: == 0 171 | compare_op = get_op(val) 172 | if not isinstance(compare_op, CompareOp): 173 | return None 174 | if hlo.ComparisonDirectionAttr(compare_op.comparison_direction).value != "EQ": 175 | return None 176 | if not match_constant(compare_op.rhs, lambda x: np.array(x) == 0): 177 | return None 178 | return compare_op.lhs 179 | 180 | # 1. Match outer select: select(pred, NaN, on_false) 181 | if not isinstance(op, SelectOp): 182 | return None 183 | 184 | if not match_constant(op.on_true, np.isnan): 185 | return None 186 | 187 | # 2. Match NaN check: pred is (x != x) 188 | matched_arg_idx = match_isnan(op.pred) 189 | if matched_arg_idx is None: 190 | return None 191 | 192 | # 3. Match inner select: select(pred, 0, x) 193 | inner_op = get_op(op.on_false) 194 | if not isinstance(inner_op, SelectOp): 195 | return None 196 | 197 | if not match_constant(inner_op.on_true, lambda x: np.array(x) == 0): 198 | return None 199 | 200 | if inner_op.on_false != matched_arg_idx: 201 | return None 202 | 203 | # 4. Match Zero check: pred is (x == 0) 204 | x_val_zero = match_is_zero(inner_op.pred) 205 | if x_val_zero != matched_arg_idx: 206 | return None 207 | 208 | # 5. Return index if found in args 209 | if matched_arg_idx in args: 210 | return args.index(matched_arg_idx) 211 | 212 | return None 213 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.export import export as _jax_export 4 | from jax._src.lib.mlir import ir 5 | from jax._src.interpreters import mlir as jax_mlir 6 | 7 | import numpy as np 8 | 9 | from stablehlo_coreml.converter import convert 10 | from stablehlo_coreml import DEFAULT_HLO_PIPELINE 11 | 12 | import coremltools as ct 13 | from coremltools.converters.mil.testing_utils import compare_backend 14 | from coremltools.converters.mil.mil import Program, Block 15 | 16 | from typing import List 17 | 18 | 19 | def jax_export(jax_func, input_spec): 20 | def compute_input_shapes(input_specs): 21 | shapes = [] 22 | for input_spec in input_specs: 23 | if isinstance(input_spec, (list, tuple)): 24 | # We only unwrap the shapes for one level 25 | shapes.append(input_spec) 26 | else: 27 | shapes.append(jax.ShapeDtypeStruct(input_spec.shape, input_spec.dtype)) 28 | return shapes 29 | input_shapes = compute_input_shapes(input_spec) 30 | jax_exported = _jax_export(jax.jit(jax_func))(*input_shapes) 31 | return jax_exported 32 | 33 | 34 | def generate_random_from_shape(input_spec, key=jax.random.PRNGKey): 35 | shape = input_spec.shape 36 | dtype = input_spec.dtype 37 | if jnp.issubdtype(dtype, jnp.integer): 38 | output = jax.random.randint(key=key, shape=shape, minval=-100, maxval=100, dtype=dtype) 39 | elif jnp.issubdtype(dtype, jnp.bool_): 40 | output = jax.random.bernoulli(key=key, shape=shape).astype(dtype) 41 | else: 42 | output = jax.random.normal(key=key, shape=shape, dtype=dtype) 43 | return output 44 | 45 | 46 | def flatten(nested_list): 47 | def visit(lst): 48 | flat = [] 49 | for element in lst: 50 | if isinstance(element, (list, tuple)): 51 | flat += visit(element) 52 | else: 53 | flat.append(element) 54 | return flat 55 | return visit(nested_list) 56 | 57 | 58 | def __nest_flat_jax_input_to_input_spec(input_spec, flat_input): 59 | idx = 0 60 | 61 | def visit(lst): 62 | nonlocal idx 63 | result = [] 64 | for element in lst: 65 | if isinstance(element, (list, tuple)): 66 | result.append(visit(element)) 67 | else: 68 | if idx >= len(flat_input): 69 | raise ValueError( 70 | "flat_input had too many inputs to fit input_spec. " 71 | f"Input spec: {input_spec}, Flat input: {flat_input}") 72 | result.append(flat_input[idx]) 73 | idx += 1 74 | return result 75 | 76 | structured_input = visit(input_spec) 77 | if idx != len(flat_input): 78 | raise ValueError("flat_input had too few inputs to fill input_spec. " 79 | f"Input spec: {input_spec}, Flat input: {flat_input}") 80 | 81 | return structured_input 82 | 83 | 84 | def _count_program_complexity(mil_program: Program): 85 | """ 86 | Counts the number of instructions in the given `mil_program` 87 | This is used to ensure we don't generate crazy big programs 88 | """ 89 | def count_block(block: Block): 90 | complexity = 0 91 | for op in block.operations: 92 | for child_block in op.blocks: 93 | complexity += count_block(child_block) 94 | complexity += 1 95 | return complexity 96 | 97 | total_complexity = 0 98 | for func in mil_program.functions.values(): 99 | total_complexity += count_block(func) 100 | return total_complexity 101 | 102 | 103 | def run_and_compare_hlo_module( 104 | hlo_module, 105 | inputs, 106 | expected_outputs, 107 | *, 108 | max_complexity: int = 10_000, 109 | atol=1e-04, 110 | rtol=1e-05, 111 | ): 112 | mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18) 113 | program_complexity = _count_program_complexity(mil_program) 114 | if program_complexity > max_complexity: 115 | raise ValueError( 116 | f"Generated a MIL program with complexity {program_complexity}, " 117 | f"max allowed complexity is {max_complexity}" 118 | ) 119 | 120 | pipeline = DEFAULT_HLO_PIPELINE 121 | # We temporarily avoid fp16 conversions in tests because of https://github.com/apple/coremltools/issues/2324 122 | passes_to_remove = [ 123 | 'common::add_fp16_cast' 124 | ] 125 | pipeline.remove_passes(passes_to_remove) 126 | 127 | cml_model = ct.convert( 128 | mil_program, 129 | source="milinternal", 130 | minimum_deployment_target=ct.target.iOS18, 131 | pass_pipeline=pipeline, 132 | compute_units=ct.ComputeUnit.CPU_ONLY, 133 | ) 134 | 135 | # Generate random inputs that matches cml_model input spec 136 | cml_input_key_values = {} 137 | for input_name, input_value in zip(cml_model.input_description, flatten(inputs)): 138 | cml_input_key_values[input_name] = input_value 139 | 140 | # TODO(knielsen): Is there a nicer way of doing this? 141 | if not isinstance(expected_outputs, (list, tuple)): 142 | expected_outputs = (expected_outputs, ) 143 | 144 | # Prepare the output for comparison 145 | cml_expected_outputs = {} 146 | for output_name, output_value in zip(cml_model.output_description, flatten(expected_outputs)): 147 | cml_expected_outputs[output_name] = np.asarray(output_value) 148 | 149 | compare_backend(cml_model, cml_input_key_values, cml_expected_outputs, atol=atol, rtol=rtol) 150 | 151 | return cml_model 152 | 153 | 154 | def run_and_compare_specific_input(jax_func, inputs, max_complexity: int = 10_000): 155 | """ 156 | Converts the given `jax_func` to a CoreML model. 157 | If the CoreML model and `jax_func` does not agree on the output, an error will be raised. 158 | The resulting CoreML model will be returned. 159 | """ 160 | 161 | jax_func = jax.jit(jax_func) 162 | exported = jax_export(jax_func, inputs) 163 | context = jax_mlir.make_ir_context() 164 | hlo_module = ir.Module.parse(exported.mlir_module(), context=context) 165 | # print(f"HLO module: {hlo_module}") 166 | 167 | # Transfor the input to match the Jax model, and call it 168 | jax_input_values = __nest_flat_jax_input_to_input_spec(inputs, flatten(inputs)) 169 | expected_output = jax_func(*jax_input_values) 170 | 171 | return run_and_compare_hlo_module(hlo_module, inputs, expected_output, max_complexity=max_complexity) 172 | 173 | 174 | def run_and_compare(jax_func, input_specification, max_complexity: int = 10_000): 175 | """ 176 | Converts the given `jax_func` to a CoreML model. 177 | The model will be tested with randomly generated data with the shapes of `input_specification`. 178 | If the CoreML model and `jax_func` does not agree on the output, an error will be raised. 179 | The resulting CoreML model will be returned. 180 | """ 181 | flat_inputs = [] 182 | key = jax.random.PRNGKey(0) 183 | for input_spec in flatten(input_specification): 184 | key, value_key = jax.random.split(key, num=2) 185 | input_value = generate_random_from_shape(input_spec, value_key) 186 | flat_inputs.append(input_value) 187 | 188 | inputs = __nest_flat_jax_input_to_input_spec(input_specification, flat_inputs) 189 | return run_and_compare_specific_input(jax_func, inputs, max_complexity=max_complexity) 190 | 191 | 192 | def get_model_instruction_types(cml_model) -> List[str]: 193 | def collect_ops(ops: List) -> List[str]: 194 | collected_ops = [] 195 | for op in ops: 196 | collected_ops.append(op.op_type) 197 | for block in op.blocks: 198 | collected_ops += collect_ops(block.operations) 199 | 200 | return collected_ops 201 | 202 | mil_program = cml_model._mil_program 203 | all_ops = [] 204 | for func in mil_program.functions.values(): 205 | all_ops += collect_ops(func.operations) 206 | return all_ops 207 | -------------------------------------------------------------------------------- /stablehlo_coreml/utils.py: -------------------------------------------------------------------------------- 1 | from coremltools.converters.mil.mil import Builder as mb 2 | from coremltools.converters.mil.mil.var import Var 3 | 4 | from dataclasses import dataclass 5 | from typing import List 6 | from functools import reduce 7 | import itertools 8 | import numpy as np 9 | from coremltools.converters.mil.mil import types 10 | from jaxlib.mlir import ir 11 | 12 | 13 | @dataclass 14 | class ResolvedSliceSpec: 15 | start_indices: List[int] | Var 16 | end_indices: List[int] | Var 17 | strides: List[int] 18 | shape: List[int] 19 | 20 | 21 | def index_by_slices(tensor, slice_spec): 22 | tensor = fix_scalar_tensor(tensor) 23 | resolved_slices = _resolve_slice_spec(tensor, slice_spec) 24 | 25 | return mb.slice_by_index( 26 | x=tensor, 27 | begin=resolved_slices.start_indices, 28 | end=resolved_slices.end_indices, 29 | stride=resolved_slices.strides 30 | ) 31 | 32 | 33 | def update_tensor_by_slice(tensor, slice_spec, value): 34 | tensor = fix_scalar_tensor(tensor) 35 | resolved_slices = _resolve_slice_spec(tensor, slice_spec) 36 | 37 | value = mb.reshape(x=value, shape=resolved_slices.shape) 38 | return mb.slice_update( 39 | x=tensor, 40 | update=value, 41 | begin=resolved_slices.start_indices, 42 | end=resolved_slices.end_indices, 43 | stride=resolved_slices.strides 44 | ) 45 | 46 | 47 | def fix_scalar_tensor(tensor): 48 | """ 49 | From a numpy scalar type, CoreML will create a rank 0 tensor, which it will 50 | later struggle to do operations on. We will re-shape it to a rank 1 tensor 51 | with dimension 1. 52 | """ 53 | if len(tensor.shape) == 0: 54 | tensor = mb.reshape(x=tensor, shape=(1,)) 55 | return tensor 56 | 57 | 58 | def _flatten_list(lst): 59 | flat_list = [] 60 | for item in lst: 61 | if isinstance(item, (list, tuple)): 62 | flat_list += _flatten_list(item) 63 | else: 64 | flat_list.append(item) 65 | return flat_list 66 | 67 | 68 | def _count_dimensions(slice_spec): 69 | dim_count = 0 70 | for spec in slice_spec: 71 | if isinstance(spec, Var): 72 | if spec.rank != 1: 73 | raise ValueError("The Var spec must have rank 1!") 74 | dim_count += spec.shape[0] 75 | elif isinstance(spec, type(Ellipsis)): 76 | raise ValueError("Can not count dimensions for slice spec containing Ellipsis") 77 | else: 78 | dim_count += 1 79 | return dim_count 80 | 81 | 82 | def _resolve_slice_spec(tensor, slice_spec) -> ResolvedSliceSpec: 83 | start_indices = [] 84 | end_indices = [] 85 | strides = [] 86 | shape = [] 87 | 88 | # We allow the slice_spec to have nested lists. In that case we flatten it 89 | slice_spec = _flatten_list(slice_spec) 90 | if len(slice_spec) == 0: 91 | # Special case for scalar 92 | slice_spec = [slice(None)] 93 | 94 | tensor_rank = len(tensor.shape) 95 | contains_var_type = False 96 | dim_counter = 0 97 | for i, spec in enumerate(slice_spec): 98 | if isinstance(spec, type(slice(None))): 99 | start_indices.append(spec.start or 0) 100 | end_indices.append(spec.stop or tensor.shape[dim_counter]) 101 | strides.append(spec.step or 1) 102 | shape.append(end_indices[-1] - start_indices[-1] // strides[-1]) 103 | dim_counter += 1 104 | elif isinstance(spec, type(Ellipsis)): 105 | if any([isinstance(s, type(Ellipsis)) for s in slice_spec[i+1:]]): 106 | raise ValueError("Only supports one ellipsis when indexing") 107 | 108 | dims_before = dim_counter 109 | dims_after = _count_dimensions(slice_spec[i+1:]) 110 | num_ellipsis_dims = tensor_rank - (dims_before + dims_after) 111 | 112 | ellipsis_starts = [0] * num_ellipsis_dims 113 | ellipsis_ends = [tensor.shape[dim] for dim in range(dim_counter, dim_counter + num_ellipsis_dims)] 114 | ellipsis_strides = [1] * num_ellipsis_dims 115 | ellipsis_shape = [ 116 | (end - start) // stride for start, end, stride 117 | in zip(ellipsis_starts, ellipsis_ends, ellipsis_strides) 118 | ] 119 | 120 | start_indices += ellipsis_starts 121 | end_indices += ellipsis_ends 122 | strides += ellipsis_strides 123 | shape += ellipsis_shape 124 | 125 | dim_counter += num_ellipsis_dims 126 | elif isinstance(spec, Var): 127 | if spec.rank != 1: 128 | raise ValueError("The Var spec must have rank 1!") 129 | contains_var_type = True 130 | start_indices.append(spec) 131 | end_indices.append(mb.add(x=spec, y=1)) 132 | strides += [1] * spec.shape[0] 133 | shape += [1] * spec.shape[0] 134 | dim_counter += spec.shape[0] 135 | else: 136 | # Assume it is an integer index 137 | idx = int(spec) 138 | start_indices.append(idx) 139 | end_indices.append(idx + 1) 140 | strides.append(1) 141 | shape.append(1) 142 | dim_counter += 1 143 | 144 | # If slice_spec contained any Var types, we will need to concatenate the full result 145 | # to be one big Var type 146 | if contains_var_type: 147 | def partition_list(lst): 148 | parts = [[]] 149 | for element in lst: 150 | if isinstance(element, Var): 151 | if len(parts[-1]) == 0: 152 | parts.pop() 153 | parts.append(element) 154 | parts.append([]) 155 | else: 156 | parts[-1].append(element) 157 | if len(parts[-1]) == 0: 158 | # The last partition may be empty, if so we skip it 159 | return parts[:-1] 160 | return parts 161 | 162 | def concat_to_var(lst): 163 | parts = partition_list(lst) 164 | return mb.concat(values=parts, axis=0) 165 | 166 | start_indices = concat_to_var(start_indices) 167 | end_indices = concat_to_var(end_indices) 168 | 169 | if len(strides) != tensor_rank or len(shape) != tensor_rank: 170 | raise ValueError("Slice does not line up!") 171 | 172 | return ResolvedSliceSpec( 173 | start_indices=start_indices, 174 | end_indices=end_indices, 175 | strides=strides, 176 | shape=shape, 177 | ) 178 | 179 | 180 | def iterate_indexes_in_shapes(func, shapes: List, init_values: List, unroll_limit: int = 25): 181 | """ 182 | Given a list of `shapes`, fx [(3, 2, 3), (5, 2, 3)] this method will iterate 183 | the product of all valid indexes into the given shapes. 184 | The function `func: Idx1, Idx2, ..., Idxn, Acc1, ..., Acck -> Res1, ..., Resk` is expected to given the 185 | list of indexes and the accumulated result so far, to update the result based 186 | on the index. 187 | The init_values is a list of [InitVal1, ..., InitValk], and the function `func` 188 | must return a list of `k` values [Res1, ..., Resk]. 189 | The indexes, Idx1, ..., Idn, may be either a mil mb.Var type of a python 190 | tuple depending on if the loop is unrolled or not. `func` is expected to be 191 | able to handle this. 192 | 193 | If the total number of traversed indexes is <`unroll_limit`, the loop will be 194 | fully un-rolled into MIL instructions. Otherwise it will be constructed as 195 | a dynamic while loop executed at runtime. 196 | """ 197 | shapes_elements = [reduce(lambda a, b: a * b, shape, 1) for shape in shapes] 198 | total_iterations = reduce(lambda a, b: a * b, shapes_elements, 1) 199 | 200 | results = init_values 201 | if total_iterations <= unroll_limit: 202 | # Fully unroll the loop 203 | ranges = [itertools.product(*[range(dim) for dim in shape]) for shape in shapes] 204 | for indexes in itertools.product(*ranges): 205 | results = func(*indexes, *results) 206 | else: 207 | # Dynamically compute the loop 208 | def suffix_product(lst): 209 | res = [] 210 | acc = 1 211 | for i in reversed(lst): 212 | res.append(acc) 213 | acc *= i 214 | return list(reversed(res)) 215 | integer_index_strides = suffix_product(shapes_elements) 216 | index_strides = [suffix_product(shape) for shape in shapes] 217 | 218 | # Attempt at looping over result indexes without fully unrolling 219 | def cond(i, *acc): 220 | return mb.less(x=i, y=total_iterations) 221 | 222 | def body(i, *acc): 223 | # Split out the index i to an integer index into the individual shapes. 224 | integer_indexes = [ 225 | mb.mod(x=mb.floor_div(x=i, y=stride), y=elements) 226 | for stride, elements in zip(integer_index_strides, shapes_elements) 227 | ] 228 | # Map the integer index in the shapes to an actual shaped index 229 | indexes = [ 230 | mb.concat(values=[ 231 | mb.mod(x=mb.floor_div(x=idx, y=stride), y=dim) for stride, dim in zip(strides, shape) 232 | ], axis=0) 233 | if len(shape) > 0 else [] 234 | for idx, strides, shape in zip(integer_indexes, index_strides, shapes) 235 | ] 236 | 237 | results = func(*indexes, *acc) 238 | return [mb.add(x=i, y=1)] + results 239 | 240 | fixed_init_values = [fix_scalar_tensor(init_value) for init_value in init_values] 241 | results = mb.while_loop(_cond=cond, _body=body, loop_vars=[0] + fixed_init_values)[1:] # Skip the counter 242 | 243 | return results 244 | 245 | 246 | def inverse_permutation(perm): 247 | """ 248 | Given a permutation `perm`, compute the inverse of the permutation 249 | """ 250 | inv = [0] * len(perm) 251 | for i, j in enumerate(perm): 252 | inv[j] = i 253 | return inv 254 | 255 | 256 | def get_mil_type_from_ir(element_type): 257 | if isinstance(element_type, ir.IntegerType): 258 | match (element_type.width, element_type.is_unsigned): 259 | case (32, False): 260 | return types.int32 261 | case (32, True): 262 | return types.uint32 263 | case (16, False): 264 | return types.int16 265 | case (16, True): 266 | return types.uint16 267 | case (8, False): 268 | return types.int8 269 | case (8, True): 270 | return types.uint8 271 | case (1, _): 272 | return types.bool 273 | if isinstance(element_type, ir.F16Type): 274 | return types.fp16 275 | if isinstance(element_type, ir.F32Type): 276 | return types.fp32 277 | raise ValueError(f"Unsupported type {element_type}") 278 | 279 | 280 | def get_mil_type(obj): 281 | if isinstance(obj, ir.Type): 282 | if hasattr(obj, 'element_type'): 283 | return get_mil_type_from_ir(obj.element_type) 284 | return get_mil_type_from_ir(obj) 285 | if isinstance(obj, np.ndarray): 286 | return types.numpy_type_to_builtin_type(obj.dtype) 287 | return obj.dtype 288 | 289 | 290 | def get_numpy_type(obj): 291 | return types.nptype_from_builtin(get_mil_type(obj)) 292 | 293 | 294 | def dtype_str(type): 295 | # TODO(knielsen): Add additional types 296 | return { 297 | types.int32: "int32", 298 | types.uint32: "uint32", 299 | types.int16: "int16", 300 | types.uint16: "uint16", 301 | types.int8: "int8", 302 | types.uint8: "uint8", 303 | types.fp16: "fp16", 304 | types.fp32: "fp32", 305 | types.bool: "bool", 306 | }[type] 307 | 308 | 309 | def clamp_index(index, shape, size): 310 | """ 311 | Clamps start indices to ensure they are within bounds: [0, operand_dim - slice_size] 312 | This is required by the StableHLO specification 313 | """ 314 | max_start_indices = mb.sub(x=shape, y=size) 315 | index = mb.minimum(x=index, y=max_start_indices) 316 | index = mb.maximum(x=index, y=0) 317 | return index 318 | 319 | 320 | def range_along_dim(shape, axis, dtype): 321 | axis = len(shape) + axis if axis < 0 else axis 322 | vec_shape = [shape[dim] if dim == axis else 1 for dim in range(len(shape))] 323 | vec_reps = [1 if dim == axis else shape[dim] for dim in range(len(shape))] 324 | arange = mb.range_1d(start=dtype(0), end=dtype(shape[axis]), step=dtype(1)) 325 | return mb.tile(x=mb.reshape(x=arange, shape=vec_shape), reps=vec_reps) 326 | -------------------------------------------------------------------------------- /tests/pytorch/test_pytorch.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from contextlib import contextmanager 3 | 4 | import jax 5 | import numpy as np 6 | from jax._src.lib.mlir import ir 7 | from jax._src.interpreters import mlir as jax_mlir 8 | 9 | from tests.utils import run_and_compare_hlo_module, flatten 10 | 11 | torch = pytest.importorskip("torch") 12 | torchvision = pytest.importorskip("torchvision") 13 | tx = pytest.importorskip("torchax") 14 | tx_export = pytest.importorskip("torchax.export") 15 | 16 | 17 | def export_to_stablehlo_module(pytorch_model, inputs): 18 | pytorch_model.eval() 19 | weights, jax_func = tx.extract_jax(pytorch_model) 20 | 21 | def wrapped_weights_func(inputs): 22 | out = jax_func(weights, inputs) 23 | 24 | # This is slightly hacky, but sometimes the output is a dict-like object 25 | # which is not registered with jax for jitting. 26 | # We will try to convert it to a dict first. 27 | try: 28 | out_dict = dict(out) 29 | return {k: v for k, v in out_dict.items() if isinstance(v, jax.Array)} 30 | except (TypeError, ValueError): 31 | return out 32 | 33 | numpy_inputs = tuple([input.detach().numpy() for input in inputs]) 34 | 35 | # Export the JIT-ed function 36 | jax_exported = jax.export.export(jax.jit(wrapped_weights_func))(numpy_inputs) 37 | stablehlo = jax_exported.mlir_module() 38 | 39 | context = jax_mlir.make_ir_context() 40 | hlo_module = ir.Module.parse(stablehlo, context=context) 41 | 42 | # Use jaxpr to find which inputs are actually used 43 | # We analyze the un-jitted function to see actual usage inside the body 44 | jaxpr = jax.make_jaxpr(wrapped_weights_func)(numpy_inputs) 45 | filtered_inputs = _filter_unused_inputs(jaxpr, inputs) 46 | 47 | return hlo_module, filtered_inputs 48 | 49 | 50 | def _filter_unused_inputs(jaxpr, inputs): 51 | """ 52 | Filters inputs based on their usage in the jaxpr. 53 | JAX export drops unused arguments from the MLIR module, so we need to align our inputs. 54 | """ 55 | used_input_indices = [] 56 | for i, invar in enumerate(jaxpr.jaxpr.invars): 57 | # Check if invar is used in any equation 58 | is_used = False 59 | for eqn in jaxpr.jaxpr.eqns: 60 | if invar in eqn.invars: 61 | is_used = True 62 | break 63 | 64 | # Check if invar is used as an output 65 | if not is_used: 66 | for outvar in jaxpr.jaxpr.outvars: 67 | if invar == outvar: 68 | is_used = True 69 | break 70 | 71 | if is_used: 72 | used_input_indices.append(i) 73 | 74 | return [inputs[i].detach().numpy() for i in used_input_indices] 75 | 76 | 77 | def evaluate_pytorch_model(model, inputs): 78 | hlo_module, module_inputs = export_to_stablehlo_module(model, inputs) 79 | 80 | model_outputs = model(*inputs) 81 | if isinstance(model_outputs, torch.Tensor): 82 | expected_outputs = [model_outputs.detach().numpy()] 83 | else: 84 | # Sort keys to match JAX's behavior (JAX sorts dict keys) 85 | keys = sorted(model_outputs.keys()) 86 | expected_outputs = [model_outputs[k] for k in keys] 87 | 88 | expected_outputs = [ 89 | output_tensor.detach().numpy() for output_tensor 90 | in flatten(expected_outputs) 91 | if isinstance(output_tensor, torch.Tensor) 92 | ] 93 | 94 | # Sanity check expected outputs to catch uninitialized weights issues 95 | for i, out in enumerate(expected_outputs): 96 | abs_max = np.abs(out).max() 97 | if abs_max > 1e9: 98 | raise ValueError(f"Output {i} has insanely large values (max: {abs_max:.2e}). " 99 | "This likely means the model has uninitialized weights (batch norm explosion)") 100 | 101 | output_range = out.max() - out.min() 102 | if output_range < 1e-5 and out.size > 1: 103 | raise ValueError(f"Output {i} has effectively zero range (range: {output_range:.2e}). " 104 | "This likely means the model has uninitialized weights") 105 | 106 | # These models are quite big, so tolerances are relaxed 107 | run_and_compare_hlo_module(hlo_module, module_inputs, expected_outputs, max_complexity=50_000, atol=5e-01, rtol=5e-02) 108 | 109 | 110 | @contextmanager 111 | def patch_transformers_compiling(): 112 | # Currently the transformers package is not aware of torchax static compilation / tracing. 113 | # This causes the jax-export to fail: https://github.com/google/torchax/issues/56 114 | # For now, we patch the transformers package to indicate that we are compiling. 115 | from unittest.mock import patch 116 | patches = [] 117 | targets = [ 118 | "transformers.modeling_attn_mask_utils.is_torchdynamo_compiling", 119 | "transformers.utils.is_torchdynamo_compiling", 120 | "transformers.modeling_utils.is_torchdynamo_compiling", 121 | ] 122 | 123 | for target in targets: 124 | try: 125 | # Check if module exists before patching 126 | module_name = target.rsplit(".", 1)[0] 127 | __import__(module_name) 128 | p = patch(target, return_value=True) 129 | p.start() 130 | patches.append(p) 131 | except (ImportError, AttributeError): 132 | pass 133 | 134 | try: 135 | yield 136 | finally: 137 | for p in patches: 138 | p.stop() 139 | 140 | 141 | # ============================================================================== 142 | # LLM / NLP Models 143 | # ============================================================================== 144 | 145 | def test_tinyllama(): 146 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 147 | 148 | model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" 149 | tokenizer = AutoTokenizer.from_pretrained(model_name) 150 | config = AutoConfig.from_pretrained(model_name) 151 | 152 | # Use a much smaller config to avoid OOM in CI 153 | config.num_hidden_layers = 2 154 | config.hidden_size = 128 155 | config.intermediate_size = 512 156 | config.num_attention_heads = 4 157 | config.num_key_value_heads = 4 158 | config.use_cache = False 159 | config.torch_dtype = "float16" 160 | 161 | model = AutoModelForCausalLM.from_config(config) 162 | 163 | prompt = "Hello, my name is" 164 | inputs = tokenizer(prompt, return_tensors="pt") 165 | 166 | evaluate_pytorch_model(model, (inputs.input_ids, )) 167 | 168 | 169 | def test_t5_small(): 170 | from transformers import AutoTokenizer, T5Model, AutoConfig 171 | 172 | # Use AutoTokenizer which might fallback to fast tokenizer (no sentencepiece needed if available) 173 | tokenizer = AutoTokenizer.from_pretrained("t5-small") 174 | config = AutoConfig.from_pretrained("t5-small") 175 | config.num_layers = 2 176 | config.num_decoder_layers = 2 177 | config.d_model = 128 178 | config.d_kv = 32 179 | config.d_ff = 512 180 | config.num_heads = 4 181 | config.use_cache = False 182 | model = T5Model(config) 183 | 184 | input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids 185 | decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids 186 | attention_mask = torch.ones_like(input_ids) 187 | 188 | # T5Model forward: (input_ids, attention_mask, decoder_input_ids, ...) 189 | with patch_transformers_compiling(): 190 | evaluate_pytorch_model(model, (input_ids, attention_mask, decoder_input_ids)) 191 | 192 | 193 | def test_distilbert(): 194 | from transformers import AutoModel, AutoTokenizer, AutoConfig 195 | 196 | model_name = "distilbert-base-uncased" 197 | tokenizer = AutoTokenizer.from_pretrained(model_name) 198 | config = AutoConfig.from_pretrained(model_name) 199 | config.n_layers = 2 200 | config.dim = 128 201 | config.hidden_dim = 512 202 | config.n_heads = 4 203 | model = AutoModel.from_config(config) 204 | 205 | inputs = tokenizer("this is a test of distilbert", return_tensors="pt") 206 | with patch_transformers_compiling(): 207 | evaluate_pytorch_model(model, (inputs.input_ids, inputs.attention_mask)) 208 | 209 | 210 | def test_gpt2(): 211 | from transformers import AutoModel, AutoTokenizer, AutoConfig 212 | 213 | model_name = "gpt2" 214 | tokenizer = AutoTokenizer.from_pretrained(model_name) 215 | config = AutoConfig.from_pretrained(model_name) 216 | config.n_layer = 2 217 | config.n_embd = 128 218 | config.n_head = 4 219 | config.use_cache = False 220 | model = AutoModel.from_config(config) 221 | 222 | input_ids = tokenizer("this is a test of gpt2", return_tensors="pt").input_ids 223 | evaluate_pytorch_model(model, (input_ids, )) 224 | 225 | 226 | def test_bert(): 227 | from transformers import AutoModel, AutoTokenizer, AutoConfig 228 | 229 | model_name = "bert-base-uncased" 230 | tokenizer = AutoTokenizer.from_pretrained(model_name) 231 | config = AutoConfig.from_pretrained(model_name) 232 | config.num_hidden_layers = 2 233 | config.hidden_size = 128 234 | config.intermediate_size = 512 235 | config.num_attention_heads = 4 236 | model = AutoModel.from_config(config) 237 | 238 | inputs = tokenizer("this is a test of bert", return_tensors="pt") 239 | with patch_transformers_compiling(): 240 | evaluate_pytorch_model(model, (inputs.input_ids, inputs.attention_mask)) 241 | 242 | 243 | # ============================================================================== 244 | # Audio Models 245 | # ============================================================================== 246 | 247 | def test_whisper_tiny(): 248 | from transformers import AutoModelForSpeechSeq2Seq, AutoConfig 249 | import torch 250 | 251 | model_name = "openai/whisper-tiny" 252 | config = AutoConfig.from_pretrained(model_name) 253 | config.encoder_layers = 2 254 | config.decoder_layers = 2 255 | config.d_model = 128 256 | config.encoder_attention_heads = 4 257 | config.decoder_attention_heads = 4 258 | config.use_cache = False 259 | model = AutoModelForSpeechSeq2Seq.from_config(config) 260 | 261 | # Workaround for torchax issue with tied weights 262 | for module in model.modules(): 263 | if hasattr(module, 'weight') and isinstance(module.weight, torch.nn.Parameter): 264 | module.weight = torch.nn.Parameter(module.weight.clone()) 265 | 266 | # Generate dummy audio input 267 | # Whisper expects input_features of shape (batch, feature_size, sequence_length) 268 | # feature_size=80, sequence_length=3000 (for 30s audio at 100Hz frame rate roughly) 269 | input_features = torch.randn(1, 80, 3000) 270 | decoder_input_ids = torch.tensor([[50258]]) 271 | attention_mask = torch.ones((1, 3000)) 272 | 273 | # Whisper forward: (input_features, attention_mask, decoder_input_ids, ...) 274 | with patch_transformers_compiling(): 275 | evaluate_pytorch_model(model, (input_features, attention_mask, decoder_input_ids)) 276 | 277 | 278 | # ============================================================================== 279 | # Vision Models 280 | # ============================================================================== 281 | 282 | def test_convnext_tiny(): 283 | inputs = (torch.randn(1, 3, 224, 224), ) 284 | model = torchvision.models.convnext_tiny(weights="DEFAULT") 285 | evaluate_pytorch_model(model, inputs) 286 | 287 | 288 | def test_vit_b_16(): 289 | inputs = (torch.randn(1, 3, 224, 224), ) 290 | model = torchvision.models.vit_b_16(weights="DEFAULT") 291 | evaluate_pytorch_model(model, inputs) 292 | 293 | 294 | def test_efficientnet_b0(): 295 | inputs = (torch.randn(1, 3, 224, 224), ) 296 | model = torchvision.models.efficientnet_b0(weights="DEFAULT") 297 | evaluate_pytorch_model(model, inputs) 298 | 299 | 300 | def test_mobilenet_v3_small(): 301 | inputs = (torch.randn(1, 3, 224, 224), ) 302 | model = torchvision.models.mobilenet_v3_small(weights="DEFAULT") 303 | evaluate_pytorch_model(model, inputs) 304 | 305 | 306 | def test_densenet121(): 307 | inputs = (torch.randn(1, 3, 224, 224), ) 308 | model = torchvision.models.densenet121(weights="DEFAULT") 309 | evaluate_pytorch_model(model, inputs) 310 | 311 | 312 | def test_resnet50(): 313 | inputs = (torch.randn(4, 3, 224, 224), ) 314 | model = torchvision.models.resnet50() 315 | evaluate_pytorch_model(model, inputs) 316 | 317 | 318 | def test_resnet18(): 319 | inputs = (torch.randn(1, 3, 224, 224), ) 320 | model = torchvision.models.resnet18(weights="DEFAULT") 321 | evaluate_pytorch_model(model, inputs) 322 | 323 | 324 | def test_inception_v3(): 325 | inputs = (torch.randn(2, 3, 299, 299), ) 326 | model = torchvision.models.inception_v3(weights="DEFAULT") 327 | evaluate_pytorch_model(model, inputs) 328 | 329 | 330 | def test_vgg11(): 331 | inputs = (torch.randn(1, 3, 224, 224), ) 332 | model = torchvision.models.vgg11(weights="DEFAULT") 333 | evaluate_pytorch_model(model, inputs) 334 | -------------------------------------------------------------------------------- /tests/test_flax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import nnx 3 | import jax.numpy as jnp 4 | 5 | from tests.utils import run_and_compare, run_and_compare_specific_input 6 | 7 | from tests.flax_blocks import ResidualConv, Encoder, UNet 8 | 9 | from functools import partial 10 | 11 | 12 | def test_flax_nnx_linear(): 13 | class TestLinear(nnx.Module): 14 | def __init__(self, rngs=nnx.Rngs): 15 | self.layer = nnx.Linear(in_features=2, out_features=4, rngs=rngs) 16 | 17 | def __call__(self, x): 18 | return self.layer(x) 19 | 20 | model = TestLinear(nnx.Rngs(0)) 21 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 2)), )) 22 | 23 | 24 | def test_flax_stacked_linear(): 25 | class TestStackedLinear(nnx.Module): 26 | def __init__(self, rngs=nnx.Rngs): 27 | self.upscale_layer = nnx.Linear(in_features=2, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs) 28 | 29 | self.hidden_layers = nnx.List() 30 | for _ in range(3): # 3 hidden layers 31 | self.hidden_layers.append(nnx.Linear( 32 | in_features=4, 33 | out_features=4, 34 | bias_init=nnx.initializers.ones, 35 | rngs=rngs 36 | )) 37 | self.downscale_layer = nnx.Linear( 38 | in_features=4, 39 | out_features=2, 40 | bias_init=nnx.initializers.ones, 41 | rngs=rngs 42 | ) 43 | 44 | def __call__(self, x): 45 | out = self.upscale_layer(x) 46 | for layer in self.hidden_layers: 47 | out = layer(out) 48 | out = self.downscale_layer(out) 49 | return out 50 | 51 | model = TestStackedLinear(nnx.Rngs(0)) 52 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 2)), )) 53 | 54 | 55 | def test_flax_stacked_lax_scan(): 56 | class TestStackedLaxScanLinear(nnx.Module): 57 | def __init__(self, rngs=nnx.Rngs): 58 | @nnx.split_rngs(splits=3) # 3 hidden layers 59 | @nnx.vmap(axis_size=3) 60 | def create_hidden_layers(rngs: nnx.Rngs): 61 | return nnx.Linear(in_features=4, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs) 62 | self.hidden_layers = create_hidden_layers(rngs) 63 | 64 | self.upscale_layer = nnx.Linear(in_features=2, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs) 65 | self.downscale_layer = nnx.Linear(in_features=4, out_features=2, bias_init=nnx.initializers.ones, rngs=rngs) 66 | 67 | def __call__(self, x): 68 | out = self.upscale_layer(x) 69 | 70 | layer_def, layer_states = nnx.split(self.hidden_layers) 71 | 72 | def forward(x, layer_state): 73 | layer = nnx.merge(layer_def, layer_state) 74 | x = layer(x) 75 | return x, None 76 | out, _ = jax.lax.scan(forward, out, layer_states) 77 | 78 | out = self.downscale_layer(out) 79 | return out 80 | 81 | model = TestStackedLaxScanLinear(nnx.Rngs(0)) 82 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 2)), )) 83 | 84 | 85 | def test_flax_convolution(): 86 | class TestConvolution(nnx.Module): 87 | def __init__(self, rngs=nnx.Rngs): 88 | self.conv = nnx.Conv(in_features=2, out_features=1, kernel_size=3, rngs=rngs) 89 | 90 | def __call__(self, x): 91 | return self.conv(x) 92 | 93 | model = TestConvolution(nnx.Rngs(0)) 94 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 8, 2)), )) 95 | 96 | 97 | def test_flax_grouped_convolution(): 98 | class TestConvolution(nnx.Module): 99 | def __init__(self, in_features: int, feature_groups: int, rngs=nnx.Rngs): 100 | self.conv = nnx.Conv( 101 | in_features=in_features, 102 | out_features=2 * feature_groups, 103 | kernel_size=3, 104 | feature_group_count=feature_groups, 105 | rngs=rngs 106 | ) 107 | 108 | def __call__(self, x): 109 | return self.conv(x) 110 | 111 | run_and_compare(nnx.jit(TestConvolution(4, 2, nnx.Rngs(0))), (jnp.zeros((2, 8, 4)), )) 112 | run_and_compare(nnx.jit(TestConvolution(9, 3, nnx.Rngs(0))), (jnp.zeros((2, 8, 9)), )) 113 | 114 | 115 | def test_flax_2d_convolution(): 116 | class TestConvolution(nnx.Module): 117 | def __init__(self, rngs=nnx.Rngs): 118 | self.conv = nnx.Conv(in_features=3, out_features=1, kernel_size=(3, 3), rngs=rngs) 119 | 120 | def __call__(self, x): 121 | return self.conv(x) 122 | 123 | model = TestConvolution(nnx.Rngs(0)) 124 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 8, 8, 3)), )) 125 | 126 | 127 | def test_flax_3d_convolution(): 128 | class TestConvolution(nnx.Module): 129 | def __init__(self, rngs=nnx.Rngs): 130 | self.conv = nnx.Conv(in_features=3, out_features=1, kernel_size=(3, 3, 3), rngs=rngs) 131 | 132 | def __call__(self, x): 133 | return self.conv(x) 134 | 135 | model = TestConvolution(nnx.Rngs(0)) 136 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 8, 8, 8, 3)), )) 137 | 138 | 139 | def test_flax_stacked_convolution(): 140 | class TestStackedConvolution(nnx.Module): 141 | def __init__(self, rngs=nnx.Rngs): 142 | @nnx.split_rngs(splits=3) # 3 hidden layers 143 | @nnx.vmap(axis_size=3) 144 | def create_convs(rngs: nnx.Rngs): 145 | return nnx.Conv(in_features=2, out_features=2, kernel_size=3, rngs=rngs) 146 | self.conv_layers = create_convs(rngs) 147 | 148 | def __call__(self, x): 149 | layer_def, layer_states = nnx.split(self.conv_layers) 150 | 151 | def forward(x, layer_state): 152 | layer = nnx.merge(layer_def, layer_state) 153 | x = layer(x) 154 | x = nnx.relu(x) 155 | return x, None 156 | out, _ = jax.lax.scan(forward, x, layer_states) 157 | return out 158 | 159 | model = TestStackedConvolution(nnx.Rngs(0)) 160 | run_and_compare(nnx.jit(model), (jnp.zeros((3, 8, 2)), )) 161 | 162 | 163 | def test_flax_transposed_convolution(): 164 | class TestTransposedConvolution(nnx.Module): 165 | def __init__(self, rngs=nnx.Rngs): 166 | self.conv = nnx.Conv(in_features=2, out_features=3, kernel_size=4, rngs=rngs) 167 | self.conv_transpose = nnx.ConvTranspose(in_features=3, out_features=2, kernel_size=3, rngs=rngs) 168 | 169 | def __call__(self, x): 170 | x = self.conv(x) 171 | x = self.conv_transpose(x) 172 | return x 173 | 174 | model = TestTransposedConvolution(nnx.Rngs(0)) 175 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 8, 2)), )) 176 | 177 | 178 | def test_flax_transposed_2d_convolution(): 179 | class TestTransposedConvolution(nnx.Module): 180 | def __init__(self, rngs=nnx.Rngs): 181 | self.conv = nnx.Conv(in_features=2, out_features=3, kernel_size=(4, 4), rngs=rngs) 182 | self.conv_transpose = nnx.ConvTranspose(in_features=3, out_features=2, kernel_size=(3, 4), rngs=rngs) 183 | 184 | def __call__(self, x): 185 | x = self.conv(x) 186 | x = self.conv_transpose(x) 187 | return x 188 | 189 | model = TestTransposedConvolution(nnx.Rngs(0)) 190 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 10, 8, 2)), )) 191 | 192 | 193 | def test_flax_transposed_3d_convolution(): 194 | class TestTransposedConvolution(nnx.Module): 195 | def __init__(self, rngs=nnx.Rngs): 196 | self.conv = nnx.Conv(in_features=2, out_features=3, kernel_size=(4, 4, 4), rngs=rngs) 197 | self.conv_transpose = nnx.ConvTranspose(in_features=3, out_features=2, kernel_size=(3, 4, 2), rngs=rngs) 198 | 199 | def __call__(self, x): 200 | x = self.conv(x) 201 | x = self.conv_transpose(x) 202 | return x 203 | 204 | model = TestTransposedConvolution(nnx.Rngs(0)) 205 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 10, 8, 7, 2)), )) 206 | 207 | 208 | def test_kernel_dilated_conv(): 209 | class DilatedConvolution(nnx.Module): 210 | def __init__(self, rngs=nnx.Rngs): 211 | self.conv = nnx.Conv(in_features=4, out_features=2, kernel_size=4, kernel_dilation=2, rngs=rngs) 212 | 213 | def __call__(self, x): 214 | return self.conv(x) 215 | 216 | model = DilatedConvolution(nnx.Rngs(0)) 217 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 4, 4)), )) 218 | 219 | 220 | def test_strided_conv_transpose(): 221 | class StridedConvTranspose(nnx.Module): 222 | def __init__(self, rngs=nnx.Rngs): 223 | self.conv = nnx.ConvTranspose(in_features=4, out_features=2, kernel_size=3, strides=2, rngs=rngs) 224 | 225 | def __call__(self, x): 226 | return self.conv(x) 227 | 228 | model = StridedConvTranspose(nnx.Rngs(0)) 229 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 4, 4)), )) 230 | 231 | 232 | def test_convolution_ranges(): 233 | class ConvModel(nnx.Module): 234 | def __init__( 235 | self, 236 | conv_type, 237 | in_features: int, 238 | out_features: int, 239 | kernel_size: int, 240 | strides: int, 241 | dilation: int, 242 | rngs=nnx.Rngs 243 | ): 244 | self.conv = conv_type( 245 | in_features=in_features, 246 | out_features=out_features, 247 | kernel_size=kernel_size, 248 | strides=strides, 249 | kernel_dilation=dilation, 250 | rngs=rngs 251 | ) 252 | 253 | def __call__(self, x): 254 | return self.conv(x) 255 | 256 | for conv_type in [nnx.Conv, nnx.ConvTranspose]: 257 | for in_features in [1, 3]: 258 | for out_features in [1, 3]: 259 | for kernel_size in [2, 3]: 260 | for strides in [2, 3]: 261 | for dilation in [2, 3]: 262 | model = ConvModel( 263 | conv_type=conv_type, 264 | in_features=in_features, 265 | out_features=out_features, 266 | kernel_size=kernel_size, 267 | strides=strides, 268 | dilation=dilation, 269 | rngs=nnx.Rngs(0) 270 | ) 271 | run_and_compare(nnx.jit(model), (jnp.zeros((2, 8, in_features)), )) 272 | 273 | 274 | def test_flax_residual_conv_module(): 275 | model_upscale = ResidualConv(in_channels=2, out_channels=4, rngs=nnx.Rngs(0)) 276 | model_upscale.eval() 277 | run_and_compare(nnx.jit(model_upscale), (jnp.zeros((4, 8, 2)), )) 278 | 279 | model_downscale = ResidualConv(in_channels=4, out_channels=2, rngs=nnx.Rngs(0)) 280 | model_downscale.eval() 281 | run_and_compare(nnx.jit(model_downscale), (jnp.zeros((4, 4, 4)), )) 282 | 283 | 284 | def test_encoder(): 285 | model = Encoder(num_layers=3, rngs=nnx.Rngs(0)) 286 | model.eval() 287 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 8, 1)), )) 288 | 289 | 290 | def test_unet(): 291 | model = UNet(num_layers=2, rngs=nnx.Rngs(0)) 292 | model.eval() 293 | run_and_compare(nnx.jit(model), (jnp.zeros((4, 8, 1)), )) 294 | 295 | 296 | def test_activations(): 297 | example_input = (jnp.zeros((20,)),) 298 | run_and_compare(nnx.celu, example_input) 299 | run_and_compare(nnx.elu, example_input) 300 | run_and_compare(nnx.gelu, example_input) 301 | run_and_compare(nnx.glu, example_input) 302 | run_and_compare(nnx.hard_sigmoid, example_input) 303 | run_and_compare(nnx.hard_silu, example_input) 304 | run_and_compare(nnx.hard_swish, example_input) 305 | run_and_compare(nnx.hard_tanh, example_input) 306 | run_and_compare(nnx.leaky_relu, example_input) 307 | run_and_compare(nnx.log_sigmoid, example_input) 308 | run_and_compare(nnx.log_softmax, example_input) 309 | run_and_compare(nnx.logsumexp, example_input) 310 | run_and_compare(nnx.relu, example_input) 311 | run_and_compare(nnx.selu, example_input) 312 | run_and_compare(nnx.sigmoid, example_input) 313 | run_and_compare(nnx.silu, example_input) 314 | run_and_compare(nnx.soft_sign, example_input) 315 | run_and_compare(nnx.softmax, example_input) 316 | run_and_compare(nnx.softplus, example_input) 317 | run_and_compare(nnx.standardize, example_input) 318 | run_and_compare(nnx.swish, example_input) 319 | run_and_compare(nnx.tanh, example_input) 320 | 321 | run_and_compare_specific_input(partial(nnx.one_hot, num_classes=3), (jnp.array([0, 1, 2]), )) 322 | run_and_compare_specific_input(partial(nnx.one_hot, num_classes=5), (jnp.array([4, 0, 1, 0]), )) 323 | 324 | 325 | def test_attantion(): 326 | class TestAttention(nnx.Module): 327 | def __init__(self, rngs=nnx.Rngs): 328 | self.layer = nnx.MultiHeadAttention( 329 | num_heads=4, 330 | in_features=5, 331 | qkv_features=16, 332 | decode=False, 333 | rngs=rngs, 334 | ) 335 | 336 | def __call__(self, q, k, v): 337 | return self.layer(q, k, v) 338 | 339 | shape = (4, 3, 2, 5) 340 | input_spec = (jnp.zeros(shape), jnp.zeros(shape), jnp.zeros(shape)) 341 | run_and_compare(nnx.jit(TestAttention(nnx.Rngs(0))), input_spec) 342 | 343 | @nnx.jit 344 | def create_masks(length): 345 | attention_mask = nnx.make_attention_mask(length, length) 346 | causal_mask = nnx.make_causal_mask(length) 347 | return nnx.combine_masks(attention_mask, causal_mask) 348 | 349 | run_and_compare(create_masks, (jnp.zeros((5, 20)), )) 350 | 351 | 352 | def test_embed(): 353 | model = nnx.Embed(num_embeddings=10, features=5, rngs=nnx.Rngs(0)) 354 | example_input = (jnp.array([[1, 5, 3], [9, 3, 0]], dtype=jnp.int32), ) 355 | run_and_compare_specific_input(nnx.jit(model), example_input) 356 | 357 | 358 | def test_nnx_einsum(): 359 | layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) 360 | example_input = (jnp.zeros((16, 11, 2)), ) 361 | run_and_compare(nnx.jit(layer), example_input) 362 | 363 | 364 | def test_batch_norm_infer(): 365 | layer = nnx.BatchNorm(num_features=10, momentum=0.9, epsilon=1e-5, rngs=nnx.Rngs(0)) 366 | layer.eval() 367 | example_input = (jnp.zeros((20, 10)), ) 368 | run_and_compare(nnx.jit(layer), example_input) 369 | 370 | 371 | def test_layer_norm_infer(): 372 | layer = nnx.LayerNorm(num_features=10, rngs=nnx.Rngs(0)) 373 | layer.eval() 374 | example_input = (jnp.zeros((20, 10)), ) 375 | run_and_compare(nnx.jit(layer), example_input) 376 | 377 | 378 | def test_rms_norm_infer(): 379 | layer = nnx.RMSNorm(num_features=10, rngs=nnx.Rngs(0)) 380 | layer.eval() 381 | example_input = (jnp.zeros((20, 10)), ) 382 | run_and_compare(nnx.jit(layer), example_input) 383 | 384 | 385 | def test_group_norm_infer(): 386 | layer = nnx.GroupNorm(num_features=10, num_groups=2, rngs=nnx.Rngs(0)) 387 | layer.eval() 388 | example_input = (jnp.zeros((20, 10)), ) 389 | run_and_compare(nnx.jit(layer), example_input) 390 | -------------------------------------------------------------------------------- /tests/test_equinox.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import equinox as eqx 4 | import equinox.internal as eqxi 5 | 6 | from tests.utils import run_and_compare, run_and_compare_specific_input 7 | 8 | from functools import partial 9 | 10 | 11 | def run_and_compare_eqx_specific_input(model, inputs): 12 | return run_and_compare_specific_input( 13 | eqxi.finalise_fn(eqx.nn.inference_mode(model)), 14 | inputs 15 | ) 16 | 17 | 18 | def run_and_compare_eqx(model, input_spec): 19 | return run_and_compare( 20 | eqxi.finalise_fn(eqx.nn.inference_mode(model)), 21 | input_spec 22 | ) 23 | 24 | 25 | def test_conv_1d(): 26 | class SimpleConv(eqx.Module): 27 | conv: eqx.nn.Conv1d 28 | 29 | def __init__(self, key): 30 | self.conv = eqx.nn.Conv1d(in_channels=2, out_channels=5, kernel_size=3, key=key) 31 | 32 | def __call__(self, x): 33 | return self.conv(x) 34 | 35 | input_spec = (jnp.zeros((2, 30)), ) 36 | model = SimpleConv(jax.random.PRNGKey(0)) 37 | run_and_compare_eqx(model, input_spec) 38 | 39 | 40 | def test_conv_transpose_1d(): 41 | class SimpleConv(eqx.Module): 42 | conv: eqx.nn.ConvTranspose1d 43 | 44 | def __init__(self, key): 45 | self.conv = eqx.nn.ConvTranspose1d(in_channels=5, out_channels=2, kernel_size=3, key=key) 46 | 47 | def __call__(self, x): 48 | return self.conv(x) 49 | 50 | input_spec = (jnp.zeros((5, 30)), ) 51 | model = SimpleConv(jax.random.PRNGKey(0)) 52 | run_and_compare_eqx(model, input_spec) 53 | 54 | 55 | def test_conv_2d(): 56 | class SimpleConv(eqx.Module): 57 | conv: eqx.nn.Conv2d 58 | 59 | def __init__(self, key): 60 | self.conv = eqx.nn.Conv2d(in_channels=2, out_channels=5, kernel_size=(3, 4), key=key) 61 | 62 | def __call__(self, x): 63 | return self.conv(x) 64 | 65 | input_spec = (jnp.zeros((2, 40, 30)), ) 66 | model = SimpleConv(jax.random.PRNGKey(0)) 67 | run_and_compare_eqx(model, input_spec) 68 | 69 | 70 | def test_conv_transpose_2d(): 71 | class SimpleConv(eqx.Module): 72 | conv: eqx.nn.ConvTranspose2d 73 | 74 | def __init__(self, key): 75 | self.conv = eqx.nn.ConvTranspose2d(in_channels=5, out_channels=2, kernel_size=(3, 4), key=key) 76 | 77 | def __call__(self, x): 78 | return self.conv(x) 79 | 80 | input_spec = (jnp.zeros((5, 40, 30)), ) 81 | model = SimpleConv(jax.random.PRNGKey(0)) 82 | run_and_compare_eqx(model, input_spec) 83 | 84 | 85 | def test_conv_3d(): 86 | class SimpleConv(eqx.Module): 87 | conv: eqx.nn.Conv3d 88 | 89 | def __init__(self, key): 90 | self.conv = eqx.nn.Conv3d(in_channels=2, out_channels=5, kernel_size=(3, 4, 2), key=key) 91 | 92 | def __call__(self, x): 93 | return self.conv(x) 94 | 95 | input_spec = (jnp.zeros((2, 40, 30, 15)), ) 96 | model = SimpleConv(jax.random.PRNGKey(0)) 97 | run_and_compare_eqx(model, input_spec) 98 | 99 | 100 | def test_conv_transpose_3d(): 101 | class SimpleConv(eqx.Module): 102 | conv: eqx.nn.ConvTranspose3d 103 | 104 | def __init__(self, key): 105 | self.conv = eqx.nn.ConvTranspose3d(in_channels=5, out_channels=2, kernel_size=(3, 4, 2), key=key) 106 | 107 | def __call__(self, x): 108 | return self.conv(x) 109 | 110 | input_spec = (jnp.zeros((5, 40, 30, 15)), ) 111 | model = SimpleConv(jax.random.PRNGKey(0)) 112 | run_and_compare_eqx(model, input_spec) 113 | 114 | 115 | def test_odd_batch_dimension(): 116 | class SimpleConv(eqx.Module): 117 | conv: eqx.nn.Conv1d 118 | 119 | def __init__(self, key): 120 | self.conv = eqx.nn.Conv1d(in_channels=2, out_channels=5, kernel_size=3, key=key) 121 | 122 | def __call__(self, x): 123 | return self.conv(x) 124 | 125 | model = SimpleConv(jax.random.PRNGKey(0)) 126 | batched_model = jax.vmap(model, axis_name="batch", in_axes=2, out_axes=2) 127 | 128 | input_spec = (jnp.zeros((2, 30, 5)), ) 129 | run_and_compare_eqx(batched_model, input_spec) 130 | 131 | 132 | def test_linear(): 133 | model = jax.vmap(eqx.nn.Linear(in_features=10, out_features=20, key=jax.random.PRNGKey(0))) 134 | input_spec = (jnp.zeros((20, 10)), ) 135 | run_and_compare_eqx(model, input_spec) 136 | 137 | 138 | def test_identity(): 139 | model = jax.vmap(eqx.nn.Identity(key=jax.random.PRNGKey(0))) 140 | input_spec = (jnp.zeros((20, 10)), ) 141 | run_and_compare_eqx(model, input_spec) 142 | 143 | 144 | def test_gru_cell(): 145 | model = jax.vmap(eqx.nn.GRUCell(input_size=10, hidden_size=24, key=jax.random.PRNGKey(0))) 146 | input_spec = (jnp.zeros((20, 10)), jnp.zeros((20, 24))) 147 | run_and_compare_eqx(model, input_spec) 148 | 149 | 150 | def test_lstm_cell(): 151 | class Model(eqx.Module): 152 | cell: eqx.nn.LSTMCell 153 | 154 | def __init__(self): 155 | self.cell = eqx.nn.LSTMCell(input_size=10, hidden_size=24, key=jax.random.PRNGKey(0)) 156 | 157 | def __call__(self, xs): 158 | def scan_fn(state, input): 159 | return (self.cell(input, state), None) 160 | init_state = (jnp.zeros(self.cell.hidden_size), 161 | jnp.zeros(self.cell.hidden_size)) 162 | (h, c), _ = jax.lax.scan(scan_fn, init_state, xs) 163 | return h, c 164 | 165 | model = jax.vmap(Model()) 166 | input_spec = (jnp.zeros((20, 30, 10)), ) 167 | run_and_compare_eqx(model, input_spec) 168 | 169 | 170 | def test_rotary_attention(): 171 | class Model(eqx.Module): 172 | mha_attention: eqx.nn.MultiheadAttention 173 | rope_embeddings: eqx.nn.RotaryPositionalEmbedding 174 | 175 | def __init__(self, key: jax.random.PRNGKey): 176 | attention_key, rope_key = jax.random.split(key, 2) 177 | self.mha_attention = eqx.nn.MultiheadAttention( 178 | num_heads=4, 179 | query_size=24, 180 | key=attention_key, 181 | ) 182 | self.rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size=6) 183 | 184 | def __call__(self, q, k, v): 185 | def process_heads(key_heads, query_heads, value_heads): 186 | query_heads = jax.vmap(self.rope_embeddings, 187 | in_axes=1, 188 | out_axes=1)(query_heads) 189 | key_heads = jax.vmap(self.rope_embeddings, 190 | in_axes=1, 191 | out_axes=1)(key_heads) 192 | 193 | return query_heads, key_heads, value_heads 194 | 195 | x = self.mha_attention(q, k, v, process_heads=process_heads) 196 | return x 197 | 198 | model = jax.vmap(Model(jax.random.PRNGKey(0))) 199 | input_spec = (jnp.zeros((5, 15, 24)), jnp.zeros((5, 15, 24)), jnp.zeros((5, 15, 24))) 200 | run_and_compare_eqx(model, input_spec) 201 | 202 | 203 | def test_prelu(): 204 | model = jax.vmap(eqx.nn.PReLU()) 205 | input_spec = (jnp.zeros((5, 20)), ) 206 | run_and_compare_eqx(model, input_spec) 207 | 208 | 209 | def test_1d_polling(): 210 | channels = 3 211 | run_and_compare_eqx(eqx.nn.AvgPool1d(kernel_size=3), (jnp.zeros((channels, 41, )), )) 212 | run_and_compare_eqx(eqx.nn.AvgPool1d(kernel_size=3, stride=2), (jnp.zeros((channels, 41, )), )) 213 | run_and_compare_eqx(eqx.nn.AvgPool1d(kernel_size=3, stride=3), (jnp.zeros((channels, 41, )), )) 214 | 215 | run_and_compare_eqx(eqx.nn.AvgPool1d(kernel_size=3, padding=2), (jnp.zeros((channels, 41, )), )) 216 | run_and_compare_eqx(eqx.nn.AvgPool1d(kernel_size=3, stride=2, padding=3), (jnp.zeros((channels, 41, )), )) 217 | 218 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=3), (jnp.zeros((channels, 41, )), )) 219 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=3, stride=2), (jnp.zeros((channels, 41, )), )) 220 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=3, stride=3), (jnp.zeros((channels, 41, )), )) 221 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=3, stride=4), (jnp.zeros((channels, 41, )), )) 222 | 223 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=4, padding=3), (jnp.zeros((channels, 41, )), )) 224 | run_and_compare_eqx(eqx.nn.MaxPool1d(kernel_size=3, stride=3, padding=2), (jnp.zeros((channels, 41, )), )) 225 | 226 | run_and_compare_eqx(eqx.nn.AdaptiveAvgPool1d(target_shape=4), (jnp.zeros((channels, 41, )), )) 227 | run_and_compare_eqx(eqx.nn.AdaptiveMaxPool1d(target_shape=5), (jnp.zeros((channels, 41, )), )) 228 | 229 | batch_size = 10 230 | run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool1d(kernel_size=3)), (jnp.zeros((batch_size, channels, 41, )), )) 231 | run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool1d(kernel_size=3, stride=2)), (jnp.zeros((batch_size, channels, 41, )), )) 232 | run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool1d(kernel_size=3, stride=3)), (jnp.zeros((batch_size, channels, 41, )), )) 233 | 234 | 235 | def test_2d_polling(): 236 | channels = 3 237 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4)), (jnp.zeros((channels, 41, 21)), )) 238 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4), stride=(2, 3)), (jnp.zeros((channels, 41, 21)), )) 239 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4), stride=(3, 2)), (jnp.zeros((channels, 41, 21)), )) 240 | 241 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4), padding=(2, 4)), (jnp.zeros((channels, 41, 21)), )) 242 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4), padding=(3, 1), stride=(1, 2)), (jnp.zeros((channels, 41, 21)), )) 243 | run_and_compare_eqx(eqx.nn.AvgPool2d(kernel_size=(3, 4), padding=(1, 2), stride=(3, 2)), (jnp.zeros((channels, 41, 21)), )) 244 | 245 | run_and_compare_eqx(eqx.nn.AdaptiveAvgPool2d(target_shape=(4, 2)), (jnp.zeros((channels, 41, 21)), )) 246 | run_and_compare_eqx(eqx.nn.AdaptiveMaxPool2d(target_shape=(3, 5)), (jnp.zeros((channels, 41, 21)), )) 247 | 248 | batch_size = 10 249 | run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool2d(kernel_size=(3, 4))), (jnp.zeros((batch_size, channels, 41, 21)), )) 250 | 251 | 252 | def test_3d_polling(): 253 | channels = 3 254 | run_and_compare_eqx(eqx.nn.AvgPool3d(kernel_size=(5, 4, 3)), (jnp.zeros((channels, 41, 21, 10)), )) 255 | run_and_compare_eqx(eqx.nn.MaxPool3d(kernel_size=(5, 4, 3)), (jnp.zeros((channels, 41, 21, 10)), )) 256 | 257 | # Due to the CoreML rank <= 5 condition, the result can unfortunately not fit in a tensor 258 | # run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool3d(kernel_size=(5, 4, 3))), (jnp.zeros((10, channels, 41, 21, 10)), )) 259 | 260 | 261 | def test_layernorm(): 262 | batch_size = 3 263 | input_shape = (10, 3) 264 | run_and_compare_eqx(jax.vmap(eqx.nn.LayerNorm(shape=input_shape)), (jnp.zeros((batch_size, *input_shape)), )) 265 | 266 | 267 | def test_rmsnorm(): 268 | batch_size = 3 269 | input_shape = (10, 3) 270 | run_and_compare_eqx(jax.vmap(eqx.nn.RMSNorm(shape=input_shape)), (jnp.zeros((batch_size, *input_shape)), )) 271 | 272 | 273 | def test_groupnorm(): 274 | batch_size = 3 275 | input_shape = (4, 12) 276 | run_and_compare_eqx( 277 | jax.vmap(eqx.nn.GroupNorm(groups=4, channelwise_affine=False)), 278 | (jnp.zeros((batch_size, *input_shape)), ) 279 | ) 280 | run_and_compare_eqx( 281 | jax.vmap(eqx.nn.GroupNorm(groups=2, channels=4)), 282 | (jnp.zeros((batch_size, *input_shape)), ) 283 | ) 284 | 285 | 286 | def test_batchnorm(): 287 | batch_size = 3 288 | input_shape = (4, 12) 289 | 290 | class Model(eqx.Module): 291 | batch_norm: eqx.nn.BatchNorm 292 | 293 | def __init__(self, input_size: int, axis_name: str): 294 | self.batch_norm = eqx.nn.BatchNorm( 295 | input_size=input_size, 296 | axis_name=axis_name, 297 | ) 298 | 299 | def __call__(self, x, state): 300 | out, _state = self.batch_norm(x, state) 301 | return out 302 | 303 | model, state = eqx.nn.make_with_state(Model)(input_size=4, axis_name="batch") 304 | batched_model = jax.vmap(partial(model, state=state), axis_name="batch") 305 | run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, *input_shape)), )) 306 | 307 | 308 | def test_spectralnorm(): 309 | batch_size = 5 310 | 311 | class Model(eqx.Module): 312 | spectral_norm: eqx.nn.SpectralNorm[eqx.Module] 313 | 314 | def __init__(self, wrapping_layer: eqx.Module, key: jax.random.PRNGKey): 315 | self.spectral_norm = eqx.nn.SpectralNorm( 316 | layer=wrapping_layer, 317 | weight_name="weight", 318 | key=key, 319 | ) 320 | 321 | def __call__(self, x, state): 322 | out, _state = self.spectral_norm(x, state) 323 | return out 324 | 325 | wrapping_key, model_key = jax.random.split(jax.random.PRNGKey(0), 2) 326 | 327 | # Linear wrapping layer 328 | model, state = eqx.nn.make_with_state(Model)( 329 | wrapping_layer=eqx.nn.Linear(in_features=12, out_features=24, key=wrapping_key), 330 | key=model_key, 331 | ) 332 | batched_model = jax.vmap(partial(model, state=state)) 333 | run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12)), )) 334 | 335 | # Convolutional 1d wrapping layer 336 | model, state = eqx.nn.make_with_state(Model)( 337 | wrapping_layer=eqx.nn.Conv1d(in_channels=12, out_channels=24, kernel_size=3, key=wrapping_key), 338 | key=model_key, 339 | ) 340 | batched_model = jax.vmap(partial(model, state=state)) 341 | run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12, 31)), )) 342 | 343 | # Convolutional 2d wrapping layer 344 | model, state = eqx.nn.make_with_state(Model)( 345 | wrapping_layer=eqx.nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, key=wrapping_key), 346 | key=model_key, 347 | ) 348 | batched_model = jax.vmap(partial(model, state=state)) 349 | run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12, 31, 15)), )) 350 | 351 | 352 | def test_weightnorm(): 353 | batch_size = 5 354 | key = jax.random.PRNGKey(0) 355 | 356 | class Model(eqx.Module): 357 | weight_norm: eqx.nn.WeightNorm[eqx.Module] 358 | 359 | def __init__(self, wrapping_layer: eqx.Module): 360 | self.weight_norm = eqx.nn.WeightNorm( 361 | layer=wrapping_layer, 362 | weight_name="weight", 363 | ) 364 | 365 | def __call__(self, x): 366 | return self.weight_norm(x) 367 | 368 | # Linear wrapping layer 369 | model = jax.vmap(Model( 370 | wrapping_layer=eqx.nn.Linear(in_features=12, out_features=24, key=key), 371 | )) 372 | run_and_compare_eqx(model, (jnp.zeros((batch_size, 12)), )) 373 | 374 | # Convolutional 1d wrapping layer 375 | model = jax.vmap(Model( 376 | wrapping_layer=eqx.nn.Conv1d(in_channels=12, out_channels=24, kernel_size=3, key=key), 377 | )) 378 | run_and_compare_eqx(model, (jnp.zeros((batch_size, 12, 31)), )) 379 | 380 | # Convolutional 2d wrapping layer 381 | model = jax.vmap(Model( 382 | wrapping_layer=eqx.nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, key=key), 383 | )) 384 | run_and_compare_eqx(model, (jnp.zeros((batch_size, 12, 31, 15)), )) 385 | 386 | 387 | def test_embedding(): 388 | key = jax.random.PRNGKey(0) 389 | run_and_compare_eqx_specific_input( 390 | jax.vmap(eqx.nn.Embedding(num_embeddings=5, embedding_size=10, key=key)), 391 | (jnp.array([0, 1, 2, 3, 4], dtype=jnp.int32), ) 392 | ) 393 | 394 | 395 | def test_mlp(): 396 | model = jax.vmap(eqx.nn.MLP( 397 | in_size=10, 398 | out_size=20, 399 | width_size=30, 400 | depth=3, 401 | key=jax.random.PRNGKey(0)) 402 | ) 403 | input_spec = (jnp.zeros((20, 10)), ) 404 | run_and_compare_eqx(model, input_spec) 405 | 406 | 407 | def test_sequential(): 408 | model = jax.vmap(eqx.nn.Sequential( 409 | [ 410 | eqx.nn.Linear(in_features=10, out_features=20, key=jax.random.PRNGKey(0)), 411 | eqx.nn.Lambda(jax.nn.relu), 412 | ] 413 | )) 414 | input_spec = (jnp.zeros((20, 10)), ) 415 | run_and_compare_eqx(model, input_spec) 416 | -------------------------------------------------------------------------------- /tests/test_jax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | import pytest 5 | 6 | from tests.utils import run_and_compare, run_and_compare_specific_input, get_model_instruction_types 7 | 8 | 9 | def test_addition(): 10 | run_and_compare(jnp.add, (jnp.float32(1), jnp.float32(1))) 11 | run_and_compare(jnp.add, (jnp.zeros((2, 2, 2)), jnp.zeros((2, 2, 2)))) 12 | 13 | 14 | def test_div(): 15 | run_and_compare(jnp.divide, (jnp.float32(1), jnp.float32(1))) 16 | 17 | dim_size = 20 18 | run_and_compare(jnp.divide, (jnp.zeros((dim_size, dim_size)), jnp.zeros((dim_size, dim_size)))) 19 | run_and_compare(jnp.divide, (jnp.zeros((dim_size, dim_size)), jnp.float32(1))) 20 | run_and_compare(jnp.divide, (jnp.float32(1), jnp.zeros((dim_size, dim_size)))) 21 | run_and_compare(jnp.divide, (jnp.zeros((dim_size, dim_size)), jnp.zeros((dim_size, dim_size)))) 22 | 23 | run_and_compare(jnp.divide, (jnp.int32(1), jnp.int32(1))) 24 | run_and_compare(jnp.divide, ( 25 | jnp.zeros((dim_size, dim_size), dtype=jnp.int32), 26 | jnp.zeros((dim_size, dim_size), dtype=jnp.int32) 27 | )) 28 | run_and_compare(jnp.divide, (jnp.zeros((dim_size, dim_size), dtype=jnp.int32), jnp.int32(1))) 29 | run_and_compare(jnp.divide, (jnp.int32(1), jnp.zeros((dim_size, dim_size), dtype=jnp.int32))) 30 | 31 | 32 | def test_tensor_multiplication(): 33 | def scalar_product(lhs, rhs): 34 | return jnp.einsum("a,a", lhs, rhs) 35 | 36 | def scalar_with_vector(lhs, rhs): 37 | return jnp.einsum("a,b->ab", lhs, rhs) 38 | 39 | def scalar_with_matrix(lhs, rhs): 40 | return jnp.einsum("a,bc->abc", lhs, rhs) 41 | 42 | def vector_with_matrix(lhs, rhs): 43 | return jnp.einsum("a,ab->b", lhs, rhs) 44 | 45 | def matrix_multiplication(lhs, rhs): 46 | return jnp.einsum("ij,jk -> ik", lhs, rhs) 47 | 48 | def outer_product_with_single_batch_dim(lhs, rhs): 49 | return jnp.einsum("abc,ajk->abcjk", lhs, rhs) 50 | 51 | def single_contraction_single_batch(lhs, rhs): 52 | return jnp.einsum("abcd,ackl->abdkl", lhs, rhs) 53 | 54 | def two_contractions_single_batch(lhs, rhs): 55 | return jnp.einsum("abcd,ackd->abk", lhs, rhs) 56 | 57 | def three_contractions_single_batch(lhs, rhs): 58 | return jnp.einsum("abcd,acbd->a", lhs, rhs) 59 | 60 | def contract_all(lhs, rhs): 61 | return jnp.einsum("abcd,acbd", lhs, rhs) 62 | 63 | def full_tensor_product(lhs, rhs): 64 | return jnp.einsum("ab,ihj->abihj", lhs, rhs) 65 | 66 | def full_tensor_product_1_4(lhs, rhs): 67 | return jnp.einsum("a,ihjk->aihjk", lhs, rhs) 68 | 69 | def full_tensor_product_3_2(lhs, rhs): 70 | return jnp.einsum("abc,ih->abcih", lhs, rhs) 71 | 72 | def full_tensor_product_4_1(lhs, rhs): 73 | return jnp.einsum("abcd,i->abcdi", lhs, rhs) 74 | 75 | run_and_compare(scalar_product, (jnp.zeros((1)), jnp.zeros((1)))) 76 | run_and_compare(scalar_with_vector, (jnp.zeros((1)), jnp.zeros((5)))) 77 | run_and_compare(scalar_with_matrix, (jnp.zeros((1)), jnp.zeros((5, 3)))) 78 | run_and_compare(vector_with_matrix, (jnp.zeros((5)), jnp.zeros((5, 3)))) 79 | run_and_compare(matrix_multiplication, (jnp.zeros((3, 4)), jnp.zeros((4, 5)))) 80 | run_and_compare(outer_product_with_single_batch_dim, (jnp.zeros((2, 3, 4)), jnp.zeros((2, 4, 5)))) 81 | run_and_compare(single_contraction_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 2, 5)))) 82 | run_and_compare(two_contractions_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 2, 5)))) 83 | run_and_compare(three_contractions_single_batch, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 3, 5)))) 84 | run_and_compare(full_tensor_product, (jnp.zeros((2, 3)), jnp.zeros((2, 4, 3)))) 85 | run_and_compare(contract_all, (jnp.zeros((2, 3, 4, 5)), jnp.zeros((2, 4, 3, 5)))) 86 | 87 | # # Test the full tensor product with a big dimensions, and ensure that the program gets handled by a dynamic loop 88 | run_and_compare(full_tensor_product, (jnp.zeros((10, 3)), jnp.zeros((15, 20, 3)))) 89 | run_and_compare(full_tensor_product_1_4, (jnp.zeros((10,)), jnp.zeros((15, 20, 5, 3)))) 90 | run_and_compare(full_tensor_product_1_4, (jnp.zeros((2,)), jnp.zeros((2, 2, 2, 3)))) 91 | run_and_compare(full_tensor_product_3_2, (jnp.zeros((20, 10, 3)), jnp.zeros((15, 3)))) 92 | run_and_compare(full_tensor_product_3_2, (jnp.zeros((2, 2, 3)), jnp.zeros((2, 3)))) 93 | run_and_compare(full_tensor_product_4_1, (jnp.zeros(((15, 20, 5, 3))), jnp.zeros((10,)))) 94 | run_and_compare(full_tensor_product_4_1, (jnp.zeros(((2, 2, 2, 3))), jnp.zeros((2,)))) 95 | 96 | 97 | def test_simple_reductions(): 98 | def compare_and_ensure_no_loops(jax_func, input_spec): 99 | cml_model = run_and_compare(jax_func, input_spec) 100 | assert "while_loop" not in get_model_instruction_types(cml_model) 101 | 102 | compare_and_ensure_no_loops(partial(jnp.max, axis=1), (jnp.zeros((2, 3, 4)),)) 103 | compare_and_ensure_no_loops(partial(jnp.max, axis=1, keepdims=True), (jnp.zeros((2, 3, 4)),)) 104 | compare_and_ensure_no_loops(partial(jnp.sum, axis=0), (jnp.zeros((2, 3, 4)),)) 105 | compare_and_ensure_no_loops(partial(jnp.sum, axis=1), (jnp.zeros((2, 3, 4)),)) 106 | compare_and_ensure_no_loops(partial(jnp.sum, axis=2), (jnp.zeros((2, 3, 4)),)) 107 | compare_and_ensure_no_loops(partial(jnp.sum, axis=(0, 2)), (jnp.zeros((2, 3, 4)),)) 108 | compare_and_ensure_no_loops(partial(jnp.sum, axis=(0, 1, 2)), (jnp.zeros((2, 3, 4)),)) 109 | compare_and_ensure_no_loops(partial(jnp.min, axis=0), (jnp.zeros((2, 3, 4)),)) 110 | compare_and_ensure_no_loops(partial(jnp.min, axis=(1, 2)), (jnp.zeros((2, 3, 4)),)) 111 | compare_and_ensure_no_loops(partial(jnp.mean, axis=0), (jnp.zeros((2, 3, 4)),)) 112 | compare_and_ensure_no_loops(partial(jnp.prod, axis=1), (jnp.zeros((2, 3, 4)),)) 113 | 114 | 115 | def test_complex_reductions(): 116 | """ 117 | These reductions are complicated, and will be handled using while loops (potentially unrolled) 118 | """ 119 | run_and_compare(jnp.argmax, (jnp.zeros((2, 3, 3)),)) 120 | run_and_compare(partial(jnp.argmax, keepdims=True), (jnp.zeros((2, 3, 3)),)) 121 | run_and_compare(jnp.argmax, (jnp.zeros((20, 30, 40)),)) 122 | run_and_compare(partial(jnp.argmax, keepdims=True), (jnp.zeros((20, 30, 40)),)) 123 | run_and_compare(partial(jnp.argmax, axis=1), (jnp.zeros((2, 3, 3)),)) 124 | run_and_compare(partial(jnp.argmax, axis=0, keepdims=True), (jnp.zeros((2, 3, 3)),)) 125 | run_and_compare(partial(jnp.argmax, axis=2), (jnp.zeros((2, 3, 3)),)) 126 | run_and_compare(partial(jnp.argmax, axis=1), (jnp.zeros((20, 30, 40)),)) 127 | 128 | run_and_compare(jnp.argmin, (jnp.zeros((2, 3, 3)),)) 129 | run_and_compare(partial(jnp.argmin, axis=1), (jnp.zeros((2, 3, 3)),)) 130 | run_and_compare(partial(jnp.argmin, axis=1), (jnp.zeros((20, 100, 40)),)) 131 | run_and_compare(partial(jnp.argmin, axis=1, keepdims=True), (jnp.zeros((20, 100, 40)),)) 132 | 133 | 134 | def test_topk(): 135 | input_shape = (3, 5, 10) 136 | run_and_compare(partial(jax.lax.top_k, k=3), (jnp.zeros(input_shape),)) 137 | 138 | 139 | def test_reverse(): 140 | run_and_compare(jnp.flip, (jnp.zeros((5,)),)) 141 | run_and_compare(jnp.flip, (jnp.zeros((5, 5, 5)),)) 142 | run_and_compare(partial(jnp.flip, axis=(0, 2)), (jnp.zeros((5, 5, 5)),)) 143 | run_and_compare(partial(jnp.flip, axis=(1,)), (jnp.zeros((5, 5, 5)),)) 144 | 145 | 146 | def test_trigonmetry(): 147 | run_and_compare(jnp.sin, (jnp.zeros((5, 6)),)) 148 | run_and_compare(jnp.cos, (jnp.zeros((5, 6)),)) 149 | run_and_compare(jnp.tan, (jnp.zeros((5, 6)),)) 150 | 151 | run_and_compare(jnp.arcsin, (jnp.zeros((5, 6)),)) 152 | run_and_compare(jnp.arccos, (jnp.zeros((5, 6)),)) 153 | run_and_compare(jnp.arctan, (jnp.zeros((5, 6)),)) 154 | 155 | run_and_compare(jnp.sinh, (jnp.zeros((5, 6)),)) 156 | run_and_compare(jnp.cosh, (jnp.zeros((5, 6)),)) 157 | run_and_compare(jnp.tanh, (jnp.zeros((5, 6)),)) 158 | 159 | run_and_compare(jnp.arcsinh, (jnp.zeros((5, 6)),)) 160 | run_and_compare(jnp.arccosh, (jnp.zeros((5, 6)),)) 161 | run_and_compare(jnp.arctanh, (jnp.zeros((5, 6)),)) 162 | 163 | run_and_compare(jnp.atan2, (jnp.zeros((50, 20)), jnp.zeros((50, 20)),)) 164 | 165 | 166 | def test_is_finite(): 167 | input = (jnp.array([20.0, -12.23, jnp.inf, -jnp.inf, jnp.nan], dtype=jnp.float16), ) 168 | run_and_compare_specific_input(jnp.isfinite, input) 169 | run_and_compare_specific_input(jnp.isinf, input) 170 | run_and_compare_specific_input(jnp.isnan, input) 171 | 172 | 173 | def test_take(): 174 | run_and_compare_specific_input(jnp.take, (jnp.reshape(jnp.arange(24), (4, 6)), jnp.array([ 175 | [[0, 0], [1, 1], [2, 2]] 176 | ], dtype=jnp.int32))) 177 | 178 | 179 | def test_gather(): 180 | from jax.lax import GatherDimensionNumbers 181 | 182 | def wrapped_gather(dimension_numbers, slice_sizes): 183 | @jax.jit 184 | def internal_gather(operand, start_indices): 185 | return jax.lax.gather( 186 | operand=operand, 187 | start_indices=start_indices, 188 | dimension_numbers=dimension_numbers, 189 | slice_sizes=slice_sizes, 190 | ) 191 | return internal_gather 192 | 193 | operand = jnp.reshape(jnp.arange(8000), (10, 8, 5, 20)) 194 | start_indices = jnp.array([ 195 | [1, 1], [3, 1], [1, 10], [4, 15], 196 | ], dtype=jnp.int32) 197 | 198 | dimension_numbers = GatherDimensionNumbers( 199 | offset_dims=(0, 1), 200 | collapsed_slice_dims=(0, 2,), 201 | start_index_map=(1, 3, ) 202 | ) 203 | 204 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 2, 1, 3)), (operand, start_indices)) 205 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 3, 1, 4)), (operand, start_indices)) 206 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 3, 1, 7)), (operand, start_indices)) 207 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 8, 1, 2)), (operand, start_indices)) 208 | 209 | dimension_numbers = GatherDimensionNumbers( 210 | offset_dims=(1, 2), 211 | collapsed_slice_dims=(0, 2,), 212 | start_index_map=(1, 3, ) 213 | ) 214 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 3, 1, 4)), (operand, start_indices)) 215 | 216 | dimension_numbers = GatherDimensionNumbers( 217 | offset_dims=(0, 2), 218 | collapsed_slice_dims=(0, 2,), 219 | start_index_map=(1, 3, ) 220 | ) 221 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 3, 1, 4)), (operand, start_indices)) 222 | 223 | operand = jnp.reshape(jnp.arange(50), (5, 10)) 224 | start_indices = jnp.array([ 225 | [0, 1], [1, 0], [0, 0], [2, 6], [4, 2] 226 | ], dtype=jnp.int32) 227 | dimension_numbers = GatherDimensionNumbers( 228 | offset_dims=(1, 2), 229 | collapsed_slice_dims=tuple(), 230 | start_index_map=(0, 1) 231 | ) 232 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 10)), (operand, start_indices)) 233 | 234 | operand = jnp.reshape(jnp.arange(500), (5, 2, 5, 10)) 235 | start_indices = jnp.array([ 236 | [0, 1, 4], [1, 0, 8], [0, 0, 2], [2, 6, 1], [4, 2, 0] 237 | ], dtype=jnp.int32) 238 | dimension_numbers = GatherDimensionNumbers( 239 | offset_dims=(0, 1, 2, 3), 240 | collapsed_slice_dims=tuple(), 241 | start_index_map=(0, 1, 2) 242 | ) 243 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 1, 2, 4)), (operand, start_indices)) 244 | 245 | operand = jnp.reshape(jnp.arange(50), (10, 5)) 246 | start_indices = jnp.array([ 247 | [[3], [1], [7]], 248 | [[4], [0], [9]] 249 | ], dtype=jnp.int32) # (2, 3, 1) 250 | 251 | dimension_numbers = GatherDimensionNumbers( 252 | offset_dims=(2,), 253 | collapsed_slice_dims=(0,), 254 | start_index_map=(0,) 255 | ) 256 | 257 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 5)), (operand, start_indices)) 258 | 259 | 260 | def test_complex_gather(): 261 | from jax.lax import GatherDimensionNumbers 262 | 263 | def wrapped_gather(dimension_numbers, slice_sizes): 264 | @jax.jit 265 | def internal_gather(operand, start_indices): 266 | return jax.lax.gather( 267 | operand=operand, 268 | start_indices=start_indices, 269 | dimension_numbers=dimension_numbers, 270 | slice_sizes=slice_sizes, 271 | ) 272 | return internal_gather 273 | 274 | start_indices = [ 275 | [ 276 | [[0, 0], [1, 0], [2, 1]], 277 | [[0, 1], [1, 1], [0, 9]] 278 | ], 279 | [ 280 | [[0, 0], [2, 1], [2, 2]], 281 | [[1, 2], [0, 1], [1, 0]] 282 | ] 283 | ] 284 | start_indices = jnp.array(start_indices, dtype=jnp.int32) 285 | operand = jnp.arange(1, 49).reshape((2, 3, 4, 2)) 286 | dimension_numbers = GatherDimensionNumbers( 287 | offset_dims=(3, 4), 288 | collapsed_slice_dims=(1,), 289 | operand_batching_dims=(0,), 290 | start_indices_batching_dims=(0,), 291 | start_index_map=(2, 1), 292 | ) 293 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 1, 1, 2)), (operand, start_indices)) 294 | 295 | operand = jnp.arange(1, 25).reshape((3, 4, 2)) 296 | dimension_numbers = GatherDimensionNumbers( 297 | offset_dims=(2, 3), 298 | collapsed_slice_dims=(0,), 299 | start_index_map=(1, 0), 300 | ) 301 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 2, 2)), (operand, start_indices[0])) 302 | 303 | operand = jnp.arange(1, 49).reshape((2, 3, 4, 2)) 304 | dimension_numbers = GatherDimensionNumbers( 305 | offset_dims=(3, 4), 306 | collapsed_slice_dims=(1,), 307 | operand_batching_dims=(0,), 308 | start_indices_batching_dims=(1,), 309 | start_index_map=(2, 1), 310 | ) 311 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 1, 1, 2)), (operand, start_indices)) 312 | 313 | start_indices = jnp.concatenate((start_indices, start_indices[::-1, 1:]), 1) 314 | dimension_numbers = GatherDimensionNumbers( 315 | offset_dims=(3, 4), 316 | collapsed_slice_dims=(), 317 | operand_batching_dims=(0, 1), 318 | start_indices_batching_dims=(0, 1), 319 | start_index_map=(3, 2), 320 | ) 321 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 1, 1, 1)), (operand, start_indices)) 322 | 323 | 324 | def test_large_gather(): 325 | # Test gather with large indices to verify if optimization is necessary 326 | # and if it works correctly. 327 | from jax.lax import GatherDimensionNumbers 328 | 329 | def wrapped_gather(dimension_numbers, slice_sizes): 330 | @jax.jit 331 | def internal_gather(operand, start_indices): 332 | return jax.lax.gather( 333 | operand=operand, 334 | start_indices=start_indices, 335 | dimension_numbers=dimension_numbers, 336 | slice_sizes=slice_sizes, 337 | ) 338 | return internal_gather 339 | 340 | # Create a large operand and indices 341 | # Operand: (10, 20, 30) 342 | operand = jnp.reshape(jnp.arange(6000), (10, 20, 30)) 343 | 344 | # Indices: (100, 2) -> gathering 100 slices 345 | # We want enough indices to potentially trigger unroll limits if not optimized 346 | num_indices = 100 347 | start_indices = jnp.zeros((num_indices, 2), dtype=jnp.int32) 348 | # Fill with some valid indices 349 | for i in range(num_indices): 350 | start_indices = start_indices.at[i, 0].set(i % 10) 351 | start_indices = start_indices.at[i, 1].set(i % 20) 352 | 353 | dimension_numbers = GatherDimensionNumbers( 354 | offset_dims=(1,), 355 | collapsed_slice_dims=(0, 1), 356 | start_index_map=(0, 1) 357 | ) 358 | 359 | # slice_sizes: (1, 1, 30) 360 | # We gather from dim 0 and 1. Dim 2 is kept. 361 | # Result shape: (100, 30) 362 | 363 | run_and_compare_specific_input(wrapped_gather(dimension_numbers, (1, 1, 30)), (operand, start_indices)) 364 | 365 | 366 | def test_simple_scatter(): 367 | def scatter_set(arr): 368 | indices = jnp.arange(arr.shape[0] // 2) * 2 369 | updates = jnp.arange(indices.shape[0]) 370 | return arr.at[indices].set(updates) 371 | run_and_compare(scatter_set, (jnp.zeros((30,)),)) 372 | 373 | def scatter_add(arr): 374 | indices = jnp.arange(arr.shape[0] // 2) * 2 375 | updates = jnp.arange(indices.shape[0]) 376 | return arr.at[indices].add(updates) 377 | run_and_compare(scatter_add, (jnp.zeros((30,)),)) 378 | 379 | def scatter_sub(arr): 380 | indices = jnp.arange(arr.shape[0] // 2) * 2 381 | updates = jnp.arange(indices.shape[0]) 382 | return arr.at[indices].subtract(updates) 383 | run_and_compare(scatter_sub, (jnp.zeros((30,)),)) 384 | 385 | def scatter_mul(arr): 386 | indices = jnp.arange(arr.shape[0] // 2) * 2 387 | updates = jnp.arange(indices.shape[0]) 388 | return arr.at[indices].multiply(updates) 389 | run_and_compare(scatter_mul, (jnp.zeros((30,)),)) 390 | 391 | def scatter_div(arr): 392 | indices = jnp.arange(arr.shape[0] // 2) * 2 393 | updates = jnp.arange(indices.shape[0]) 394 | return arr.at[indices].divide(updates) 395 | run_and_compare(scatter_div, (jnp.zeros((30,)),)) 396 | 397 | def scatter_max(arr): 398 | indices = jnp.arange(arr.shape[0] // 2) * 2 399 | updates = jnp.arange(indices.shape[0]) 400 | return arr.at[indices].max(updates) 401 | run_and_compare(scatter_max, (jnp.zeros((30,)),)) 402 | 403 | def scatter_min(arr): 404 | indices = jnp.arange(arr.shape[0] // 2) * 2 405 | updates = jnp.arange(indices.shape[0]) 406 | return arr.at[indices].min(updates) 407 | run_and_compare(scatter_min, (jnp.zeros((30,)),)) 408 | 409 | 410 | def test_scatter_with_dimension_numbers(): 411 | from jax.lax import ScatterDimensionNumbers 412 | 413 | def wrapped_scatter_add(dimension_numbers): 414 | @jax.jit 415 | def internal_scatter_add(operand, scatter_indices, updates): 416 | return jax.lax.scatter_add( 417 | operand=operand, 418 | scatter_indices=scatter_indices, 419 | updates=updates, 420 | dimension_numbers=dimension_numbers, 421 | ) 422 | return internal_scatter_add 423 | 424 | # https://raw.githubusercontent.com/openxla/stablehlo/bd8d708/docs/images/spec/scatter.svg 425 | # original test case features partially filled update dimension windows 426 | 427 | scatter_indices = [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]] 428 | scatter_indices = jnp.array(scatter_indices) 429 | operand = jnp.arange(1, 25).reshape((3, 4, 2)) 430 | update = jnp.ones((2, 3, 2), dtype=jnp.int32) 431 | dimension_numbers = ScatterDimensionNumbers( 432 | update_window_dims=(2,), 433 | inserted_window_dims=(0, 1), 434 | scatter_dims_to_operand_dims=(1, 0), 435 | ) 436 | 437 | run_and_compare_specific_input(wrapped_scatter_add(dimension_numbers), (operand, scatter_indices, update)) 438 | 439 | 440 | @pytest.mark.parametrize("op_fn,op_name", [ 441 | (lambda x, y, z, dnums: jax.lax.scatter_add(x, y, z, dnums), "add"), 442 | (lambda x, y, z, dnums: jax.lax.scatter_mul(x, y, z, dnums), "mul"), 443 | (lambda x, y, z, dnums: jax.lax.scatter_min(x, y, z, dnums), "min"), 444 | (lambda x, y, z, dnums: jax.lax.scatter_max(x, y, z, dnums), "max"), 445 | # scatter_apply (set) is slightly different in JAX, usually just scatter 446 | (lambda x, y, z, dnums: jax.lax.scatter(x, y, z, dnums), "set"), 447 | ], ids=["add", "mul", "min", "max", "set"]) 448 | def test_scatter_middle_update_rank1(op_fn, op_name): 449 | # Update a slice in the middle, leaving the end untouched. 450 | # Operand: [0, 1, 2, 3, 4] 451 | # Indices: [1] (shape (1,)) 452 | # Updates: [10, 11] (shape (2,)) 453 | # Expected (add): [0, 11, 13, 3, 4] 454 | 455 | def scatter_func(operand, indices, updates): 456 | dnums = jax.lax.ScatterDimensionNumbers( 457 | update_window_dims=(0,), 458 | inserted_window_dims=(), 459 | scatter_dims_to_operand_dims=(0,) 460 | ) 461 | return op_fn(operand, indices, updates, dnums) 462 | 463 | operand = jnp.array([0, 1, 2, 3, 4], dtype=jnp.float32) 464 | indices = jnp.array([1], dtype=jnp.int32) # Shape (1,) 465 | updates = jnp.array([10, 11], dtype=jnp.float32) # Shape (2,) 466 | 467 | run_and_compare_specific_input(scatter_func, (operand, indices, updates)) 468 | 469 | 470 | @pytest.mark.parametrize("op_fn,op_name", [ 471 | (lambda x, y: x.at[y].set, "set"), 472 | (lambda x, y: x.at[y].add, "add"), 473 | (lambda x, y: x.at[y].subtract, "subtract"), 474 | (lambda x, y: x.at[y].multiply, "multiply"), 475 | (lambda x, y: x.at[y].divide, "divide"), 476 | (lambda x, y: x.at[y].max, "max"), 477 | (lambda x, y: x.at[y].min, "min"), 478 | ], ids=["set", "add", "subtract", "multiply", "divide", "max", "min"]) 479 | @pytest.mark.parametrize("shape", [ 480 | (30,), 481 | (10, 3), 482 | (5, 2, 3), 483 | ], ids=["rank1", "rank2", "rank3"]) 484 | def test_scatter(op_fn, op_name, shape): 485 | arr = jnp.zeros(shape) 486 | 487 | axis_len = arr.shape[0] 488 | indices = jnp.arange(axis_len // 2) * 2 489 | updates_shape = (indices.shape[0],) + arr.shape[1:] 490 | updates = jnp.arange(jnp.prod(jnp.array(updates_shape))).reshape(updates_shape) 491 | 492 | def scatter_op(arr): 493 | return op_fn(arr, indices)(updates) 494 | 495 | run_and_compare(scatter_op, (arr,)) 496 | 497 | 498 | @pytest.mark.parametrize("op_fn,op_name", [ 499 | (lambda x, y: x.at[y].set, "set"), 500 | (lambda x, y: x.at[y].add, "add"), 501 | (lambda x, y: x.at[y].subtract, "subtract"), 502 | (lambda x, y: x.at[y].multiply, "multiply"), 503 | (lambda x, y: x.at[y].divide, "divide"), 504 | (lambda x, y: x.at[y].max, "max"), 505 | (lambda x, y: x.at[y].min, "min"), 506 | ], ids=["set", "add", "subtract", "multiply", "divide", "max", "min"]) 507 | def test_scatter_2d_indices(op_fn, op_name): 508 | arr = jnp.zeros((5, 5)) 509 | 510 | indices = jnp.array([ 511 | [0, 1], 512 | [1, 2], 513 | [2, 3], 514 | [3, 4], 515 | [4, 0], 516 | ]) 517 | updates = jnp.arange(arr.shape[1]) 518 | 519 | def scatter_op(arr): 520 | return op_fn(arr, indices)(updates) 521 | 522 | run_and_compare(scatter_op, (arr,)) 523 | 524 | 525 | @pytest.mark.parametrize("op_fn,op_name", [ 526 | (lambda x, y: x.at[y].set, "set"), 527 | (lambda x, y: x.at[y].add, "add"), 528 | (lambda x, y: x.at[y].subtract, "subtract"), 529 | (lambda x, y: x.at[y].multiply, "multiply"), 530 | (lambda x, y: x.at[y].divide, "divide"), 531 | (lambda x, y: x.at[y].max, "max"), 532 | (lambda x, y: x.at[y].min, "min"), 533 | ], ids=["set", "add", "subtract", "multiply", "divide", "max", "min"]) 534 | def test_scatter_3d_indices(op_fn, op_name): 535 | arr = jnp.zeros((4, 3, 2)) 536 | 537 | indices = jnp.array([ 538 | [0, 0, 0], 539 | [1, 1, 1], 540 | [2, 2, 0], 541 | [3, 0, 1], 542 | ]) 543 | updates = jnp.arange(arr.shape[2]) 544 | 545 | def scatter_op(arr): 546 | return op_fn(arr, indices)(updates) 547 | 548 | run_and_compare(scatter_op, (arr,)) 549 | 550 | 551 | @pytest.mark.parametrize("op_fn,op_name", [ 552 | (lambda x, y: x.at[y].set, "set"), 553 | (lambda x, y: x.at[y].add, "add"), 554 | (lambda x, y: x.at[y].subtract, "subtract"), 555 | (lambda x, y: x.at[y].multiply, "multiply"), 556 | (lambda x, y: x.at[y].divide, "divide"), 557 | (lambda x, y: x.at[y].max, "max"), 558 | (lambda x, y: x.at[y].min, "min"), 559 | ], ids=["set", "add", "subtract", "multiply", "divide", "max", "min"]) 560 | def test_scatter_replace_vector(op_fn, op_name): 561 | arr = jnp.zeros((3, 4, 5, 6)) 562 | 563 | indices = jnp.array([ 564 | [0, 1, 3], 565 | [1, 2, 4], 566 | ]) 567 | 568 | updates = jnp.reshape(jnp.arange(arr.shape[3] * indices.shape[0]), (indices.shape[0], 1, 1, 1, arr.shape[3])) 569 | 570 | def scatter_op(arr): 571 | return op_fn(arr, indices)(updates) 572 | 573 | run_and_compare(scatter_op, (arr,)) 574 | 575 | 576 | @pytest.mark.parametrize("op_fn,op_name", [ 577 | (lambda x, y: x.at[y].set, "set"), 578 | (lambda x, y: x.at[y].add, "add"), 579 | (lambda x, y: x.at[y].subtract, "subtract"), 580 | (lambda x, y: x.at[y].multiply, "multiply"), 581 | (lambda x, y: x.at[y].divide, "divide"), 582 | (lambda x, y: x.at[y].max, "max"), 583 | (lambda x, y: x.at[y].min, "min"), 584 | ], ids=["set", "add", "subtract", "multiply", "divide", "max", "min"]) 585 | def test_scatter_replace_matrix(op_fn, op_name): 586 | arr = jnp.zeros((3, 4, 5, 6)) 587 | 588 | indices = jnp.array([ 589 | [0, 1], 590 | [1, 2], 591 | ]) 592 | 593 | updates = jnp.arange(arr.shape[2])[None, :] @ jnp.arange(arr.shape[2])[:, None] 594 | 595 | def scatter_op(arr): 596 | return op_fn(arr, indices)(updates) 597 | 598 | run_and_compare(scatter_op, (arr,)) 599 | 600 | 601 | def test_scatter_empty_indices(): 602 | def scatter_add(operand, indices, updates): 603 | dnums = jax.lax.ScatterDimensionNumbers( 604 | update_window_dims=(1,), 605 | inserted_window_dims=(0,), 606 | scatter_dims_to_operand_dims=(0,), 607 | ) 608 | return jax.lax.scatter_add(operand, indices, updates, dnums) 609 | 610 | operand = jnp.zeros((4, 4), dtype=jnp.float32) 611 | indices = jnp.zeros((0, 1), dtype=jnp.int32) 612 | updates = jnp.ones((0, 4), dtype=jnp.float32) 613 | 614 | run_and_compare_specific_input(scatter_add, (operand, indices, updates)) 615 | 616 | 617 | def test_pad(): 618 | run_and_compare(partial(jnp.pad, pad_width=((0, 0), (10, 5))), (jnp.zeros((10, 20)),)) 619 | run_and_compare(partial(jnp.pad, pad_width=((0, 10), (5, 0), (2, 1))), (jnp.zeros((10, 20, 15)),)) 620 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="empty"), (jnp.zeros((10, 20)),)) 621 | run_and_compare(partial(jnp.pad, pad_width=((1, 2), (3, 4)), constant_values=12.3), (jnp.zeros((10, 20)),)) 622 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="reflect"), (jnp.zeros((10, 20)),)) 623 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="wrap"), (jnp.zeros((10, 20)),)) 624 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="edge"), (jnp.zeros((10, 20)),)) 625 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="linear_ramp"), (jnp.zeros((10, 20)),)) 626 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="maximum"), (jnp.zeros((10, 20)),)) 627 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="mean"), (jnp.zeros((10, 20)),)) 628 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="median"), (jnp.zeros((10, 20)),)) 629 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="minimum"), (jnp.zeros((10, 20)),)) 630 | run_and_compare(partial(jnp.pad, pad_width=((5, 10), (10, 5)), mode="symmetric"), (jnp.zeros((10, 20)),)) 631 | 632 | 633 | def test_pad_int32(): 634 | run_and_compare(partial(jnp.pad, pad_width=((1, 1), (2, 2)), constant_values=10), (jnp.zeros((5, 5), dtype=jnp.int32),)) 635 | run_and_compare(partial(jnp.pad, pad_width=((1, 1), (2, 2))), (jnp.zeros((5, 5), dtype=jnp.int32),)) 636 | 637 | 638 | def test_remainder(): 639 | run_and_compare(jnp.remainder, ( 640 | jnp.array([10, 20, 30], dtype=jnp.int32), jnp.array([3, 7, 11], dtype=jnp.int32) 641 | )) 642 | run_and_compare(jnp.remainder, ( 643 | jnp.array([10.5, 20.2, 30.1], dtype=jnp.float32), jnp.array([3.1, 7.2, 11.3], dtype=jnp.float32) 644 | )) 645 | 646 | 647 | def test_floor(): 648 | run_and_compare(jnp.floor, (jnp.array([1.1, 2.9, -1.1, -2.9], dtype=jnp.float32),)) 649 | 650 | 651 | def test_ceil(): 652 | run_and_compare(jnp.ceil, (jnp.array([1.1, 2.9, -1.1, -2.9], dtype=jnp.float32),)) 653 | 654 | 655 | def test_clamp(): 656 | run_and_compare(partial(jnp.clip, a_min=0.0, a_max=1.0), (jnp.array([-1.0, 0.5, 2.0], dtype=jnp.float32),)) 657 | run_and_compare(partial(jnp.clip, a_min=-5, a_max=5), (jnp.array([-10, 0, 10], dtype=jnp.int32),)) 658 | 659 | 660 | def test_sort(): 661 | run_and_compare(jnp.sort, (jnp.array([3, 1, 2], dtype=jnp.int32),)) 662 | run_and_compare(jnp.sort, (jnp.array([[3, 1, 2], [6, 5, 4]], dtype=jnp.float32),)) 663 | run_and_compare(partial(jnp.sort, axis=0), (jnp.array([[3, 1, 2], [6, 5, 4]], dtype=jnp.float32),)) 664 | 665 | # Test with larger random input 666 | run_and_compare(jnp.sort, (jnp.zeros((100, 50), dtype=jnp.float32),)) 667 | run_and_compare(partial(jnp.sort, axis=0), (jnp.zeros((100, 50), dtype=jnp.float32),)) 668 | run_and_compare(partial(jnp.sort, descending=True), (jnp.zeros((100, 50), dtype=jnp.float32),)) 669 | run_and_compare(partial(jnp.sort, axis=0, descending=True), (jnp.zeros((100, 50), dtype=jnp.float32),)) 670 | 671 | # Test with NaNs and negative zeros to trigger total sort logic 672 | # Total sort order for floats: NaN < -Inf < ... < -0.0 < 0.0 < ... < Inf 673 | # (or similar, depending on implementation, but it must be total) 674 | # Note: CoreML handles NaNs differently than JAX (CoreML puts them at the beginning, JAX at the end) 675 | # So we exclude NaNs from this test to ensure we test the rest of the total sort logic (signed zeros etc) 676 | data = jnp.array([0.0, -0.0, 1.0, -1.0, jnp.inf, -jnp.inf], dtype=jnp.float32) 677 | run_and_compare_specific_input(jnp.sort, (data,)) 678 | 679 | 680 | def test_multikey_sort_fails_due_to_stability(): 681 | # Test lexicographical sort with multiple keys 682 | def sort_dim_0(k1, k2): 683 | return jax.lax.sort([k1, k2], dimension=0, num_keys=2) 684 | 685 | k1 = jnp.array([1, 3, 2, 4], dtype=jnp.int32) 686 | k2 = jnp.array([3, 1, 2, 4], dtype=jnp.int32) 687 | with pytest.raises(Exception): 688 | run_and_compare_specific_input(sort_dim_0, (k1, k2)) 689 | 690 | k1 = jnp.array([1, 5, 1, 4, 3, 4, 4], dtype=jnp.int32) 691 | k2 = jnp.array([9, 4, 0, 4, 0, 2, 1], dtype=jnp.int32) 692 | with pytest.raises(Exception): 693 | run_and_compare_specific_input(sort_dim_0, (k1, k2)) 694 | 695 | k1 = jnp.array([1, 3, 1, 4, 3, 5, 4], dtype=jnp.int32) 696 | k2 = jnp.array([0, 4, 0, 4, 0, -21, -12], dtype=jnp.int32) 697 | with pytest.raises(Exception): 698 | run_and_compare_specific_input(sort_dim_0, (k1, k2)) 699 | 700 | k1_2d = jnp.array([[1, 2], [3, 4]], dtype=jnp.int32) 701 | k2_2d = jnp.array([[3, 1], [2, 4]], dtype=jnp.int32) 702 | 703 | def sort_dim_1(k1, k2): 704 | return jax.lax.sort([k1, k2], dimension=1, num_keys=2) 705 | 706 | with pytest.raises(Exception): 707 | run_and_compare_specific_input(sort_dim_1, (k1_2d, k2_2d)) 708 | 709 | # Larger random inputs 710 | def sort_dim_0_large(k1, k2): 711 | return jax.lax.sort([k1, k2], dimension=0, num_keys=2) 712 | 713 | with pytest.raises(Exception): 714 | run_and_compare(sort_dim_0_large, (jnp.zeros((100, 50), dtype=jnp.int32), jnp.zeros((100, 50), dtype=jnp.int32))) 715 | with pytest.raises(Exception): 716 | run_and_compare(sort_dim_0_large, (jnp.zeros((100, 50), dtype=jnp.float32), jnp.zeros((100, 50), dtype=jnp.float32))) 717 | 718 | 719 | def test_unstable_argsort(): 720 | def unstable_argsort(x, **kwargs): 721 | return jnp.argsort(x, stable=False, **kwargs) 722 | 723 | run_and_compare_specific_input(unstable_argsort, (jnp.array([3, 1, 2], dtype=jnp.int32),)) 724 | run_and_compare_specific_input(unstable_argsort, (jnp.array([[3, 1, 2], [6, 5, 4]], dtype=jnp.float32),)) 725 | run_and_compare_specific_input(partial(unstable_argsort, axis=0), (jnp.array([[3, 1, 2], [6, 5, 4]], dtype=jnp.float32),)) 726 | 727 | # Test with larger random input 728 | run_and_compare(unstable_argsort, (jnp.zeros((100, 50), dtype=jnp.float32),)) 729 | run_and_compare(partial(unstable_argsort, axis=0), (jnp.zeros((100, 50), dtype=jnp.float32),)) 730 | 731 | 732 | def test_multi_input_argsort(): 733 | # Because argsort is unstable, we cannot directly compare the output indices. 734 | # Instead, we perform argsort followed by gather to retrieve the sorted values, 735 | # which can then be compared. 736 | def unstable_argsort_and_lookup(sort_array, lookup_array, lookup_values): 737 | _sorted_array, ordered_lookup_idx = jax.lax.sort([sort_array, lookup_array], dimension=0, num_keys=1, is_stable=False) 738 | gathered = jnp.take(lookup_values, ordered_lookup_idx) 739 | return gathered 740 | 741 | run_and_compare_specific_input(unstable_argsort_and_lookup, ( 742 | jnp.array([3, 1, 2, 3, 1, 2, 3, 1, 2], dtype=jnp.int32), 743 | jnp.array([2, 0, 1, 2, 0, 1, 2, 0, 1], dtype=jnp.int32), 744 | jnp.array([0, 1, 2], dtype=jnp.int32) 745 | )) 746 | 747 | run_and_compare_specific_input(unstable_argsort_and_lookup, ( 748 | jnp.array([3, 1, 2, 3, 1, 2, 3, 1, 2], dtype=jnp.float32), 749 | jnp.array([2, 0, 1, 2, 0, 1, 2, 0, 1], dtype=jnp.int32), 750 | jnp.array([0, 1, 2], dtype=jnp.float32) 751 | )) 752 | 753 | 754 | def test_case(): 755 | def switch_fn(index, x): 756 | return jax.lax.switch(index, [ 757 | lambda x: x + 1, 758 | lambda x: x * 2, 759 | lambda x: x - 1 760 | ], x) 761 | 762 | run_and_compare_specific_input(switch_fn, ( 763 | jnp.array(0, dtype=jnp.int32), jnp.array(10.0, dtype=jnp.float32) 764 | )) 765 | run_and_compare_specific_input(switch_fn, ( 766 | jnp.array(1, dtype=jnp.int32), jnp.array(10.0, dtype=jnp.float32) 767 | )) 768 | run_and_compare_specific_input(switch_fn, ( 769 | jnp.array(2, dtype=jnp.int32), jnp.array(10.0, dtype=jnp.float32) 770 | )) 771 | 772 | 773 | def test_reshape_scalar(): 774 | # Test reshaping to scalar (0-rank tensor) 775 | def reshape_to_scalar(x): 776 | return jnp.reshape(x, ()) 777 | 778 | run_and_compare(reshape_to_scalar, (jnp.array([5.0], dtype=jnp.float32),)) 779 | 780 | 781 | def test_compare_bool(): 782 | run_and_compare_specific_input(jnp.equal, ( 783 | jnp.array([True, False, True], dtype=jnp.bool_), 784 | jnp.array([True, True, False], dtype=jnp.bool_) 785 | )) 786 | run_and_compare_specific_input(jnp.not_equal, ( 787 | jnp.array([True, False, True], dtype=jnp.bool_), 788 | jnp.array([True, True, False], dtype=jnp.bool_) 789 | )) 790 | 791 | 792 | def test_logical_not(): 793 | run_and_compare(jnp.logical_not, (jnp.array([True, False]),)) 794 | 795 | 796 | def test_power(): 797 | run_and_compare(jnp.power, (jnp.array([2.0, 3.0]), jnp.array([3.0, 2.0]))) 798 | 799 | 800 | def test_dynamic_slice_oob(): 801 | # Test dynamic slice with out of bounds indices 802 | # StableHLO spec requires that the start indices are clamped to ensure the slice remains within bounds 803 | # start_index = clamp(start_index, 0, operand_dim - slice_size) 804 | def dynamic_slice(operand, start_indices): 805 | return jax.lax.dynamic_slice(operand, start_indices, slice_sizes=(2, 2)) 806 | 807 | operand = jnp.zeros((5, 5)) 808 | # Valid index 809 | run_and_compare_specific_input(dynamic_slice, (operand, jnp.array([1, 1], dtype=jnp.int32))) 810 | # Out of bounds index (too large) -> should be clamped to 5-2 = 3 811 | run_and_compare_specific_input(dynamic_slice, (operand, jnp.array([4, 4], dtype=jnp.int32))) 812 | # Out of bounds index (negative) -> should be clamped to 0 813 | run_and_compare_specific_input(dynamic_slice, (operand, jnp.array([10, 10], dtype=jnp.int32))) 814 | 815 | 816 | def test_dynamic_update_slice_oob(): 817 | # Test dynamic update slice with out of bounds indices 818 | # StableHLO spec requires that the start indices are clamped to ensure the slice remains within bounds 819 | # start_index = clamp(start_index, 0, operand_dim - update_dim) 820 | def dynamic_update_slice(operand, update, start_indices): 821 | return jax.lax.dynamic_update_slice(operand, update, start_indices) 822 | 823 | operand = jnp.zeros((5, 5)) 824 | update = jnp.ones((2, 2)) 825 | # Valid index 826 | run_and_compare_specific_input(dynamic_update_slice, (operand, update, jnp.array([1, 1], dtype=jnp.int32))) 827 | # Out of bounds index (too large) -> should be clamped to 5-2 = 3 828 | run_and_compare_specific_input(dynamic_update_slice, (operand, update, jnp.array([4, 4], dtype=jnp.int32))) 829 | # Out of bounds index (negative) -> should be clamped to 0 830 | run_and_compare_specific_input(dynamic_update_slice, (operand, update, jnp.array([10, 10], dtype=jnp.int32))) 831 | 832 | 833 | def test_transposed_conv_large_padding(): 834 | input_shape = (1, 1, 4, 4) 835 | kernel_shape = (1, 1, 3, 3) 836 | 837 | def transposed_conv(img, kernel): 838 | return jax.lax.conv_general_dilated( 839 | lhs=img, 840 | rhs=kernel, 841 | window_strides=(1, 1), 842 | padding=((3, 3), (3, 3)), 843 | lhs_dilation=(2, 2), 844 | dimension_numbers=('NCHW', 'OIHW', 'NCHW') 845 | ) 846 | 847 | run_and_compare(transposed_conv, (jnp.zeros(input_shape), jnp.zeros(kernel_shape))) 848 | -------------------------------------------------------------------------------- /stablehlo_coreml/converter.py: -------------------------------------------------------------------------------- 1 | from coremltools import _logger as logger 2 | from coremltools.converters.mil import mil 3 | from coremltools.converters.mil.mil import Builder as mb 4 | from coremltools.converters.mil.mil import Function, Program, types 5 | from coremltools.converters.mil._deployment_compatibility import AvailableTarget 6 | from coremltools.converters.mil.mil.ops.defs._utils import ( 7 | promote_input_dtypes, 8 | ) 9 | from .utils import ( 10 | index_by_slices, update_tensor_by_slice, iterate_indexes_in_shapes, 11 | inverse_permutation, get_mil_type, dtype_str, get_mil_type_from_ir, get_numpy_type, 12 | clamp_index, range_along_dim 13 | ) 14 | from .passes.utils import register_optimizations 15 | from .translation_context import TranslationContext 16 | from .ops_register import StableHloOpsRegistry, register_stablehlo_op 17 | from .sort_utils import match_sort 18 | from .reductions import ( 19 | compute_reduction, compute_windowed_reduction, match_computation 20 | ) 21 | from .padding import pad_with_cast 22 | 23 | from jaxlib.mlir import ir 24 | from jaxlib.mlir.dialects.func import FuncOp, CallOp, ReturnOp as FuncReturnOp 25 | from jaxlib.mlir.dialects.stablehlo import ( 26 | AddOp, SubtractOp, MulOp, DivOp, NegOp, SignOp, AbsOp, ExpOp, Expm1Op, LogOp, 27 | Log1pOp, SqrtOp, ConstantOp, DotGeneralOp, ReshapeOp, BroadcastInDimOp, WhileOp, 28 | CompareOp, ConvertOp, SelectOp, DynamicSliceOp, ReturnOp, ConvolutionOp, MinOp, 29 | MaxOp, RsqrtOp, TanhOp, SineOp, CosineOp, TanOp, Atan2Op, ConcatenateOp, TransposeOp, 30 | DynamicUpdateSliceOp, SliceOp, CustomCallOp, IotaOp, ReduceOp, ReduceWindowOp, 31 | OrOp, AndOp, NotOp, ReverseOp, IsFiniteOp, GatherOp, PowOp, PadOp, RemOp, 32 | ScatterOp, FloorOp, CeilOp, SortOp, ClampOp, CaseOp, 33 | ) 34 | from jaxlib.mlir.dialects.mhlo import (TopKOp, AsinOp, AcosOp, SinhOp, CoshOp, AsinhOp, AcoshOp, AtanhOp) 35 | from jax._src.lib.mlir.dialects import hlo 36 | 37 | import numpy as np 38 | 39 | from typing import List, Optional 40 | from functools import partial, reduce 41 | 42 | 43 | def convert(module, minimum_deployment_target: AvailableTarget): 44 | if minimum_deployment_target < AvailableTarget.iOS18: 45 | raise ValueError("Converting to Program: 61 | logger.info("Converting graph.") 62 | 63 | # Build function index to resolve/inline HLO function calls 64 | for func in module.body: 65 | self.func_index[func.name.value] = func 66 | 67 | for func in module.body: 68 | if func.sym_visibility is None or "public" == func.sym_visibility.value: 69 | self.build_func(func) 70 | 71 | return self.prog 72 | 73 | def build_func(self, hlo_func: FuncOp): 74 | context = TranslationContext() # Map from results to created variables 75 | 76 | func_inputs = {} 77 | for arg in hlo_func.arguments: 78 | shape = arg.type.shape 79 | if shape == []: 80 | shape = [1] 81 | 82 | func_inputs[arg.get_name()] = mb.placeholder( 83 | shape=shape, dtype=get_mil_type_from_ir(arg.type.element_type) 84 | ) 85 | 86 | with Function(func_inputs, opset_version=self.opset_version) as ssa_func: 87 | for name in func_inputs.keys(): 88 | context.add_variable(name, ssa_func.inputs[name]) 89 | 90 | ssa_func.set_outputs(self.process_block(context, hlo_func.body.blocks[0])) 91 | self.prog.add_function(hlo_func.name.value, ssa_func) 92 | 93 | def process_block(self, context: TranslationContext, block: ir.Block): 94 | outputs = None 95 | for op in block: 96 | # Convention: Only the "return" op is returning from its building function 97 | # TODO: Check that "return" is always the last node! 98 | ret = self.dispatch_op(self, context, op) 99 | if ret is not None: 100 | if outputs is not None: 101 | raise ValueError("More than 1 return op in block!") 102 | outputs = ret 103 | return outputs 104 | 105 | @register_stablehlo_op 106 | def op_call(self, context: TranslationContext, op: CallOp): 107 | # We can not do function calls in MIL, so we have to inline the function 108 | 109 | # Get the argument mapping prior to entering the function context 110 | context_args = [] 111 | 112 | for arg in op.operands: 113 | context_args.append(context[arg.get_name()]) 114 | 115 | func_name = op.callee.value 116 | hlo_func = self.func_index[op.callee.value] 117 | params = hlo_func.arguments 118 | outputs = self.invoke_hlo_function(context, func_name, params, hlo_func.body, context_args) 119 | 120 | # Configure return value 121 | for result, output in zip(op.results, outputs): 122 | context.add_result(result, output) 123 | 124 | @register_stablehlo_op 125 | def op_return(self, context: TranslationContext, op: ReturnOp): 126 | return [context[result.get_name()] for result in op.operands] 127 | 128 | @register_stablehlo_op 129 | def op_func_return(self, context: TranslationContext, op: FuncReturnOp): 130 | # The HLO / MLIR types for function return ops seem to be both in use 131 | # The behaviour and fields of the two types should be similar, so we 132 | # simply delegate to the HLO version 133 | return self.op_return(context, op) 134 | 135 | @register_stablehlo_op 136 | def op_add(self, context: TranslationContext, op: AddOp): 137 | self.__simple_binary_op(context, mb.add, op) 138 | 139 | @register_stablehlo_op 140 | def op_or(self, context: TranslationContext, op: OrOp): 141 | self.__simple_binary_op(context, mb.logical_or, op) 142 | 143 | @register_stablehlo_op 144 | def op_and(self, context: TranslationContext, op: AndOp): 145 | self.__simple_binary_op(context, mb.logical_and, op) 146 | 147 | @register_stablehlo_op 148 | def op_not(self, context: TranslationContext, op: NotOp): 149 | self.__simple_unary_op(context, mb.logical_not, op) 150 | 151 | @register_stablehlo_op 152 | def op_subtract(self, context: TranslationContext, op: SubtractOp): 153 | self.__simple_binary_op(context, mb.sub, op) 154 | 155 | @register_stablehlo_op 156 | def op_mul(self, context: TranslationContext, op: MulOp): 157 | self.__simple_binary_op(context, mb.mul, op) 158 | 159 | @register_stablehlo_op 160 | def op_div(self, context: TranslationContext, op: DivOp): 161 | lhs = context[op.lhs.get_name()] 162 | rhs = context[op.rhs.get_name()] 163 | 164 | # From HLO constraints we know the base-types should line up 165 | lhs_type = get_mil_type(lhs) 166 | rhs_type = get_mil_type(rhs) 167 | if lhs_type != rhs_type: 168 | raise ValueError(f"Division not supported for different types. lhs type: {lhs_type}, rhs type: {rhs_type}") 169 | if types.is_complex(lhs_type): 170 | raise ValueError("Complex numbers are not supported in MIL") 171 | 172 | if types.is_float(lhs_type): 173 | cml_op = mb.real_div(x=lhs, y=rhs) 174 | elif types.is_int(lhs_type): 175 | cml_op = mb.floor_div(x=lhs, y=rhs) 176 | else: 177 | raise ValueError(f"Unknown dtype {lhs_type}") 178 | 179 | context.add_result(op.result, cml_op) 180 | 181 | @register_stablehlo_op 182 | def op_neg(self, context: TranslationContext, op: NegOp): 183 | # TODO(knielsen): Consider unsigned and more exotic types 184 | operand = context[op.operand.get_name()] 185 | minus_one = np.array([-1], dtype=get_numpy_type(operand)) 186 | cml_op = mb.mul(x=minus_one, y=operand) 187 | context.add_result(op.result, cml_op) 188 | 189 | @register_stablehlo_op 190 | def op_sign(self, context: TranslationContext, op: SignOp): 191 | self.__simple_unary_op(context, mb.sign, op) 192 | 193 | @register_stablehlo_op 194 | def op_abs(self, context: TranslationContext, op: AbsOp): 195 | self.__simple_unary_op(context, mb.abs, op) 196 | 197 | @register_stablehlo_op 198 | def op_log(self, context: TranslationContext, op: LogOp): 199 | self.__simple_unary_op(context, mb.log, op) 200 | 201 | @register_stablehlo_op 202 | def op_log1p(self, context: TranslationContext, op: Log1pOp): 203 | operand = context[op.operand.get_name()] 204 | one = np.array([1], dtype=get_numpy_type(operand)) 205 | x_plus_one = mb.add(x=one, y=operand) 206 | cml_op = mb.log(x=x_plus_one) 207 | context.add_result(op.result, cml_op) 208 | 209 | @register_stablehlo_op 210 | def op_exp(self, context: TranslationContext, op: ExpOp): 211 | self.__simple_unary_op(context, mb.exp, op) 212 | 213 | @register_stablehlo_op 214 | def op_pow(self, context: TranslationContext, op: PowOp): 215 | self.__simple_binary_op(context, mb.pow, op) 216 | 217 | @register_stablehlo_op 218 | def op_expm1(self, context: TranslationContext, op: Expm1Op): 219 | operand = context[op.operand.get_name()] 220 | cml_op = mb.add(x=mb.exp(x=operand), y=-1.0) 221 | context.add_result(op.result, cml_op) 222 | 223 | @register_stablehlo_op 224 | def op_transpose(self, context: TranslationContext, op: TransposeOp): 225 | operand = context[op.operand.get_name()] 226 | perm = np.array(op.permutation, dtype=np.int32) 227 | cml_op = mb.transpose(x=operand, perm=perm) 228 | context.add_result(op.result, cml_op) 229 | 230 | @register_stablehlo_op 231 | def op_pad(self, context: TranslationContext, op: PadOp): 232 | operand = context[op.operand.get_name()] 233 | 234 | if not np.all(np.array(op.interior_padding) == 0): 235 | raise ValueError("Interior padding is not supported") 236 | 237 | operand_rank = len(op.operand.type.shape) 238 | indices = np.arange(2 * operand_rank, dtype=np.int32) 239 | pad = np.zeros_like(indices) 240 | pad = mb.scatter_along_axis( 241 | data=pad, 242 | indices=indices[::2], 243 | mode="update", 244 | updates=np.array(op.edge_padding_low, dtype=np.int32) 245 | ) 246 | pad = mb.scatter_along_axis( 247 | data=pad, 248 | indices=indices[1::2], 249 | mode="update", 250 | updates=np.array(op.edge_padding_high, dtype=np.int32) 251 | ) 252 | 253 | cml_padding_value = context[op.padding_value.get_name()] 254 | cml_op = pad_with_cast(x=operand, pad=pad, mode="constant", constant_val=cml_padding_value) 255 | context.add_result(op.result, cml_op) 256 | 257 | @register_stablehlo_op 258 | def op_sqrt(self, context: TranslationContext, op: SqrtOp): 259 | self.__simple_unary_op(context, mb.sqrt, op) 260 | 261 | @register_stablehlo_op 262 | def op_constant(self, context: TranslationContext, op: ConstantOp): 263 | constant = np.array(op.value) 264 | constant = np.reshape(constant, op.result.type.shape) 265 | context.add_result(op.result, constant) 266 | 267 | @register_stablehlo_op 268 | def op_dot_general(self, context: TranslationContext, op: DotGeneralOp): 269 | # This roughly follows the steps from https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general 270 | # but uses that we have a matrix multiplication primitive, instead of just a dot-product primitive. 271 | lhs_rank = len(op.lhs.type.shape) 272 | rhs_rank = len(op.rhs.type.shape) 273 | dot_dim_numbers = hlo.DotDimensionNumbers(op.dot_dimension_numbers) 274 | 275 | lhs_contracting_dim = dot_dim_numbers.lhs_contracting_dimensions 276 | rhs_contracting_dim = dot_dim_numbers.rhs_contracting_dimensions 277 | lhs_batching_dim = dot_dim_numbers.lhs_batching_dimensions 278 | rhs_batching_dim = dot_dim_numbers.rhs_batching_dimensions 279 | 280 | lhs = context[op.lhs.get_name()] 281 | rhs = context[op.rhs.get_name()] 282 | 283 | def multiply(lst: List): 284 | return reduce(lambda a, b: int(a) * int(b), lst, 1) 285 | 286 | def last_column_dot(lhs, rhs): 287 | # TODO: Figure out if we need to special case broadcasting dims 288 | return mb.matmul(x=lhs, y=rhs, transpose_y=True) 289 | 290 | # Remark: There is a potential performance optimization here: 291 | # If we move the largest result dimensions of the tensor towards 292 | # the end of the array, we may save a lot of work when iterating 293 | # over the result indexes later, as the last dims will be handled 294 | # by matrix multiplication 295 | lhs_result_dim = [dim for dim in range(lhs_rank) if dim not in lhs_batching_dim + lhs_contracting_dim] 296 | rhs_result_dim = [dim for dim in range(rhs_rank) if dim not in rhs_batching_dim + rhs_contracting_dim] 297 | 298 | # For both the lhs and rhs, put the dimensions being contracted last 299 | transposed_lhs = mb.transpose(x=lhs, perm=lhs_batching_dim + lhs_result_dim + lhs_contracting_dim) 300 | transposed_rhs = mb.transpose(x=rhs, perm=rhs_batching_dim + rhs_result_dim + rhs_contracting_dim) 301 | 302 | # Calculate the result by looping over the contracting dims in order 303 | result_shape = [lhs.shape[dim] for dim in lhs_batching_dim] 304 | result_shape += [lhs.shape[dim] for dim in lhs_result_dim] 305 | result_shape += [rhs.shape[dim] for dim in rhs_result_dim] 306 | if len(result_shape) == 0: 307 | # Special case for scalar result 308 | result_shape = [1] 309 | 310 | # Allocate memory of the correct type for the result 311 | result_dtype = get_mil_type_from_ir(op.result.type.element_type) 312 | result = mb.fill(shape=result_shape, value=mb.cast(x=0, dtype=dtype_str(result_dtype))) 313 | 314 | def calculate_result_index(lhs_idx, rhs_idx, acc): 315 | contracted_element_count = multiply([lhs.shape[dim] for dim in lhs_contracting_dim]) 316 | # print(f"contracted_element_count = {contracted_element_count}") 317 | batch_selector = tuple([slice(None) for _i in range(len(lhs_batching_dim))]) 318 | batch_shape = tuple([lhs.shape[dim] for dim in lhs_batching_dim]) 319 | 320 | # Reshape the lhs and rhs to have all the contracting dimensions in the end. 321 | # We will always make them have the shape `(batch_shape, last_dim_shape, contraction_count)`` 322 | # where we may have to set `last_dim_shape` to 1, if the dimension does not exist. 323 | lhs_for_result_idx = index_by_slices(transposed_lhs, list(batch_selector) + [lhs_idx, ...]) 324 | if len(lhs_result_dim) > 0: 325 | lhs_reshape_shape = batch_shape + (lhs.shape[lhs_result_dim[-1]],) + (contracted_element_count, ) 326 | else: 327 | lhs_reshape_shape = batch_shape + (1, contracted_element_count) 328 | contracted_lhs = mb.reshape(x=lhs_for_result_idx, shape=lhs_reshape_shape) 329 | 330 | rhs_for_result_idx = index_by_slices(transposed_rhs, list(batch_selector) + [rhs_idx, ...]) 331 | if len(rhs_result_dim) > 0: 332 | rhs_reshape_shape = batch_shape + (rhs.shape[rhs_result_dim[-1]],) + (contracted_element_count, ) 333 | else: 334 | rhs_reshape_shape = batch_shape + (1, contracted_element_count) 335 | contracted_rhs = mb.reshape(x=rhs_for_result_idx, shape=rhs_reshape_shape) 336 | 337 | # print(f"contracted_lhs shape: {contracted_lhs.shape}") 338 | # print(f"contracted_rhs shape: {contracted_rhs.shape}") 339 | 340 | idx_result = last_column_dot(contracted_lhs, contracted_rhs) 341 | 342 | # If we added a fake dimension, we will make sure to squeeze it away 343 | if len(lhs_result_dim) == 0 and len(rhs_result_dim) == 0: 344 | if len(idx_result.shape) == 2: 345 | assert idx_result.shape == (1, 1) 346 | # This is a special case, where the result is a scalar of shape (1, 1) 347 | # In order to not end up with a 0-rank tensor, we only contract one dimension 348 | idx_result = mb.reshape(x=idx_result, shape=(1,)) 349 | else: 350 | idx_result = mb.squeeze(x=idx_result, axes=(-1, -2)) 351 | elif len(lhs_result_dim) == 0: 352 | idx_result = mb.squeeze(x=idx_result, axes=(-2,)) 353 | elif len(rhs_result_dim) == 0: 354 | idx_result = mb.squeeze(x=idx_result, axes=(-1,)) 355 | 356 | # TODO: Consider making this work on iOS<18 by using concatenation 357 | # We may have to add an extra slice for the skipped dimension 358 | result_idx = [] 359 | result_idx.append(lhs_idx) 360 | if len(lhs_result_dim) > 0: 361 | result_idx.append(slice(None)) 362 | result_idx.append(rhs_idx) 363 | if len(rhs_result_dim) > 0: 364 | result_idx.append(slice(None)) 365 | 366 | return [update_tensor_by_slice(acc, list(batch_selector) + result_idx, idx_result)] 367 | 368 | # We can utilize that we have a full matrix multiply primitive available, compared to having only 369 | # a dot-product primitive. Therefore we can avoid iterating over the last dimension in respectively 370 | # the lhs and rhs tensors 371 | lhs_shape = [lhs.shape[dim] for dim in lhs_result_dim[:-1]] 372 | rhs_shape = [rhs.shape[dim] for dim in rhs_result_dim[:-1]] 373 | # In principle all of the matrix multiplications generated here, could be done in parallel. 374 | # MIL does not seem to support this. 375 | # We could try to combine the matrix multiplications when the shapes allow it, but for now 376 | # we will just loop through them sequentially. 377 | result, = iterate_indexes_in_shapes(calculate_result_index, [lhs_shape, rhs_shape], [result]) 378 | 379 | context.add_result(op.result, result) 380 | 381 | @register_stablehlo_op 382 | def op_rem(self, context: TranslationContext, op: RemOp): 383 | self.__simple_binary_op(context, mb.mod, op) 384 | 385 | @register_stablehlo_op 386 | def op_floor(self, context: TranslationContext, op: FloorOp): 387 | self.__simple_unary_op(context, mb.floor, op) 388 | 389 | @register_stablehlo_op 390 | def op_ceil(self, context: TranslationContext, op: CeilOp): 391 | self.__simple_unary_op(context, mb.ceil, op) 392 | 393 | @register_stablehlo_op 394 | def op_clamp(self, context: TranslationContext, op: ClampOp): 395 | min = context[op.min.get_name()] 396 | max = context[op.max.get_name()] 397 | operand = context[op.operand.get_name()] 398 | result = mb.minimum(x=mb.maximum(x=operand, y=min), y=max) 399 | context.add_result(op.results[0], result) 400 | 401 | @register_stablehlo_op 402 | def op_sort(self, context: TranslationContext, op: SortOp): 403 | # StableHLO defines sorting via a comparator region (a small function) that returns true if 404 | # element A < element B. CoreML, however, uses high-level primitives. 405 | # To bridge this gap, we must analyze the comparator's structure to reverse-engineer 406 | # the sorting criteria (which keys to sort by and in what direction). 407 | inputs = [context[operand.get_name()] for operand in op.inputs] 408 | if op.is_stable and len(inputs) > 1: 409 | raise ValueError("Stable sorting is not supported for multi-input sorting") 410 | 411 | if len(op.comparator.blocks) != 1: 412 | raise ValueError("Unsupported comparator format: must have exactly one block") 413 | 414 | comparator_block = op.comparator.blocks[0] 415 | return_op = comparator_block.operations[-1] 416 | 417 | if not isinstance(return_op, ReturnOp): 418 | raise ValueError("Unsupported comparator format: last operation must be a return") 419 | 420 | # We start tracing from the return value of the comparator to understand the logic 421 | comparator_root = return_op.operands[0].owner.opview 422 | args = list(comparator_block.arguments) 423 | 424 | # Try to match known sorting patterns 425 | sort_keys = match_sort(comparator_root, args, inputs) 426 | if sort_keys is None: 427 | raise ValueError("Unrecognized comparator format") 428 | 429 | # Apply the sort 430 | sort_dim, (key, ascending) = op.dimension.value, sort_keys[-1] 431 | indices = mb.argsort(x=key, axis=sort_dim, ascending=ascending) 432 | 433 | # Given CoreML's argsort is unstable we are not able to handle multiple sort keys 434 | if len(sort_keys) > 1: 435 | raise ValueError("Having more than one sort key is not supported because MIL's argsort is not supported") 436 | # The following code would be used if CoreML had a stable argsort 437 | # for key, ascending in sort_keys[-2::-1]: 438 | # gathered_key = mb.gather_along_axis(x=key, indices=indices, axis=sort_dim) 439 | # relative_indices = mb.argsort(x=gathered_key, axis=sort_dim, ascending=ascending) 440 | # indices = mb.gather_along_axis(x=indices, indices=relative_indices, axis=sort_dim) 441 | 442 | for i, tensor in enumerate(inputs): 443 | context.add_result(op.results[i], mb.gather_along_axis(x=tensor, indices=indices, axis=sort_dim)) 444 | 445 | @register_stablehlo_op 446 | def op_case(self, context: TranslationContext, op: CaseOp): 447 | index = context[op.index.get_name()] 448 | 449 | def params(i): 450 | closure, args = [], [] 451 | for j in op.branches[i].blocks[0].operations: 452 | for k in j.operands: 453 | if k.get_name() in context.variables[context.path()]: 454 | closure.append(k) 455 | args.append(context[k.get_name()]) 456 | return (closure, op.branches[i], args) 457 | 458 | def build_branch(i): 459 | if i == len(op.branches) - 1: 460 | # Default/Last branch 461 | return self.invoke_hlo_function(context, "branch_default", *params(i)) 462 | 463 | def true_fn(): 464 | return self.invoke_hlo_function(context, f"branch_{i}", *params(i)) 465 | 466 | def false_fn(): 467 | return build_branch(i + 1) 468 | 469 | return mb.cond( 470 | pred=mb.equal(x=index, y=i), 471 | _true_fn=true_fn, 472 | _false_fn=false_fn 473 | ) 474 | 475 | results = build_branch(0) 476 | if not isinstance(results, (list, tuple)): 477 | results = [results] 478 | for i, result in enumerate(results): 479 | context.add_result(op.results[i], result) 480 | 481 | @register_stablehlo_op 482 | def op_reshape(self, context: TranslationContext, op: ReshapeOp): 483 | x = context[op.operand.get_name()] 484 | new_shape = op.result.type.shape 485 | if len(new_shape) == 0: 486 | reshape_res = mb.squeeze(x=x) 487 | else: 488 | reshape_res = mb.reshape(x=x, shape=new_shape) 489 | context.add_result(op.result, reshape_res) 490 | 491 | @register_stablehlo_op 492 | def op_broadcast_in_dim(self, context: TranslationContext, op: BroadcastInDimOp): 493 | x = context[op.operand.get_name()] 494 | 495 | result_shape = op.result.type.shape 496 | if len(result_shape) == 0: 497 | # Cast a scalar shape to a (1,) shape 498 | result_shape = [1] 499 | result_shape_rank = len(result_shape) 500 | 501 | reshaped_operand_shape = [1] * result_shape_rank 502 | for i, op_shape in enumerate(op.operand.type.shape): 503 | result_idx = op.broadcast_dimensions[i] 504 | reshaped_operand_shape[result_idx] = op_shape 505 | 506 | x = mb.reshape(x=x, shape=reshaped_operand_shape) 507 | 508 | result_tiling = [1] * result_shape_rank 509 | for result_dim, current_shape in enumerate(reshaped_operand_shape): 510 | # Replicate data along dimension `dim` until the result dimension matches 511 | assert result_shape[result_dim] % current_shape == 0 512 | result_tiling[result_dim] = result_shape[result_dim] // current_shape 513 | x = mb.tile(x=x, reps=result_tiling) 514 | 515 | context.add_result(op.result, x) 516 | 517 | @register_stablehlo_op 518 | def op_while(self, context: TranslationContext, op: WhileOp): 519 | def cond(*loop_args): 520 | params = [param for param in op.cond.blocks[0].arguments] 521 | outputs = self.invoke_hlo_function(context, "while_cond", params, op.cond, loop_args) 522 | if len(outputs) != 1: 523 | raise ValueError("The output of while_cond should always be a single boolean!") 524 | # TODO(knielsen): Add a check that the output is in fact a single boolean value 525 | 526 | return outputs[0] 527 | 528 | def body(*body_args): 529 | params = [param for param in op.body.blocks[0].arguments] 530 | return self.invoke_hlo_function(context, "while_body", params, op.body, body_args) 531 | 532 | loop_vars = [context[arg.get_name()] for arg in op.operands] 533 | while_results = mb.while_loop(_cond=cond, _body=body, loop_vars=loop_vars) 534 | 535 | for result_var, while_result in zip(op.results, while_results): 536 | context.add_result(result_var, while_result) 537 | 538 | @register_stablehlo_op 539 | def op_compare(self, context: TranslationContext, op: CompareOp): 540 | comparison_direction = hlo.ComparisonDirectionAttr(op.comparison_direction).value 541 | cml_op_builder = { 542 | "EQ": mb.equal, 543 | "NE": mb.not_equal, 544 | "GE": mb.greater_equal, 545 | "GT": mb.greater, 546 | "LE": mb.less_equal, 547 | "LT": mb.less, 548 | }[comparison_direction] 549 | 550 | lhs = context[op.lhs.get_name()] 551 | rhs = context[op.rhs.get_name()] 552 | if types.is_bool(lhs.dtype): 553 | if comparison_direction == "EQ": 554 | cml_op = mb.logical_not(x=mb.logical_xor(x=lhs, y=rhs)) 555 | elif comparison_direction == "NE": 556 | cml_op = mb.logical_xor(x=lhs, y=rhs) 557 | else: 558 | raise ValueError( 559 | f"Boolean comparison operations other than EQ and NE (such as GT, LT, GE, LE) are not supported! " 560 | f"Attempted operation: {comparison_direction}" 561 | ) 562 | else: 563 | cml_op = cml_op_builder(x=lhs, y=rhs) 564 | context.add_result(op.result, cml_op) 565 | 566 | @register_stablehlo_op 567 | def op_convert(self, context: TranslationContext, op: ConvertOp): 568 | x = context[op.operand.get_name()] 569 | new_dtype = get_mil_type_from_ir(op.result.type.element_type) 570 | cml_op = mb.cast(x=x, dtype=dtype_str(new_dtype)) 571 | context.add_result(op.result, cml_op) 572 | 573 | @register_stablehlo_op 574 | def op_select(self, context: TranslationContext, op: SelectOp): 575 | cond = context[op.pred.get_name()] 576 | a = context[op.on_true.get_name()] 577 | b = context[op.on_false.get_name()] 578 | cml_op = mb.select(cond=cond, a=a, b=b) 579 | context.add_result(op.result, cml_op) 580 | 581 | @register_stablehlo_op 582 | def op_dynamic_slice(self, context: TranslationContext, op: DynamicSliceOp): 583 | x = context[op.operand.get_name()] 584 | 585 | # The HLO DynamicSliceOp gives the start indices as seperate 0-dimensional integer variables 586 | # We need to convert them to a tensor to be compatible with mb.slice_by_size 587 | start_idx_variables = [context[i.get_name()] for i in op.start_indices] 588 | begin = mb.concat(values=start_idx_variables, axis=0) 589 | 590 | # The slice sizes in HLO are given by a signed integer with 64 bits 591 | # This is not supported by MIL, so we convert it to a MIL int32 type 592 | # TODO(knielsen): Overflow check? 593 | sizes = np.array(op.slice_sizes, dtype=np.int32) 594 | 595 | # Clamp start indices to ensure they are within bounds: [0, operand_dim - slice_size] 596 | # This is required by the StableHLO specification 597 | shape = mb.shape(x=x) 598 | begin = clamp_index(begin, shape, sizes) 599 | 600 | cml_op = mb.slice_by_size(x=x, begin=begin, size=sizes) 601 | context.add_result(op.result, cml_op) 602 | 603 | @register_stablehlo_op 604 | def op_slice(self, context: TranslationContext, op: SliceOp): 605 | x = context[op.operand.get_name()] 606 | 607 | begin = np.array(op.start_indices, dtype=np.int32) 608 | end = np.array(op.limit_indices, dtype=np.int32) 609 | stride = np.array(op.strides, dtype=np.int32) 610 | 611 | cml_op = mb.slice_by_index( 612 | x=x, 613 | begin=begin, 614 | end=end, 615 | stride=stride, 616 | ) 617 | context.add_result(op.result, cml_op) 618 | 619 | @register_stablehlo_op 620 | def op_dynamic_update_slice(self, context: TranslationContext, op: DynamicUpdateSliceOp): 621 | x = context[op.operand.get_name()] 622 | updates = context[op.update.get_name()] 623 | 624 | start_indices = [context[i.get_name()] for i in op.start_indices] 625 | start_indices = mb.concat(values=start_indices, axis=0) 626 | 627 | # Clamp start indices to ensure they are within bounds: [0, operand_dim - update_dim] 628 | # This is required by the StableHLO specification 629 | shape = mb.shape(x=x) 630 | update_shape = mb.shape(x=updates) 631 | start_indices = clamp_index(start_indices, shape, update_shape) 632 | 633 | end_indices = mb.add(x=start_indices, y=op.update.type.shape) 634 | 635 | update_res = mb.slice_update( 636 | x=x, 637 | update=updates, 638 | begin=start_indices, 639 | end=end_indices, 640 | ) 641 | context.add_result(op.result, update_res) 642 | 643 | @register_stablehlo_op 644 | def op_convolution(self, context: TranslationContext, op: ConvolutionOp): 645 | dim_spec = hlo.ConvDimensionNumbers(op.dimension_numbers) 646 | # TODO(knielsen): It should be possible to remove this batch dimension check, but 647 | # there should be a unit test testing it. 648 | if dim_spec.input_batch_dimension != 0 or dim_spec.output_batch_dimension != 0: 649 | raise ValueError(f"Only the first dimension is currently supported for batch dimension. Got {dim_spec}") 650 | if len(dim_spec.input_spatial_dimensions) > 3 or len(dim_spec.output_spatial_dimensions) > 3: 651 | raise ValueError("MIL only supports convolutions with dim <= 3") 652 | 653 | if op.batch_group_count.value != 1: 654 | raise ValueError(f"Only a batch group count of 1 is supported. Got {op.batch_group_count.value}") 655 | 656 | # MIL expects it on the form [input_batch_dimension, input_feature_dimension, spatial_dimensions*] 657 | input_permutation = [ 658 | dim_spec.input_batch_dimension, 659 | dim_spec.input_feature_dimension, 660 | *dim_spec.input_spatial_dimensions 661 | ] 662 | x = context[op.lhs.get_name()] # The inputs comes from vars 663 | x = mb.transpose(x=x, perm=input_permutation) 664 | 665 | strides = None 666 | if op.window_strides is not None: 667 | strides = np.array(op.window_strides, dtype=np.int32) 668 | 669 | kernel_dilation = None 670 | if op.rhs_dilation is not None: 671 | kernel_dilation = np.array(op.rhs_dilation, dtype=np.int32) 672 | 673 | groups = op.feature_group_count.value 674 | 675 | # Handle padding 676 | # TODO(knielsen): Consider moving splat/non-splat handling to some utility 677 | in_rank = x.rank - 2 678 | if op.padding is None: 679 | pad = np.zeros((2 * in_rank), dtype=np.int32) 680 | elif op.padding.is_splat: 681 | pad = op.padding.get_splat_value().value * np.ones((2 * in_rank), dtype=np.int32) 682 | else: 683 | # We need to reshape the array to a linear array to match MILs expectation 684 | pad = np.reshape(np.array(op.padding, dtype=np.int32), (2 * in_rank, )) 685 | 686 | # We switch the convolution to a transposed convolution if we have lhs_dilation 687 | conv_type = mb.conv 688 | if op.lhs_dilation: 689 | lhs_dilations = np.array(op.lhs_dilation, dtype=np.int32) 690 | if np.any(lhs_dilations > 1): 691 | # This is a transpoed convolution 692 | if strides is not None: 693 | raise ValueError("For a conv with lhs dilation we expect the stride to be not set! " 694 | "Because convolution with input dilation d is equivalent to transposed " 695 | "convolution with stride d.") 696 | # Convolution with input dilation d is equivalent to transposed convolution with stride d 697 | strides = lhs_dilations 698 | 699 | output_shape = [op.result.type.shape[dim_spec.output_batch_dimension], 700 | op.result.type.shape[dim_spec.output_feature_dimension]] 701 | for d in dim_spec.output_spatial_dimensions: 702 | output_shape.append(op.result.type.shape[d]) 703 | 704 | conv_type = partial( 705 | mb.conv_transpose, 706 | output_shape=output_shape 707 | ) 708 | 709 | # Calculate the padding for the transposed convolution 710 | # We need to invert the padding: p_transpose = K - 1 - p_original 711 | # If the target padding is negative, we need to pad the input x 712 | kernel_spatial_dims = dim_spec.kernel_spatial_dimensions 713 | raw_weight_shape = context[op.rhs.get_name()].shape 714 | kernel_sizes = [raw_weight_shape[d] for d in kernel_spatial_dims] 715 | 716 | new_pad_out = [] 717 | pad_in = [] 718 | 719 | for i in range(len(kernel_sizes)): 720 | k = kernel_sizes[i] 721 | s = strides[i] 722 | d = kernel_dilation[i] if kernel_dilation is not None else 1 723 | k_eff = (k - 1) * d + 1 724 | 725 | p_low = pad[2*i] 726 | p_high = pad[2*i+1] 727 | 728 | # Target crop 729 | t_low = k_eff - 1 - p_low 730 | t_high = k_eff - 1 - p_high 731 | 732 | # Calculate input padding needed to satisfy non-negative crop 733 | # pad_in >= ceil(-t / s) 734 | pi_low = max(0, (-t_low + s - 1) // s) 735 | pi_high = max(0, (-t_high + s - 1) // s) 736 | 737 | # Calculate output crop 738 | po_low = t_low + pi_low * s 739 | po_high = t_high + pi_high * s 740 | 741 | new_pad_out.extend([po_low, po_high]) 742 | pad_in.extend([pi_low, pi_high]) 743 | 744 | pad = np.array(new_pad_out, dtype=np.int32) 745 | pad_in = np.array(pad_in, dtype=np.int32) 746 | 747 | if np.any(pad_in > 0): 748 | # Apply padding to x 749 | # x is [batch, channel, spatial...] 750 | x_rank = len(x.shape) 751 | full_pad_in = np.zeros(2 * x_rank, dtype=np.int32) 752 | # Fill spatial padding starting at dimension 2 753 | for i in range(len(pad_in)): 754 | full_pad_in[4 + i] = pad_in[i] 755 | x = pad_with_cast(x=x, pad=full_pad_in) 756 | 757 | if np.any(pad < 0): 758 | raise ValueError("The case where the padding turns negative when translating to a " 759 | "transposed convolution is not supported.") 760 | 761 | # The MIL weights should be on form: 762 | # - normal convolutions: [output_features, input_features / groups, spatial kernels*] 763 | # - transposed convolutions: [input_features, output_features / groups, spatial kernels*] 764 | weight = context[op.rhs.get_name()] # The weights are numpy arrays 765 | weight_permutation = [] 766 | if conv_type == mb.conv: 767 | weight_permutation = [ 768 | dim_spec.kernel_output_feature_dimension, 769 | dim_spec.kernel_input_feature_dimension, 770 | *dim_spec.kernel_spatial_dimensions 771 | ] 772 | else: 773 | weight_permutation = [ 774 | dim_spec.kernel_input_feature_dimension, 775 | dim_spec.kernel_output_feature_dimension, 776 | *dim_spec.kernel_spatial_dimensions 777 | ] 778 | weight = mb.transpose(x=weight, perm=weight_permutation) 779 | 780 | # TODO(knielsen): Make this check more readable! 781 | # It is executed for conv transpose 782 | if conv_type != mb.conv: 783 | # MIL expects the weights to be reversed along the kernel dimensions 784 | kernel_dimensions = [i + 2 for i in range(len(weight.shape) - 2)] 785 | weight = mb.reverse(x=weight, axes=kernel_dimensions) 786 | 787 | cml_conv = conv_type( 788 | x=x, 789 | weight=weight, 790 | strides=strides, 791 | pad_type="custom", 792 | pad=pad, 793 | dilations=kernel_dilation, 794 | groups=groups, 795 | ) 796 | 797 | # Re-arrange output dimensions to match expectation 798 | # MIL outputs on the form [batch, features, spatial dims*] 799 | output_permutation = inverse_permutation([ 800 | dim_spec.output_batch_dimension, 801 | dim_spec.output_feature_dimension, 802 | *dim_spec.output_spatial_dimensions 803 | ]) 804 | cml_conv = mb.transpose(x=cml_conv, perm=output_permutation) 805 | 806 | context.add_result(op.result, cml_conv) 807 | 808 | @register_stablehlo_op 809 | def op_max(self, context: TranslationContext, op: MaxOp): 810 | self.__simple_binary_op(context, mb.maximum, op) 811 | 812 | @register_stablehlo_op 813 | def op_min(self, context: TranslationContext, op: MinOp): 814 | self.__simple_binary_op(context, mb.minimum, op) 815 | 816 | @register_stablehlo_op 817 | def op_rsqrt(self, context: TranslationContext, op: RsqrtOp): 818 | self.__simple_unary_op(context, mb.rsqrt, op) 819 | 820 | @register_stablehlo_op 821 | def op_tanh(self, context: TranslationContext, op: TanhOp): 822 | self.__simple_unary_op(context, mb.tanh, op) 823 | 824 | @register_stablehlo_op 825 | def op_sine(self, context: TranslationContext, op: SineOp): 826 | self.__simple_unary_op(context, mb.sin, op) 827 | 828 | @register_stablehlo_op 829 | def op_cosine(self, context: TranslationContext, op: CosineOp): 830 | self.__simple_unary_op(context, mb.cos, op) 831 | 832 | @register_stablehlo_op 833 | def op_tan(self, context: TranslationContext, op: TanOp): 834 | self.__simple_unary_op(context, mb.tan, op) 835 | 836 | @register_stablehlo_op 837 | def op_atan2(self, context: TranslationContext, op: Atan2Op): 838 | y = context[op.lhs.get_name()] 839 | x = context[op.rhs.get_name()] 840 | # Notice the fraction may be +-inf 841 | fraction = mb.real_div(x=y, y=x) 842 | atan2_res = mb.atan(x=fraction) 843 | # We need to adjust for negative x, based on the sign of y 844 | atan2_res_adjusted = mb.add(x=atan2_res, y=mb.mul(x=mb.sign(x=y), y=np.pi)) 845 | atan2_res = mb.select( 846 | cond=mb.less(x=x, y=0.0), 847 | a=atan2_res_adjusted, 848 | b=atan2_res, 849 | ) 850 | context.add_result(op.result, atan2_res) 851 | 852 | @register_stablehlo_op 853 | def op_concatenate(self, context: TranslationContext, op: ConcatenateOp): 854 | values = [context[input.get_name()] for input in op.inputs] 855 | values = promote_input_dtypes(values) 856 | mil_res = mb.concat(values=values, axis=op.dimension.value) 857 | context.add_result(op.result, mil_res) 858 | 859 | @register_stablehlo_op 860 | def op_reverse(self, context: TranslationContext, op: ReverseOp): 861 | x = context[op.operand.get_name()] 862 | mil_res = mb.reverse(x=x, axes=np.array(op.dimensions, dtype=np.int32)) 863 | context.add_result(op.result, mil_res) 864 | 865 | @register_stablehlo_op 866 | def op_isfinite(self, context: TranslationContext, op: IsFiniteOp): 867 | x = context[op.x.get_name()] 868 | # All finite numbers will have abs(x) < inf 869 | infinity = np.array(np.inf, dtype=get_numpy_type(x)) 870 | mil_res = mb.less(x=mb.abs(x=x), y=infinity) 871 | context.add_result(op.result, mil_res) 872 | 873 | @register_stablehlo_op 874 | def op_reduce(self, context: TranslationContext, op: ReduceOp): 875 | # HLO reductions can be arbitrarily complex and defines a custom function 876 | # specifying the reduction. 877 | # Unforunately this level of granularity is not supported through MIL. 878 | # We try to detect some simple cases for reductions mapping to native MIL 879 | # instructions, and otherwise fall back to a MIL while-loop based implementation. 880 | inputs = [context[input.get_name()] for input in op.inputs] 881 | init_values = [context[init_value.get_name()] for init_value in op.init_values] 882 | result_types = [result.type for result in op.results] 883 | 884 | mil_results = compute_reduction(self, context, inputs, op.dimensions, op.body, init_values, result_types) 885 | for (res, mil_res) in zip(op.results, mil_results): 886 | context.add_result(res, mil_res) 887 | 888 | @register_stablehlo_op 889 | def op_reduce_window(self, context: TranslationContext, op: ReduceWindowOp): 890 | if op.window_dilations and not np.all(op.window_dilations == 1): 891 | raise ValueError("Window dilations are currently unsupported for windowed reduce") 892 | if op.base_dilations and not np.all(op.base_dilations == 1): 893 | raise ValueError("Base dilations are currently unsupported for windowed reduce") 894 | 895 | inputs_rank = len(op.window_dimensions) 896 | window_strides = op.window_strides 897 | if not window_strides: 898 | window_strides = np.ones((inputs_rank,), dtype=np.int32) 899 | 900 | inputs = [context[input.get_name()] for input in op.inputs] 901 | init_values = [context[init_value.get_name()] for init_value in op.init_values] 902 | 903 | # Pad the inputs if required 904 | if op.padding: 905 | padding = np.reshape(np.array(op.padding, dtype=np.int32), (2 * inputs_rank,)) 906 | inputs = [ 907 | pad_with_cast(x=input, pad=padding, constant_val=mb.reduce_max(x=init_value)) 908 | for input, init_value in zip(inputs, init_values) 909 | ] 910 | 911 | # Unfortunately CoreML only supports tensors with rank <= 6. 912 | # Due to the re-shaping and windowing operations inside `__compute_windowed_reduction`, this 913 | # means the function can not be called with tensors of rank >= 4. 914 | # To work around this problem, we have to iterate over the leading dimensions not being 915 | # windowed over, and calculate the result values incrementally. 916 | fixed_dimensions = [] 917 | reduction_dimensions = [] 918 | for axis in range(inputs_rank): 919 | if op.window_dimensions[axis] == 1 and window_strides[axis] == 1: 920 | fixed_dimensions.append(axis) 921 | else: 922 | reduction_dimensions.append(axis) 923 | permutation = fixed_dimensions + reduction_dimensions 924 | 925 | # We will put as few dimensions as possible in the loop_dimensions (i.e. we may 926 | # choose to put some of the `fixedf_dimensions` inside the reduction itself) 927 | max_dims = 3 928 | if len(reduction_dimensions) > max_dims: 929 | raise ValueError("Due to CoreML's rank <= 5 restriction, it is not supported to reduce on more then 3 dimensions!") 930 | loop_dimensions = fixed_dimensions[:max(0, inputs_rank - max_dims)] 931 | loop_shapes = [inputs[0].shape[dim] for dim in loop_dimensions] 932 | loop_shape_rank = len(loop_shapes) 933 | 934 | # Transpose the input so they are easily indexable inside the loop 935 | transposed_inputs = [mb.transpose(x=input, perm=permutation) for input in inputs] 936 | 937 | def compute_reduction(result_idx, *partial_results): 938 | # Pick out the attributes from the dimensions we are reducing over for this index 939 | idx_dims = permutation[loop_shape_rank:] 940 | idx_inputs = [index_by_slices(input, [result_idx] + [...]) for input in transposed_inputs] 941 | idx_window_dimensions = [op.window_dimensions[dim] for dim in idx_dims] 942 | idx_window_strides = [window_strides[dim] for dim in idx_dims] 943 | idx_result_types = [ 944 | index_by_slices(partial_result, [result_idx] + [...]) 945 | for partial_result in partial_results 946 | ] 947 | 948 | if loop_shape_rank > 0: 949 | # We need to squeeze out the loop (result_idx) dimensions 950 | idx_inputs = [ 951 | mb.reshape(x=input, shape=mb.slice_by_size(x=mb.shape(x=input), begin=[loop_shape_rank], size=[-1])) 952 | for input in idx_inputs 953 | ] 954 | idx_result_types = [ 955 | mb.reshape(x=result, shape=mb.slice_by_size(x=mb.shape(x=result), begin=[loop_shape_rank], size=[-1])) 956 | for result in idx_result_types 957 | ] 958 | 959 | results = compute_windowed_reduction( 960 | converter=self, 961 | context=context, 962 | inputs=idx_inputs, 963 | window_dimensions=idx_window_dimensions, 964 | window_strides=idx_window_strides, 965 | body=op.body, 966 | init_values=init_values, 967 | result_types=idx_result_types, 968 | ) 969 | 970 | result_rank = inputs_rank - loop_shape_rank 971 | return [ 972 | update_tensor_by_slice(acc, [result_idx] + [slice(None)] * result_rank, result) 973 | for acc, result in zip(partial_results, results) 974 | ] 975 | 976 | result_types = [result.type for result in op.results] 977 | reduction_results = [ 978 | mb.transpose( 979 | x=np.zeros(result_type.shape, dtype=get_numpy_type(result_type.element_type)), 980 | perm=permutation, 981 | ) 982 | for result_type in result_types 983 | ] 984 | reduction_results = iterate_indexes_in_shapes(compute_reduction, [loop_shapes], reduction_results, unroll_limit=5) 985 | reduction_results = [ 986 | mb.transpose(x=reduction_result, perm=inverse_permutation(permutation)) 987 | for reduction_result in reduction_results 988 | ] 989 | 990 | for (res, mil_res) in zip(op.results, reduction_results): 991 | context.add_result(res, mil_res) 992 | 993 | @register_stablehlo_op 994 | def op_iota(self, context: TranslationContext, op: IotaOp): 995 | res = range_along_dim(op.result.type.shape, int(op.iota_dimension), get_numpy_type(op.result.type.element_type)) 996 | context.add_result(op.result, res) 997 | 998 | @register_stablehlo_op 999 | def op_gather(self, context: TranslationContext, op: GatherOp): 1000 | """ 1001 | Calculates special cases of the GatherOp. Assumes no backing dims, and 1002 | that the index_vector_dim is always the last indexing dimension. 1003 | 1004 | TODO(knielsen): Consider if this can be done in a more efficient way 1005 | """ 1006 | start_indices = context[op.start_indices.get_name()] 1007 | operand = context[op.operand.get_name()] 1008 | 1009 | operand_rank = len(operand.shape) 1010 | start_indices_rank = len(start_indices.shape) 1011 | 1012 | dim_numbers = hlo.GatherDimensionNumbers(op.dimension_numbers) 1013 | dim_mapping = dim_numbers.start_index_map 1014 | dim_batches = dim_numbers.operand_batching_dims 1015 | 1016 | if dim_numbers.index_vector_dim != start_indices_rank - 1: 1017 | raise ValueError("The `index_vector_dim` is only supported to be the last dimension") 1018 | 1019 | # Handle simple gather cases directly, avoiding the while-loop below 1020 | inferred_sizes = np.array([ 1021 | 1 if i in dim_mapping or i in dim_batches else 1022 | operand.shape[i] for i in range(operand_rank)] 1023 | ) 1024 | if dim_batches == dim_numbers.start_indices_batching_dims and \ 1025 | (not dim_batches or np.max(dim_batches) < len(dim_batches)) and \ 1026 | np.all(np.array(op.slice_sizes) == inferred_sizes): 1027 | upper, lower = [operand.shape[i] - 1 for i in dim_mapping], [0] * len(dim_mapping) 1028 | 1029 | def broadcastable(x): 1030 | return np.array(x)[(None,) * (start_indices_rank - 1)] 1031 | clamped_indices = mb.minimum(x=mb.maximum(x=start_indices, y=broadcastable(lower)), y=broadcastable(upper)) 1032 | clamped_indices = mb.gather(x=clamped_indices, indices=np.argsort(dim_mapping), axis=-1) 1033 | if len(dim_mapping) == 1: 1034 | if start_indices_rank > 1: 1035 | clamped_indices = mb.squeeze(x=clamped_indices, axes=(start_indices_rank - 1,)) 1036 | result = mb.gather(x=operand, indices=clamped_indices, axis=dim_mapping[0], batch_dims=len(dim_batches)) 1037 | context.add_result(op.result, result) 1038 | return 1039 | elif np.max(dim_mapping) < len(dim_mapping) + len(dim_batches): 1040 | result = mb.gather_nd(x=operand, indices=clamped_indices, batch_dims=len(dim_batches)) 1041 | window_outputs = [ 1042 | i for i in range(operand_rank) 1043 | if i not in dim_batches and i not in dim_numbers.collapsed_slice_dims 1044 | ] 1045 | window_outputs = [j for i, j in zip(window_outputs, dim_numbers.offset_dims) if op.slice_sizes[i] == 1] 1046 | if window_outputs: 1047 | result = mb.expand_dims(x=result, axes=window_outputs) 1048 | context.add_result(op.result, result) 1049 | return 1050 | 1051 | result_rank = len(op.result.type.shape) 1052 | slice_sizes = op.slice_sizes 1053 | result_iteration_axes = [axis for axis in range(result_rank) if axis not in dim_numbers.offset_dims] 1054 | 1055 | def compute_index_slice(slice_idx, *partial_results): 1056 | partial_results = partial_results[0] 1057 | 1058 | slice_start = [] 1059 | slice_end = [] 1060 | 1061 | for operand_dim in range(operand_rank): 1062 | if operand_dim in dim_numbers.start_index_map: 1063 | start_index_dim = dim_numbers.start_index_map.index(operand_dim) 1064 | elements = operand.shape[operand_dim] 1065 | 1066 | start_index = index_by_slices(start_indices, [slice_idx] + [start_index_dim]) 1067 | start_index = mb.reshape(x=start_index, shape=(1,)) 1068 | 1069 | actual_start_index = mb.maximum(x=mb.minimum(x=start_index, y=elements - slice_sizes[operand_dim]), y=0) 1070 | end_index = mb.add(x=actual_start_index, y=slice_sizes[operand_dim]) 1071 | slice_start.append(actual_start_index) 1072 | slice_end.append(end_index) 1073 | elif operand_dim in dim_numbers.operand_batching_dims: 1074 | batch_index = dim_numbers.operand_batching_dims.index(operand_dim) 1075 | slice_batch = dim_numbers.start_indices_batching_dims[batch_index] 1076 | start_index = mb.slice_by_size(x=slice_idx, begin=(slice_batch,), size=(1,)) 1077 | slice_start.append(start_index) 1078 | slice_end.append(mb.add(x=start_index, y=1)) 1079 | elif operand_dim in dim_numbers.collapsed_slice_dims: 1080 | slice_start.append(mb.reshape(x=0, shape=(1,))) 1081 | slice_end.append(mb.reshape(x=1, shape=(1,))) 1082 | else: 1083 | slice_start.append(mb.reshape(x=0, shape=(1,))) 1084 | slice_end.append(mb.reshape(x=slice_sizes[operand_dim], shape=(1,))) 1085 | 1086 | selected_slice = mb.slice_by_index( 1087 | x=operand, 1088 | begin=mb.concat(values=slice_start, axis=0), 1089 | end=mb.concat(values=slice_end, axis=0), 1090 | ) 1091 | if len(dim_numbers.collapsed_slice_dims) > 0: 1092 | selected_slice = mb.squeeze(x=selected_slice, axes=dim_numbers.collapsed_slice_dims) 1093 | 1094 | # Figure out which result to update 1095 | update_slice_spec = [] 1096 | stack_axes_idx = 0 1097 | for output_dim in range(result_rank): 1098 | if output_dim in result_iteration_axes: 1099 | result_idx = mb.gather(x=slice_idx, indices=[stack_axes_idx]) 1100 | update_slice_spec.append(result_idx) 1101 | stack_axes_idx += 1 1102 | else: 1103 | update_slice_spec.append(slice(None)) 1104 | return [update_tensor_by_slice(partial_results, update_slice_spec, selected_slice)] 1105 | 1106 | result_dtype = get_mil_type_from_ir(op.result.type.element_type) 1107 | result = mb.fill(shape=op.result.type.shape, value=mb.cast(x=0, dtype=dtype_str(result_dtype))) 1108 | result_iteration_shape = [result.shape[stack_axis] for stack_axis in result_iteration_axes] 1109 | result, = iterate_indexes_in_shapes(compute_index_slice, [result_iteration_shape], [result], unroll_limit=5) 1110 | 1111 | context.add_result(op.result, result) 1112 | 1113 | @register_stablehlo_op 1114 | def op_scatter(self, context: TranslationContext, op: ScatterOp): 1115 | dim_numbers = hlo.ScatterDimensionNumbers(op.scatter_dimension_numbers) 1116 | dim_mapping = dim_numbers.scattered_dims_to_operand_dims 1117 | operand = context[op.inputs[0].get_name()] 1118 | scatter_indices = context[op.scatter_indices.get_name()] 1119 | updates = context[op.updates[0].get_name()] 1120 | 1121 | if len(dim_numbers.input_batching_dims) > 0: 1122 | raise ValueError("Scatter batching index is not supported!") 1123 | if len(op.inputs) != 1 or len(op.updates) != 1: 1124 | raise ValueError("Scatter with multiple operands is not supported!") 1125 | 1126 | scatter_indices_rank = len(scatter_indices.shape) 1127 | if scatter_indices_rank == 0 or 0 in scatter_indices.shape: 1128 | # Special case for empty scatter indices 1129 | context.add_result(op.results[0], operand) 1130 | return 1131 | 1132 | if np.max(dim_mapping) >= len(dim_mapping): 1133 | raise ValueError("Scatter windows are only supported with dimension numbers contiguous with the rank!") 1134 | # MIL only supports scatter window update sizes that match the operand shape 1135 | # updates must be the shape as `indices.shape[:-1] + data.shape[indices.shape[-1]:]` 1136 | # [sic] via 1137 | # https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.scatter_gather.scatter_nd 1138 | if scatter_indices.shape != (1,) and \ 1139 | updates.shape != scatter_indices.shape[:-1] + operand.shape[scatter_indices.shape[-1]:]: 1140 | raise ValueError("Scatter windows that only partially fill dimensions are not supported!") 1141 | 1142 | # this can be done pre-emptively because of the constraint on scatter windows 1143 | scatter_indices = mb.gather(x=scatter_indices, indices=np.argsort(dim_mapping), axis=-1) 1144 | 1145 | # StableHLO supports arbitrary scatter computations, but MIL has a fixed set 1146 | # We try to match the update computation to a known binary operation 1147 | _, mil_binary_op, mode = match_computation(op.update_computation) 1148 | 1149 | if mil_binary_op is None: 1150 | raise ValueError("Unsupported update mode for scatter operation") 1151 | 1152 | upper_bound = np.array(operand.shape[:len(dim_mapping)], dtype=np.int32)[(None,) * (scatter_indices_rank - 1)] 1153 | valid = mb.logical_and( 1154 | x=mb.greater_equal(x=scatter_indices, y=0), 1155 | y=mb.less(x=scatter_indices, y=upper_bound) 1156 | ) 1157 | 1158 | def along(n): 1159 | return mb.slice_by_index( 1160 | x=valid, begin=(0,) * (scatter_indices_rank - 1) + (n,), 1161 | end=scatter_indices.shape[:-1] + (n + 1,) 1162 | ) 1163 | 1164 | # unrolling O(scatter_indices.rank) 1165 | reduction = along(0) 1166 | for i in range(1, scatter_indices.shape[-1]): 1167 | reduction = mb.logical_and(x=reduction, y=along(i)) 1168 | reduction = mb.squeeze(x=reduction, axes=(scatter_indices_rank - 1,)) 1169 | 1170 | # Special handling for rank-0 reduction (single index update). 1171 | # It supports updating a window that is a subset of the dimension (partial update), 1172 | # which `scatter_nd` does not support (it only supports full slice updates). 1173 | if reduction.rank == 0: 1174 | assert scatter_indices.shape == (1,), \ 1175 | f"unexpected input shape for scatter indices of {scatter_indices.shape}" 1176 | assert updates.rank == operand.rank 1177 | 1178 | # The index to update 1179 | update_index = scatter_indices 1180 | 1181 | # Helper to construct end indices for slicing 1182 | # If rank <= 1, it's just the bound. Otherwise, it's [bound, dim1, dim2, ...] 1183 | def get_end_indices(operand, bound): 1184 | if operand.rank <= 1: 1185 | return bound 1186 | return mb.concat(values=(bound, operand.shape[1:]), axis=0) 1187 | 1188 | # 1. Slice before the update index 1189 | # operand[:update_index] 1190 | before = mb.slice_by_index( 1191 | x=operand, 1192 | begin=(0,) * operand.rank, 1193 | end=get_end_indices(operand, update_index) 1194 | ) 1195 | 1196 | # 2. Slice after the update index 1197 | # operand[update_index+window_size:] 1198 | # We need to clamp the start index for 'after' to be at most operand.shape[0] 1199 | # to avoid out of bounds. 1200 | update_window_size = updates.shape[0] 1201 | update_end_index = mb.minimum( 1202 | x=mb.add(x=update_index, y=update_window_size), 1203 | y=operand.shape[0] 1204 | ) 1205 | after = mb.slice_by_index( 1206 | x=operand, 1207 | begin=get_end_indices(operand, update_end_index), 1208 | end=operand.shape 1209 | ) 1210 | 1211 | # 3. The update value itself 1212 | # We need to extract the current value at the update index to apply the update operation 1213 | # operand[update_index:update_index+window_size] 1214 | current_value_slice = mb.slice_by_index( 1215 | x=operand, 1216 | begin=get_end_indices(operand, update_index), 1217 | end=get_end_indices(operand, update_end_index) 1218 | ) 1219 | 1220 | # Apply the update computation (add, mul, etc.) 1221 | new_value_slice = mil_binary_op(x=current_value_slice, y=updates) 1222 | 1223 | # 4. Concatenate parts to form the result 1224 | # [before, new_value, after] 1225 | # We only do this if the index is valid (reduction condition) 1226 | # 'reduction' here is actually a boolean scalar indicating if the index is valid 1227 | is_valid_index = reduction 1228 | 1229 | result_if_valid = mb.concat(values=(before, new_value_slice, after), axis=0) 1230 | result = mb.select( 1231 | cond=is_valid_index, 1232 | a=result_if_valid, 1233 | b=operand 1234 | ) 1235 | else: 1236 | where = mb.non_zero(x=reduction) 1237 | scatter_indices = mb.gather_nd(x=scatter_indices, indices=where) 1238 | updates = mb.gather_nd(x=updates, indices=where) 1239 | result = mb.scatter_nd(data=operand, indices=scatter_indices, updates=updates, mode=mode) 1240 | context.add_result(op.results[0], result) 1241 | 1242 | @register_stablehlo_op 1243 | def op_custom_call(self, context: TranslationContext, op: CustomCallOp): 1244 | if op.call_target_name.value.startswith("mhlo."): 1245 | mapped_op = None 1246 | op_impl = None 1247 | match op.call_target_name.value: 1248 | case "mhlo.topk": 1249 | mapped_op = TopKOp 1250 | op_impl = self._op_mhlo_topk 1251 | case "mhlo.asin": 1252 | mapped_op = AsinOp 1253 | op_impl = self._op_mhlo_asin 1254 | case "mhlo.sinh": 1255 | mapped_op = SinhOp 1256 | op_impl = self._op_mhlo_sinh 1257 | case "mhlo.asinh": 1258 | mapped_op = AsinhOp 1259 | op_impl = self._op_mhlo_asinh 1260 | case "mhlo.acos": 1261 | mapped_op = AcosOp 1262 | op_impl = self._op_mhlo_acos 1263 | case "mhlo.cosh": 1264 | mapped_op = CoshOp 1265 | op_impl = self._op_mhlo_cosh 1266 | case "mhlo.acosh": 1267 | mapped_op = AcoshOp 1268 | op_impl = self._op_mhlo_acosh 1269 | case "mhlo.atanh": 1270 | mapped_op = AtanhOp 1271 | op_impl = self._op_mhlo_atanh 1272 | 1273 | if not mapped_op: 1274 | raise ValueError(f"mhlo op '{op.call_target_name.value}' is not implemented") 1275 | if not op_impl: 1276 | raise ValueError(f"mhlo op '{op.call_target_name.value}' does not have an implementation") 1277 | 1278 | mhlo_attributes = {attr.name: attr.attr for attr in list(op.attributes["mhlo.attributes"])} 1279 | delegate_op = partial(mapped_op, **mhlo_attributes, loc=op.location)(*op.operands) 1280 | 1281 | # We manually have to handle the results, as the current API does not allow naming 1282 | # the `delegate_op` results according to the custom call results 1283 | mil_results = op_impl(context, delegate_op) 1284 | for (custom_call_result, mil_result) in zip(op.results, mil_results): 1285 | context.add_result(custom_call_result, mil_result) 1286 | 1287 | return 1288 | 1289 | raise ValueError(f"Custom call is not supported: {op.call_target_name}") 1290 | 1291 | def _op_mhlo_topk(self, context: TranslationContext, op: TopKOp): 1292 | """ 1293 | This is a MHLO op, and follows a slightly different pattern, since it is unvoked by a 1294 | custom call. It will return the results, as we currently can not rename the results 1295 | in the TopKOp 1296 | """ 1297 | x = context[op.operand.get_name()] 1298 | descending = op.largest is None or op.largest.value 1299 | mil_res = mb.topk(x=x, k=op.k.value, ascending=not descending) 1300 | return mil_res 1301 | 1302 | def _op_mhlo_asin(self, context: TranslationContext, op: AsinOp): 1303 | x = context[op.operand.get_name()] 1304 | mil_res = mb.asin(x=x) 1305 | return [mil_res] 1306 | 1307 | def _op_mhlo_sinh(self, context: TranslationContext, op: SinhOp): 1308 | x = context[op.operand.get_name()] 1309 | mil_res = mb.sinh(x=x) 1310 | return [mil_res] 1311 | 1312 | def _op_mhlo_asinh(self, context: TranslationContext, op: AsinhOp): 1313 | x = context[op.operand.get_name()] 1314 | # asinh(x) = log(x + sqrt(x^2 + 1)) 1315 | x_sq = mb.mul(x=x, y=x) 1316 | x_sq_plus_1 = mb.add(x=x_sq, y=1.0) 1317 | sqrt_part = mb.sqrt(x=x_sq_plus_1) 1318 | log_arg = mb.add(x=x, y=sqrt_part) 1319 | mil_res = mb.log(x=log_arg) 1320 | return [mil_res] 1321 | 1322 | def _op_mhlo_acos(self, context: TranslationContext, op: AcosOp): 1323 | x = context[op.operand.get_name()] 1324 | mil_res = mb.acos(x=x) 1325 | return [mil_res] 1326 | 1327 | def _op_mhlo_cosh(self, context: TranslationContext, op: CoshOp): 1328 | x = context[op.operand.get_name()] 1329 | mil_res = mb.cosh(x=x) 1330 | return [mil_res] 1331 | 1332 | def _op_mhlo_acosh(self, context: TranslationContext, op: AcoshOp): 1333 | x = context[op.operand.get_name()] 1334 | # acosh(x) = log(x + sqrt(x^2 - 1)) 1335 | x_sq = mb.mul(x=x, y=x) 1336 | x_sq_minus_1 = mb.sub(x=x_sq, y=1.0) 1337 | sqrt_part = mb.sqrt(x=x_sq_minus_1) 1338 | log_arg = mb.add(x=x, y=sqrt_part) 1339 | mil_res = mb.log(x=log_arg) 1340 | return [mil_res] 1341 | 1342 | def _op_mhlo_atanh(self, context: TranslationContext, op: AtanhOp): 1343 | x = context[op.operand.get_name()] 1344 | # atanh(x) = 0.5 * log((1 + x) / (1 - x)) 1345 | one_plus_x = mb.add(x=1.0, y=x) 1346 | one_minus_x = mb.sub(x=1.0, y=x) 1347 | div_res = mb.real_div(x=one_plus_x, y=one_minus_x) 1348 | log_res = mb.log(x=div_res) 1349 | mil_res = mb.mul(x=0.5, y=log_res) 1350 | return [mil_res] 1351 | 1352 | def invoke_hlo_function(self, context: TranslationContext, func_name: str, hlo_params, hlo_func_body, cml_args): 1353 | # Enter variable context for the function call 1354 | context.push_function(func_name) 1355 | 1356 | # Setup arguments for the function 1357 | for hlo_func_param, actual_arg in zip(hlo_params, cml_args): 1358 | context.add_result(hlo_func_param, actual_arg) 1359 | 1360 | # Process the function 1361 | if len(hlo_func_body.blocks) != 1: 1362 | raise ValueError(f"Unsupported function with {len(hlo_func_body.blocks)} blocks") 1363 | outputs = self.process_block(context, hlo_func_body.blocks[0]) 1364 | 1365 | # Exit the function context 1366 | context.pop_function() 1367 | 1368 | return outputs 1369 | 1370 | def __simple_unary_op(self, context: TranslationContext, mil_op, hlo_op): 1371 | operand = context[hlo_op.operand.get_name()] 1372 | cml_op = mil_op(x=operand) 1373 | context.add_result(hlo_op.result, cml_op) 1374 | 1375 | def __simple_binary_op(self, context: TranslationContext, mil_op, hlo_op): 1376 | lhs = context[hlo_op.lhs.get_name()] 1377 | rhs = context[hlo_op.rhs.get_name()] 1378 | cml_op = mil_op(x=lhs, y=rhs) 1379 | context.add_result(hlo_op.result, cml_op) 1380 | --------------------------------------------------------------------------------