├── docs └── img │ ├── loss.png │ ├── perf.png │ └── scaling_law.png ├── .gitignore ├── BackendBench ├── scripts │ ├── __init__.py │ ├── dataset_filters.py │ ├── create_simple_test_ops.py │ ├── setup_operator_directories.py │ ├── get_tests_stat.py │ ├── create_watermarked_operators.py │ └── parquet_trace_converter.py ├── backends │ ├── base.py │ ├── aten.py │ ├── __init__.py │ ├── directory.py │ └── kernel_agent.py ├── suite │ ├── smoke.py │ ├── base.py │ ├── __init__.py │ ├── torchbench.py │ ├── opinfo.py │ └── facto.py ├── op_categories.py ├── opregistry.py ├── __init__.py ├── llm_client.py ├── kernel_templates.py ├── data_loaders.py └── eval.py ├── .pre-commit-config.yaml ├── .github ├── workflows │ ├── ruff.yml │ ├── smoke-test.yml │ └── failure-test.yml └── scripts │ └── check_license_headers.py ├── pytest.ini ├── LICENSE.md ├── CONTRIBUTING.md ├── test ├── test_torchbench_suite.py ├── test_smoke.py ├── test_adverse_cases.py ├── test_directory_backend.py ├── test_facto_suite.py ├── fixtures │ └── llm_response │ │ ├── add_good.txt │ │ ├── add_missing_target_functions.txt │ │ └── add_missing_python_code_block.txt ├── test_backend_evaluation.py ├── test_backends.py ├── test_llm_backend.py ├── test_suite.py ├── test_monkey_patch.py ├── test_eval.py └── test_output.py ├── pyproject.toml ├── README.md └── CODE_OF_CONDUCT.md /docs/img/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/BackendBench/HEAD/docs/img/loss.png -------------------------------------------------------------------------------- /docs/img/perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/BackendBench/HEAD/docs/img/perf.png -------------------------------------------------------------------------------- /docs/img/scaling_law.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/BackendBench/HEAD/docs/img/scaling_law.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | .claude/ 4 | .vscode/ 5 | .ruff_cache/ 6 | backendbench.egg-info/ 7 | CLAUDE.md 8 | venv/ 9 | ops/ 10 | datasets/ 11 | uv.lock 12 | .pre-commit-cache/ 13 | logs/ 14 | generated_kernels/ 15 | *.csv 16 | backendbench_output* 17 | .DS_Store 18 | *.bak -------------------------------------------------------------------------------- /BackendBench/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Scripts module for BackendBench 8 | -------------------------------------------------------------------------------- /BackendBench/backends/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class Backend: 9 | def __init__(self, name): 10 | self.name = name 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.12.1 4 | hooks: 5 | - id: ruff 6 | args: [--fix] 7 | - id: ruff-format 8 | - repo: local 9 | hooks: 10 | - id: check-license-headers 11 | name: Check license headers 12 | entry: uv run python .github/scripts/check_license_headers.py 13 | language: system 14 | files: \.py$ 15 | pass_filenames: true -------------------------------------------------------------------------------- /BackendBench/backends/aten.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .base import Backend 8 | 9 | 10 | class AtenBackend(Backend): 11 | def __init__(self) -> None: 12 | super().__init__("aten") 13 | 14 | def __getitem__(self, key): 15 | return key 16 | 17 | def __contains__(self, key): 18 | return True 19 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | ruff: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v3 18 | 19 | - name: Set up Python 20 | run: uv python install 3.13 21 | 22 | - name: Install package with dev dependencies 23 | run: uv sync --dev 24 | 25 | - name: Run ruff check 26 | run: uv run ruff check . 27 | 28 | - name: Run ruff format check 29 | run: uv run ruff format --check . 30 | 31 | - name: Check license headers 32 | run: python .github/scripts/check_license_headers.py $(find . -name "*.py" -type f -not -path "./.venv/*" -not -path "./__pycache__/*") -------------------------------------------------------------------------------- /BackendBench/suite/smoke.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from BackendBench.opregistry import get_operator 10 | 11 | from .base import OpTest, Test, TestSuite 12 | 13 | 14 | def randn(*args, **kwargs): 15 | return lambda: torch.randn(*args, **kwargs) 16 | 17 | 18 | SmokeTestSuite = TestSuite( 19 | "smoke", 20 | [ 21 | OpTest( 22 | get_operator(torch.ops.aten.relu.default), 23 | [ 24 | Test( 25 | randn(2, device="cpu"), 26 | ), 27 | ], 28 | [ 29 | Test( 30 | randn(2**10, 2**10, device="cpu"), 31 | ), 32 | ], 33 | ) 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /BackendBench/suite/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class Test: 9 | def __init__(self, *args, **kwargs): 10 | self._args = args 11 | self._kwargs = kwargs 12 | 13 | @property 14 | def args(self): 15 | return [arg() for arg in self._args] 16 | 17 | @property 18 | def kwargs(self): 19 | return {k: v() for k, v in self._kwargs.items()} 20 | 21 | 22 | class OpTest: 23 | def __init__(self, op, correctness_tests, performance_tests): 24 | self.op = op 25 | self.correctness_tests = correctness_tests 26 | self.performance_tests = performance_tests 27 | 28 | 29 | class TestSuite: 30 | def __init__(self, name, optests): 31 | self.name = name 32 | self.optests = optests 33 | 34 | def __iter__(self): 35 | for optest in self.optests: 36 | yield optest 37 | -------------------------------------------------------------------------------- /BackendBench/suite/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | BackendBench suites submodule. 9 | 10 | This module provides various test suite implementations for benchmarking 11 | PyTorch operations across different backends. Each test suite defines a 12 | collection of tests to evaluate the correctness and/or performacne of 13 | backend implementations by comparing them against PyTorch operations. 14 | """ 15 | 16 | from .base import OpTest, Test, TestSuite 17 | from .facto import FactoTestSuite 18 | from .opinfo import OpInfoTestSuite 19 | from .smoke import randn, SmokeTestSuite 20 | from .torchbench import TorchBenchOpTest, TorchBenchTestSuite 21 | 22 | __all__ = [ 23 | "Test", 24 | "OpTest", 25 | "TestSuite", 26 | "FactoTestSuite", 27 | "OpInfoTestSuite", 28 | "SmokeTestSuite", 29 | "randn", 30 | "TorchBenchOpTest", 31 | "TorchBenchTestSuite", 32 | ] 33 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # Pytest configuration for BackendBench 3 | 4 | # Test discovery patterns 5 | python_files = test_*.py 6 | python_classes = Test* 7 | python_functions = test_* 8 | 9 | # Test directories 10 | testpaths = test 11 | 12 | # Output options 13 | addopts = 14 | -v 15 | --tb=short 16 | --strict-markers 17 | --disable-warnings 18 | -p no:warnings 19 | 20 | # Markers for categorizing tests 21 | markers = 22 | smoke: Basic smoke tests that should always pass 23 | unit: Unit tests for individual components 24 | integration: Integration tests that test multiple components 25 | slow: Tests that take a long time to run 26 | requires_cuda: Tests that require CUDA/GPU 27 | requires_api_key: Tests that require API keys (e.g., for LLM backends) 28 | 29 | # Coverage settings (if pytest-cov is installed) 30 | [coverage:run] 31 | source = BackendBench 32 | omit = 33 | */test/* 34 | */tests/* 35 | setup.py 36 | 37 | [coverage:report] 38 | exclude_lines = 39 | pragma: no cover 40 | def __repr__ 41 | raise AssertionError 42 | raise NotImplementedError 43 | if __name__ == .__main__.: -------------------------------------------------------------------------------- /BackendBench/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | BackendBench backends submodule. 9 | 10 | This module provides various backend implementations for PyTorch operations. 11 | Each backend implements a different strategy for mapping PyTorch operations 12 | to alternative implementations. 13 | """ 14 | 15 | import importlib.util 16 | 17 | from .aten import AtenBackend 18 | from .base import Backend 19 | from .directory import DirectoryBackend 20 | from .flag_gems import FlagGemsBackend 21 | from .kernel_agent import KernelAgentBackend 22 | from .llm import LLMBackend 23 | 24 | __all__ = [ 25 | "Backend", 26 | "DirectoryBackend", 27 | "AtenBackend", 28 | "FlagGemsBackend", 29 | "LLMBackend", 30 | ] 31 | 32 | if importlib.util.find_spec("triton_kernel_agent") is not None: 33 | from .kernel_agent import KernelAgentBackend 34 | 35 | __all__.append("KernelAgentBackend") 36 | else: 37 | KernelAgentBackend = None 38 | -------------------------------------------------------------------------------- /.github/workflows/smoke-test.yml: -------------------------------------------------------------------------------- 1 | name: Smoke Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | smoke-test: 13 | runs-on: 4-core-ubuntu-gpu-t4 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v3 20 | 21 | - name: Set up Python 22 | run: uv python install 3.13 23 | 24 | - name: Install package and dependencies 25 | run: uv sync --dev 26 | 27 | - name: Clone FlagGems source 28 | run: git clone https://github.com/FlagOpen/FlagGems.git 29 | 30 | - name: Build and install FlagGems 31 | run: uv pip install FlagGems/ 32 | 33 | - name: Clone FACTO source 34 | run: git clone https://github.com/pytorch-labs/FACTO.git 35 | 36 | - name: Build and install FACTO 37 | run: uv pip install FACTO/ 38 | 39 | - name: Run smoke test 40 | run: uv run python -m BackendBench.scripts.main --suite smoke --backend aten 41 | 42 | - name: Run FACTO test 43 | run: uv run python -m BackendBench.scripts.main --suite facto --backend aten --ops "add.Tensor" 44 | 45 | - name: Run pytest tests 46 | run: uv run pytest test/ 47 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2025 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to BackendBench 2 | 3 | ## License 4 | 5 | BackendBench is BSD-3-Clause licensed, as found in the LICENSE file. 6 | 7 | ## Our Development Process 8 | 9 | BackendBench is actively developed internally at Meta and synced to GitHub regularly. External contributions are welcomed and will be reviewed by the Meta team. 10 | 11 | ## Code Quality 12 | 13 | We use [ruff](https://docs.astral.sh/ruff/) for linting and code formatting. 14 | 15 | ## Pre-commit Hooks 16 | 17 | To make development easier, we provide pre-commit hooks that automatically run ruff on your changes: 18 | 19 | ```bash 20 | pip install pre-commit 21 | pre-commit install 22 | ``` 23 | 24 | This will automatically lint your code before each commit, ensuring consistent code quality across the project. 25 | 26 | ## Pull Requests 27 | 28 | We actively welcome your pull requests. 29 | 30 | 1. Fork the repo and create your branch from `main`. 31 | 2. If you've added code that should be tested, add tests. 32 | 3. If you've changed APIs, update the documentation. 33 | 4. Ensure the test suite passes. 34 | 5. Make sure your code lints. 35 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 36 | 37 | ## Contributor License Agreement ("CLA") 38 | 39 | In order to accept your pull request, we need you to submit a CLA. You only need 40 | to do this once to work on any of Meta's open source projects. 41 | 42 | Complete your CLA here: 43 | 44 | ## Issues 45 | 46 | We use GitHub issues to track public bugs. Please ensure your description is 47 | clear and has sufficient instructions to be able to reproduce the issue. -------------------------------------------------------------------------------- /test/test_torchbench_suite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from BackendBench.suite import TorchBenchOpTest 10 | 11 | 12 | class TestOpTest: 13 | def test_op_test(self): 14 | op_test = TorchBenchOpTest( 15 | "aten.relu.default", ["((T([32, 128, 512], f16, None, 'cpu'),), {})"], None 16 | ) 17 | for test in op_test.correctness_tests: 18 | args, kwargs = test.args, test.kwargs 19 | arg, *extras = args 20 | assert arg.shape == torch.Size([32, 128, 512]) 21 | assert arg.dtype == torch.float16 22 | assert kwargs == {} 23 | assert extras == [] 24 | 25 | torch.testing.assert_close(torch.relu(arg), op_test.op(arg)) 26 | 27 | def test_topn(self): 28 | op_test = TorchBenchOpTest( 29 | "aten.relu.default", 30 | [ 31 | "((T([32, 128, 512], f16, None, 'cpu'),), {})", 32 | "((T([32, 256, 512], f16, None, 'cpu'),), {})", 33 | ], 34 | 1, 35 | ) 36 | assert len(op_test.tests()) == 1 37 | for test in op_test.correctness_tests: 38 | args, kwargs = test.args, test.kwargs 39 | arg, *extras = args 40 | assert arg.shape == torch.Size([32, 256, 512]) 41 | assert arg.dtype == torch.float16 42 | assert kwargs == {} 43 | assert extras == [] 44 | 45 | torch.testing.assert_close(torch.relu(arg), op_test.op(arg)) 46 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "backendbench" 7 | version = "0.1.0" 8 | description = "A PyTorch backend evaluation suite" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Development Status :: 3 - Alpha", 13 | "Intended Audience :: Developers", 14 | "License :: OSI Approved :: MIT License", 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.10", 17 | "Programming Language :: Python :: 3.11", 18 | ] 19 | dependencies = [ 20 | "torch", 21 | "click", 22 | "numpy", 23 | "expecttest", 24 | "anthropic>=0.34.0", 25 | "pytest", 26 | "requests", 27 | "huggingface_hub", 28 | "pandas", 29 | "datasets", 30 | "tenacity", 31 | "nvidia-cutlass-dsl", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | flaggems = [ 36 | # flag_gems must be installed from source: https://github.com/FlagOpen/FlagGems 37 | ] 38 | facto = [ 39 | # facto must be installed from source: https://github.com/pytorch-labs/FACTO 40 | ] 41 | 42 | [project.scripts] 43 | backendbench = "BackendBench.scripts.main:cli" 44 | 45 | [tool.hatch.build.targets.wheel] 46 | packages = ["BackendBench"] 47 | 48 | [tool.uv] 49 | dev-dependencies = [ 50 | "pytest", 51 | "pytest-cov", 52 | "pytest-mock", 53 | "pytest-timeout", 54 | "ruff==0.12.1", 55 | "pre-commit", 56 | "torch", 57 | "numpy", 58 | "pyarrow", 59 | # cupy-cuda12x is platform specific, install manually if needed 60 | ] 61 | 62 | [tool.ruff] 63 | line-length = 100 64 | 65 | [tool.ruff.lint] 66 | extend-select = ["I", "W292", "W291"] # Enable isort rules, final newline rule, and trailing whitespace rule 67 | 68 | [tool.ruff.lint.isort] 69 | case-sensitive = false 70 | combine-as-imports = true 71 | detect-same-package = false 72 | order-by-type = false 73 | -------------------------------------------------------------------------------- /.github/scripts/check_license_headers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #!/usr/bin/env python3 8 | 9 | import argparse 10 | import sys 11 | 12 | REQUIRED_LICENSE_TEXT = "Copyright (c) Meta Platforms, Inc. and affiliates." 13 | 14 | 15 | def check_license_header(file_path): 16 | """Check if a Python file has the required license header.""" 17 | try: 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | content = f.read() 20 | return REQUIRED_LICENSE_TEXT in content 21 | except Exception as e: 22 | print(f"Error reading {file_path}: {e}") 23 | return False 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(description="Check license headers in Python files") 28 | parser.add_argument("files", nargs="*", help="Files to check") 29 | args = parser.parse_args() 30 | 31 | if not args.files: 32 | return 0 33 | 34 | missing_headers = [] 35 | 36 | for file_path in args.files: 37 | if file_path.endswith(".py"): 38 | if not check_license_header(file_path): 39 | missing_headers.append(file_path) 40 | 41 | if missing_headers: 42 | print("Missing license headers in the following files:") 43 | for file_path in missing_headers: 44 | print(f" - {file_path}") 45 | print("\nPlease add the following license header to the top of each file:") 46 | print("# Copyright (c) Meta Platforms, Inc. and affiliates.") 47 | print("# All rights reserved.") 48 | print("#") 49 | print("# This source code is licensed under the BSD 3-Clause license found in the") 50 | print("# LICENSE file in the root directory of this source tree.") 51 | return 1 52 | 53 | return 0 54 | 55 | 56 | if __name__ == "__main__": 57 | sys.exit(main()) 58 | -------------------------------------------------------------------------------- /test/test_smoke.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import pytest 9 | import torch 10 | 11 | import BackendBench.backends as backends 12 | from BackendBench.eval import eval_one_op 13 | from BackendBench.suite import SmokeTestSuite 14 | 15 | 16 | class TestSmoke: 17 | @pytest.fixture 18 | def aten_backend(self): 19 | return backends.AtenBackend() 20 | 21 | def test_smoke_suite_aten_backend(self, aten_backend): 22 | overall_correctness = [] 23 | overall_performance = [] 24 | 25 | for test in SmokeTestSuite: 26 | if test.op not in aten_backend: 27 | pytest.skip(f"Operation {test.op} not in backend") 28 | 29 | correctness, perf, correctness_results, performance_results = eval_one_op( 30 | test.op, 31 | aten_backend[test.op], 32 | test.correctness_tests, 33 | test.performance_tests, 34 | ) 35 | 36 | is_correct = all(result.is_correct for result in correctness_results) 37 | overall_correctness.append(is_correct) 38 | overall_performance.append(perf) 39 | 40 | assert len(correctness_results) == len(test.correctness_tests) 41 | assert len(performance_results) == len(test.performance_tests) 42 | 43 | assert correctness > 0, f"Operation {test.op} failed all correctness tests" 44 | assert perf > 0.1, f"Operation {test.op} is more than 10x slower than reference" 45 | 46 | mean_correctness = torch.tensor(overall_correctness).float().mean().item() 47 | geomean_perf = torch.tensor(overall_performance).log().mean().exp().item() 48 | 49 | assert mean_correctness >= 0.8, ( 50 | f"Mean correctness {mean_correctness:.2f} is below threshold of 0.8" 51 | ) 52 | assert geomean_perf >= 0.5, ( 53 | f"Geomean performance {geomean_perf:.2f} is below threshold of 0.5" 54 | ) 55 | 56 | print(f"Correctness score (mean pass rate): {mean_correctness:.2f}") 57 | print(f"Performance score (geomean speedup): {geomean_perf:.2f}") 58 | -------------------------------------------------------------------------------- /BackendBench/op_categories.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | TENSOR_CREATION_AND_MANIPULATION_OPS = [ 8 | "cat.default", 9 | "cat.out", 10 | "cat.names", 11 | "cat.names_outclone.default", 12 | "clone.out", 13 | "copy_.default", 14 | "elu_backward.default", 15 | "elu_backward.grad_input", 16 | "masked_fill_.Scalarmasked_fill_.Tensor", 17 | "new_empty.default", 18 | "new_empty.out", 19 | "new_empty_strided.default", 20 | "new_empty_strided.out", 21 | "new_full.default", 22 | "new_full.out", 23 | "new_ones.default", 24 | "new_ones.out", 25 | "new_zeros.default", 26 | "new_zeros.out", 27 | "nonzero.default", 28 | "nonzero.out", 29 | "repeat.default", 30 | "repeat.out", 31 | "split.Tensor", 32 | "split_with_sizes.default", 33 | "unsqueeze_.default", 34 | ] 35 | 36 | RANDOM_OPS = [ 37 | "bernoulli.default", 38 | "bernoulli.out", 39 | "bernoulli.Tensor", 40 | "bernoulli.Tensor_out", 41 | ] 42 | 43 | # Operators to skip for indexing ops that need valid indices 44 | UNSUPPORTED_OPERATORS = [ 45 | "embedding.default", 46 | "embedding.out", 47 | "scatter.src", 48 | "scatter.src_out", 49 | "scatter.reduce", 50 | "scatter.reduce_out", 51 | "scatter.value", 52 | "scatter.value_out", 53 | "scatter.value_reduce", 54 | "scatter.value_reduce_outgather.default", 55 | "gather.out", 56 | "gather.dimname", 57 | "gather.dimname_outindex.Tensor", 58 | "index.Tensor_outnll_loss.default", 59 | "nll_loss.outim2col_backward.default", 60 | "im2col_backward.default", 61 | "im2col_backward.grad_input", 62 | "col2im_backward.default", 63 | "col2im_backward.grad_input", 64 | "native_layer_norm_backward.default", 65 | "native_layer_norm_backward.out", 66 | "upsample_nearest2d_backward.default", 67 | "upsample_nearest2d_backward.grad_input", 68 | "upsample_bilinear2d_backward.default", 69 | "upsample_bilinear2d_backward.grad_input", 70 | "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM 71 | "_cudnn_rnn_backward.out", 72 | "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision 73 | "_fft_c2c.out", 74 | "_cudnn_rnn.default", # We are running into numerical stability issues with running the forward pass multiple times 75 | "_cudnn_rnn.out", 76 | ] 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BackendBench 2 | 3 | BackendBench is an evaluation suite for testing how well LLMs and humans can write PyTorch backends. It lets developers add custom kernels in an organized directory structure and dynamically override PyTorch's core operators at runtime resulting in a fully functional PyTorch backend you can pip install and use with existing models, no modeling code changes required. 4 | 5 | Features: 6 | 1. Comprehensive edge case correctness testing via PyTorch's OpInfo and FACTO test suites 7 | 2. Performance benchmarks using real tensor shapes from popular Hugging Face models 8 | 3. Clean path to upstream your kernels to PyTorch (if it passes our tests, it's likely correct enough to merge) 9 | 10 | Many kernel optimization efforts struggle with correctness. Our approach ensures your kernels are production-ready by meeting PyTorch's own standards. You can learn about correcntess in our [launch blog](docs/correctness.md) and [launch video](https://www.youtube.com/watch?v=BTfjdyZOKww) 11 | 12 | ## Installation: 13 | 14 | ```bash 15 | pip install . 16 | ``` 17 | 18 | ## LLM Kernel Development Workflow 19 | 20 | 1. **Create operator directories**: 21 | ```bash 22 | python -m BackendBench.scripts.setup_operator_directories 23 | ``` 24 | 25 | 2. **Implement kernels** in each directory you'll see an empty op implementation. Please get your LLM to fill it out! 26 | 27 | 3. **Test your implementations**: 28 | 29 | ```bash 30 | # smoke test to make sure everything is in check 31 | python BackendBench/scripts/main.py --suite smoke --backend aten 32 | 33 | # OpInfo correctness tests 34 | python BackendBench/scripts/main.py --suite opinfo --backend directory 35 | 36 | # TorchBench performance tests 37 | python BackendBench/scripts/main.py --suite torchbench --backend directory 38 | ``` 39 | 40 | ## Example: Train nanoGPT using BackendBench with LLM generated kernels 41 | 42 | See [BackendBench Example](https://github.com/jiannanWang/BackendBenchExamples) for a practical demonstration of how to use BackendBench for model convergence testing. 43 | 44 | ## Citation 45 | 46 | If you use BackendBench in your research or projects, please cite it as: 47 | 48 | ```bibtex 49 | @software{saroufim2025backendbench, 50 | author = {Mark Saroufim and Jiannan Wang and Bert Maher and Sahan Paliskara and Laura Wang and Shahin Sefati and Manuel Candales}, 51 | title = {BackendBench: An Evaluation Suite for Testing How Well LLMs and Humans Can Write PyTorch Backends}, 52 | year = {2025}, 53 | url = {https://github.com/meta-pytorch/BackendBench} 54 | } 55 | ``` 56 | 57 | ## License 58 | 59 | Source code is made available under a [BSD 3 license](LICENSE.md) 60 | -------------------------------------------------------------------------------- /.github/workflows/failure-test.yml: -------------------------------------------------------------------------------- 1 | name: Failure Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | smoke-failure-test: 11 | runs-on: 4-core-ubuntu-gpu-t4 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Install uv 15 | uses: astral-sh/setup-uv@v3 16 | - name: Set up Python 17 | run: uv python install 3.13 18 | - name: Install package and dependencies 19 | run: uv sync --dev 20 | - name: Run smoke test 21 | run: | 22 | uv run python -m BackendBench.scripts.main --suite smoke --backend aten --log-dir smoke_output 23 | - name: Check smoke test failures 24 | run: | 25 | if [ -s "smoke_output/failed_ops.json" ] && [ "$(jq 'length' smoke_output/failed_ops.json)" -ne 0 ]; then 26 | echo "Some operations failed in the smoke test." 27 | cat "smoke_output/failed_ops.json" 28 | exit 1 29 | else 30 | echo "All operations passed the smoke test." 31 | fi 32 | 33 | opinfo-failure-test: 34 | runs-on: 4-core-ubuntu-gpu-t4 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Install uv 38 | uses: astral-sh/setup-uv@v3 39 | - name: Set up Python 40 | run: uv python install 3.13 41 | - name: Install package and dependencies 42 | run: uv sync --dev 43 | - name: Run opinfo test 44 | run: | 45 | uv run python -m BackendBench.scripts.main --suite opinfo --backend aten --log-dir opinfo_output 46 | - name: Check opinfo test failures 47 | run: | 48 | if [ -s "opinfo_output/failed_ops.json" ] && [ "$(jq 'length' opinfo_output/failed_ops.json)" -ne 0 ]; then 49 | echo "Some operations failed in the opinfo test." 50 | cat "opinfo_output/failed_ops.json" 51 | exit 1 52 | else 53 | echo "All operations passed the opinfo test." 54 | fi 55 | 56 | facto-failure-test: 57 | runs-on: 4-core-ubuntu-gpu-t4 58 | steps: 59 | - uses: actions/checkout@v4 60 | - name: Install uv 61 | uses: astral-sh/setup-uv@v3 62 | - name: Set up Python 63 | run: uv python install 3.13 64 | - name: Install package and dependencies 65 | run: uv sync --dev 66 | - name: Clone FACTO source 67 | run: git clone https://github.com/pytorch-labs/FACTO.git 68 | - name: Build and install FACTO 69 | run: cd FACTO && uv pip install . 70 | - name: Run FACTO test 71 | run: | 72 | uv run python -m BackendBench.scripts.main --suite facto --backend aten --log-dir facto_output 73 | - name: Check FACTO test failures 74 | run: | 75 | if [ -s "facto_output/failed_ops.json" ] && [ "$(jq 'length' facto_output/failed_ops.json)" -ne 0 ]; then 76 | echo "Some operations failed in the FACTO test." 77 | cat "facto_output/failed_ops.json" 78 | exit 1 79 | else 80 | echo "All operations passed the FACTO test." 81 | fi -------------------------------------------------------------------------------- /BackendBench/scripts/dataset_filters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import tqdm 10 | from triton.testing import do_bench 11 | 12 | from BackendBench.op_categories import ( 13 | RANDOM_OPS, 14 | TENSOR_CREATION_AND_MANIPULATION_OPS, 15 | UNSUPPORTED_OPERATORS, 16 | ) 17 | from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args 18 | 19 | # We get this threshhold from the analysis here 20 | # https://github.com/meta-pytorch/BackendBench/issues/108 21 | RELATIVE_RUNTIME_THRESHOLD = 1.3 22 | 23 | 24 | def apply_skip_ops_filter(ops): 25 | for op in tqdm.tqdm(ops, desc="Filtering ops by skip and synthetic ops"): 26 | op_name = op["op_name"] 27 | if any(s in op_name for s in UNSUPPORTED_OPERATORS): 28 | op["included_in_benchmark"] = False 29 | op["why_excluded"].append("We cannot run this op on backendbench yet") 30 | op["runnable"] = False 31 | 32 | if any(s in op_name for s in RANDOM_OPS): 33 | op["included_in_benchmark"] = False 34 | op["why_excluded"].append( 35 | "BackendBench does not support correctness testing for random ops yet" 36 | ) 37 | 38 | if any(s in op_name for s in TENSOR_CREATION_AND_MANIPULATION_OPS): 39 | op["included_in_benchmark"] = False 40 | op["why_excluded"].append( 41 | "BackendBench does not support correctness testing for tensor creation and manipulation ops yet" 42 | ) 43 | 44 | if op["is_synthetic"]: 45 | op["included_in_benchmark"] = False 46 | op["why_excluded"].append( 47 | "Synthetic ops are not supported in the official benchmark yet" 48 | ) 49 | op["runnable"] = False 50 | return ops 51 | 52 | 53 | def apply_runtime_filter(ops): 54 | def _overhead_benchmark(): 55 | return torch.empty(0, device="cuda") 56 | 57 | runtime_threshold_ms = do_bench(_overhead_benchmark, warmup=25, rep=100) 58 | 59 | for op in tqdm.tqdm(ops, desc="Filtering ops by runtime"): 60 | if op["runnable"]: 61 | args, kwargs = deserialize_args(op["args"]) 62 | try: 63 | op_name = op["op_name"] 64 | op_func = eval(f"torch.ops.{op_name}") 65 | ms = do_bench(lambda: op_func(*args, **kwargs), warmup=25, rep=100) 66 | del args, kwargs 67 | cleanup_memory_and_gpu() 68 | except Exception as e: 69 | # if we can't run the op, we cannot expect others to run it either 70 | op["why_excluded"].append(f"Failed to run: {e}") 71 | op["runnable"] = False 72 | op["included_in_benchmark"] = False 73 | del args, kwargs 74 | cleanup_memory_and_gpu() 75 | continue 76 | op["runtime_ms"] = ms 77 | relative_runtime = ms / runtime_threshold_ms 78 | op["relative_runtime_to_kernel_launch"] = relative_runtime 79 | if relative_runtime < RELATIVE_RUNTIME_THRESHOLD: 80 | op["is_overhead_dominated_op"] = True 81 | return ops 82 | -------------------------------------------------------------------------------- /test/test_adverse_cases.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | import BackendBench.backends as backends 11 | import BackendBench.multiprocessing_eval as multiprocessing_eval 12 | from BackendBench.suite import TorchBenchOpTest 13 | 14 | 15 | class TestAdaptiveAvgPool2dBackward: 16 | @pytest.mark.skip(reason="Skipped due to tensor size causing CUDA OOM in smoke test.") 17 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU") 18 | def test_adaptive_avg_pool2d_backward_gpu(self): 19 | """Test on GPU with eval_one_op.""" 20 | op_test_should_error = TorchBenchOpTest( 21 | "aten._adaptive_avg_pool2d_backward.default", 22 | ["((T([512, 4096, 56, 56], f16), T([512, 4096, 56, 56], f16)), {})"], 23 | None, 24 | ) 25 | 26 | op_test_should_succeed = TorchBenchOpTest( 27 | "aten.addmm.default", 28 | ["((T([14, 14], f32), T([14, 14], f32), T([14, 14], f32)), {})"], 29 | None, 30 | ) 31 | 32 | # run test that should brick the gpu due to an illegal memory access 33 | backend = backends.AtenBackend() 34 | with multiprocessing_eval.MultiprocessingEvaluator() as evaluator: 35 | evaluator.submit_task( 36 | op_test_should_error.op, 37 | backend[op_test_should_error.op], 38 | list(op_test_should_error.correctness_tests), 39 | list(op_test_should_error.performance_tests), 40 | ) 41 | evaluator.submit_task( 42 | op_test_should_succeed.op, 43 | backend[op_test_should_succeed.op], 44 | list(op_test_should_succeed.correctness_tests), 45 | list(op_test_should_succeed.performance_tests), 46 | ) 47 | evaluator.start_evaluation() 48 | 49 | results = evaluator.get_results() 50 | 51 | assert len(results) == 1 52 | assert results[0].correctness_score == 1.0 53 | 54 | 55 | class TestCase: 56 | def __init__(self, args, kwargs): 57 | self.args = args 58 | self.kwargs = kwargs 59 | 60 | 61 | class TestMultiprocessingEval: 62 | def test_multiprocessing_evaluator(self): 63 | op = torch.relu 64 | impl = torch.relu # Same implementation 65 | 66 | correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)] 67 | performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)] 68 | 69 | with multiprocessing_eval.MultiprocessingEvaluator() as evaluator: 70 | evaluator.submit_task(op, impl, correctness_tests, performance_tests) 71 | 72 | evaluator.start_evaluation() 73 | 74 | results = evaluator.get_results() 75 | 76 | assert len(results) == 1 77 | # Should have perfect correctness since using same implementation 78 | assert results[0].correctness_score == 1.0 79 | # Performance should be around 1.0 (same speed) 80 | assert results[0].performance_score.item() > 0 81 | 82 | 83 | if __name__ == "__main__": 84 | pytest.main([__file__, "-v", "-s"]) 85 | -------------------------------------------------------------------------------- /test/test_directory_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Test DirectoryBackend with 5 kernel implementations. 9 | """ 10 | 11 | import os 12 | import sys 13 | 14 | sys.path.insert(0, ".") 15 | 16 | import pytest 17 | import torch 18 | 19 | from BackendBench.backends import DirectoryBackend 20 | from BackendBench.utils import op_name_to_folder_name 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def backend(): 25 | # Always create correct test implementations, overriding any watermarked ones 26 | import subprocess 27 | 28 | subprocess.run( 29 | [sys.executable, "-m", "BackendBench.scripts.create_simple_test_ops"], check=True 30 | ) 31 | 32 | return DirectoryBackend(ops_dir="generated_kernels") 33 | 34 | 35 | def test_relu_operation(backend): 36 | relu_op = torch.ops.aten.relu.default 37 | assert relu_op in backend 38 | 39 | our_impl = backend[relu_op] 40 | x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) 41 | result = our_impl(x) 42 | expected = relu_op(x) 43 | 44 | assert torch.allclose(result, expected) 45 | 46 | 47 | def test_add_operation(backend): 48 | add_op = torch.ops.aten.add.Tensor 49 | assert add_op in backend 50 | 51 | our_impl = backend[add_op] 52 | a = torch.tensor([1.0, 2.0, 3.0]) 53 | b = torch.tensor([4.0, 5.0, 6.0]) 54 | result = our_impl(a, b) 55 | expected = add_op(a, b) 56 | 57 | assert torch.allclose(result, expected) 58 | 59 | 60 | def test_mul_operation(backend): 61 | mul_op = torch.ops.aten.mul.Tensor 62 | assert mul_op in backend 63 | 64 | our_impl = backend[mul_op] 65 | a = torch.tensor([1.0, 2.0, 3.0]) 66 | b = torch.tensor([4.0, 5.0, 6.0]) 67 | result = our_impl(a, b) 68 | expected = mul_op(a, b) 69 | 70 | assert torch.allclose(result, expected) 71 | 72 | 73 | def test_abs_operation(backend): 74 | abs_op = torch.ops.aten.abs.default 75 | assert abs_op in backend 76 | 77 | our_impl = backend[abs_op] 78 | x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) 79 | result = our_impl(x) 80 | expected = abs_op(x) 81 | 82 | assert torch.allclose(result, expected) 83 | 84 | 85 | def test_sum_operation(backend): 86 | sum_op = torch.ops.aten.sum.default 87 | assert sum_op in backend 88 | 89 | our_impl = backend[sum_op] 90 | x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 91 | result = our_impl(x) 92 | expected = sum_op(x) 93 | 94 | assert torch.allclose(result, expected) 95 | 96 | 97 | def test_backend_loading(backend): 98 | loaded_ops = set(backend.compiled_kernels.keys()) 99 | assert len(loaded_ops) > 0 100 | 101 | if os.path.exists("generated_kernels"): 102 | dirs = [ 103 | d 104 | for d in os.listdir("generated_kernels") 105 | if os.path.isdir(os.path.join("generated_kernels", d)) 106 | ] 107 | assert len(dirs) > 0 108 | 109 | 110 | def test_kernel_directories_exist(backend): 111 | assert os.path.exists("generated_kernels") 112 | 113 | expected_ops = ["relu.default", "add.Tensor", "mul.Tensor", "abs.default", "sum.default"] 114 | for expected_op in expected_ops: 115 | expected_dir = op_name_to_folder_name(expected_op) 116 | dir_path = os.path.join("generated_kernels", expected_dir) 117 | assert os.path.isdir(dir_path) 118 | 119 | py_files = [f for f in os.listdir(dir_path) if f.endswith(".py")] 120 | assert len(py_files) > 0 121 | -------------------------------------------------------------------------------- /BackendBench/opregistry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def _extract_spec_name_from_op(op_obj): 15 | try: 16 | # PyTorch operator objects have _name attribute that contains the full name 17 | if hasattr(op_obj, "_name"): 18 | full_name = op_obj._name 19 | # full_name is typically like "aten::add.Tensor" 20 | if "::" in full_name: 21 | # Remove the "aten::" prefix 22 | spec_name = full_name.split("::", 1)[1] 23 | return spec_name 24 | return None 25 | 26 | except Exception as e: 27 | logger.debug(f"Failed to extract spec name from operator {op_obj}: {e}") 28 | return None 29 | 30 | 31 | class OpRegistry: 32 | def __init__(self): 33 | self._registry = {} 34 | 35 | def get_operator(self, input_obj): 36 | if isinstance(input_obj, str): 37 | return self._get_operator_from_spec_name(input_obj) 38 | else: 39 | return self._get_operator_from_object(input_obj) 40 | 41 | def _get_operator_from_spec_name(self, spec_name): 42 | # Return cached operator if available 43 | if spec_name in self._registry: 44 | return self._registry[spec_name] 45 | 46 | # Parse spec name 47 | op_parts = spec_name.split(".") 48 | op_name = op_parts[0] 49 | overload = op_parts[1] if len(op_parts) > 1 else "default" 50 | 51 | try: 52 | # Resolve operator using PyTorch's API 53 | op = getattr(torch.ops.aten, op_name).__getattr__(overload) 54 | 55 | # Cache the resolved operator 56 | self._registry[spec_name] = op 57 | # logger.debug(f"Registered operator: {spec_name} -> {op}") 58 | return op 59 | 60 | except AttributeError as e: 61 | logger.warning(f"Failed to resolve operator {spec_name}: {e}") 62 | return None 63 | 64 | def _get_operator_from_object(self, op_obj): 65 | # Extract spec name from the operator object 66 | spec_name = _extract_spec_name_from_op(op_obj) 67 | 68 | # Check if we already have this operator registered 69 | if spec_name in self._registry: 70 | return self._registry[spec_name] 71 | 72 | # Register the provided operator object 73 | self._registry[spec_name] = op_obj 74 | # logger.debug(f"Registered operator from object: {spec_name} -> {op_obj}") 75 | return op_obj 76 | 77 | def register_operator(self, op_obj): 78 | return self._get_operator_from_object(op_obj) 79 | 80 | def get_all_registered_ops(self): 81 | return self._registry.copy() 82 | 83 | def clear(self): 84 | self._registry.clear() 85 | 86 | def __len__(self): 87 | return len(self._registry) 88 | 89 | def __contains__(self, spec_name): 90 | """Check if operator is registered.""" 91 | return spec_name in self._registry 92 | 93 | def __repr__(self): 94 | return f"OpRegistry({len(self._registry)} ops)" 95 | 96 | 97 | # Global operator registry instance 98 | _op_registry = OpRegistry() 99 | 100 | 101 | def get_operator(input_obj): 102 | return _op_registry.get_operator(input_obj) 103 | 104 | 105 | def register_operator(op_obj): 106 | return _op_registry.register_operator(op_obj) 107 | 108 | 109 | def get_registry(): 110 | return _op_registry 111 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /test/test_facto_suite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.util 8 | 9 | import pytest 10 | import torch 11 | 12 | import BackendBench.backends as backends 13 | from BackendBench.eval import eval_one_op 14 | from BackendBench.suite import FactoTestSuite 15 | 16 | HAS_FACTO_DEPS = importlib.util.find_spec("facto") is not None 17 | 18 | pytestmark = pytest.mark.skipif(not HAS_FACTO_DEPS, reason="facto dependencies not available") 19 | 20 | 21 | class TestFactoSuite: 22 | def test_facto_suite_relu_default_correctness_not_empty(self): 23 | ops = ["relu.default"] 24 | num_runs = 10 25 | empty = False 26 | probability = 1.0 27 | 28 | suite = FactoTestSuite( 29 | name="facto_relu_test", 30 | device="cuda", 31 | dtype=torch.bfloat16, 32 | filter=ops, 33 | num_runs=num_runs, 34 | empty=empty, 35 | probability=probability, 36 | ) 37 | 38 | backend = backends.AtenBackend() 39 | 40 | # Track overall correctness and performance 41 | overall_correctness = [] 42 | 43 | # Iterate through the test suite (should contain relu operations) 44 | for test in suite: 45 | ctest_count = 0 46 | for ctest in test.correctness_tests: 47 | ctest_count += 1 48 | for arg in ctest.args: 49 | if isinstance(arg, torch.Tensor): 50 | # assert args not empty 51 | assert arg.numel() > 0, f"Tensor arg is empty for {test.op}" 52 | for key, value in ctest.kwargs.items(): 53 | if isinstance(value, torch.Tensor): 54 | # assert kwargs not empty 55 | assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}" 56 | 57 | # Evaluate the operation 58 | correctness, _, correctness_results, _ = eval_one_op( 59 | test.op, 60 | backend[test.op], # AtenBackend returns the original op 61 | test.correctness_tests, 62 | test.performance_tests, 63 | ) 64 | 65 | assert len(correctness_results) == ctest_count, ( 66 | f"Number of correctness results for {test.op} is not {ctest_count}" 67 | ) 68 | is_correct = all(result.is_correct for result in correctness_results) 69 | overall_correctness.append(is_correct) 70 | 71 | # Individual test assertions 72 | assert correctness > 0, f"Operation {test.op} failed all correctness tests" 73 | 74 | # Calculate mean correctness 75 | mean_correctness = torch.tensor(overall_correctness).float().mean().item() 76 | 77 | # Main assertion: correctness should be > 0.8 78 | assert mean_correctness > 0.8, ( 79 | f"Mean correctness {mean_correctness:.2f} is not > 0.8 for relu.default" 80 | ) 81 | 82 | def test_facto_suite_num_run(self): 83 | ops = ["relu.default"] 84 | num_runs = 10 85 | empty = False 86 | probability = 1.0 87 | 88 | suite = FactoTestSuite( 89 | name="facto_relu_test", 90 | device="cuda", 91 | dtype=torch.bfloat16, 92 | filter=ops, 93 | num_runs=num_runs, 94 | empty=empty, 95 | probability=probability, 96 | ) 97 | 98 | for test in suite: 99 | assert len(list(test.correctness_tests)) == num_runs, ( 100 | f"Number of correctness tests for {test.op} is not {num_runs}" 101 | ) 102 | -------------------------------------------------------------------------------- /test/fixtures/llm_response/add_good.txt: -------------------------------------------------------------------------------- 1 | ```python 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | @triton.jit 7 | def add_triton_kernel( 8 | x_ptr, 9 | y_ptr, 10 | output_ptr, 11 | n_elements, 12 | BLOCK_SIZE: tl.constexpr, 13 | ): 14 | pid = tl.program_id(axis=0) 15 | block_start = pid * BLOCK_SIZE 16 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 17 | mask = offsets < n_elements 18 | 19 | x = tl.load(x_ptr + offsets, mask=mask) 20 | y = tl.load(y_ptr + offsets, mask=mask) 21 | output = x + y 22 | tl.store(output_ptr + offsets, output, mask=mask) 23 | 24 | def add__Tensor_kernel_impl(*args, **kwargs): 25 | # Handle both positional and keyword arguments 26 | if len(args) >= 2: 27 | input_tensor = args[0] 28 | other = args[1] 29 | alpha = kwargs.get('alpha', 1.0) 30 | out = kwargs.get('out', None) 31 | elif len(args) == 1: 32 | input_tensor = args[0] 33 | other = kwargs.get('other', kwargs.get('input', None)) 34 | if other is None: 35 | raise TypeError("add() missing required argument: 'other'") 36 | alpha = kwargs.get('alpha', 1.0) 37 | out = kwargs.get('out', None) 38 | else: 39 | input_tensor = kwargs.get('input', None) 40 | other = kwargs.get('other', None) 41 | if input_tensor is None or other is None: 42 | raise TypeError("add() missing required arguments") 43 | alpha = kwargs.get('alpha', 1.0) 44 | out = kwargs.get('out', None) 45 | 46 | # Store original devices 47 | input_device = input_tensor.device 48 | other_device = other.device if torch.is_tensor(other) else input_device 49 | 50 | # Check CUDA availability 51 | if not torch.cuda.is_available(): 52 | raise RuntimeError("CUDA is not available") 53 | 54 | # Move tensors to GPU if needed 55 | if input_tensor.device.type == 'cpu': 56 | input_tensor = input_tensor.cuda() 57 | if torch.is_tensor(other) and other.device.type == 'cpu': 58 | other = other.cuda() 59 | 60 | # Handle scalar other 61 | if not torch.is_tensor(other): 62 | other = torch.tensor(other, device=input_tensor.device, dtype=input_tensor.dtype) 63 | 64 | # Broadcast tensors to same shape 65 | broadcasted_shape = torch.broadcast_shapes(input_tensor.shape, other.shape) 66 | input_tensor = input_tensor.broadcast_to(broadcasted_shape) 67 | other = other.broadcast_to(broadcasted_shape) 68 | 69 | # Apply alpha scaling 70 | if alpha != 1.0: 71 | other = other * alpha 72 | 73 | # Ensure contiguous tensors 74 | input_tensor = input_tensor.contiguous() 75 | other = other.contiguous() 76 | 77 | # Create output tensor 78 | if out is not None: 79 | if out.device.type == 'cpu': 80 | out = out.cuda() 81 | output = out.contiguous() 82 | if output.shape != broadcasted_shape: 83 | raise RuntimeError(f"Output tensor shape {output.shape} doesn't match broadcast shape {broadcasted_shape}") 84 | else: 85 | output = torch.empty(broadcasted_shape, dtype=input_tensor.dtype, device=input_tensor.device) 86 | 87 | n_elements = input_tensor.numel() 88 | 89 | if n_elements == 0: 90 | # Handle empty tensors 91 | result = output 92 | else: 93 | # Launch kernel 94 | BLOCK_SIZE = 1024 95 | grid = (triton.cdiv(n_elements, BLOCK_SIZE),) 96 | 97 | add_triton_kernel[grid]( 98 | input_tensor, 99 | other, 100 | output, 101 | n_elements, 102 | BLOCK_SIZE=BLOCK_SIZE, 103 | ) 104 | 105 | result = output 106 | 107 | # Move result back to original device 108 | target_device = input_device 109 | if result.device != target_device: 110 | result = result.to(target_device) 111 | 112 | return result 113 | ``` 114 | -------------------------------------------------------------------------------- /test/fixtures/llm_response/add_missing_target_functions.txt: -------------------------------------------------------------------------------- 1 | # This file deliberately breaks the naming convention (e.g. {op_name}_triton_kernel & {op_name}_kernel_impl) 2 | 3 | ```python 4 | import torch 5 | import triton 6 | import triton.language as tl 7 | 8 | @triton.jit 9 | def XYZ_triton_kernel( 10 | x_ptr, 11 | y_ptr, 12 | output_ptr, 13 | n_elements, 14 | BLOCK_SIZE: tl.constexpr, 15 | ): 16 | pid = tl.program_id(axis=0) 17 | block_start = pid * BLOCK_SIZE 18 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 19 | mask = offsets < n_elements 20 | 21 | x = tl.load(x_ptr + offsets, mask=mask) 22 | y = tl.load(y_ptr + offsets, mask=mask) 23 | output = x + y 24 | tl.store(output_ptr + offsets, output, mask=mask) 25 | 26 | def XYZ_kernel_impl(*args, **kwargs): 27 | # Handle both positional and keyword arguments 28 | if len(args) >= 2: 29 | input_tensor = args[0] 30 | other = args[1] 31 | alpha = kwargs.get('alpha', 1.0) 32 | out = kwargs.get('out', None) 33 | elif len(args) == 1: 34 | input_tensor = args[0] 35 | other = kwargs.get('other', kwargs.get('input', None)) 36 | if other is None: 37 | raise TypeError("add() missing required argument: 'other'") 38 | alpha = kwargs.get('alpha', 1.0) 39 | out = kwargs.get('out', None) 40 | else: 41 | input_tensor = kwargs.get('input', None) 42 | other = kwargs.get('other', None) 43 | if input_tensor is None or other is None: 44 | raise TypeError("add() missing required arguments") 45 | alpha = kwargs.get('alpha', 1.0) 46 | out = kwargs.get('out', None) 47 | 48 | # Store original devices 49 | input_device = input_tensor.device 50 | other_device = other.device if torch.is_tensor(other) else input_device 51 | 52 | # Check CUDA availability 53 | if not torch.cuda.is_available(): 54 | raise RuntimeError("CUDA is not available") 55 | 56 | # Move tensors to GPU if needed 57 | if input_tensor.device.type == 'cpu': 58 | input_tensor = input_tensor.cuda() 59 | if torch.is_tensor(other) and other.device.type == 'cpu': 60 | other = other.cuda() 61 | 62 | # Handle scalar other 63 | if not torch.is_tensor(other): 64 | other = torch.tensor(other, device=input_tensor.device, dtype=input_tensor.dtype) 65 | 66 | # Broadcast tensors to same shape 67 | broadcasted_shape = torch.broadcast_shapes(input_tensor.shape, other.shape) 68 | input_tensor = input_tensor.broadcast_to(broadcasted_shape) 69 | other = other.broadcast_to(broadcasted_shape) 70 | 71 | # Apply alpha scaling 72 | if alpha != 1.0: 73 | other = other * alpha 74 | 75 | # Ensure contiguous tensors 76 | input_tensor = input_tensor.contiguous() 77 | other = other.contiguous() 78 | 79 | # Create output tensor 80 | if out is not None: 81 | if out.device.type == 'cpu': 82 | out = out.cuda() 83 | output = out.contiguous() 84 | if output.shape != broadcasted_shape: 85 | raise RuntimeError(f"Output tensor shape {output.shape} doesn't match broadcast shape {broadcasted_shape}") 86 | else: 87 | output = torch.empty(broadcasted_shape, dtype=input_tensor.dtype, device=input_tensor.device) 88 | 89 | n_elements = input_tensor.numel() 90 | 91 | if n_elements == 0: 92 | # Handle empty tensors 93 | result = output 94 | else: 95 | # Launch kernel 96 | BLOCK_SIZE = 1024 97 | grid = (triton.cdiv(n_elements, BLOCK_SIZE),) 98 | 99 | XYZ_triton_kernel[grid]( 100 | input_tensor, 101 | other, 102 | output, 103 | n_elements, 104 | BLOCK_SIZE=BLOCK_SIZE, 105 | ) 106 | 107 | result = output 108 | 109 | # Move result back to original device 110 | target_device = input_device 111 | if result.device != target_device: 112 | result = result.to(target_device) 113 | 114 | return result 115 | ``` 116 | -------------------------------------------------------------------------------- /BackendBench/suite/torchbench.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Test suite that runs real-world PyTorch operation traces from serialized data files. 9 | 10 | Data Source: 11 | - Dataset: https://huggingface.co/datasets/GPUMODE/backendbench_tests 12 | - Configuration: Set in data_loaders.py: 13 | - HUGGINGFACE_REPO: HF repository name 14 | - TORCHBENCH_SUITE_FILE: Specific file name in the repo 15 | - TORCHBENCH_SUITE_HF_COMMIT: Git commit hash for reproducibility 16 | 17 | Updating the Test Set: 18 | 1. Choose a test file from https://huggingface.co/datasets/GPUMODE/backendbench_tests (it will likely be the same) 19 | 2. Update TORCHBENCH_SUITE_FILE in data_loaders.py with the file name (it will likely be the same) 20 | 3. Get the current commit hash: 21 | python -c "from huggingface_hub import HfApi; print(HfApi().dataset_info('GPUMODE/backendbench_tests', revision='main').sha)" 22 | 4. Update TORCHBENCH_SUITE_HF_COMMIT in data_loaders.py with the hash 23 | 24 | Creating New Test Sets: 25 | Use scripts/parquet_to_trace.py to generate and upload new datasets to HuggingFace. 26 | """ 27 | 28 | import torch # noqa: F401 29 | 30 | from BackendBench.data_loaders import ( 31 | _args_size, 32 | load_ops_from_source, 33 | op_list_to_benchmark_dict, 34 | ) 35 | from BackendBench.op_categories import UNSUPPORTED_OPERATORS 36 | from BackendBench.utils import deserialize_args 37 | 38 | 39 | class TorchBenchTest: 40 | def __init__(self, *args, **kwargs): 41 | self.args = args 42 | self.kwargs = kwargs 43 | 44 | 45 | class TorchBenchOpTest: 46 | def __init__(self, op, inputs, topn): 47 | self.op = eval(f"torch.ops.{op}") 48 | self.inputs = inputs 49 | self.topn = topn 50 | 51 | def tests(self): 52 | inputs_and_sizes = [] 53 | for inp in self.inputs: 54 | args, kwargs = deserialize_args(inp) 55 | size = _args_size(args) + _args_size(list(kwargs.values())) 56 | inputs_and_sizes.append((size, inp)) 57 | ret = [x[1] for x in sorted(inputs_and_sizes, reverse=True)] 58 | return ret if self.topn is None else ret[: self.topn] 59 | 60 | @property 61 | def correctness_tests(self): 62 | for inp in self.tests(): 63 | args, kwargs = deserialize_args(inp) 64 | yield TorchBenchTest(*args, **kwargs) 65 | 66 | @property 67 | def performance_tests(self): 68 | for inp in self.tests(): 69 | args, kwargs = deserialize_args(inp) 70 | yield TorchBenchTest(*args, **kwargs) 71 | 72 | 73 | class TorchBenchTestSuite: 74 | def __init__( 75 | self, 76 | name, 77 | filename=None, 78 | filter=None, 79 | topn=None, 80 | check_overhead_dominated_ops=False, 81 | ): 82 | self.name = name 83 | self.topn = topn 84 | # Load operations using the shared data loader 85 | ops_list = load_ops_from_source( 86 | source=filename, 87 | format="auto", # Auto-detect based on file extension 88 | filter=filter, 89 | ) 90 | if check_overhead_dominated_ops: 91 | # Only include ops which are overhead dominated (this is useful as a performance canary) 92 | ops_list = [op for op in ops_list if op.get("is_overhead_dominated_op", False)] 93 | 94 | # Convert to dictionary format using utility function 95 | self.optests = op_list_to_benchmark_dict(ops_list) 96 | 97 | # Deduplicate the strings in self.optests 98 | for op in self.optests: 99 | self.optests[op] = list(set(self.optests[op])) 100 | 101 | def __iter__(self): 102 | for op, inputs in self.optests.items(): 103 | if any(s in op for s in UNSUPPORTED_OPERATORS): 104 | continue 105 | yield TorchBenchOpTest(op, inputs, self.topn) 106 | -------------------------------------------------------------------------------- /test/fixtures/llm_response/add_missing_python_code_block.txt: -------------------------------------------------------------------------------- 1 | # This file deliberately mess up with the response payload 2 | # The kernel implementation is good but the response is missing the python code block 3 | 4 | import torch 5 | import triton 6 | import triton.language as tl 7 | 8 | @triton.jit 9 | def add_triton_kernel( 10 | x_ptr, 11 | y_ptr, 12 | output_ptr, 13 | n_elements, 14 | BLOCK_SIZE: tl.constexpr, 15 | ): 16 | pid = tl.program_id(axis=0) 17 | block_start = pid * BLOCK_SIZE 18 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 19 | mask = offsets < n_elements 20 | 21 | x = tl.load(x_ptr + offsets, mask=mask) 22 | y = tl.load(y_ptr + offsets, mask=mask) 23 | output = x + y 24 | tl.store(output_ptr + offsets, output, mask=mask) 25 | 26 | def add__Tensor_kernel_impl(*args, **kwargs): 27 | # Handle both positional and keyword arguments 28 | if len(args) >= 2: 29 | input_tensor = args[0] 30 | other = args[1] 31 | alpha = kwargs.get('alpha', 1.0) 32 | out = kwargs.get('out', None) 33 | elif len(args) == 1: 34 | input_tensor = args[0] 35 | other = kwargs.get('other', kwargs.get('input', None)) 36 | if other is None: 37 | raise TypeError("add() missing required argument: 'other'") 38 | alpha = kwargs.get('alpha', 1.0) 39 | out = kwargs.get('out', None) 40 | else: 41 | input_tensor = kwargs.get('input', None) 42 | other = kwargs.get('other', None) 43 | if input_tensor is None or other is None: 44 | raise TypeError("add() missing required arguments") 45 | alpha = kwargs.get('alpha', 1.0) 46 | out = kwargs.get('out', None) 47 | 48 | # Store original devices 49 | input_device = input_tensor.device 50 | other_device = other.device if torch.is_tensor(other) else input_device 51 | 52 | # Check CUDA availability 53 | if not torch.cuda.is_available(): 54 | raise RuntimeError("CUDA is not available") 55 | 56 | # Move tensors to GPU if needed 57 | if input_tensor.device.type == 'cpu': 58 | input_tensor = input_tensor.cuda() 59 | if torch.is_tensor(other) and other.device.type == 'cpu': 60 | other = other.cuda() 61 | 62 | # Handle scalar other 63 | if not torch.is_tensor(other): 64 | other = torch.tensor(other, device=input_tensor.device, dtype=input_tensor.dtype) 65 | 66 | # Broadcast tensors to same shape 67 | broadcasted_shape = torch.broadcast_shapes(input_tensor.shape, other.shape) 68 | input_tensor = input_tensor.broadcast_to(broadcasted_shape) 69 | other = other.broadcast_to(broadcasted_shape) 70 | 71 | # Apply alpha scaling 72 | if alpha != 1.0: 73 | other = other * alpha 74 | 75 | # Ensure contiguous tensors 76 | input_tensor = input_tensor.contiguous() 77 | other = other.contiguous() 78 | 79 | # Create output tensor 80 | if out is not None: 81 | if out.device.type == 'cpu': 82 | out = out.cuda() 83 | output = out.contiguous() 84 | if output.shape != broadcasted_shape: 85 | raise RuntimeError(f"Output tensor shape {output.shape} doesn't match broadcast shape {broadcasted_shape}") 86 | else: 87 | output = torch.empty(broadcasted_shape, dtype=input_tensor.dtype, device=input_tensor.device) 88 | 89 | n_elements = input_tensor.numel() 90 | 91 | if n_elements == 0: 92 | # Handle empty tensors 93 | result = output 94 | else: 95 | # Launch kernel 96 | BLOCK_SIZE = 1024 97 | grid = (triton.cdiv(n_elements, BLOCK_SIZE),) 98 | 99 | add_triton_kernel[grid]( 100 | input_tensor, 101 | other, 102 | output, 103 | n_elements, 104 | BLOCK_SIZE=BLOCK_SIZE, 105 | ) 106 | 107 | result = output 108 | 109 | # Move result back to original device 110 | target_device = input_device 111 | if result.device != target_device: 112 | result = result.to(target_device) 113 | 114 | return result 115 | -------------------------------------------------------------------------------- /BackendBench/suite/opinfo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from collections import defaultdict 9 | 10 | from torch.testing._internal.common_methods_invocations import op_db 11 | from torch.utils._python_dispatch import TorchDispatchMode 12 | 13 | from BackendBench.eval import allclose 14 | from BackendBench.op_categories import ( 15 | RANDOM_OPS, 16 | TENSOR_CREATION_AND_MANIPULATION_OPS, 17 | UNSUPPORTED_OPERATORS, 18 | ) 19 | from BackendBench.utils import extract_operator_name 20 | 21 | from .base import OpTest, TestSuite 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class OpInfoTest: 27 | def __init__(self, *args, **kwargs): 28 | self.args = args 29 | self.kwargs = kwargs 30 | 31 | 32 | class OpInfoOpTest(OpTest): 33 | def __init__(self, op, correctness_tests, indices): 34 | self.op = op 35 | self._correctness_tests = correctness_tests 36 | self.indices = set(indices) 37 | self.performance_tests = [] 38 | 39 | @property 40 | def correctness_tests(self): 41 | for idx, test in enumerate(self._correctness_tests): 42 | if idx in self.indices: 43 | # print(f"{idx} {test.input=} {test.args=} {test.kwargs=}") 44 | yield OpInfoTest(test.input, *test.args, **test.kwargs) 45 | 46 | 47 | class OpTracerMode(TorchDispatchMode): 48 | def __init__(self): 49 | self.ops = [] 50 | self.args = [] 51 | self.kwargs = [] 52 | 53 | def __torch_dispatch__(self, fn, types, args=(), kwargs={}): 54 | self.ops.append(fn) 55 | self.args.append(args) 56 | self.kwargs.append(kwargs) 57 | return fn(*args, **kwargs) 58 | 59 | 60 | def get_op_base_name(op_name): 61 | if "." in op_name: 62 | return op_name.split(".")[0] 63 | return op_name 64 | 65 | 66 | def build_op_tests(device, dtype, filter=None): 67 | op_info_op_tests = [] 68 | 69 | for op in op_db: 70 | if "." in op.name and "nn.functional" not in op.name: 71 | continue 72 | if dtype not in op.supported_dtypes(device): 73 | continue 74 | if op.name in ["nonzero_static"]: 75 | continue 76 | 77 | op_indices = defaultdict(list) 78 | try: 79 | sample_inputs = list(op.sample_inputs(device, dtype)) 80 | except Exception: 81 | continue 82 | 83 | for idx, test in enumerate(sample_inputs): 84 | # print(f"{idx=} {test.input=} {test.args=} {test.kwargs=}") 85 | try: 86 | with OpTracerMode() as tracer: 87 | ref = op.op(test.input, *test.args, **test.kwargs) 88 | if len(tracer.ops) == 1: 89 | res = tracer.ops[0](test.input, *test.args, **test.kwargs) 90 | if allclose(ref, res): 91 | if filter and extract_operator_name(str(tracer.ops[0])) not in filter: 92 | continue 93 | if ( 94 | extract_operator_name(str(tracer.ops[0])) 95 | in UNSUPPORTED_OPERATORS 96 | + RANDOM_OPS 97 | + TENSOR_CREATION_AND_MANIPULATION_OPS 98 | ): 99 | continue 100 | op_indices[tracer.ops[0]].append(idx) 101 | else: 102 | logger.debug(f"opinfo {op.name} has {len(tracer.ops)} ops") 103 | except Exception: 104 | continue 105 | 106 | for overload, indices in op_indices.items(): 107 | if len(indices) > 0: 108 | op_info_op_tests.append(OpInfoOpTest(overload, sample_inputs, indices)) 109 | return op_info_op_tests 110 | 111 | 112 | class OpInfoTestSuite(TestSuite): 113 | def __init__(self, name, device, dtype, filter=None): 114 | super().__init__(name, build_op_tests(device, dtype, filter)) 115 | -------------------------------------------------------------------------------- /test/test_backend_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD 3-Clause license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import subprocess 10 | import sys 11 | import unittest 12 | from pathlib import Path 13 | 14 | import torch 15 | 16 | sys.path.insert(0, str(Path(__file__).parent.parent)) 17 | 18 | from BackendBench.backends import DirectoryBackend 19 | from BackendBench.eval import eval_correctness 20 | from BackendBench.suite import Test 21 | 22 | 23 | class TestBackendEvaluation(unittest.TestCase): 24 | """Comprehensive test for backend evaluation system.""" 25 | 26 | @classmethod 27 | def setUpClass(cls): 28 | """Generate required directory structure and operators.""" 29 | from pathlib import Path 30 | 31 | base_dir = Path("generated_kernels") 32 | test_ops = [ 33 | "bitwise_and__Tensor", 34 | "fmod__Tensor", 35 | "relu__default", 36 | "add__Tensor", 37 | "mul__Tensor", 38 | "div__Tensor", 39 | ] 40 | 41 | for op_name in test_ops: 42 | op_dir = base_dir / op_name 43 | op_dir.mkdir(parents=True, exist_ok=True) 44 | 45 | subprocess.run( 46 | [ 47 | sys.executable, 48 | "-m", 49 | "BackendBench.scripts.create_watermarked_operators", 50 | "--overwrite", 51 | "--unique-watermarks", 52 | ], 53 | check=True, 54 | ) 55 | 56 | def test_1_directory_backend_loads_operators(self): 57 | """Verify DirectoryBackend loads operators correctly.""" 58 | backend = DirectoryBackend("generated_kernels") 59 | operator_count = len(backend.compiled_kernels) 60 | 61 | self.assertGreater(operator_count, 0, "Should load operators from generated_kernels") 62 | self.assertIsInstance(backend.compiled_kernels, dict) 63 | 64 | def test_2_watermarked_implementations_fail_correctness(self): 65 | """Verify watermarked operators fail eval_correctness (proving monkey patching).""" 66 | backend = DirectoryBackend("generated_kernels") 67 | 68 | failed_count = 0 69 | total_tested = 0 70 | 71 | test_ops = [ 72 | ( 73 | torch.ops.aten.bitwise_and.Tensor, 74 | lambda: torch.tensor([1, 2, 3]), 75 | lambda: torch.tensor([2, 3, 4]), 76 | ), 77 | ( 78 | torch.ops.aten.fmod.Tensor, 79 | lambda: torch.tensor([5.0, 7.0]), 80 | lambda: torch.tensor([2.0, 3.0]), 81 | ), 82 | ] 83 | 84 | for op, *arg_generators in test_ops: 85 | if op in backend: 86 | impl = backend[op] 87 | test = Test(*arg_generators) 88 | correctness, correctness_results = eval_correctness(op, impl, [test]) 89 | assert len(correctness_results) == 1 90 | total_tested += 1 91 | if correctness == 0.0: 92 | failed_count += 1 93 | 94 | self.assertGreater(total_tested, 0, "Should test at least one operator") 95 | self.assertGreater(failed_count, 0, "At least some watermarked ops should fail") 96 | 97 | def test_3_main_script_evaluation(self): 98 | """Verify main.py script works with DirectoryBackend.""" 99 | cmd = [ 100 | sys.executable, 101 | "-m", 102 | "BackendBench.scripts.main", 103 | "--backend", 104 | "directory", 105 | "--suite", 106 | "smoke", 107 | "--log-level", 108 | "ERROR", 109 | ] 110 | 111 | result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) 112 | 113 | self.assertEqual(result.returncode, 0, "Main script should complete successfully") 114 | self.assertIsInstance(result.stdout, str) 115 | self.assertIsInstance(result.stderr, str) 116 | 117 | 118 | if __name__ == "__main__": 119 | unittest.main() 120 | -------------------------------------------------------------------------------- /test/test_backends.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.util 8 | 9 | import pytest 10 | import torch 11 | 12 | from BackendBench.backends import ( 13 | AtenBackend, 14 | FlagGemsBackend, 15 | KernelAgentBackend, 16 | ) 17 | 18 | HAS_KERNEL_AGENT = KernelAgentBackend is not None 19 | 20 | HAS_FLAG_GEMS = importlib.util.find_spec("flag_gems") is not None 21 | 22 | 23 | class TestAtenBackend: 24 | def test_aten_backend_initialization(self): 25 | backend = AtenBackend() 26 | assert backend.name == "aten" 27 | 28 | def test_aten_backend_contains_op(self): 29 | backend = AtenBackend() 30 | 31 | assert torch.ops.aten.relu.default in backend 32 | assert torch.ops.aten.add.Tensor in backend 33 | assert torch.ops.aten.mul.Tensor in backend 34 | 35 | def test_aten_backend_getitem(self): 36 | backend = AtenBackend() 37 | 38 | relu_op = torch.ops.aten.relu.default 39 | assert backend[relu_op] == relu_op 40 | 41 | add_op = torch.ops.aten.add.Tensor 42 | assert backend[add_op] == add_op 43 | 44 | 45 | class TestFlagGemsBackend: 46 | @pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available") 47 | def test_flag_gems_backend_initialization(self): 48 | backend = FlagGemsBackend() 49 | assert backend.name == "flaggems" 50 | assert isinstance(backend.ops, dict) 51 | 52 | @pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available") 53 | def test_flag_gems_backend_contains_op(self): 54 | backend = FlagGemsBackend() 55 | 56 | # Test with actual ops that flag_gems supports 57 | if hasattr(torch.ops.aten, "abs"): 58 | if torch.ops.aten.abs.default in backend: 59 | assert torch.ops.aten.abs.default in backend 60 | 61 | # Test with an op that might not be in flag_gems 62 | unsupported_op = ( 63 | torch.ops.aten.special_log_ndtr.default 64 | if hasattr(torch.ops.aten, "special_log_ndtr") 65 | else None 66 | ) 67 | if unsupported_op: 68 | assert unsupported_op not in backend 69 | 70 | @pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available") 71 | def test_flag_gems_backend_getitem(self): 72 | backend = FlagGemsBackend() 73 | 74 | # Test with an op that should exist 75 | if hasattr(torch.ops.aten, "abs") and torch.ops.aten.abs.default in backend: 76 | impl = backend[torch.ops.aten.abs.default] 77 | assert impl is not None 78 | 79 | # Test with an op that doesn't exist in flag_gems 80 | unsupported_op = ( 81 | torch.ops.aten.special_log_ndtr.default 82 | if hasattr(torch.ops.aten, "special_log_ndtr") 83 | else None 84 | ) 85 | if unsupported_op and unsupported_op not in backend: 86 | with pytest.raises(KeyError): 87 | _ = backend[unsupported_op] 88 | 89 | 90 | class TestKernelAgentBackend: 91 | @pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available") 92 | def test_kernel_agent_backend_initialization(self): 93 | backend = KernelAgentBackend() 94 | assert backend.name == "kernel_agent" 95 | assert "kernel_agent_run_" in backend.kernels_dir 96 | assert backend.num_workers == 4 # default value 97 | assert backend.max_rounds == 10 # default value 98 | 99 | @pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available") 100 | def test_kernel_agent_backend_set_config(self): 101 | backend = KernelAgentBackend() 102 | 103 | backend.set_config(num_workers=8, max_rounds=20) 104 | 105 | assert backend.num_workers == 8 106 | assert backend.max_rounds == 20 107 | 108 | 109 | class TestBackendIntegration: 110 | def test_backend_polymorphism(self): 111 | backends = [] 112 | backends.append(AtenBackend()) 113 | 114 | if HAS_FLAG_GEMS: 115 | backends.append(FlagGemsBackend()) 116 | 117 | if HAS_KERNEL_AGENT: 118 | backends.append(KernelAgentBackend()) 119 | 120 | for backend in backends: 121 | assert hasattr(backend, "name") 122 | assert hasattr(backend, "__contains__") 123 | assert hasattr(backend, "__getitem__") 124 | assert isinstance(backend.name, str) 125 | -------------------------------------------------------------------------------- /BackendBench/backends/directory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.util 8 | import logging 9 | import os 10 | from typing import Callable, Dict 11 | 12 | from ..utils import folder_name_to_op_name, get_pytorch_op 13 | from .base import Backend 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class DirectoryBackend(Backend): 19 | def __init__(self, ops_dir="generated_kernels"): 20 | super().__init__("directory") 21 | self.ops_dir = ops_dir 22 | self.compiled_kernels: Dict[str, Callable] = {} 23 | self._load_kernels() 24 | 25 | def _load_kernels(self): 26 | """ 27 | Discovers and loads kernel implementations from the operator directory structure. 28 | 29 | This method scans the ops_dir for subdirectories named after PyTorch operator 30 | overloads (e.g., "add__Tensor" for add.Tensor and "add__Scalar" for add.Scalar). 31 | Each subdirectory should contain Python files with kernel implementations 32 | following the naming pattern: {op_name}_implementation*.py 33 | 34 | This method uses the op overload format (e.g., "add__Tensor" for "add.Tensor") and 35 | registers the kernel for ONLY that specific overload. 36 | """ 37 | if not os.path.exists(self.ops_dir): 38 | logger.warning(f"ops directory {self.ops_dir} does not exist") 39 | return 40 | 41 | loaded_count = 0 42 | for folder_name in os.listdir(self.ops_dir): 43 | op_dir = os.path.join(self.ops_dir, folder_name) 44 | if not os.path.isdir(op_dir): 45 | continue 46 | 47 | impl_files = [ 48 | f 49 | for f in os.listdir(op_dir) 50 | if f.endswith(".py") and f.startswith(f"{folder_name}_implementation") 51 | ] 52 | if not impl_files: 53 | logger.debug(f"No implementation files found in {op_dir}") 54 | continue 55 | 56 | impl_file = sorted(impl_files)[0] 57 | impl_path = os.path.join(op_dir, impl_file) 58 | 59 | try: 60 | op_name = folder_name_to_op_name(folder_name) 61 | kernel_func = self._load_kernel_from_file(impl_path, folder_name) 62 | 63 | pytorch_op = get_pytorch_op(op_name) 64 | if pytorch_op: 65 | self.compiled_kernels[pytorch_op] = kernel_func 66 | logger.info(f"Loaded {op_name} from {impl_file} -> {op_name}") 67 | loaded_count += 1 68 | 69 | except Exception as e: 70 | logger.error(f"Error loading {op_name} from {impl_file}: {e}") 71 | 72 | logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/") 73 | 74 | def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable: 75 | """ 76 | Dynamically load a kernel implementation function from a Python file. 77 | 78 | Each operator directory should contain implementation files that export a function 79 | named {op_name}_kernel_impl. This function becomes the kernel implementation 80 | that gets registered for all variants of the operator. 81 | 82 | Args: 83 | file_path: Path to the Python implementation file 84 | op_name: Base name of the operator (e.g., "add", "mul", "conv2d") 85 | 86 | Returns: 87 | Callable kernel implementation function 88 | 89 | Raises: 90 | ValueError: If the expected kernel function is not found in the file 91 | """ 92 | spec = importlib.util.spec_from_file_location(f"op_{folder_name}", file_path) 93 | module = importlib.util.module_from_spec(spec) 94 | spec.loader.exec_module(module) 95 | 96 | kernel_func_name = f"{folder_name}_kernel_impl" 97 | if hasattr(module, kernel_func_name): 98 | return getattr(module, kernel_func_name) 99 | else: 100 | raise ValueError(f"No function named {kernel_func_name} found in {file_path}") 101 | 102 | def __getitem__(self, key): 103 | if key in self.compiled_kernels: 104 | return self.compiled_kernels[key] 105 | raise KeyError( 106 | f"Operator {key} not implemented in DirectoryBackend - add implementation to {self.ops_dir}/" 107 | ) 108 | 109 | def __contains__(self, key): 110 | return key in self.compiled_kernels 111 | -------------------------------------------------------------------------------- /BackendBench/scripts/create_simple_test_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD 3-Clause license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Create simple kernel implementations for 5 common operations. 11 | Each just calls the original PyTorch function. 12 | """ 13 | 14 | import logging 15 | import os 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def create_relu(): 21 | os.makedirs("generated_kernels/relu__default", exist_ok=True) 22 | with open("generated_kernels/relu__default/relu__default_implementation_v1.py", "w") as f: 23 | f.write('''import torch 24 | 25 | def relu__default_kernel_impl(input): 26 | """Simple ReLU implementation.""" 27 | return torch.ops.aten.relu.default(input) 28 | 29 | if __name__ == "__main__": 30 | x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) 31 | result = relu__default_kernel_impl(x) 32 | expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) 33 | print(f"ReLU test passed: {torch.allclose(result, expected)}") 34 | ''') 35 | logger.info("Created relu implementation") 36 | 37 | 38 | def create_add(): 39 | os.makedirs("generated_kernels/add__Tensor", exist_ok=True) 40 | with open("generated_kernels/add__Tensor/add__Tensor_implementation_v1.py", "w") as f: 41 | f.write('''import torch 42 | 43 | def add__Tensor_kernel_impl(input, other): 44 | """Simple add implementation.""" 45 | return torch.ops.aten.add.Tensor(input, other) 46 | 47 | if __name__ == "__main__": 48 | a = torch.tensor([1.0, 2.0, 3.0]) 49 | b = torch.tensor([4.0, 5.0, 6.0]) 50 | result = add__Tensor_kernel_impl(a, b) 51 | expected = torch.tensor([5.0, 7.0, 9.0]) 52 | print(f"Add test passed: {torch.allclose(result, expected)}") 53 | ''') 54 | logger.info("Created add implementation") 55 | 56 | 57 | def create_mul(): 58 | os.makedirs("generated_kernels/mul__Tensor", exist_ok=True) 59 | with open("generated_kernels/mul__Tensor/mul__Tensor_implementation_v1.py", "w") as f: 60 | f.write('''import torch 61 | 62 | def mul__Tensor_kernel_impl(input, other): 63 | """Simple mul implementation.""" 64 | return torch.ops.aten.mul.Tensor(input, other) 65 | 66 | if __name__ == "__main__": 67 | a = torch.tensor([1.0, 2.0, 3.0]) 68 | b = torch.tensor([4.0, 5.0, 6.0]) 69 | result = mul__Tensor_kernel_impl(a, b) 70 | expected = torch.tensor([4.0, 10.0, 18.0]) 71 | print(f"Mul test passed: {torch.allclose(result, expected)}") 72 | ''') 73 | logger.info("Created mul implementation") 74 | 75 | 76 | def create_abs(): 77 | os.makedirs("generated_kernels/abs__default", exist_ok=True) 78 | with open("generated_kernels/abs__default/abs__default_implementation_v1.py", "w") as f: 79 | f.write('''import torch 80 | 81 | def abs__default_kernel_impl(input): 82 | """Simple abs implementation.""" 83 | return torch.ops.aten.abs.default(input) 84 | 85 | if __name__ == "__main__": 86 | x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) 87 | result = abs__default_kernel_impl(x) 88 | expected = torch.tensor([2.0, 1.0, 0.0, 1.0, 2.0]) 89 | print(f"Abs test passed: {torch.allclose(result, expected)}") 90 | ''') 91 | logger.info("Created abs implementation") 92 | 93 | 94 | def create_sum(): 95 | os.makedirs("generated_kernels/sum__default", exist_ok=True) 96 | with open("generated_kernels/sum__default/sum__default_implementation_v1.py", "w") as f: 97 | f.write('''import torch 98 | 99 | def sum__default_kernel_impl(input, *args, **kwargs): 100 | """Simple sum implementation.""" 101 | return torch.ops.aten.sum.default(input, *args, **kwargs) 102 | 103 | if __name__ == "__main__": 104 | x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 105 | result = sum__default_kernel_impl(x) 106 | expected = torch.tensor(10.0) 107 | print(f"Sum test passed: {torch.allclose(result, expected)}") 108 | ''') 109 | logger.info("Created sum implementation") 110 | 111 | 112 | def main(): 113 | """Create 5 simple test operations.""" 114 | logging.basicConfig(level=logging.INFO, format="%(message)s") 115 | logger.info("Creating simple test implementations...") 116 | 117 | create_relu() 118 | create_add() 119 | create_mul() 120 | create_abs() 121 | create_sum() 122 | 123 | logger.info("Created 5 simple kernel implementations in generated_kernels/") 124 | logger.info("Test them individually:") 125 | logger.info(" python generated_kernels/relu__default/relu__default_implementation_v1.py.py") 126 | logger.info(" python generated_kernels/add__Tensor/add__Tensor_implementation_v1.py") 127 | logger.info(" etc.") 128 | logger.info("Or test all with the backend:") 129 | logger.info(" python test/test_simple_directory_backend.py") 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /test/test_llm_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import torch 10 | 11 | from BackendBench.backends import LLMBackend 12 | from BackendBench.llm_client import KernelTemplateManager, LLMKernelGenerator 13 | from BackendBench.suite import OpInfoTestSuite 14 | 15 | 16 | class MockLLMKernelGenerator(LLMKernelGenerator): 17 | def __init__( 18 | self, 19 | mock_response_files: list[str], 20 | ): 21 | self.model = "mock_model" 22 | self.template_manager = KernelTemplateManager() 23 | self.mock_response_files = mock_response_files 24 | self.attempts = 0 25 | 26 | def call_llm(self, prompt: str) -> str: 27 | file = ( 28 | self.mock_response_files[self.attempts] 29 | if self.attempts < len(self.mock_response_files) 30 | else self.mock_response_files[-1] 31 | ) 32 | self.attempts += 1 33 | 34 | file_path = os.path.join(os.path.dirname(__file__), "fixtures", "llm_response", file) 35 | with open(file_path, "r") as f: 36 | return f.read() 37 | 38 | 39 | class TestLLMBackend: 40 | suite = OpInfoTestSuite( 41 | "opinfo_cpu_bfloat16", 42 | "cpu", 43 | torch.bfloat16, 44 | filter=["add.Tensor"], 45 | ) 46 | 47 | def test_generate_kernels_good(self): 48 | mock_response_files = ["add_good.txt"] 49 | attempts = 5 50 | 51 | backend = LLMBackend( 52 | model="mock_model", 53 | llm_client=MockLLMKernelGenerator(mock_response_files), 54 | ) 55 | backend.generate_kernels(self.suite, attempts) 56 | 57 | summary_file = os.path.join(backend.kernels_dir, "add__Tensor", "add__Tensor_summary.txt") 58 | assert os.path.exists(summary_file) 59 | 60 | with open(summary_file, "r") as f: 61 | summary = f.read() 62 | assert "Final Status: ✓ Success" in summary 63 | assert f"Best Kernel Attempt: 1/{attempts}" in summary 64 | 65 | def test_retry(self): 66 | mock_response_files = ["add_missing_target_functions.txt", "add_good.txt"] 67 | attempts = 5 68 | 69 | backend = LLMBackend( 70 | model="mock_model", 71 | llm_client=MockLLMKernelGenerator(mock_response_files), 72 | ) 73 | backend.generate_kernels(self.suite, attempts) 74 | 75 | summary_file = os.path.join(backend.kernels_dir, "add__Tensor", "add__Tensor_summary.txt") 76 | assert os.path.exists(summary_file) 77 | 78 | with open(summary_file, "r") as f: 79 | summary = f.read() 80 | assert "Final Status: ✓ Success" in summary 81 | assert f"Best Kernel Attempt: 2/{attempts}" in summary 82 | 83 | def test_missing_target_functions(self): 84 | mock_response_files = ["add_missing_target_functions.txt"] 85 | attempts = 1 86 | 87 | backend = LLMBackend( 88 | model="mock_model", 89 | llm_client=MockLLMKernelGenerator(mock_response_files), 90 | ) 91 | backend.generate_kernels(self.suite, attempts) 92 | 93 | summary_file = os.path.join(backend.kernels_dir, "add__Tensor", "add__Tensor_summary.txt") 94 | assert os.path.exists(summary_file) 95 | 96 | with open(summary_file, "r") as f: 97 | summary = f.read() 98 | assert "Final Status: ✗ Failure" in summary 99 | assert f"Best Kernel Attempt: 1/{attempts}" in summary 100 | 101 | def test_missing_python_code_block(self): 102 | mock_response_files = ["add_missing_python_code_block.txt"] 103 | attempts = 1 104 | 105 | backend = LLMBackend( 106 | model="mock_model", 107 | llm_client=MockLLMKernelGenerator(mock_response_files), 108 | ) 109 | backend.generate_kernels(self.suite, attempts) 110 | 111 | summary_file = os.path.join(backend.kernels_dir, "add__Tensor", "add__Tensor_summary.txt") 112 | assert os.path.exists(summary_file) 113 | 114 | with open(summary_file, "r") as f: 115 | summary = f.read() 116 | assert "Final Status: ✗ Failure" in summary 117 | assert f"Best Kernel Attempt: 1/{attempts}" in summary 118 | 119 | def test_chooses_best_kernel(self): 120 | mock_response_files = [ 121 | "add_missing_target_functions.txt", 122 | "add_good.txt", 123 | "add_missing_python_code_block.txt", 124 | ] 125 | attempts = 3 126 | 127 | backend = LLMBackend( 128 | model="mock_model", 129 | llm_client=MockLLMKernelGenerator(mock_response_files), 130 | ) 131 | backend.generate_kernels(self.suite, attempts) 132 | 133 | summary_file = os.path.join(backend.kernels_dir, "add__Tensor", "add__Tensor_summary.txt") 134 | assert os.path.exists(summary_file) 135 | 136 | with open(summary_file, "r") as f: 137 | summary = f.read() 138 | assert "Final Status: ✓ Success" in summary 139 | # we should choose the best kernel which is the second one in this case as it's the only "correct" one 140 | assert f"Best Kernel Attempt: 2/{attempts}" in summary 141 | -------------------------------------------------------------------------------- /BackendBench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | BackendBench: A PyTorch backend evaluation framework. 9 | """ 10 | 11 | import logging 12 | from pathlib import Path 13 | from typing import Optional, Union 14 | 15 | import torch 16 | 17 | # Import the existing DirectoryBackend implementation 18 | from BackendBench.backends.directory import DirectoryBackend 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | __version__ = "0.1.0" 23 | __all__ = ["enable", "disable", "BackendBench"] 24 | 25 | # Global state 26 | _lib = None 27 | 28 | 29 | class _BackendBenchContext: 30 | """Context manager for BackendBench that enables on entry and disables on exit.""" 31 | 32 | def __init__(self, kernel_dir=None, namespace="aten", dispatch_key="CUDA"): 33 | self.kernel_dir = kernel_dir 34 | self.namespace = namespace 35 | self.dispatch_key = dispatch_key 36 | self.lib = torch.library.Library(namespace, "IMPL", dispatch_key) 37 | 38 | def __enter__(self): 39 | enable(self.kernel_dir, self.namespace, self.dispatch_key, self.lib) 40 | return self 41 | 42 | def __exit__(self, exc_type, exc_val, exc_tb): 43 | self.lib = None 44 | 45 | 46 | class BackendBench: 47 | """BackendBench main class with context manager support.""" 48 | 49 | @classmethod 50 | def enable(cls, kernel_dir=None, namespace="aten", dispatch_key="CUDA"): 51 | """ 52 | Return a context manager that enables BackendBench on entry and disables on exit. 53 | 54 | Args: 55 | kernel_dir: Path to the directory containing custom kernels 56 | namespace: PyTorch namespace to patch (default: "aten") 57 | dispatch_key: Dispatch key for the kernels (default: "CUDA") 58 | 59 | Returns: 60 | Context manager that can be used with 'with' statement 61 | 62 | Example: 63 | with BackendBench.enable(kernel_dir="generated_kernels/"): 64 | model.forward() # uses LLM kernels 65 | # On exit, uses aten kernels 66 | """ 67 | return _BackendBenchContext(kernel_dir, namespace, dispatch_key) 68 | 69 | 70 | def _monkey_patch_operators(lib, op_custom_impl, namespace="aten", dispatch_key="CUDA"): 71 | """ 72 | Replace PyTorch operators with custom implementations using torch.library. 73 | """ 74 | 75 | assert dispatch_key in ["CPU", "CUDA"], "Only CPU and CUDA dispatch keys are supported" 76 | 77 | patched_count = 0 78 | for op, custom_impl in op_custom_impl.items(): 79 | try: 80 | # Extract operator name and overload from the OpOverload 81 | op_name = op._schema.name 82 | overload_name = op._schema.overload_name 83 | 84 | # Create the full name for torch.library 85 | if overload_name: 86 | full_name = f"{op_name}.{overload_name}" 87 | else: 88 | full_name = op_name 89 | 90 | # Register the custom implementation 91 | lib.impl(full_name, custom_impl, dispatch_key) 92 | patched_count += 1 93 | 94 | except Exception as e: 95 | # Some operators might not be patchable 96 | logger.warning(f"Could not register {op}: {e}") 97 | 98 | if patched_count > 0: 99 | print(f"Successfully registered {patched_count} custom operators") 100 | else: 101 | print("No custom operators registered") 102 | 103 | 104 | def enable( 105 | kernel_dir: Optional[Union[str, Path]] = None, 106 | namespace: str = "aten", 107 | dispatch_key: str = "CUDA", 108 | lib=None, 109 | ) -> None: 110 | """ 111 | Enable the DirectoryBackend to use custom operator implementations. 112 | """ 113 | print("Enabling DirectoryBackend") 114 | # Set default kernel directory 115 | if kernel_dir is None: 116 | kernel_dir = Path(__file__).parents[1] / "generated_kernels" 117 | 118 | kernel_dir = Path(kernel_dir) 119 | 120 | # Check if kernel directory exists 121 | if not kernel_dir.exists(): 122 | logger.warning( 123 | f"Kernel directory {kernel_dir} does not exist. Call" 124 | f"directory_backend.setup_operators('{kernel_dir}') manually." 125 | ) 126 | return 127 | 128 | # Initialize the backend 129 | try: 130 | _current_backend = DirectoryBackend(str(kernel_dir)) 131 | 132 | # Actually monkey-patch PyTorch operators 133 | if lib: 134 | _monkey_patch_operators(lib, _current_backend.compiled_kernels, namespace, dispatch_key) 135 | else: 136 | global _lib 137 | if _lib is None: 138 | _lib = torch.library.Library(namespace, "IMPL", dispatch_key) 139 | _monkey_patch_operators( 140 | _lib, _current_backend.compiled_kernels, namespace, dispatch_key 141 | ) 142 | except Exception as e: 143 | logger.warn(f"Failed to enable DirectoryBackend: {e}") 144 | 145 | 146 | def disable() -> None: 147 | """ 148 | Disable the DirectoryBackend and restore original PyTorch operators. 149 | """ 150 | global _lib 151 | 152 | if _lib is None: 153 | logger.warn("DirectoryBackend is not currently enabled") 154 | return 155 | 156 | # Restore original operators 157 | _lib = None 158 | print("DirectoryBackend disabled") 159 | -------------------------------------------------------------------------------- /BackendBench/suite/facto.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from collections import defaultdict 9 | 10 | import torch 11 | from torch.utils._python_dispatch import TorchDispatchMode 12 | 13 | try: 14 | from facto.inputgen.argtuple.gen import ArgumentTupleGenerator 15 | from facto.inputgen.utils.config import TensorConfig 16 | from facto.specdb.db import SpecDictDB 17 | except ImportError: 18 | ArgumentTupleGenerator = None 19 | TensorConfig = None 20 | SpecDictDB = None 21 | 22 | 23 | from BackendBench.eval import allclose 24 | from BackendBench.op_categories import ( 25 | RANDOM_OPS, 26 | TENSOR_CREATION_AND_MANIPULATION_OPS, 27 | UNSUPPORTED_OPERATORS, 28 | ) 29 | from BackendBench.opregistry import get_operator 30 | 31 | from .base import OpTest, TestSuite 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class FactoTest: 37 | def __init__(self, *args, **kwargs): 38 | self.args = args 39 | self.kwargs = kwargs 40 | 41 | 42 | class FactoOpTest(OpTest): 43 | def __init__(self, op, correctness_tests): 44 | self.op = op 45 | self._correctness_tests = correctness_tests 46 | self.performance_tests = [] 47 | 48 | @property 49 | def correctness_tests(self): 50 | for test in self._correctness_tests: 51 | yield FactoTest(*test.args, **test.kwargs) 52 | 53 | 54 | class OpTracerMode(TorchDispatchMode): 55 | def __init__(self): 56 | self.ops = [] 57 | self.args = [] 58 | self.kwargs = [] 59 | 60 | def __torch_dispatch__(self, fn, types, args=(), kwargs={}): 61 | self.ops.append(fn) 62 | self.args.append(args) 63 | self.kwargs.append(kwargs) 64 | return fn(*args, **kwargs) 65 | 66 | 67 | def build_facto_op_tests(device, dtype, filter=None, num_runs=10, empty=False, probability=1.0): 68 | facto_op_tests = [] 69 | failed = [] 70 | for spec_name in SpecDictDB: 71 | try: 72 | if filter and spec_name not in filter: 73 | continue 74 | if ( 75 | spec_name 76 | in UNSUPPORTED_OPERATORS + RANDOM_OPS + TENSOR_CREATION_AND_MANIPULATION_OPS 77 | ): 78 | continue 79 | 80 | # Get canonical operator from registry 81 | op = get_operator(spec_name) 82 | if op is None: 83 | logger.debug(f"Skipping {spec_name}: operator resolution failed") 84 | continue 85 | 86 | config = TensorConfig( 87 | empty=empty, 88 | ).set_probability(probability) 89 | 90 | spec = SpecDictDB[spec_name] 91 | generator = ArgumentTupleGenerator(spec, config) 92 | 93 | op_tests = defaultdict(list) 94 | 95 | for idx, (posargs, inkwargs, outargs) in enumerate(generator.gen()): 96 | if idx >= num_runs: 97 | break 98 | 99 | # Filter arguments to target device/dtype 100 | filtered_posargs = [] 101 | for arg in posargs: 102 | if isinstance(arg, torch.Tensor): 103 | arg = arg.to(device=device, dtype=dtype) 104 | filtered_posargs.append(arg) 105 | 106 | filtered_inkwargs = {} 107 | for k, v in inkwargs.items(): 108 | if isinstance(v, torch.Tensor): 109 | v = v.to(device=device, dtype=dtype) 110 | filtered_inkwargs[k] = v 111 | 112 | filtered_outargs = {} 113 | for k, v in outargs.items(): 114 | if isinstance(v, torch.Tensor): 115 | v = v.to(device=device, dtype=dtype) 116 | filtered_outargs[k] = v 117 | 118 | all_kwargs = {**filtered_inkwargs, **filtered_outargs} 119 | 120 | try: 121 | # Trace execution to find underlying PyTorch ops 122 | with OpTracerMode() as tracer: 123 | ref = op(*filtered_posargs, **all_kwargs) 124 | except Exception: 125 | logger.debug(f"FACTO spec {spec_name} couldn't run underlying op {op}") 126 | continue 127 | 128 | # Check if we captured exactly one op (clean mapping) 129 | if len(tracer.ops) == 1: 130 | try: 131 | # Verify the traced op produces the same result 132 | res = tracer.ops[0](*filtered_posargs, **all_kwargs) 133 | if allclose(ref, res): 134 | op_tests[tracer.ops[0]].append( 135 | FactoTest(*filtered_posargs, **all_kwargs) 136 | ) 137 | except Exception: 138 | logger.debug( 139 | f"FACTO spec {spec_name} couldn't run underlying op {tracer.ops[0]}" 140 | ) 141 | else: 142 | logger.debug(f"FACTO spec {spec_name} has {len(tracer.ops)} ops") 143 | 144 | for traced_op, tests in op_tests.items(): 145 | if len(tests) > 0: 146 | facto_op_tests.append(FactoOpTest(traced_op, tests)) 147 | except Exception: 148 | logger.debug(f"FACTO spec {spec_name} failed") 149 | failed.append(spec_name) 150 | 151 | logger.debug(f"Failed specs: {failed}") 152 | 153 | return facto_op_tests 154 | 155 | 156 | class FactoTestSuite(TestSuite): 157 | def __init__(self, name, device, dtype, filter=None, num_runs=10, empty=False, probability=1.0): 158 | super().__init__( 159 | name, build_facto_op_tests(device, dtype, filter, num_runs, empty, probability) 160 | ) 161 | -------------------------------------------------------------------------------- /BackendBench/scripts/setup_operator_directories.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD 3-Clause license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Setup script to create directory structure for PyTorch operators in op_map. 11 | This creates directories for operators that are actually used in evaluation suites 12 | (opinfo, torchbench) so LLM researchers can fill them with generated kernels. 13 | """ 14 | 15 | import argparse 16 | from pathlib import Path 17 | from typing import Set 18 | 19 | from BackendBench.scripts.op_map import op_map_data 20 | from BackendBench.utils import extract_operator_name, op_name_to_folder_name 21 | 22 | 23 | def extract_aten_ops(op_strings): 24 | """Extract unique aten operator names from a list of operation strings.""" 25 | return [extract_operator_name(op_str) for op_str in op_strings] 26 | 27 | 28 | def get_all_operators_from_op_map(): 29 | """Extract all unique folder names from the authoritative op_map.""" 30 | folder_names = set() 31 | 32 | for line in op_map_data.strip().split("\n"): 33 | if line.startswith("canonical:"): 34 | # Extract canonical name from line like "canonical:add.Tensor func:add.Tensor ..." 35 | canonical_part = line.split()[0] # Get "canonical:add.Tensor" 36 | canonical_name = canonical_part.split(":", 1)[1] # Get "add.Tensor" 37 | 38 | folder_names.add(canonical_name) 39 | 40 | return sorted(folder_names) 41 | 42 | 43 | def get_torchbench_operators() -> Set[str]: 44 | """Get operators used in TorchBench suite.""" 45 | try: 46 | from BackendBench.suite import TorchBenchTestSuite 47 | 48 | suite = TorchBenchTestSuite("torchbench", None) 49 | ops = set() 50 | for optest in suite: 51 | op_str = str(optest.op) 52 | op_name = extract_operator_name(op_str) 53 | ops.add(op_name) 54 | return ops 55 | except Exception as e: 56 | print(f"Warning: Could not load TorchBench operators: {e}") 57 | return set() 58 | 59 | 60 | def get_opinfo_operators() -> Set[str]: 61 | """Get operators available in OpInfo suite.""" 62 | try: 63 | import torch 64 | 65 | from BackendBench.suite import OpInfoTestSuite 66 | 67 | suite = OpInfoTestSuite("opinfo", "cpu", torch.float32) 68 | opinfo_ops = [str(optest.op) for optest in suite] 69 | return set(extract_aten_ops(opinfo_ops)) 70 | except Exception as e: 71 | print(f"Warning: Could not load OpInfo operators: {e}") 72 | return set() 73 | 74 | 75 | def setup_operator_directories( 76 | base_dir: str = "generated_kernels", verbose: bool = False, suite: str = "all" 77 | ): 78 | """ 79 | Set up directory structure for operators based on test suite selection. 80 | 81 | Args: 82 | base_dir: Base directory for operator implementations 83 | verbose: Show verbose output for each directory created/skipped 84 | suite: Which operators to include ('torchbench', 'opinfo', 'all') 85 | """ 86 | 87 | # Get all operators from op_map first 88 | all_op_map_operators = set(get_all_operators_from_op_map()) 89 | print(f"Found {len(all_op_map_operators)} unique operators in op_map") 90 | 91 | # Filter based on suite selection 92 | if suite == "torchbench": 93 | torchbench_ops = get_torchbench_operators() 94 | selected_ops = all_op_map_operators & torchbench_ops 95 | print(f"TorchBench operators in op_map: {len(selected_ops)} total") 96 | elif suite == "opinfo": 97 | opinfo_ops = get_opinfo_operators() 98 | selected_ops = all_op_map_operators & opinfo_ops 99 | print(f"OpInfo operators in op_map: {len(selected_ops)} total") 100 | elif suite == "all": 101 | selected_ops = all_op_map_operators 102 | print(f"All operators from op_map: {len(selected_ops)} total") 103 | else: 104 | raise ValueError(f"Invalid suite '{suite}'. Must be one of: torchbench, opinfo, all") 105 | 106 | folder_names = [op_name_to_folder_name(op) for op in sorted(selected_ops)] 107 | print(f"Creating directories for {len(folder_names)} operators") 108 | 109 | base_path = Path(base_dir) 110 | base_path.mkdir(exist_ok=True) 111 | 112 | created_count = 0 113 | skipped_count = 0 114 | 115 | for folder_name in folder_names: 116 | op_dir = base_path / folder_name 117 | 118 | if op_dir.exists(): 119 | if verbose: 120 | print(f"Directory already exists: {folder_name}") 121 | skipped_count += 1 122 | continue 123 | 124 | op_dir.mkdir(exist_ok=True) 125 | if verbose: 126 | print(f"Created directory: {folder_name}") 127 | created_count += 1 128 | 129 | print("\nDirectory setup complete:") 130 | print(f"- Created {created_count} new directories") 131 | print(f"- Skipped {skipped_count} existing directories") 132 | print(f"- Total operators for {suite} suite: {len(folder_names)}") 133 | print(f"- Base directory: {base_path.absolute()}") 134 | 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser( 138 | description="Set up directory structure for PyTorch operators based on test suite selection" 139 | ) 140 | parser.add_argument( 141 | "--base-dir", 142 | default="generated_kernels", 143 | help="Base directory for operator implementations (default: generated_kernels)", 144 | ) 145 | parser.add_argument( 146 | "--verbose", 147 | "-v", 148 | action="store_true", 149 | help="Show verbose output for each directory created/skipped", 150 | ) 151 | parser.add_argument( 152 | "--suite", 153 | choices=["torchbench", "opinfo", "all"], 154 | default="torchbench", 155 | help="Which test suite operators to include (default: torchbench)", 156 | ) 157 | 158 | args = parser.parse_args() 159 | setup_operator_directories(args.base_dir, verbose=args.verbose, suite=args.suite) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /test/test_suite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from BackendBench.opregistry import get_operator 11 | from BackendBench.suite import OpTest, randn, SmokeTestSuite, Test, TestSuite 12 | 13 | 14 | class TestRandnFunction: 15 | def test_randn_returns_callable(self): 16 | fn = randn(2, 3) 17 | assert callable(fn) 18 | 19 | tensor = fn() 20 | assert isinstance(tensor, torch.Tensor) 21 | assert tensor.shape == (2, 3) 22 | 23 | def test_randn_with_kwargs(self): 24 | fn = randn(2, 3, device="cpu", dtype=torch.float32) 25 | tensor = fn() 26 | 27 | assert tensor.device.type == "cpu" 28 | assert tensor.dtype == torch.float32 29 | assert tensor.shape == (2, 3) 30 | 31 | 32 | class TestTestClass: 33 | def test_test_initialization(self): 34 | test = Test(1, 2, 3, key1="value1", key2="value2") 35 | 36 | assert test._args == (1, 2, 3) 37 | assert test._kwargs == {"key1": "value1", "key2": "value2"} 38 | 39 | @pytest.mark.skip(reason="Test expects mixed callable/non-callable args - needs clarification") 40 | def test_test_args_property(self): 41 | def fn1(): 42 | return 10 43 | 44 | def fn2(): 45 | return 20 46 | 47 | test = Test(fn1, 5, fn2) 48 | 49 | args = test.args 50 | assert args == [10, 5, 20] 51 | 52 | @pytest.mark.skip( 53 | reason="Test expects mixed callable/non-callable kwargs - needs clarification" 54 | ) 55 | def test_test_kwargs_property(self): 56 | def fn1(): 57 | return "computed" 58 | 59 | test = Test(key1=fn1, key2="static") 60 | 61 | kwargs = test.kwargs 62 | assert kwargs == {"key1": "computed", "key2": "static"} 63 | 64 | def test_test_with_randn(self): 65 | test = Test(randn(2, 3), randn(3, 4), device=lambda: "cpu") 66 | 67 | args = test.args 68 | kwargs = test.kwargs 69 | 70 | assert len(args) == 2 71 | assert isinstance(args[0], torch.Tensor) 72 | assert isinstance(args[1], torch.Tensor) 73 | assert args[0].shape == (2, 3) 74 | assert args[1].shape == (3, 4) 75 | assert kwargs == {"device": "cpu"} 76 | 77 | 78 | class TestOpTest: 79 | def test_optest_initialization(self): 80 | op = torch.ops.aten.relu.default 81 | correctness_tests = [Test(randn(2, 2))] 82 | performance_tests = [Test(randn(100, 100))] 83 | 84 | optest = OpTest(op, correctness_tests, performance_tests) 85 | 86 | assert optest.op == op 87 | assert optest.correctness_tests == correctness_tests 88 | assert optest.performance_tests == performance_tests 89 | 90 | def test_optest_attributes(self): 91 | op = torch.ops.aten.add.Tensor 92 | correctness_tests = [Test(randn(2, 2), randn(2, 2)), Test(randn(3, 3), randn(3, 3))] 93 | performance_tests = [Test(randn(1000, 1000), randn(1000, 1000))] 94 | 95 | optest = OpTest(op, correctness_tests, performance_tests) 96 | 97 | assert len(optest.correctness_tests) == 2 98 | assert len(optest.performance_tests) == 1 99 | 100 | 101 | class TestTestSuite: 102 | def test_testsuite_initialization(self): 103 | optests = [ 104 | OpTest(torch.ops.aten.relu.default, [Test(randn(2))], [Test(randn(100))]), 105 | OpTest( 106 | torch.ops.aten.add.Tensor, 107 | [Test(randn(2), randn(2))], 108 | [Test(randn(100), randn(100))], 109 | ), 110 | ] 111 | 112 | suite = TestSuite("test_suite", optests) 113 | 114 | assert suite.name == "test_suite" 115 | assert suite.optests == optests 116 | 117 | def test_testsuite_iteration(self): 118 | optests = [ 119 | OpTest(torch.ops.aten.relu.default, [Test(randn(2))], [Test(randn(100))]), 120 | OpTest( 121 | torch.ops.aten.add.Tensor, 122 | [Test(randn(2), randn(2))], 123 | [Test(randn(100), randn(100))], 124 | ), 125 | ] 126 | 127 | suite = TestSuite("test_suite", optests) 128 | 129 | collected = list(suite) 130 | assert len(collected) == 2 131 | assert collected[0].op == torch.ops.aten.relu.default 132 | assert collected[1].op == torch.ops.aten.add.Tensor 133 | 134 | 135 | class TestSmokeTestSuiteStructure: 136 | def test_smoke_test_suite_exists(self): 137 | assert isinstance(SmokeTestSuite, TestSuite) 138 | assert SmokeTestSuite.name == "smoke" 139 | 140 | def test_smoke_test_suite_contains_relu(self): 141 | optests = list(SmokeTestSuite) 142 | 143 | assert len(optests) >= 1 144 | assert optests[0].op == get_operator(torch.ops.aten.relu.default) 145 | 146 | # Check correctness tests 147 | assert len(optests[0].correctness_tests) >= 1 148 | correctness_test = optests[0].correctness_tests[0] 149 | args = correctness_test.args 150 | assert len(args) == 1 151 | assert isinstance(args[0], torch.Tensor) 152 | 153 | # Check performance tests 154 | assert len(optests[0].performance_tests) >= 1 155 | perf_test = optests[0].performance_tests[0] 156 | perf_args = perf_test.args 157 | assert len(perf_args) == 1 158 | assert isinstance(perf_args[0], torch.Tensor) 159 | assert perf_args[0].numel() > args[0].numel() 160 | 161 | 162 | class TestSuiteIntegration: 163 | def test_suite_with_multiple_operations(self): 164 | optests = [ 165 | OpTest( 166 | torch.ops.aten.relu.default, 167 | [Test(randn(2, 2)), Test(randn(3, 3))], 168 | [Test(randn(100, 100))], 169 | ), 170 | OpTest(torch.ops.aten.sigmoid.default, [Test(randn(4, 4))], [Test(randn(200, 200))]), 171 | OpTest( 172 | torch.ops.aten.add.Tensor, 173 | [Test(randn(2, 2), randn(2, 2))], 174 | [Test(randn(100, 100), randn(100, 100))], 175 | ), 176 | ] 177 | 178 | suite = TestSuite("integration_test", optests) 179 | 180 | ops_found = [optest.op for optest in suite] 181 | assert torch.ops.aten.relu.default in ops_found 182 | assert torch.ops.aten.sigmoid.default in ops_found 183 | assert torch.ops.aten.add.Tensor in ops_found 184 | 185 | def test_test_args_evaluation_timing(self): 186 | counter = 0 187 | 188 | def counting_fn(): 189 | nonlocal counter 190 | counter += 1 191 | return counter 192 | 193 | test = Test(counting_fn) 194 | 195 | assert test.args == [1] 196 | assert test.args == [2] 197 | assert test.args == [3] 198 | -------------------------------------------------------------------------------- /BackendBench/llm_client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Optional 9 | 10 | import anthropic 11 | import requests 12 | from tenacity import retry 13 | from tenacity.wait import wait_random_exponential 14 | 15 | from .kernel_templates import KernelTemplateManager 16 | 17 | 18 | class LLMKernelGenerator: 19 | """ 20 | LLM Kernel Generator that uses direct Anthropic API. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: str = "claude-sonnet-4-20250514", 26 | ): 27 | self.model = model 28 | self.template_manager = KernelTemplateManager() 29 | self.api_key = os.getenv("ANTHROPIC_API_KEY") 30 | if not self.api_key: 31 | raise ValueError( 32 | "ANTHROPIC_API_KEY must be set in environment or passed to constructor" 33 | ) 34 | assert "claude" in self.model, "Only Claude (Anthropic) models are supported for now" 35 | 36 | self.client = anthropic.Anthropic(api_key=self.api_key) 37 | # check connection to the server 38 | try: 39 | self.client.messages.create( 40 | model=self.model, 41 | max_tokens=8000, 42 | temperature=0.2, 43 | timeout=120.0, 44 | messages=[{"role": "user", "content": "Hello, how are you?"}], 45 | ) 46 | except anthropic.AnthropicError as e: 47 | raise ConnectionError(f"Cannot connect to Anthropic server: {e}") 48 | 49 | @property 50 | def readme_server_description(self) -> str: 51 | return "Direct Anthropic API" 52 | 53 | @property 54 | def readme_setup_section(self) -> str: 55 | return """## Setup 56 | This backend uses the direct Anthropic API and requires: 57 | ```bash 58 | export ANTHROPIC_API_KEY=your_api_key_here 59 | ```""" 60 | 61 | @retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2)) 62 | def call_llm(self, prompt: str) -> str: 63 | response = self.client.messages.create( 64 | model=self.model, 65 | max_tokens=8000, 66 | temperature=0.2, 67 | timeout=120.0, 68 | messages=[{"role": "user", "content": prompt}], 69 | ) 70 | content = response.content[0].text 71 | return content 72 | 73 | def generate_kernel( 74 | self, 75 | op_name: str, 76 | op_signature: str, 77 | op_description: str, 78 | dsl: str = "triton", 79 | feedback: Optional[str] = None, 80 | ) -> str: 81 | if feedback: 82 | prompt = self.template_manager.create_refinement_prompt( 83 | op_name, op_signature, op_description, dsl, feedback 84 | ) 85 | else: 86 | prompt = self.template_manager.create_prompt(op_name, op_signature, op_description, dsl) 87 | 88 | print("\n=== DEBUG: PROMPT SENT TO LLM RELAY ===") 89 | print(prompt) 90 | print("=== END PROMPT ===\n") 91 | 92 | try: 93 | content = self.call_llm(prompt) 94 | if not content: 95 | raise RuntimeError("Empty response from LLM relay server") 96 | 97 | extracted_code = self._extract_code_from_response(content) 98 | 99 | print("\n=== DEBUG: RAW LLM RELAY RESPONSE ===") 100 | print(content) 101 | print("=== DEBUG: EXTRACTED CODE ===") 102 | print(extracted_code) 103 | print("=== END DEBUG ===\n") 104 | 105 | return extracted_code 106 | 107 | except requests.exceptions.RequestException as e: 108 | raise RuntimeError( 109 | f"Failed to communicate with LLM relay server for {op_name}: {str(e)}" 110 | ) 111 | except Exception as e: 112 | raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}") 113 | 114 | def _extract_code_from_response(self, response: str) -> str: 115 | if "```python" not in response: 116 | raise ValueError( 117 | "No Python code block found in LLM response. Response should contain ```python...``` block." 118 | ) 119 | 120 | start = response.find("```python") + len("```python") 121 | end = response.find("```", start) 122 | 123 | if end == -1: 124 | raise ValueError("Unclosed Python code block in LLM response.") 125 | 126 | return response[start:end].strip() 127 | 128 | 129 | class LLMRelayKernelGenerator(LLMKernelGenerator): 130 | """ 131 | LLM Kernel Generator that uses local plugboard server. 132 | Inherits from LLMKernelGenerator and overrides call_llm method. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | server_url: str = "http://127.0.0.1:11434", 138 | model: str = "gcp-claude-4-sonnet", 139 | ): 140 | self.server_url = server_url 141 | self.model = model 142 | self.template_manager = KernelTemplateManager() 143 | # Test connection to the server 144 | try: 145 | requests.get(f"{self.server_url}/", timeout=5) 146 | except requests.exceptions.ConnectionError: 147 | raise ConnectionError(f"Cannot connect to LLM relay server at {self.server_url}. ") 148 | except requests.exceptions.Timeout: 149 | raise TimeoutError(f"Timeout connecting to LLM relay server at {self.server_url}. ") 150 | 151 | @property 152 | def readme_server_description(self) -> str: 153 | return "Local plugboard server (localhost:11434)" 154 | 155 | @property 156 | def readme_setup_section(self) -> str: 157 | return """## Server Setup 158 | This backend requires the plugboard server to be running: 159 | ``` 160 | buck run @//mode/inplace run_plugboard_server -- --model gcp-claude-4-sonnet --pipeline usecase-dev-ai-user 161 | ```""" 162 | 163 | @retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2)) 164 | def call_llm(self, prompt: str) -> str: 165 | # Prepare request data for the plugboard server 166 | request_data = { 167 | "messages": [{"role": "user", "content": prompt}], 168 | "model": self.model, 169 | "temperature": 0.2, 170 | "max_tokens": 8000, 171 | "top_p": 0.95, 172 | } 173 | 174 | # Bypass proxy for localhost connections 175 | proxies = ( 176 | {"http": None, "https": None} 177 | if "127.0.0.1" in self.server_url or "localhost" in self.server_url 178 | else None 179 | ) 180 | 181 | response = requests.post( 182 | self.server_url, 183 | json=request_data, 184 | headers={"Content-Type": "application/json"}, 185 | timeout=120.0, 186 | proxies=proxies, 187 | ) 188 | 189 | if response.status_code != 200: 190 | raise RuntimeError(f"Server returned status {response.status_code}: {response.text}") 191 | 192 | response_data = response.json() 193 | content = response_data.get("output", "") 194 | return content 195 | -------------------------------------------------------------------------------- /BackendBench/kernel_templates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Kernel code templates and prompt engineering for LLM-based kernel generation. 9 | """ 10 | 11 | from typing import Dict 12 | 13 | from .prompts import ( 14 | CUTEDSL_EXAMPLE_TEMPLATES, 15 | CUTEDSL_KERNEL_PROMPT, 16 | CUTEDSL_OPTIMIZATIONS, 17 | HELION_EXAMPLE_TEMPLATES, 18 | HELION_KERNEL_PROMPT, 19 | HELION_OPTIMIZATIONS, 20 | PYTORCH_KERNEL_PROMPT, 21 | TRITON_EXAMPLE_TEMPLATES, 22 | TRITON_KERNEL_PROMPT, 23 | TRITON_OPTIMIZATIONS, 24 | ) 25 | from .utils import op_name_to_folder_name 26 | 27 | 28 | class KernelTemplate: 29 | """Base class for kernel templates.""" 30 | 31 | def __init__(self, name: str, dsl: str): 32 | self.name = name 33 | self.dsl = dsl 34 | 35 | def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: 36 | """Create a prompt for kernel generation.""" 37 | raise NotImplementedError 38 | 39 | 40 | class TritonKernelTemplate(KernelTemplate): 41 | """Template for Triton kernel generation.""" 42 | 43 | def __init__(self): 44 | super().__init__("triton", "triton") 45 | 46 | def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: 47 | """Create a specialized prompt for Triton kernel generation.""" 48 | 49 | # Get operation-specific optimizations 50 | optimizations = self._get_optimizations(op_name) 51 | 52 | # Get example template 53 | example = self._get_example_template(op_name) 54 | 55 | return TRITON_KERNEL_PROMPT.format( 56 | op_name=op_name, 57 | folder_name=op_name_to_folder_name(op_name), 58 | op_signature=op_signature, 59 | op_description=op_description, 60 | optimizations=optimizations, 61 | example=example, 62 | ) 63 | 64 | def _get_optimizations(self, op_name: str) -> str: 65 | """Get operation-specific optimization guidelines.""" 66 | return TRITON_OPTIMIZATIONS.get(op_name, TRITON_OPTIMIZATIONS["default"]) 67 | 68 | def _get_example_template(self, op_name: str) -> str: 69 | """Get operation-specific code template.""" 70 | return TRITON_EXAMPLE_TEMPLATES["default"] 71 | 72 | 73 | class PyTorchKernelTemplate(KernelTemplate): 74 | """Template for pure PyTorch kernel generation.""" 75 | 76 | def __init__(self): 77 | super().__init__("pytorch", "pytorch") 78 | 79 | def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: 80 | """Create a prompt for PyTorch kernel generation.""" 81 | 82 | return PYTORCH_KERNEL_PROMPT.format( 83 | op_name=op_name, op_signature=op_signature, op_description=op_description 84 | ) 85 | 86 | 87 | class CuTeDSLKernelTemplate(KernelTemplate): 88 | """Template for CuTeDSL kernel generation.""" 89 | 90 | def __init__(self): 91 | super().__init__("cutedsl", "cutedsl") 92 | 93 | def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: 94 | """Create a specialized prompt for CuTeDSL kernel generation.""" 95 | 96 | # Get operation-specific optimizations 97 | optimizations = self._get_optimizations(op_name) 98 | 99 | # Get example template 100 | example = self._get_example_template(op_name) 101 | 102 | return CUTEDSL_KERNEL_PROMPT.format( 103 | op_name=op_name, 104 | folder_name=op_name_to_folder_name(op_name), 105 | op_signature=op_signature, 106 | op_description=op_description, 107 | optimizations=optimizations, 108 | example=example, 109 | ) 110 | 111 | def _get_optimizations(self, op_name: str) -> str: 112 | """Get operation-specific optimization guidelines.""" 113 | return CUTEDSL_OPTIMIZATIONS.get(op_name, CUTEDSL_OPTIMIZATIONS["default"]) 114 | 115 | def _get_example_template(self, op_name: str) -> str: 116 | """Get operation-specific code template.""" 117 | return CUTEDSL_EXAMPLE_TEMPLATES["default"] 118 | 119 | 120 | class HelionKernelTemplate(KernelTemplate): 121 | """Template for Helion kernel generation.""" 122 | 123 | def __init__(self): 124 | super().__init__("helion", "helion") 125 | 126 | def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: 127 | optimizations = self._get_optimizations(op_name) 128 | 129 | example = self._get_example_template(op_name) 130 | 131 | return HELION_KERNEL_PROMPT.format( 132 | op_name=op_name, 133 | folder_name=op_name_to_folder_name(op_name), 134 | op_signature=op_signature, 135 | op_description=op_description, 136 | optimizations=optimizations, 137 | example=example, 138 | ) 139 | 140 | def _get_optimizations(self, op_name: str) -> str: 141 | return HELION_OPTIMIZATIONS.get(op_name, HELION_OPTIMIZATIONS["default"]) 142 | 143 | def _get_example_template(self, op_name: str) -> str: 144 | return HELION_EXAMPLE_TEMPLATES["default"] 145 | 146 | 147 | class KernelTemplateManager: 148 | """Manages kernel templates for different dsls.""" 149 | 150 | def __init__(self): 151 | self.templates: Dict[str, KernelTemplate] = { 152 | "triton": TritonKernelTemplate(), 153 | "pytorch": PyTorchKernelTemplate(), 154 | "cutedsl": CuTeDSLKernelTemplate(), 155 | "helion": HelionKernelTemplate(), 156 | # TODO: Add cuda, cutile, whatever we want 157 | } 158 | 159 | def get_template(self, dsl: str) -> KernelTemplate: 160 | """Get template for specified dsl.""" 161 | if dsl not in self.templates: 162 | raise ValueError(f"Unknown dsl: {dsl}") 163 | return self.templates[dsl] 164 | 165 | def create_prompt( 166 | self, 167 | op_name: str, 168 | op_signature: str, 169 | op_description: str, 170 | dsl: str = "triton", 171 | ) -> str: 172 | """Create a prompt using the specified template.""" 173 | template = self.get_template(dsl) 174 | return template.create_prompt(op_name, op_signature, op_description) 175 | 176 | def create_refinement_prompt( 177 | self, 178 | op_name: str, 179 | op_signature: str, 180 | op_description: str, 181 | dsl: str = "triton", 182 | feedback: str = "", 183 | ) -> str: 184 | """Create a refinement prompt with feedback from previous attempts.""" 185 | base_prompt = self.create_prompt(op_name, op_signature, op_description, dsl) 186 | 187 | if feedback and feedback.strip(): 188 | refinement_prompt = f"""{feedback} 189 | 190 | {base_prompt} 191 | 192 | Use the above feedback and generate improved code.""" 193 | else: 194 | # Fallback if no feedback 195 | refinement_prompt = f"""{base_prompt} 196 | 197 | The previous attempt failed. Please generate a corrected version.""" 198 | 199 | return refinement_prompt 200 | -------------------------------------------------------------------------------- /BackendBench/scripts/get_tests_stat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This is a helper script to analyze the test suite and provide statistics about the number of tests per operation. 9 | """ 10 | 11 | import statistics 12 | 13 | import torch 14 | 15 | from BackendBench.suite import FactoTestSuite, OpInfoTestSuite, TorchBenchTestSuite 16 | 17 | from .setup_operator_directories import extract_operator_name 18 | 19 | 20 | def analyze_test_suite(suite): 21 | test_counts = {} 22 | total_correctness_tests = 0 23 | total_performance_tests = 0 24 | 25 | print(f"Analyzing suite: {suite.name}") 26 | 27 | for op_test in suite: 28 | op_str = str(op_test.op) 29 | op_name = extract_operator_name(op_str) 30 | 31 | num_correctness_tests = 0 32 | for _ in op_test.correctness_tests: 33 | num_correctness_tests += 1 34 | 35 | # In case later we have different tests for performance 36 | num_performance_tests = 0 37 | for _ in op_test.performance_tests: 38 | num_performance_tests += 1 39 | 40 | if op_name not in test_counts: 41 | test_counts[op_name] = { 42 | "full_op_str": [], 43 | "num_correctness_tests": 0, 44 | "num_performance_tests": 0, 45 | } 46 | test_counts[op_name]["full_op_str"].append(op_str) 47 | test_counts[op_name]["num_correctness_tests"] += num_correctness_tests 48 | test_counts[op_name]["num_performance_tests"] += num_performance_tests 49 | 50 | total_correctness_tests += num_correctness_tests 51 | total_performance_tests += num_performance_tests 52 | # print(f" {op_name}: {num_correctness_tests} correctness tests and {num_performance_tests} performance tests") 53 | 54 | # Calculate statistics 55 | test_numbers_dict = { 56 | "correctness": [info["num_correctness_tests"] for info in test_counts.values()], 57 | "performance": [info["num_performance_tests"] for info in test_counts.values()], 58 | } 59 | 60 | results = {"operations": test_counts} 61 | for correct_or_perf, test_numbers in test_numbers_dict.items(): 62 | if correct_or_perf == "correctness": 63 | total_tests = total_correctness_tests 64 | elif correct_or_perf == "performance": 65 | total_tests = total_performance_tests 66 | if test_numbers: 67 | stats = { 68 | "total_operations": len(test_counts), 69 | "total_tests": total_tests, 70 | "min_tests": min(test_numbers), 71 | "max_tests": max(test_numbers), 72 | "mean_tests": statistics.mean(test_numbers), 73 | "median_tests": statistics.median(test_numbers), 74 | "stdev_tests": statistics.stdev(test_numbers) if len(test_numbers) > 1 else 0.0, 75 | } 76 | else: 77 | stats = { 78 | "total_operations": 0, 79 | "total_tests": 0, 80 | "min_tests": 0, 81 | "max_tests": 0, 82 | "mean_tests": 0.0, 83 | "median_tests": 0.0, 84 | "stdev_tests": 0.0, 85 | } 86 | results[f"{correct_or_perf}_stats"] = stats 87 | 88 | return results 89 | 90 | 91 | def print_summary(analysis_results, suite_name, correct_or_perf="correctness"): 92 | """Print a formatted summary of the analysis results.""" 93 | if correct_or_perf not in ["correctness", "performance"]: 94 | raise ValueError(f"Invalid value for 'correct_or_perf': {correct_or_perf}") 95 | stats = analysis_results[f"{correct_or_perf}_stats"] 96 | num_tests_str = f"num_{correct_or_perf}_tests" 97 | 98 | operations = analysis_results["operations"] 99 | 100 | print(f"\n{'=' * 60}") 101 | print(f"TEST STATISTICS SUMMARY FOR {suite_name.upper()} SUITE {correct_or_perf.upper()} TESTS") 102 | print(f"{'=' * 60}") 103 | 104 | print("\nOverall Statistics:") 105 | print(f" Total operations: {stats['total_operations']}") 106 | print(f" Total tests: {stats['total_tests']}") 107 | print(f" Average tests per operation: {stats['mean_tests']:.2f}") 108 | print(f" Median tests per operation: {stats['median_tests']:.2f}") 109 | print(f" Min tests per operation: {stats['min_tests']}") 110 | print(f" Max tests per operation: {stats['max_tests']}") 111 | print(f" Standard deviation: {stats['stdev_tests']:.2f}") 112 | 113 | if operations: 114 | test_counts = [info[num_tests_str] for info in operations.values()] 115 | bins = [1, 5, 10, 25, 50, 100, float("inf")] 116 | bin_labels = ["1-4", "5-9", "10-24", "25-49", "50-99", "100+"] 117 | distribution = [0] * (len(bins)) 118 | 119 | for count in test_counts: 120 | for i in range(len(bins) - 1): 121 | if bins[i] <= count < bins[i + 1]: 122 | distribution[i] += 1 123 | break 124 | 125 | print("\nTest Count Distribution:") 126 | for i, (label, count) in enumerate(zip(bin_labels, distribution)): 127 | percentage = (count / len(operations)) * 100 128 | print(f" {label:>6} tests: {count:3d} operations ({percentage:5.1f}%)") 129 | 130 | 131 | def main(): 132 | suite_dict = { 133 | "opinfo": OpInfoTestSuite( 134 | "opinfo_cuda_bfloat16", 135 | "cuda", 136 | torch.bfloat16, 137 | ), 138 | "torchbench": TorchBenchTestSuite( 139 | "torchbench", 140 | ), 141 | "facto": FactoTestSuite( 142 | "facto_cuda_bfloat16", 143 | "cuda", 144 | torch.bfloat16, 145 | ), 146 | } 147 | 148 | # Analyze each suite 149 | all_results = {} 150 | for suite_name, suite in suite_dict.items(): 151 | print(f"\n{'=' * 60}") 152 | print(f"ANALYZING {suite_name.upper()} SUITE") 153 | print(f"{'=' * 60}") 154 | 155 | results = analyze_test_suite(suite) 156 | all_results[suite_name] = results 157 | 158 | print_summary(results, suite_name, "correctness") 159 | print_summary(results, suite_name, "performance") 160 | 161 | # If analyzing all suites, provide a comparison summary 162 | if len(all_results) > 1: 163 | print(f"\n{'=' * 60}") 164 | print("SUITE COMPARISON SUMMARY") 165 | print(f"{'=' * 60}") 166 | 167 | print( 168 | f"{'Suite':<15} {'Ops':<8} {'Tests':<8} {'Avg':<8} {'Min':<6} {'Max':<6} {'StdDev':<8}" 169 | ) 170 | print("-" * 65) 171 | 172 | for correct_or_perf in ["correctness", "performance"]: 173 | print(f"\nFor {correct_or_perf} tests") 174 | for suite_name, results in all_results.items(): 175 | stats = results[f"{correct_or_perf}_stats"] 176 | print( 177 | f"{suite_name:<15} " 178 | f"{stats['total_operations']:<8} " 179 | f"{stats['total_tests']:<8} " 180 | f"{stats['mean_tests']:<8.1f} " 181 | f"{stats['min_tests']:<6} " 182 | f"{stats['max_tests']:<6} " 183 | f"{stats['stdev_tests']:<8.1f}" 184 | ) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /BackendBench/scripts/create_watermarked_operators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD 3-Clause license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Create watermarked operator implementations that return constant tensors. 11 | These implementations will verify monkey patching works but will fail correctness tests. 12 | """ 13 | 14 | import argparse 15 | import hashlib 16 | import os 17 | from pathlib import Path 18 | 19 | from BackendBench.utils import op_name_to_folder_name 20 | 21 | WATERMARK_BASE = 42.0 22 | 23 | 24 | def get_operator_watermark_value(op_name: str, base_value: float = WATERMARK_BASE) -> float: 25 | """Generate a unique watermark value for each operator to catch cross-contamination.""" 26 | op_name = op_name_to_folder_name(op_name) # if op_name is add.Tensor, convert to add__Tensor 27 | op_hash = hashlib.md5(op_name.encode("utf-8")).hexdigest() 28 | hash_int = int(op_hash[:8], 16) # Use first 8 hex chars 29 | return base_value + (hash_int % 100) 30 | 31 | 32 | def create_watermarked_impl( 33 | op_name: str, watermark_value: float = None, use_unique_watermarks: bool = True 34 | ) -> str: 35 | """Generate a watermarked implementation that returns a constant tensor.""" 36 | 37 | if watermark_value is None: 38 | watermark_value = WATERMARK_BASE 39 | if use_unique_watermarks: 40 | watermark_value = get_operator_watermark_value(op_name, watermark_value) 41 | 42 | return f'''# Watermarked implementation for {op_name} operator 43 | # This implementation returns a constant tensor to verify monkey patching 44 | # Watermark value: {watermark_value} 45 | 46 | import torch 47 | 48 | def {op_name}_kernel_impl(*args, **kwargs): 49 | """Watermarked implementation of {op_name}. 50 | 51 | Returns a tensor filled with {watermark_value} to verify the operator 52 | is being called through DirectoryBackend. This will fail correctness 53 | tests but confirms the monkey patching mechanism is working. 54 | """ 55 | # Find the first tensor argument to determine output shape and device 56 | tensor_arg = None 57 | for arg in args: 58 | if isinstance(arg, torch.Tensor): 59 | tensor_arg = arg 60 | break 61 | 62 | if tensor_arg is not None: 63 | # Return a tensor with same shape, dtype, and device as input 64 | result = torch.full_like(tensor_arg, {watermark_value}) 65 | return result 66 | else: 67 | # Fallback for operators without tensor inputs 68 | # Return a scalar tensor 69 | return torch.tensor({watermark_value}) 70 | ''' 71 | 72 | 73 | def create_watermarked_operators( 74 | base_dir: str = "generated_kernels", 75 | watermark_value: float = None, 76 | overwrite: bool = False, 77 | use_unique_watermarks: bool = False, 78 | ): 79 | """Create watermarked implementations for all operators in the directory structure.""" 80 | 81 | base_path = Path(base_dir) 82 | if not base_path.exists(): 83 | print(f"Error: Directory {base_path} does not exist.") 84 | print("Please run setup_operator_directories.py first.") 85 | return 86 | 87 | created_count = 0 88 | skipped_count = 0 89 | 90 | # Iterate through all operator directories 91 | for op_dir in base_path.iterdir(): 92 | if not op_dir.is_dir() or op_dir.name == "__pycache__": 93 | continue 94 | 95 | op_name = op_dir.name 96 | impl_file = op_dir / f"{op_name}_implementation_v1.py" 97 | 98 | # Skip if file exists and overwrite is False 99 | if impl_file.exists() and not overwrite: 100 | skipped_count += 1 101 | continue 102 | 103 | # Create watermarked implementation 104 | impl_content = create_watermarked_impl(op_name, watermark_value, use_unique_watermarks) 105 | impl_file.write_text(impl_content) 106 | created_count += 1 107 | 108 | print("\nWatermarked operator creation complete:") 109 | print(f"- Created {created_count} watermarked implementations") 110 | print(f"- Skipped {skipped_count} existing implementations") 111 | if use_unique_watermarks: 112 | print(f"- Using unique watermarks per operator (base: {watermark_value or WATERMARK_BASE})") 113 | else: 114 | print(f"- Using uniform watermark value: {watermark_value or WATERMARK_BASE}") 115 | print(f"- Base directory: {base_path.absolute()}") 116 | 117 | # Create a verification script 118 | verification_script = base_path / "verify_watermarks.py" 119 | 120 | # Generate some sample expected values for verification 121 | sample_ops = ["relu.default", "add.Tensor", "mul.Tensor", "sub.Tensor", "div.Tensor"] 122 | expected_values = {} 123 | for op in sample_ops: 124 | if use_unique_watermarks: 125 | expected_values[op] = get_operator_watermark_value( 126 | op, watermark_value or WATERMARK_BASE 127 | ) 128 | else: 129 | expected_values[op] = watermark_value or WATERMARK_BASE 130 | 131 | verification_content = f'''#!/usr/bin/env python3 132 | """Verify that watermarked operators are being loaded correctly.""" 133 | 134 | import torch 135 | from BackendBench.backends import DirectoryBackend 136 | 137 | # Expected watermark values (unique per operator: {use_unique_watermarks}) 138 | EXPECTED_VALUES = {expected_values} 139 | USE_UNIQUE_WATERMARKS = {use_unique_watermarks} 140 | 141 | def get_expected_watermark(op_name): 142 | """Get expected watermark value for an operator.""" 143 | if USE_UNIQUE_WATERMARKS: 144 | import hashlib 145 | op_hash = hashlib.md5(op_name.encode('utf-8')).hexdigest() 146 | hash_int = int(op_hash[:8], 16) 147 | return {watermark_value or WATERMARK_BASE} + (hash_int % 100) 148 | else: 149 | return {watermark_value or WATERMARK_BASE} 150 | 151 | # Load the backend 152 | backend = DirectoryBackend("{base_dir}") 153 | 154 | # Test operators 155 | test_ops = list(EXPECTED_VALUES.keys()) 156 | 157 | print(f"Testing watermarked operators...") 158 | print(f"Unique watermarks per operator: {{USE_UNIQUE_WATERMARKS}}") 159 | print(f"Loaded {{len(backend.compiled_kernels)}} operators\\n") 160 | 161 | for op_name in test_ops: 162 | expected_value = get_expected_watermark(op_name) 163 | 164 | # Try to find the operator 165 | found = False 166 | for torch_op in backend.compiled_kernels: 167 | if op_name in str(torch_op): 168 | # Test the operator 169 | try: 170 | x = torch.tensor([1.0, 2.0, 3.0]) 171 | result = backend[torch_op](x) 172 | 173 | if torch.allclose(result, torch.full_like(x, expected_value)): 174 | print(f"✓ {{op_name}}: Watermark {{expected_value}} detected correctly") 175 | else: 176 | print(f"✗ {{op_name}}: Expected {{expected_value}}, got {{result}}") 177 | 178 | found = True 179 | break 180 | except Exception as e: 181 | print(f"✗ {{op_name}}: Error - {{e}}") 182 | found = True 183 | break 184 | 185 | if not found: 186 | print(f"? {{op_name}}: Not found in loaded operators") 187 | ''' 188 | 189 | verification_script.write_text(verification_content) 190 | os.chmod(verification_script, 0o755) 191 | 192 | print(f"\nCreated verification script: {verification_script}") 193 | print("\nTo verify watermarks are working:") 194 | print(f" python {verification_script}") 195 | print("\nTo test with evaluation harness (should fail correctness):") 196 | print(" python -m BackendBench.scripts.main --backend directory --ops relu,add --suite smoke") 197 | 198 | 199 | def main(): 200 | parser = argparse.ArgumentParser( 201 | description="Create watermarked operator implementations for testing" 202 | ) 203 | parser.add_argument( 204 | "--base-dir", 205 | default="generated_kernels", 206 | help="Base directory containing operator subdirectories", 207 | ) 208 | parser.add_argument( 209 | "--watermark-value", 210 | type=float, 211 | default=None, 212 | help=f"Base value to use for watermarking (default: {WATERMARK_BASE})", 213 | ) 214 | parser.add_argument( 215 | "--overwrite", action="store_true", help="Overwrite existing implementation files" 216 | ) 217 | parser.add_argument( 218 | "--unique-watermarks", 219 | action="store_true", 220 | help="Use unique watermark values per operator (default: uniform 42.0)", 221 | ) 222 | 223 | args = parser.parse_args() 224 | 225 | use_unique_watermarks = args.unique_watermarks 226 | 227 | create_watermarked_operators( 228 | args.base_dir, args.watermark_value, args.overwrite, use_unique_watermarks 229 | ) 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /test/test_monkey_patch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD 3-Clause license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Test monkey patching of directory backend. 11 | """ 12 | 13 | import os 14 | import shutil 15 | import subprocess 16 | import sys 17 | 18 | import pytest 19 | import torch 20 | 21 | import BackendBench 22 | from BackendBench.scripts.create_watermarked_operators import get_operator_watermark_value 23 | from BackendBench.utils import op_name_to_folder_name 24 | 25 | 26 | class TestMonkeyPatch: 27 | """Verify monkey patching of directory backend.""" 28 | 29 | kernel_dir_relu = "generated_kernels_test_monkey_patch_relu" 30 | kernel_dir_leaky_relu = "generated_kernels_test_monkey_path_leaky_relu" 31 | 32 | def setup_watermarked_kernel_dir(self, kernel_dir, ops=None): 33 | """Generate required directory structure and operators.""" 34 | # Generate the directory structure 35 | command_list = [ 36 | sys.executable, 37 | "-m", 38 | "BackendBench.scripts.setup_operator_directories", 39 | "--base-dir", 40 | kernel_dir, 41 | ] 42 | subprocess.run( 43 | command_list, 44 | check=True, 45 | ) 46 | # Clean up directory structure and only keep the specified ops 47 | if ops: 48 | ops = [op_name_to_folder_name(op) for op in ops] 49 | for directory in os.listdir(kernel_dir): 50 | if directory not in ops and os.path.isdir(os.path.join(kernel_dir, directory)): 51 | shutil.rmtree(os.path.join(kernel_dir, directory)) 52 | 53 | command_list = [ 54 | sys.executable, 55 | "-m", 56 | "BackendBench.scripts.create_watermarked_operators", 57 | "--base-dir", 58 | kernel_dir, 59 | "--overwrite", 60 | "--unique-watermarks", 61 | ] 62 | 63 | subprocess.run( 64 | command_list, 65 | check=True, 66 | ) 67 | 68 | def cleanup_kernel_dir(self, kernel_dir): 69 | shutil.rmtree(kernel_dir) 70 | 71 | @pytest.fixture(scope="module") 72 | def setup_dir_relu(self): 73 | """Generate required directory structure and operators.""" 74 | self.setup_watermarked_kernel_dir( 75 | self.kernel_dir_relu, ["relu.default", "add.Tensor", "add.Scalar"] 76 | ) 77 | self.setup_watermarked_kernel_dir(self.kernel_dir_leaky_relu, ["leaky_relu.default"]) 78 | 79 | yield 80 | 81 | self.cleanup_kernel_dir(self.kernel_dir_relu) 82 | self.cleanup_kernel_dir(self.kernel_dir_leaky_relu) 83 | 84 | @pytest.mark.parametrize( 85 | "device", 86 | [ 87 | "cpu", 88 | pytest.param( 89 | "cuda", 90 | marks=pytest.mark.skipif( 91 | not torch.cuda.is_available(), reason="CUDA not available" 92 | ), 93 | ), 94 | ], 95 | ) 96 | def test_monkey_patch_relu(self, setup_dir_relu, device): 97 | BackendBench.disable() # In case monkey patching is enabled from previous test 98 | relu = torch.ops.aten.relu.default 99 | x = torch.tensor([-1.0, 0.0, 1.0], device=device) 100 | expected = torch.tensor([0.0, 0.0, 1.0], device=device) 101 | watermarked = torch.full_like(x, get_operator_watermark_value("relu.default")) 102 | 103 | torch.testing.assert_close(relu(x), expected) 104 | 105 | # Enable monkey patching 106 | BackendBench.enable(kernel_dir=self.kernel_dir_relu, dispatch_key=device.upper()) 107 | 108 | torch.testing.assert_close(relu(x), watermarked) 109 | 110 | # Disable monkey patching 111 | BackendBench.disable() 112 | torch.testing.assert_close(relu(x), expected) 113 | 114 | @pytest.mark.parametrize( 115 | "device", 116 | [ 117 | "cpu", 118 | pytest.param( 119 | "cuda", 120 | marks=pytest.mark.skipif( 121 | not torch.cuda.is_available(), reason="CUDA not available" 122 | ), 123 | ), 124 | ], 125 | ) 126 | def test_monkey_patch_add(self, setup_dir_relu, device): 127 | # This test ensures that monkey patching is applied only to the add.Tensor overload, 128 | # and not to add.Scalar. 129 | BackendBench.disable() # In case monkey patching is enabled from previous test 130 | add__Scalar = torch.ops.aten.add.Scalar 131 | add__Tensor = torch.ops.aten.add.Tensor 132 | x = torch.tensor([-1.0, 0.0, 1.0], device=device) 133 | y_tensor = torch.tensor([1.0, 1.0, 1.0], device=device) 134 | y_scalar = 1.0 135 | expected = torch.tensor([0.0, 1.0, 2.0], device=device) 136 | watermarked = torch.full_like(x, get_operator_watermark_value("add.Tensor")) 137 | 138 | torch.testing.assert_close(add__Tensor(x, y_tensor), expected) 139 | torch.testing.assert_close(add__Scalar(x, y_scalar), expected) 140 | 141 | # Enable monkey patching 142 | BackendBench.enable(kernel_dir=self.kernel_dir_relu, dispatch_key=device.upper()) 143 | 144 | torch.testing.assert_close(add__Tensor(x, y_tensor), watermarked) 145 | torch.testing.assert_close(add__Scalar(x, y_scalar), expected) 146 | 147 | # Disable monkey patching 148 | BackendBench.disable() 149 | torch.testing.assert_close(add__Tensor(x, y_tensor), expected) 150 | torch.testing.assert_close(add__Scalar(x, y_scalar), expected) 151 | 152 | @pytest.mark.parametrize( 153 | "device", 154 | [ 155 | "cpu", 156 | pytest.param( 157 | "cuda", 158 | marks=pytest.mark.skipif( 159 | not torch.cuda.is_available(), reason="CUDA not available" 160 | ), 161 | ), 162 | ], 163 | ) 164 | def test_context_manager_relu(self, setup_dir_relu, device): 165 | """Test that context manager enables and disables correctly.""" 166 | BackendBench.disable() 167 | relu = torch.ops.aten.relu.default 168 | x = torch.tensor([-1.0, 0.0, 1.0], device=device) 169 | expected = torch.tensor([0.0, 0.0, 1.0], device=device) 170 | watermarked = torch.full_like(x, get_operator_watermark_value("relu.default")) 171 | 172 | torch.testing.assert_close(relu(x), expected) 173 | 174 | with BackendBench.BackendBench.enable( 175 | kernel_dir=self.kernel_dir_relu, dispatch_key=device.upper() 176 | ): 177 | torch.testing.assert_close(relu(x), watermarked) 178 | 179 | torch.testing.assert_close(relu(x), expected) 180 | 181 | def test_context_manager_nested_behavior(self, setup_dir_relu): 182 | """Test context manager behavior when BackendBench is already enabled.""" 183 | BackendBench.disable() 184 | relu = torch.ops.aten.relu.default 185 | leaky_relu = torch.ops.aten.leaky_relu.default 186 | x = torch.tensor([-1.0, 0.0, 1.0]) 187 | expected_relu = torch.tensor([0.0, 0.0, 1.0]) 188 | expected_leaky_relu = torch.tensor([-0.01, 0.0, 1.0]) 189 | watermarked_relu = torch.full_like(x, get_operator_watermark_value("relu.default")) 190 | watermarked_leaky_relu = torch.full_like( 191 | x, get_operator_watermark_value("leaky_relu.default") 192 | ) 193 | 194 | BackendBench.enable(kernel_dir=self.kernel_dir_relu, dispatch_key="CPU") 195 | 196 | torch.testing.assert_close(relu(x), watermarked_relu) 197 | torch.testing.assert_close(leaky_relu(x), expected_leaky_relu) 198 | 199 | with BackendBench.BackendBench.enable( 200 | kernel_dir=self.kernel_dir_leaky_relu, dispatch_key="CPU" 201 | ): 202 | torch.testing.assert_close(relu(x), watermarked_relu) 203 | torch.testing.assert_close(leaky_relu(x), watermarked_leaky_relu) 204 | 205 | torch.testing.assert_close(relu(x), watermarked_relu) 206 | torch.testing.assert_close(leaky_relu(x), expected_leaky_relu) 207 | 208 | BackendBench.disable() 209 | torch.testing.assert_close(relu(x), expected_relu) 210 | torch.testing.assert_close(leaky_relu(x), expected_leaky_relu) 211 | 212 | def test_context_manager_with_exception(self, setup_dir_relu): 213 | """Test that context manager properly disables even when exception occurs.""" 214 | BackendBench.disable() 215 | relu = torch.ops.aten.relu.default 216 | x = torch.tensor([-1.0, 0.0, 1.0]) 217 | expected = torch.tensor([0.0, 0.0, 1.0]) 218 | 219 | torch.testing.assert_close(relu(x), expected) 220 | 221 | try: 222 | with BackendBench.BackendBench.enable( 223 | kernel_dir=self.kernel_dir_relu, dispatch_key="CPU" 224 | ): 225 | raise ValueError("Test exception") 226 | except ValueError: 227 | pass 228 | 229 | torch.testing.assert_close(relu(x), expected) 230 | -------------------------------------------------------------------------------- /test/test_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.util 8 | 9 | import numpy as np 10 | import pytest 11 | import torch 12 | 13 | from BackendBench.eval import ( 14 | allclose, 15 | cpu_bench, 16 | eval_correctness, 17 | eval_correctness_test, 18 | eval_one_op, 19 | format_exception, 20 | perf_at_p, 21 | ) 22 | 23 | HAS_TRITON = importlib.util.find_spec("triton") is not None 24 | 25 | pytestmark = pytest.mark.skipif(not HAS_TRITON, reason="triton not available") 26 | 27 | 28 | class TestFormatFunctions: 29 | def test_format_exception(self): 30 | op = torch.ops.aten.relu.default 31 | args = [torch.randn(2, 3)] 32 | kwargs = {"dim": 1} 33 | exc = ValueError("Test error") 34 | 35 | formatted = format_exception(exc, op, args, kwargs) 36 | assert "relu.default" in formatted 37 | assert "(T([2, 3], f32)" in formatted 38 | assert "dim" in formatted 39 | assert "Test error" in formatted 40 | 41 | 42 | class TestAllclose: 43 | def test_allclose_tensors(self): 44 | tensor1 = torch.tensor([1.0, 2.0, 3.0]) 45 | tensor2 = torch.tensor([1.0, 2.0, 3.0]) 46 | 47 | assert allclose(tensor1, tensor2) is True 48 | 49 | tensor3 = torch.tensor([1.0, 2.0, 3.01]) 50 | assert allclose(tensor1, tensor3) is True 51 | 52 | tensor_nan1 = torch.tensor([1.0, float("nan"), 3.0]) 53 | tensor_nan2 = torch.tensor([1.0, float("nan"), 3.0]) 54 | assert allclose(tensor_nan1, tensor_nan2) is True 55 | 56 | def test_allclose_scalars(self): 57 | assert allclose(1, 1) is True 58 | assert allclose(1.0, 1.0) is True 59 | assert allclose("test", "test") is True 60 | assert allclose(1, 2) is False 61 | 62 | def test_allclose_tuples_lists_with_tolerances(self): 63 | """Test tuple/list comparison with specified tolerances""" 64 | atol, rtol = 1e-2, 1e-2 65 | 66 | # Lists of tensors - exact match 67 | list1 = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] 68 | list2 = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] 69 | assert allclose(list1, list2, atol=atol, rtol=rtol) is True 70 | 71 | # Lists of tensors - within tolerance 72 | list1 = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] 73 | list2 = [torch.tensor([1.01, 2.01]), torch.tensor([3.01, 4.01])] 74 | assert allclose(list1, list2, atol=atol, rtol=rtol) is True 75 | 76 | # Lists of tensors - outside tolerance 77 | list1 = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] 78 | list2 = [torch.tensor([1.1, 2.1]), torch.tensor([3.1, 4.1])] 79 | assert allclose(list1, list2, atol=atol, rtol=rtol) is False 80 | 81 | # Tuples of tensors 82 | tuple1 = (torch.tensor([1.0]), torch.tensor([2.0])) 83 | tuple2 = (torch.tensor([1.01]), torch.tensor([2.01])) 84 | assert allclose(tuple1, tuple2, atol=atol, rtol=rtol) is True 85 | 86 | # Nested structures 87 | nested1 = [[torch.tensor([1.0])], (torch.tensor([2.0]), torch.tensor([3.0]))] 88 | nested2 = [[torch.tensor([1.01])], (torch.tensor([2.01]), torch.tensor([3.01]))] 89 | assert allclose(nested1, nested2, atol=atol, rtol=rtol) is True 90 | 91 | # Length mismatch 92 | list1 = [torch.tensor([1.0]), torch.tensor([2.0])] 93 | list2 = [torch.tensor([1.0])] 94 | assert allclose(list1, list2, atol=atol, rtol=rtol) is False 95 | 96 | 97 | class TestEvalCorrectness: 98 | def test_eval_correctness_test_pass(self): 99 | # Use actual torch operations 100 | op = torch.relu 101 | impl = torch.relu # Same implementation should pass 102 | 103 | class TestCase: 104 | def __init__(self, args, kwargs): 105 | self.args = args 106 | self.kwargs = kwargs 107 | 108 | test = TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) 109 | 110 | result = eval_correctness_test(op, impl, test) 111 | assert result.is_correct is True 112 | 113 | def test_eval_correctness_test_fail(self): 114 | # Use different operations that produce different results 115 | op = torch.relu 116 | 117 | def impl(x): 118 | return x * 2 # Different implementation 119 | 120 | class TestCase: 121 | def __init__(self, args, kwargs): 122 | self.args = args 123 | self.kwargs = kwargs 124 | 125 | test = TestCase([torch.tensor([1.0, 2.0, 3.0])], {}) 126 | 127 | result = eval_correctness_test(op, impl, test) 128 | assert result.is_correct is False 129 | 130 | def test_eval_correctness_test_exception(self): 131 | op = torch.relu 132 | 133 | def impl_with_error(x): 134 | raise RuntimeError("Test error") 135 | 136 | class TestCase: 137 | def __init__(self, args, kwargs): 138 | self.args = args 139 | self.kwargs = kwargs 140 | 141 | test = TestCase([torch.tensor([1.0])], {}) 142 | 143 | # Just test that it returns False on exception 144 | result = eval_correctness_test(op, impl_with_error, test) 145 | assert result.is_correct is False 146 | assert result.error_msg is not None # Should have an error message 147 | 148 | def test_eval_correctness_multiple_tests(self): 149 | op = torch.abs 150 | impl = torch.abs # Same implementation 151 | 152 | class TestCase: 153 | def __init__(self, args, kwargs): 154 | self.args = args 155 | self.kwargs = kwargs 156 | 157 | tests = [] 158 | for i in range(5): 159 | test = TestCase([torch.tensor([float(i) - 2.5])], {}) 160 | tests.append(test) 161 | 162 | score, correctness_results = eval_correctness(op, impl, tests) 163 | assert score == 1.0 164 | assert len(correctness_results) == len(tests) 165 | 166 | 167 | class TestEvalPerformance: 168 | def test_cpu_bench(self): 169 | counter = 0 170 | 171 | def test_fn(): 172 | nonlocal counter 173 | counter += 1 174 | 175 | # Actually run the benchmark 176 | time_per_run = cpu_bench(test_fn, num_runs=10) 177 | 178 | # Should have run 10 warmup runs + 10 actual runs = 20 total 179 | assert counter == 20 180 | assert time_per_run > 0 181 | 182 | 183 | class TestEvalOneOp: 184 | def test_eval_one_op(self): 185 | op = torch.relu 186 | impl = torch.relu # Same implementation 187 | 188 | class TestCase: 189 | def __init__(self, args, kwargs): 190 | self.args = args 191 | self.kwargs = kwargs 192 | 193 | correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)] 194 | performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)] 195 | 196 | correctness, performance, correctness_results, performance_results = eval_one_op( 197 | op, impl, correctness_tests, performance_tests 198 | ) 199 | 200 | # Should have perfect correctness since using same implementation 201 | assert correctness == 1.0 202 | # Performance should be around 1.0 (same speed) 203 | assert performance.item() > 0 204 | # Verbose data should be populated 205 | assert len(correctness_results) == len(correctness_tests) 206 | assert len(performance_results) == len(performance_tests) 207 | 208 | 209 | def fastp_kernel_bench( 210 | is_correct: np.ndarray, 211 | baseline_speed: np.ndarray, 212 | actual_speed: np.ndarray, 213 | n: int, 214 | p: float, 215 | ) -> float: 216 | """ 217 | Original fastp implementation from kernelBench 218 | """ 219 | filtered_baseline_speed = np.array([x for i, x in enumerate(baseline_speed) if is_correct[i]]) 220 | filtered_actual_speed = np.array([x for i, x in enumerate(actual_speed) if is_correct[i]]) 221 | speed_up = filtered_baseline_speed / filtered_actual_speed 222 | fast_p_score = np.sum(speed_up > p) 223 | return fast_p_score / n if n > 0 else 0 224 | 225 | 226 | class TestPerfAtP: 227 | def get_results(self, num_tests=100): 228 | overall_correctness = np.random.randint(0, 2, size=num_tests) 229 | overall_performance = np.random.uniform(0.5, 2, size=num_tests) 230 | return overall_correctness, overall_performance 231 | 232 | def test_perf_at_p(self): 233 | for num_tests in [5, 10, 50, 100]: 234 | for p in [0, 1, 1.5, 2]: 235 | overall_correctness, overall_performance = self.get_results(num_tests) 236 | 237 | actual_speed = np.random.randint(1, 101, size=num_tests) 238 | baseline_speed = actual_speed * overall_performance 239 | fastp_score_orig = fastp_kernel_bench( 240 | overall_correctness, baseline_speed, actual_speed, num_tests, p 241 | ) 242 | 243 | # Note: The perf@p score calculation here differs subtly from the original fastp score in 244 | # kernel bench. The original fastp score filters correct samples first, then averages. 245 | # Here, perf@p averages first, then filters correct samples. Despite this difference, 246 | # both methods produce equivalent results, so the test remains valid. 247 | perf_at_p_score = perf_at_p( 248 | overall_correctness.tolist(), overall_performance.tolist(), p 249 | ) 250 | 251 | assert torch.allclose( 252 | perf_at_p_score, torch.tensor(fastp_score_orig, dtype=torch.float32) 253 | ) 254 | -------------------------------------------------------------------------------- /BackendBench/data_loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Shared data loading utilities for reading trace and parquet files. 9 | """ 10 | 11 | import hashlib 12 | import logging 13 | import re 14 | from pathlib import Path 15 | from typing import Dict, List, Optional, Union 16 | 17 | import pyarrow.parquet as pq 18 | import requests 19 | import torch 20 | from datasets import load_dataset 21 | from tqdm import tqdm 22 | 23 | # constants for downloading the test set from huggingface 24 | # you can explore the dataset here 25 | # https://huggingface.co/datasets/GPUMODE/backendbench_tests 26 | HUGGINGFACE_REPO = "GPUMODE/backendbench_tests" 27 | TORCHBENCH_SUITE_HF_COMMIT = "ca7b1361b162d1499cb22ea4ad589dae506ead5d" 28 | TORCHBENCH_SUITE_FILE = "backend_bench_problems.parquet" 29 | 30 | 31 | def _args_size(args): 32 | """Calculate the size of arguments in bytes.""" 33 | 34 | size = 0 35 | for arg in args: 36 | if isinstance(arg, torch.Tensor): 37 | size += arg.numel() * arg.element_size() 38 | elif isinstance(arg, (tuple, list)): 39 | size += _args_size(arg) 40 | return size 41 | 42 | 43 | def _parse_trace_file( 44 | filename: str, filter: Optional[List[str]] = None, limit: Optional[int] = None 45 | ) -> List[Dict]: 46 | """ 47 | Parse a single trace file and return a list of operation dictionaries. 48 | 49 | Args: 50 | filename: Path to trace file 51 | filter: Optional list of operation name filters 52 | """ 53 | op_inputs = [] 54 | op = None 55 | num_ops = 0 56 | 57 | with open(filename, "r") as f: 58 | lines = list(f) 59 | print(f"parsing {len(lines)} lines from {filename}") 60 | iterator = tqdm(lines, desc=f"Parsing {Path(filename).name}") 61 | for line in iterator: 62 | if m := re.match("Operator: (.*)", line): 63 | num_ops += 1 64 | if limit: 65 | if num_ops > limit: 66 | break 67 | op = m.group(1) 68 | # this is due to a version skew error of the pytorch version we're 69 | # using for developing BackendBench and what was used in tritonbench where 70 | # SymInt didn't exist. 71 | # @todo: see if we can remove this before releasing 72 | if op == "aten.sum.SymInt": 73 | op = "aten.sum.dim_IntList" 74 | if m := re.match("cnt: \\d+, (.*)", line): 75 | assert op is not None 76 | args_str = m.group(1) 77 | cnt = int(m.group(0).split(",")[0].split(":")[1]) 78 | 79 | if filter is None or any(f in op for f in filter): 80 | is_synthetic = cnt == 0 81 | 82 | op_inputs.append( 83 | { 84 | "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), 85 | "op_name": op, 86 | "args": args_str, 87 | "count": cnt, 88 | "is_synthetic": is_synthetic, 89 | } 90 | ) 91 | return op_inputs 92 | 93 | 94 | def _parse_trace_stream( 95 | stream, 96 | filter: Optional[List[str]] = None, 97 | desc: str = "Parsing stream", 98 | limit: Optional[int] = None, 99 | ) -> List[Dict]: 100 | """ 101 | Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). 102 | 103 | Args: 104 | stream: Iterable of lines (strings or bytes) 105 | filter: Optional list of operation name filters 106 | desc: Description for progress bar 107 | """ 108 | op_inputs = [] 109 | op = None 110 | num_ops = 0 111 | 112 | iterator = tqdm(stream, desc=desc, total=len(stream)) 113 | 114 | for line in iterator: 115 | # Handle bytes from response stream 116 | if isinstance(line, bytes): 117 | line = line.decode("utf-8") 118 | 119 | if m := re.match("Operator: (.*)", line): 120 | num_ops += 1 121 | if limit: 122 | if num_ops > limit: 123 | break 124 | op = m.group(1) 125 | if op == "aten.sum.SymInt": 126 | op = "aten.sum.dim_IntList" 127 | if m := re.match("cnt: \\d+, (.*)", line): 128 | assert op is not None 129 | args_str = m.group(1) 130 | cnt = int(m.group(0).split(",")[0].split(":")[1]) 131 | 132 | if filter is None or any(f in op for f in filter): 133 | is_synthetic = cnt == 0 134 | 135 | op_inputs.append( 136 | { 137 | "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), 138 | "op_name": op, 139 | "args": args_str, 140 | "count": cnt, 141 | "is_synthetic": is_synthetic, 142 | } 143 | ) 144 | return op_inputs 145 | 146 | 147 | def _detect_format(source: Union[str, Path, None]) -> str: 148 | """Detect format based on source type and extension.""" 149 | if source is None: 150 | return "parquet" 151 | 152 | if not isinstance(source, (str, Path)): 153 | raise ValueError(f"Unsupported source type: {type(source)}, should be str, Path, or None") 154 | 155 | source_str = str(source) 156 | if source_str.endswith(".parquet"): 157 | return "parquet" 158 | elif source_str.endswith(".txt"): 159 | return "trace" 160 | else: 161 | raise ValueError( 162 | f"Cannot auto-detect format for source: {source}. Please specify format explicitly." 163 | ) 164 | 165 | 166 | def load_ops_from_source( 167 | source: Union[str, Path, None], 168 | format: str = "auto", 169 | filter: Optional[List[str]] = None, 170 | ) -> List[Dict]: 171 | """ 172 | Load operation data from various sources and formats. 173 | 174 | Args: 175 | source: File path or URL (only trace) or None. If None, use huggingface dataset for parquet mode (default). 176 | format: "trace", "parquet", or "auto" (detect from file extension) 177 | filter: Optional list of operation name filters 178 | 179 | Returns: 180 | List of dictionaries with detailed operation info 181 | 182 | Auto-detection behavior: 183 | - None → parquet format test set from huggingface (default) 184 | - *.parquet → parquet format 185 | - *.txt → trace format 186 | - http*.txt → trace format 187 | - Other extensions → error (must specify format explicitly) 188 | """ 189 | # Format detection/validation 190 | if format == "auto": 191 | format = _detect_format(source) 192 | elif format not in ("parquet", "trace"): 193 | raise ValueError(f"Unsupported format: {format}") 194 | 195 | # Dispatch to appropriate loader 196 | loaders = {"parquet": _load_from_parquet, "trace": _load_from_trace} 197 | 198 | return loaders[format](source, filter) 199 | 200 | 201 | def _load_from_parquet( 202 | source: Optional[Union[str, Path]] = None, filter: Optional[List[str]] = None 203 | ): 204 | """ 205 | Load operations from parquet file or URL. 206 | 207 | Args: 208 | source: Local file path or None. If None, use huggingface dataset (default). 209 | filter: Optional list of strings to filter operation names 210 | 211 | Returns: 212 | List of dictionaries containing the data 213 | """ 214 | 215 | if source is None: 216 | # read parquet file from huggingface 217 | table = load_dataset( 218 | HUGGINGFACE_REPO, 219 | data_files=TORCHBENCH_SUITE_FILE, 220 | revision=TORCHBENCH_SUITE_HF_COMMIT, 221 | )["train"] 222 | else: 223 | # read parquet file directly 224 | table = pq.read_table(source) 225 | 226 | df = table.to_pandas() 227 | # Apply filter if provided 228 | if filter: 229 | mask = df["op_name"].apply(lambda op: any(f in op for f in filter)) 230 | df = df[mask] 231 | 232 | return df.to_dict("records") 233 | 234 | 235 | def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: 236 | """ 237 | Convert a list of operation dictionaries to a dictionary format which can be used for benchmarking. 238 | 239 | Args: 240 | ops_list: List of dicts with 'op_name' and 'args' keys 241 | 242 | Returns: 243 | Dictionary mapping op_name to list of args strings 244 | """ 245 | result = {} 246 | for op_data in ops_list: 247 | if not op_data["included_in_benchmark"]: 248 | continue 249 | op_name = op_data["op_name"] 250 | args = op_data["args"] 251 | if op_name not in result: 252 | result[op_name] = [] 253 | result[op_name].append(args) 254 | return result 255 | 256 | 257 | def _load_from_trace( 258 | source: Union[str, Path], filter: Optional[List[str]], limit: Optional[int] = None 259 | ) -> List[Dict]: 260 | """Load operations from trace file(s) and return list of dicts.""" 261 | op_inputs = [] 262 | 263 | # Handle URLs - stream directly without saving to disk 264 | if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): 265 | logging.info(f"Downloading trace from {source}") 266 | with requests.get(source) as response: 267 | response.raise_for_status() 268 | 269 | # Download entire content 270 | content = response.text 271 | 272 | # Create an iterator from the lines for the progress bar 273 | lines = content.splitlines() 274 | 275 | # Now parse with accurate progress (tqdm will know total lines) 276 | op_inputs = _parse_trace_stream(lines, filter, "Parsing", limit=limit) 277 | 278 | # Handle single files 279 | else: 280 | op_inputs = _parse_trace_file(source, filter, limit=limit) 281 | 282 | return op_inputs 283 | -------------------------------------------------------------------------------- /BackendBench/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | import traceback 10 | from dataclasses import dataclass 11 | from typing import List, Tuple 12 | 13 | import torch 14 | 15 | from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream 16 | 17 | 18 | @dataclass 19 | class CorrectnessTestResult: 20 | op_name: str 21 | args: str 22 | is_correct: bool = False 23 | error_msg: str = "" 24 | error_type: str = "" 25 | traceback: str = "" 26 | max_abs_error: float = -math.inf 27 | max_rel_error: float = -math.inf 28 | test_type: str = "correctness" 29 | 30 | 31 | @dataclass 32 | class PerformanceTestResult: 33 | op_name: str 34 | args: str 35 | speedup: float 36 | benchmark_time_ms: float 37 | reference_time_ms: float 38 | error_msg: str = "" 39 | successfully_ran: bool = False 40 | test_type: str = "performance" 41 | 42 | 43 | try: 44 | if torch.cuda.is_available(): 45 | import triton.testing 46 | 47 | TRITON_AVAILABLE = True 48 | else: 49 | TRITON_AVAILABLE = False 50 | except ImportError: 51 | TRITON_AVAILABLE = False 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | EXC_MSG = """ 56 | Exception raised for {op}: 57 | args: {args} 58 | exc: {exc} 59 | traceback: {traceback} 60 | """ 61 | 62 | 63 | def format_exception(e, op, args, kwargs, traceback=None): 64 | op_name = getattr(op, "__name__", str(op)) 65 | return EXC_MSG.format(op=op_name, args=serialize_args(args, kwargs), exc=e, traceback=traceback) 66 | 67 | 68 | def _allclose(a, b, atol=1e-2, rtol=1e-2): 69 | # using a stack to avoid recursion overflow issues 70 | stack = [(a, b)] 71 | 72 | while len(stack) > 0: 73 | curr_a, curr_b = stack.pop() 74 | 75 | if isinstance(curr_a, torch.Tensor): 76 | torch.testing.assert_close(curr_a, curr_b, equal_nan=True, atol=atol, rtol=rtol) 77 | elif isinstance(curr_a, (list, tuple)): 78 | assert len(curr_a) == len(curr_b) 79 | # Add pairs to stack in reverse order to maintain left-to-right checking 80 | stack.extend(reversed(list(zip(curr_a, curr_b)))) 81 | else: 82 | assert curr_a == curr_b 83 | 84 | 85 | def allclose(a, b, atol=1e-2, rtol=1e-2): 86 | try: 87 | _allclose(a, b) 88 | return True 89 | except Exception: 90 | return False 91 | 92 | 93 | def eval_correctness_test(op, impl, test) -> CorrectnessTestResult: 94 | """Evaluate impl of op against test. 95 | 96 | Returns: 97 | Tuple of (is_correct, error_message, absolute_error, relative_error) 98 | """ 99 | args, kwargs = test.args, test.kwargs 100 | ref = op(*args, **kwargs) 101 | try: 102 | res = impl(*args, **kwargs) 103 | is_correct = allclose(ref, res) 104 | 105 | abs_error, rel_error = compute_errors(ref, res) 106 | result = CorrectnessTestResult( 107 | op_name=op.__name__, 108 | args=serialize_args(args, kwargs), 109 | is_correct=is_correct, 110 | max_abs_error=abs_error, 111 | max_rel_error=rel_error, 112 | ) 113 | return result 114 | except Exception as e: 115 | error_msg = format_exception(e, op, args, kwargs, traceback.format_exc()) 116 | result = CorrectnessTestResult( 117 | op_name=op.__name__, 118 | args=serialize_args(args, kwargs), 119 | is_correct=False, 120 | error_msg=error_msg, 121 | error_type=str(type(e)), 122 | traceback=traceback.format_exc(), 123 | ) 124 | logger.warning(error_msg) 125 | return result 126 | 127 | 128 | def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult]]: 129 | """Evaluate correctness of impl against tests.""" 130 | correct, total = 0, 0 131 | test_results: List[CorrectnessTestResult] = [] 132 | for test in tests: 133 | args_str = serialize_args(test.args, test.kwargs) 134 | logging.debug(f"Testing {op.__name__} with args {args_str}") 135 | result = eval_correctness_test(op, impl, test) 136 | test_results.append(result) 137 | if result.is_correct: 138 | correct += 1 139 | total += 1 140 | 141 | # Handle the case where no tests are available 142 | if total == 0: 143 | logger.warning(f"No correctness tests available for {str(op)}") 144 | return 0.0, [] 145 | 146 | return correct / total, test_results 147 | 148 | 149 | def cpu_bench(fn, num_runs=100): 150 | """Simple CPU benchmarking using time.perf_counter.""" 151 | import time 152 | 153 | for _ in range(10): 154 | fn() 155 | 156 | start = time.perf_counter() 157 | for _ in range(num_runs): 158 | fn() 159 | return (time.perf_counter() - start) / num_runs 160 | 161 | 162 | def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult]]: 163 | """Evaluate performance of impl against tests.""" 164 | bench_fn = ( 165 | triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench 166 | ) 167 | base_times = [] 168 | test_times = [] 169 | args_strs = [] 170 | performance_results: List[PerformanceTestResult] = [] 171 | 172 | for test in tests: 173 | # Cache the arguments to ensure consistency between reference and implementation 174 | cached_args = test.args 175 | cached_kwargs = test.kwargs 176 | args_str = serialize_args(cached_args, cached_kwargs) 177 | args_strs.append(args_str) 178 | logging.debug(f"Benchmarking {op.__name__} with args {args_str}") 179 | base_time = bench_fn(lambda: op(*cached_args, **cached_kwargs)) 180 | base_times.append(base_time) 181 | # Note: If the test fails we consider the speedup to be 1.0 182 | # TODO: We should make this more explicit, by having an if resolving it in the except and removing the finally block 183 | test_time = base_time 184 | try: 185 | ref = op(*cached_args, **cached_kwargs) 186 | res = impl(*cached_args, **cached_kwargs) 187 | if not allclose( 188 | ref, 189 | res, 190 | ): 191 | abs_error, rel_error = compute_errors(ref, res) 192 | raise ValueError( 193 | f"Reference and result tensors are not close: max absolute error {abs_error}, max relative error {rel_error}" 194 | ) 195 | test_time = bench_fn(lambda: impl(*cached_args, **cached_kwargs)) 196 | performance_results.append( 197 | PerformanceTestResult( 198 | op_name=op.__name__, 199 | args=args_str, 200 | speedup=base_time / test_time, 201 | successfully_ran=True, 202 | benchmark_time_ms=test_time, 203 | reference_time_ms=base_time, 204 | ) 205 | ) 206 | except Exception as e: 207 | error_msg = format_exception(e, op, test.args, test.kwargs, traceback.format_exc()) 208 | performance_results.append( 209 | PerformanceTestResult( 210 | op_name=op.__name__, 211 | args=args_str, 212 | successfully_ran=False, 213 | speedup=None, 214 | benchmark_time_ms=None, 215 | reference_time_ms=base_time, 216 | error_msg=error_msg, 217 | ) 218 | ) 219 | finally: 220 | test_times.append(test_time) 221 | 222 | speedups = torch.tensor(base_times) / torch.tensor(test_times) 223 | 224 | return speedups.log().mean().exp(), performance_results 225 | 226 | 227 | def eval_one_op( 228 | op, impl, correctness_tests, performance_tests 229 | ) -> Tuple[float, float, List[CorrectnessTestResult], List[PerformanceTestResult]]: 230 | """Evaluate impl of op against correctness_tests and performance_tests. 231 | 232 | Returns: 233 | Tuple of (correctness_score, performance_score, correctness_results, performance_results) 234 | """ 235 | 236 | if uses_cuda_stream(impl): 237 | logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") 238 | performance_results = [] 239 | correctness_results = [] 240 | for test in correctness_tests: 241 | args_str = serialize_args(test.args, test.kwargs) 242 | correctness_results.append( 243 | CorrectnessTestResult( 244 | op_name=op.__name__, 245 | args=args_str, 246 | is_correct=False, 247 | error_msg="Skipped: uses CUDA stream", 248 | ) 249 | ) 250 | for test in performance_tests: 251 | args_str = serialize_args(test.args, test.kwargs) 252 | performance_results.append( 253 | PerformanceTestResult( 254 | op_name=op.__name__, 255 | args=args_str, 256 | speedup=0, 257 | benchmark_time_ms=0, 258 | reference_time_ms=0, 259 | error_msg="Skipped: uses CUDA stream", 260 | ) 261 | ) 262 | return 0, 1.0, correctness_results, performance_results 263 | 264 | correctness_score, correctness_results = eval_correctness(op, impl, correctness_tests) 265 | performance_score, performance_results = eval_performance(op, impl, performance_tests) 266 | return ( 267 | correctness_score, 268 | performance_score, 269 | correctness_results, 270 | performance_results, 271 | ) 272 | 273 | 274 | def perf_at_p(correctness, performance, p=1.0): 275 | assert len(correctness) == len(performance), ( 276 | "correctness and performance must have the same length" 277 | ) 278 | return ( 279 | torch.where(torch.tensor(correctness).bool(), torch.tensor(performance) > p, 0) 280 | .float() 281 | .mean() 282 | ) 283 | -------------------------------------------------------------------------------- /BackendBench/scripts/parquet_trace_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # utility functions to convert parquet and trace files back and forth 8 | 9 | import hashlib 10 | import logging 11 | import os 12 | from collections import defaultdict 13 | from pathlib import Path 14 | 15 | import click 16 | import numpy as np 17 | import pyarrow as pa 18 | import pyarrow.parquet as pq 19 | from huggingface_hub import HfApi 20 | 21 | from BackendBench.data_loaders import _load_from_trace 22 | from BackendBench.scripts.dataset_filters import ( 23 | apply_runtime_filter, 24 | apply_skip_ops_filter, 25 | ) 26 | 27 | DEFAULT_TRACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/augmented_hf_op_traces.txt" 28 | DEFAULT_PARQUET_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/backend_bench_problems.parquet" 29 | 30 | 31 | """ 32 | Columns for the parquet dataset: 33 | - uuid (int) (hash of op + args) 34 | - op_name (string) 35 | - args (string) 36 | - count (int) (number of times this op + set of args was called in real models) 37 | - is_synthetic (boolean) (did we generate this op or is it from a real model) 38 | - included_in_benchmark (boolean) 39 | - why_excluded (list of strings) (empty if included) 40 | - runtime_ms (float) (timings on H100 gpu) 41 | - runnable (bool) (does this op + test work) [we may remove this column later after we solve for special ops] 42 | - in_models (string) (which models did we include this op in) [@TODO add this] 43 | """ 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | def _upload_to_hf(file_path: str) -> None: 49 | """Upload file to GPUMODE/huggingface_op_trace.""" 50 | try: 51 | api = HfApi() 52 | api.upload_file( 53 | path_or_fileobj=file_path, 54 | path_in_repo=Path(file_path).name, 55 | repo_id="GPUMODE/huggingface_op_trace", 56 | repo_type="dataset", 57 | ) 58 | logger.info(f"Uploaded {Path(file_path).name} to Hugging Face") 59 | except Exception as e: 60 | logger.warning(f"Failed to upload {Path(file_path).name}: {e}") 61 | 62 | 63 | def setup_logging(log_level): 64 | """Configure logging with the specified level.""" 65 | numeric_level = getattr(logging, log_level.upper(), None) 66 | if not isinstance(numeric_level, int): 67 | raise ValueError(f"Invalid log level: {log_level}") 68 | 69 | logging.basicConfig( 70 | level=numeric_level, 71 | format="[%(asctime)s][%(levelname)s][%(filename)s] %(message)s", 72 | datefmt="%Y-%m-%d %H:%M:%S", 73 | handlers=[ 74 | logging.FileHandler("logs/parquet_trace_converter.log"), 75 | logging.StreamHandler(), # Also print to console 76 | ], 77 | ) 78 | 79 | 80 | def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): 81 | """ 82 | Convert a trace file to a parquet file 83 | """ 84 | 85 | # Load operations using local trace parsing function 86 | ops = _load_from_trace(trace_file, filter=None, limit=limit) 87 | 88 | # Add additional metadata fields required for the parquet format 89 | for op in ops: 90 | op["uuid"] = hashlib.sha256(op["args"].encode() + op["op_name"].encode()).hexdigest() 91 | op["included_in_benchmark"] = True 92 | op["why_excluded"] = [] 93 | op["runtime_ms"] = np.nan 94 | op["relative_runtime_to_kernel_launch"] = np.nan 95 | op["runnable"] = True 96 | op["is_overhead_dominated_op"] = False 97 | 98 | # apply filters 99 | ops = apply_skip_ops_filter(ops) 100 | ops = apply_runtime_filter(ops) 101 | 102 | exclusion_dict = defaultdict(lambda: 0) 103 | exclusion_mapping = defaultdict(lambda: set()) 104 | testable_ops = set() 105 | all_ops = set() 106 | for op in ops: 107 | for reason in op["why_excluded"]: 108 | exclusion_dict[reason] += 1 109 | exclusion_mapping[reason].add(op["op_name"]) 110 | if op["included_in_benchmark"]: 111 | testable_ops.add(op["op_name"]) 112 | all_ops.add(op["op_name"]) 113 | non_testable_ops = all_ops - testable_ops 114 | 115 | for reason, count in exclusion_dict.items(): 116 | logger.info(f"Excluded tests from {count} / {len(ops)} ops due to {reason}") 117 | for reason in exclusion_mapping.keys(): 118 | no_op_set = exclusion_mapping[reason].intersection(non_testable_ops) 119 | list_str = "\n".join(no_op_set) 120 | logger.info( 121 | f"Excluded the following {len(no_op_set)}/{len(all_ops)} ops and input combinations at least partially due to the reason: {reason}:\n {list_str}" 122 | ) 123 | list_str = "\n".join(non_testable_ops) 124 | logger.info( 125 | f"Excluded {len(non_testable_ops)} / {len(all_ops)} ops due to not having tests. They are as follows: {list_str}" 126 | ) 127 | 128 | # Some logging about performance canaries 129 | overhead_dominated_ops = [op for op in ops if op["is_overhead_dominated_op"]] 130 | overhead_dominated_op_names = {op["op_name"] for op in overhead_dominated_ops} 131 | logger.info( 132 | f"Found {len(overhead_dominated_ops)} / {len(ops)} tests that are dominated by overhead" 133 | ) 134 | logger.info( 135 | f"Found {len(overhead_dominated_op_names)} / {len(all_ops)} unique ops that are dominated by overhead" 136 | ) 137 | 138 | # Create parquet table with all metadata (formerly "dev" version) 139 | table = pa.Table.from_pylist(ops) 140 | 141 | # Write parquet file 142 | pq.write_table(table, parquet_file) 143 | 144 | logger.info(f"Wrote {len(ops)} ops and inputs to {parquet_file}") 145 | 146 | # Log column information for verification 147 | logger.debug(f"Parquet columns: {table.column_names}") 148 | 149 | 150 | def convert_parquet_to_trace(parquet_file, trace_file, limit: int = None): 151 | """ 152 | Convert a parquet file to a trace file 153 | """ 154 | table = pq.read_table(parquet_file) 155 | op_inputs = {} 156 | 157 | for row in table.to_pylist(): 158 | formatted_entry = f"cnt: {row['count']}, {row['args']}" 159 | 160 | if row["op_name"] not in op_inputs: 161 | op_inputs[row["op_name"]] = [] 162 | op_inputs[row["op_name"]].append(formatted_entry) 163 | if limit: 164 | op_inputs = op_inputs[:limit] 165 | 166 | # write to trace file 167 | with open(trace_file, "w") as f: 168 | for op, args in op_inputs.items(): 169 | f.write(f"Operator: {op}\n") 170 | for arg in args: 171 | f.write(f"{arg}\n") 172 | total_args = sum(len(op_inputs[op]) for op in op_inputs) 173 | logging.info(f"Wrote {total_args} ops and inputs to {trace_file}") 174 | 175 | 176 | def _validate_parquet_name(parquet_name: str) -> str: 177 | """Validate parquet filename. URLs allowed only for inputs.""" 178 | # URLs are allowed only if this is an input file 179 | if parquet_name.startswith(("http://", "https://")): 180 | raise click.BadParameter("Output parquet file cannot be a URL") 181 | 182 | if not parquet_name.endswith(".parquet"): 183 | raise click.BadParameter("Parquet file must end with .parquet suffix") 184 | 185 | # Ensure local files are in datasets directory 186 | if not parquet_name.startswith("datasets/"): 187 | parquet_name = os.path.join("datasets", parquet_name) 188 | 189 | return parquet_name 190 | 191 | 192 | def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: 193 | """Validate trace file. URLs allowed only for inputs.""" 194 | # URLs are allowed only if this is an input file 195 | if trace_file.startswith(("http://", "https://")): 196 | if is_input: 197 | return trace_file 198 | else: 199 | raise click.BadParameter("Output trace file cannot be a URL") 200 | 201 | # For local files, check extension 202 | if not (trace_file.endswith(".txt") or Path(trace_file).is_dir()): 203 | raise click.BadParameter("Local trace file must end with .txt or be a directory") 204 | 205 | if Path(trace_file).is_dir() and not is_input: 206 | raise click.BadParameter("Output trace file cannot be a directory") 207 | 208 | return trace_file 209 | 210 | 211 | @click.command() 212 | @click.option( 213 | "--log-level", 214 | default=os.getenv("LOG_LEVEL", "INFO"), 215 | type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False), 216 | help="Set the logging level", 217 | ) 218 | @click.option( 219 | "--mode", 220 | default="trace-to-parquet", 221 | type=click.Choice(["trace-to-parquet", "parquet-to-trace"]), 222 | help="Conversion mode", 223 | ) 224 | @click.option( 225 | "--trace-file", 226 | default=DEFAULT_TRACE_URL, 227 | type=str, 228 | help="Input trace file: URL (for downloads), local .txt file, or directory. Output trace files cannot be URLs", 229 | ) 230 | @click.option( 231 | "--parquet-name", 232 | default="backend_bench_problems.parquet", 233 | type=str, 234 | help="Parquet filename: URL allowed as input in parquet-to-trace mode, local files in datasets/.", 235 | ) 236 | @click.option( 237 | "--upload-to-hf", 238 | is_flag=True, 239 | default=False, 240 | help="Upload generated parquet files to Hugging Face (GPUMODE/huggingface_op_trace) in trace-to-parquet mode", 241 | ) 242 | @click.option( 243 | "--limit", 244 | default=None, 245 | type=int, 246 | help="Limit the number of operators to convert. (Useful for testing)", 247 | ) 248 | def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit): 249 | """Convert trace files to parquet format or vice versa.""" 250 | setup_logging(log_level) 251 | 252 | # Create datasets directory 253 | os.makedirs("datasets", exist_ok=True) 254 | 255 | if mode == "trace-to-parquet": 256 | # Validate inputs/outputs 257 | trace_file = _validate_trace_file(trace_file, is_input=True) # Input: URLs allowed 258 | parquet_name = _validate_parquet_name(parquet_name) # Output: URLs not allowed 259 | 260 | logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}") 261 | 262 | convert_trace_to_parquet(trace_file, parquet_name, limit=limit) 263 | logger.info("Conversion completed successfully") 264 | 265 | if upload_to_hf: 266 | # Upload to Hugging Face 267 | _upload_to_hf(os.path.abspath(parquet_name)) 268 | 269 | elif mode == "parquet-to-trace": 270 | # Validate parquet input (URLs allowed for input in this mode) 271 | parquet_input = _validate_parquet_name(parquet_name) 272 | # Validate trace output (URLs not allowed for output) 273 | trace_output = _validate_trace_file(trace_file, is_input=False) # Output: URLs not allowed 274 | 275 | logger.info(f"Converting parquet file {parquet_input} to trace file {trace_output}") 276 | convert_parquet_to_trace(parquet_input, trace_output, limit=limit) 277 | logger.info("Conversion completed successfully") 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /BackendBench/backends/kernel_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from typing import Callable, Dict 10 | 11 | from BackendBench.utils import compile_kernel_from_string, op_name_to_folder_name 12 | 13 | from .base import Backend 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class KernelAgentBackend(Backend): 19 | """ 20 | Backend that uses KernelAgent for sophisticated parallel kernel generation. 21 | 22 | This backend leverages KernelAgent's advanced features: 23 | - Parallel workers with iterative refinement 24 | - Multi-turn conversation history 25 | - Comprehensive prompt engineering with Triton guidelines 26 | - Automatic test generation 27 | """ 28 | 29 | def __init__(self) -> None: 30 | super().__init__("kernel_agent") 31 | self.compiled_kernels: Dict[str, Callable] = {} 32 | 33 | # Create generated_kernels directory 34 | import datetime 35 | 36 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 37 | self.kernels_dir = f"generated_kernels/kernel_agent_run_{timestamp}" 38 | os.makedirs(self.kernels_dir, exist_ok=True) 39 | 40 | # Create README for this run 41 | readme_path = os.path.join(self.kernels_dir, "README.md") 42 | with open(readme_path, "w") as f: 43 | f.write( 44 | f"""# Generated Kernels - KernelAgent - {timestamp} 45 | 46 | This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend. 47 | 48 | ## Run Info 49 | - Timestamp: {timestamp} 50 | - Backend: KernelAgent 51 | - Features: Parallel workers, iterative refinement, conversation history 52 | 53 | ## Files 54 | Each `_kernel.py` file contains the complete generated kernel code for that operation. 55 | KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts. 56 | 57 | ## KernelAgent Features Used 58 | - Parallel workers for increased success rate 59 | - Iterative refinement with multi-turn dialogue 60 | - Comprehensive Triton programming guidelines 61 | - Automatic test generation and validation 62 | - Session logging and artifact preservation 63 | 64 | ## Usage 65 | You can inspect these files to debug kernel generation, analyze the parallel worker outputs, 66 | or understand the sophisticated generation process used by KernelAgent. 67 | """ 68 | ) 69 | 70 | print(f"Saving KernelAgent generated kernels to: {self.kernels_dir}") 71 | 72 | self.kernel_agent = None 73 | self.num_workers = 4 74 | self.max_rounds = 10 75 | 76 | def set_config(self, num_workers: int, max_rounds: int): 77 | """Set configuration for KernelAgent.""" 78 | self.num_workers = num_workers 79 | self.max_rounds = max_rounds 80 | 81 | def _get_kernel_agent(self): 82 | """Lazy initialization of KernelAgent to avoid import issues.""" 83 | if self.kernel_agent is None: 84 | try: 85 | from triton_kernel_agent import TritonKernelAgent 86 | 87 | agent_log_dir = os.path.join(self.kernels_dir, "agent_logs") 88 | os.makedirs(agent_log_dir, exist_ok=True) 89 | 90 | self.kernel_agent = TritonKernelAgent( 91 | log_dir=agent_log_dir, 92 | num_workers=self.num_workers, 93 | max_rounds=self.max_rounds, 94 | ) 95 | 96 | print(f"✓ KernelAgent initialized with log directory: {agent_log_dir}") 97 | 98 | except ImportError: 99 | raise ImportError( 100 | "triton_kernel_agent package not found. Install it to use KernelAgent backend." 101 | ) 102 | 103 | return self.kernel_agent 104 | 105 | def _create_problem_description_from_op(self, op, op_name: str) -> str: 106 | """ 107 | Create a problem description for KernelAgent based on the PyTorch operation. 108 | 109 | Args: 110 | op: PyTorch operation 111 | op_name: Operation name extracted from op 112 | 113 | Returns: 114 | Problem description string for KernelAgent 115 | """ 116 | # Create a comprehensive problem description that KernelAgent can understand 117 | problem_description = f""" 118 | Implement a high-performance Triton kernel for the PyTorch operation: {op_name} 119 | 120 | Operation details: 121 | - PyTorch operation: {op} 122 | - Operation name: {op_name} 123 | - Framework target: OpenAI Triton 124 | 125 | Requirements: 126 | 1. The kernel must be functionally equivalent to the PyTorch operation 127 | 2. Implement using Triton language primitives (tl.load, tl.store, etc.) 128 | 3. Handle all tensor shapes and data types that the original operation supports 129 | 4. Optimize for GPU performance with proper memory coalescing 130 | 5. Include proper boundary condition handling 131 | 6. Follow Triton best practices for kernel design 132 | 133 | The generated kernel should: 134 | - Take the same input arguments as the PyTorch operation 135 | - Return outputs with identical shapes, dtypes, and numerical values 136 | - Be optimized for common tensor shapes and memory layouts 137 | - Handle edge cases gracefully 138 | 139 | Please generate a complete, production-ready Triton kernel implementation. 140 | """ 141 | return problem_description 142 | 143 | def _adapt_kernel_function_name(self, kernel_code: str, op_name: str) -> str: 144 | """ 145 | Adapt KernelAgent's 'kernel_function' to BackendBench's expected naming convention. 146 | 147 | KernelAgent generates kernels with 'kernel_function' as the main entry point. 148 | BackendBench expects '{op_name}_kernel_impl' as the function name. 149 | 150 | Args: 151 | kernel_code: Original kernel code from KernelAgent 152 | op_name: Operation name for the expected function name 153 | 154 | Returns: 155 | Modified kernel code with correct function name 156 | """ 157 | folder_name = os.path.basename(os.path.dirname(kernel_code)) 158 | expected_name = f"{folder_name}_kernel_impl" 159 | 160 | # Replace 'def kernel_function' with 'def {op_name}_kernel_impl' 161 | if "def kernel_function(" in kernel_code: 162 | adapted_code = kernel_code.replace("def kernel_function(", f"def {expected_name}(") 163 | 164 | # Also replace any docstring references 165 | adapted_code = adapted_code.replace( 166 | '"""Wrapper function that handles kernel launch."""', 167 | f'"""{op_name} kernel implementation using Triton."""', 168 | ) 169 | 170 | return adapted_code 171 | else: 172 | # If kernel_function is not found, add a wrapper that calls the existing function 173 | wrapper_code = f''' 174 | 175 | def {expected_name}(*args, **kwargs): 176 | """{op_name} kernel implementation using Triton - BackendBench adapter.""" 177 | # Call the original kernel_function from KernelAgent 178 | return kernel_function(*args, **kwargs) 179 | ''' 180 | return kernel_code + wrapper_code 181 | 182 | def compile_kernel_from_string( 183 | self, kernel_code: str, op_name: str, attempt: int = 1 184 | ) -> Callable: 185 | """Compile a kernel from string code and return a callable.""" 186 | folder_name = op_name_to_folder_name(op_name) 187 | adapted_code = self._adapt_kernel_function_name(kernel_code, op_name) 188 | kernel_file_path = os.path.join(self.kernels_dir, f"{folder_name}_kernel.py") 189 | expected_fn_name = f"{folder_name}_kernel_impl" 190 | module_name = f"kernel_agent_{folder_name}" 191 | 192 | try: 193 | kernel = compile_kernel_from_string( 194 | kernel_code=adapted_code, 195 | op_name=op_name, 196 | kernel_file_path=kernel_file_path, 197 | expected_fn_name=expected_fn_name, 198 | module_name=module_name, 199 | ) 200 | except Exception as e: 201 | raise e 202 | return kernel 203 | 204 | def add_kernel(self, op, kernel_code: str, op_name: str): 205 | """Add a kernel implementation for a specific operator.""" 206 | compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1) 207 | self.compiled_kernels[op] = compiled_kernel 208 | 209 | # Save the original KernelAgent code as well 210 | folder_name = op_name_to_folder_name(op_name) 211 | original_file = os.path.join(self.kernels_dir, f"{folder_name}_original_kernel_agent.py") 212 | with open(original_file, "w") as f: 213 | f.write(kernel_code) 214 | 215 | def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]: 216 | """ 217 | Generate a kernel using KernelAgent's sophisticated generation system. 218 | 219 | Args: 220 | op: PyTorch operation 221 | op_name: Operation name 222 | 223 | Returns: 224 | tuple: (kernel_code, success) 225 | """ 226 | try: 227 | agent = self._get_kernel_agent() 228 | 229 | # Create problem description 230 | problem_description = self._create_problem_description_from_op(op, op_name) 231 | 232 | print( 233 | f"🚀 Generating {op_name} kernel with KernelAgent (parallel workers + refinement)" 234 | ) 235 | 236 | # Generate kernel using KernelAgent 237 | result = agent.generate_kernel( 238 | problem_description=problem_description, 239 | test_code=None, # Let KernelAgent auto-generate the test 240 | ) 241 | 242 | if result["success"]: 243 | print(f"✅ KernelAgent succeeded for {op_name}!") 244 | print( 245 | f" Worker {result['worker_id']} found solution in {result['rounds']} rounds" 246 | ) 247 | print(f" Session: {result['session_dir']}") 248 | 249 | # Copy the session directory to our kernels directory for preservation 250 | import shutil 251 | 252 | session_name = os.path.basename(result["session_dir"]) 253 | folder_name = op_name_to_folder_name(op_name) 254 | preserved_session = os.path.join( 255 | self.kernels_dir, f"{folder_name}_session_{session_name}" 256 | ) 257 | try: 258 | shutil.copytree(result["session_dir"], preserved_session) 259 | print(f" Session preserved: {preserved_session}") 260 | except Exception as e: 261 | print(f" Warning: Could not preserve session: {e}") 262 | 263 | return result["kernel_code"], True 264 | else: 265 | print(f"❌ KernelAgent failed for {op_name}: {result['message']}") 266 | return "", False 267 | 268 | except Exception as e: 269 | print(f"❌ KernelAgent error for {op_name}: {e}") 270 | return "", False 271 | 272 | def __getitem__(self, key): 273 | if key in self.compiled_kernels: 274 | return self.compiled_kernels[key] 275 | raise KeyError(f"No KernelAgent kernel implementation found for {key}") 276 | 277 | def __contains__(self, key): 278 | return key in self.compiled_kernels 279 | -------------------------------------------------------------------------------- /test/test_output.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import math 9 | import tempfile 10 | from pathlib import Path 11 | 12 | from expecttest import assert_expected_inline 13 | 14 | from BackendBench.eval import CorrectnessTestResult, PerformanceTestResult 15 | from BackendBench.output import ( 16 | _get_summary_op_results, 17 | _prepare_results_data, 18 | save_overall_summary, 19 | save_results, 20 | ) 21 | 22 | 23 | class TestOutputFunctions: 24 | def _create_test_fixtures(self): 25 | """Create test fixtures for correctness and performance results.""" 26 | correctness_results = [ 27 | CorrectnessTestResult( 28 | op_name="torch.ops.aten.add.Tensor", 29 | args="[tensor([1, 2]), tensor([3, 4])]", 30 | is_correct=True, 31 | max_abs_error=0.001, 32 | max_rel_error=0.0001, 33 | ), 34 | CorrectnessTestResult( 35 | op_name="torch.ops.aten.add.Tensor", 36 | args="[tensor([5, 6]), tensor([7, 8])]", 37 | is_correct=True, 38 | max_abs_error=0.002, 39 | max_rel_error=0.0002, 40 | ), 41 | CorrectnessTestResult( 42 | op_name="torch.ops.aten.mul.Tensor", 43 | args="[tensor([1, 2]), tensor([3, 4])]", 44 | is_correct=False, 45 | error_msg="Tensor mismatch", 46 | error_type="AssertionError", 47 | ), 48 | CorrectnessTestResult( 49 | op_name="torch.ops.aten.sin.default", 50 | args="[tensor([0.5])]", 51 | is_correct=True, 52 | max_abs_error=0.0, 53 | max_rel_error=0.0, 54 | ), 55 | ] 56 | 57 | performance_results = [ 58 | PerformanceTestResult( 59 | op_name="torch.ops.aten.add.Tensor", 60 | args="[tensor([1, 2]), tensor([3, 4])]", 61 | speedup=1.5, 62 | benchmark_time_ms=10.0, 63 | reference_time_ms=15.0, 64 | successfully_ran=True, 65 | ), 66 | PerformanceTestResult( 67 | op_name="torch.ops.aten.add.Tensor", 68 | args="[tensor([5, 6]), tensor([7, 8])]", 69 | speedup=2.0, 70 | benchmark_time_ms=8.0, 71 | reference_time_ms=16.0, 72 | successfully_ran=True, 73 | ), 74 | PerformanceTestResult( 75 | op_name="torch.ops.aten.mul.Tensor", 76 | args="[tensor([1, 2]), tensor([3, 4])]", 77 | speedup=1.0, 78 | benchmark_time_ms=20.0, 79 | reference_time_ms=20.0, 80 | successfully_ran=True, 81 | ), 82 | PerformanceTestResult( 83 | op_name="torch.ops.aten.sin.default", 84 | args="[tensor([0.5])]", 85 | speedup=None, 86 | benchmark_time_ms=None, 87 | reference_time_ms=20.0, 88 | successfully_ran=False, 89 | error_msg="Compilation failed", 90 | ), 91 | ] 92 | 93 | return correctness_results, performance_results 94 | 95 | def test_prepare_results_data(self): 96 | """Test the _prepare_results_data function.""" 97 | correctness_results, performance_results = self._create_test_fixtures() 98 | 99 | all_results, failed_tests, op_summaries = _prepare_results_data( 100 | correctness_results, performance_results 101 | ) 102 | 103 | # Check that all results are properly converted to dicts and sorted 104 | assert len(all_results) == 8 # 4 correctness + 4 performance 105 | 106 | # Check failed tests 107 | assert len(failed_tests) == 2 # 1 correctness + 1 performance failure 108 | failed_tests = [test["op_name"] for test in failed_tests] 109 | assert "torch.ops.aten.mul.Tensor" in failed_tests 110 | assert "torch.ops.aten.sin.default" in failed_tests 111 | 112 | # Check operator summaries 113 | assert len(op_summaries) == 3 # add, mul, sin 114 | 115 | # Test add operator summary 116 | add_summary = op_summaries["torch.ops.aten.add.Tensor"] 117 | assert_expected_inline( 118 | str(add_summary), 119 | """{'operator': 'torch.ops.aten.add.Tensor', 'total_tests': 3, 'correctness_tests': 2, 'performance_tests': 2, 'passed_correctness_tests': 2, 'passed_performance_tests': 2, 'failed_correctness_tests': 0, 'failed_performance_tests': 0, 'correctness_rate': 1.0, 'avg_speedup': 1.75, 'geomean_speedup': 1.7320507764816284, 'max_absolute_error': 0.002, 'max_relative_error': 0.0002}""", 120 | ) 121 | 122 | # Test mul operator summary (should have failed correctness and performance) 123 | mul_summary = op_summaries["torch.ops.aten.mul.Tensor"] 124 | assert_expected_inline( 125 | str(mul_summary), 126 | """{'operator': 'torch.ops.aten.mul.Tensor', 'total_tests': 3, 'correctness_tests': 1, 'performance_tests': 1, 'passed_correctness_tests': 0, 'passed_performance_tests': 1, 'failed_correctness_tests': 1, 'failed_performance_tests': 0, 'correctness_rate': 0.0, 'avg_speedup': 1.0, 'geomean_speedup': 1.0, 'max_absolute_error': -inf, 'max_relative_error': -inf}""", 127 | ) 128 | 129 | sin_summary = op_summaries["torch.ops.aten.sin.default"] 130 | assert_expected_inline( 131 | str(sin_summary), 132 | """{'operator': 'torch.ops.aten.sin.default', 'total_tests': 3, 'correctness_tests': 1, 'performance_tests': 1, 'passed_correctness_tests': 1, 'passed_performance_tests': 0, 'failed_correctness_tests': 0, 'failed_performance_tests': 1, 'correctness_rate': 1.0, 'avg_speedup': 0.0, 'geomean_speedup': 0.0, 'max_absolute_error': 0.0, 'max_relative_error': 0.0}""", 133 | ) 134 | 135 | def test_get_summary_op_results(self): 136 | """Test the _get_summary_op_results function.""" 137 | correctness_results, performance_results = self._create_test_fixtures() 138 | 139 | op_results = _get_summary_op_results(performance_results, correctness_results) 140 | 141 | # Should return list of tuples (op_name, correctness_str, speedup_str) 142 | assert len(op_results) == 3 143 | 144 | # Check that results are sorted properly (by speedup descending, then correctness) 145 | assert_expected_inline( 146 | str(op_results), 147 | """[('torch.ops.aten.add.Tensor', '100.0000%', '1.7321x'), ('torch.ops.aten.sin.default', '100.0000%', '1.0000x'), ('torch.ops.aten.mul.Tensor', '0.0000%', '1.0000x')]""", 148 | ) 149 | 150 | def test_save_results_integration(self): 151 | """Test the full save_results function with file I/O.""" 152 | correctness_results, performance_results = self._create_test_fixtures() 153 | 154 | with tempfile.TemporaryDirectory() as tmpdir: 155 | output_path = Path(tmpdir) / "test_output" 156 | 157 | save_results( 158 | correctness_results=correctness_results, 159 | performance_results=performance_results, 160 | output_path=output_path, 161 | command="backendbench --suite test_suite", 162 | mean_correctness=0.75, 163 | geomean_perf=1.8, 164 | perf_at_p_score=0.6, 165 | p=1.2, 166 | ) 167 | 168 | # Check that all expected files were created 169 | assert (output_path / "full_results.json").exists() 170 | assert (output_path / "operator_summary.csv").exists() 171 | assert (output_path / "failed_tests.json").exists() 172 | assert (output_path / "OVERALL_SUMMARY.md").exists() 173 | 174 | # Check full_results.json content 175 | with open(output_path / "full_results.json") as f: 176 | full_results = json.load(f) 177 | assert len(full_results) == 8 178 | 179 | # Check failed_tests.json content 180 | with open(output_path / "failed_tests.json") as f: 181 | failed_tests = json.load(f) 182 | assert len(failed_tests) == 2 183 | 184 | # Check that CSV has correct number of rows (header + 3 operators) 185 | with open(output_path / "operator_summary.csv") as f: 186 | csv_content = f.read() 187 | # Should have header + 3 data rows 188 | assert len(csv_content.strip().split("\n")) == 4 189 | 190 | # Check overall summary exists and has expected content 191 | with open(output_path / "OVERALL_SUMMARY.md") as f: 192 | summary_content = f.read() 193 | assert "# BackendBench Run Summary" in summary_content 194 | assert "backendbench --suite test_suite" in summary_content 195 | assert "0.75" in summary_content # mean_correctness 196 | assert "1.80" in summary_content # geomean_perf 197 | 198 | def test_save_overall_summary_standalone(self): 199 | """Test the save_overall_summary function independently.""" 200 | correctness_results, performance_results = self._create_test_fixtures() 201 | 202 | with tempfile.TemporaryDirectory() as tmpdir: 203 | output_path = Path(tmpdir) / "test_summary" 204 | 205 | save_overall_summary( 206 | output_path=output_path, 207 | command="backendbench --ops add,mul", 208 | mean_correctness=0.8, 209 | geomean_perf=2.1, 210 | perf_at_p_score=0.7, 211 | p=1.5, 212 | performance_results=performance_results, 213 | correctness_results=correctness_results, 214 | ) 215 | 216 | # Check that the summary file was created 217 | summary_path = output_path / "OVERALL_SUMMARY.md" 218 | assert summary_path.exists() 219 | 220 | # Check content 221 | with open(summary_path) as f: 222 | content = f.read() 223 | 224 | assert "backendbench --ops add,mul" in content 225 | assert "| Correctness Score | 0.80 |" in content 226 | assert "| Performance Score (geomean speedup) | 2.10 |" in content 227 | assert "| Perf@1.5 Score | 0.70 |" in content 228 | 229 | def test_empty_results(self): 230 | """Test functions with empty input data.""" 231 | empty_correctness = [] 232 | empty_performance = [] 233 | 234 | all_results, failed_tests, op_summaries = _prepare_results_data( 235 | empty_correctness, empty_performance 236 | ) 237 | 238 | assert len(all_results) == 0 239 | assert len(failed_tests) == 0 240 | assert len(op_summaries) == 0 241 | 242 | # Test with empty results in summary function 243 | op_results = _get_summary_op_results(empty_performance, empty_correctness) 244 | assert len(op_results) == 0 245 | 246 | def test_edge_cases(self): 247 | """Test edge cases and error conditions.""" 248 | # Test with NaN and infinite values 249 | edge_case_results = [ 250 | CorrectnessTestResult( 251 | op_name="edge_case_op", 252 | args="[tensor([nan])]", 253 | is_correct=True, 254 | max_abs_error=float("inf"), 255 | max_rel_error=-math.inf, 256 | ), 257 | PerformanceTestResult( 258 | op_name="edge_case_op", 259 | args="[tensor([nan])]", 260 | speedup=float("inf"), 261 | benchmark_time_ms=0.0, 262 | reference_time_ms=1.0, 263 | successfully_ran=True, 264 | ), 265 | ] 266 | 267 | all_results, failed_tests, op_summaries = _prepare_results_data( 268 | [edge_case_results[0]], [edge_case_results[1]] 269 | ) 270 | 271 | # Should handle infinite values gracefully 272 | assert len(all_results) == 2 273 | assert len(op_summaries) == 1 274 | 275 | # Check that infinite speedup is handled 276 | edge_summary = op_summaries["edge_case_op"] 277 | assert math.isinf(edge_summary["avg_speedup"]) 278 | --------------------------------------------------------------------------------