├── tests ├── __init__.py ├── ir │ ├── __init__.py │ └── test_peephole.py ├── codegen │ └── __init__.py ├── test_sum.py ├── test_numpy.py ├── test_operators.py ├── test_tensor_method.py ├── test_cli.py ├── test_format.py ├── test_desugar.py ├── test_evaluate.py ├── test_expression.py └── test_combinatorically.py ├── .python-version ├── fuzz_tests ├── __init__.py ├── test_cli.py ├── test_generate.py ├── test_parsing.py └── strategies.py ├── .gitattributes ├── docs ├── CNAME ├── index.md ├── getting-started.md ├── contributing.md ├── tensors.md ├── evaluate.md └── creation.md ├── src └── tensora │ ├── generate │ ├── __init__.py │ ├── _base.py │ ├── _tensora.py │ └── _deparse_to_taco.py │ ├── iteration_graph │ ├── __init__.py │ ├── outputs │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _bucket.py │ │ └── _append.py │ ├── identifiable_expression │ │ ├── __init__.py │ │ ├── ast.py │ │ ├── _to_ir.py │ │ ├── _exhaust_tensor.py │ │ ├── _tensor_layer.py │ │ └── _extract_context.py │ ├── _definition.py │ ├── _names.py │ ├── _write_sparse_ir.py │ └── iteration_graph.py │ ├── codegen │ ├── __init__.py │ ├── _type_to_c.py │ ├── _hoist_declarations.py │ ├── _type_to_llvm.py │ └── _ir_to_c.py │ ├── ir │ ├── __init__.py │ ├── types.py │ ├── _builder.py │ ├── ast.py │ └── _peephole.py │ ├── __init__.py │ ├── expression │ ├── __init__.py │ ├── _exceptions.py │ ├── _parser.py │ └── ast.py │ ├── format │ ├── __init__.py │ ├── _exceptions.py │ ├── _format.py │ └── _parser.py │ ├── desugar │ ├── __init__.py │ ├── _to_identifiable.py │ ├── _best_algorithm.py │ ├── _exceptions.py │ ├── ast.py │ ├── _index_dimensions.py │ └── _desugar_expression.py │ ├── compile │ ├── _initialize_llvm.py │ ├── __init__.py │ ├── _compile_llvm.py │ ├── _compile_cffi.py │ ├── _porcelain.py │ └── _tensor_method.py │ ├── kernel_type.py │ ├── _stable_set.py │ ├── cli.py │ └── problem.py ├── tests_cffi ├── __init__.py ├── test_sum.py ├── test_evaluate.py └── test_combinatorically.py ├── .gitignore ├── .github └── workflows │ ├── release.yml │ └── ci.yml ├── LICENSE ├── mkdocs.yml ├── noxfile.py └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13.3 -------------------------------------------------------------------------------- /fuzz_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ir/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/codegen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | tensora.drhagen.com -------------------------------------------------------------------------------- /src/tensora/generate/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Language, generate_code 2 | from ._tensora import generate_module_tensora 3 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/__init__.py: -------------------------------------------------------------------------------- 1 | from ._definition import Definition, TensorDimension 2 | from ._generate_ir import generate_ir 3 | -------------------------------------------------------------------------------- /src/tensora/codegen/__init__.py: -------------------------------------------------------------------------------- 1 | from ._ir_to_c import ir_to_c, ir_to_c_function_definition, ir_to_c_statement 2 | from ._ir_to_llvm import ir_to_llvm 3 | -------------------------------------------------------------------------------- /src/tensora/ir/__init__.py: -------------------------------------------------------------------------------- 1 | from ._builder import SourceBuilder 2 | from ._peephole import peephole, peephole_function_definition, peephole_statement 3 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | from ._append import AppendOutput 2 | from ._base import Output 3 | from ._bucket import BucketOutput 4 | -------------------------------------------------------------------------------- /src/tensora/__init__.py: -------------------------------------------------------------------------------- 1 | from .compile import BackendCompiler, evaluate, evaluate_tensora, tensor_method 2 | from .format import Format, Mode 3 | from .tensor import Tensor 4 | -------------------------------------------------------------------------------- /src/tensora/expression/__init__.py: -------------------------------------------------------------------------------- 1 | from ._exceptions import InconsistentDimensionsError, MutatingAssignmentError, NameConflictError 2 | from ._parser import parse_assignment 3 | -------------------------------------------------------------------------------- /src/tensora/format/__init__.py: -------------------------------------------------------------------------------- 1 | from ._exceptions import InvalidModeOrderingError 2 | from ._format import Format, Mode 3 | from ._parser import parse_format, parse_named_format 4 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/__init__.py: -------------------------------------------------------------------------------- 1 | from ._exhaust_tensor import exhaust_tensor 2 | from ._extract_context import Context, extract_context 3 | from ._tensor_layer import TensorLayer 4 | from ._to_ir import to_ir 5 | -------------------------------------------------------------------------------- /tests_cffi/__init__.py: -------------------------------------------------------------------------------- 1 | # This folder cannot be nested until the import needs of the base packages and 2 | # the cffi extra are more differentiated. Right now, the cffi extra needs only 3 | # the setuptools package and only in Python 3.12. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | **/__pycache__ 3 | 4 | # poetry 5 | /.venv 6 | /dist 7 | 8 | # nox 9 | /.nox 10 | 11 | # coverage 12 | /.coverage* 13 | /htmlcov 14 | /coverage.xml 15 | 16 | # hypothesis 17 | /.hypothesis 18 | 19 | # mkdocs 20 | /site 21 | -------------------------------------------------------------------------------- /src/tensora/desugar/__init__.py: -------------------------------------------------------------------------------- 1 | from ._best_algorithm import best_algorithm 2 | from ._desugar_expression import desugar_assignment 3 | from ._exceptions import DiagonalAccessError, NoKernelFoundError 4 | from ._index_dimensions import index_dimensions 5 | from ._to_identifiable import to_identifiable 6 | -------------------------------------------------------------------------------- /tests/test_sum.py: -------------------------------------------------------------------------------- 1 | from tensora import Tensor, evaluate 2 | 3 | 4 | def test_sum_non_adjacent(): 5 | b = Tensor.from_lol([1, 2, 3]) 6 | c = Tensor.from_lol([[1, 3, 5], [2, 4, 6]]) 7 | d = Tensor.from_lol([7, 8, 9]) 8 | actual = evaluate("a(i) = b(i) + c(j,i) + d(i)", "d", b=b, c=c, d=d) 9 | expected = Tensor.from_lol([11, 17, 23]) 10 | assert actual == expected 11 | -------------------------------------------------------------------------------- /src/tensora/compile/_initialize_llvm.py: -------------------------------------------------------------------------------- 1 | __all__ = ["target"] 2 | 3 | import llvmlite.binding as llvm 4 | 5 | # Initialize the LLVM 6 | # https://llvmlite.readthedocs.io/en/latest/user-guide/binding/examples.html 7 | llvm.initialize_native_target() 8 | llvm.initialize_native_asmprinter() 9 | 10 | # Create the target representing the current host 11 | target = llvm.Target.from_default_triple() 12 | -------------------------------------------------------------------------------- /src/tensora/compile/__init__.py: -------------------------------------------------------------------------------- 1 | from ._cffi_ownership import ( 2 | allocate_taco_structure, 3 | taco_structure_to_cffi, 4 | take_ownership_of_arrays, 5 | take_ownership_of_tensor, 6 | take_ownership_of_tensor_members, 7 | tensor_cdefs, 8 | ) 9 | from ._initialize_llvm import target 10 | from ._porcelain import evaluate, evaluate_cffi, evaluate_tensora, tensor_method 11 | from ._tensor_method import BackendCompiler, BroadcastTargetIndexError, TensorMethod 12 | -------------------------------------------------------------------------------- /fuzz_tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from hypothesis import assume, given 2 | from hypothesis import strategies as st 3 | from typer.testing import CliRunner 4 | 5 | from tensora.cli import app 6 | 7 | runner = CliRunner() 8 | 9 | 10 | @given(command=st.lists(st.text())) 11 | def test_cli_cannot_crash(command): 12 | # Arguments to a CLI cannot contain null bytes. 13 | assume(not any("\0" in string for string in command)) 14 | 15 | _ = runner.invoke(app, command, catch_exceptions=False) 16 | -------------------------------------------------------------------------------- /tests_cffi/test_sum.py: -------------------------------------------------------------------------------- 1 | from tensora import BackendCompiler, Tensor, tensor_method 2 | 3 | 4 | def test_sum_non_adjacent(): 5 | b = Tensor.from_lol([1, 2, 3]) 6 | c = Tensor.from_lol([[1, 3, 5], [2, 4, 6]]) 7 | d = Tensor.from_lol([7, 8, 9]) 8 | function = tensor_method( 9 | "a(i) = b(i) + c(j,i) + d(i)", formats={}, backend=BackendCompiler.cffi 10 | ) 11 | actual = function(b=b, c=c, d=d) 12 | expected = Tensor.from_lol([11, 17, 23]) 13 | assert actual == expected 14 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/_definition.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Definition", "TensorDimension"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from ..format import Format 6 | from .identifiable_expression.ast import Tensor 7 | 8 | 9 | @dataclass(frozen=True, slots=True) 10 | class TensorDimension: 11 | name: str 12 | dimension: int 13 | 14 | 15 | @dataclass(frozen=True, slots=True) 16 | class Definition: 17 | output_variable: Tensor 18 | formats: dict[str, Format] 19 | indexes: dict[str, TensorDimension] 20 | -------------------------------------------------------------------------------- /src/tensora/desugar/_to_identifiable.py: -------------------------------------------------------------------------------- 1 | __all__ = ["to_identifiable"] 2 | 3 | from ..format import Format 4 | from ..iteration_graph.identifiable_expression import ast as id 5 | from . import ast as desugar 6 | 7 | 8 | def to_identifiable(self: desugar.Tensor, formats: dict[str, Format]) -> id.Tensor: 9 | format = formats[self.name] 10 | return id.Tensor( 11 | f"{self.id}_{self.name}", 12 | self.name, 13 | tuple(self.indexes[i_index] for i_index in format.ordering), 14 | format.modes, 15 | ) 16 | -------------------------------------------------------------------------------- /src/tensora/format/_exceptions.py: -------------------------------------------------------------------------------- 1 | __all__ = ["InvalidModeOrderingError"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from ._format import Mode 6 | 7 | 8 | @dataclass(frozen=True, slots=True) 9 | class InvalidModeOrderingError(Exception): 10 | modes: tuple[Mode, ...] 11 | ordering: tuple[int, ...] 12 | 13 | def __str__(self): 14 | return ( 15 | f"Expected ordering to have be some order of integers 0 until length of modes, " 16 | f"but got modes={self.modes} and ordering={self.ordering}" 17 | ) 18 | -------------------------------------------------------------------------------- /src/tensora/kernel_type.py: -------------------------------------------------------------------------------- 1 | __all__ = ["KernelType"] 2 | 3 | from enum import Enum 4 | 5 | 6 | class KernelType(str, Enum): 7 | # Python 3.10 does not support StrEnum, so do it manually 8 | assemble = "assemble" 9 | compute = "compute" 10 | evaluate = "evaluate" 11 | 12 | def is_assemble(self): 13 | return self == KernelType.assemble or self == KernelType.evaluate 14 | 15 | def is_compute(self): 16 | return self == KernelType.compute or self == KernelType.evaluate 17 | 18 | def __str__(self) -> str: 19 | return self.name 20 | -------------------------------------------------------------------------------- /fuzz_tests/test_generate.py: -------------------------------------------------------------------------------- 1 | from hypothesis import given 2 | 3 | from tensora.compile import BroadcastTargetIndexError, TensorMethod 4 | from tensora.desugar import DiagonalAccessError, NoKernelFoundError 5 | 6 | from .strategies import problem_and_tensors 7 | 8 | 9 | @given(problem_and_tensors()) 10 | def test_generate_cannot_crash(problem_inputs): 11 | problem, input_tensors = problem_inputs 12 | 13 | try: 14 | method = TensorMethod(problem) 15 | except (BroadcastTargetIndexError, DiagonalAccessError, NoKernelFoundError): 16 | return 17 | 18 | _ = method(**input_tensors) 19 | -------------------------------------------------------------------------------- /tests/test_numpy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora import Tensor 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "array", 8 | [ 9 | 0.0, 10 | 4.5, 11 | [], 12 | [[], []], 13 | [0, 0, 0], 14 | [[0, 1, 2], [0, 4, 5]], 15 | [[[0, 0, 3], [4, 5, 0]], [[0, 0, 0], [4, 5, 6]]], 16 | ], 17 | ) 18 | @pytest.mark.parametrize("format", ["d", "s"]) 19 | def test_to_from_numpy(array, format): 20 | numpy = pytest.importorskip("numpy") 21 | 22 | expected = numpy.array(array) 23 | 24 | tensor = Tensor.from_numpy(expected, format=format * expected.ndim) 25 | actual = Tensor.to_numpy(tensor) 26 | 27 | assert numpy.array_equal(actual, expected) 28 | -------------------------------------------------------------------------------- /tests/test_operators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora import Tensor 4 | 5 | 6 | @pytest.fixture 7 | def a_ds(): 8 | return Tensor.from_aos( 9 | [[1, 0], [0, 1], [1, 2]], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds" 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def b_ds(): 15 | return Tensor.from_aos( 16 | [[1, 1], [1, 2], [0, 2]], [-3.0, 4.0, 3.5], dimensions=(2, 3), format="ds" 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def c_ds_add(): 22 | return Tensor.from_aos( 23 | [[1, 0], [0, 1], [1, 2], [1, 1], [0, 2]], 24 | [2.0, -2.0, 8.0, -3.0, 3.5], 25 | dimensions=(2, 3), 26 | format="ds", 27 | ) 28 | 29 | 30 | def test_add_ds_ds(a_ds, b_ds, c_ds_add): 31 | actual = a_ds + b_ds 32 | 33 | assert actual == c_ds_add 34 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/outputs/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["Output"] 4 | 5 | from abc import abstractmethod 6 | 7 | from ...ir import SourceBuilder 8 | from ...ir.ast import Expression 9 | from ...kernel_type import KernelType 10 | from ..identifiable_expression import TensorLayer 11 | 12 | 13 | class Output: 14 | __slots__ = () 15 | 16 | @abstractmethod 17 | def write_assignment( 18 | self, right_hand_side: Expression, kernel_type: KernelType 19 | ) -> SourceBuilder: 20 | raise NotImplementedError() 21 | 22 | @abstractmethod 23 | def next_output( 24 | self, iteration_output: TensorLayer | None, kernel_type: KernelType 25 | ) -> tuple[Output, SourceBuilder, SourceBuilder]: 26 | raise NotImplementedError() 27 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | env: 9 | uv-version: "0.9.0" 10 | 11 | jobs: 12 | pypi-publish: 13 | name: Publish release to PyPI 14 | runs-on: ubuntu-22.04 15 | environment: release 16 | permissions: 17 | id-token: write 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v5 23 | with: 24 | version: ${{ env.uv-version }} 25 | - name: Build release with uv 26 | run: uv build 27 | - name: Check that tag version and project version match 28 | run: '[[ "v$(uv version --short)" == "${{ github.ref_name }}" ]]' 29 | - name: Upload distribution to PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | -------------------------------------------------------------------------------- /src/tensora/desugar/_best_algorithm.py: -------------------------------------------------------------------------------- 1 | __all__ = ["best_algorithm"] 2 | 3 | from returns.result import Failure, Result, Success 4 | 5 | from ..format import Format 6 | from ..iteration_graph.iteration_graph import IterationGraph 7 | from . import ast 8 | from ._exceptions import DiagonalAccessError, NoKernelFoundError 9 | from ._to_iteration_graphs import to_iteration_graphs 10 | 11 | 12 | def best_algorithm( 13 | assignment: ast.Assignment, formats: dict[str, Format] 14 | ) -> Result[IterationGraph, DiagonalAccessError | NoKernelFoundError]: 15 | try: 16 | match next(to_iteration_graphs(assignment, formats), None): 17 | case None: 18 | return Failure(NoKernelFoundError()) 19 | case graph: 20 | return Success(graph) 21 | except DiagonalAccessError as e: 22 | return Failure(e) 23 | -------------------------------------------------------------------------------- /src/tensora/desugar/_exceptions.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DiagonalAccessError", "NoKernelFoundError"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from . import ast 6 | 7 | 8 | @dataclass(frozen=True, slots=True) 9 | class DiagonalAccessError(Exception): 10 | tensor: ast.Tensor 11 | 12 | def __str__(self) -> str: 13 | return ( 14 | f"Diagonal access to a tensor (i.e. repeating the same index within a tensor) is not " 15 | f"currently supported: {self.tensor.name}({', '.join(self.tensor.indexes)})" 16 | ) 17 | 18 | 19 | @dataclass(frozen=True, slots=True) 20 | class NoKernelFoundError(Exception): 21 | def __str__(self) -> str: 22 | return ( 23 | "Tensora's tensor algebra compiler was unable to find a kernel for the given problem. " 24 | "This is likely due to sparse tensors needing to be iterated in opposite orders." 25 | ) 26 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/ast.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Add", "Expression", "Float", "Integer", "Literal", "Multiply", "Tensor"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from ...format import Mode 6 | 7 | 8 | class Expression: 9 | __slots__ = () 10 | 11 | 12 | class Literal(Expression): 13 | __slots__ = () 14 | 15 | 16 | @dataclass(frozen=True, slots=True) 17 | class Integer(Literal): 18 | value: int 19 | 20 | 21 | @dataclass(frozen=True, slots=True) 22 | class Float(Literal): 23 | value: float 24 | 25 | 26 | @dataclass(frozen=True, slots=True) 27 | class Tensor(Expression): 28 | id: str 29 | name: str 30 | indexes: tuple[str, ...] 31 | modes: tuple[Mode, ...] 32 | 33 | @property 34 | def order(self): 35 | return len(self.indexes) 36 | 37 | 38 | @dataclass(frozen=True, slots=True) 39 | class Add(Expression): 40 | left: Expression 41 | right: Expression 42 | 43 | 44 | @dataclass(frozen=True, slots=True) 45 | class Multiply(Expression): 46 | left: Expression 47 | right: Expression 48 | -------------------------------------------------------------------------------- /src/tensora/compile/_compile_llvm.py: -------------------------------------------------------------------------------- 1 | __all__ = ["compile_module"] 2 | 3 | import llvmlite.binding as llvm 4 | 5 | from ..codegen import ir_to_llvm 6 | from ..ir.ast import Module 7 | from ._initialize_llvm import target 8 | 9 | 10 | def compile_module(module: Module) -> llvm.ExecutionEngine: 11 | llvm_ir = ir_to_llvm(module) 12 | 13 | # Compile the module 14 | llvm_module = llvm.parse_assembly(str(llvm_ir)) 15 | llvm_module.verify() 16 | 17 | # Create target machine 18 | # We have to recreate this for every module because create_mcjit_compiler 19 | # takes ownership of target_machine and frees it when the engine goes out of scope 20 | target_machine = target.create_target_machine() 21 | 22 | # Create execution engine 23 | backing_mod = llvm.parse_assembly("") 24 | engine = llvm.create_mcjit_compiler(backing_mod, target_machine) 25 | 26 | # Add the module to the engine and make sure it is ready for execution 27 | engine.add_module(llvm_module) 28 | engine.finalize_object() 29 | engine.run_static_constructors() 30 | 31 | return engine 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-2024 David R Hagen 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /src/tensora/ir/types.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "Array", 3 | "FixedArray", 4 | "Float", 5 | "Integer", 6 | "Mode", 7 | "Pointer", 8 | "Tensor", 9 | "Type", 10 | "float", 11 | "integer", 12 | "mode", 13 | "tensor", 14 | ] 15 | 16 | from dataclasses import dataclass 17 | 18 | 19 | class Type: 20 | __slots__ = () 21 | 22 | 23 | @dataclass(frozen=True, slots=True) 24 | class Integer(Type): 25 | pass 26 | 27 | 28 | integer = Integer() 29 | 30 | 31 | @dataclass(frozen=True, slots=True) 32 | class Float(Type): 33 | pass 34 | 35 | 36 | float = Float() 37 | 38 | 39 | @dataclass(frozen=True, slots=True) 40 | class Tensor(Type): 41 | pass 42 | 43 | 44 | tensor = Tensor() 45 | 46 | 47 | @dataclass(frozen=True, slots=True) 48 | class Mode(Type): 49 | pass 50 | 51 | 52 | mode = Mode() 53 | 54 | 55 | @dataclass(frozen=True, slots=True) 56 | class Pointer(Type): 57 | target: Type 58 | 59 | 60 | @dataclass(frozen=True, slots=True) 61 | class Array(Type): 62 | element: Type 63 | 64 | 65 | @dataclass(frozen=True, slots=True) 66 | class FixedArray(Type): 67 | element: Type 68 | n: int 69 | -------------------------------------------------------------------------------- /tests/test_tensor_method.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora import Tensor, tensor_method 4 | from tensora.compile import BroadcastTargetIndexError 5 | from tensora.desugar import DiagonalAccessError, NoKernelFoundError 6 | 7 | 8 | def test_tensor_method(): 9 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 10 | 11 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 12 | 13 | expected = Tensor.from_aos([(0,), (1,)], [-5.0, 14.0], dimensions=(2,), format="d") 14 | 15 | function = tensor_method("y(i) = A(i,j) * x(j)", {"y": "d", "A": "ds", "x": "d"}) 16 | 17 | actual = function(A=A, x=x) 18 | 19 | assert actual == expected 20 | 21 | 22 | def test_broadcast_target_index_error(): 23 | with pytest.raises(BroadcastTargetIndexError): 24 | tensor_method("A(i,j) = a(i)", {}) 25 | 26 | 27 | def test_diagonal_error(): 28 | with pytest.raises(DiagonalAccessError): 29 | tensor_method("a(i) = A(i,i)", {"a": "d", "A": "dd"}) 30 | 31 | 32 | def test_no_solution(): 33 | with pytest.raises(NoKernelFoundError): 34 | tensor_method("A(i,j) = B(i,j) + C(j,i)", {"A": "ds", "B": "ds", "C": "ds"}) 35 | -------------------------------------------------------------------------------- /src/tensora/desugar/ast.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "Add", 3 | "Assignment", 4 | "Contract", 5 | "Expression", 6 | "Float", 7 | "Integer", 8 | "Literal", 9 | "Multiply", 10 | "Tensor", 11 | ] 12 | 13 | from dataclasses import dataclass 14 | 15 | 16 | class Expression: 17 | __slots__ = () 18 | 19 | 20 | class Literal(Expression): 21 | __slots__ = () 22 | 23 | 24 | @dataclass(frozen=True, slots=True) 25 | class Integer(Literal): 26 | value: int 27 | 28 | 29 | @dataclass(frozen=True, slots=True) 30 | class Float(Literal): 31 | value: float 32 | 33 | 34 | @dataclass(frozen=True, slots=True) 35 | class Tensor(Expression): 36 | id: int 37 | name: str 38 | indexes: tuple[str, ...] 39 | 40 | 41 | @dataclass(frozen=True, slots=True) 42 | class Add(Expression): 43 | left: Expression 44 | right: Expression 45 | 46 | 47 | @dataclass(frozen=True, slots=True) 48 | class Multiply(Expression): 49 | left: Expression 50 | right: Expression 51 | 52 | 53 | @dataclass(frozen=True, slots=True) 54 | class Contract(Expression): 55 | index: str 56 | expression: Expression 57 | 58 | 59 | @dataclass(frozen=True, slots=True) 60 | class Assignment: 61 | target: Tensor 62 | expression: Expression 63 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/_to_ir.py: -------------------------------------------------------------------------------- 1 | __all__ = ["to_ir"] 2 | 3 | from functools import singledispatch 4 | 5 | from ...ir import ast as ir 6 | from .ast import Add, Expression, Float, Integer, Multiply, Tensor 7 | 8 | 9 | @singledispatch 10 | def to_ir(self: Expression) -> ir.Expression: 11 | raise NotImplementedError(f"to_ir not implemented for {type(self)}: {self}") 12 | 13 | 14 | @to_ir.register(Integer) 15 | def to_ir_integer(self: Integer): 16 | # This is sensible as long as we only support floating point values and don't support division. If either of those 17 | # ceases to be true, this will need to be updated. 18 | return ir.IntegerLiteral(self.value) 19 | 20 | 21 | @to_ir.register(Float) 22 | def to_ir_float(self: Float): 23 | return ir.FloatLiteral(self.value) 24 | 25 | 26 | @to_ir.register(Tensor) 27 | def to_ir_tensor(self: Tensor): 28 | from .._names import previous_layer_pointer, vals_name 29 | 30 | return vals_name(self.name).idx(previous_layer_pointer(self.id, self.order)) 31 | 32 | 33 | @to_ir.register(Add) 34 | def to_ir_add(self: Add): 35 | return ir.Add(to_ir(self.left), to_ir(self.right)) 36 | 37 | 38 | @to_ir.register(Multiply) 39 | def to_ir_multiply(self: Multiply): 40 | return ir.Multiply(to_ir(self.left), to_ir(self.right)) 41 | -------------------------------------------------------------------------------- /src/tensora/generate/_base.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Language", "generate_code"] 2 | 3 | from enum import Enum 4 | 5 | from returns.result import Failure, Result, Success 6 | 7 | from ..codegen import ir_to_c, ir_to_llvm 8 | from ..kernel_type import KernelType 9 | from ..problem import Problem 10 | from ._tensora import generate_module_tensora 11 | 12 | 13 | class Language(str, Enum): 14 | """The language to be generated. 15 | 16 | Attributes 17 | ---------- 18 | c 19 | C language. 20 | llvm 21 | LLVM IR. 22 | """ 23 | 24 | # Python 3.10 does not support StrEnum, so do it manually 25 | c = "c" 26 | llvm = "llvm" 27 | 28 | def __str__(self) -> str: 29 | return self.name 30 | 31 | 32 | def generate_code( 33 | problem: Problem, 34 | kernel_types: list[KernelType], 35 | language: Language, 36 | ) -> Result[str, Exception]: 37 | match generate_module_tensora(problem, kernel_types): 38 | case Failure(error): 39 | return Failure(error) 40 | case Success(module): 41 | match language: 42 | case Language.c: 43 | return Success(ir_to_c(module)) 44 | case Language.llvm: 45 | return Success(str(ir_to_llvm(module))) 46 | case _: 47 | raise NotImplementedError() 48 | -------------------------------------------------------------------------------- /src/tensora/generate/_tensora.py: -------------------------------------------------------------------------------- 1 | __all__ = ["generate_module_tensora"] 2 | 3 | from returns.result import Failure, Result, Success 4 | 5 | from ..desugar import ( 6 | DiagonalAccessError, 7 | NoKernelFoundError, 8 | best_algorithm, 9 | desugar_assignment, 10 | index_dimensions, 11 | to_identifiable, 12 | ) 13 | from ..ir import peephole 14 | from ..ir.ast import Module 15 | from ..iteration_graph import Definition, generate_ir 16 | from ..kernel_type import KernelType 17 | from ..problem import Problem 18 | 19 | 20 | def generate_module_tensora( 21 | problem: Problem, kernel_types: list[KernelType] 22 | ) -> Result[Module, DiagonalAccessError | NoKernelFoundError]: 23 | formats = problem.formats 24 | 25 | desugar = desugar_assignment(problem.assignment) 26 | 27 | output_variable = to_identifiable(desugar.target, formats) 28 | 29 | definition = Definition(output_variable, formats, index_dimensions(desugar)) 30 | 31 | match best_algorithm(desugar, formats): 32 | case Failure() as result: 33 | return result 34 | case Success(graph): 35 | pass 36 | case _: 37 | raise NotImplementedError() 38 | 39 | functions = [generate_ir(definition, graph, kernel_type) for kernel_type in kernel_types] 40 | module = Module(functions) 41 | 42 | return Success(peephole(module)) 43 | -------------------------------------------------------------------------------- /src/tensora/expression/_exceptions.py: -------------------------------------------------------------------------------- 1 | __all__ = ["InconsistentDimensionsError", "MutatingAssignmentError", "NameConflictError"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from .ast import Assignment, Tensor 6 | 7 | 8 | @dataclass(frozen=True, slots=True) 9 | class MutatingAssignmentError(Exception): 10 | assignment: Assignment 11 | 12 | def __str__(self): 13 | return ( 14 | f"Expected assignment target to never appear on the right hand side, " 15 | f"but found {self.assignment.target.name} on both sides of {self.assignment}" 16 | ) 17 | 18 | 19 | @dataclass(frozen=True, slots=True) 20 | class InconsistentDimensionsError(Exception): 21 | assignment: Assignment 22 | first: Tensor 23 | second: Tensor 24 | 25 | def __str__(self): 26 | return ( 27 | f"Expected each tensor in an assignment to be referenced with the same number of " 28 | f"indexes, but found parameter {self.first.name} referenced as {self.first} and then " 29 | f"as {self.second} in {self.assignment}" 30 | ) 31 | 32 | 33 | @dataclass(frozen=True, slots=True) 34 | class NameConflictError(Exception): 35 | name: str 36 | assignment: Assignment 37 | 38 | def __str__(self): 39 | return ( 40 | f"Expected no tensor and index to have the same name, but found {self.name} as both a " 41 | f"tensor and an index in {self.assignment}" 42 | ) 43 | -------------------------------------------------------------------------------- /tests_cffi/test_evaluate.py: -------------------------------------------------------------------------------- 1 | from tensora import Tensor 2 | from tensora.compile import evaluate_cffi as evaluate 3 | 4 | 5 | def test_csr_matrix_vector_product(): 6 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 7 | 8 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 9 | 10 | expected = Tensor.from_aos([(0,), (1,)], [-5.0, 14.0], dimensions=(2,), format="d") 11 | 12 | actual = evaluate("y(i) = A(i,j) * x(j)", "d", A=A, x=x) 13 | 14 | assert actual == expected 15 | 16 | 17 | def test_csc_matrix_vector_product(): 18 | A = Tensor.from_aos( 19 | [(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="d1s0" 20 | ) 21 | 22 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 23 | 24 | expected = Tensor.from_aos([(0,), (1,)], [-5.0, 14.0], dimensions=(2,), format="d") 25 | 26 | actual = evaluate("y(i) = A(i,j) * x(j)", "d", A=A, x=x) 27 | 28 | assert actual == expected 29 | 30 | 31 | def test_csr_matrix_plus_csr_matrix(): 32 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 33 | 34 | B = Tensor.from_aos([(1, 1), (1, 2), (0, 2)], [-3.0, 4.0, 3.5], dimensions=(2, 3), format="ds") 35 | 36 | expected = Tensor.from_aos( 37 | [(1, 0), (0, 1), (1, 2), (1, 1), (0, 2)], 38 | [2.0, -2.0, 8.0, -3.0, 3.5], 39 | dimensions=(2, 3), 40 | format="ds", 41 | ) 42 | 43 | actual = evaluate("C(i,j) = A(i,j) + B(i,j)", "ds", A=A, B=B) 44 | 45 | assert actual == expected 46 | -------------------------------------------------------------------------------- /src/tensora/format/_format.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["Format", "Mode"] 4 | 5 | from dataclasses import dataclass 6 | from enum import Enum 7 | 8 | 9 | class Mode(Enum): 10 | # Manually map these to the entries in .compile.taco_type_header.taco_mode_t 11 | dense = (0, "d") 12 | compressed = (1, "s") 13 | 14 | def __init__(self, c_int: int, character: "str"): 15 | self.c_int = c_int 16 | self.character = character 17 | 18 | @staticmethod 19 | def from_c_int(value: int) -> Mode: 20 | for member in Mode: 21 | if member.value[0] == value: 22 | return member 23 | raise ValueError(f"No member of Mode has the integer value {value}") 24 | 25 | def __repr__(self) -> str: 26 | return f"Mode.{self.name}" 27 | 28 | 29 | @dataclass(frozen=True, slots=True) 30 | class Format: 31 | modes: tuple[Mode, ...] 32 | ordering: tuple[int, ...] 33 | 34 | def __post_init__(self): 35 | from ._exceptions import InvalidModeOrderingError 36 | 37 | if set(self.ordering) != set(range(len(self.modes))): 38 | raise InvalidModeOrderingError(self.modes, self.ordering) 39 | 40 | @property 41 | def order(self): 42 | return len(self.modes) 43 | 44 | def deparse(self): 45 | if self.ordering == tuple(range(self.order)): 46 | return "".join(mode.character for mode in self.modes) 47 | else: 48 | return "".join( 49 | mode.character + str(ordering) 50 | for mode, ordering in zip(self.modes, self.ordering, strict=True) 51 | ) 52 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/home 3 | --- 4 | 5 | # Tensora 6 | 7 | Tensora is a tensor algebra library for Python. You can create `Tensor` objects in a variety of sparse and dense formats. You can do calculations with these tensors by passing the them to the `evaluate` function along with an expression (e.g. `y = evaluate('y(i) = A(i,j) * x(j)', A=A, x=x)`). The expression is parsed, a kernel is generated, the C code is compiled on the fly, the binary is invoked, and the result is packaged into a output `Tensor`. 8 | 9 | Tensora also comes with the `tensora` command line tool that can be used to generate the kernel code for external use. 10 | 11 | Tensora is based on the [Tensor Algebra Compiler](http://tensor-compiler.org/) (TACO). 12 | 13 | ## Installation 14 | 15 | The recommended means of installation is with `pip` from PyPI. 16 | 17 | ```bash 18 | pip install tensora 19 | ``` 20 | 21 | By default, Tensora uses its own code to generate the kernels and LLVM to compile them (via `llvmlite`). The `tensora[cffi]` extra makes available the option to compile the kernels with CFFI; this requires a system C compiler available to CFFI. 22 | 23 | Tensora is tested on Linux, Mac, and Windows. The CFFI backend is not available on Windows. 24 | 25 | ## Hello world 26 | 27 | Here is an example of multiplying a sparse matrix in CSR format with a dense vector: 28 | 29 | ```python 30 | from tensora import Tensor, evaluate 31 | 32 | elements = { 33 | (1,0): 2.0, 34 | (0,1): -2.0, 35 | (1,2): 4.0, 36 | } 37 | 38 | A = Tensor.from_dok(elements, dimensions=(2,3), format='ds') 39 | x = Tensor.from_lol([0, -1, 2]) 40 | 41 | y = evaluate('y(i) = A(i,j) * x(j)', 'd', A=A, x=x) 42 | 43 | assert y == Tensor.from_lol([2,8]) 44 | ``` 45 | -------------------------------------------------------------------------------- /fuzz_tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | import hypothesis.strategies as st 2 | from hypothesis import given 3 | from parsita import ParseError 4 | from returns.result import Failure, Success 5 | 6 | from tensora.expression import ( 7 | InconsistentDimensionsError, 8 | MutatingAssignmentError, 9 | NameConflictError, 10 | ast, 11 | parse_assignment, 12 | ) 13 | from tensora.format import Format, InvalidModeOrderingError, parse_format 14 | 15 | from .strategies import assignments, formats 16 | 17 | 18 | @given(st.text()) 19 | def test_format_parsing_cannot_crash(string): 20 | match parse_format(string): 21 | case Success(Format()): 22 | pass 23 | case Failure(ParseError() | InvalidModeOrderingError()): 24 | pass 25 | case _: 26 | raise RuntimeError("Unexpected result") 27 | 28 | 29 | @given(formats()) 30 | def test_format_parsing_round_trips(format): 31 | text = format.deparse() 32 | new_format = parse_format(text).unwrap() 33 | assert format == new_format 34 | 35 | 36 | @given(st.text()) 37 | def test_expression_parsing_cannot_crash(string): 38 | match parse_assignment(string): 39 | case Success(ast.Assignment()): 40 | pass 41 | case Failure( 42 | ParseError() 43 | | MutatingAssignmentError() 44 | | InconsistentDimensionsError() 45 | | NameConflictError() 46 | ): 47 | pass 48 | case _: 49 | raise RuntimeError("Unexpected result") 50 | 51 | 52 | @given(assignments()) 53 | def test_expression_parsing_round_trips(assignment): 54 | text = assignment.deparse() 55 | new_assignment = parse_assignment(text).unwrap() 56 | assert assignment == new_assignment 57 | -------------------------------------------------------------------------------- /src/tensora/format/_parser.py: -------------------------------------------------------------------------------- 1 | __all__ = ["parse_format", "parse_named_format"] 2 | 3 | from parsita import ParseError, ParserContext, lit, reg, rep 4 | from parsita.util import constant 5 | from returns import result 6 | 7 | from ._exceptions import InvalidModeOrderingError 8 | from ._format import Format, Mode 9 | 10 | 11 | def make_format_with_orderings(dims): 12 | modes = [] 13 | orderings = [] 14 | for mode, ordering in dims: 15 | modes.append(mode) 16 | orderings.append(ordering) 17 | return Format(tuple(modes), tuple(orderings)) 18 | 19 | 20 | class FormatParsers(ParserContext): 21 | integer = reg(r"[0-9]+") > int 22 | dense = lit("d") > constant(Mode.dense) 23 | compressed = lit("s") > constant(Mode.compressed) 24 | mode = dense | compressed 25 | 26 | format_without_orderings = rep(mode) > ( 27 | lambda modes: Format(tuple(modes), tuple(range(len(modes)))) 28 | ) 29 | format_with_orderings = rep(mode & integer) > make_format_with_orderings 30 | 31 | format = format_without_orderings | format_with_orderings 32 | 33 | variable = reg(r"[a-zA-Z_][a-zA-Z0-9_]*") 34 | named_format = variable << ":" & format > tuple 35 | 36 | 37 | def parse_format(string: str, /) -> result.Result[Format, ParseError | InvalidModeOrderingError]: 38 | try: 39 | return FormatParsers.format.parse(string) 40 | except InvalidModeOrderingError as e: 41 | return result.Failure(e) 42 | 43 | 44 | def parse_named_format( 45 | string: str, / 46 | ) -> result.Result[tuple[str, Format], ParseError | InvalidModeOrderingError]: 47 | try: 48 | return FormatParsers.named_format.parse(string) 49 | except InvalidModeOrderingError as e: 50 | return result.Failure(e) 51 | -------------------------------------------------------------------------------- /src/tensora/codegen/_type_to_c.py: -------------------------------------------------------------------------------- 1 | __all__ = ["type_to_c"] 2 | 3 | from functools import singledispatch 4 | 5 | from ..ir.types import Array, FixedArray, Float, Integer, Mode, Pointer, Tensor, Type 6 | 7 | 8 | def space_variable(variable: str | None = None) -> str: 9 | if variable is None: 10 | return "" 11 | else: 12 | return f" {variable}" 13 | 14 | 15 | @singledispatch 16 | def type_to_c(self: Type, variable: str | None = None) -> str: 17 | raise NotImplementedError(f"type_to_c not implemented for {type(self)}: {self}") 18 | 19 | 20 | @type_to_c.register(Integer) 21 | def type_to_c_integer(self: Integer, variable: str | None = None) -> str: 22 | return "int32_t" + space_variable(variable) 23 | 24 | 25 | @type_to_c.register(Float) 26 | def type_to_c_float(self: Float, variable: str | None = None) -> str: 27 | return "double" + space_variable(variable) 28 | 29 | 30 | @type_to_c.register(Tensor) 31 | def type_to_c_tensor(self: Tensor, variable: str | None = None) -> str: 32 | return "taco_tensor_t" + space_variable(variable) 33 | 34 | 35 | @type_to_c.register(Mode) 36 | def type_to_c_mode(self: Mode, variable: str | None = None) -> str: 37 | return "taco_mode_t" + space_variable(variable) 38 | 39 | 40 | @type_to_c.register(Pointer) 41 | def type_to_c_pointer(self: Pointer, variable: str | None = None) -> str: 42 | return f"{type_to_c(self.target)}* restrict" + space_variable(variable) 43 | 44 | 45 | @type_to_c.register(Array) 46 | def type_to_c_array(self: Array, variable: str | None = None) -> str: 47 | return f"{type_to_c(self.element, variable)}[]" 48 | 49 | 50 | @type_to_c.register(FixedArray) 51 | def type_to_c_fixed_array(self: FixedArray, variable: str | None = None) -> str: 52 | return f"{type_to_c(self.element, variable)}[{self.n}]" 53 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Tensora 2 | site_description: Sparse/dense tensor library for Python 3 | site_author: David Hagen 4 | 5 | repo_name: drhagen/tensora 6 | repo_url: https://github.com/drhagen/tensora/ 7 | edit_uri: '' 8 | 9 | copyright: > 10 | © David R Hagen. Content available under an 11 | MIT license. 12 | 13 | theme: 14 | name: material 15 | 16 | icon: 17 | logo: fontawesome/solid/braille 18 | repo: fontawesome/brands/github 19 | 20 | palette: 21 | - media: "(prefers-color-scheme)" 22 | toggle: 23 | icon: material/link 24 | name: Switch to light mode 25 | - media: "(prefers-color-scheme: light)" 26 | scheme: default 27 | primary: black 28 | accent: yellow 29 | toggle: 30 | icon: material/weather-sunny 31 | name: Switch to dark mode 32 | - media: "(prefers-color-scheme: dark)" 33 | scheme: slate 34 | primary: yellow 35 | accent: deep orange 36 | toggle: 37 | icon: material/weather-night 38 | name: Switch to system preference 39 | 40 | markdown_extensions: 41 | - pymdownx.highlight 42 | - pymdownx.superfences 43 | 44 | extra: 45 | social: 46 | - icon: fontawesome/brands/github 47 | link: https://github.com/drhagen 48 | - icon: fontawesome/brands/stack-overflow 49 | link: https://stackoverflow.com/users/1485877/drhagen 50 | - icon: fontawesome/brands/twitter 51 | link: https://twitter.com/drhagen 52 | - icon: fontawesome/brands/linkedin 53 | link: https://www.linkedin.com/in/davidrhagen/ 54 | 55 | nav: 56 | - Home: index.md 57 | - Getting started: getting-started.md 58 | - Creation: creation.md 59 | - Tensors: tensors.md 60 | - Evaluate: evaluate.md 61 | - Contributing: contributing.md 62 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | from nox import Session, options, parametrize 4 | from nox_uv import session 5 | 6 | options.default_venv_backend = "uv" 7 | options.sessions = ["test", "test_cffi", "test_numpy", "coverage", "lint"] 8 | 9 | 10 | @session(python=["3.10", "3.11", "3.12", "3.13"], uv_groups=["test"]) 11 | def test(s: Session): 12 | coverage_file = f".coverage.{platform.machine()}.{platform.system()}.{s.python}" 13 | s.run("coverage", "run", "--data-file", coverage_file, "-m", "pytest", "tests") 14 | 15 | 16 | @session(python=["3.10", "3.11", "3.12", "3.13"], uv_groups=["test"], uv_extras=["numpy"]) 17 | def test_numpy(s: Session): 18 | coverage_file = f".coverage.{platform.machine()}.{platform.system()}.{s.python}.numpy" 19 | s.run("coverage", "run", "--data-file", coverage_file, "-m", "pytest", "tests/test_numpy.py") 20 | 21 | 22 | @session(python=["3.10", "3.11", "3.12", "3.13"], uv_groups=["test"], uv_extras=["cffi"]) 23 | def test_cffi(s: Session): 24 | coverage_file = f".coverage.{platform.machine()}.{platform.system()}.{s.python}.cffi" 25 | s.run("coverage", "run", "--data-file", coverage_file, "-m", "pytest", "tests_cffi") 26 | 27 | 28 | @session(venv_backend="none") 29 | def coverage(s: Session): 30 | s.run("coverage", "combine") 31 | s.run("coverage", "html") 32 | s.run("coverage", "xml") 33 | 34 | 35 | @session(venv_backend="none") 36 | def fuzz(s: Session): 37 | s.run("hypothesis", "fuzz", "fuzz_tests") 38 | 39 | 40 | @session(venv_backend="none") 41 | @parametrize("command", [["ruff", "check", "."], ["ruff", "format", "--check", "."]]) 42 | def lint(s: Session, command: list[str]): 43 | s.run(*command) 44 | 45 | 46 | @session(venv_backend="none") 47 | def format(s: Session) -> None: 48 | s.run("ruff", "check", ".", "--select", "I", "--fix") 49 | s.run("ruff", "format", ".") 50 | -------------------------------------------------------------------------------- /src/tensora/expression/_parser.py: -------------------------------------------------------------------------------- 1 | __all__ = ["parse_assignment"] 2 | 3 | from functools import reduce 4 | 5 | from parsita import ParseError, ParserContext, lit, reg, rep, rep1sep, repsep 6 | from parsita.util import splat 7 | from returns import result 8 | 9 | from ._exceptions import InconsistentDimensionsError, MutatingAssignmentError, NameConflictError 10 | from .ast import Add, Assignment, Float, Integer, Multiply, Subtract, Tensor 11 | 12 | 13 | def make_expression(first, rest): 14 | value = first 15 | for op, term in rest: 16 | match op: 17 | case "+": 18 | value = Add(value, term) 19 | case "-": 20 | value = Subtract(value, term) 21 | return value 22 | 23 | 24 | class TensorExpressionParsers(ParserContext, whitespace=r"[ ]*"): 25 | name = reg(r"[A-Za-z][A-Za-z0-9]*") 26 | 27 | floating_point = reg(r"\d+((\.\d+([Ee][+-]?\d+)?)|((\.\d+)?[Ee][+-]?\d+))") > ( 28 | lambda x: Float(float(x)) 29 | ) 30 | integer = reg(r"[0-9]+") > (lambda x: Integer(int(x))) 31 | number = floating_point | integer 32 | 33 | tensor = name & "(" >> (repsep(name, ",") > tuple) << ")" > splat(Tensor) 34 | 35 | parentheses = "(" >> expression << ")" # noqa: F821 36 | factor = tensor | number | parentheses 37 | 38 | term = rep1sep(factor, "*") > (lambda x: reduce(Multiply, x)) 39 | expression = term & rep(lit("+", "-") & term) > splat(make_expression) 40 | 41 | assignment = tensor & "=" >> expression > splat(Assignment) 42 | 43 | 44 | def parse_assignment( 45 | string: str, 46 | ) -> result.Result[ 47 | Assignment, 48 | ParseError | MutatingAssignmentError | InconsistentDimensionsError | NameConflictError, 49 | ]: 50 | try: 51 | return TensorExpressionParsers.assignment.parse(string) 52 | except (MutatingAssignmentError, InconsistentDimensionsError, NameConflictError) as e: 53 | return result.Failure(e) 54 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/_exhaust_tensor.py: -------------------------------------------------------------------------------- 1 | __all__ = ["exhaust_tensor"] 2 | 3 | from functools import singledispatch 4 | 5 | from .ast import Add, Expression, Integer, Literal, Multiply, Tensor 6 | 7 | 8 | @singledispatch 9 | def exhaust_tensor(self, reference: str) -> Expression: 10 | raise NotImplementedError(f"exhaust_tensor not implemented for {type(self)}: {self}") 11 | 12 | 13 | @exhaust_tensor.register(Literal) 14 | def exhaust_tensor_literal(self: Literal, reference: str): 15 | return self 16 | 17 | 18 | @exhaust_tensor.register(Tensor) 19 | def exhaust_tensor_tensor(self: Tensor, reference: str): 20 | if self.id == reference: 21 | return Integer(0) 22 | else: 23 | return self 24 | 25 | 26 | @exhaust_tensor.register(Add) 27 | def exhaust_tensor_add(self: Add, reference: str): 28 | left_exhausted = exhaust_tensor(self.left, reference) 29 | right_exhausted = exhaust_tensor(self.right, reference) 30 | if left_exhausted is self.left and right_exhausted is self.right: 31 | # Short circuit when there are no changes 32 | return self 33 | elif left_exhausted == Integer(0): 34 | # Covers the case where both are exhausted 35 | return right_exhausted 36 | elif right_exhausted == Integer(0): 37 | return left_exhausted 38 | else: 39 | return Add(left_exhausted, right_exhausted) 40 | 41 | 42 | @exhaust_tensor.register(Multiply) 43 | def exhaust_tensor_multiply(self: Multiply, reference: str): 44 | left_exhausted = exhaust_tensor(self.left, reference) 45 | right_exhausted = exhaust_tensor(self.right, reference) 46 | if left_exhausted is self.left and right_exhausted is self.right: 47 | # Short circuit when there are no changes 48 | return self 49 | elif left_exhausted == Integer(0) or right_exhausted == Integer(0): 50 | return Integer(0) 51 | else: 52 | return Multiply(left_exhausted, right_exhausted) 53 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/_tensor_layer.py: -------------------------------------------------------------------------------- 1 | __all__ = ["TensorLayer"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from ...ir import ast as ir 6 | from .._names import ( 7 | crd_capacity_name, 8 | crd_name, 9 | layer_begin_name, 10 | layer_pointer, 11 | pos_capacity_name, 12 | pos_name, 13 | previous_layer_pointer, 14 | sparse_end_name, 15 | vals_capacity_name, 16 | vals_name, 17 | value_from_crd, 18 | ) 19 | from . import ast as id 20 | 21 | 22 | @dataclass(frozen=True, slots=True) 23 | class TensorLayer: 24 | tensor: id.Tensor 25 | layer: int 26 | 27 | @property 28 | def mode(self): 29 | return self.tensor.modes[self.layer] 30 | 31 | def pos_name(self) -> ir.Variable: 32 | return pos_name(self.tensor.name, self.layer) 33 | 34 | def crd_name(self) -> ir.Variable: 35 | return crd_name(self.tensor.name, self.layer) 36 | 37 | def vals_name(self) -> ir.Variable: 38 | return vals_name(self.tensor.name) 39 | 40 | def pos_capacity_name(self) -> ir.Variable: 41 | return pos_capacity_name(self.tensor.name, self.layer) 42 | 43 | def crd_capacity_name(self) -> ir.Variable: 44 | return crd_capacity_name(self.tensor.name, self.layer) 45 | 46 | def vals_capacity_name(self) -> ir.Variable: 47 | return vals_capacity_name(self.tensor.name) 48 | 49 | def layer_pointer(self) -> ir.Variable: 50 | return layer_pointer(self.tensor.id, self.layer) 51 | 52 | def previous_layer_pointer(self) -> ir.Expression: 53 | return previous_layer_pointer(self.tensor.id, self.layer) 54 | 55 | def sparse_end_name(self) -> ir.Variable: 56 | return sparse_end_name(self.tensor.id, self.layer) 57 | 58 | def layer_begin_name(self) -> ir.Variable: 59 | return layer_begin_name(self.tensor.id, self.layer) 60 | 61 | def value_from_crd(self) -> ir.Variable: 62 | return value_from_crd(self.tensor.id, self.layer) 63 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/_names.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = [ 4 | "crd_capacity_name", 5 | "crd_name", 6 | "dimension_name", 7 | "layer_pointer", 8 | "pos_capacity_name", 9 | "pos_name", 10 | "previous_layer_pointer", 11 | "sparse_end_name", 12 | "vals_capacity_name", 13 | "vals_name", 14 | "value_from_crd", 15 | ] 16 | 17 | 18 | from ..ir.ast import Expression, IntegerLiteral, Variable 19 | 20 | 21 | def dimension_name(index_variable: str) -> Variable: 22 | return Variable(f"{index_variable}_dim") 23 | 24 | 25 | def pos_name(tensor: str, layer: int) -> Variable: 26 | return Variable(f"{tensor}_{layer}_pos") 27 | 28 | 29 | def crd_name(tensor: str, layer: int) -> Variable: 30 | return Variable(f"{tensor}_{layer}_crd") 31 | 32 | 33 | def vals_name(tensor: str) -> Variable: 34 | return Variable(f"{tensor}_vals") 35 | 36 | 37 | def pos_capacity_name(tensor: str, layer: int) -> Variable: 38 | return Variable(f"{tensor}_{layer}_pos_capacity") 39 | 40 | 41 | def crd_capacity_name(tensor: str, layer: int) -> Variable: 42 | return Variable(f"{tensor}_{layer}_crd_capacity") 43 | 44 | 45 | def vals_capacity_name(tensor: str) -> Variable: 46 | return Variable(f"{tensor}_vals_capacity") 47 | 48 | 49 | def layer_pointer(reference: str, layer: int) -> Variable: 50 | return Variable(f"p_{reference}_{layer}") 51 | 52 | 53 | def previous_layer_pointer(reference: str, layer: int) -> Expression: 54 | if layer == 0: 55 | return IntegerLiteral(0) 56 | else: 57 | return layer_pointer(reference, layer - 1) 58 | 59 | 60 | def sparse_end_name(reference: str, layer: int) -> Variable: 61 | return Variable(f"p_{reference}_{layer}_end") 62 | 63 | 64 | def layer_begin_name(reference: str, layer: int) -> Variable: 65 | return Variable(f"p_{reference}_{layer}_begin") 66 | 67 | 68 | def value_from_crd(reference: str, layer: int) -> Variable: 69 | return Variable(f"i_{reference}_{layer}") 70 | -------------------------------------------------------------------------------- /src/tensora/compile/_compile_cffi.py: -------------------------------------------------------------------------------- 1 | __all__ = ["compile_evaluate"] 2 | 3 | import re 4 | import tempfile 5 | import threading 6 | from typing import Any 7 | 8 | from cffi import FFI 9 | 10 | from ._cffi_ownership import taco_type_header, tensor_cdefs 11 | 12 | lock = threading.Lock() 13 | 14 | taco_define_header = """ 15 | #ifndef TACO_C_HEADERS 16 | #define TACO_C_HEADERS 17 | #define TACO_MIN(_a,_b) ((_a) < (_b) ? (_a) : (_b)) 18 | #define TACO_MAX(_a,_b) ((_a) > (_b) ? (_a) : (_b)) 19 | #endif 20 | """ 21 | 22 | 23 | def compile_evaluate(source: str) -> Any: 24 | """Compile evaluate kernel in C code using CFFI. 25 | 26 | Args: 27 | source: C code containing one evaluate function 28 | 29 | Returns: 30 | The compiled FFILibrary which has a single method `evaluate` which 31 | expects cffi pointers to taco_tensor_t instances. 32 | """ 33 | # Extract signature 34 | # This needs to be provided alone to cdef 35 | signature_match = re.search(r"int(32_t)? evaluate\(([^)]*)\)", source) 36 | signature = signature_match.group(0) 37 | 38 | # Use cffi to compile the kernels 39 | ffibuilder = FFI() 40 | ffibuilder.include(tensor_cdefs) 41 | ffibuilder.cdef(signature + ";") 42 | ffibuilder.set_source( 43 | "taco_kernel", 44 | taco_define_header + taco_type_header + source, 45 | extra_compile_args=["-Wno-unused-variable", "-Wno-unknown-pragmas"], 46 | ) 47 | 48 | with tempfile.TemporaryDirectory() as temp_dir: 49 | # Lock because FFI.compile is not thread safe: https://foss.heptapod.net/pypy/cffi/-/issues/490 50 | with lock: 51 | # Create shared object in temporary directory 52 | lib_path = ffibuilder.compile(tmpdir=temp_dir) 53 | 54 | # Load the shared object 55 | lib = ffibuilder.dlopen(lib_path) 56 | 57 | # Return the entire library rather than just the function because it appears that the memory containing the compiled 58 | # code is freed as soon as the library goes out of scope: https://stackoverflow.com/q/55323592/1485877 59 | return lib 60 | -------------------------------------------------------------------------------- /src/tensora/generate/_deparse_to_taco.py: -------------------------------------------------------------------------------- 1 | __all__ = ["deparse_to_taco"] 2 | 3 | from functools import singledispatch 4 | 5 | from tensora.expression import ast 6 | 7 | 8 | @singledispatch 9 | def deparse_to_taco_expression(self: ast.Expression) -> str: 10 | raise NotImplementedError( 11 | f"deparse_to_taco_expression not implemented for {type(self)}: {self}" 12 | ) 13 | 14 | 15 | @deparse_to_taco_expression.register(ast.Integer) 16 | def deparse_to_taco_integer(self: ast.Integer) -> str: 17 | return str(self.value) 18 | 19 | 20 | @deparse_to_taco_expression.register(ast.Float) 21 | def deparse_to_taco_float(self: ast.Float) -> str: 22 | return str(self.value) 23 | 24 | 25 | @deparse_to_taco_expression.register(ast.Tensor) 26 | def deparse_to_taco_tensor(self: ast.Tensor) -> str: 27 | if len(self.indexes) == 0: 28 | # Taco represents zero-dimensional tensors as scalars 29 | return self.name 30 | else: 31 | return f"{self.name}({', '.join(self.indexes)})" 32 | 33 | 34 | @deparse_to_taco_expression.register(ast.Add) 35 | def deparse_to_taco_add(self: ast.Add) -> str: 36 | return f"{deparse_to_taco_expression(self.left)} + {deparse_to_taco_expression(self.right)}" 37 | 38 | 39 | @deparse_to_taco_expression.register(ast.Subtract) 40 | def deparse_to_taco_subtract(self: ast.Subtract) -> str: 41 | return f"{deparse_to_taco_expression(self.left)} - {deparse_to_taco_expression(self.right)}" 42 | 43 | 44 | @deparse_to_taco_expression.register(ast.Multiply) 45 | def deparse_to_taco_multiply(self: ast.Multiply) -> str: 46 | left_string = deparse_to_taco_expression(self.left) 47 | if isinstance(self.left, (ast.Add, ast.Subtract)): 48 | left_string = f"({left_string})" 49 | 50 | right_string = deparse_to_taco_expression(self.right) 51 | if isinstance(self.right, (ast.Add, ast.Subtract)): 52 | right_string = f"({right_string})" 53 | 54 | return f"{left_string} * {right_string}" 55 | 56 | 57 | def deparse_to_taco(self: ast.Assignment) -> str: 58 | return f"{deparse_to_taco_expression(self.target)} = {deparse_to_taco_expression(self.expression)}" 59 | -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/sign-direction 3 | --- 4 | 5 | # Getting started 6 | 7 | Tensors are n-dimensional generalizations of matrices. Instead of being limited to 2 dimensions, tensors may have 3, 4, or more dimensions. They may also have 0 or 1 dimensions. The number of dimensions is the called the order. NumPy is the best known tensor library in Python; its central `ndarray` object is an example of a dense tensor. 8 | 9 | Each dimension of a tensor has a size. This determines, conceptually, the number of elements in the tensor. "Conceptually" because the number of stored elements and the amount of memory required for the tensor may be smaller than that if the tensor is sparse. 10 | 11 | Tensors also have a format. The format has a list of modes, which determines the internal layout of the tensor, and a mode ordering, which maps each dimension to each mode. Each mode can be either sparse or dense. An example of two different formats with the same internal layout would be CSR, which has format `ds` in Tensora, and CSC, which has format `d1s0`. 12 | 13 | Here are a list of common formats: 14 | 15 | | common name | Tensora format | 16 | |---------------|----------------| 17 | | scalar | `''` | 18 | | dense vector | `'d'` | 19 | | sparse vector | `'s'` | 20 | | row-major | `'dd'` | 21 | | column-major | `'d1d0'` | 22 | | CSR | `'ds'` | 23 | | CSC | `'d1s0'` | 24 | | CSF | `'sd'` | 25 | | DCSR | `'ss'` | 26 | 27 | There are formats for higher order tensors, but they do not have common names. That is one of the goals of Tensora, to give access to the creation and use of new formats. 28 | 29 | Tensors are [created](./creation.md) via one of several static methods on the `Tensor` class. The key [attributes](./tensors.md), `order`, `dimensions`, and `format`, are available on every `Tensor`. While basic [arithmetic](./tensors.md#arithmetic) (`+`, `-`, `*`, `@`) is available as well, it is generally better to use the `evaluate` function, which makes much more complex operations available and will fuse the loops of multiple arithmetic operators. 30 | -------------------------------------------------------------------------------- /src/tensora/codegen/_hoist_declarations.py: -------------------------------------------------------------------------------- 1 | __all__ = ["hoist_declarations"] 2 | 3 | from functools import singledispatch 4 | 5 | from ..ir.ast import ( 6 | Assignment, 7 | Block, 8 | Branch, 9 | Declaration, 10 | DeclarationAssignment, 11 | Expression, 12 | FunctionDefinition, 13 | Loop, 14 | Return, 15 | Statement, 16 | ) 17 | from ..ir.types import Type 18 | 19 | 20 | @singledispatch 21 | def hoist_declarations_statement(self: Statement) -> dict[str, Type]: 22 | raise NotImplementedError( 23 | f"hoist_declarations_statement not implemented for {type(self).__name__}" 24 | ) 25 | 26 | 27 | @hoist_declarations_statement.register 28 | def hoist_declarations_expression(self: Expression) -> dict[str, Type]: 29 | return {} 30 | 31 | 32 | @hoist_declarations_statement.register 33 | def hoist_declarations_declaration(self: Declaration) -> dict[str, Type]: 34 | return {self.name.name: self.type} 35 | 36 | 37 | @hoist_declarations_statement.register 38 | def hoist_declarations_assignment(self: Assignment) -> dict[str, Type]: 39 | return {} 40 | 41 | 42 | @hoist_declarations_statement.register 43 | def hoist_declarations_declaration_assignment(self: DeclarationAssignment) -> dict[str, Type]: 44 | return {self.target.name.name: self.target.type} 45 | 46 | 47 | @hoist_declarations_statement.register 48 | def hoist_declarations_block(self: Block) -> dict[str, Type]: 49 | result = {} 50 | for s in self.statements: 51 | result.update(hoist_declarations_statement(s)) 52 | return result 53 | 54 | 55 | @hoist_declarations_statement.register 56 | def hoist_declarations_branch(self: Branch) -> dict[str, Type]: 57 | result = {} 58 | result.update(hoist_declarations_statement(self.if_true)) 59 | result.update(hoist_declarations_statement(self.if_false)) 60 | return result 61 | 62 | 63 | @hoist_declarations_statement.register 64 | def hoist_declarations_loop(self: Loop) -> dict[str, Type]: 65 | return hoist_declarations_statement(self.body) 66 | 67 | 68 | @hoist_declarations_statement.register 69 | def hoist_declarations_return(self: Return) -> dict[str, Type]: 70 | return {} 71 | 72 | 73 | def hoist_declarations(fn: FunctionDefinition) -> dict[str, Type]: 74 | return hoist_declarations_statement(fn.body) 75 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tempfile import TemporaryDirectory 3 | 4 | import pytest 5 | from typer.testing import CliRunner 6 | 7 | from tensora.cli import app 8 | 9 | runner = CliRunner() 10 | 11 | 12 | def test_help(): 13 | result = runner.invoke(app, ["--help"]) 14 | 15 | assert result.exit_code == 0 16 | assert "Usage:" in result.stdout 17 | assert "Options" in result.stdout 18 | 19 | 20 | def test_cli(): 21 | result = runner.invoke(app, ["y(i) = A(i,j) * x(j)", "-f", "A:ds"]) 22 | 23 | assert result.exit_code == 0 24 | assert result.stdout.startswith("int32_t compute(taco_tensor_t* restrict y,") 25 | 26 | 27 | def test_multiple_kernels(): 28 | result = runner.invoke( 29 | app, 30 | ["y(i) = A(i,j) * x(j)", "-t", "compute", "-t", "evaluate", "-t", "assemble"], 31 | ) 32 | 33 | assert result.exit_code == 0 34 | assert "compute" in result.stdout 35 | assert "evaluate" in result.stdout 36 | assert "assemble" in result.stdout 37 | 38 | 39 | def test_llvm_language(): 40 | result = runner.invoke(app, ["y(i) = A(i,j) * x(j)", "-l", "llvm"]) 41 | 42 | assert result.exit_code == 0 43 | assert "getelementptr" in result.stdout 44 | 45 | 46 | def test_write_to_file(): 47 | # Use a temporary directory instead of a temporary file to avoid issue on Windows 48 | # where the same file cannot be opened by two processes at the same time. 49 | with TemporaryDirectory() as tmpdir: 50 | file = Path(tmpdir) / "output.c" 51 | result = runner.invoke(app, ["y(i) = A(i,j) * x(j)", "-o", str(file)]) 52 | 53 | assert result.exit_code == 0 54 | assert result.stdout == "" 55 | assert file.read_text().startswith("int32_t compute(taco_tensor_t* restrict y,") 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "command", 60 | [ 61 | ["a(i) = b(i) +"], 62 | ["y(i) = A(i,j) * x(j)", "-f=ds"], 63 | ["y(i) = A(i,j) * x(j)", "-f=A:d1s2"], 64 | ["y(i) = A(i,j) * x(j)", "-f=A:ds", "-f=A:dd"], 65 | ["y(i) = A(i,j) * x(j)", "-f=A:d"], 66 | ["y(i) = A(i,j) * x(j)", "-f=B:ds"], 67 | ["a(i) = A(i,i)"], 68 | ["A(i,j) = B(i,j) + C(j,i)", "-f=A:ds", "-f=B:ds", "-f=C:ds"], 69 | ], 70 | ) 71 | def test_bad_input(command): 72 | result = runner.invoke(app, command, catch_exceptions=False) 73 | 74 | assert result.exit_code == 1 75 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | env: 10 | uv-version: "0.9.0" 11 | 12 | jobs: 13 | test: 14 | strategy: 15 | matrix: 16 | platform: 17 | - ubuntu-24.04 18 | - ubuntu-24.04-arm 19 | - macos-13 20 | - macos-15 21 | - windows-2025 22 | python: 23 | - "3.10" 24 | - "3.11" 25 | - "3.12" 26 | - "3.13" 27 | runs-on: ${{ matrix.platform }} 28 | steps: 29 | - name: Checkout repo 30 | uses: actions/checkout@v4 31 | - name: Install uv 32 | uses: astral-sh/setup-uv@v5 33 | with: 34 | version: ${{ env.uv-version }} 35 | - name: Test with Nox (all platforms) 36 | run: >- 37 | uv run --locked --only-group nox nox -s 38 | test-${{ matrix.python }} 39 | test_numpy-${{ matrix.python }} 40 | - name: Test with Nox (non-Windows) 41 | if: runner.os != 'Windows' 42 | run: >- 43 | uv run --locked --only-group nox nox -s 44 | test_cffi-${{ matrix.python }} 45 | - name: Store coverage 46 | uses: actions/upload-artifact@v4 47 | with: 48 | name: coverage-${{ matrix.python }}-${{ matrix.platform }} 49 | path: .coverage.* 50 | include-hidden-files: true 51 | if-no-files-found: error 52 | 53 | coverage: 54 | needs: test 55 | runs-on: ubuntu-22.04 56 | steps: 57 | - name: Checkout repo 58 | uses: actions/checkout@v4 59 | - name: Install uv 60 | uses: astral-sh/setup-uv@v5 61 | with: 62 | version: ${{ env.uv-version }} 63 | - name: Fetch coverage 64 | uses: actions/download-artifact@v4 65 | with: 66 | pattern: coverage-* 67 | merge-multiple: true 68 | - name: Combine coverage and generate report 69 | run: uv run --locked nox -s coverage 70 | - name: Upload coverage to Codecov 71 | uses: codecov/codecov-action@v4 72 | with: 73 | token: ${{ secrets.CODECOV_TOKEN }} 74 | fail_ci_if_error: true 75 | 76 | lint: 77 | runs-on: ubuntu-22.04 78 | steps: 79 | - name: Checkout repo 80 | uses: actions/checkout@v4 81 | - name: Install uv 82 | uses: astral-sh/setup-uv@v5 83 | with: 84 | version: ${{ env.uv-version }} 85 | - name: Check code quality with Nox 86 | run: uv run --locked nox -s lint 87 | -------------------------------------------------------------------------------- /tests/test_format.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from returns.result import Failure 3 | 4 | from tensora import Format, Mode 5 | from tensora.format import InvalidModeOrderingError, parse_format, parse_named_format 6 | 7 | format_strings = [ 8 | ("", Format((), ())), 9 | ("d", Format((Mode.dense,), (0,))), 10 | ("s", Format((Mode.compressed,), (0,))), 11 | ("ds", Format((Mode.dense, Mode.compressed), (0, 1))), 12 | ("sd", Format((Mode.compressed, Mode.dense), (0, 1))), 13 | ("d1s0", Format((Mode.dense, Mode.compressed), (1, 0))), 14 | ("d1s0s2", Format((Mode.dense, Mode.compressed, Mode.compressed), (1, 0, 2))), 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize(("string", "format"), format_strings) 19 | def test_parse_format(string, format): 20 | actual = parse_format(string).unwrap() 21 | assert actual == format 22 | 23 | 24 | @pytest.mark.parametrize(("string", "format"), format_strings) 25 | def test_deparse_format(string, format): 26 | actual = format.deparse() 27 | assert actual == string 28 | 29 | 30 | @pytest.mark.parametrize("string", ["df", "1d0s", "d0s", "d0s1s1", "d1s2s3", "d3d1d2"]) 31 | def test_parse_bad_format(string): 32 | actual = parse_format(string) 33 | assert isinstance(actual, Failure) 34 | 35 | 36 | def test_parse_named_format(): 37 | actual = parse_named_format("A:d1s0s2").unwrap() 38 | assert actual == ("A", Format((Mode.dense, Mode.compressed, Mode.compressed), (1, 0, 2))) 39 | 40 | 41 | def test_parse_bad_named_format(): 42 | actual = parse_named_format("d1s0s2s3") 43 | assert isinstance(actual, Failure) 44 | 45 | 46 | @pytest.mark.parametrize("string", ["A:d0s", "A:d3d1d2"]) 47 | def test_parse_bad_ordering_in_named_format(string): 48 | actual = parse_named_format(string) 49 | assert isinstance(actual, Failure) 50 | 51 | 52 | def test_format_attributes(): 53 | format = Format((Mode.dense, Mode.compressed), (1, 0)) 54 | 55 | assert format.order == 2 56 | assert format.modes[0] == Mode.dense 57 | 58 | 59 | def test_mode_dense_attributes(): 60 | mode_dense = Mode.from_c_int(0) 61 | assert mode_dense.c_int == 0 62 | assert mode_dense.character == "d" 63 | 64 | 65 | def test_mode_sparse_attributes(): 66 | mode_dense = Mode.from_c_int(1) 67 | assert mode_dense.c_int == 1 68 | assert mode_dense.character == "s" 69 | 70 | 71 | def test_mode_from_illegal_int(): 72 | with pytest.raises(ValueError, match="No member of Mode"): 73 | Mode.from_c_int(3) 74 | 75 | 76 | def test_differing_sizes(): 77 | with pytest.raises(InvalidModeOrderingError): 78 | _ = Format((Mode.dense,), (0, 1)) 79 | -------------------------------------------------------------------------------- /tests/test_desugar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora.desugar import ast as desugar 4 | from tensora.desugar import desugar_assignment 5 | from tensora.expression import ast as sugar 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ("expression", "expected"), 10 | [ 11 | ( 12 | sugar.Assignment( 13 | sugar.Tensor("a", ("i",)), 14 | sugar.Add(sugar.Tensor("b", ("i",)), sugar.Tensor("c", ("i",))), 15 | ), 16 | desugar.Assignment( 17 | desugar.Tensor(0, "a", ("i",)), 18 | desugar.Add(desugar.Tensor(1, "b", ("i",)), desugar.Tensor(2, "c", ("i",))), 19 | ), 20 | ), 21 | ( 22 | sugar.Assignment( 23 | sugar.Tensor("A", ("i", "j")), 24 | sugar.Multiply(sugar.Tensor("B", ("i", "k")), sugar.Tensor("C", ("k", "i"))), 25 | ), 26 | desugar.Assignment( 27 | desugar.Tensor(0, "A", ("i", "j")), 28 | desugar.Contract( 29 | "k", 30 | desugar.Multiply( 31 | desugar.Tensor(1, "B", ("i", "k")), 32 | desugar.Tensor(2, "C", ("k", "i")), 33 | ), 34 | ), 35 | ), 36 | ), 37 | ( 38 | sugar.Assignment( 39 | sugar.Tensor("A", ("i", "j")), 40 | sugar.Add( 41 | sugar.Tensor("K", ("i", "j")), 42 | sugar.Multiply(sugar.Tensor("B", ("i", "k")), sugar.Tensor("C", ("k", "i"))), 43 | ), 44 | ), 45 | desugar.Assignment( 46 | desugar.Tensor(0, "A", ("i", "j")), 47 | desugar.Add( 48 | desugar.Tensor(1, "K", ("i", "j")), 49 | desugar.Contract( 50 | "k", 51 | desugar.Multiply( 52 | desugar.Tensor(2, "B", ("i", "k")), 53 | desugar.Tensor(3, "C", ("k", "i")), 54 | ), 55 | ), 56 | ), 57 | ), 58 | ), 59 | ( 60 | sugar.Assignment(sugar.Tensor("a", ("i",)), sugar.Tensor("b", ("i", "j"))), 61 | desugar.Assignment( 62 | desugar.Tensor(0, "a", ("i",)), 63 | desugar.Contract("j", desugar.Tensor(1, "b", ("i", "j"))), 64 | ), 65 | ), 66 | ], 67 | ) 68 | def test_desugar(expression, expected): 69 | actual = desugar_assignment(expression) 70 | assert actual == expected 71 | -------------------------------------------------------------------------------- /src/tensora/compile/_porcelain.py: -------------------------------------------------------------------------------- 1 | __all__ = ["evaluate", "evaluate_cffi", "evaluate_tensora", "tensor_method"] 2 | 3 | from functools import lru_cache 4 | 5 | from returns.functions import raise_exception 6 | 7 | from ..expression import parse_assignment 8 | from ..format import parse_format 9 | from ..problem import Problem, make_problem 10 | from ..tensor import Tensor 11 | from ._tensor_method import BackendCompiler, TensorMethod 12 | 13 | 14 | @lru_cache() 15 | def cachable_tensor_method(problem: Problem, backend: BackendCompiler) -> TensorMethod: 16 | return TensorMethod(problem, backend=backend) 17 | 18 | 19 | def tensor_method( 20 | assignment: str, 21 | formats: dict[str, str], 22 | backend: BackendCompiler = BackendCompiler.llvm, 23 | ) -> TensorMethod: 24 | parsed_assignment = parse_assignment(assignment).alt(raise_exception).unwrap() 25 | parsed_formats = { 26 | name: parse_format(format).alt(raise_exception).unwrap() 27 | for name, format in formats.items() 28 | } 29 | 30 | problem = make_problem(parsed_assignment, parsed_formats).alt(raise_exception).unwrap() 31 | 32 | return cachable_tensor_method(problem, backend) 33 | 34 | 35 | def evaluate_cffi(assignment: str, output_format: str, **inputs: Tensor) -> Tensor: 36 | parsed_assignment = parse_assignment(assignment).alt(raise_exception).unwrap() 37 | input_formats = {name: tensor.format for name, tensor in inputs.items()} 38 | parsed_output_format = parse_format(output_format).alt(raise_exception).unwrap() 39 | 40 | formats = {parsed_assignment.target.name: parsed_output_format} | input_formats 41 | 42 | problem = make_problem(parsed_assignment, formats).alt(raise_exception).unwrap() 43 | 44 | function = cachable_tensor_method(problem, BackendCompiler.cffi) 45 | 46 | return function(**inputs) 47 | 48 | 49 | def evaluate_tensora(assignment: str, output_format: str, **inputs: Tensor) -> Tensor: 50 | parsed_assignment = parse_assignment(assignment).alt(raise_exception).unwrap() 51 | input_formats = {name: tensor.format for name, tensor in inputs.items()} 52 | parsed_output_format = parse_format(output_format).alt(raise_exception).unwrap() 53 | 54 | formats = {parsed_assignment.target.name: parsed_output_format} | input_formats 55 | 56 | problem = make_problem(parsed_assignment, formats).alt(raise_exception).unwrap() 57 | 58 | function = cachable_tensor_method(problem, BackendCompiler.llvm) 59 | 60 | return function(**inputs) 61 | 62 | 63 | def evaluate(assignment: str, output_format: str, **inputs: Tensor) -> Tensor: 64 | return evaluate_tensora(assignment, output_format, **inputs) 65 | -------------------------------------------------------------------------------- /src/tensora/desugar/_index_dimensions.py: -------------------------------------------------------------------------------- 1 | __all__ = ["index_dimensions"] 2 | 3 | from functools import singledispatch 4 | 5 | from ..iteration_graph import TensorDimension 6 | from .ast import Add, Assignment, Contract, Expression, Float, Integer, Multiply, Tensor 7 | 8 | 9 | @singledispatch 10 | def index_dimensions_expression(self: Expression) -> dict[str, TensorDimension]: 11 | raise NotImplementedError(f"index_dimensions not implemented for {type(self)}: {self}") 12 | 13 | 14 | @index_dimensions_expression.register(Integer) 15 | @index_dimensions_expression.register(Float) 16 | def index_dimensions_nothing(self: Integer | Float) -> dict[str, TensorDimension]: 17 | return {} 18 | 19 | 20 | @index_dimensions_expression.register(Tensor) 21 | def index_dimensions_tensor(self: Tensor) -> dict[str, TensorDimension]: 22 | indexes = {} 23 | for i, index_i in enumerate(self.indexes): 24 | if index_i not in indexes: 25 | indexes[index_i] = TensorDimension(self.name, i) 26 | return indexes 27 | 28 | 29 | @index_dimensions_expression.register(Add) 30 | @index_dimensions_expression.register(Multiply) 31 | def index_dimensions_add(self: Add | Multiply) -> dict[str, TensorDimension]: 32 | left_dimensions = index_dimensions_expression(self.left) 33 | right_dimensions = index_dimensions_expression(self.right) 34 | 35 | indexes = left_dimensions.copy() 36 | for index_i, dimension in right_dimensions.items(): 37 | if index_i not in indexes: 38 | indexes[index_i] = dimension 39 | return indexes 40 | 41 | 42 | @index_dimensions_expression.register(Contract) 43 | def index_dimensions_contract(self: Contract) -> dict[str, TensorDimension]: 44 | return index_dimensions_expression(self.expression) 45 | 46 | 47 | def index_dimensions(self: Assignment) -> dict[str, TensorDimension]: 48 | """Find a tensor name and dimension for each index in the assignment. 49 | 50 | The only way a kernel can know the size of an index is to get it from one 51 | of the tensors with a dimension indexed by that index. For each index, there 52 | will usually be multiple tensors whose dimension is indexed by that index, 53 | but they should all have the same size. This function returns the first one 54 | it finds for each index. 55 | """ 56 | target_dimensions = index_dimensions_expression(self.target) 57 | right_dimensions = index_dimensions_expression(self.expression) 58 | 59 | indexes = target_dimensions.copy() 60 | for index_i, dimension in right_dimensions.items(): 61 | if index_i not in indexes: 62 | indexes[index_i] = dimension 63 | return indexes 64 | -------------------------------------------------------------------------------- /src/tensora/codegen/_type_to_llvm.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "attribute_indexes", 3 | "llvm_boolean_type", 4 | "llvm_float_type", 5 | "llvm_integer_type", 6 | "llvm_mode_type", 7 | "llvm_size_type", 8 | "type_to_llvm", 9 | ] 10 | 11 | from functools import singledispatch 12 | 13 | import llvmlite.ir as llvm 14 | 15 | from ..ir.types import Array, FixedArray, Float, Integer, Mode, Pointer, Tensor, Type 16 | 17 | llvm_integer_type = llvm.IntType(32) 18 | llvm_float_type = llvm.DoubleType() 19 | llvm_boolean_type = llvm.IntType(1) 20 | llvm_mode_type = llvm.IntType(8) 21 | llvm_size_type = llvm.IntType(64) 22 | 23 | 24 | @singledispatch 25 | def type_to_llvm(self: Type) -> llvm.Type: 26 | raise NotImplementedError(f"type_to_llvm not implemented for {type(self)}: {self}") 27 | 28 | 29 | @type_to_llvm.register(Integer) 30 | def type_to_llvm_integer(self: Integer) -> llvm.Type: 31 | return llvm_integer_type 32 | 33 | 34 | @type_to_llvm.register(Float) 35 | def type_to_llvm_float(self: Float) -> llvm.Type: 36 | return llvm_float_type 37 | 38 | 39 | tensor_attribute_indexes = { 40 | "dimensions": 1, 41 | "indices": 5, 42 | "vals": 6, 43 | } 44 | 45 | 46 | @type_to_llvm.register(Tensor) 47 | def type_to_llvm_tensor(self: Tensor) -> llvm.Type: 48 | return llvm.LiteralStructType( 49 | [ 50 | llvm_integer_type, # order 51 | llvm.PointerType(llvm_integer_type), # dimensions 52 | llvm_integer_type, # csize 53 | llvm.PointerType(llvm_integer_type), # mode_ordering 54 | llvm_mode_type, # mode_types 55 | llvm.PointerType(llvm.PointerType(llvm.PointerType(llvm_integer_type))), # indices 56 | llvm.PointerType(llvm_float_type), # vals 57 | llvm_integer_type, # vals_size 58 | ] 59 | ) 60 | 61 | 62 | @type_to_llvm.register(Mode) 63 | def type_to_llvm_mode(self: Mode) -> llvm.Type: 64 | return llvm_mode_type 65 | 66 | 67 | @type_to_llvm.register(Pointer) 68 | def type_to_llvm_pointer(self: Pointer) -> llvm.Type: 69 | return llvm.PointerType(type_to_llvm(self.target)) 70 | 71 | 72 | @type_to_llvm.register(Array) 73 | def type_to_llvm_array(self: Array) -> llvm.Type: 74 | return llvm.PointerType(type_to_llvm(self.element)) 75 | 76 | 77 | @type_to_llvm.register(FixedArray) 78 | def type_to_llvm_fixed_array(self: FixedArray) -> llvm.Type: 79 | return llvm.ArrayType(type_to_llvm(self.element), self.n) 80 | 81 | 82 | # LLVM index for all attributes. 83 | # If we ever reuse an attribute name, and it goes to a different index, the whole IR will have to 84 | # be redesigned. 85 | attribute_indexes = tensor_attribute_indexes 86 | -------------------------------------------------------------------------------- /src/tensora/_stable_set.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["StableFrozenSet", "StableSet"] 4 | 5 | from typing import AbstractSet, Hashable, Iterator, MutableSet, TypeVar 6 | 7 | Element = TypeVar("Element", bound=Hashable, covariant=True) 8 | 9 | 10 | class StableSet(MutableSet[Element]): 11 | def __init__(self, *items: Element): 12 | # Rely on stable dictionary 13 | self._items = dict.fromkeys(items, None) 14 | 15 | def __len__(self) -> int: 16 | return len(self._items) 17 | 18 | def __contains__(self, x: Element) -> bool: 19 | return x in self._items 20 | 21 | def __iter__(self) -> Iterator[Element]: 22 | return iter(self._items) 23 | 24 | def __reversed__(self): 25 | return StableSet(*reversed(self._items)) 26 | 27 | def __or__(self, other: StableSet[Element]) -> StableSet[Element]: 28 | return StableSet(*self._items, *other._items) 29 | 30 | def __eq__(self, other: StableSet[Element]): 31 | if isinstance(other, StableSet): 32 | return self._items == other._items 33 | else: 34 | return NotImplemented 35 | 36 | def __repr__(self) -> str: 37 | return f"StableSet({', '.join(repr(item) for item in self._items)})" 38 | 39 | def add(self, element: Element, /) -> None: 40 | self._items[element] = None 41 | 42 | def discard(self, element: Element, /) -> None: 43 | del self[element] 44 | 45 | 46 | class StableFrozenSet(AbstractSet[Element]): 47 | def __init__(self, *items: Element): 48 | used = set() 49 | unique_items = [] 50 | for item in items: 51 | if item not in used: 52 | unique_items.append(item) 53 | used.add(item) 54 | self._items = tuple(unique_items) 55 | self._set = frozenset(items) 56 | 57 | def __len__(self) -> int: 58 | return len(self._items) 59 | 60 | def __contains__(self, x: Element) -> bool: 61 | return x in self._set 62 | 63 | def __iter__(self) -> Iterator[Element]: 64 | return iter(self._items) 65 | 66 | def __reversed__(self): 67 | return StableFrozenSet(*reversed(self._items)) 68 | 69 | def __or__(self, other: StableFrozenSet[Element]) -> StableFrozenSet[Element]: 70 | return StableFrozenSet(*self._items, *other._items) 71 | 72 | def __eq__(self, other): 73 | if isinstance(other, StableFrozenSet): 74 | return self._set == other._set 75 | else: 76 | return NotImplemented 77 | 78 | def __hash__(self): 79 | return hash(self._set) 80 | 81 | def __repr__(self) -> str: 82 | return f"StableFrozenSet({', '.join(repr(item) for item in self._items)})" 83 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/identifiable_expression/_extract_context.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["extract_context"] 4 | 5 | from dataclasses import dataclass, field 6 | from functools import singledispatch 7 | 8 | from ...format import Mode 9 | from . import ast 10 | from ._tensor_layer import TensorLayer 11 | 12 | 13 | @dataclass(frozen=True, slots=True, kw_only=True) 14 | class Context: 15 | is_sparse: bool 16 | sparse_leaves: list[TensorLayer] = field(default_factory=list) 17 | dense_leaves: list[TensorLayer] = field(default_factory=list) 18 | indexes: frozenset[str] = frozenset() 19 | has_output: bool = False 20 | has_assemble: bool = False 21 | 22 | def add(self, other: Context) -> Context: 23 | return Context( 24 | is_sparse=self.is_sparse and other.is_sparse, 25 | sparse_leaves=self.sparse_leaves + other.sparse_leaves, 26 | dense_leaves=self.dense_leaves + other.dense_leaves, 27 | indexes=self.indexes | other.indexes, 28 | has_output=self.has_output or other.has_output, 29 | has_assemble=self.has_assemble or other.has_assemble, 30 | ) 31 | 32 | def multiply(self, other: Context) -> Context: 33 | return Context( 34 | is_sparse=self.is_sparse or other.is_sparse, 35 | sparse_leaves=self.sparse_leaves + other.sparse_leaves, 36 | dense_leaves=self.dense_leaves + other.dense_leaves, 37 | indexes=self.indexes | other.indexes, 38 | has_output=self.has_output or other.has_output, 39 | has_assemble=self.has_assemble or other.has_assemble, 40 | ) 41 | 42 | 43 | @singledispatch 44 | def extract_context(self: ast.Expression, index: str) -> Context: 45 | raise NotImplementedError(f"extract_context not implemented for {type(self)}: {self}") 46 | 47 | 48 | @extract_context.register(ast.Literal) 49 | def extract_context_literal(self: ast.Literal, index: str) -> Context: 50 | if self == ast.Integer(0) or self == ast.Float(0.0): 51 | return Context(is_sparse=True) 52 | else: 53 | return Context(is_sparse=False) 54 | 55 | 56 | @extract_context.register(ast.Tensor) 57 | def extract_context_tensor(self: ast.Tensor, index: str) -> Context: 58 | try: 59 | maybe_layer = self.indexes.index(index) 60 | except ValueError: 61 | return Context(is_sparse=False) 62 | 63 | if self.modes[maybe_layer] == Mode.dense: 64 | return Context(is_sparse=False, dense_leaves=[TensorLayer(self, maybe_layer)]) 65 | else: 66 | return Context(is_sparse=True, sparse_leaves=[TensorLayer(self, maybe_layer)]) 67 | 68 | 69 | @extract_context.register(ast.Add) 70 | def extract_context_add(self: ast.Add, index: str) -> Context: 71 | left = extract_context(self.left, index) 72 | right = extract_context(self.right, index) 73 | return left.add(right) 74 | 75 | 76 | @extract_context.register(ast.Multiply) 77 | def extract_context_multiply(self: ast.Multiply, index: str) -> Context: 78 | left = extract_context(self.left, index) 79 | right = extract_context(self.right, index) 80 | return left.multiply(right) 81 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/hand-heart 3 | --- 4 | 5 | # Contributing 6 | 7 | Tensora is free and open source software developed under an MIT license. Development occurs at the [GitHub project](https://github.com/drhagen/tensora). Contributions, big and small, are welcome. 8 | 9 | Bug reports and feature requests may be made directly on the [issues](https://github.com/drhagen/tensora/issues) tab. 10 | 11 | To make a pull request, you will need to fork the repo, clone the repo, make the changes, run the tests, push the changes, and [open a PR](https://github.com/drhagen/tensora/pulls). 12 | 13 | ## Cloning the repo 14 | 15 | To make a local copy of Tensora, clone the repository with git: 16 | 17 | ```shell 18 | git clone https://github.com/drhagen/tensora.git 19 | ``` 20 | 21 | ## Installing from source 22 | 23 | Tensora uses uv for project and environment management with Hatchling as its build backend. In whatever environment you prefer, ensure [uv](https://github.com/astral-sh/uv) is installed and then use uv to install Tensora and its dependencies: 24 | 25 | ```shell 26 | uv sync 27 | ``` 28 | 29 | ## Testing 30 | 31 | Tensora uses pytest to run the tests in the `tests/` directory. The test command is encapsulated with Nox: 32 | 33 | ```shell 34 | uv run nox -s test test_cffi test_numpy 35 | ``` 36 | 37 | This will try to test with all compatible Python versions that `nox` can find. To run the tests with only a particular version, run something like this: 38 | 39 | ```shell 40 | uv run nox -s test-3.13 test_cffi-3.13 test_numpy-3.13 41 | ``` 42 | 43 | It is good to run the tests locally before making a PR, but it is not necessary to have all Python versions run. It is rare for a failure to appear in a single version, and the CI will catch it anyway. 44 | 45 | ## Code quality 46 | 47 | Tensora uses Ruff to ensure a minimum standard of code quality. The code quality commands are encapsulated with Nox: 48 | 49 | ```shell 50 | uv run nox -s format 51 | uv run nox -s lint 52 | ``` 53 | 54 | ## Generating the docs 55 | 56 | Tensora uses MkDocs to generate HTML docs from Markdown. For development purposes, they can be served locally without needing to build them first: 57 | 58 | ```shell 59 | uv run mkdocs serve 60 | ``` 61 | 62 | To deploy the current docs to GitHub Pages, Tensora uses the MkDocs `gh-deploy` command that builds the static site on the `gh-pages` branch, commits, and pushes to the origin: 63 | 64 | ```shell 65 | uv run mkdocs gh-deploy 66 | ``` 67 | 68 | ## Making a release 69 | 70 | 1. Bump 71 | 1. Increment version in `pyproject.toml` 72 | 2. Run `uv lock` 73 | 3. Commit with message "Bump version number to X.Y.Z" 74 | 4. Push commit to GitHub 75 | 5. Check [CI](https://github.com/drhagen/tensora/actions/workflows/ci.yml) to ensure all tests pass 76 | 2. Tag 77 | 1. Tag commit with "vX.Y.Z" 78 | 2. Push tag to GitHub 79 | 3. Wait for [build](https://github.com/drhagen/tensora/actions/workflows/release.yml) to finish 80 | 4. Check [PyPI](https://pypi.org/project/tensora/) for good upload 81 | 3. Document 82 | 1. Create [GitHub release](https://github.com/drhagen/tensora/releases) with name "Tensora X.Y.Z" and major changes in body 83 | 2. If appropriate, deploy updated docs 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["uv_build == 0.8.*"] 3 | build-backend = "uv_build" 4 | 5 | [project] 6 | name = "tensora" 7 | version = "0.4.3" 8 | description = "Library for dense and sparse tensors built on the tensor algebra compiler" 9 | authors = [{ name = "David Hagen", email = "david@drhagen.com" }] 10 | license = "MIT" 11 | readme = "README.md" 12 | keywords = ["tensor", "sparse", "matrix", "array"] 13 | requires-python = ">=3.10" 14 | classifiers = [ 15 | "Development Status :: 4 - Beta", 16 | "Intended Audience :: Science/Research", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: POSIX :: Linux", 20 | "Operating System :: MacOS :: MacOS X", 21 | "Operating System :: Microsoft :: Windows", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | "Programming Language :: Python :: 3.13", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | ] 29 | dependencies = [ 30 | "returns >= 0.20,<0.27", 31 | "cffi ~= 1.16", 32 | "llvmlite == 0.45.*", 33 | "parsita == 2.*", 34 | "typer == 0.19.*", 35 | ] 36 | 37 | [project.optional-dependencies] 38 | cffi = [ 39 | # setuptools is required for cffi compilation on Python 3.12+ 40 | # https://cffi.readthedocs.io/en/latest/whatsnew.html#v1-16-0 41 | "setuptools >= 69; python_version >= '3.12'", 42 | ] 43 | numpy = [ 44 | "numpy >= 1.24,<3.0", 45 | ] 46 | scipy = [ 47 | "scipy ~= 1.7", 48 | ] 49 | 50 | [project.urls] 51 | Documentation = "https://tensora.drhagen.com" 52 | Repository = "https://github.com/drhagen/tensora" 53 | 54 | [project.scripts] 55 | tensora = "tensora.cli:app" 56 | 57 | [dependency-groups] 58 | nox = [ 59 | "nox-uv == 0.6.*", 60 | ] 61 | test = [ 62 | "pytest == 7.*", # Upgrading to pytest 8 breaks hypofuzz 63 | "coverage == 7.*" 64 | ] 65 | fuzz = [ 66 | "hypothesis == 6.*", 67 | "hypofuzz == 24.*" 68 | ] 69 | lint = [ 70 | "ruff == 0.14.*" 71 | ] 72 | docs = [ 73 | "mkdocs-material == 9.*" 74 | ] 75 | 76 | [tool.uv] 77 | default-groups = "all" 78 | python-preference = "only-managed" 79 | 80 | [tool.coverage.run] 81 | branch = true 82 | source_pkgs = ["tensora"] 83 | relative_files = true 84 | 85 | [tool.coverage.report] 86 | exclude_lines = [ 87 | "pragma: no cover", 88 | "pass", 89 | "raise NotImplementedError", 90 | "def __str__", 91 | "def __repr__", 92 | ] 93 | 94 | [tool.coverage.paths] 95 | source = [ 96 | "src/", 97 | "**/site-packages/", 98 | ] 99 | 100 | [tool.ruff] 101 | src = ["src"] 102 | line-length = 99 103 | 104 | [tool.ruff.lint] 105 | extend-select = [ 106 | "I", # isort 107 | "N", # pep8-naming 108 | "RUF", # ruff 109 | "B", # flake8-bugbear 110 | "N", # flake8-broken-line 111 | "C4", # flake8-comprehensions 112 | "PIE", # flake8-pie 113 | "PT", # flake8-pytest-style 114 | "PTH", # flake8-use-pathlib 115 | "ERA", # flake8-eradicate 116 | ] 117 | 118 | [tool.ruff.lint.flake8-bugbear] 119 | extend-immutable-calls = ["hypothesis.strategies.integers"] 120 | 121 | [tool.ruff.lint.per-file-ignores] 122 | # Allow glob imports and uppercase identifiers in tests 123 | "tests*/*" = ["F403", "F405", "N802", "N806"] 124 | 125 | # F401: Allow unused imports in __init__.py 126 | "__init__.py" = ["F401"] 127 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/outputs/_bucket.py: -------------------------------------------------------------------------------- 1 | __all__ = ["BucketOutput"] 2 | 3 | from dataclasses import dataclass, replace 4 | 5 | from ...format import Mode 6 | from ...ir import SourceBuilder, types 7 | from ...ir.ast import Add, Expression, LessThan, Multiply, Variable 8 | from ...kernel_type import KernelType 9 | from .._names import dimension_name 10 | from ..identifiable_expression import TensorLayer 11 | from ..identifiable_expression import ast as ie_ast 12 | from ._base import Output 13 | 14 | 15 | @dataclass(frozen=True, slots=True) 16 | class BucketOutput(Output): 17 | output: ie_ast.Tensor 18 | layers: list[int] 19 | unfulfilled: set[int] 20 | 21 | def __init__( 22 | self, output: ie_ast.Tensor, layers: list[int], unfulfilled: set[int] | None = None 23 | ): 24 | object.__setattr__(self, "output", output) 25 | object.__setattr__(self, "layers", layers) 26 | if unfulfilled is not None: 27 | object.__setattr__(self, "unfulfilled", unfulfilled) 28 | else: 29 | unfulfilled = {layer for layer in layers if output.modes[layer] == Mode.compressed} 30 | object.__setattr__(self, "unfulfilled", unfulfilled) 31 | 32 | def write_declarations(self, right_hand_side: Expression) -> SourceBuilder: 33 | source = SourceBuilder("Bucket initialization") 34 | source.append(self.name().declare(types.Pointer(types.float)).assign(right_hand_side)) 35 | bucket_loop_index = self.loop_name() 36 | source.append(bucket_loop_index.declare(types.integer).assign(0)) 37 | with source.loop(LessThan(bucket_loop_index, Multiply.join(self.dimension_names()))): 38 | source.append(self.name().idx(bucket_loop_index).assign(0)) 39 | source.append(bucket_loop_index.increment()) 40 | return source 41 | 42 | def next_output( 43 | self, iteration_output: TensorLayer | None, kernel_type: KernelType 44 | ) -> tuple[Output, SourceBuilder, SourceBuilder]: 45 | if iteration_output is None: 46 | return self, SourceBuilder(), SourceBuilder() 47 | else: 48 | next_unfulfilled = self.unfulfilled - {iteration_output.layer} 49 | return replace(self, unfulfilled=next_unfulfilled), SourceBuilder(), SourceBuilder() 50 | 51 | def write_assignment(self, right_hand_side: Expression, kernel_type: KernelType): 52 | source = SourceBuilder() 53 | bucket_index = self.ravel_indexes( 54 | self.dimension_names(), 55 | [Variable(self.output.indexes[layer]) for layer in self.layers], 56 | ) 57 | source.append(self.name().idx(bucket_index).increment(right_hand_side)) 58 | return source 59 | 60 | def ravel_indexes(self, dimensions: list[Variable], indexes: list[Variable]): 61 | dimensions_so_far: list[Variable] = [] 62 | terms: list[Expression] = [] 63 | for dim_i, index_i in zip(reversed(dimensions), reversed(indexes), strict=True): 64 | terms.append(Multiply.join([index_i, *dimensions_so_far])) 65 | dimensions_so_far.append(dim_i) 66 | 67 | # Reverse it to make it look nice 68 | return Add.join(list(reversed(terms))) 69 | 70 | def name(self) -> Variable: 71 | return Variable(f"bucket_{self.output.id}{''.join(f'_{x}' for x in self.layers)}") 72 | 73 | def loop_name(self) -> Variable: 74 | return Variable(f"i_bucket_{self.output.id}{''.join(f'_{x}' for x in self.layers)}") 75 | 76 | def dimension_names(self): 77 | return [dimension_name(self.output.indexes[layer]) for layer in self.layers] 78 | -------------------------------------------------------------------------------- /tests/test_evaluate.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.pool import ThreadPool 2 | from random import randrange 3 | 4 | from tensora import Tensor, evaluate 5 | 6 | 7 | def test_csr_matrix_vector_product(): 8 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 9 | 10 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 11 | 12 | expected = Tensor.from_aos([(0,), (1,)], [-5.0, 14.0], dimensions=(2,), format="d") 13 | 14 | actual = evaluate("y(i) = A(i,j) * x(j)", "d", A=A, x=x) 15 | 16 | assert actual == expected 17 | 18 | 19 | def test_csc_matrix_vector_product(): 20 | A = Tensor.from_aos( 21 | [(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="d1s0" 22 | ) 23 | 24 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 25 | 26 | expected = Tensor.from_aos([(0,), (1,)], [-5.0, 14.0], dimensions=(2,), format="d") 27 | 28 | actual = evaluate("y(i) = A(i,j) * x(j)", "d", A=A, x=x) 29 | 30 | assert actual == expected 31 | 32 | 33 | def test_csr_matrix_plus_csr_matrix(): 34 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 35 | 36 | B = Tensor.from_aos([(1, 1), (1, 2), (0, 2)], [-3.0, 4.0, 3.5], dimensions=(2, 3), format="ds") 37 | 38 | expected = Tensor.from_aos( 39 | [(1, 0), (0, 1), (1, 2), (1, 1), (0, 2)], 40 | [2.0, -2.0, 8.0, -3.0, 3.5], 41 | dimensions=(2, 3), 42 | format="ds", 43 | ) 44 | 45 | actual = evaluate("C(i,j) = A(i,j) + B(i,j)", "ds", A=A, B=B) 46 | 47 | assert actual == expected 48 | 49 | 50 | def test_rhs(): 51 | A0 = Tensor.from_lol([2, -3, 0]) 52 | A1 = Tensor.from_aos([(0, 2), (1, 2), (2, 2)], [3, 3, -3], dimensions=(3, 3), format="ds") 53 | A2 = Tensor.from_aos( 54 | [(0, 0, 1), (1, 0, 1), (2, 0, 1)], [-2, -2, 2], dimensions=(3, 3, 3), format="dss" 55 | ) 56 | x = Tensor.from_lol([2, 3, 5]) 57 | 58 | expected = Tensor.from_lol([5, 0, -3]) 59 | 60 | assignment = "f(i) = A0(i) + A1(i,j) * x(j) + A2(i,k,l) * x(k) * x(l)" 61 | 62 | actual = evaluate(assignment, "d", A0=A0, A1=A1, A2=A2, x=x) 63 | 64 | assert actual == expected 65 | 66 | 67 | def test_many_elements_stack_overflow(): 68 | size = 1000000 # Big enough to trigger stack overflow 69 | 70 | a = Tensor.from_dok({}, dimensions=(size,), format="d") 71 | b = evaluate("b(i) = a(i)", "d", a=a) 72 | 73 | assert b == Tensor.from_dok({}, dimensions=(size,), format="d") 74 | 75 | 76 | def test_many_elements_realloc(): 77 | size = 2000000 # Big enough to trigger realloc 78 | 79 | a = Tensor.from_dok({}, dimensions=(size,), format="d") 80 | b = evaluate("b(i) = a(i)", "s", a=a) 81 | 82 | assert b == Tensor.from_dok({}, dimensions=(size,), format="s") 83 | 84 | 85 | def test_multithread_evaluation(): 86 | # As of version 1.14.4 of cffi, the FFI.compile method is not thread safe. This tests that evaluation of different 87 | # kernels is thread safe. 88 | A = Tensor.from_aos([(1, 0), (0, 1), (1, 2)], [2.0, -2.0, 4.0], dimensions=(2, 3), format="ds") 89 | 90 | x = Tensor.from_aos([(0,), (1,), (2,)], [3.0, 2.5, 2.0], dimensions=(3,), format="d") 91 | 92 | def run_eval(): 93 | # Generate a random expression so that the cache cannot be hit 94 | return evaluate(f"y{randrange(1024)}(i) = A(i,j) * x(j)", "d", A=A, x=x) 95 | 96 | n = 4 97 | with ThreadPool(n) as p: 98 | results = p.starmap(run_eval, [()] * n) 99 | 100 | expected = run_eval() 101 | 102 | for actual in results: 103 | assert actual == expected 104 | -------------------------------------------------------------------------------- /tests/test_expression.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora.expression import ( 4 | InconsistentDimensionsError, 5 | MutatingAssignmentError, 6 | NameConflictError, 7 | parse_assignment, 8 | ) 9 | from tensora.expression.ast import * 10 | 11 | assignment_strings = [ 12 | ( 13 | "A(i) = B(i,j) * C(j)", 14 | Assignment(Tensor("A", ("i",)), Multiply(Tensor("B", ("i", "j")), Tensor("C", ("j",)))), 15 | ), 16 | ( 17 | "ab(i) = a(i) + b(i)", 18 | Assignment(Tensor("ab", ("i",)), Add(Tensor("a", ("i",)), Tensor("b", ("i",)))), 19 | ), 20 | ( 21 | "D(i) = A(i) - B(i)", 22 | Assignment(Tensor("D", ("i",)), Subtract(Tensor("A", ("i",)), Tensor("B", ("i",)))), 23 | ), 24 | ( 25 | "B2(i) = 2.0 * B(i)", 26 | Assignment(Tensor("B2", ("i",)), Multiply(Float(2.0), Tensor("B", ("i",)))), 27 | ), 28 | ( 29 | "ab2(i) = 2.0 * (a(i) + b(i))", 30 | Assignment( 31 | Tensor("ab2", ("i",)), 32 | Multiply(Float(2.0), Add(Tensor("a", ("i",)), Tensor("b", ("i",)))), 33 | ), 34 | ), 35 | ( 36 | "ab2(i) = (a(i) + b(i)) * 2.0", 37 | Assignment( 38 | Tensor("ab2", ("i",)), 39 | Multiply(Add(Tensor("a", ("i",)), Tensor("b", ("i",))), Float(2.0)), 40 | ), 41 | ), 42 | ] 43 | 44 | 45 | @pytest.mark.parametrize(("string", "assignment"), assignment_strings) 46 | def test_assignment_parsing(string, assignment): 47 | actual = parse_assignment(string).unwrap() 48 | assert actual == assignment 49 | 50 | 51 | @pytest.mark.parametrize(("string", "assignment"), assignment_strings) 52 | def test_assignment_deparsing(string, assignment): 53 | deparsed = assignment.deparse() 54 | assert deparsed == string 55 | 56 | 57 | def test_mutating_assignment(): 58 | assert isinstance(parse_assignment("A(i) = A(i) + 1").failure(), MutatingAssignmentError) 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "assignment", 63 | [ 64 | "A(i) = B(i) + B(i,j)", 65 | "A(i) = B(i) - B()", 66 | "A(i) = B(i,j) * B(k,l,m)", 67 | "A(i) = B(i,j) + C(j,k) + (B(k) * D(k))", 68 | ], 69 | ) 70 | def test_inconsistent_variable_size(assignment): 71 | assert isinstance(parse_assignment(assignment).failure(), InconsistentDimensionsError) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "assignment", 76 | [ 77 | "A(i) = B(B)", 78 | "A(i) = B(i,j) * C(j,B)", 79 | "A(i) = C(j,B) * B(i,j)", 80 | "A(A) = B(i)", 81 | "A(i) = B(A)", 82 | "A(B) = B(i)", 83 | ], 84 | ) 85 | def test_name_conflict(assignment): 86 | assert isinstance(parse_assignment(assignment).failure(), NameConflictError) 87 | 88 | 89 | def parse(string): 90 | return parse_assignment(string).unwrap() 91 | 92 | 93 | def test_assignment_to_string(): 94 | string = "A(i) = 2 * B(i,j) * (C(j) + D(j))" 95 | assert str(parse(string)) == string 96 | 97 | 98 | @pytest.mark.parametrize( 99 | ("string", "output"), 100 | [ 101 | ( 102 | "y(i) = 0.5 * (b() - a()) * (x1(i,j) + x2(i,j)) * z(j)", 103 | {"y": 1, "b": 0, "a": 0, "x1": 2, "x2": 2, "z": 1}, 104 | ), 105 | ("B2(i,k) = B(i,j) * B(j,k)", {"B2": 2, "B": 2}), 106 | ], 107 | ) 108 | def test_variable_order(string, output): 109 | assert parse(string).variable_orders() == output 110 | 111 | 112 | @pytest.mark.parametrize( 113 | ("string", "output"), 114 | [ 115 | ( 116 | "y(i) = 0.5 * (b() - a()) * (x1(i,j) + x2(i,j)) * z(j)", 117 | {"i": {("y", 0), ("x1", 0), ("x2", 0)}, "j": {("x1", 1), ("x2", 1), ("z", 0)}}, 118 | ), 119 | ( 120 | "B2(i,k) = B(i,j) * B(j,k)", 121 | {"i": {("B2", 0), ("B", 0)}, "j": {("B", 1), ("B", 0)}, "k": {("B2", 1), ("B", 1)}}, 122 | ), 123 | ("diagA2(i) = A(i,i) + A(i,i)", {"i": {("diagA2", 0), ("A", 0), ("A", 1)}}), 124 | ], 125 | ) 126 | def test_index_participants(string, output): 127 | assert parse(string).index_participants() == output 128 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/_write_sparse_ir.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "write_crd_assembly", 3 | "write_pos_allocation", 4 | "write_pos_assembly", 5 | "write_sparse_initialization", 6 | ] 7 | 8 | from ..format import Mode 9 | from ..ir import SourceBuilder, types 10 | from ..ir.ast import ArrayReallocate, GreaterThanOrEqual, Max, Multiply, Variable 11 | from ._names import dimension_name 12 | from .identifiable_expression import TensorLayer 13 | 14 | 15 | def write_sparse_initialization(leaf: TensorLayer) -> SourceBuilder: 16 | source = SourceBuilder() 17 | 18 | index_variable = leaf.layer_pointer() 19 | start_index = leaf.previous_layer_pointer() 20 | end_index = leaf.sparse_end_name() 21 | pos_array = leaf.pos_name() 22 | 23 | source.append(index_variable.declare(types.integer).assign(pos_array.idx(start_index))) 24 | source.append(end_index.declare(types.integer).assign(pos_array.idx(start_index.plus(1)))) 25 | 26 | return source 27 | 28 | 29 | def write_crd_assembly(output: TensorLayer) -> SourceBuilder: 30 | source = SourceBuilder("crd assembly") 31 | 32 | pointer = output.layer_pointer() 33 | capacity = output.crd_capacity_name() 34 | crd = output.crd_name() 35 | loop_variable = Variable(output.tensor.indexes[output.layer]) 36 | 37 | with source.branch(GreaterThanOrEqual(pointer, capacity)): 38 | source.append(capacity.assign(capacity.times(2))) 39 | source.append(crd.assign(ArrayReallocate(crd, types.integer, capacity))) 40 | 41 | source.append(crd.idx(pointer).assign(loop_variable)) 42 | 43 | return source 44 | 45 | 46 | def write_pos_assembly(output: TensorLayer) -> SourceBuilder: 47 | source = SourceBuilder("pos assembly") 48 | 49 | pointer = output.layer_pointer() 50 | pos = output.pos_name() 51 | previous_pointer = output.previous_layer_pointer() 52 | 53 | source.append(pos.idx(previous_pointer.plus(1)).assign(pointer)) 54 | 55 | return source 56 | 57 | 58 | def write_pos_allocation(output: TensorLayer) -> SourceBuilder: 59 | dense_dimensions = [] 60 | for i_layer in range(output.layer + 1, output.tensor.order): 61 | index_variable_i = output.tensor.indexes[i_layer] 62 | mode_i = output.tensor.modes[i_layer] 63 | if mode_i == Mode.compressed: 64 | break 65 | dense_dimensions.append(dimension_name(index_variable_i)) 66 | 67 | layer_being_allocated = output.layer + len(dense_dimensions) + 1 68 | if layer_being_allocated == len(output.tensor.indexes): 69 | comment = "vals allocation" 70 | capacity = output.vals_capacity_name() 71 | array = output.vals_name() 72 | type = types.float 73 | bonus = 0 74 | else: 75 | comment = "pos allocation for next sparse layer" 76 | target_leaf = TensorLayer(output.tensor, output.layer + len(dense_dimensions) + 1) 77 | capacity = target_leaf.pos_capacity_name() 78 | array = target_leaf.pos_name() 79 | type = types.integer 80 | bonus = 1 # pos array is 1 longer 81 | 82 | source = SourceBuilder(comment) 83 | 84 | # TODO: The minimum capacity formulas do not seem consistent, double check them 85 | if len(dense_dimensions) == 0: 86 | # Peephole optimization cannot figure out that doubling is always bigger with no dense dimensions, so the 87 | # dropping of max() must be done manually. 88 | minimum_capacity = output.layer_pointer().plus(bonus) 89 | with source.branch(GreaterThanOrEqual(minimum_capacity, capacity)): 90 | source.append(capacity.assign(capacity.times(2))) 91 | source.append(array.assign(ArrayReallocate(array, type, capacity))) 92 | else: 93 | minimum_capacity = ( 94 | output.layer_pointer().plus(1).times(Multiply.join(dense_dimensions)).plus(bonus) 95 | ) 96 | 97 | with source.branch(GreaterThanOrEqual(minimum_capacity, capacity)): 98 | source.append(capacity.assign(Max(capacity.times(2), minimum_capacity))) 99 | source.append(array.assign(ArrayReallocate(array, type, capacity))) 100 | 101 | return source 102 | -------------------------------------------------------------------------------- /src/tensora/cli.py: -------------------------------------------------------------------------------- 1 | __all__ = ["app"] 2 | 3 | from pathlib import Path 4 | from typing import Annotated, Optional 5 | 6 | import typer 7 | from parsita import ParseError 8 | from returns.result import Failure, Success 9 | 10 | from .expression import parse_assignment 11 | from .format import parse_named_format 12 | from .generate import Language, generate_code 13 | from .kernel_type import KernelType 14 | from .problem import make_problem 15 | 16 | app = typer.Typer() 17 | 18 | 19 | @app.command() 20 | def tensora( 21 | assignment: Annotated[ 22 | str, 23 | typer.Argument( 24 | show_default=False, 25 | help="The assignment for which to generate code, e.g. y(i) = A(i,j) * x(j).", 26 | ), 27 | ], 28 | target_format_strings: Annotated[ 29 | list[str], 30 | typer.Option( 31 | "--format", 32 | "-f", 33 | help=( 34 | "A tensor and its format separated by a colon, e.g. A:d1s0 for CSC matrix. " 35 | "Unmentioned tensors are be assumed to be all dense." 36 | ), 37 | ), 38 | ] = [], # noqa: B006 Typer does not support Sequence or tuple 39 | kernel_types: Annotated[ 40 | list[KernelType], 41 | typer.Option( 42 | "--type", 43 | "-t", 44 | help="The type of kernel that will be generated. Can be mentioned multiple times.", 45 | ), 46 | ] = [KernelType.compute], # noqa: B006 Typer does not support Sequence or tuple 47 | language: Annotated[ 48 | Language, 49 | typer.Option( 50 | "--language", 51 | "-l", 52 | help="The language in which to generate the kernel.", 53 | ), 54 | ] = Language.c, 55 | output_path: Annotated[ 56 | Optional[Path], 57 | typer.Option( 58 | "--output", 59 | "-o", 60 | writable=True, 61 | help=( 62 | "The file to which the kernel will be written. If not specified, prints to " 63 | "standard out." 64 | ), 65 | ), 66 | ] = None, 67 | ): 68 | # Parse assignment 69 | match parse_assignment(assignment): 70 | case Failure(error): 71 | typer.echo(f"Failed to parse assignment:\n{error}", err=True) 72 | raise typer.Exit(1) 73 | case Success(parsed_assignment): 74 | pass 75 | case _: 76 | raise NotImplementedError() 77 | 78 | # Parse formats 79 | parsed_formats = {} 80 | for target_format_string in target_format_strings: 81 | match parse_named_format(target_format_string): 82 | case Failure(ParseError(_) as error): 83 | typer.echo(f"Failed to parse format:\n{error}", err=True) 84 | raise typer.Exit(1) 85 | case Failure(error): 86 | typer.echo(str(error), err=True) 87 | raise typer.Exit(1) 88 | case Success((target, format)): 89 | pass 90 | case _: 91 | raise NotImplementedError() 92 | 93 | if target in parsed_formats: 94 | typer.echo(f"Format for {target} was mentioned multiple times", err=True) 95 | raise typer.Exit(1) 96 | 97 | parsed_formats[target] = format 98 | 99 | # Validate and standardize assignment and formats 100 | match make_problem(parsed_assignment, parsed_formats): 101 | case Failure(error): 102 | typer.echo(str(error), err=True) 103 | raise typer.Exit(1) 104 | case Success(problem): 105 | pass 106 | case _: 107 | raise NotImplementedError() 108 | 109 | # Generate code 110 | match generate_code(problem, kernel_types, language): 111 | case Failure(error): 112 | typer.echo(str(error), err=True) 113 | raise typer.Exit(1) 114 | case Success(code): 115 | pass 116 | case _: 117 | raise NotImplementedError() 118 | 119 | if output_path is None: 120 | typer.echo(code) 121 | else: 122 | output_path.write_text(code) 123 | -------------------------------------------------------------------------------- /fuzz_tests/strategies.py: -------------------------------------------------------------------------------- 1 | import hypothesis.strategies as st 2 | from hypothesis import assume 3 | 4 | from tensora import Tensor 5 | from tensora.expression import ( 6 | InconsistentDimensionsError, 7 | MutatingAssignmentError, 8 | NameConflictError, 9 | ast, 10 | ) 11 | from tensora.format import Format, Mode 12 | from tensora.problem import Problem 13 | 14 | names = st.from_regex(r"[A-Za-z][A-Za-z0-9]*", fullmatch=True) 15 | variables = st.builds( 16 | ast.Tensor, name=names, indexes=st.builds(tuple, st.lists(names, max_size=4)) 17 | ) 18 | expressions = st.deferred( 19 | lambda: st.builds(ast.Integer, st.integers(min_value=0, max_value=2**16)) 20 | | st.builds(ast.Float, st.floats(min_value=0, allow_infinity=False, allow_nan=False)) 21 | | variables 22 | | adds 23 | | subtracts 24 | | multiplies 25 | ) 26 | adds = st.builds(ast.Add, expressions, expressions) 27 | subtracts = st.builds(ast.Subtract, expressions, expressions) 28 | multiplies = st.builds(ast.Multiply, expressions, expressions) 29 | 30 | 31 | @st.composite 32 | def assignments(draw): 33 | target = draw(variables) 34 | expression = draw(expressions) 35 | try: 36 | return ast.Assignment(target, expression) 37 | except (InconsistentDimensionsError, MutatingAssignmentError, NameConflictError): 38 | assume(False) 39 | 40 | 41 | modes = st.sampled_from(Mode) 42 | 43 | 44 | @st.composite 45 | def formats(draw, orders=st.integers(min_value=0, max_value=4)) -> Format: 46 | order = draw(orders) 47 | format_modes = tuple(draw(st.lists(modes, min_size=order, max_size=order))) 48 | format_mode_ordering = tuple(draw(st.permutations(range(order)))) 49 | return Format(format_modes, format_mode_ordering) 50 | 51 | 52 | @st.composite 53 | def problems(draw) -> Problem: 54 | assignment = draw(assignments()) 55 | 56 | problem_formats = { 57 | name: draw(formats(st.just(order))) for name, order in assignment.variable_orders().items() 58 | } 59 | 60 | return Problem(assignment, problem_formats) 61 | 62 | 63 | @st.composite 64 | def tensors(draw, format: Format, dimensions: tuple[int, ...] | None = None) -> Tensor: 65 | if dimensions is None: 66 | dimensions = draw(st.lists(st.integers(), min_size=format.order, max_size=format.order)) 67 | 68 | if any(dim == 0 for dim in dimensions): 69 | # Hypothesis does not like being forced to draw empty dictionaries by giving it an 70 | # empty set of possible keys 71 | dok = {} 72 | else: 73 | dok = draw( 74 | st.dictionaries( 75 | st.tuples(*[st.integers(min_value=0, max_value=dim - 1) for dim in dimensions]), 76 | st.floats(allow_infinity=False, allow_nan=False), 77 | ) 78 | ) 79 | 80 | return Tensor.from_dok(dok, format=format, dimensions=dimensions) 81 | 82 | 83 | @st.composite 84 | def problem_and_tensors(draw): 85 | problem: Problem = draw(problems()) 86 | 87 | # Indexes over the same dimension must have the same size. 88 | # Dimensions sharing an index must have the same size. 89 | participant_sizes: dict[tuple[str, int], int] = {} 90 | index_sizes: dict[str, int] = {} 91 | for index, participants in problem.assignment.index_participants().items(): 92 | for participant in participants: 93 | if participant in participant_sizes: 94 | index_sizes[index] = participant_sizes[participant] 95 | elif index in index_sizes: 96 | participant_sizes[participant] = index_sizes[index] 97 | else: 98 | # This number can conspire with the number of dimensions to 99 | # require a LOT of memory. 100 | size = draw(st.integers(min_value=0, max_value=64)) 101 | participant_sizes[participant] = size 102 | index_sizes[index] = size 103 | 104 | input_tensors = {} 105 | for name, variable in problem.assignment.expression.variables().items(): 106 | format = problem.formats[name] 107 | dimensions = tuple(index_sizes[index] for index in variable[0].indexes) 108 | input_tensors[name] = draw(tensors(format, dimensions)) 109 | 110 | return (problem, input_tensors) 111 | -------------------------------------------------------------------------------- /docs/tensors.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/cube-outline 3 | --- 4 | 5 | # Tensors 6 | 7 | The main type in Tensora is the `Tensor` class. `Tensor`s are immutable. New tensors may be constructed from operations on other `Tensor`s, but no property of a `Tensor` may change once it is constructed. This is different from NumPy arrays and Scipy matrices, which may be mutated in-place. 8 | 9 | ## Attributes 10 | 11 | The order, dimensions, and format are the fundamental structural properties of a tensor. These are available as attributes of a `Tensor`. 12 | 13 | ### `tensor.order` 14 | 15 | The order of a tensor is the number of dimensions it has. A scalar is a 0-order tensor, a vector is a 1-order tensor, a matrix is a 2-order tensor, and so on. Conceptually, the order may be any non-negative integer, but realistically, a large enough number of dimensions will cause a stack overflow or other resource error. 16 | 17 | ```python 18 | from tensora import Tensor 19 | 20 | tensor = Tensor.from_lol([[1,2,3], [4,5,6]]) 21 | assert tensor.order == 2 22 | ``` 23 | 24 | ### `tensor.dimensions` 25 | 26 | Each element of the `dimensions` tuple is the size of the corresponding dimension. 27 | 28 | ```python 29 | from tensora import Tensor 30 | 31 | tensor = Tensor.from_lol([[1,2,3], [4,5,6]]) 32 | assert tensor.dimensions == (2, 3) 33 | ``` 34 | 35 | ### `tensor.format` 36 | 37 | The type of `format` is a `tensora.Format` object, which has `modes` and `ordering` attributes. The `format.deparse()` method will give you a human-readable string. 38 | 39 | ```python 40 | from tensora import Tensor 41 | 42 | tensor = Tensor.from_lol([[1,2,3], [4,5,6]]) 43 | assert tensor.format.deparse() == 'dd' 44 | ``` 45 | 46 | ## Arithmetic 47 | 48 | The normal way to perform mathematical operations on `Tensor`s is to use the `evaluate` function. However, the `Tensor` class implements several of the standard arithmetic operations available in Python. Tensora makes some guesses on the format of the result. If more control is needed use `evaluate`. 49 | 50 | ### `tensor1 + tensor2` and `tensor1 - tensor2` 51 | 52 | Addition and subtraction are element-wise operations. If both operands are `Tensor`s, they must have the same order and dimensions. The result will be a `Tensor` where each dimension is dense if either operand is dense at that dimension. If one of the operands is a Python scalar, it will be broadcast to the dimensions of the other operand. The result will be a `Tensor` with the same order, dimensions, and format as the other operand. 53 | 54 | ```python 55 | from tensora import Tensor 56 | 57 | tensor1 = Tensor.from_lol([[1,2,3], [4,5,6]]) 58 | tensor2 = Tensor.from_lol([[7,8,9], [10,11,12]]) 59 | 60 | assert tensor1 + tensor2 == Tensor.from_lol([[8,10,12], [14,16,18]]) 61 | assert tensor1 - tensor2 == Tensor.from_lol([[-6,-6,-6], [-6,-6,-6]]) 62 | ``` 63 | 64 | ### `tensor1 * tensor2` 65 | 66 | Multiplication is and element-wise operation. If both operands are `Tensor`s, they must have the same order and dimensions. The result will be a `Tensor` where each dimension is sparse if either operand is sparse at that dimension. If one of the operands is a Python scalar, it will be broadcast to the dimensions of the other operand. The result will be a `Tensor` with the same order, dimensions, and format as the other operand. 67 | 68 | ```python 69 | from tensora import Tensor 70 | 71 | tensor1 = Tensor.from_lol([[1,2,3], [4,5,6]]) 72 | tensor2 = Tensor.from_lol([[7,8,9], [10,11,12]]) 73 | 74 | assert tensor1 * tensor2 == Tensor.from_lol([[7,16,27], [40,55,72]]) 75 | ``` 76 | 77 | ### `tensor1 @ tensor2` 78 | 79 | Matrix multiplication is only permitted between vectors (order-1 tensors) and matrices (order-2 tensors). The dimensions of the operands must be compatible like normal and as in the table below. The result is a `Tensor` with the the expected dimensions. The format of the result is determined by the format of the operand dimensions that give the result dimension its size. 80 | 81 | | `a` | `b` | `a @ b` | assignment | 82 | |--------|--------|---------|----------------------------| 83 | | (n,) | (n,) | () | `c() = a(i) * b(i)` | 84 | | (n,) | (n, p) | (p,) | `c(j) = a(i) * b(i,j)` | 85 | | (m, n) | (n,) | (m,) | `c(i) = a(i,j) * b(j)` | 86 | | (m, n) | (n, p) | (m, p) | `c(i,j) = a(i,k) * b(k,j)` | 87 | 88 | ```python 89 | from tensora import Tensor 90 | 91 | A = Tensor.from_lol([[1,2,3], [4,5,6]]) 92 | x = Tensor.from_lol([1,2,3]) 93 | 94 | assert A @ x == Tensor.from_lol([14, 32]) 95 | ``` -------------------------------------------------------------------------------- /src/tensora/desugar/_desugar_expression.py: -------------------------------------------------------------------------------- 1 | __all__ = ["desugar_assignment"] 2 | 3 | from functools import singledispatch 4 | from itertools import count 5 | from typing import Iterator 6 | 7 | from ..expression import ast as sugar 8 | from . import ast as desugar 9 | 10 | 11 | @singledispatch 12 | def desugar_expression( 13 | self: sugar.Expression, contract_indexes: set[str], ids: Iterator[int] 14 | ) -> desugar.Expression: 15 | raise NotImplementedError(f"desugar_expression not implemented for {type(self)}: {self}") 16 | 17 | 18 | @desugar_expression.register(sugar.Integer) 19 | def desugar_integer( 20 | self: sugar.Integer, contract_indexes: set[str], ids: Iterator[int] 21 | ) -> desugar.Expression: 22 | return desugar.Integer(self.value) 23 | 24 | 25 | @desugar_expression.register(sugar.Float) 26 | def desugar_float( 27 | self: sugar.Float, contract_indexes: set[str], ids: Iterator[int] 28 | ) -> desugar.Expression: 29 | return desugar.Float(self.value) 30 | 31 | 32 | @desugar_expression.register(sugar.Tensor) 33 | def desugar_tensor( 34 | self: sugar.Tensor, contract_indexes: set[str], ids: Iterator[int] 35 | ) -> desugar.Expression: 36 | output = desugar.Tensor(next(ids), self.name, self.indexes) 37 | for index in contract_indexes: 38 | output = desugar.Contract(index, output) 39 | return output 40 | 41 | 42 | @desugar_expression.register(sugar.Add) 43 | def desugar_add( 44 | self: sugar.Add, contract_indexes: set[str], ids: Iterator[int] 45 | ) -> desugar.Expression: 46 | left_indexes = set(self.left.index_participants().keys()).intersection(contract_indexes) 47 | right_indexes = set(self.right.index_participants().keys()).intersection(contract_indexes) 48 | 49 | intersection_indexes = left_indexes.intersection(right_indexes) 50 | 51 | output = desugar.Add( 52 | desugar_expression(self.left, left_indexes - intersection_indexes, ids), 53 | desugar_expression(self.right, right_indexes - intersection_indexes, ids), 54 | ) 55 | 56 | for index in intersection_indexes: 57 | output = desugar.Contract(index, output) 58 | 59 | return output 60 | 61 | 62 | @desugar_expression.register(sugar.Subtract) 63 | def desugar_subtract( 64 | self: sugar.Subtract, contract_indexes: set[str], ids: Iterator[int] 65 | ) -> desugar.Expression: 66 | left_indexes = set(self.left.index_participants().keys()).intersection(contract_indexes) 67 | right_indexes = set(self.right.index_participants().keys()).intersection(contract_indexes) 68 | 69 | intersection_indexes = left_indexes.intersection(right_indexes) 70 | 71 | output = desugar.Add( 72 | desugar_expression(self.left, left_indexes - intersection_indexes, ids), 73 | desugar.Multiply( 74 | desugar.Integer(-1), 75 | desugar_expression(self.right, right_indexes - intersection_indexes, ids), 76 | ), 77 | ) 78 | 79 | for index in intersection_indexes: 80 | output = desugar.Contract(index, output) 81 | 82 | return output 83 | 84 | 85 | @desugar_expression.register(sugar.Multiply) 86 | def desugar_multiply( 87 | self: sugar.Multiply, contract_indexes: set[str], ids: Iterator[int] 88 | ) -> desugar.Expression: 89 | left_indexes = set(self.left.index_participants().keys()).intersection(contract_indexes) 90 | right_indexes = set(self.right.index_participants().keys()).intersection(contract_indexes) 91 | 92 | intersection_indexes = left_indexes.intersection(right_indexes) 93 | 94 | output = desugar.Multiply( 95 | desugar_expression(self.left, left_indexes - intersection_indexes, ids), 96 | desugar_expression(self.right, right_indexes - intersection_indexes, ids), 97 | ) 98 | 99 | for index in intersection_indexes: 100 | output = desugar.Contract(index, output) 101 | 102 | return output 103 | 104 | 105 | def desugar_assignment(assignment: sugar.Assignment) -> desugar.Assignment: 106 | ids = count() 107 | 108 | desugared_target = desugar.Tensor(next(ids), assignment.target.name, assignment.target.indexes) 109 | 110 | all_indexes = set(assignment.index_participants().keys()) 111 | contract_indexes = all_indexes - set(assignment.target.indexes) 112 | 113 | desugared_right_hand_side = desugar_expression(assignment.expression, contract_indexes, ids) 114 | 115 | return desugar.Assignment(desugared_target, desugared_right_hand_side) 116 | -------------------------------------------------------------------------------- /src/tensora/ir/_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["SourceBuilder"] 4 | 5 | from abc import abstractmethod 6 | from contextlib import contextmanager 7 | from typing import Mapping 8 | 9 | from .._stable_set import StableSet 10 | from .ast import ( 11 | Block, 12 | Branch, 13 | Declaration, 14 | Expression, 15 | FunctionDefinition, 16 | Loop, 17 | Statement, 18 | Variable, 19 | ) 20 | from .types import Type 21 | 22 | 23 | class Builder: 24 | def __init__(self): 25 | self.lines: list[Statement] = [] 26 | 27 | @abstractmethod 28 | def finalize(self): 29 | raise NotImplementedError() 30 | 31 | 32 | class BlockBuilder(Builder): 33 | def __init__(self, comment: str | None = None): 34 | super().__init__() 35 | self._comment = comment 36 | 37 | def finalize(self): 38 | return Block(self.lines, self._comment) 39 | 40 | 41 | class BranchBuilder(Builder): 42 | def __init__(self, condition: Expression): 43 | super().__init__() 44 | self._condition = condition 45 | 46 | def finalize(self): 47 | return Branch(self._condition, Block(self.lines), Block([])) 48 | 49 | 50 | class LoopBuilder(Builder): 51 | def __init__(self, condition: Expression): 52 | super().__init__() 53 | self._condition = condition 54 | 55 | def finalize(self): 56 | return Loop(self._condition, Block(self.lines)) 57 | 58 | 59 | class FunctionDefinitionBuilder(Builder): 60 | def __init__(self, name: str, parameters: Mapping[str, Type], return_type: Type): 61 | super().__init__() 62 | self._name = name 63 | self._parameters = parameters 64 | self._return_type = return_type 65 | 66 | def finalize(self): 67 | return FunctionDefinition( 68 | Variable(self._name), 69 | [Declaration(Variable(name), type) for name, type in self._parameters.items()], 70 | self._return_type, 71 | Block(self.lines), 72 | ) 73 | 74 | 75 | class SourceBuilder: 76 | def __init__(self, comment: str | None = None): 77 | self._dependencies: StableSet[str] = StableSet() 78 | self._stack: list[Builder] = [BlockBuilder(comment)] 79 | 80 | def append(self, statement: Statement | SourceBuilder): 81 | match statement: 82 | case Statement(): 83 | self._stack[-1].lines.append(statement) 84 | case SourceBuilder(): 85 | for dependency in statement._dependencies: 86 | self.add_dependency(dependency) 87 | statement = statement.finalize() 88 | if statement.comment is not None: 89 | self._stack[-1].lines.append(statement) 90 | else: 91 | # This is not simplified by peephole, which only simplifies empty blocks; 92 | # it does not inline not blocks with no comments. This is because blocks 93 | # with no comments still get newlines between them. This means that 94 | # appending a SourceBuilder appends the lines, but appending a block 95 | # appends the block itself. 96 | self._stack[-1].lines.extend(statement.statements) 97 | 98 | def add_dependency(self, name: str): 99 | self._dependencies.add(name) 100 | 101 | @contextmanager 102 | def block(self, comment: str | None = None): 103 | self._stack.append(BlockBuilder(comment)) 104 | yield None 105 | self.append(self._stack.pop().finalize()) 106 | 107 | @contextmanager 108 | def branch(self, condition: Expression): 109 | self._stack.append(BranchBuilder(condition)) 110 | yield None 111 | self.append(self._stack.pop().finalize()) 112 | 113 | @contextmanager 114 | def loop(self, condition: Expression): 115 | self._stack.append(LoopBuilder(condition)) 116 | yield None 117 | self.append(self._stack.pop().finalize()) 118 | 119 | @contextmanager 120 | def function_definition(self, name: str, parameters: Mapping[str, Type], return_type: Type): 121 | self._stack.append(FunctionDefinitionBuilder(name, parameters, return_type)) 122 | yield None 123 | self.append(self._stack.pop().finalize()) 124 | 125 | def finalize(self) -> Block: 126 | assert len(self._stack) == 1 127 | return self._stack[0].finalize() 128 | -------------------------------------------------------------------------------- /src/tensora/problem.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "IncorrectDimensionsError", 3 | "Problem", 4 | "UndefinedReferenceError", 5 | "UnusedFormatError", 6 | "make_problem", 7 | ] 8 | 9 | from dataclasses import dataclass 10 | 11 | from returns.result import Failure, Result, Success 12 | 13 | from .expression.ast import Assignment 14 | from .format import Format, Mode 15 | 16 | 17 | @dataclass(frozen=True, slots=True) 18 | class IncorrectDimensionsError(Exception): 19 | name: str 20 | actual: int 21 | expected: int 22 | assignment: Assignment 23 | 24 | def __str__(self): 25 | return ( 26 | f"Expected each reference in an assignment to have a number of indexes matching the " 27 | f"order of the corresponding format, but variable {self.name} referenced in " 28 | f"{self.assignment} indexes has order {self.actual} while its format has order {self.expected}" 29 | ) 30 | 31 | 32 | @dataclass(frozen=True, slots=True) 33 | class UndefinedReferenceError(Exception): 34 | name: str 35 | assignment: Assignment 36 | formats: list[str] 37 | 38 | def __str__(self): 39 | return ( 40 | f"Excepted each reference in an assignment to have a corresponding format, " 41 | f"but variable {self.name} referenced in {self.assignment} was not found among the " 42 | f"given formats {self.formats}" 43 | ) 44 | 45 | 46 | @dataclass(frozen=True, slots=True) 47 | class UnusedFormatError(Exception): 48 | name: str 49 | assignment: Assignment 50 | 51 | def __str__(self): 52 | return ( 53 | f"Expected each format to be referenced in the assignment, " 54 | f"but format {self.name} was not referenced in {self.assignment}" 55 | ) 56 | 57 | 58 | @dataclass(frozen=True, slots=True) 59 | class Problem: 60 | assignment: Assignment 61 | formats: dict[str, Format] 62 | 63 | def __post_init__(self): 64 | # This intentionally allows for names in formats that are not referenced in the assignment. 65 | # The CLI and porcelain API will not allow this, but this is just as valid as defining a 66 | # function with unused parameters. 67 | 68 | tensor_orders = self.assignment.variable_orders() 69 | for name, order in tensor_orders.items(): 70 | if name not in self.formats: 71 | raise UndefinedReferenceError(name, self.assignment, list(self.formats.keys())) 72 | elif order != self.formats[name].order: 73 | raise IncorrectDimensionsError( 74 | name, self.formats[name].order, order, self.assignment 75 | ) 76 | 77 | def __eq__(self, other: object): 78 | if isinstance(other, Problem): 79 | # Problems are only equal if the formats orders are equal 80 | return self.assignment == other.assignment and tuple(self.formats.items()) == tuple( 81 | other.formats.items() 82 | ) 83 | else: 84 | return NotImplemented 85 | 86 | def __hash__(self) -> int: 87 | return hash((self.assignment, tuple(self.formats.items()))) 88 | 89 | 90 | def make_problem( 91 | assignment: Assignment, formats: dict[str, Format] 92 | ) -> Result[Problem, UnusedFormatError | UndefinedReferenceError | IncorrectDimensionsError]: 93 | """Create a Problem while filling in default formats. 94 | 95 | This does three things that the `Problem` constructor does not do: 96 | 1. It reorders the formats to match the order the tensors appear in the assignment. 97 | 2. It fills in any missing formats with all dense modes. 98 | 3. It raises an exception if there are formats not referenced in the assignment. 99 | """ 100 | 101 | tensor_orders = assignment.variable_orders() 102 | 103 | for name in formats.keys(): 104 | if name not in tensor_orders: 105 | return Failure(UnusedFormatError(name, assignment)) 106 | 107 | new_formats = {} 108 | for name, order in tensor_orders.items(): 109 | if name not in formats: 110 | new_formats[name] = Format(tuple([Mode.dense] * order), tuple(range(order))) 111 | else: 112 | new_formats[name] = formats[name] 113 | 114 | try: 115 | problem = Problem(assignment, new_formats) 116 | except (UndefinedReferenceError, IncorrectDimensionsError) as error: 117 | return Failure(error) 118 | 119 | return Success(problem) 120 | -------------------------------------------------------------------------------- /docs/evaluate.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/function-variant 3 | --- 4 | 5 | # Evaluate 6 | 7 | ``` 8 | evaluate( 9 | assignment: str, 10 | output_format: Format | str, 11 | *, 12 | **inputs: Tensor, 13 | ) -> Tensor 14 | ``` 15 | 16 | The main entry point for mathematical operations in Tensora is the `evaluate` function. It takes a tensor algebra assignment and a list of `Tensor` objects. It returns a new `Tensor`, having evaluated the expression according to the input `Tensor`s. 17 | 18 | * `assignment` is parsable as an algebraic tensor assignment. 19 | 20 | * `output_format` is the desired format of the output tensor. 21 | 22 | * `inputs` is all the inputs to the expression. There must be one named argument for each variable name in `assignment`. The dimensions of the tensors in `inputs` must be consistent with `assignment` and with each other. 23 | 24 | There is also `evaluate_tensora` that exposes the internal compiler explicitly. `evaluate` is an alias for the default, which is currently `evaluate_tensora`. 25 | 26 | ```python 27 | from tensora import Tensor, evaluate 28 | 29 | A = Tensor.from_lol([[1,2,3], [4,5,6]]) 30 | x = Tensor.from_lol([1,2,3]) 31 | 32 | y = evaluate('y(i) = A(i,j) * x(j)', 'd', A=A, x=x) 33 | assert y == Tensor.from_lol([14, 32]) 34 | ``` 35 | 36 | ## Assignments 37 | 38 | In a loose sense, the assignment strings use Einstein notation. The assignments are made of tensor names, index names, and operations. A tensor with its indexes is the target of the assignment on the left-hand side. Various tensors with their indexes are connected by elementary operations on the right-hand side. 39 | 40 | ### Output indexes 41 | 42 | Indexes that appear on both sides match an output dimension to the input dimensions sharing that index. 43 | 44 | ```python 45 | from tensora import Tensor, evaluate 46 | 47 | a = Tensor.from_lol([1,2,3]) 48 | b = Tensor.from_lol([4,5,6]) 49 | 50 | c = evaluate('c(i) = a(i) * b(i)', 'd', a=a, b=b) 51 | assert c == Tensor.from_lol([4, 10, 18]) 52 | ``` 53 | 54 | ### Contraction indexes 55 | 56 | Indexes that appear only on the right-hand side are summed over, also known as a contraction. 57 | 58 | ```python 59 | from tensora import Tensor, evaluate 60 | 61 | A = Tensor.from_lol([[1,2,3], [4,5,6]]) 62 | 63 | a = evaluate('a(i) = A(i,j)', '', A=A) 64 | assert a == Tensor.from_lol([6, 15]) 65 | 66 | b = evaluate('b(j) = A(i,j)', '', A=A) 67 | assert b == Tensor.from_lol([5, 7, 9]) 68 | ``` 69 | 70 | This commonly appears in the context of multiplication, in which it called an inner product. 71 | 72 | ```python 73 | from tensora import Tensor, evaluate 74 | 75 | a = Tensor.from_lol([1,2,3]) 76 | b = Tensor.from_lol([4,5,6]) 77 | 78 | c = evaluate('c() = a(i) * b(i)', '', a=a, b=b) 79 | assert c == 32 80 | ``` 81 | 82 | ### Broadcasting indexes 83 | 84 | Indexes that appear only on the left-hand side would be interpreted as broadcasting the value of the right-hand side along that dimension. 85 | 86 | This operation is not currently allowed by `evaluate` because that indicates that the expression should be broadcast along that target dimension, but there is currently no way to specify the size of that dimension. It is allowed by the `tensora` CLI, however. 87 | 88 | ```python 89 | from tensora import Tensor, evaluate 90 | 91 | a = Tensor.from_lol(1) 92 | 93 | b = evaluate('b(i) = a()', 'd', a=a) 94 | # BroadcastTargetIndexError: Expected index variable i on the target variable 95 | # to be mentioned on the right-hand side, but it was not: b(i) = a(). Such 96 | # broadcasting makes sense in a kernel and those kernels can be generated, but 97 | # they cannot be used in `evaluate` or `tensor_method` because those functions 98 | # get the output dimensions from the the dimensions of the input tensors. 99 | ``` 100 | 101 | ### Reusing tensors 102 | 103 | Tensor names may be repeated, possibly with different indexes. The tensor can and should only be provided once; it will be used for all occurrences of that tensor name in the assignment. 104 | 105 | ```python 106 | from tensora import Tensor, evaluate 107 | 108 | x = Tensor.from_lol([1,2,3]) 109 | V = Tensor.from_lol([[1,2,3], [4,5,6], [7,8,9]]) 110 | 111 | y = evaluate('y() = x(i) * V(i,j) * x(j)', '', x=x, V=V) 112 | assert y == 228 113 | ``` 114 | 115 | ### Diagonal indexes 116 | 117 | Indexes may *not* be repeated within a tensor. Such syntax would represent a diagonal operation, which is currently not supported. 118 | 119 | ```python 120 | from tensora import Tensor, evaluate 121 | 122 | V = Tensor.from_lol([[1,2,3], [4,5,4], [3,2,1]]) 123 | 124 | v = evaluate('v(i) = V(i,i)', 'd', V=V) 125 | # DiagonalAccessError: Diagonal access to a tensor (i.e. repeating the same 126 | # index within a tensor) is not currently supported: V(i, i) 127 | ``` 128 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/iteration_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["IterationGraph", "IterationNode", "SumNode", "TerminalNode"] 4 | 5 | from abc import abstractmethod 6 | from dataclasses import dataclass, replace 7 | 8 | from .._stable_set import StableFrozenSet 9 | from ..format import Mode 10 | from .identifiable_expression import Context, TensorLayer, exhaust_tensor, extract_context 11 | from .identifiable_expression.ast import Expression, Integer 12 | 13 | 14 | class IterationGraph: 15 | @abstractmethod 16 | def extract_context(self, index: str) -> Context: 17 | raise NotImplementedError() 18 | 19 | @abstractmethod 20 | def exhaust_tensor(self, reference: str) -> IterationGraph: 21 | raise NotImplementedError() 22 | 23 | @abstractmethod 24 | def is_sparse_output(self) -> bool: 25 | # Needed by assembly to determine if next layer is guaranteed to advance position or not 26 | raise NotImplementedError() 27 | 28 | @abstractmethod 29 | def compressed_dimensions(self) -> StableFrozenSet[str]: 30 | # Needed when empty subgraphs simplify 31 | raise NotImplementedError() 32 | 33 | @abstractmethod 34 | def later_indexes(self) -> frozenset[str]: 35 | raise NotImplementedError() 36 | 37 | @abstractmethod 38 | def has_output(self) -> bool: 39 | raise NotImplementedError() 40 | 41 | 42 | @dataclass(frozen=True) 43 | class TerminalNode(IterationGraph): 44 | expression: Expression 45 | 46 | def extract_context(self, index: str) -> Context: 47 | return extract_context(self.expression, index) 48 | 49 | def exhaust_tensor(self, reference: str) -> IterationGraph: 50 | return TerminalNode(exhaust_tensor(self.expression, reference)) 51 | 52 | def is_sparse_output(self) -> bool: 53 | return False 54 | 55 | def compressed_dimensions(self) -> StableFrozenSet[str]: 56 | # Needed when empty subgraphs simplify 57 | return StableFrozenSet() 58 | 59 | def later_indexes(self) -> frozenset[str]: 60 | return frozenset() 61 | 62 | def has_output(self) -> bool: 63 | return False 64 | 65 | 66 | @dataclass(frozen=True) 67 | class IterationNode(IterationGraph): 68 | index_variable: str 69 | output: TensorLayer | None 70 | next: IterationGraph 71 | 72 | def __post_init__(self): 73 | self.context: Context 74 | object.__setattr__(self, "context", self.extract_context(self.index_variable)) 75 | 76 | def extract_context(self, index: str) -> Context: 77 | next_context = self.next.extract_context(index) 78 | return replace( 79 | next_context, 80 | indexes=next_context.indexes | frozenset([self.index_variable]), 81 | has_output=next_context.has_output or self.output is not None, 82 | has_assemble=next_context.has_assemble or self.is_sparse_output(), 83 | ) 84 | 85 | def exhaust_tensor(self, reference: str) -> IterationGraph: 86 | new_next = self.next.exhaust_tensor(reference) 87 | 88 | return replace(self, next=new_next) 89 | 90 | def compressed_dimensions(self) -> StableFrozenSet[str]: 91 | return StableFrozenSet(*(leaf.tensor.id for leaf in self.context.sparse_leaves)) 92 | 93 | def sparse_leaves(self) -> list[TensorLayer]: 94 | return [TensorLayer(leaf.tensor, leaf.layer) for leaf in self.context.sparse_leaves] 95 | 96 | def dense_leaves(self) -> list[TensorLayer]: 97 | return [TensorLayer(leaf.tensor, leaf.layer) for leaf in self.context.dense_leaves] 98 | 99 | def is_sparse_input(self) -> bool: 100 | return self.context.is_sparse 101 | 102 | def is_dense_output(self) -> bool: 103 | return self.output is not None and self.output.mode == Mode.dense 104 | 105 | def is_sparse_output(self) -> bool: 106 | return self.output is not None and self.output.mode == Mode.compressed 107 | 108 | def later_indexes(self) -> frozenset[str]: 109 | return self.context.indexes 110 | 111 | def has_output(self) -> bool: 112 | return self.context.has_output 113 | 114 | def has_assemble(self) -> bool: 115 | return self.context.has_assemble 116 | 117 | 118 | @dataclass(frozen=True) 119 | class SumNode(IterationGraph): 120 | name: str 121 | terms: list[IterationGraph] 122 | 123 | def extract_context(self, index: str) -> Context: 124 | context = Context(is_sparse=True) 125 | for term in self.terms: 126 | context = context.add(term.extract_context(index)) 127 | return context 128 | 129 | def exhaust_tensor(self, reference: str) -> IterationGraph: 130 | new_terms = [] 131 | for term in self.terms: 132 | new_term = term.exhaust_tensor(reference) 133 | # TODO: Simplify empty terms 134 | new_terms.append(new_term) 135 | 136 | if len(new_terms) == 0: 137 | return TerminalNode(Integer(0)) 138 | elif len(new_terms) == 1: 139 | return new_terms[0] 140 | else: 141 | return replace(self, terms=new_terms) 142 | 143 | def is_sparse_output(self) -> bool: 144 | return False 145 | 146 | def compressed_dimensions(self) -> StableFrozenSet[str]: 147 | # Needed when empty subgraphs simplify 148 | return StableFrozenSet() 149 | 150 | def later_indexes(self) -> frozenset[str]: 151 | return frozenset.union(*(term.later_indexes() for term in self.terms)) 152 | 153 | def has_output(self) -> bool: 154 | return any(term.has_output() for term in self.terms) 155 | -------------------------------------------------------------------------------- /tests/test_combinatorically.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora import Tensor, evaluate 4 | 5 | 6 | def assert_same_as_dense(expression, format_out, **tensor_pairs): 7 | tensors_in_format = { 8 | name: Tensor.from_lol(data, format=format) for name, (data, format) in tensor_pairs.items() 9 | } 10 | tensors_as_dense = {name: Tensor.from_lol(data) for name, (data, _) in tensor_pairs.items()} 11 | 12 | dense_format = "d" * (format_out.count("d") + format_out.count("s")) 13 | actual = evaluate(expression, format_out, **tensors_in_format) 14 | expected = evaluate(expression, dense_format, **tensors_as_dense) 15 | assert actual == expected 16 | 17 | 18 | @pytest.mark.parametrize("dense", [[3, 2, 4], [0, 0, 0]]) 19 | @pytest.mark.parametrize("format_in", ["s", "d"]) 20 | @pytest.mark.parametrize("format_out", ["s", "d"]) 21 | def test_copy_1(dense, format_in, format_out): 22 | a = Tensor.from_lol(dense, format=format_in) 23 | actual = evaluate("b(i) = a(i)", format_out, a=a) 24 | assert actual == a 25 | 26 | 27 | @pytest.mark.parametrize("dense", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 28 | @pytest.mark.parametrize("format_in", ["ss", "dd", "sd", "ds", "d1d0"]) 29 | @pytest.mark.parametrize("format_out", ["ss", "dd", "sd", "ds", "d1d0"]) 30 | def test_copy_2(dense, format_in, format_out): 31 | a = Tensor.from_lol(dense, format=format_in) 32 | actual = evaluate("b(i,j) = a(i,j)", format_out, a=a) 33 | assert actual == a 34 | 35 | 36 | @pytest.mark.parametrize("dense", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 37 | @pytest.mark.parametrize("format_in", ["s1s0", "d1d0", "s1d0", "d1s0", "dd"]) 38 | @pytest.mark.parametrize("format_out", ["s1s0", "d1d0", "s1d0", "d1s0", "dd"]) 39 | def test_copy_2_backwards(dense, format_in, format_out): 40 | a = Tensor.from_lol(dense, format=format_in) 41 | actual = evaluate("b(i,j) = a(i,j)", format_out, a=a) 42 | assert actual == a 43 | 44 | 45 | @pytest.mark.parametrize("expression", [0, 1]) 46 | def test_constant_scalar(expression): 47 | actual = evaluate(f"a() = {expression}", "") 48 | assert actual == Tensor.from_lol(expression) 49 | 50 | 51 | @pytest.mark.parametrize("dense1", [[0, 2, 4, 0], [0, 0, 0, 0]]) 52 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0, 0], [0, 0, 0, 0]]) 53 | @pytest.mark.parametrize("format1", ["s", "d"]) 54 | @pytest.mark.parametrize("format2", ["s", "d"]) 55 | def test_vector_dot(dense1, dense2, format1, format2): 56 | assert_same_as_dense( 57 | "out() = in1(i) * in2(i)", "", in1=(dense1, format1), in2=(dense2, format2) 58 | ) 59 | 60 | 61 | @pytest.mark.parametrize("dense1", [[0, 2, 4, 0], [0, 0, 0, 0]]) 62 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0, 0], [0, 0, 0, 0]]) 63 | @pytest.mark.parametrize("format1", ["s", "d"]) 64 | @pytest.mark.parametrize("format2", ["s", "d"]) 65 | @pytest.mark.parametrize("format_out", ["s", "d"]) 66 | @pytest.mark.parametrize("operator", ["+", "-", "*"]) 67 | def test_vector_binary(operator, dense1, dense2, format1, format2, format_out): 68 | assert_same_as_dense( 69 | f"out(i) = in1(i) {operator} in2(i)", 70 | format_out, 71 | in1=(dense1, format1), 72 | in2=(dense2, format2), 73 | ) 74 | 75 | 76 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 77 | @pytest.mark.parametrize("dense2", [[[-1, 3.5], [0, 0], [4, 0]], [[0, 0], [0, 0], [0, 0]]]) 78 | @pytest.mark.parametrize("format1", ["ss", "dd", "sd", "ds", "d1d0"]) 79 | @pytest.mark.parametrize("format2", ["ss", "dd", "sd", "ds", "d1d0"]) 80 | @pytest.mark.parametrize("format_out", ["dd", "d1d0"]) 81 | def test_matrix_dot(dense1, dense2, format1, format2, format_out): 82 | assert_same_as_dense( 83 | "out(i,k) = in1(i,j) * in2(j,k)", format_out, in1=(dense1, format1), in2=(dense2, format2) 84 | ) 85 | 86 | 87 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 88 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0], [0, 0, 0]]) 89 | @pytest.mark.parametrize("format1", ["ss", "dd", "sd", "ds", "s1s0", "d1d0", "s1d0", "d1s0"]) 90 | @pytest.mark.parametrize("format2", ["s", "d"]) 91 | @pytest.mark.parametrize("format_out", ["d"]) 92 | def test_matrix_vector_product(dense1, dense2, format1, format2, format_out): 93 | assert_same_as_dense( 94 | "out(i) = in1(i,j) * in2(j)", format_out, in1=(dense1, format1), in2=(dense2, format2) 95 | ) 96 | 97 | 98 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 99 | @pytest.mark.parametrize("dense2", [[[-1, 3.5], [0, 0], [4, 0]], [[0, 0], [0, 0], [0, 0]]]) 100 | @pytest.mark.parametrize("dense3", [[[-3, 0], [7, 0]], [[0, 0], [0, 0]]]) 101 | @pytest.mark.parametrize("format1", ["dd", "ds"]) 102 | @pytest.mark.parametrize("format2", ["dd"]) 103 | @pytest.mark.parametrize("format3", ["dd", "ds"]) 104 | @pytest.mark.parametrize("format_out", ["dd"]) 105 | def test_matrix_multiply_add(dense1, dense2, dense3, format1, format2, format3, format_out): 106 | assert_same_as_dense( 107 | "out(i,k) = in1(i,j) * in2(j,k) + in3(i,k)", 108 | format_out, 109 | in1=(dense1, format1), 110 | in2=(dense2, format2), 111 | in3=(dense3, format3), 112 | ) 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "dense_b", 117 | [ 118 | [ 119 | [[0, 2, 4, 0], [0, -1, 0, 3], [1, -1, 0, 0]], 120 | [[-2, 4, 0, 0], [0, 0, 0, 3], [1, 1, 0, 0]], 121 | ], 122 | [ 123 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 124 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 125 | ], 126 | ], 127 | ) 128 | @pytest.mark.parametrize( 129 | "dense_d", 130 | [ 131 | [[-1, 3.5, 1, 2, 0], [0, 2, 6, 3, 0], [4, 0, 0, 1, -1], [0, 0, 3, 6, 9]], 132 | [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 133 | ], 134 | ) 135 | @pytest.mark.parametrize( 136 | "dense_c", 137 | [ 138 | [[0, 0, 1, 2, 7], [7, 0, 5, 2, 0], [-1, 0, 0, 2, 1]], 139 | [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 140 | ], 141 | ) 142 | @pytest.mark.parametrize("format_b", ["ddd", "dss", "sss", "ssd", "d1d0s2", "s0d2d1"]) 143 | @pytest.mark.parametrize("format_d", ["dd", "ds", "ss"]) 144 | @pytest.mark.parametrize("format_c", ["dd", "ds", "ss"]) 145 | @pytest.mark.parametrize("format_out", ["dd", "d1d0"]) 146 | def test_mttkrp(dense_b, dense_d, dense_c, format_b, format_d, format_c, format_out): 147 | assert_same_as_dense( 148 | "A(i,j) = B(i,k,l) * D(l,j) * C(k,j)", 149 | format_out, 150 | B=(dense_b, format_b), 151 | D=(dense_d, format_d), 152 | C=(dense_c, format_c), 153 | ) 154 | -------------------------------------------------------------------------------- /tests_cffi/test_combinatorically.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora import Tensor 4 | from tensora.compile import evaluate_cffi as evaluate 5 | 6 | 7 | def assert_same_as_dense(expression, format_out, **tensor_pairs): 8 | tensors_in_format = { 9 | name: Tensor.from_lol(data, format=format) for name, (data, format) in tensor_pairs.items() 10 | } 11 | tensors_as_dense = {name: Tensor.from_lol(data) for name, (data, _) in tensor_pairs.items()} 12 | 13 | dense_format = "d" * (format_out.count("d") + format_out.count("s")) 14 | actual = evaluate(expression, format_out, **tensors_in_format) 15 | expected = evaluate(expression, dense_format, **tensors_as_dense) 16 | assert actual == expected 17 | 18 | 19 | @pytest.mark.parametrize("dense", [[3, 2, 4], [0, 0, 0]]) 20 | @pytest.mark.parametrize("format_in", ["s", "d"]) 21 | @pytest.mark.parametrize("format_out", ["s", "d"]) 22 | def test_copy_1(dense, format_in, format_out): 23 | a = Tensor.from_lol(dense, format=format_in) 24 | actual = evaluate("b(i) = a(i)", format_out, a=a) 25 | assert actual == a 26 | 27 | 28 | @pytest.mark.parametrize("dense", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 29 | @pytest.mark.parametrize("format_in", ["ss", "dd", "sd", "ds", "d1d0"]) 30 | @pytest.mark.parametrize("format_out", ["ss", "dd", "sd", "ds", "d1d0"]) 31 | def test_copy_2(dense, format_in, format_out): 32 | a = Tensor.from_lol(dense, format=format_in) 33 | actual = evaluate("b(i,j) = a(i,j)", format_out, a=a) 34 | assert actual == a 35 | 36 | 37 | @pytest.mark.parametrize("dense", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 38 | @pytest.mark.parametrize("format_in", ["s1s0", "d1d0", "s1d0", "d1s0", "dd"]) 39 | @pytest.mark.parametrize("format_out", ["s1s0", "d1d0", "s1d0", "d1s0", "dd"]) 40 | def test_copy_2_backwards(dense, format_in, format_out): 41 | a = Tensor.from_lol(dense, format=format_in) 42 | actual = evaluate("b(i,j) = a(i,j)", format_out, a=a) 43 | assert actual == a 44 | 45 | 46 | @pytest.mark.parametrize("expression", [0, 1]) 47 | def test_constant_scalar(expression): 48 | actual = evaluate(f"a() = {expression}", "") 49 | assert actual == Tensor.from_lol(expression) 50 | 51 | 52 | @pytest.mark.parametrize("dense1", [[0, 2, 4, 0], [0, 0, 0, 0]]) 53 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0, 0], [0, 0, 0, 0]]) 54 | @pytest.mark.parametrize("format1", ["s", "d"]) 55 | @pytest.mark.parametrize("format2", ["s", "d"]) 56 | def test_vector_dot(dense1, dense2, format1, format2): 57 | assert_same_as_dense( 58 | "out() = in1(i) * in2(i)", "", in1=(dense1, format1), in2=(dense2, format2) 59 | ) 60 | 61 | 62 | @pytest.mark.parametrize("dense1", [[0, 2, 4, 0], [0, 0, 0, 0]]) 63 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0, 0], [0, 0, 0, 0]]) 64 | @pytest.mark.parametrize("format1", ["s", "d"]) 65 | @pytest.mark.parametrize("format2", ["s", "d"]) 66 | @pytest.mark.parametrize("format_out", ["s", "d"]) 67 | @pytest.mark.parametrize("operator", ["+", "-", "*"]) 68 | def test_vector_binary(operator, dense1, dense2, format1, format2, format_out): 69 | assert_same_as_dense( 70 | f"out(i) = in1(i) {operator} in2(i)", 71 | format_out, 72 | in1=(dense1, format1), 73 | in2=(dense2, format2), 74 | ) 75 | 76 | 77 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 78 | @pytest.mark.parametrize("dense2", [[[-1, 3.5], [0, 0], [4, 0]], [[0, 0], [0, 0], [0, 0]]]) 79 | @pytest.mark.parametrize("format1", ["ss", "dd", "sd", "ds", "d1d0"]) 80 | @pytest.mark.parametrize("format2", ["ss", "dd", "sd", "ds", "d1d0"]) 81 | @pytest.mark.parametrize("format_out", ["dd", "d1d0"]) 82 | def test_matrix_dot(dense1, dense2, format1, format2, format_out): 83 | assert_same_as_dense( 84 | "out(i,k) = in1(i,j) * in2(j,k)", format_out, in1=(dense1, format1), in2=(dense2, format2) 85 | ) 86 | 87 | 88 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 89 | @pytest.mark.parametrize("dense2", [[-1, 3.5, 0], [0, 0, 0]]) 90 | @pytest.mark.parametrize("format1", ["ss", "dd", "sd", "ds", "s1s0", "d1d0", "s1d0", "d1s0"]) 91 | @pytest.mark.parametrize("format2", ["s", "d"]) 92 | @pytest.mark.parametrize("format_out", ["d"]) 93 | def test_matrix_vector_product(dense1, dense2, format1, format2, format_out): 94 | assert_same_as_dense( 95 | "out(i) = in1(i,j) * in2(j)", format_out, in1=(dense1, format1), in2=(dense2, format2) 96 | ) 97 | 98 | 99 | @pytest.mark.parametrize("dense1", [[[0, 2, 4], [0, -1, 0]], [[0, 0, 0], [0, 0, 0]]]) 100 | @pytest.mark.parametrize("dense2", [[[-1, 3.5], [0, 0], [4, 0]], [[0, 0], [0, 0], [0, 0]]]) 101 | @pytest.mark.parametrize("dense3", [[[-3, 0], [7, 0]], [[0, 0], [0, 0]]]) 102 | @pytest.mark.parametrize("format1", ["dd", "ds"]) 103 | @pytest.mark.parametrize("format2", ["dd"]) 104 | @pytest.mark.parametrize("format3", ["dd", "ds"]) 105 | @pytest.mark.parametrize("format_out", ["dd"]) 106 | def test_matrix_multiply_add(dense1, dense2, dense3, format1, format2, format3, format_out): 107 | assert_same_as_dense( 108 | "out(i,k) = in1(i,j) * in2(j,k) + in3(i,k)", 109 | format_out, 110 | in1=(dense1, format1), 111 | in2=(dense2, format2), 112 | in3=(dense3, format3), 113 | ) 114 | 115 | 116 | @pytest.mark.parametrize( 117 | "dense_b", 118 | [ 119 | [ 120 | [[0, 2, 4, 0], [0, -1, 0, 3], [1, -1, 0, 0]], 121 | [[-2, 4, 0, 0], [0, 0, 0, 3], [1, 1, 0, 0]], 122 | ], 123 | [ 124 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 125 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 126 | ], 127 | ], 128 | ) 129 | @pytest.mark.parametrize( 130 | "dense_d", 131 | [ 132 | [[-1, 3.5, 1, 2, 0], [0, 2, 6, 3, 0], [4, 0, 0, 1, -1], [0, 0, 3, 6, 9]], 133 | [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 134 | ], 135 | ) 136 | @pytest.mark.parametrize( 137 | "dense_c", 138 | [ 139 | [[0, 0, 1, 2, 7], [7, 0, 5, 2, 0], [-1, 0, 0, 2, 1]], 140 | [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 141 | ], 142 | ) 143 | @pytest.mark.parametrize("format_b", ["ddd", "dss", "sss", "ssd", "d1d0s2", "s0d2d1"]) 144 | @pytest.mark.parametrize("format_d", ["dd", "ds", "ss"]) 145 | @pytest.mark.parametrize("format_c", ["dd", "ds", "ss"]) 146 | @pytest.mark.parametrize("format_out", ["dd", "d1d0"]) 147 | def test_mttkrp(dense_b, dense_d, dense_c, format_b, format_d, format_c, format_out): 148 | assert_same_as_dense( 149 | "A(i,j) = B(i,k,l) * D(l,j) * C(k,j)", 150 | format_out, 151 | B=(dense_b, format_b), 152 | D=(dense_d, format_d), 153 | C=(dense_c, format_c), 154 | ) 155 | -------------------------------------------------------------------------------- /src/tensora/iteration_graph/outputs/_append.py: -------------------------------------------------------------------------------- 1 | __all__ = ["AppendOutput"] 2 | 3 | from dataclasses import dataclass 4 | 5 | from ...format import Mode 6 | from ...ir import SourceBuilder, types 7 | from ...ir.ast import ArrayAllocate, Expression, IntegerLiteral, Multiply, Variable 8 | from ...kernel_type import KernelType 9 | from .._names import ( 10 | crd_capacity_name, 11 | crd_name, 12 | dimension_name, 13 | layer_pointer, 14 | pos_capacity_name, 15 | pos_name, 16 | previous_layer_pointer, 17 | vals_capacity_name, 18 | vals_name, 19 | ) 20 | from ..identifiable_expression import TensorLayer 21 | from ..identifiable_expression import ast as ie_ast 22 | from ._base import Output 23 | from ._bucket import BucketOutput 24 | 25 | default_array_size = Multiply(IntegerLiteral(1024), IntegerLiteral(1024)) 26 | 27 | 28 | @dataclass(frozen=True, slots=True) 29 | class AppendOutput(Output): 30 | output: ie_ast.Tensor 31 | next_layer: int 32 | 33 | def vals_pointer(self) -> Expression: 34 | return previous_layer_pointer(self.output.id, self.output.order) 35 | 36 | def write_declarations(self, kernel_type: KernelType): 37 | source = SourceBuilder("Output initialization") 38 | 39 | target_name = self.output.name 40 | output_tensor = Variable(target_name) 41 | 42 | all_dense = True 43 | for i, mode in enumerate(self.output.modes): 44 | if mode == Mode.dense: 45 | pass 46 | elif mode == Mode.compressed: 47 | if kernel_type.is_assemble(): 48 | # How pos is handled depends on what the previous modes were 49 | if all_dense: 50 | # If the previous dimensions were all dense, then the size of pos in this dimension is fixed 51 | pos_size = Multiply.join( 52 | [dimension_name(self.output.indexes[i_prev]) for i_prev in range(i)] 53 | ).plus(1) 54 | else: 55 | # Otherwise, the value will change, so provide a good default 56 | pos_size = default_array_size 57 | pos_capacity = pos_capacity_name(target_name, i) 58 | pos_array = pos_name(target_name, i) 59 | source.append(pos_capacity.declare(types.integer).assign(pos_size)) 60 | source.append(pos_array.assign(ArrayAllocate(types.integer, pos_capacity))) 61 | source.append(pos_array.idx(0).assign(0)) 62 | 63 | # crd is always the same 64 | crd_capacity = crd_capacity_name(target_name, i) 65 | source.append(crd_capacity.declare(types.integer).assign(default_array_size)) 66 | source.append( 67 | crd_name(target_name, i).assign(ArrayAllocate(types.integer, crd_capacity)) 68 | ) 69 | 70 | # This is the only thing that get written when doing compute 71 | source.append(layer_pointer(self.output.id, i).declare(types.integer).assign(0)) 72 | 73 | all_dense = False 74 | else: 75 | raise NotImplementedError() 76 | 77 | if kernel_type.is_assemble(): 78 | if all_dense: 79 | vals_size = Multiply.join( 80 | [output_tensor.attr("dimensions").idx(i) for i in range(self.output.order)] 81 | ) 82 | else: 83 | vals_size = default_array_size 84 | vals_capacity = vals_capacity_name(target_name) 85 | source.append(vals_capacity.declare(types.integer).assign(vals_size)) 86 | source.append(vals_name(target_name).assign(ArrayAllocate(types.float, vals_capacity))) 87 | 88 | return source 89 | 90 | def write_assignment(self, right_hand_side: Expression, kernel_type: KernelType): 91 | source = SourceBuilder() 92 | 93 | if self.next_layer != self.output.order: 94 | raise RuntimeError() 95 | 96 | source.append(vals_name(self.output.name).idx(self.vals_pointer()).assign(right_hand_side)) 97 | 98 | return source 99 | 100 | def write_cleanup(self, kernel_type: KernelType): 101 | source = SourceBuilder(f"Assembling output tensor {self.output.name}") 102 | 103 | if kernel_type.is_assemble(): 104 | target_name = self.output.name 105 | output_tensor = Variable(target_name) 106 | 107 | for i, mode in enumerate(self.output.modes): 108 | if mode == Mode.dense: 109 | pass 110 | elif mode == Mode.compressed: 111 | source.append( 112 | output_tensor.attr("indices") 113 | .idx(i) 114 | .idx(0) 115 | .assign(pos_name(target_name, i)) 116 | ) 117 | source.append( 118 | output_tensor.attr("indices") 119 | .idx(i) 120 | .idx(1) 121 | .assign(crd_name(target_name, i)) 122 | ) 123 | else: 124 | raise NotImplementedError() 125 | source.append(output_tensor.attr("vals").assign(vals_name(target_name))) 126 | 127 | return source 128 | 129 | def next_output( 130 | self, iteration_output: TensorLayer | None, kernel_type: KernelType 131 | ) -> tuple[Output, SourceBuilder, SourceBuilder]: 132 | if iteration_output is not None and self.next_layer == iteration_output.layer: 133 | return AppendOutput(self.output, self.next_layer + 1), SourceBuilder(), SourceBuilder() 134 | else: 135 | # No layer or wrong layer was encountered 136 | dense_only_remaining = all( 137 | mode == Mode.dense for mode in self.output.modes[self.next_layer :] 138 | ) 139 | if dense_only_remaining: 140 | next_output = BucketOutput( 141 | self.output, list(range(self.next_layer, len(self.output.modes))) 142 | ) 143 | dimension_names = [ 144 | dimension_name(index) for index in self.output.indexes[self.next_layer :] 145 | ] 146 | bucket = vals_name(self.output.name).plus( 147 | previous_layer_pointer(self.output.id, self.next_layer).times( 148 | Multiply.join(dimension_names) 149 | ) 150 | ) 151 | return next_output, next_output.write_declarations(bucket), SourceBuilder() 152 | else: 153 | raise NotImplementedError( 154 | "Encountered a sparse output layer preceded by a contraction layer or a later " 155 | "output layer. This requires a hash table to store intermediate outputs, " 156 | "which is not currently implemented." 157 | ) 158 | -------------------------------------------------------------------------------- /docs/creation.md: -------------------------------------------------------------------------------- 1 | --- 2 | icon: material/tree 3 | --- 4 | 5 | # Creation 6 | 7 | Creating a `Tensor` is best done via the `Tensor.from_*` methods. These methods convert a variety of data types into a `Tensor`. Most of the conversion methods optionally take both dimensions and format to determine the `dimensions` and `format` of the resulting tensor. 8 | 9 | ## `Tensor.from_lol`: list of lists 10 | 11 | ``` 12 | Tensor.from_lol( 13 | lol, 14 | *, 15 | dimensions: tuple[int, ...] = None, 16 | format: Format | str = None, 17 | ) 18 | ``` 19 | 20 | Convert a nested list of lists to a `Tensor`. 21 | 22 | * `lol` is a list of lists, possibly deeply nested. That is, `lol` is a `float`, a `list[float]`, a `list[list[float]]`, etc. to an arbitrary depth of `list`s. The values are read in row-major format, meaning the top-level list is the first dimension and the deepest list (the one containing actual scalars) is the last dimension. All lists at the same level must have the same length. Note that these "lists" may be `Iterable`s. For those familiar, this is identical to the NumPy behavior when constructing an array from lists of lists via `numpy.array`. 23 | 24 | * `dimensions` has a default value that is inferred from the structure of `lol`. If provided, it must be consistent with the structure of `lol`. Providing the dimensions is typically only useful when one or more non-final dimensions may have size zero. For example, `Tensor.from_lol([[], []])` has dimensions of `(2,0)`, while `Tensor.from_lol([[], []], dimensions=(2,0,3))` has dimensions of `(2,0,3)`. 25 | 26 | * `format` has a default value of all dense dimensions. 27 | 28 | ```python 29 | from tensora import Tensor 30 | 31 | tensor = Tensor.from_lol([[1,2,3], [4,5,6]]) 32 | 33 | assert tensor.dimensions == (2, 3) 34 | ``` 35 | 36 | This is also the best way to create a scalar `Tensor` because passing a single number to this method means the list nesting is 0 levels deep and is therefore a 0-order tensor. 37 | 38 | ```python 39 | from tensora import Tensor 40 | 41 | tensor = Tensor.from_lol(2.5) 42 | 43 | assert tensor.dimensions == () 44 | ``` 45 | 46 | ## `Tensor.from_dok`: dictionary of keys 47 | 48 | ``` 49 | Tensor.from_dok( 50 | dok: dict[tuple[int, ...], float], 51 | *, 52 | dimensions: tuple[int, ...] = None, 53 | format: Format | str = None, 54 | ) 55 | ``` 56 | 57 | Convert a dictionary of keys to a `Tensor`. 58 | 59 | * `dok` is a Python dictionary where each key is the coordinate of one nonzero value and the value of the entry is the value of the tensor at that coordinate. All coordinates not mentioned are implicitly zero. 60 | 61 | * `dimensions` has a default value that is the largest size in each dimension found among the coordinates. 62 | 63 | * `format` has a default value of dense dimensions as long as the number of nonzeros is larger than the product of those dimensions and then sparse dimensions after that. The default value is subject to change with experience. 64 | 65 | ```python 66 | from tensora import Tensor 67 | 68 | tensor = Tensor.from_dok({ 69 | (1,0): 2.0, 70 | (0,1): -2.0, 71 | (1,2): 4.0, 72 | }, dimensions=(2,3), format='ds') 73 | 74 | assert tensor == Tensor.from_lol([[0,-2,0], [2,0,4]]) 75 | ``` 76 | 77 | ## `Tensor.from_aos`: array of structs 78 | 79 | ``` 80 | Tensor.from_aos( 81 | aos: Iterable[tuple[int, ...]], 82 | values: Iterable[float], 83 | *, 84 | dimensions: tuple[int, ...] = None, 85 | format: Format | str = None, 86 | ) 87 | ``` 88 | 89 | Convert a list of coordinates and a corresponding list of values to a `Tensor`. 90 | 91 | * `aos` is an iterable of the coordinates of the nonzero values. 92 | 93 | * `values` must be the same length as `aos` and each value is the value at the corresponding coordinate. 94 | 95 | * `dimensions` has the same default as `Tensor.from_dok`, the largest size in each dimension. 96 | 97 | * `format`has the same default as `Tensor.from_dok`, dense for an many dimensions as needed to fit the nonzeros. 98 | 99 | ```python 100 | from tensora import Tensor 101 | 102 | tensor = Tensor.from_aos( 103 | [(1,0), (0,1), (1,2)], 104 | [2.0, -2.0, 4.0], 105 | dimensions=(2,3), 106 | format='ds', 107 | ) 108 | 109 | assert tensor == Tensor.from_lol([[0,-2,0], [2,0,4]]) 110 | ``` 111 | 112 | ## `Tensor.from_soa`: struct of arrays 113 | 114 | ``` 115 | Tensor.from_soa( 116 | soa: tuple[Iterable[int], ...], 117 | values: Iterable[float], 118 | *, 119 | dimensions: tuple[int, ...] = None, 120 | format: Format | str = None, 121 | ) 122 | ``` 123 | 124 | Convert lists of indexes for each dimension and a corresponding list of values to a `Tensor`. 125 | 126 | * `soa` is a tuple of iterables, where each iterable is all the indexes of the corresponding dimension. All iterables must be the same length. 127 | 128 | * `values` must be the same length as the iterables in `soa` and each value is the nonzero value at the corresponding coordinate. 129 | 130 | * `dimensions` has the same default as `Tensor.from_dok`, the largest size in each dimension. 131 | 132 | * `format` has the same default as `Tensor.from_dok`, dense for an many dimensions as needed to fit the nonzeros. 133 | 134 | ```python 135 | from tensora import Tensor 136 | 137 | tensor = Tensor.from_soa( 138 | ([1,0,1], [0,1,2]), 139 | [2.0, -2.0, 4.0], 140 | dimensions=(2,3), 141 | format='ds', 142 | ) 143 | 144 | assert tensor == Tensor.from_lol([[0,-2,0], [2,0,4]]) 145 | ``` 146 | 147 | ## `Tensor.from_numpy`: convert a NumPy array 148 | 149 | ``` 150 | Tensor.from_numpy( 151 | array: numpy.ndarray, 152 | *, 153 | format: Format | str = None, 154 | ) 155 | ``` 156 | 157 | Convert a NumPy array to a `Tensor`. 158 | 159 | * `array` is any `numpy.ndarray`. The resulting `Tensor` will have the same order, dimensions, and values of this array. 160 | 161 | * `format` has a default value of all dense dimensions. 162 | 163 | ```python 164 | import numpy as np 165 | from tensora import Tensor 166 | 167 | array = np.array([[1,2,3], [4,5,6]]) 168 | tensor = Tensor.from_numpy(array) 169 | 170 | assert tensor == Tensor.from_lol([[1,2,3], [4,5,6]]) 171 | ``` 172 | 173 | ## `Tensor.from_scipy_sparse`: convert a SciPy sparse matrix 174 | 175 | ``` 176 | Tensor.from_scipy_sparse( 177 | matrix: scipy.sparse.spmatrix, 178 | *, 179 | format: Format | str = None, 180 | ) 181 | ``` 182 | 183 | Convert a SciPy sparse matrix to a `Tensor`. 184 | 185 | * `matrix` is any `scipy.sparse.spmatrix`. The resulting `Tensor` will have the same order, dimensions, and values of this matrix. The tensor will always have order 2. 186 | 187 | * `format` has a default value of `ds` for `csr_matrix` and `d1s0` for `csc_matrix` and also `ds` for the other sparse matrix types, though that is subject to changes as Tensora adds new format mode types. 188 | 189 | ```python 190 | import scipy.sparse as sp 191 | from tensora import Tensor 192 | 193 | matrix = sp.csr_matrix(([2.0, -2.0, 4.0], ([1,0,1], [0,1,2])), shape=(2,3)) 194 | tensor = Tensor.from_scipy_sparse(matrix) 195 | 196 | assert tensor.format.deparse() == 'ds' 197 | assert tensor == Tensor.from_lol([[0,-2,0], [2,0,4]]) 198 | ``` 199 | -------------------------------------------------------------------------------- /src/tensora/compile/_tensor_method.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "BackendCompiler", 3 | "BroadcastTargetIndexError", 4 | "TensorMethod", 5 | ] 6 | 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from inspect import Parameter, Signature 10 | 11 | from returns.result import Failure, Success 12 | 13 | from ..expression.ast import Assignment 14 | from ..generate import Language, generate_code, generate_module_tensora 15 | from ..kernel_type import KernelType 16 | from ..problem import Problem 17 | from ..tensor import Tensor 18 | from ._cffi_ownership import allocate_taco_structure, take_ownership_of_arrays, tensor_cdefs 19 | 20 | 21 | @dataclass(frozen=True, slots=True) 22 | class BroadcastTargetIndexError(Exception): 23 | index: str 24 | assignment: Assignment 25 | 26 | def __str__(self): 27 | return ( 28 | f"Expected index variable {self.index} on the target variable to be mentioned on the " 29 | f"right-hand side, but it was not: {self.assignment}. Such broadcasting makes sense " 30 | f"in a kernel and those kernels can be generated, but they cannot be used in " 31 | f"`evaluate` or `tensor_method` because those functions get the output dimensions " 32 | f"from the the dimensions of the input tensors." 33 | ) 34 | 35 | 36 | class BackendCompiler(Enum): 37 | """The tool to generate the machine code. 38 | 39 | Attributes 40 | ---------- 41 | llvm 42 | Generate LLVM IR and compile with the llvmlite package. 43 | cffi 44 | Generate C code and compile with the cffi package. 45 | Not available on Windows. 46 | """ 47 | 48 | llvm = "llvm" 49 | cffi = "cffi" 50 | 51 | 52 | class TensorMethod: 53 | """A function taking specific tensor arguments.""" 54 | 55 | def __init__( 56 | self, 57 | problem: Problem, 58 | backend: BackendCompiler = BackendCompiler.llvm, 59 | ): 60 | # Reject broadcasting to outputs because there is no way to specify output dimensions that 61 | # do not have a corresponding input dimension 62 | input_indexes = set(problem.assignment.expression.index_participants().keys()) 63 | for output_index in problem.assignment.target.indexes: 64 | if output_index not in input_indexes: 65 | raise BroadcastTargetIndexError(output_index, problem.assignment) 66 | 67 | # Store validated attributes 68 | self._problem = problem 69 | self._output_name = problem.assignment.target.name 70 | self._input_formats = { 71 | name: format for name, format in problem.formats.items() if name != self._output_name 72 | } 73 | self._output_format = problem.formats[self._output_name] 74 | 75 | # Create Python signature of the function 76 | self.signature = Signature( 77 | [ 78 | Parameter(parameter_name, Parameter.KEYWORD_ONLY, annotation=Tensor) 79 | for parameter_name in self._input_formats.keys() 80 | ] 81 | ) 82 | 83 | match backend: 84 | case BackendCompiler.llvm: 85 | match generate_module_tensora(problem, [KernelType.evaluate]): 86 | case Failure(error): 87 | raise error 88 | case Success(tensora_module): 89 | from ._compile_llvm import compile_module 90 | 91 | self._lib = compile_module(tensora_module) 92 | 93 | # Convert ctypes function to cffi function 94 | function_type = ( 95 | f"int32_t (*)({', '.join(['void *'] * len(problem.formats))})" 96 | ) 97 | function_pointer = self._lib.get_function_address("evaluate") 98 | self._evaluate = tensor_cdefs.cast(function_type, function_pointer) 99 | case BackendCompiler.cffi: 100 | match generate_code(problem, [KernelType.evaluate], Language.c): 101 | case Failure(error): 102 | raise error 103 | case Success(c_code): 104 | from ._compile_cffi import compile_evaluate 105 | 106 | self._lib = compile_evaluate(c_code) 107 | self._evaluate = self._lib.evaluate 108 | 109 | def __call__(self, *args, **kwargs): 110 | # Handle arguments like normal Python function 111 | bound_arguments = self.signature.bind(*args, **kwargs).arguments 112 | 113 | # Validate tensor arguments 114 | for name, argument, format in zip( 115 | bound_arguments.keys(), 116 | bound_arguments.values(), 117 | self._input_formats.values(), 118 | strict=True, 119 | ): 120 | if not isinstance(argument, Tensor): 121 | raise TypeError(f"Argument {name} must be a Tensor not {type(argument)}") 122 | 123 | if argument.order != format.order: 124 | raise ValueError( 125 | f"Argument {name} must have order {format.order} not {argument.order}" 126 | ) 127 | if tuple(argument.modes) != tuple(format.modes): 128 | raise ValueError( 129 | f"Argument {name} must have modes " 130 | f"{tuple(mode.name for mode in format.modes)} not " 131 | f"{tuple(mode.name for mode in argument.modes)}" 132 | ) 133 | if tuple(argument.mode_ordering) != tuple(format.ordering): 134 | raise ValueError( 135 | f"Argument {name} must have mode ordering " 136 | f"{format.ordering} not {argument.mode_ordering}" 137 | ) 138 | 139 | # Validate dimensions 140 | index_participants = self._problem.assignment.expression.index_participants() 141 | index_sizes = {} 142 | for index, participants in index_participants.items(): 143 | # Extract the size of dimension referenced by this index on each tensor that uses it; record the variable 144 | # name and dimension for a better error 145 | actual_sizes = [ 146 | (variable, dimension, bound_arguments[variable].dimensions[dimension]) 147 | for variable, dimension in participants 148 | ] 149 | 150 | reference_size = actual_sizes[0][2] 151 | index_sizes[index] = reference_size 152 | 153 | for _, _, size in actual_sizes[1:]: 154 | if size != reference_size: 155 | expected = ", ".join( 156 | f"{variable}.dimensions[{dimension}] == {size}" 157 | for variable, dimension, size in actual_sizes 158 | ) 159 | raise ValueError( 160 | f"{self._problem.assignment} expected all these dimensions of these tensors to be the same " 161 | f"because they share the index {index}: {expected}" 162 | ) 163 | 164 | # Determine output dimensions 165 | output_dimensions = tuple( 166 | index_sizes[index] for index in self._problem.assignment.target.indexes 167 | ) 168 | 169 | cffi_output = allocate_taco_structure( 170 | tuple(mode.c_int for mode in self._output_format.modes), 171 | output_dimensions, 172 | self._output_format.ordering, 173 | ) 174 | 175 | output = Tensor(cffi_output) 176 | 177 | all_arguments = {self._output_name: output, **bound_arguments} 178 | 179 | cffi_args = [all_arguments[name].cffi_tensor for name in self._problem.formats.keys()] 180 | 181 | return_value = self._evaluate(*cffi_args) 182 | 183 | take_ownership_of_arrays(cffi_output) 184 | 185 | if return_value != 0: 186 | raise RuntimeError(f"Kernel evaluation failed with error code {return_value}") 187 | 188 | return output 189 | -------------------------------------------------------------------------------- /src/tensora/expression/ast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = [ 4 | "Add", 5 | "Assignment", 6 | "Expression", 7 | "Float", 8 | "Integer", 9 | "Literal", 10 | "Multiply", 11 | "Subtract", 12 | "Tensor", 13 | ] 14 | 15 | from abc import abstractmethod 16 | from dataclasses import dataclass 17 | 18 | 19 | class Expression: 20 | __slots__ = () 21 | 22 | @abstractmethod 23 | def variables(self) -> dict[str, list[Tensor]]: 24 | raise NotImplementedError() 25 | 26 | @abstractmethod 27 | def deparse(self) -> str: 28 | """Convert the assignment back into a string.""" 29 | raise NotImplementedError() 30 | 31 | @abstractmethod 32 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 33 | """Map of index name to tensors and dimensions it exists in. 34 | 35 | Returns: 36 | A mapping where each key is the string name of an index and each value is a sets of 37 | pairs. In each pair, the first element is the name of a tensor and the second element 38 | is the dimension in which that index appears. 39 | """ 40 | raise NotImplementedError() 41 | 42 | def __str__(self): 43 | return self.deparse() 44 | 45 | 46 | class Literal(Expression): 47 | __slots__ = () 48 | 49 | def variables(self) -> dict[str, list[Tensor]]: 50 | return {} 51 | 52 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 53 | return {} 54 | 55 | 56 | @dataclass(frozen=True, slots=True) 57 | class Integer(Literal): 58 | value: int 59 | 60 | def deparse(self): 61 | return str(self.value) 62 | 63 | 64 | @dataclass(frozen=True, slots=True) 65 | class Float(Literal): 66 | value: float 67 | 68 | def deparse(self): 69 | return str(self.value) 70 | 71 | 72 | @dataclass(frozen=True, slots=True) 73 | class Tensor(Expression): 74 | name: str 75 | indexes: tuple[str, ...] 76 | 77 | @property 78 | def order(self): 79 | return len(self.indexes) 80 | 81 | def variables(self) -> dict[str, list[Tensor]]: 82 | return {self.name: [self]} 83 | 84 | def deparse(self): 85 | return self.name + "(" + ",".join(self.indexes) + ")" 86 | 87 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 88 | participants = {} 89 | for i, index_name in enumerate(self.indexes): 90 | participants[index_name] = participants.get(index_name, set()) | {(self.name, i)} 91 | return participants 92 | 93 | 94 | def merge_index_participants(left: Expression, right: Expression): 95 | left_indexes = left.index_participants() 96 | right_indexes = right.index_participants() 97 | return { 98 | index_name: left_indexes.get(index_name, set()) | right_indexes.get(index_name, set()) 99 | for index_name in {*left_indexes.keys(), *right_indexes.keys()} 100 | } 101 | 102 | 103 | @dataclass(frozen=True, slots=True) 104 | class Add(Expression): 105 | left: Expression 106 | right: Expression 107 | 108 | def variables(self) -> dict[str, list[Tensor]]: 109 | variables_mapping = self.left.variables().copy() 110 | for name, variables in self.right.variables().items(): 111 | if name in variables_mapping: 112 | variables_mapping[name] = [*variables_mapping[name], *variables] 113 | else: 114 | variables_mapping[name] = variables 115 | return variables_mapping 116 | 117 | def deparse(self): 118 | left_string = self.left.deparse() 119 | 120 | right_string = self.right.deparse() 121 | if isinstance(self.right, (Add, Subtract)): 122 | # Preserve AST even though addition is associative. 123 | right_string = f"({right_string})" 124 | 125 | return left_string + " + " + right_string 126 | 127 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 128 | return merge_index_participants(self.left, self.right) 129 | 130 | 131 | @dataclass(frozen=True, slots=True) 132 | class Subtract(Expression): 133 | left: Expression 134 | right: Expression 135 | 136 | def variables(self) -> dict[str, list[Tensor]]: 137 | variables_mapping = self.left.variables().copy() 138 | for name, variables in self.right.variables().items(): 139 | if name in variables_mapping: 140 | variables_mapping[name] = [*variables_mapping[name], *variables] 141 | else: 142 | variables_mapping[name] = variables 143 | return variables_mapping 144 | 145 | def deparse(self): 146 | left_string = self.left.deparse() 147 | 148 | right_string = self.right.deparse() 149 | if isinstance(self.right, (Add, Subtract)): 150 | right_string = f"({right_string})" 151 | 152 | return left_string + " - " + right_string 153 | 154 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 155 | return merge_index_participants(self.left, self.right) 156 | 157 | 158 | @dataclass(frozen=True, slots=True) 159 | class Multiply(Expression): 160 | left: Expression 161 | right: Expression 162 | 163 | def variables(self) -> dict[str, list[Tensor]]: 164 | variables_mapping = self.left.variables().copy() 165 | for name, variables in self.right.variables().items(): 166 | if name in variables_mapping: 167 | variables_mapping[name] = [*variables_mapping[name], *variables] 168 | else: 169 | variables_mapping[name] = variables 170 | return variables_mapping 171 | 172 | def deparse(self): 173 | left_string = self.left.deparse() 174 | if isinstance(self.left, (Add, Subtract)): 175 | left_string = f"({left_string})" 176 | 177 | right_string = self.right.deparse() 178 | # Preserve AST even though multiplication is associative. 179 | if isinstance(self.right, (Add, Subtract, Multiply)): 180 | right_string = f"({right_string})" 181 | 182 | return f"{left_string} * {right_string}" 183 | 184 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 185 | return merge_index_participants(self.left, self.right) 186 | 187 | 188 | @dataclass(frozen=True) 189 | class Assignment: 190 | target: Tensor 191 | expression: Expression 192 | 193 | def __post_init__(self): 194 | from ._exceptions import ( 195 | InconsistentDimensionsError, 196 | MutatingAssignmentError, 197 | NameConflictError, 198 | ) 199 | 200 | target_name = self.target.name 201 | 202 | # Mutable containers 203 | variable_orders: dict[str, int] = {target_name: self.target.order} 204 | index_names = set(self.target.indexes) 205 | 206 | # Validate expression 207 | variables_mapping = self.expression.variables() 208 | for name, variables in variables_mapping.items(): 209 | if name == target_name: 210 | raise MutatingAssignmentError(self) 211 | 212 | for variable in variables: 213 | index_names.update(variable.indexes) 214 | 215 | (first, *rest) = variables 216 | for variable in rest: 217 | if first.order != variable.order: 218 | raise InconsistentDimensionsError(self, first, variable) 219 | 220 | variable_orders[name] = first.order 221 | 222 | # Detect name conflicts 223 | conflicted_names = index_names.intersection(variable_orders.keys()) 224 | if len(conflicted_names) > 0: 225 | raise NameConflictError(conflicted_names.pop(), self) 226 | 227 | self._variable_orders: dict[str, int] 228 | object.__setattr__(self, "_variable_orders", variable_orders) 229 | 230 | def deparse(self) -> str: 231 | return self.target.deparse() + " = " + self.expression.deparse() 232 | 233 | def index_participants(self) -> dict[str, set[tuple[str, int]]]: 234 | return merge_index_participants(self.target, self.expression) 235 | 236 | def variable_orders(self) -> dict[str, int]: 237 | """Number of dimensions of each variable. 238 | 239 | Returns: 240 | A mapping where each key is the string name of a variable and each value is the number 241 | of dimensions that variable has. 242 | """ 243 | 244 | return self._variable_orders 245 | 246 | def __str__(self) -> str: 247 | return self.deparse() 248 | -------------------------------------------------------------------------------- /src/tensora/ir/ast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = [ 4 | "Add", 5 | "And", 6 | "ArrayAllocate", 7 | "ArrayIndex", 8 | "ArrayReallocate", 9 | "Assignable", 10 | "Assignment", 11 | "AttributeAccess", 12 | "Block", 13 | "BooleanLiteral", 14 | "BooleanToInteger", 15 | "Branch", 16 | "Declaration", 17 | "DeclarationAssignment", 18 | "Equal", 19 | "Expression", 20 | "FloatLiteral", 21 | "FunctionDefinition", 22 | "GreaterThan", 23 | "GreaterThanOrEqual", 24 | "IntegerLiteral", 25 | "LessThan", 26 | "LessThanOrEqual", 27 | "Loop", 28 | "Max", 29 | "Min", 30 | "Module", 31 | "Multiply", 32 | "NotEqual", 33 | "Or", 34 | "Return", 35 | "Statement", 36 | "Subtract", 37 | "Variable", 38 | ] 39 | 40 | from dataclasses import dataclass 41 | from functools import reduce 42 | from typing import Sequence 43 | 44 | from .types import Type 45 | 46 | 47 | def to_expression(expression: Expression | float | int | str) -> Expression: 48 | match expression: 49 | case Expression(): 50 | return expression 51 | case float(): 52 | return FloatLiteral(expression) 53 | case int(): 54 | return IntegerLiteral(expression) 55 | case str(): 56 | return Variable(expression) 57 | 58 | 59 | class Statement: 60 | __slots__ = () 61 | 62 | 63 | class Expression(Statement): 64 | __slots__ = () 65 | 66 | def plus(self, term: Expression | int | str) -> Expression: 67 | term = to_expression(term) 68 | return Add(self, term) 69 | 70 | def minus(self, term: Expression | int | str) -> Expression: 71 | term = to_expression(term) 72 | return Subtract(self, term) 73 | 74 | def times(self, factor: Expression | int | str) -> Expression: 75 | factor = to_expression(factor) 76 | return Multiply(self, factor) 77 | 78 | 79 | class Assignable(Expression): 80 | __slots__ = () 81 | 82 | def attr(self, attribute: str) -> Assignable: 83 | return AttributeAccess(self, attribute) 84 | 85 | def idx(self, index: Expression | int | str) -> Assignable: 86 | index = to_expression(index) 87 | return ArrayIndex(self, index) 88 | 89 | def assign(self, value: Expression | float | int | str) -> Statement: 90 | value = to_expression(value) 91 | return Assignment(self, value) 92 | 93 | def increment(self, amount: Expression | float | int | str = 1) -> Statement: 94 | amount = to_expression(amount) 95 | return self.assign(self.plus(amount)) 96 | 97 | 98 | @dataclass(frozen=True, slots=True) 99 | class Variable(Assignable): 100 | name: str 101 | 102 | def declare(self, type: Type) -> Declaration: 103 | return Declaration(self, type) 104 | 105 | 106 | @dataclass(frozen=True, slots=True) 107 | class AttributeAccess(Assignable): 108 | # For languages that care, this represents attribute access to an object on 109 | # the heap. There is no way to use an object on the stack in this IR. 110 | target: Assignable 111 | attribute: str 112 | 113 | 114 | @dataclass(frozen=True, slots=True) 115 | class ArrayIndex(Assignable): 116 | # For languages that care, this represents attribute access to an array on 117 | # the heap. There is no way to use an array on the stack in this IR. 118 | target: Assignable 119 | index: Expression 120 | 121 | 122 | @dataclass(frozen=True, slots=True) 123 | class IntegerLiteral(Expression): 124 | value: int 125 | 126 | 127 | @dataclass(frozen=True, slots=True) 128 | class FloatLiteral(Expression): 129 | value: float 130 | 131 | 132 | @dataclass(frozen=True, slots=True) 133 | class BooleanLiteral(Expression): 134 | value: bool 135 | 136 | 137 | @dataclass(frozen=True, slots=True) 138 | class Add(Expression): 139 | left: Expression 140 | right: Expression 141 | 142 | @staticmethod 143 | def join(operands: Sequence[Expression | int | str]) -> Expression: 144 | expression_operands = [to_expression(operand) for operand in operands] 145 | return reduce(Add, expression_operands, IntegerLiteral(0)) 146 | 147 | 148 | @dataclass(frozen=True, slots=True) 149 | class Subtract(Expression): 150 | left: Expression 151 | right: Expression 152 | 153 | 154 | @dataclass(frozen=True, slots=True) 155 | class Multiply(Expression): 156 | left: Expression 157 | right: Expression 158 | 159 | @staticmethod 160 | def join(operands: Sequence[Expression | int | str]) -> Expression: 161 | expression_operands = [to_expression(operand) for operand in operands] 162 | return reduce(Multiply, expression_operands, IntegerLiteral(1)) 163 | 164 | 165 | @dataclass(frozen=True, slots=True) 166 | class Equal(Expression): 167 | left: Expression 168 | right: Expression 169 | 170 | 171 | @dataclass(frozen=True, slots=True) 172 | class NotEqual(Expression): 173 | left: Expression 174 | right: Expression 175 | 176 | 177 | @dataclass(frozen=True, slots=True) 178 | class GreaterThan(Expression): 179 | left: Expression 180 | right: Expression 181 | 182 | 183 | @dataclass(frozen=True, slots=True) 184 | class LessThan(Expression): 185 | left: Expression 186 | right: Expression 187 | 188 | 189 | @dataclass(frozen=True, slots=True) 190 | class GreaterThanOrEqual(Expression): 191 | left: Expression 192 | right: Expression 193 | 194 | 195 | @dataclass(frozen=True, slots=True) 196 | class LessThanOrEqual(Expression): 197 | left: Expression 198 | right: Expression 199 | 200 | 201 | @dataclass(frozen=True, slots=True) 202 | class And(Expression): 203 | left: Expression 204 | right: Expression 205 | 206 | @staticmethod 207 | def join(operands: Sequence[Expression | int | str]) -> Expression: 208 | expression_operands = [to_expression(operand) for operand in operands] 209 | return reduce(And, expression_operands, BooleanLiteral(True)) 210 | 211 | 212 | @dataclass(frozen=True, slots=True) 213 | class Or(Expression): 214 | left: Expression 215 | right: Expression 216 | 217 | @staticmethod 218 | def join(operands: Sequence[Expression | int | str]) -> Expression: 219 | expression_operands = [to_expression(operand) for operand in operands] 220 | return reduce(Or, expression_operands, BooleanLiteral(False)) 221 | 222 | 223 | @dataclass(frozen=True, slots=True) 224 | class Max(Expression): 225 | left: Expression 226 | right: Expression 227 | 228 | @staticmethod 229 | def join(operands: Sequence[Expression | int | str]) -> Expression: 230 | expression_operands = [to_expression(operand) for operand in operands] 231 | return reduce(Max, expression_operands) 232 | 233 | 234 | @dataclass(frozen=True, slots=True) 235 | class Min(Expression): 236 | left: Expression 237 | right: Expression 238 | 239 | @staticmethod 240 | def join(operands: Sequence[Expression | int | str]) -> Expression: 241 | expression_operands = [to_expression(operand) for operand in operands] 242 | return reduce(Min, expression_operands) 243 | 244 | 245 | @dataclass(frozen=True, slots=True) 246 | class BooleanToInteger(Expression): 247 | expression: Expression 248 | 249 | 250 | @dataclass(frozen=True, slots=True) 251 | class ArrayAllocate(Expression): 252 | element_type: Type 253 | n_elements: Expression 254 | 255 | 256 | @dataclass(frozen=True, slots=True) 257 | class ArrayReallocate(Expression): 258 | old: Assignable 259 | element_type: Type 260 | n_elements: Expression 261 | 262 | 263 | @dataclass(frozen=True, slots=True) 264 | class Declaration(Statement): 265 | name: Variable 266 | type: Type 267 | 268 | def assign(self, value: Expression | int | str) -> Statement: 269 | value = to_expression(value) 270 | return DeclarationAssignment(self, value) 271 | 272 | 273 | @dataclass(frozen=True, slots=True) 274 | class Assignment(Statement): 275 | target: Assignable 276 | value: Expression 277 | 278 | 279 | @dataclass(frozen=True, slots=True) 280 | class DeclarationAssignment(Statement): 281 | target: Declaration 282 | value: Expression 283 | 284 | 285 | @dataclass(frozen=True, slots=True) 286 | class Block(Statement): 287 | statements: list[Statement] 288 | comment: str | None = None 289 | 290 | def is_empty(self): 291 | return len(self.statements) == 0 292 | 293 | 294 | @dataclass(frozen=True, slots=True) 295 | class Branch(Statement): 296 | condition: Expression 297 | if_true: Statement 298 | if_false: Statement 299 | 300 | @staticmethod 301 | def join(leaves: Sequence[tuple[Expression | int | str, Statement]]) -> Statement: 302 | # This is a fold right operation 303 | return reduce( 304 | lambda previous, leaf: Branch(to_expression(leaf[0]), leaf[1], previous), 305 | reversed(leaves), 306 | Block([]), 307 | ) 308 | 309 | 310 | @dataclass(frozen=True, slots=True) 311 | class Loop(Statement): 312 | condition: Expression 313 | body: Statement 314 | 315 | 316 | @dataclass(frozen=True, slots=True) 317 | class Return(Statement): 318 | value: Expression 319 | 320 | 321 | @dataclass(frozen=True, slots=True) 322 | class FunctionDefinition: 323 | name: Variable 324 | parameters: list[Declaration] 325 | return_type: Type 326 | body: Statement 327 | 328 | 329 | @dataclass(frozen=True, slots=True) 330 | class Module: 331 | definitions: list[FunctionDefinition] 332 | -------------------------------------------------------------------------------- /tests/ir/test_peephole.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tensora.ir import peephole, peephole_function_definition, peephole_statement 4 | from tensora.ir.ast import * 5 | from tensora.ir.types import * 6 | 7 | changed = [ 8 | # Add zero 9 | (Add(IntegerLiteral(0), Variable("x")), Variable("x")), 10 | (Add(Variable("x"), IntegerLiteral(0)), Variable("x")), 11 | (Add(FloatLiteral(0.0), Variable("x")), Variable("x")), 12 | (Add(Variable("x"), FloatLiteral(0.0)), Variable("x")), 13 | # Minus zero 14 | (Subtract(Variable("x"), IntegerLiteral(0)), Variable("x")), 15 | (Subtract(Variable("x"), FloatLiteral(0.0)), Variable("x")), 16 | # Multiply 0 17 | (Multiply(IntegerLiteral(0), Variable("x")), IntegerLiteral(0)), 18 | (Multiply(Variable("x"), IntegerLiteral(0)), IntegerLiteral(0)), 19 | (Multiply(FloatLiteral(0.0), Variable("x")), FloatLiteral(0.0)), 20 | (Multiply(Variable("x"), FloatLiteral(0.0)), FloatLiteral(0.0)), 21 | # Multiply 1 22 | (Multiply(IntegerLiteral(1), Variable("x")), Variable("x")), 23 | (Multiply(Variable("x"), IntegerLiteral(1)), Variable("x")), 24 | (Multiply(FloatLiteral(1.0), Variable("x")), Variable("x")), 25 | (Multiply(Variable("x"), FloatLiteral(1.0)), Variable("x")), 26 | # Equals same 27 | (Equal(Variable("x"), Variable("x")), BooleanLiteral(True)), 28 | (LessThanOrEqual(Variable("x"), Variable("x")), BooleanLiteral(True)), 29 | (GreaterThanOrEqual(Variable("x"), Variable("x")), BooleanLiteral(True)), 30 | # Not equal same 31 | (NotEqual(Variable("x"), Variable("x")), BooleanLiteral(False)), 32 | (LessThan(Variable("x"), Variable("x")), BooleanLiteral(False)), 33 | (GreaterThan(Variable("x"), Variable("x")), BooleanLiteral(False)), 34 | # And true 35 | (And(BooleanLiteral(True), Variable("x")), Variable("x")), 36 | (And(Variable("x"), BooleanLiteral(True)), Variable("x")), 37 | # And false 38 | (And(BooleanLiteral(False), Variable("x")), BooleanLiteral(False)), 39 | (And(Variable("x"), BooleanLiteral(False)), BooleanLiteral(False)), 40 | # Or true 41 | (Or(BooleanLiteral(True), Variable("x")), BooleanLiteral(True)), 42 | (Or(Variable("x"), BooleanLiteral(True)), BooleanLiteral(True)), 43 | # Or false 44 | (Or(BooleanLiteral(False), Variable("x")), Variable("x")), 45 | (Or(Variable("x"), BooleanLiteral(False)), Variable("x")), 46 | # Boolean cast constant 47 | (BooleanToInteger(BooleanLiteral(False)), IntegerLiteral(0)), 48 | (BooleanToInteger(BooleanLiteral(True)), IntegerLiteral(1)), 49 | # Branch true 50 | (Branch(BooleanLiteral(True), Variable("x"), Variable("y")), Variable("x")), 51 | # Branch false 52 | (Branch(BooleanLiteral(False), Variable("x"), Variable("y")), Variable("y")), 53 | # Loop false 54 | (Loop(BooleanLiteral(False), Variable("x")), Block([])), 55 | # Empty block 56 | (Block([Block([], "comment"), Return(IntegerLiteral(0))]), Block([Return(IntegerLiteral(0))])), 57 | (Branch(Variable("t"), Block([], "comment1"), Block([], "comment2")), Block([])), 58 | (Loop(Variable("t"), Block([], "comment")), Block([])), 59 | # Redundant assignment 60 | (Assignment(Variable("x"), Variable("x")), Block([])), 61 | ( 62 | Assignment( 63 | ArrayIndex(Variable("x"), IntegerLiteral(0)), 64 | ArrayIndex(Variable("x"), IntegerLiteral(0)), 65 | ), 66 | Block([]), 67 | ), 68 | # Pass through 69 | ( 70 | ArrayIndex( 71 | ArrayIndex(Variable("x"), Add(IntegerLiteral(0), Variable("i"))), Variable("j") 72 | ), 73 | ArrayIndex(ArrayIndex(Variable("x"), Variable("i")), Variable("j")), 74 | ), 75 | ( 76 | ArrayIndex(Variable("x"), Add(IntegerLiteral(0), Variable("i"))), 77 | ArrayIndex(Variable("x"), Variable("i")), 78 | ), 79 | ( 80 | AttributeAccess(ArrayIndex(Variable("x"), Add(IntegerLiteral(0), Variable("i"))), "modes"), 81 | AttributeAccess(ArrayIndex(Variable("x"), Variable("i")), "modes"), 82 | ), 83 | ( 84 | BooleanToInteger(And(BooleanLiteral(True), Variable("t"))), 85 | BooleanToInteger(Variable("t")), 86 | ), 87 | ( 88 | ArrayAllocate(float, Add(IntegerLiteral(0), Variable("x"))), 89 | ArrayAllocate(float, Variable("x")), 90 | ), 91 | ( 92 | ArrayReallocate(Variable("x"), float, Add(IntegerLiteral(0), Variable("y"))), 93 | ArrayReallocate(Variable("x"), float, Variable("y")), 94 | ), 95 | ( 96 | Assignment(Variable("x"), Add(IntegerLiteral(0), Variable("y"))), 97 | Assignment(Variable("x"), Variable("y")), 98 | ), 99 | ( 100 | Assignment( 101 | ArrayIndex(Variable("x"), Add(IntegerLiteral(0), Variable("i"))), Variable("y") 102 | ), 103 | Assignment(ArrayIndex(Variable("x"), Variable("i")), Variable("y")), 104 | ), 105 | ( 106 | DeclarationAssignment( 107 | Declaration(Variable("x"), float), Add(FloatLiteral(0.0), Variable("y")) 108 | ), 109 | DeclarationAssignment(Declaration(Variable("x"), float), Variable("y")), 110 | ), 111 | ( 112 | Block( 113 | [ 114 | Assignment( 115 | Variable("x"), Add(Add(IntegerLiteral(0), Variable("x")), IntegerLiteral(1)) 116 | ) 117 | ], 118 | "test block", 119 | ), 120 | Block([Assignment(Variable("x"), Add(Variable("x"), IntegerLiteral(1)))], "test block"), 121 | ), 122 | ( 123 | Branch(Variable("t"), Variable("x"), Add(IntegerLiteral(0), Variable("y"))), 124 | Branch(Variable("t"), Variable("x"), Variable("y")), 125 | ), 126 | ( 127 | Branch( 128 | Variable("t"), 129 | Add(IntegerLiteral(0), Variable("x")), 130 | Add(IntegerLiteral(0), Variable("y")), 131 | ), 132 | Branch(Variable("t"), Variable("x"), Variable("y")), 133 | ), 134 | ( 135 | Branch( 136 | And(BooleanLiteral(True), Variable("t")), 137 | Variable("x"), 138 | Add(IntegerLiteral(0), Variable("y")), 139 | ), 140 | Branch(Variable("t"), Variable("x"), Variable("y")), 141 | ), 142 | ( 143 | Loop(Variable("t"), Add(IntegerLiteral(0), Variable("x"))), 144 | Loop(Variable("t"), Variable("x")), 145 | ), 146 | ( 147 | Loop(And(BooleanLiteral(True), Variable("t")), Variable("x")), 148 | Loop(Variable("t"), Variable("x")), 149 | ), 150 | ( 151 | Return(Add(IntegerLiteral(0), Variable("x"))), 152 | Return(Variable("x")), 153 | ), 154 | ] 155 | 156 | 157 | @pytest.mark.parametrize(("before", "after"), changed) 158 | def test_peephole_statement(before: Statement, after: Statement): 159 | assert peephole_statement(before) == after 160 | 161 | 162 | left_right_classes = [ 163 | Add, 164 | Subtract, 165 | Multiply, 166 | Equal, 167 | NotEqual, 168 | LessThan, 169 | LessThanOrEqual, 170 | GreaterThan, 171 | GreaterThanOrEqual, 172 | And, 173 | Or, 174 | Max, 175 | Min, 176 | ] 177 | 178 | 179 | @pytest.mark.parametrize("cls", left_right_classes) 180 | def test_pass_through_left_right(cls): 181 | left = Add(IntegerLiteral(0), Variable("x")) 182 | right = Add(IntegerLiteral(0), Variable("y")) 183 | expected = cls(Variable("x"), Variable("y")) 184 | assert peephole_statement(cls(left, Variable("y"))) == expected 185 | assert peephole_statement(cls(Variable("x"), right)) == expected 186 | 187 | 188 | unchanged = [ 189 | Subtract(IntegerLiteral(0), Variable("x")), 190 | Subtract(FloatLiteral(0.0), Variable("x")), 191 | Equal(Variable("x"), Variable("y")), 192 | LessThanOrEqual(Variable("x"), Variable("y")), 193 | GreaterThanOrEqual(Variable("x"), Variable("y")), 194 | NotEqual(Variable("x"), Variable("y")), 195 | LessThan(Variable("x"), Variable("y")), 196 | GreaterThan(Variable("x"), Variable("y")), 197 | Loop(BooleanLiteral(True), Variable("x")), 198 | Assignment(Variable("x"), ArrayIndex(ArrayIndex(Variable("y"), Variable("i")), Variable("j"))), 199 | Declaration(Variable("x"), float), 200 | ] 201 | 202 | 203 | @pytest.mark.parametrize("input", unchanged) 204 | def test_peephole_statement_noop(input: Statement): 205 | assert peephole_statement(input) == input 206 | 207 | 208 | def test_peephole_function_definition(): 209 | function = FunctionDefinition( 210 | Variable("f"), 211 | [Declaration(Variable("x"), tensor)], 212 | integer, 213 | Return(Multiply(IntegerLiteral(0), IntegerLiteral(1))), 214 | ) 215 | expected = FunctionDefinition( 216 | Variable("f"), 217 | [Declaration(Variable("x"), tensor)], 218 | integer, 219 | Return(IntegerLiteral(0)), 220 | ) 221 | assert peephole_function_definition(function) == expected 222 | 223 | 224 | def test_peephole_module(): 225 | input_function = FunctionDefinition( 226 | Variable("f"), 227 | [Declaration(Variable("x"), tensor)], 228 | integer, 229 | Return(Multiply(IntegerLiteral(0), IntegerLiteral(1))), 230 | ) 231 | expected_function = FunctionDefinition( 232 | Variable("f"), 233 | [Declaration(Variable("x"), tensor)], 234 | integer, 235 | Return(IntegerLiteral(0)), 236 | ) 237 | 238 | module = Module([input_function, input_function]) 239 | expected = Module([expected_function, expected_function]) 240 | assert peephole(module) == expected 241 | -------------------------------------------------------------------------------- /src/tensora/codegen/_ir_to_c.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ir_to_c", "ir_to_c_function_definition", "ir_to_c_statement"] 2 | 3 | from functools import singledispatch 4 | 5 | from ..ir.ast import ( 6 | Add, 7 | And, 8 | ArrayAllocate, 9 | ArrayIndex, 10 | ArrayReallocate, 11 | Assignment, 12 | AttributeAccess, 13 | Block, 14 | BooleanLiteral, 15 | BooleanToInteger, 16 | Branch, 17 | Declaration, 18 | DeclarationAssignment, 19 | Equal, 20 | Expression, 21 | FloatLiteral, 22 | FunctionDefinition, 23 | GreaterThan, 24 | GreaterThanOrEqual, 25 | IntegerLiteral, 26 | LessThan, 27 | LessThanOrEqual, 28 | Loop, 29 | Max, 30 | Min, 31 | Module, 32 | Multiply, 33 | NotEqual, 34 | Or, 35 | Return, 36 | Statement, 37 | Subtract, 38 | Variable, 39 | ) 40 | from ._type_to_c import type_to_c 41 | 42 | 43 | def parens(code: Expression, wrap_me: type | tuple[type, ...]): 44 | string = ir_to_c_expression(code) 45 | if isinstance(code, wrap_me): 46 | string = f"({string})" 47 | return string 48 | 49 | 50 | def indent_lines(lines: list[str]) -> list[str]: 51 | return [" " + line for line in lines] 52 | 53 | 54 | @singledispatch 55 | def ir_to_c_expression(self: Expression) -> str: 56 | raise NotImplementedError(f"ir_to_c_expression not implemented for {type(self)}: {self}") 57 | 58 | 59 | @ir_to_c_expression.register(Variable) 60 | def ir_to_c_variable(self: Variable) -> str: 61 | return self.name 62 | 63 | 64 | @ir_to_c_expression.register(AttributeAccess) 65 | def ir_to_c_attribute_access(self: AttributeAccess) -> str: 66 | return f"{ir_to_c_expression(self.target)}->{self.attribute}" 67 | 68 | 69 | @ir_to_c_expression.register(ArrayIndex) 70 | def ir_to_c_array_index(self: ArrayIndex) -> str: 71 | return f"{ir_to_c_expression(self.target)}[{ir_to_c_expression(self.index)}]" 72 | 73 | 74 | @ir_to_c_expression.register(IntegerLiteral) 75 | def ir_to_c_integer_literal(self: IntegerLiteral) -> str: 76 | return str(self.value) 77 | 78 | 79 | @ir_to_c_expression.register(FloatLiteral) 80 | def ir_to_c_float_literal(self: FloatLiteral) -> str: 81 | return str(self.value) 82 | 83 | 84 | @ir_to_c_expression.register(BooleanLiteral) 85 | def ir_to_c_boolean_literal(self: BooleanLiteral) -> str: 86 | return str(int(self.value)) 87 | 88 | 89 | @ir_to_c_expression.register(Add) 90 | def ir_to_c_add(self: Add) -> str: 91 | return f"{ir_to_c_expression(self.left)} + {ir_to_c_expression(self.right)}" 92 | 93 | 94 | @ir_to_c_expression.register(Subtract) 95 | def ir_to_c_subtract(self: Subtract) -> str: 96 | # Subtract does not have the associative property so it needs parentheses around the right operand if it has the 97 | # same precedence. 98 | return f"{ir_to_c_expression(self.left)} - {parens(self.right, (Add, Subtract))}" 99 | 100 | 101 | @ir_to_c_expression.register(Multiply) 102 | def ir_to_c_multiply(self: Multiply) -> str: 103 | return f"{parens(self.left, (Add, Subtract))} * {parens(self.right, (Add, Subtract))}" 104 | 105 | 106 | @ir_to_c_expression.register(Equal) 107 | def ir_to_c_equal(self: Equal) -> str: 108 | return f"{ir_to_c_expression(self.left)} == {ir_to_c_expression(self.right)}" 109 | 110 | 111 | @ir_to_c_expression.register(NotEqual) 112 | def ir_to_c_not_equal(self: NotEqual) -> str: 113 | return f"{ir_to_c_expression(self.left)} != {ir_to_c_expression(self.right)}" 114 | 115 | 116 | @ir_to_c_expression.register(GreaterThan) 117 | def ir_to_c_greater_than(self: GreaterThan) -> str: 118 | return f"{ir_to_c_expression(self.left)} > {ir_to_c_expression(self.right)}" 119 | 120 | 121 | @ir_to_c_expression.register(LessThan) 122 | def ir_to_c_less_than(self: LessThan) -> str: 123 | return f"{ir_to_c_expression(self.left)} < {ir_to_c_expression(self.right)}" 124 | 125 | 126 | @ir_to_c_expression.register(GreaterThanOrEqual) 127 | def ir_to_c_greater_than_or_equal(self: GreaterThanOrEqual) -> str: 128 | return f"{ir_to_c_expression(self.left)} >= {ir_to_c_expression(self.right)}" 129 | 130 | 131 | @ir_to_c_expression.register(LessThanOrEqual) 132 | def ir_to_c_less_than_or_equal(self: LessThanOrEqual) -> str: 133 | return f"{ir_to_c_expression(self.left)} <= {ir_to_c_expression(self.right)}" 134 | 135 | 136 | @ir_to_c_expression.register(And) 137 | def ir_to_c_and(self: And) -> str: 138 | return f"{parens(self.left, Or)} && {parens(self.right, Or)}" 139 | 140 | 141 | @ir_to_c_expression.register(Or) 142 | def ir_to_c_or(self: Or) -> str: 143 | return f"{ir_to_c_expression(self.left)} || {ir_to_c_expression(self.right)}" 144 | 145 | 146 | @ir_to_c_expression.register(Max) 147 | def ir_to_c_max(self: Max) -> str: 148 | return f"TACO_MAX({ir_to_c_expression(self.left)}, {ir_to_c_expression(self.right)})" 149 | 150 | 151 | @ir_to_c_expression.register(Min) 152 | def ir_to_c_min(self: Min) -> str: 153 | return f"TACO_MIN({ir_to_c_expression(self.left)}, {ir_to_c_expression(self.right)})" 154 | 155 | 156 | @ir_to_c_expression.register(BooleanToInteger) 157 | def ir_to_c_boolean_to_integer(self: BooleanToInteger) -> str: 158 | return f"(int32_t)({ir_to_c_expression(self.expression)})" 159 | 160 | 161 | @ir_to_c_expression.register(ArrayAllocate) 162 | def ir_to_c_array_allocate(self: ArrayAllocate) -> str: 163 | return f"malloc(sizeof({type_to_c(self.element_type)}) * {parens(self.n_elements, (Add, Subtract))})" 164 | 165 | 166 | @ir_to_c_expression.register(ArrayReallocate) 167 | def ir_to_c_array_reallocate(self: ArrayReallocate) -> str: 168 | old = ir_to_c_expression(self.old) 169 | return f"realloc({old}, sizeof({type_to_c(self.element_type)}) * {parens(self.n_elements, (Add, Subtract))})" 170 | 171 | 172 | def ir_to_c_declaration(self: Declaration) -> str: 173 | return type_to_c(self.type, self.name.name) 174 | 175 | 176 | @singledispatch 177 | def ir_to_c_statement(self: Statement) -> list[str]: 178 | raise NotImplementedError(f"ir_to_c_statement not implemented for {type(self)}: {self}") 179 | 180 | 181 | @ir_to_c_statement.register(Expression) 182 | def convert_expression_to_statement(self: Expression) -> list[str]: 183 | # Every expression can also be a statement; convert it here 184 | return [ir_to_c_expression(self) + ";"] 185 | 186 | 187 | @ir_to_c_statement.register(Declaration) 188 | def convert_declaration_to_statement(self: Declaration) -> list[str]: 189 | return [ir_to_c_declaration(self) + ";"] 190 | 191 | 192 | @ir_to_c_statement.register(Assignment) 193 | def ir_to_c_assignment(self: Assignment) -> list[str]: 194 | target = ir_to_c_expression(self.target) 195 | if isinstance(self.value, Add) and self.value.left == self.target: 196 | if self.value.right == IntegerLiteral(1): 197 | return [f"{target}++;"] 198 | else: 199 | return [f"{target} += {ir_to_c_expression(self.value.right)};"] 200 | elif isinstance(self.value, Subtract) and self.value.left == self.target: 201 | if self.value.right == IntegerLiteral(1): 202 | return [f"{target}--;"] 203 | else: 204 | return [f"{target} -= {ir_to_c_expression(self.value.right)};"] 205 | elif isinstance(self.value, Multiply) and self.value.left == self.target: 206 | return [f"{target} *= {ir_to_c_expression(self.value.right)};"] 207 | else: 208 | return [f"{target} = {ir_to_c_expression(self.value)};"] 209 | 210 | 211 | @ir_to_c_statement.register(DeclarationAssignment) 212 | def ir_to_c_declaration_assignment(self: DeclarationAssignment) -> list[str]: 213 | return [f"{ir_to_c_declaration(self.target)} = {ir_to_c_expression(self.value)};"] 214 | 215 | 216 | @ir_to_c_statement.register(Block) 217 | def ir_to_c_block(self: Block) -> list[str]: 218 | lines = [] 219 | need_separator = False 220 | 221 | if self.comment is not None: 222 | # Add comment if it there is one 223 | lines.append(f"// {self.comment}") 224 | 225 | for statement in self.statements: 226 | if isinstance(statement, Block): 227 | if len(lines) > 0: 228 | # Add blank line if there are preceding lines 229 | lines.append("") 230 | lines.extend(ir_to_c_block(statement)) 231 | need_separator = True 232 | else: 233 | if need_separator: 234 | # A block immediately preceded this line, so add the separator 235 | lines.append("") 236 | need_separator = False 237 | lines.extend(ir_to_c_statement(statement)) 238 | 239 | return lines 240 | 241 | 242 | @ir_to_c_statement.register(Branch) 243 | def ir_to_c_branch(self: Branch) -> list[str]: 244 | if_true_lines = ir_to_c_statement(self.if_true) 245 | if_false_lines = ir_to_c_statement(self.if_false) 246 | 247 | lines = [] 248 | lines.append(f"if ({ir_to_c_expression(self.condition)}) {{") 249 | lines.extend(indent_lines(if_true_lines)) 250 | 251 | if isinstance(self.if_false, Branch): 252 | # Special case if-else chain to be put at the same level 253 | lines.append(f"}} else {if_false_lines[0]}") 254 | lines.extend(if_false_lines[1:]) 255 | elif self.if_false == Block([]): 256 | # Special case empty if_false block to emit no else branch 257 | lines.append("}") 258 | else: 259 | lines.append("} else {") 260 | lines.extend(indent_lines(if_false_lines)) 261 | lines.append("}") 262 | 263 | return lines 264 | 265 | 266 | @ir_to_c_statement.register(Loop) 267 | def ir_to_c_loop(self: Loop) -> list[str]: 268 | return [ 269 | f"while ({ir_to_c_expression(self.condition)}) {{", 270 | *indent_lines(ir_to_c_statement(self.body)), 271 | "}", 272 | ] 273 | 274 | 275 | @ir_to_c_statement.register(Return) 276 | def ir_to_c_return(self: Return) -> list[str]: 277 | return [f"return {ir_to_c_expression(self.value)};"] 278 | 279 | 280 | def ir_to_c_function_definition(self: FunctionDefinition) -> str: 281 | return_type_string = type_to_c(self.return_type) 282 | name_string = ir_to_c_expression(self.name) 283 | parameters_string = ", ".join(map(ir_to_c_declaration, self.parameters)) 284 | 285 | lines = [ 286 | f"{return_type_string} {name_string}({parameters_string}) {{", 287 | *indent_lines(ir_to_c_statement(self.body)), 288 | "}", 289 | ] 290 | 291 | return "\n".join(lines) 292 | 293 | 294 | def ir_to_c(self: Module) -> str: 295 | return "\n\n".join((ir_to_c_function_definition(function) for function in self.definitions)) 296 | -------------------------------------------------------------------------------- /src/tensora/ir/_peephole.py: -------------------------------------------------------------------------------- 1 | """Peephole optimizations. 2 | 3 | Current optimization implemented: 4 | * add_zero: 0 + expr or expr + 0 => expr 5 | * minus_zero: expr - 0 => expr 6 | * multiply_zero: 0 * expr or expr * 0 => 0 7 | * multiply_one: 1 * expr or expr * 1 => expr 8 | * equal_same: a == a or a <= a or a >= a => true 9 | * not_equal_same: a != a or a < a or a > a => false 10 | * and_true: expr & true or true & expr => expr 11 | * and_false: expr & false or false & expr => false 12 | * or_true: expr | true or true | expr => true 13 | * or_false: expr | false or false | expr => expr 14 | * boolean_cast_constant: cast(false) => 0 or cast(true) => 1 15 | * branch_true: if (true) then block1 else block2 => block1 16 | * branch_false: if (false) then block1 else block2 => block2 17 | * loop_false: while(false) block => {} 18 | * empty_block: empty blocks get deleted in blocks, branches, and loops 19 | * redundant_assignment: a = a => {} 20 | """ 21 | 22 | __all__ = ["peephole", "peephole_function_definition", "peephole_statement"] 23 | 24 | from dataclasses import replace 25 | from functools import singledispatch 26 | 27 | from .ast import ( 28 | Add, 29 | And, 30 | ArrayAllocate, 31 | ArrayIndex, 32 | ArrayReallocate, 33 | Assignable, 34 | Assignment, 35 | AttributeAccess, 36 | Block, 37 | BooleanLiteral, 38 | BooleanToInteger, 39 | Branch, 40 | Declaration, 41 | DeclarationAssignment, 42 | Equal, 43 | Expression, 44 | FloatLiteral, 45 | FunctionDefinition, 46 | GreaterThan, 47 | GreaterThanOrEqual, 48 | IntegerLiteral, 49 | LessThan, 50 | LessThanOrEqual, 51 | Loop, 52 | Max, 53 | Min, 54 | Module, 55 | Multiply, 56 | NotEqual, 57 | Or, 58 | Return, 59 | Statement, 60 | Subtract, 61 | Variable, 62 | ) 63 | 64 | 65 | @singledispatch 66 | def peephole_assignable(self: Assignable) -> Assignable: 67 | raise NotImplementedError(f"peephole_assignable not implemented for {type(self)}: {self}") 68 | 69 | 70 | @peephole_assignable.register(Variable) 71 | def peephole_variable(self: Variable) -> Variable: 72 | return self 73 | 74 | 75 | @peephole_assignable.register(AttributeAccess) 76 | def peephole_attribute_access(self: AttributeAccess) -> Assignable: 77 | return replace(self, target=peephole_assignable(self.target)) 78 | 79 | 80 | @peephole_assignable.register(ArrayIndex) 81 | def peephole_array_index(self: ArrayIndex) -> Assignable: 82 | return ArrayIndex(peephole_assignable(self.target), peephole_expression(self.index)) 83 | 84 | 85 | @singledispatch 86 | def peephole_expression(self: Expression) -> Expression: 87 | raise NotImplementedError(f"peephole_expression not implemented for {type(self)}: {self}") 88 | 89 | 90 | @peephole_expression.register(Assignable) 91 | def peephole_expression_assignable(self: Assignable) -> Assignable: 92 | # Assignables are expressions in their own right 93 | return peephole_assignable(self) 94 | 95 | 96 | @peephole_expression.register(IntegerLiteral) 97 | @peephole_expression.register(FloatLiteral) 98 | @peephole_expression.register(BooleanLiteral) 99 | def peephole_noop(self: Expression) -> Expression: 100 | return self 101 | 102 | 103 | @peephole_expression.register(Add) 104 | def peephole_add(self: Add) -> Expression: 105 | left = peephole_expression(self.left) 106 | right = peephole_expression(self.right) 107 | 108 | if left == IntegerLiteral(0) or left == FloatLiteral(0.0): 109 | return right 110 | elif right == IntegerLiteral(0) or right == FloatLiteral(0.0): 111 | return left 112 | else: 113 | return Add(left, right) 114 | 115 | 116 | @peephole_expression.register(Subtract) 117 | def peephole_subtract(self: Subtract) -> Expression: 118 | left = peephole_expression(self.left) 119 | right = peephole_expression(self.right) 120 | 121 | if right == IntegerLiteral(0) or right == FloatLiteral(0.0): 122 | return left 123 | else: 124 | return Subtract(left, right) 125 | 126 | 127 | @peephole_expression.register(Multiply) 128 | def peephole_multiply(self: Multiply) -> Expression: 129 | left = peephole_expression(self.left) 130 | right = peephole_expression(self.right) 131 | 132 | if left == IntegerLiteral(0) or right == IntegerLiteral(0): 133 | return IntegerLiteral(0) 134 | elif left == FloatLiteral(0.0) or right == FloatLiteral(0.0): 135 | return FloatLiteral(0.0) 136 | elif left == IntegerLiteral(1) or left == FloatLiteral(1.0): 137 | return right 138 | elif right == IntegerLiteral(1) or right == FloatLiteral(1.0): 139 | return left 140 | else: 141 | return Multiply(left, right) 142 | 143 | 144 | @peephole_expression.register(Equal) 145 | @peephole_expression.register(GreaterThanOrEqual) 146 | @peephole_expression.register(LessThanOrEqual) 147 | def peephole_equal(self: Equal) -> Expression: 148 | left = peephole_expression(self.left) 149 | right = peephole_expression(self.right) 150 | 151 | if left == right: 152 | return BooleanLiteral(True) 153 | else: 154 | # Use replace so the class is retained 155 | return replace(self, left=left, right=right) 156 | 157 | 158 | @peephole_expression.register(NotEqual) 159 | @peephole_expression.register(GreaterThan) 160 | @peephole_expression.register(LessThan) 161 | def peephole_not_equal(self: NotEqual) -> Expression: 162 | left = peephole_expression(self.left) 163 | right = peephole_expression(self.right) 164 | 165 | if left == right: 166 | return BooleanLiteral(False) 167 | else: 168 | # Use replace so the class is retained 169 | return replace(self, left=left, right=right) 170 | 171 | 172 | @peephole_expression.register(And) 173 | def peephole_and(self: And) -> Expression: 174 | left = peephole_expression(self.left) 175 | right = peephole_expression(self.right) 176 | 177 | if left == BooleanLiteral(False) or right == BooleanLiteral(False): 178 | return BooleanLiteral(False) 179 | elif left == BooleanLiteral(True): 180 | return right 181 | elif right == BooleanLiteral(True): 182 | return left 183 | else: 184 | return And(left, right) 185 | 186 | 187 | @peephole_expression.register(Or) 188 | def peephole_or(self: Or) -> Expression: 189 | left = peephole_expression(self.left) 190 | right = peephole_expression(self.right) 191 | 192 | if left == BooleanLiteral(True) or right == BooleanLiteral(True): 193 | return BooleanLiteral(True) 194 | elif left == BooleanLiteral(False): 195 | return right 196 | elif right == BooleanLiteral(False): 197 | return left 198 | else: 199 | return Or(left, right) 200 | 201 | 202 | @peephole_expression.register(Max) 203 | @peephole_expression.register(Min) 204 | def peephole_max_min(self: Max) -> Expression: 205 | left = peephole_expression(self.left) 206 | right = peephole_expression(self.right) 207 | 208 | # Use replace so the class is retained 209 | return replace(self, left=left, right=right) 210 | 211 | 212 | @peephole_expression.register(BooleanToInteger) 213 | def peephole_boolean_to_integer(self: BooleanToInteger) -> Expression: 214 | expression = peephole_expression(self.expression) 215 | 216 | if expression == BooleanLiteral(False): 217 | return IntegerLiteral(0) 218 | elif expression == BooleanLiteral(True): 219 | return IntegerLiteral(1) 220 | else: 221 | return BooleanToInteger(expression) 222 | 223 | 224 | @peephole_expression.register(ArrayAllocate) 225 | def peephole_array_allocate(self: ArrayAllocate) -> Expression: 226 | n_elements = peephole_expression(self.n_elements) 227 | return replace(self, n_elements=n_elements) 228 | 229 | 230 | @peephole_expression.register(ArrayReallocate) 231 | def peephole_array_reallocate(self: ArrayReallocate) -> Expression: 232 | old = peephole_assignable(self.old) 233 | n_elements = peephole_expression(self.n_elements) 234 | return replace(self, old=old, n_elements=n_elements) 235 | 236 | 237 | @singledispatch 238 | def peephole_statement(self: Statement) -> Statement: 239 | raise NotImplementedError(f"peephole not implemented for {type(self)}: {self}") 240 | 241 | 242 | @peephole_statement.register(Expression) 243 | def peephole_expression_statement(self: Expression) -> Expression: 244 | # Expressions are statements in their own right 245 | return peephole_expression(self) 246 | 247 | 248 | @peephole_statement.register(Declaration) 249 | def peephole_declaration(self: Declaration) -> Declaration: 250 | return self 251 | 252 | 253 | @peephole_statement.register(Assignment) 254 | def peephole_assignment(self: Assignment) -> Statement: 255 | target = peephole_assignable(self.target) 256 | value = peephole_expression(self.value) 257 | 258 | if target == value: 259 | return Block([]) 260 | else: 261 | return Assignment(target, value) 262 | 263 | 264 | @peephole_statement.register(DeclarationAssignment) 265 | def peephole_declaration_assignment(self: DeclarationAssignment) -> Statement: 266 | value = peephole_expression(self.value) 267 | return replace(self, value=value) 268 | 269 | 270 | @peephole_statement.register(Block) 271 | def peephole_block(self: Block) -> Statement: 272 | statements = [] 273 | for old_statement in self.statements: 274 | statement = peephole_statement(old_statement) 275 | if isinstance(statement, Block) and statement.is_empty(): 276 | pass 277 | else: 278 | statements.append(statement) 279 | 280 | return replace(self, statements=statements) 281 | 282 | 283 | @peephole_statement.register(Branch) 284 | def peephole_branch(self: Branch) -> Statement: 285 | condition = peephole_expression(self.condition) 286 | if_true = peephole_statement(self.if_true) 287 | if_false = peephole_statement(self.if_false) 288 | 289 | if condition == BooleanLiteral(True): 290 | return if_true 291 | elif condition == BooleanLiteral(False): 292 | return if_false 293 | elif ( 294 | isinstance(if_true, Block) 295 | and if_true.is_empty() 296 | and isinstance(if_false, Block) 297 | and if_false.is_empty() 298 | ): 299 | return Block([]) 300 | else: 301 | return Branch(condition, if_true, if_false) 302 | 303 | 304 | @peephole_statement.register(Loop) 305 | def peephole_loop(self: Loop) -> Statement: 306 | condition = peephole_expression(self.condition) 307 | body = peephole_statement(self.body) 308 | 309 | if condition == BooleanLiteral(False): 310 | return Block([]) 311 | elif isinstance(self.body, Block) and self.body.is_empty(): 312 | return Block([]) 313 | else: 314 | return Loop(condition, body) 315 | 316 | 317 | @peephole_statement.register(Return) 318 | def peephole_return(self: Return) -> Statement: 319 | value = peephole_expression(self.value) 320 | return Return(value) 321 | 322 | 323 | def peephole_function_definition(self: FunctionDefinition) -> FunctionDefinition: 324 | body = peephole_statement(self.body) 325 | return replace(self, body=body) 326 | 327 | 328 | def peephole(self: Module) -> Module: 329 | functions = [peephole_function_definition(function) for function in self.definitions] 330 | return Module(functions) 331 | --------------------------------------------------------------------------------