├── .dockerignore ├── data ├── releases │ ├── 2025-1-problems.tar.gz │ ├── 2025-2-problems.tar.gz │ └── 2025-3-problems.tar.gz └── LICENSE ├── compute_eval ├── utils │ ├── language_data │ │ ├── languages.yml.gz │ │ └── update-linguist-languages.py │ ├── eval_utils.py │ └── parsing.py ├── __init__.py ├── token_provider.py ├── models │ ├── openAI_model.py │ └── model_interface.py ├── data │ ├── utils.py │ ├── data_pack.py │ └── data_model.py ├── main.py ├── prompts.py ├── execution.py ├── evaluation.py └── generate_completions.py ├── .vscode └── settings.json ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── Dockerfile ├── pyproject.toml ├── CONTRIBUTING.md ├── DATASET_CARD.md └── README.md /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .git 3 | __pycache__ 4 | *.pyc 5 | *.pyo 6 | *.pyd 7 | .pytest_cache 8 | .ruff_cache 9 | -------------------------------------------------------------------------------- /data/releases/2025-1-problems.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/compute-eval/HEAD/data/releases/2025-1-problems.tar.gz -------------------------------------------------------------------------------- /data/releases/2025-2-problems.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/compute-eval/HEAD/data/releases/2025-2-problems.tar.gz -------------------------------------------------------------------------------- /data/releases/2025-3-problems.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/compute-eval/HEAD/data/releases/2025-3-problems.tar.gz -------------------------------------------------------------------------------- /compute_eval/utils/language_data/languages.yml.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/compute-eval/HEAD/compute_eval/utils/language_data/languages.yml.gz -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "iostream": "cpp", 4 | "cstdlib": "cpp" 5 | }, 6 | "[python]": { 7 | "editor.defaultFormatter": "charliermarsh.ruff", 8 | "editor.formatOnSave": true, 9 | "editor.codeActionsOnSave": { 10 | "source.fixAll": "explicit", 11 | "source.organizeImports": "explicit" 12 | } 13 | }, 14 | "ruff.organizeImports": true 15 | } -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # .pre-commit-config.yaml 2 | exclude: > 3 | (?x)^( 4 | \.vscode/.*| 5 | compute_eval/temp/.*| 6 | data/.* 7 | )$ 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.0.1 # Use the latest version 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: end-of-file-fixer 15 | - id: check-yaml 16 | args: [--unsafe] 17 | - id: check-added-large-files 18 | - id: check-json 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | rev: v0.12.2 22 | hooks: 23 | - id: ruff-check 24 | args: [--fix] 25 | - id: ruff-format 26 | -------------------------------------------------------------------------------- /compute_eval/utils/language_data/update-linguist-languages.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from pathlib import Path 3 | 4 | import requests 5 | 6 | LINGUIST_URL = "https://raw.githubusercontent.com/github-linguist/linguist/HEAD/lib/linguist/languages.yml" 7 | 8 | response = requests.get(LINGUIST_URL) 9 | 10 | # If the file exists (status code is 200), write the content to a new gzipped file 11 | if response.status_code == 200 and response.text: 12 | print("Writing languages.yml.gz file") 13 | script_dir = Path(__file__).parent 14 | output_path = script_dir / "languages.yml.gz" 15 | with gzip.open(output_path, "wt") as f: 16 | f.write(response.text) 17 | else: 18 | print("Failed to fetch languages.yml file") 19 | -------------------------------------------------------------------------------- /compute_eval/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | # Load Nvidia internal extensions if available 4 | try: 5 | from compute_eval.internal import MODEL_CLASSES 6 | 7 | def _lazy_load_class(class_path: str): 8 | module_path, class_name = class_path.rsplit(".", 1) 9 | module = importlib.import_module(module_path) 10 | return getattr(module, class_name) 11 | 12 | def get_model_class(model: str): 13 | class_path = MODEL_CLASSES.get(model, "compute_eval.models.openAI_model.OpenAIModel") 14 | return _lazy_load_class(class_path) 15 | 16 | except ImportError: 17 | from compute_eval.models.openAI_model import OpenAIModel 18 | 19 | def get_model_class(model: str): 20 | return OpenAIModel 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: Apache-2.0 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | NOTE: 17 | The data in the data/ directory is licensed separately under the NVIDIA 18 | Evaluation Dataset License Agreement. See data/LICENSE for details. 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 27 | __pypackages__/ 28 | 29 | # Celery stuff 30 | celerybeat-schedule 31 | celerybeat.pid 32 | 33 | # Environments 34 | .env 35 | .venv 36 | env/ 37 | 38 | 39 | # Ignore any generated sample files 40 | **/*samples*.jsonl 41 | **/*sample*.jsonl 42 | 43 | # Ignore the generated results file 44 | **/*results.jsonl 45 | 46 | # Ignore pycharm configurations 47 | .idea/ 48 | 49 | # Ignore Claude files 50 | .claude/ 51 | CLAUDE.md 52 | -------------------------------------------------------------------------------- /compute_eval/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import shutil 3 | import subprocess 4 | 5 | 6 | # noinspection PyBroadException 7 | def _run(cmd) -> str | None: 8 | try: 9 | p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True) 10 | return p.stdout.strip() 11 | except Exception: 12 | return None 13 | 14 | 15 | def _parse_nvcc_version(text): 16 | m = re.search(r"(?i)\bV(\d+\.\d+\.\d+)\b", text) 17 | return m.group(1) if m else None 18 | 19 | 20 | def get_nvcc_version() -> str | None: 21 | nvcc = shutil.which("nvcc") 22 | if not nvcc: 23 | return None 24 | out = _run([nvcc, "--version"]) 25 | return _parse_nvcc_version(out) 26 | 27 | 28 | def parse_semver(version: str | None) -> tuple[int, int, int] | None: 29 | if version is None: 30 | return None 31 | 32 | m = re.match(r"^(\d+)(?:\.(\d+))?(?:\.(\d+))?", version) 33 | if not m: 34 | return None 35 | 36 | major, minor, patch = m.groups() 37 | return int(major), int(minor or 0), int(patch or 0) 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.3.2-cudnn9-devel-ubi8 2 | 3 | # Python 311 install as the default python3.9 is not compatible with compute eval 4 | RUN yum install -y python3.11 && \ 5 | rm -f /usr/bin/python3 && \ 6 | ln -s /usr/bin/python3.11 /usr/bin/python3 && \ 7 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3 && \ 8 | rm -rf /var/lib/apt/lists/* && \ 9 | yum clean all && \ 10 | rm -rf /var/cache/yum 11 | 12 | # Install uv 13 | COPY --from=ghcr.io/astral-sh/uv:0.9.5 /uv /usr/local/bin/uv 14 | 15 | WORKDIR /compute-eval 16 | 17 | # Copy dependency files and install dependencies only (no project) 18 | COPY pyproject.toml uv.lock ./ 19 | # Note: To enable Python CUDA support, add --extra python-cuda to both RUN commands 20 | RUN uv sync --frozen --no-dev --no-install-project 21 | 22 | # Copy source code and install project 23 | COPY . ./ 24 | ADD ./data /compute-eval-data 25 | RUN uv sync --frozen --no-dev 26 | 27 | ENV PATH="/compute-eval/.venv/bin:$PATH:/compute-eval" 28 | 29 | #set entry point 30 | ENTRYPOINT ["compute_eval"] 31 | 32 | # To generate samples do this 33 | # docker run -it --runtime nvidia -v /home/ubuntu/compute-eval/data:/data -e NEMO_API_KEY=$APIKEY compute-eval generate_samples 34 | 35 | 36 | # To verify correctness do this 37 | # docker run -it --runtime nvidia -v /home/ubuntu/compute-eval/data:/data compute-eval evaluate_functional_correctness 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "compute-eval" 3 | version = "0.1.0" 4 | description = "Library for evaluating Large Language Models on CUDA code" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "NVIDIA" } 8 | ] 9 | requires-python = ">=3.11" 10 | dependencies = [ 11 | "anthropic>=0.72.0", 12 | "fire>=0.7.1", 13 | "h11>=0.16.0", 14 | "numpy>=2.3.4", 15 | "openai>=2.6.1", 16 | "psutil>=7.1.2", 17 | "pydantic>=2.12.3", 18 | "python-dotenv>=1.2.1", 19 | "pyyaml>=6.0", 20 | "requests>=2.32.5", 21 | "tabulate>=0.9.0", 22 | "tqdm>=4.67.1", 23 | "tree-sitter>=0.25.2", 24 | "tree-sitter-language-pack>=0.10.0", 25 | "urllib3>=2.5.0", 26 | ] 27 | 28 | [project.scripts] 29 | compute_eval = "compute_eval.main:main" 30 | 31 | [build-system] 32 | requires = ["hatchling"] 33 | build-backend = "hatchling.build" 34 | 35 | [dependency-groups] 36 | dev = [ 37 | "pre-commit>=4.3.0", 38 | "pytest>=8.4.2", 39 | "ruff>=0.8.4", 40 | ] 41 | 42 | [tool.pytest.ini_options] 43 | testpaths = ["tests"] 44 | 45 | [tool.ruff] 46 | line-length = 120 47 | target-version = "py311" 48 | exclude = ["data/"] 49 | 50 | [tool.ruff.lint] 51 | # --select: 52 | # E - pycodestyle ; F - Pyflakes ; UP - pyupgrade ; B - flake8-bugbear ; SIM - flake8-simplify ; I - isort 53 | select = ["E", "F", "UP", "B", "SIM", "I"] 54 | # --ignore: 55 | # E501 - line too long 56 | # UP015 - redundant open mode 57 | # SIM118 - checks for key-existence checks against dict.keys() calls 58 | # SIM300 - constant on the left-hand side of the comparison operator 59 | # SIM910 - checks for dict.get() calls that pass None as the default value 60 | ignore = ["E501","UP015","SIM118","SIM300","SIM910"] 61 | -------------------------------------------------------------------------------- /compute_eval/token_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class TokenProvider(ABC): 6 | """Abstract base class for token providers.""" 7 | 8 | @abstractmethod 9 | def get_token(self, base_url: str) -> str | None: 10 | """Get authentication token for the given base URL.""" 11 | pass 12 | 13 | @abstractmethod 14 | def handles_url(self, base_url: str) -> bool: 15 | """Check if this provider handles the given URL.""" 16 | pass 17 | 18 | 19 | class EnvTokenProvider(TokenProvider): 20 | """Default provider that reads from environment variables.""" 21 | 22 | def __init__(self, env_var: str): 23 | self.env_var = env_var 24 | 25 | def get_token(self, base_url: str) -> str | None: 26 | return os.getenv(self.env_var) 27 | 28 | def handles_url(self, base_url: str) -> bool: 29 | # Handles all URLs by default 30 | return True 31 | 32 | 33 | # Global registry for token providers 34 | _token_providers: list[TokenProvider] = [] 35 | 36 | 37 | def register_token_provider(provider: TokenProvider): 38 | """Register a custom token provider (checked in order of registration).""" 39 | _token_providers.insert(0, provider) # Prepend so custom providers take precedence 40 | 41 | 42 | def get_token_for_url(base_url: str, default_env_var: str) -> str | None: 43 | """Get token from the first provider that handles this URL.""" 44 | for provider in _token_providers: 45 | if provider.handles_url(base_url): 46 | token = provider.get_token(base_url) 47 | if token: 48 | return token 49 | 50 | # Fallback to environment variable 51 | return os.getenv(default_env_var) 52 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Development 4 | 5 | ### Prerequisites 6 | 7 | - Python 3.11+ 8 | - [uv](https://docs.astral.sh/uv/) for dependency management 9 | 10 | ### Installation 11 | 12 | Install uv if you haven't already: 13 | 14 | ```bash 15 | curl -LsSf https://astral.sh/uv/install.sh | sh 16 | ``` 17 | 18 | Clone the repository and install dependencies: 19 | 20 | ```bash 21 | git clone 22 | cd compute-eval 23 | uv sync 24 | ``` 25 | 26 | This will create a virtual environment and install all dependencies. To also install development dependencies: 27 | 28 | ```bash 29 | uv sync --group dev 30 | ``` 31 | 32 | ### Environment Setup 33 | 34 | Create a `.env` file in the `compute-eval` directory: 35 | 36 | ```env 37 | NEMO_API_KEY="" 38 | ``` 39 | 40 | or 41 | 42 | ```env 43 | OPENAI_API_KEY="" 44 | ``` 45 | 46 | if using a custom model with OpenAI API compatibility. 47 | 48 | ### Linting 49 | 50 | You will need to install the Ruff Python formatter and linter. To do this in VSCode is simple, get [Ruff](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) from the Marketplace 51 | and then add these lines to either your workspace settings.json or your global settings.json 52 | 53 | ```json 54 | "[python]": { 55 | "editor.defaultFormatter": "charliermarsh.ruff", 56 | "editor.formatOnSave": true, 57 | "editor.codeActionsOnSave": { 58 | "source.fixAll": "explicit", 59 | "source.organizeImports": "explicit" 60 | }, 61 | }, 62 | "ruff.organizeImports": true, 63 | ``` 64 | 65 | Everytime you save the files, the linter will automatically lint for you. Depending on your workflow, you might want to have it check and report and then ask for permission to format the files. 66 | 67 | ## Sharing your contributions 68 | 69 | For any additonal contributions that are made, please include a DCO in your commit message: https://wiki.linuxfoundation.org/dco 70 | -------------------------------------------------------------------------------- /compute_eval/models/openAI_model.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | 3 | from compute_eval.models.model_interface import ModelInterface 4 | from compute_eval.token_provider import get_token_for_url 5 | 6 | # Check API keys in order of preference 7 | _api_key_names = ["OPENAI_API_KEY", "ANTHROPIC_API_KEY", "NEMO_API_KEY"] 8 | 9 | 10 | class OpenAIModel(ModelInterface): 11 | """ 12 | Generate code completions using OpenAI models. 13 | 14 | Args: 15 | base_url (str): Base URL for the OpenAI API model. 16 | model_name (str): Name of the model to use for generating completions. 17 | """ 18 | 19 | _api_key_printed = False 20 | 21 | def __init__( 22 | self, 23 | model_name: str, 24 | base_url: str | None, 25 | reasoning: str | None = None, 26 | ): 27 | dotenv.load_dotenv() 28 | 29 | self._model_name = model_name 30 | self._base_url = base_url or "https://api.openai.com/v1" 31 | self.reasoning = reasoning 32 | 33 | self._api_key_name = None 34 | 35 | for key_name in _api_key_names: 36 | if get_token_for_url(self._base_url, key_name) is not None: 37 | self._api_key_name = key_name 38 | break 39 | 40 | if self._api_key_name is None: 41 | raise Exception( 42 | f"Could not find any of: {', '.join(_api_key_names)}. Please set one of these environment variables." 43 | ) 44 | 45 | if not OpenAIModel._api_key_printed: 46 | print(f"Using {self._api_key_name} for authentication") 47 | OpenAIModel._api_key_printed = True 48 | 49 | @property 50 | def api_key(self) -> str: 51 | url = get_token_for_url(self.base_url, self._api_key_name) 52 | if url is None: 53 | raise Exception(f"Could not get {self._api_key_name}.") 54 | return url 55 | 56 | @property 57 | def base_url(self) -> str: 58 | return self._base_url 59 | 60 | @property 61 | def model_name(self) -> str: 62 | return self._model_name 63 | -------------------------------------------------------------------------------- /compute_eval/data/utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | from collections.abc import Generator, Iterable 5 | from pathlib import Path 6 | from typing import Annotated 7 | 8 | from pydantic import BaseModel, Field, TypeAdapter 9 | 10 | from compute_eval.data.data_model import ( 11 | CudaCppProblem, 12 | CudaPythonProblem, 13 | FileSolution, 14 | GradedSolution, 15 | PatchSolution, 16 | Problem, 17 | Solution, 18 | ) 19 | 20 | 21 | def _open_file(file_path: str | Path, mode: str): 22 | file_path = os.path.expanduser(file_path) 23 | if file_path.endswith(".gz"): 24 | return gzip.open(file_path, mode) 25 | else: 26 | return open(file_path, mode) 27 | 28 | 29 | def stream_jsonl(path: str | Path) -> Iterable[dict]: 30 | """ 31 | Yield each parsed JSON object from JSONL file(s). 32 | 33 | Args: 34 | path: Path to a single .jsonl/.jsonl.gz file or directory containing such files 35 | 36 | Yields: 37 | Parsed JSON objects from all matching files (in sorted filename order for directories) 38 | """ 39 | 40 | def _is_jsonl_file(p: Path) -> bool: 41 | """Check if file has .jsonl or .jsonl.gz extension.""" 42 | return p.name.endswith(".jsonl") or p.name.endswith(".jsonl.gz") 43 | 44 | def _stream_file(p: str | Path) -> Iterable[dict]: 45 | """Yield each parsed JSON object from a JSONL file (gzip supported).""" 46 | with _open_file(p, "rt") as fp: 47 | for line in fp: 48 | if any(not ch.isspace() for ch in line): 49 | yield json.loads(line, strict=False) 50 | 51 | path = Path(os.path.expanduser(path)) 52 | 53 | # Determine files to process 54 | if path.is_file(): 55 | if not _is_jsonl_file(path): 56 | raise ValueError(f"File must end with .jsonl or .jsonl.gz, got: {path}") 57 | 58 | yield from _stream_file(path) 59 | elif path.is_dir(): 60 | # Find all .jsonl and .jsonl.gz files in directory (sorted) 61 | files = sorted(f for f in path.iterdir() if f.is_file() and _is_jsonl_file(f)) 62 | if not files: 63 | raise ValueError(f"No .jsonl or .jsonl.gz files found in directory: {path}") 64 | 65 | for file in files: 66 | yield from _stream_file(str(file)) 67 | else: 68 | raise ValueError(f"Path does not exist: {path}") 69 | 70 | 71 | def write_jsonl(file_path: str, data: list[dict | BaseModel], append: bool = False): 72 | """Write iterable of dicts or Pydantic model instances to a JSONL file.""" 73 | mode = "at" if append else "wt" 74 | with _open_file(file_path, mode) as fp: 75 | for item in data: 76 | if isinstance(item, BaseModel): 77 | fp.write(item.model_dump_json(serialize_as_any=True) + "\n") 78 | elif isinstance(item, dict): 79 | fp.write(json.dumps(item) + "\n") 80 | else: 81 | raise ValueError(f"Cannot write object of type {type(item)}") 82 | 83 | 84 | def read_problems(file_path: str) -> Generator[Problem, None, None]: 85 | adapter = TypeAdapter(Annotated[CudaCppProblem | CudaPythonProblem, Field(discriminator="type")]) 86 | yield from (adapter.validate_python(data) for data in stream_jsonl(file_path)) 87 | 88 | 89 | def write_problems(file_path: str, problems: list[Problem], append: bool = False): 90 | write_jsonl(file_path, problems, append=append) 91 | 92 | 93 | def read_solutions(file_path: str) -> Generator[Solution, None, None]: 94 | adapter = TypeAdapter(Annotated[FileSolution | PatchSolution, Field(discriminator="type")]) 95 | yield from (adapter.validate_python(data) for data in stream_jsonl(file_path)) 96 | 97 | 98 | def write_solutions(file_path: str, solutions: list[Solution], append: bool = False): 99 | write_jsonl(file_path, solutions, append=append) 100 | 101 | 102 | def read_graded_solutions(file_path: str) -> Generator[GradedSolution, None, None]: 103 | yield from (GradedSolution.model_validate(data) for data in stream_jsonl(file_path)) 104 | 105 | 106 | def write_graded_solutions(file_path: str, graded_solutions: list[GradedSolution], append: bool = False): 107 | write_jsonl(file_path, graded_solutions, append=append) 108 | -------------------------------------------------------------------------------- /compute_eval/main.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import yaml 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | from compute_eval.data.data_model import ReleaseVersion 6 | from compute_eval.evaluation import evaluate_functional_correctness 7 | from compute_eval.generate_completions import generate_samples 8 | from compute_eval.prompts import SYSTEM_PROMPT 9 | 10 | 11 | class EvaluateConfig(BaseModel): 12 | """Configuration for functional correctness evaluation.""" 13 | 14 | model_config = ConfigDict(extra="forbid") 15 | 16 | solutions_datapack: str = Field( 17 | default=..., 18 | description="Path to the solutions datapack", 19 | ) 20 | problems_datapack_dir: str = Field( 21 | default="data/releases/", 22 | description="Directory where released problem datapacks are stored", 23 | ) 24 | allow_execution: bool = Field( 25 | default=False, 26 | description="Whether to allow execution of untrusted code. This must be set to True.", 27 | ) 28 | k: int | tuple[int, ...] = Field( 29 | default=1, 30 | description="K value(s) for pass@k evaluation", 31 | ) 32 | n_workers: int = Field( 33 | default=4, 34 | description="Number of worker threads", 35 | ) 36 | results_file: str | None = Field( 37 | default=None, 38 | description="Path to output results file (defaults to {solutions_datapack's release}-graded-solutions.jsonl if not provided)", 39 | ) 40 | 41 | 42 | class GenerateConfig(BaseModel): 43 | """Configuration for solution generation.""" 44 | 45 | model_config = ConfigDict(extra="forbid") 46 | 47 | release: ReleaseVersion = Field( 48 | default=ReleaseVersion.V2025_2, 49 | description="Release version to generate solutions for", 50 | ) 51 | problems_datapack_dir: str = Field( 52 | default="data/releases/", 53 | description="Directory where released problem datapacks are stored", 54 | ) 55 | solutions_per_problem: int = Field( 56 | default=1, 57 | description="Number of solutions to generate per problem", 58 | ) 59 | n_workers: int = Field( 60 | default=10, 61 | description="Number of worker threads", 62 | ) 63 | system_prompt: str = Field( 64 | default=SYSTEM_PROMPT, 65 | description="System prompt for the model", 66 | ) 67 | model: str = Field( 68 | default=..., 69 | description="Model to use for generation", 70 | ) 71 | base_url: str | None = Field( 72 | default=None, 73 | description="Base URL for the model API", 74 | ) 75 | reasoning: str | None = Field( 76 | default=None, 77 | description="Reasoning mode for the model (e.g., 'low', 'medium', 'high' for GPT models, or any value for Claude models to enable extended thinking)", 78 | ) 79 | temperature: float | None = Field( 80 | default=None, 81 | description="Temperature for generation", 82 | ) 83 | top_p: float | None = Field( 84 | default=None, 85 | description="Top-p for generation", 86 | ) 87 | max_tokens: int | None = Field( 88 | default=None, 89 | description="Maximum tokens for generation", 90 | ) 91 | temp_dir: str | None = Field( 92 | default=None, 93 | description="Temporary directory for intermediate results", 94 | ) 95 | debug: bool = Field( 96 | default=False, 97 | description="Include system prompt, prompt, and completion in the output solutions file for debugging", 98 | ) 99 | 100 | 101 | def _build_config( 102 | config_file: str | None, 103 | model_class: type[BaseModel], 104 | cli_kwargs: dict, 105 | ) -> BaseModel: 106 | """Merge config file and CLI args, with CLI taking precedence.""" 107 | if config_file: 108 | with open(config_file, "r") as file: 109 | config_data = yaml.safe_load(file) or {} 110 | else: 111 | config_data = {} 112 | 113 | config_data.update({k: v for k, v in cli_kwargs.items() if v is not None}) 114 | return model_class(**config_data) 115 | 116 | 117 | def generate_samples_with_config(config_file: str | None = None, **cli_kwargs): 118 | config = _build_config(config_file, GenerateConfig, cli_kwargs) 119 | generate_samples(**config.model_dump()) 120 | 121 | 122 | def evaluate_functional_correctness_with_config(config_file: str | None = None, **cli_kwargs): 123 | config = _build_config(config_file, EvaluateConfig, cli_kwargs) 124 | evaluate_functional_correctness(**config.model_dump()) 125 | 126 | 127 | def main(): 128 | fire.Fire( 129 | { 130 | "evaluate_functional_correctness": evaluate_functional_correctness_with_config, 131 | "generate_samples": generate_samples_with_config, 132 | } 133 | ) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /compute_eval/prompts.py: -------------------------------------------------------------------------------- 1 | from compute_eval.data.data_model import Problem, SourceFile 2 | 3 | SYSTEM_PROMPT = """ 4 | You are a senior CUDA/C/C++ engineer. Produce complete, compilable solutions from a structured problem specification. Follow these rules: 5 | 6 | General 7 | - You will be given: a problem description, context files (editable), and build environment details (e.g., build command). 8 | - Hidden tests exist but are not shown. Do not mention tests, do not write test code, and do not add I/O used only for testing. 9 | - Use only the APIs and contracts specified in the problem and context files. Preserve all provided function signatures exactly. 10 | - Prefer using only headers already present in the provided codebase. Avoid adding new headers unless strictly necessary and supported by the build command. Do not introduce third-party dependencies. 11 | 12 | Context files policy 13 | - You may modify provided context files when necessary. If you include any file in your solution output (new or modified), emit its full and final contents; your output will overwrite the provided version. 14 | - Only emit files you add or modify. Do not output files that are unchanged, and do not include placeholder blocks saying "no changes" or similar. 15 | 16 | Build command 17 | - You should pay careful attention to the build command or any context files about the build process. 18 | - The build command and/or context build files may include important hints about required files or expected project structure. This likely includes the name of the expected solution file, important macros, standards, or linked libraries. 19 | - Pay special attention to -I or -isystem flags -- they indicate important include paths. Remember, if a -I or -isystem flag is present you do not need to include the relative path in your #include statements. 20 | 21 | 22 | Output format 23 | - Output only source files needed for the solution. No explanations or commentary. 24 | - Each file must be in its own fenced code block, with the first line indicating its path as a comment. 25 | Example: 26 | ``` 27 | // file: geodistance.cu 28 | #include "geodistance.h" 29 | ... 30 | ``` 31 | 32 | Code quality and constraints 33 | 34 | The solution must compile cleanly with the provided build command and target architectures. 35 | Avoid unnecessary heap allocations, environment access, and global mutable state. Keep deterministic behavior. 36 | Honor all contracts, constants, and macros defined in provided headers. 37 | 38 | For CUDA: 39 | Implement kernels with correct global signatures and parameter types. 40 | Bounds-check all memory accesses; consider grid-stride loops when appropriate for scalability. 41 | Favor coalesced memory access and avoid undefined behavior. 42 | Apply appropriate numerical stability practices when needed (e.g., clamp arguments before acos/asin). 43 | 44 | Reasoning discipline 45 | 46 | Think through edge cases and performance internally, but output only the final code files, no analysis or explanations. 47 | """ 48 | 49 | _USER_PROMPT = """ 50 | Produce the complete solution as one or more source files that compile with the provided build command. Do not output anything except the code files. 51 | 52 | Problem 53 | Description: 54 | {prompt} 55 | 56 | Build command: 57 | {build_command} 58 | 59 | Context files: 60 | {context_files_block} 61 | 62 | Output requirements 63 | 64 | Emit only the source files necessary to satisfy the problem (new or modified). 65 | Only emit files you add or modify. Do not output files that are unchanged, and do not include placeholder blocks saying "no changes" or similar. 66 | Do not include any test code or references to tests. 67 | If an interface header is provided (e.g., declares functions to implement), place implementations in a corresponding .cu/.cc source file and include that header. 68 | Begin your response with the first code block. 69 | """ 70 | 71 | _CONTEXT_FILES_BLOCK_TEMPLATE = """ 72 | --- file: {path} 73 | ```{fence} 74 | {content} 75 | ``` 76 | """ 77 | 78 | 79 | def _fence_for_path(path: str) -> str: 80 | p = path.lower() 81 | if p.endswith((".cu", ".cuh")): 82 | return "cuda" 83 | if p.endswith((".cc", ".cpp", ".cxx")): 84 | return "cpp" 85 | if p.endswith(".c"): 86 | return "c" 87 | if p.endswith(".h") or p.endswith(".hpp"): 88 | return "h" 89 | # Default to plaintext if unknown 90 | return "" 91 | 92 | 93 | def _format_context_files_block(context_files: list[SourceFile]) -> str: 94 | blocks: list[str] = [] 95 | for source in context_files: 96 | fence = _fence_for_path(source.path) 97 | blocks.append(_CONTEXT_FILES_BLOCK_TEMPLATE.format(path=source.path, fence=fence, content=source.content)) 98 | return "".join(blocks) 99 | 100 | 101 | def to_user_message(problem: Problem) -> str: 102 | return _USER_PROMPT.format( 103 | prompt=problem.prompt, 104 | build_command=problem.build_command, 105 | context_files_block=_format_context_files_block(problem.context_files), 106 | ) 107 | -------------------------------------------------------------------------------- /compute_eval/models/model_interface.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from abc import ABC, abstractmethod 4 | 5 | from openai import OpenAI 6 | 7 | RETRIABLE_STATUS_CODES = [ 8 | # These are server side errors where we can get correct response if we try again later 9 | 429, # Too many requests, happens you are exceeding the rate limit 10 | 500, # Internal Server Error 11 | 502, # Bad Gateway 12 | 503, # Service Unavailable 13 | 504, # Gateway Timeout Error 14 | ] 15 | 16 | 17 | # define a retry decorator 18 | def retry_with_exponential_backoff( 19 | func, 20 | initial_delay: float = 1, 21 | exponential_base: float = 2, 22 | jitter: bool = True, 23 | max_retries: int = 10, 24 | ): 25 | """Retry a function with exponential backoff.""" 26 | 27 | def wrapper(*args, **kwargs): 28 | # Initialize variables 29 | num_retries = 0 30 | delay = initial_delay 31 | 32 | # Loop until a successful response or max_retries is hit or an exception is raised 33 | while True: 34 | try: 35 | return func(*args, **kwargs) 36 | 37 | # Retry on specified errors 38 | except Exception as e: 39 | status_code = getattr(e, "status_code", None) 40 | 41 | # Check if the status code is retriable status code 42 | if status_code in RETRIABLE_STATUS_CODES: 43 | # Increment retries 44 | num_retries += 1 45 | 46 | # Check if max retries has been reached 47 | if num_retries > max_retries: 48 | raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") from None 49 | 50 | # Increment the delay 51 | delay *= exponential_base * (1 + jitter * random.random()) 52 | 53 | # Print error message 54 | print(f"Error occurred {str(e)}, retrying after {delay:.2f} seconds.") 55 | 56 | # Sleep for the delay 57 | time.sleep(delay) 58 | elif status_code == 400: 59 | raise Exception("Invalid request was made. Check the headers and payload") from None 60 | elif status_code == 401: 61 | raise Exception("Unauthorized HTTP request. Check your headers and API key") from None 62 | elif status_code == 403: 63 | raise Exception("You are forbidden from accessing this resource") from None 64 | else: 65 | raise Exception( 66 | f"An error occurred when accessing the model API. Check your headers and payload. Error: {str(e)}" 67 | ) from None 68 | 69 | return wrapper 70 | 71 | 72 | class ModelInterface(ABC): 73 | """ 74 | Base class for generating code completions. 75 | """ 76 | 77 | @property 78 | @abstractmethod 79 | def api_key(self) -> str: 80 | pass 81 | 82 | @property 83 | @abstractmethod 84 | def base_url(self) -> str: 85 | pass 86 | 87 | @property 88 | @abstractmethod 89 | def model_name(self) -> str: 90 | pass 91 | 92 | @retry_with_exponential_backoff 93 | def call_chat_completions_endpoint(self, **kwargs): 94 | """ 95 | Call the chat completions endpoint of the model API. 96 | """ 97 | client = OpenAI(base_url=self.base_url, api_key=self.api_key) 98 | return client.chat.completions.create(**kwargs) 99 | 100 | def generate_response(self, system_prompt, prompt, params): 101 | """ 102 | Generate code completions by communicating with the OpenAI API. 103 | 104 | Args: 105 | system_prompt (str, optional): The system prompt to use for generating completions. 106 | prompt (str): The user prompt to use for generating completions. 107 | params (dict, optional): Additional parameters for the API call. 108 | 109 | Returns: 110 | str: The generation result from the model. 111 | """ 112 | 113 | messages = [] 114 | 115 | if system_prompt is not None: 116 | messages.append({"role": "system", "content": system_prompt}) 117 | 118 | messages.append({"role": "user", "content": prompt}) 119 | 120 | params_dict = { 121 | "model": self.model_name, 122 | "messages": messages, 123 | "stream": False, 124 | } 125 | 126 | if (reasoning := getattr(self, "reasoning", None)) is not None: 127 | params_dict["reasoning_effort"] = reasoning 128 | 129 | # Add optional parameters only if they're not None 130 | # Some models (e.g. o1-mini) don't support passing some args 131 | # We need to exclude them 132 | temperature = get_parameter_value("temperature", params, None) 133 | if temperature is not None: 134 | params_dict["temperature"] = temperature 135 | 136 | top_p = get_parameter_value("top_p", params, None) 137 | if top_p is not None: 138 | params_dict["top_p"] = top_p 139 | 140 | max_tokens = get_parameter_value("max_tokens", params, 2048) 141 | if max_tokens is not None: 142 | params_dict["max_tokens"] = max_tokens 143 | 144 | response = self.call_chat_completions_endpoint(**params_dict) 145 | 146 | try: 147 | completion = response.choices[0].message.content 148 | except KeyError as e: 149 | print(f"WARNING: The completion object is invalid. Could not find the key {str(e)}") 150 | completion = "" 151 | except Exception: 152 | raise Exception("There was an error when accessing the completion") from None 153 | 154 | return completion 155 | 156 | 157 | def get_parameter_value(parameter, parameters, default_value): 158 | if parameters is not None and parameter in parameters: 159 | return parameters[parameter] 160 | else: 161 | return default_value 162 | -------------------------------------------------------------------------------- /compute_eval/data/data_pack.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import json 4 | import os 5 | import tarfile 6 | import tempfile 7 | from abc import ABC, abstractmethod 8 | from collections.abc import Generator, Iterable 9 | from datetime import datetime 10 | from pathlib import Path 11 | from typing import Annotated 12 | 13 | from pydantic import BaseModel, Field, TypeAdapter 14 | 15 | from compute_eval.data.data_model import ( 16 | CudaCppProblem, 17 | CudaPythonProblem, 18 | FileSolution, 19 | PatchSolution, 20 | Problem, 21 | ReleaseVersion, 22 | Solution, 23 | ) 24 | 25 | 26 | class DatapackMetadata(BaseModel): 27 | release: ReleaseVersion 28 | total_count: int 29 | created_at: str = Field(default_factory=lambda: datetime.now().isoformat()) 30 | description: str | None = None 31 | 32 | 33 | class ProblemDatapackMetadata(DatapackMetadata): 34 | task_id_hashes: dict[str, str] = Field(default_factory=dict) 35 | 36 | 37 | class Datapack(ABC): 38 | _metadata_class: type[DatapackMetadata] = DatapackMetadata 39 | _data_filename = "data.jsonl" 40 | 41 | def __init__(self, path: str | Path): 42 | self._path = Path(os.path.expanduser(path)) 43 | self._tar = None 44 | self._metadata = None 45 | 46 | def __enter__(self): 47 | if self._tar is not None: 48 | raise RuntimeError("DatapackReader is already open") 49 | 50 | self._tar = tarfile.open(self._path, "r:gz") 51 | return self 52 | 53 | def __exit__(self, *args): 54 | if self._tar: 55 | self._tar.close() 56 | self._tar = None 57 | 58 | @property 59 | def metadata(self) -> DatapackMetadata: 60 | if self._metadata is None: 61 | with tarfile.open(self._path, "r:gz") as tar: 62 | try: 63 | f = tar.extractfile("metadata.json") 64 | if f is None: 65 | raise ValueError("metadata.json not found in data pack") 66 | 67 | self._metadata = self._metadata_class.model_validate_json(f.read().decode("utf-8")) 68 | except KeyError as e: 69 | raise ValueError(f"Invalid data pack: missing metadata.json in {self._path}") from e 70 | return self._metadata 71 | 72 | @abstractmethod 73 | def read_items(self) -> Generator[BaseModel, None, None]: 74 | pass 75 | 76 | def _stream(self) -> Generator[dict, None, None]: 77 | if self._tar is None: 78 | raise RuntimeError("DatapackReader must be used as a context manager") 79 | 80 | try: 81 | f = self._tar.extractfile(self._data_filename) 82 | if f is None: 83 | raise ValueError(f"{self._data_filename} not found in data pack") 84 | except KeyError as e: 85 | raise ValueError(f"Invalid data pack: missing {self._data_filename} in {self._path}") from e 86 | 87 | text_stream = io.TextIOWrapper(f, encoding="utf-8") 88 | for line in text_stream: 89 | if any(not ch.isspace() for ch in line): 90 | yield json.loads(line, strict=False) 91 | 92 | @classmethod 93 | def _write_item(cls, item: BaseModel, file, metadata: DatapackMetadata) -> str: 94 | """ 95 | Write a single item and update metadata as needed. 96 | 97 | Returns the written line for use by subclasses. 98 | """ 99 | line = item.model_dump_json(serialize_as_any=True) + "\n" 100 | file.write(line) 101 | metadata.total_count += 1 102 | 103 | return line 104 | 105 | @classmethod 106 | def create( 107 | cls, 108 | file_path: str | Path, 109 | items: Iterable[BaseModel], 110 | release: ReleaseVersion, 111 | description: str | None = None, 112 | ): 113 | file_path = Path(os.path.expanduser(file_path)) 114 | 115 | with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp: 116 | tmp_path = tmp.name 117 | 118 | try: 119 | # Create metadata instance that will be built up 120 | metadata = cls._metadata_class( 121 | release=release, 122 | total_count=0, 123 | created_at=datetime.now().isoformat(), 124 | description=description, 125 | ) 126 | 127 | for item in items: 128 | cls._write_item(item, tmp, metadata) 129 | 130 | tmp.flush() 131 | 132 | # Create the tar.gz data pack 133 | with tarfile.open(file_path, "w:gz") as tar: 134 | metadata_bytes = metadata.model_dump_json(indent=2).encode("utf-8") 135 | metadata_info = tarfile.TarInfo(name="metadata.json") 136 | metadata_info.size = len(metadata_bytes) 137 | tar.addfile(metadata_info, io.BytesIO(metadata_bytes)) 138 | tar.add(tmp_path, arcname=cls._data_filename) 139 | 140 | finally: 141 | if os.path.exists(tmp_path): 142 | os.unlink(tmp_path) 143 | 144 | 145 | class ProblemDatapack(Datapack): 146 | _metadata_class = ProblemDatapackMetadata 147 | _data_filename = "problems.jsonl" 148 | 149 | def __init__(self, path: str | Path): 150 | super().__init__(path) 151 | 152 | @property 153 | def metadata(self) -> ProblemDatapackMetadata: 154 | return super().metadata # type: ignore 155 | 156 | def read_items(self) -> Generator[Problem, None, None]: 157 | adapter = TypeAdapter(Annotated[CudaCppProblem | CudaPythonProblem, Field(discriminator="type")]) 158 | for item in self._stream(): 159 | yield adapter.validate_python(item) 160 | 161 | @classmethod 162 | def _write_item(cls, item: Problem, file, metadata: ProblemDatapackMetadata) -> str: 163 | if item.task_id in metadata.task_id_hashes: 164 | raise ValueError(f"Duplicate task_id found when writing problem datapack: {item.task_id}") 165 | 166 | line = super()._write_item(item, file, metadata) 167 | problem_hash = hashlib.md5(line.encode("utf-8")).hexdigest() 168 | metadata.task_id_hashes[item.task_id] = problem_hash 169 | 170 | return line 171 | 172 | 173 | class SolutionDatapack(Datapack): 174 | _data_filename = "solutions.jsonl" 175 | 176 | def __init__(self, path: str | Path): 177 | super().__init__(path) 178 | 179 | def read_items(self) -> Generator[Solution, None, None]: 180 | adapter = TypeAdapter(Annotated[FileSolution | PatchSolution, Field(discriminator="type")]) 181 | for item in self._stream(): 182 | yield adapter.validate_python(item) 183 | -------------------------------------------------------------------------------- /DATASET_CARD.md: -------------------------------------------------------------------------------- 1 | # Dataset Card for ComputeEval 2 | 3 | **ComputeEval** is a benchmark dataset for evaluating Large Language Models on **CUDA code generation** tasks. Each problem provides a self-contained programming challenge designed to test various aspects of CUDA development, including kernel launches, memory management, parallel algorithms, and CUDA libraries (Thrust, CUB, etc.). 4 | 5 | **Homepage:** [github.com/NVIDIA/compute-eval](https://github.com/NVIDIA/compute-eval) 6 | 7 | --- 8 | 9 | ## Dataset Format 10 | 11 | Problems are distributed as **datapacks** - versioned releases stored as compressed tarballs (`.tar.gz`). Each datapack contains: 12 | 13 | - **`metadata.json`** - Release version, creation timestamp, problem count, and integrity hashes 14 | - **`problems.jsonl`** - One JSON object per line representing each problem 15 | 16 | **Storage Format:** JSON Lines (`.jsonl`) 17 | **Encoding:** UTF-8 18 | 19 | --- 20 | 21 | ## Data Schema 22 | 23 | ### Problem Structure 24 | 25 | Each problem is a JSON object with the following schema: 26 | 27 | #### Core Fields 28 | 29 | | Field | Type | Required | Description | 30 | |-------|------|----------|-------------| 31 | | `task_id` | `string` | ✓ | Unique identifier (e.g., `"CUDA/0"`, `"CUDA/42"`) | 32 | | `type` | `string` | ✓ | Problem type: `"cuda_cpp"` or `"cuda_python"` | 33 | | `schema_version` | `integer` | ✓ | Data schema version (currently `2`) | 34 | | `date` | `string` | ✓ | Problem creation date (ISO 8601 format: `"YYYY-MM-DD"`) | 35 | | `prompt` | `string` | ✓ | Natural language instruction for the programming task | 36 | | `context_files` | `array` | ✓ | Files visible to the model/system (headers, stubs, helpers) | 37 | | `test_files` | `array` | ✓ | Held-out test harness for evaluation (not shown to model/system) | 38 | | `build_command` | `string` | — | Command to compile the solution (e.g., `"nvcc -I include ..."`) | 39 | | `test_command` | `string` | ✓ | Command to execute tests (e.g., `"./test.out"`) | 40 | | `min_cuda_toolkit` | `string` | — | Minimum CUDA Toolkit version required (e.g., `"12.0"`) | 41 | | `timeout_seconds` | `float` | — | Maximum execution time allowed for tests | 42 | | `source_references` | `string` or `array` | — | Required API calls to verify in solution (e.g., `["cudaMalloc", "cudaFree"]`) | 43 | | `metadata` | `object` | — | Additional problem metadata | 44 | | `arch_list` | `array` | — | GPU architectures required (e.g., `["sm_80", "sm_89"]`) | 45 | 46 | #### File Objects (context_files & test_files) 47 | 48 | Each file in `context_files` and `test_files` is an object with: 49 | 50 | | Field | Type | Description | 51 | |-------|------|-------------| 52 | | `path` | `string` | Relative file path (e.g., `"include/kernel.h"`) | 53 | | `content` | `string` | Complete file contents (UTF-8 encoded) | 54 | 55 | --- 56 | 57 | ## Evaluation Protocol 58 | 59 | ComputeEval follows a strict separation between generation and evaluation: 60 | 61 | ### What Models/Systems See (Generation Time) 62 | 63 | - Problem `prompt` - describes the task and requirements 64 | - `context_files` - interface definitions and helper utilities 65 | - `build_command` - compilation instructions (if provided) 66 | - Minimum CUDA toolkit version and architecture requirements 67 | 68 | ### What Models/Systems Do NOT See 69 | 70 | - `test_files` - held-out test harness 71 | - Reference solutions 72 | 73 | ### Evaluation Process 74 | 75 | 1. Create temporary workspace 76 | 2. Write `context_files` to workspace 77 | 3. Write model-generated solution files to workspace 78 | 4. Write `test_files` to workspace (now visible) 79 | 5. Execute `build_command` to compile (if provided) 80 | 6. If compilation succeeds (or no build step required), execute `test_command` 81 | 7. Test exit code determines pass/fail (exit code 0 = pass) 82 | 83 | This ensures models cannot overfit to test cases and must solve problems based solely on the natural language description and interface contracts. 84 | 85 | --- 86 | 87 | ## Versioning and Maintenance 88 | 89 | ComputeEval maintains all previous release versions to enable longitudinal tracking of model progress. Users can benchmark against any release version to track improvements over time. 90 | 91 | **Important:** We are committed to maintaining backward compatibility, but not bit-for-bit immutability. If we discover bugs in problems (e.g., unsolvable test cases, incorrect specifications), we reserve the right to fix them and update the corresponding datapacks in future releases. For exact historical versions, users can download specific releases from the git repository history. 92 | 93 | This approach ensures users can continue using previous benchmark versions while benefiting from bug fixes and improvements. 94 | 95 | --- 96 | 97 | ## Example Problem 98 | 99 | ```json 100 | { 101 | "task_id": "CUDA/3", 102 | "type": "cuda_cpp", 103 | "date": "2025-10-31", 104 | "prompt": "Implement a function called `launch` that launches a kernel function named `kernel` without using triple chevrons. The x, y, z grid and block dimensions will be provided as parameters to the `launch` function.\n\nThe function signature is defined in `include/kernel_launch.h`:\n```cuda\nvoid launch(int gridSizeX, int blockSizeX, int gridSizeY = 1, int blockSizeY = 1,\n int gridSizeZ = 1, int blockSizeZ = 1);\n```\n\nThe `kernel` function is already defined with the following signature:\n```cuda\n__global__ void kernel(int *output, const int *input);\n```\n\nYour implementation should use the CUDA runtime API to launch the kernel with the specified grid and block dimensions.", 105 | "context_files": [ 106 | { 107 | "path": "include/kernel_launch.h", 108 | "content": "#pragma once\n\nvoid launch(int gridSizeX, int blockSizeX, int gridSizeY = 1, int blockSizeY = 1,\n int gridSizeZ = 1, int blockSizeZ = 1);\n" 109 | }, 110 | { 111 | "path": "src/kernel.cu", 112 | "content": "#include \n\n__global__ void kernel(int *output, const int *input) {\n int idx = blockIdx.x * blockDim.x + threadIdx.x;\n output[idx] = input[idx] * 2;\n}\n" 113 | } 114 | ], 115 | "test_files": [ 116 | { 117 | "path": "test/test_main.cu", 118 | "content": "#include \n#include \n#include \"../include/kernel_launch.h\"\n\nint main() {\n // Test implementation\n int *d_input, *d_output;\n cudaMalloc(&d_input, 256 * sizeof(int));\n cudaMalloc(&d_output, 256 * sizeof(int));\n \n launch(4, 64); // Launch with 4 blocks, 64 threads each\n \n cudaFree(d_input);\n cudaFree(d_output);\n return 0;\n}\n" 119 | } 120 | ], 121 | "build_command": "nvcc -I include -o test.out solution.cu src/kernel.cu test/*.cu -arch=sm_80", 122 | "test_command": "./test.out", 123 | "min_cuda_toolkit": "12.0", 124 | "timeout_seconds": 30.0, 125 | "arch_list": ["sm_80", "sm_89", "sm_90"], 126 | "source_references": ["cudaLaunchKernelEx"] 127 | } 128 | ``` 129 | 130 | --- 131 | 132 | ## License 133 | 134 | **SPDX-FileCopyrightText:** Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 135 | **SPDX-License-Identifier:** LicenseRef-NVIDIA-Evaluation 136 | 137 | This dataset is licensed under the **NVIDIA Evaluation Dataset License Agreement**. 138 | See the full license text in [`data/LICENSE`](data/LICENSE). 139 | 140 | --- 141 | 142 | ## Citation 143 | 144 | If you use ComputeEval in your research, please cite: 145 | 146 | ```bibtex 147 | @misc{computeeval2025, 148 | title={ComputeEval: A Benchmark for Evaluating Large Language Models on CUDA Code Generation}, 149 | author={NVIDIA Corporation}, 150 | year={2025}, 151 | url={https://github.com/NVIDIA/compute-eval} 152 | } 153 | ``` 154 | 155 | --- 156 | 157 | ## Additional Resources 158 | 159 | - **Full Documentation:** [README.md](README.md) 160 | - **Contributing Guidelines:** [CONTRIBUTING.md](CONTRIBUTING.md) 161 | - **Issue Tracker:** [GitHub Issues](https://github.com/NVIDIA/compute-eval/issues) 162 | -------------------------------------------------------------------------------- /compute_eval/execution.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Portions of this file from human-eval (https://github.com/openai/human-eval/). 17 | # 18 | # The MIT License 19 | # 20 | # Copyright (c) OpenAI (https://openai.com) 21 | # 22 | # Permission is hereby granted, free of charge, to any person obtaining a copy 23 | # of this software and associated documentation files (the "Software"), to deal 24 | # in the Software without restriction, including without limitation the rights 25 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 26 | # copies of the Software, and to permit persons to whom the Software is 27 | # furnished to do so, subject to the following conditions: 28 | # 29 | # The above copyright notice and this permission notice shall be included in 30 | # all copies or substantial portions of the Software. 31 | # 32 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 33 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 34 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 35 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 36 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 37 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 38 | # THE SOFTWARE. 39 | import os 40 | import subprocess 41 | import tempfile 42 | import time 43 | from contextlib import contextmanager 44 | from pathlib import Path 45 | 46 | from compute_eval.data.data_model import ( 47 | SOLUTION_SCHEMA_VERSION, 48 | GradedSolution, 49 | Problem, 50 | Solution, 51 | ) 52 | from compute_eval.utils.eval_utils import parse_semver 53 | 54 | 55 | @contextmanager 56 | def _work_dir_context(task_id: str): 57 | """ 58 | Context manager for a work directory that in DEBUG mode is persistent 59 | and named exactly as `identifier` in the current working directory. 60 | 61 | Non-DEBUG mode uses a TemporaryDirectory with automatic cleanup. 62 | 63 | Args: 64 | task_id (str): Identifier for the work directory, typically the task ID. 65 | """ 66 | debug_mode = os.environ.get("DEBUG", "0") == "1" 67 | 68 | if debug_mode: 69 | base_path = Path.cwd() 70 | tmpdir_path = base_path / task_id 71 | 72 | if tmpdir_path.exists(): 73 | print(f"ERROR: debug temp directory already exists: {tmpdir_path}\n") 74 | raise FileExistsError(f"Refusing to overwrite existing debug temp directory: {tmpdir_path}") 75 | 76 | tmpdir_path.mkdir(parents=True, exist_ok=False) 77 | yield tmpdir_path 78 | else: 79 | with tempfile.TemporaryDirectory() as tmpdir: 80 | yield Path(tmpdir) 81 | 82 | 83 | def evaluate_solution( 84 | installed_ctk_major: int, 85 | installed_ctk_minor: int, 86 | problem: Problem, 87 | solution: Solution, 88 | ) -> GradedSolution: 89 | start_time = time.time() 90 | # Verify that the Solution is generated with the current schema version 91 | if solution.schema_version != SOLUTION_SCHEMA_VERSION: 92 | elapsed = time.time() - start_time 93 | return GradedSolution( 94 | task_id=solution.task_id, 95 | solution=solution, 96 | problem=problem, 97 | passed=False, 98 | skipped=True, 99 | elapsed_time=elapsed, 100 | build_output=f"[SCHEMA VERSION MISMATCH] Solution schema version {solution.schema_version} does not match expected version {SOLUTION_SCHEMA_VERSION}.", 101 | ) 102 | 103 | # Check CUDA toolkit version compatibility (if applicable) 104 | if (required_ctk := parse_semver(problem.min_cuda_toolkit)) is not None: 105 | required_ctk_major, required_ctk_minor, _ = required_ctk 106 | if (installed_ctk_major, installed_ctk_minor) < (required_ctk_major, required_ctk_minor): 107 | elapsed = time.time() - start_time 108 | return GradedSolution( 109 | task_id=solution.task_id, 110 | solution=solution, 111 | problem=problem, 112 | passed=False, 113 | skipped=True, 114 | elapsed_time=elapsed, 115 | ) 116 | 117 | # Check preconditions -- task ids must match, the solution must have source files, and the solution must not 118 | # attempt to overwrite or modify the unseen test files. 119 | if not solution.validate(problem): 120 | elapsed = time.time() - start_time 121 | return GradedSolution( 122 | task_id=solution.task_id, 123 | solution=solution, 124 | problem=problem, 125 | passed=False, 126 | skipped=False, 127 | elapsed_time=elapsed, 128 | build_output="[VALIDATION ERROR] Solution failed validation checks.", 129 | ) 130 | 131 | with _work_dir_context(problem.task_id.replace("/", "-")) as workdir_path: 132 | # Write context files (public) from Problem to workdir 133 | for cf in problem.context_files: 134 | file_path = workdir_path / cf.path 135 | file_path.parent.mkdir(parents=True, exist_ok=True) 136 | file_path.write_text(cf.content) 137 | 138 | # Write test files (private) from Problem to workdir 139 | for tf in problem.test_files: 140 | file_path = workdir_path / tf.path 141 | file_path.parent.mkdir(parents=True, exist_ok=True) 142 | file_path.write_text(tf.content) 143 | 144 | # Apply the Solution to the workdir. Note that these may intentionally overwrite context files. 145 | solution.setup_workspace(workdir_path) 146 | 147 | build_output = None 148 | # Run build command (if set) 149 | if problem.build_command: 150 | try: 151 | result = subprocess.run( 152 | problem.build_command, 153 | shell=True, 154 | cwd=workdir_path, 155 | capture_output=True, 156 | text=True, 157 | check=True, 158 | ) 159 | build_output = result.stdout + "\n" + result.stderr 160 | except subprocess.CalledProcessError as e: 161 | elapsed = time.time() - start_time 162 | return GradedSolution( 163 | task_id=solution.task_id, 164 | solution=solution, 165 | problem=problem, 166 | passed=False, 167 | skipped=False, 168 | elapsed_time=elapsed, 169 | build_output=f"[BUILD ERROR]\n{e.stdout}\n{e.stderr}", 170 | ) 171 | 172 | # Validate the Solution passes the Problem's source_references (if any) 173 | if not solution.verify_source_references(problem.source_references): 174 | elapsed = time.time() - start_time 175 | return GradedSolution( 176 | task_id=solution.task_id, 177 | solution=solution, 178 | problem=problem, 179 | passed=False, 180 | skipped=False, 181 | elapsed_time=elapsed, 182 | build_output=build_output, 183 | test_output="[VALIDATION ERROR] Solution does not include required source references.", 184 | ) 185 | 186 | # Run test command 187 | try: 188 | result = subprocess.run( 189 | problem.test_command, 190 | shell=True, 191 | cwd=workdir_path, 192 | capture_output=True, 193 | text=True, 194 | check=True, 195 | timeout=problem.timeout_seconds, 196 | ) 197 | passed = True 198 | test_output = result.stdout + "\n" + result.stderr 199 | except subprocess.CalledProcessError as e: 200 | passed = False 201 | test_output = e.stdout + "\n" + e.stderr 202 | except subprocess.TimeoutExpired as e: 203 | passed = False 204 | test_output = f"[TIMEOUT EXPIRED after {e.timeout} seconds]\n{e.stdout}\n{e.stderr}" 205 | 206 | elapsed_time = time.time() - start_time 207 | 208 | # Return graded result 209 | return GradedSolution( 210 | task_id=solution.task_id, 211 | solution=solution, 212 | problem=problem, 213 | passed=passed, 214 | skipped=False, 215 | elapsed_time=elapsed_time, 216 | build_output=build_output, 217 | test_output=test_output, 218 | ) 219 | -------------------------------------------------------------------------------- /compute_eval/data/data_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Portions of this file from human-eval (https://github.com/openai/human-eval/). 17 | # 18 | # The MIT License 19 | # 20 | # Copyright (c) OpenAI (https://openai.com) 21 | # 22 | # Permission is hereby granted, free of charge, to any person obtaining a copy 23 | # of this software and associated documentation files (the "Software"), to deal 24 | # in the Software without restriction, including without limitation the rights 25 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 26 | # copies of the Software, and to permit persons to whom the Software is 27 | # furnished to do so, subject to the following conditions: 28 | # 29 | # The above copyright notice and this permission notice shall be included in 30 | # all copies or substantial portions of the Software. 31 | # 32 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 33 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 34 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 35 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 36 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 37 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 38 | # THE SOFTWARE. 39 | 40 | from abc import ABC, abstractmethod 41 | from enum import Enum 42 | from pathlib import Path 43 | from typing import Annotated, Literal 44 | 45 | from pydantic import ( 46 | BaseModel, 47 | ConfigDict, 48 | Field, 49 | TypeAdapter, 50 | model_validator, 51 | ) 52 | 53 | from compute_eval.utils.parsing import get_most_likely_language 54 | 55 | PROBLEM_SCHEMA_VERSION = 2 56 | SOLUTION_SCHEMA_VERSION = 1 57 | SOURCE_ENCODING = "utf-8" 58 | 59 | 60 | class ReleaseVersion(str, Enum): 61 | INTERNAL = "internal" 62 | V2025_1 = "2025-1" 63 | V2025_2 = "2025-2" 64 | V2025_3 = "2025-3" 65 | 66 | 67 | class SourceFile(BaseModel): 68 | path: str 69 | content: str 70 | 71 | 72 | class Metadata(BaseModel): 73 | model_config = ConfigDict(extra="allow") 74 | 75 | difficulty: str | None = None 76 | tags: list[str] = Field(default_factory=list) 77 | releases: list[ReleaseVersion] = Field(default_factory=list) 78 | 79 | do_not_release: bool = Field(default=False) 80 | 81 | 82 | class SourceReferences(BaseModel): 83 | """Model for source references with any/all semantics. 84 | 85 | Can be: 86 | - {'any': ['ref1', 'ref2']} - at least one reference must be present 87 | - {'all': ['ref1', 'ref2']} - all references must be present 88 | - {'any': ['ref1', 'ref2'], 'all': ['ref3', 'ref4']} - combined logic: 89 | ALL references in 'all' must be present AND at least ONE from 'any' must be present 90 | """ 91 | 92 | any: list[str] | None = None 93 | all: list[str] | None = None 94 | 95 | @model_validator(mode="after") 96 | def validate_at_least_one_key(self): 97 | if self.any is None and self.all is None: 98 | raise ValueError("Must specify at least one of 'any' or 'all' in source_references") 99 | return self 100 | 101 | 102 | class Problem(BaseModel, ABC): 103 | type: Literal["cuda_cpp", "cuda_python"] 104 | schema_version: int = Field(default=PROBLEM_SCHEMA_VERSION) 105 | 106 | task_id: str 107 | date: str 108 | prompt: str 109 | 110 | context_files: list[SourceFile] = Field(default_factory=list) 111 | test_files: list[SourceFile] = Field(default_factory=list) 112 | 113 | source_references: str | list[str] | SourceReferences | None = None 114 | 115 | build_command: str | None = None 116 | test_command: str 117 | 118 | min_cuda_toolkit: str | None = None 119 | timeout_seconds: float | None = None 120 | metadata: Metadata | None = None 121 | 122 | @model_validator(mode="before") 123 | @classmethod 124 | def _upgrade_to_concrete(cls, data): 125 | if isinstance(data, cls) or cls is not Problem: 126 | return data 127 | 128 | adapter = TypeAdapter(Annotated[CudaCppProblem | CudaPythonProblem, Field(discriminator="type")]) 129 | return adapter.validate_python(data) 130 | 131 | 132 | class CudaCppProblem(Problem): 133 | type: Literal["cuda_cpp"] = "cuda_cpp" 134 | 135 | arch_list: list[str] = Field(default_factory=list) 136 | 137 | 138 | class CudaPythonProblem(Problem): 139 | type: Literal["cuda_python"] = "cuda_python" 140 | 141 | python_version: str | None = None 142 | 143 | 144 | class Solution(BaseModel, ABC): 145 | model_config = ConfigDict(extra="allow") 146 | schema_version: int = Field(default=SOLUTION_SCHEMA_VERSION) 147 | 148 | type: Literal["file", "patch"] 149 | task_id: str 150 | 151 | @abstractmethod 152 | def validate(self, problem: Problem) -> bool: 153 | pass 154 | 155 | @abstractmethod 156 | def setup_workspace(self, work_dir: Path): 157 | pass 158 | 159 | @abstractmethod 160 | def verify_source_references(self, source_references: str | list[str] | SourceReferences | None) -> bool: 161 | pass 162 | 163 | @model_validator(mode="before") 164 | @classmethod 165 | def _upgrade_to_concrete(cls, data): 166 | if isinstance(data, cls) or cls is not Solution: 167 | return data 168 | 169 | adapter = TypeAdapter(Annotated[FileSolution | PatchSolution, Field(discriminator="type")]) 170 | return adapter.validate_python(data) 171 | 172 | 173 | class FileSolution(Solution): 174 | type: Literal["file"] = "file" 175 | 176 | files: list[SourceFile] = Field(default_factory=list) 177 | 178 | def validate(self, problem: Problem) -> bool: 179 | if self.task_id != problem.task_id: 180 | return False 181 | if not self.files: 182 | return False 183 | return not {f.path for f in self.files} & {tf.path for tf in problem.test_files} 184 | 185 | def setup_workspace(self, work_dir: Path): 186 | for file in self.files: 187 | file_path = work_dir / file.path 188 | file_path.parent.mkdir(parents=True, exist_ok=True) 189 | file_path.write_text(file.content) 190 | 191 | def verify_source_references(self, source_references: str | list[str] | SourceReferences | None) -> bool: 192 | if source_references is None: 193 | return True 194 | 195 | # Parse the source_references format 196 | all_refs: list[str] = [] 197 | any_refs: list[str] = [] 198 | 199 | if isinstance(source_references, str): 200 | all_refs = [source_references] 201 | elif isinstance(source_references, list): 202 | all_refs = source_references 203 | elif isinstance(source_references, SourceReferences): 204 | if source_references.all is not None: 205 | all_refs = source_references.all 206 | if source_references.any is not None: 207 | any_refs = source_references.any 208 | 209 | if not all_refs and not any_refs: 210 | return True 211 | 212 | all_remaining = {ref.encode(SOURCE_ENCODING) for ref in all_refs} 213 | any_remaining = {ref.encode(SOURCE_ENCODING) for ref in any_refs} 214 | 215 | for file in self.files: 216 | encoded_content = file.content.encode(SOURCE_ENCODING) 217 | language = get_most_likely_language(file.path, encoded_content) 218 | if language is not None: 219 | for ref, _ in language.find_matching_subtrees(encoded_content, all_remaining | any_remaining): 220 | all_remaining.discard(ref) 221 | if ref in any_remaining: 222 | any_remaining.clear() 223 | 224 | if not all_remaining and not any_remaining: 225 | return True 226 | 227 | return not all_remaining and not any_remaining 228 | 229 | 230 | class PatchSolution(Solution): 231 | type: Literal["patch"] = "patch" 232 | 233 | patch: str 234 | 235 | def validate(self, problem: Problem) -> bool: 236 | if self.task_id != problem.task_id: 237 | return False 238 | return self.patch is not None 239 | 240 | def setup_workspace(self, work_dir: Path): 241 | # TODO: Apply the patch to the files in work_dir 242 | pass 243 | 244 | def verify_source_references(self, source_references: str | list[str] | SourceReferences | None) -> bool: 245 | # TODO: Need to fully implement patch solutions. 246 | return True 247 | 248 | 249 | class GradedSolution(BaseModel): 250 | task_id: str 251 | passed: bool 252 | skipped: bool 253 | elapsed_time: float 254 | solution: Solution 255 | problem: Problem 256 | build_output: str | None = None 257 | test_output: str | None = None 258 | 259 | @model_validator(mode="before") 260 | @classmethod 261 | def _upgrade_nested_union(cls, data): 262 | if "problem" in data and data["problem"] is not None: 263 | problem_adapter = TypeAdapter(Annotated[CudaCppProblem | CudaPythonProblem, Field(discriminator="type")]) 264 | data["problem"] = problem_adapter.validate_python(data["problem"]) 265 | if "solution" in data and data["solution"] is not None: 266 | solution_adapter = TypeAdapter(Annotated[FileSolution | PatchSolution, Field(discriminator="type")]) 267 | data["solution"] = solution_adapter.validate_python(data["solution"]) 268 | return data 269 | -------------------------------------------------------------------------------- /compute_eval/evaluation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Portions of this file from human-eval (https://github.com/openai/human-eval/). 17 | # 18 | # The MIT License 19 | # 20 | # Copyright (c) OpenAI (https://openai.com) 21 | # 22 | # Permission is hereby granted, free of charge, to any person obtaining a copy 23 | # of this software and associated documentation files (the "Software"), to deal 24 | # in the Software without restriction, including without limitation the rights 25 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 26 | # copies of the Software, and to permit persons to whom the Software is 27 | # furnished to do so, subject to the following conditions: 28 | # 29 | # The above copyright notice and this permission notice shall be included in 30 | # all copies or substantial portions of the Software. 31 | # 32 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 33 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 34 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 35 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 36 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 37 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 38 | # THE SOFTWARE. 39 | 40 | import itertools 41 | import json 42 | import os 43 | from collections import defaultdict 44 | from concurrent.futures import ThreadPoolExecutor, as_completed 45 | 46 | import numpy as np 47 | import tqdm 48 | 49 | from compute_eval.data.data_model import ( 50 | PROBLEM_SCHEMA_VERSION, 51 | SOLUTION_SCHEMA_VERSION, 52 | GradedSolution, 53 | Problem, 54 | ) 55 | from compute_eval.data.data_pack import ProblemDatapack, SolutionDatapack 56 | from compute_eval.data.utils import write_graded_solutions 57 | from compute_eval.execution import evaluate_solution 58 | from compute_eval.utils.eval_utils import get_nvcc_version, parse_semver 59 | 60 | WARNING_MSG = """=================== 61 | WARNING 62 | =================== 63 | 64 | Evaluation of correctness or performance will execute untrusted model-generated 65 | code. 66 | 67 | Although it is highly unlikely that model-generated code will do something 68 | overtly malicious in response to this test suite, model-generated code may act 69 | destructively due to a lack of model capability or alignment. 70 | 71 | Users are strongly encouraged to sandbox this evaluation suite so that it does 72 | not perform destructive actions on their host or network. 73 | 74 | In order to execute this code you must explicitly pass the --allow-execution flag. 75 | """ 76 | 77 | 78 | def estimate_pass_at_k( 79 | num_samples: int | list[int] | np.ndarray, 80 | num_correct: list[int] | np.ndarray, 81 | k: int, 82 | ) -> np.ndarray: 83 | """ 84 | Estimates pass@k of each problem and returns them in an array. 85 | 86 | Args: 87 | num_samples: Number of samples for each problem 88 | num_correct: Number of correct samples for each problem 89 | k: The k value for pass@k calculation 90 | 91 | Returns: 92 | Array of pass@k estimates for each problem 93 | """ 94 | if k <= 0: 95 | raise ValueError("k must be positive") 96 | 97 | def estimator(n: int, c: int, k: int) -> float: 98 | """ 99 | Calculates 1 - comb(n - c, k) / comb(n, k). 100 | """ 101 | if n - c < k: 102 | return 1.0 103 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 104 | 105 | if isinstance(num_samples, int): 106 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 107 | else: 108 | assert len(num_samples) == len(num_correct) 109 | num_samples_it = iter(num_samples) 110 | 111 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct, strict=False)]) 112 | 113 | 114 | def evaluate_functional_correctness( 115 | solutions_datapack: str, 116 | problems_datapack_dir: str, 117 | allow_execution: bool, 118 | k: tuple[int] | int, 119 | n_workers: int, 120 | results_file: str | None, 121 | ): 122 | """ 123 | Evaluates the functional correctness of generated solutions and writes results. 124 | 125 | Args: 126 | solutions_datapack (str): Path to the solution datapack. 127 | problems_datapack_dir (str): Directory containing problem datapacks. 128 | allow_execution (bool): Whether to allow execution of untrusted code. 129 | k (Tuple[int] | int): Tuple of k values for evaluation or single k value (default: 1). 130 | n_workers (int): Number of worker threads. 131 | results_file (str | None): Path to output results file. 132 | 133 | Returns: 134 | None 135 | """ 136 | if not allow_execution: 137 | raise RuntimeError(WARNING_MSG) 138 | 139 | if (installed_ctk_version := parse_semver(get_nvcc_version())) is None: 140 | raise RuntimeError("Could not determine CUDA toolkit version from nvcc.") 141 | 142 | installed_ctk_major, installed_ctk_minor, _ = installed_ctk_version 143 | 144 | # Check if only one k value was passed in (as an integer) 145 | # Multiple k values (tuple) is converted to a list of int 146 | k_vals = [k] if isinstance(k, int) else list(k) 147 | 148 | with SolutionDatapack(solutions_datapack) as datapack: 149 | release = datapack.metadata.release 150 | 151 | print("Reading solutions...") 152 | solutions = list(datapack.read_items()) 153 | 154 | # Verify that all solutions are for the current schema version 155 | if any(s.schema_version != SOLUTION_SCHEMA_VERSION for s in solutions): 156 | raise ValueError( 157 | f"One or more solutions in {solutions_datapack} do not match the expected schema version {SOLUTION_SCHEMA_VERSION}." 158 | ) 159 | 160 | problems_file = os.path.join(problems_datapack_dir, f"{release.value}-problems.tar.gz") 161 | with ProblemDatapack(problems_file) as datapack: 162 | # Sanity check: ensure the problems datapack matches the solutions datapack release 163 | if datapack.metadata.release != release: 164 | raise ValueError( 165 | f"Problems datapack release {datapack.metadata.release} does not match solutions datapack release {release}." 166 | ) 167 | 168 | print("Reading problems...") 169 | problems = list(datapack.read_items()) 170 | keyed_problems: dict[str, Problem] = {p.task_id: p for p in problems} 171 | 172 | # Verify that all problems are for the current schema version 173 | if any(p.schema_version != PROBLEM_SCHEMA_VERSION for p in keyed_problems.values()): 174 | raise ValueError( 175 | f"One or more problems in {problems_file} do not match the expected schema version {PROBLEM_SCHEMA_VERSION}." 176 | ) 177 | 178 | # Verify that each problem is attempted at least once 179 | task_ids = set(p.task_id for p in problems) 180 | test_ids = set(s.task_id for s in solutions) 181 | 182 | missing_ids = task_ids - test_ids 183 | if missing_ids: 184 | raise ValueError(f"The following task_ids are missing in the solutions: {missing_ids}") 185 | 186 | # Check the generated solutions against test suites. 187 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 188 | futures = [] 189 | results: list[GradedSolution] = [] 190 | 191 | for solution in tqdm.tqdm(solutions): 192 | task_id = solution.task_id 193 | problem = keyed_problems.get(task_id) 194 | 195 | args = (installed_ctk_major, installed_ctk_minor, problem, solution) 196 | future = executor.submit(evaluate_solution, *args) 197 | futures.append(future) 198 | 199 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 200 | results.append(future.result()) 201 | 202 | pass_at_k = estimate_metrics(results, k_vals) 203 | write_metrics( 204 | results, 205 | pass_at_k, 206 | results_file or f"{release.value}-graded-solutions.jsonl", 207 | ) 208 | 209 | 210 | def estimate_metrics(results: list[GradedSolution], k_vals: list[int]) -> dict[str, float]: 211 | """ 212 | Estimates the metrics for the given solutions. 213 | 214 | Args: 215 | results (List[GradedSolution]): List of graded solutions 216 | k_vals (List[int]): List of k values for evaluation. 217 | 218 | Returns: 219 | Dict[str, float]: A dictionary containing the estimated metrics. 220 | """ 221 | # Calculate pass@k. 222 | total, correct = [], [] 223 | 224 | # Group results by task_id 225 | results_by_task = defaultdict(list) 226 | for result in results: 227 | results_by_task[result.solution.task_id].append(result) 228 | 229 | skipped = 0 230 | for _, results in results_by_task.items(): 231 | total.append(len(results)) 232 | correct.append(sum(r.passed for r in results)) 233 | skipped += all(r.skipped for r in results) 234 | total = np.array(total) 235 | correct = np.array(correct) 236 | 237 | return { 238 | "skipped": float(skipped), 239 | **{f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in k_vals if (total >= k).all()}, 240 | } 241 | 242 | 243 | def write_metrics( 244 | results: list[GradedSolution], 245 | pass_at_k: dict[str, float], 246 | results_file: str, 247 | ) -> None: 248 | """ 249 | Writes the metrics to a file and prints consolidated output. 250 | 251 | Args: 252 | results (list[EvaluatedSample]): List of evaluated samples. 253 | pass_at_k (Dict[str, float]): Pass@k metrics. 254 | results_file (str): Path to the output results file. 255 | 256 | Returns: 257 | None 258 | """ 259 | # Finally, save the results in one file: 260 | print(f"Writing results to {results_file}...") 261 | write_graded_solutions(results_file, results) 262 | 263 | # Output structured JSON to stdout 264 | output = { 265 | "pass_at_k": {k: float(v) for k, v in pass_at_k.items()}, 266 | "problem_count": len(set(r.solution.task_id for r in results)), 267 | } 268 | print(json.dumps(output, indent=2)) 269 | -------------------------------------------------------------------------------- /data/LICENSE: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: LicenseRef-NVIDIA-Evaluation 3 | 4 | ================================================================================ 5 | NVIDIA Evaluation Dataset License Agreement 6 | ================================================================================ 7 | 8 | This NVIDIA Evaluation Dataset License Agreement ("Agreement") is a legal 9 | agreement between you, whether an individual or entity ("you") and NVIDIA 10 | Corporation with an address 2788 San Tomas Expressway, Santa Clara, California 11 | 95051 ("NVIDIA") and governs the use of certain datasets, including any 12 | annotations and metadata accompanying the datasets, provided by NVIDIA 13 | ("Dataset"). 14 | 15 | This Agreement can be accepted only by an adult of legal age of majority in the 16 | country in which the Dataset is used. 17 | 18 | If you don't have the required age or authority to accept this Agreement or if 19 | you don't accept all the terms and conditions of this Agreement, do not use the 20 | Dataset. 21 | 22 | You agree to use the Dataset only for purposes expressly permitted by this 23 | Agreement and in accordance with any applicable law or regulation in the 24 | relevant jurisdictions. 25 | 26 | 1. License Grant 27 | 28 | Subject to the terms of this Agreement, NVIDIA grants you a limited, 29 | non-exclusive, revocable, non-transferable, non-sublicensable, license to 30 | download, use, reproduce, modify, and create derivative works of the Dataset, 31 | in each case solely for your internal evaluation and benchmarking of AI 32 | Solutions ("Purpose"). "AI Solutions" means any artificial intelligence ("AI") 33 | based models or machine learning algorithm and associated parameters and 34 | associated weights. You may publish or otherwise disclose the results of your 35 | evaluation or benchmarking of AI Solutions using the Dataset. 36 | 37 | 2. Authorized Users 38 | 39 | You may allow your Affiliates' employees and contractors (all such users 40 | collectively "Authorized Users") to access and use the Dataset from your 41 | secure network for the Purpose on your behalf. You are responsible for the 42 | compliance with the terms of this Agreement by your Authorized Users. Any act 43 | or omission by your Authorized Users that if committed by you would 44 | constitute a breach of this Agreement will be deemed to constitute a breach 45 | of this Agreement. "Affiliates" means an entity that owns or controls, is 46 | owned or controlled by, or is under common ownership or control with you, 47 | where "control" is the possession, directly or indirectly, of the power to 48 | direct or cause the direction of the management and policies of an entity, 49 | whether through ownership of voting securities, by contract or otherwise. 50 | 51 | 3. Limitations 52 | 53 | Your license to use the Dataset is restricted as follows: 54 | 55 | 3.1 The rights granted to you in Section 1 and 2 are for the Purpose only. 56 | You may not use the Dataset for any other purpose, including the training 57 | of AI Solutions. 58 | 59 | 3.2 You may not sell, rent, sublicense, transfer, distribute, embed, or host 60 | the Dataset (in whole or in part), or otherwise make the Dataset (in 61 | whole or in part) available to others. 62 | 63 | 3.3 You may not change or remove copyright or other proprietary notices in 64 | the Dataset. 65 | 66 | 3.4 You may not use the Dataset in any manner that would cause it to become 67 | subject to an open source license. 68 | 69 | 3.5 You may not use the Dataset to identify any individuals or Personal Data. 70 | "Personal Data" means any information relating to an identified or 71 | identifiable natural person and any other information that constitutes 72 | personal data or personal information under any applicable law. 73 | 74 | 4. AI Ethics 75 | 76 | Your use of the Dataset must be consistent with NVIDIA's Trustworthy AI 77 | Terms. 78 | 79 | 5. Ownership 80 | 81 | As between you and NVIDIA and to the maximum extent under applicable law, the 82 | Dataset, including all intellectual property rights, is and will remain the 83 | sole and exclusive property of NVIDIA or its licensors. Except as expressly 84 | granted in this Agreement, (a) NVIDIA reserves all rights, interests and 85 | remedies in connection with the Dataset, and (b) no other license or right is 86 | granted to you by implication, estoppel or otherwise. 87 | 88 | 6. Feedback 89 | 90 | You may, but are not obligated to, provide suggestions, requests, fixes, 91 | modifications, enhancements, or other feedback regarding or in connection 92 | with your use of the Dataset (collectively, "Feedback"). Feedback, even if 93 | designated as confidential by you, will not create any confidentiality 94 | obligation for NVIDIA or its affiliates. If you provide Feedback, you hereby 95 | grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, 96 | irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and 97 | transferable license, under your intellectual property rights, to publicly 98 | perform, publicly display, reproduce, use, make, have made, sell, offer for 99 | sale, distribute (through multiple tiers of distribution), import, create 100 | derivative works of and otherwise commercialize and exploit the Feedback at 101 | NVIDIA's discretion. 102 | 103 | 7. Termination 104 | 105 | This Agreement will automatically terminate (a) if you fail to comply with 106 | any of the terms in this Agreement, or (b) if you commence or participate in 107 | any legal proceeding against NVIDIA with respect to the Dataset. Upon 108 | termination, you must stop using and destroy all copies of the Dataset. Upon 109 | written request, you will certify in writing that you have complied with your 110 | commitments under this section. All provisions will survive termination, 111 | except for the licenses granted to you. 112 | 113 | 8. Disclaimer of Warranties 114 | 115 | THE DATASET IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM 116 | EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND 117 | REPRESENTATIONS OF ANY KIND, WHETHER EXPRESS, IMPLIED OR STATUTORY, RELATING 118 | TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE 119 | WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A 120 | PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. 121 | 122 | 9. Limitations of Liability 123 | 124 | TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO 125 | LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, 126 | WILL NVIDIA BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, 127 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY TYPE ARISING OUT OF OR 128 | AS A RESULT OF THIS AGREEMENT OR THE USE OR INABILITY TO USE THE DATASET 129 | (INCLUDING BUT NOT LIMITED TO DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, 130 | COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER DAMAGES OR LOSSES), 131 | EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 132 | 133 | 10. Governing Law and Jurisdiction 134 | 135 | This Agreement will be governed in all respects by the laws of the United 136 | States and the laws of the State of Delaware, without regard to conflict of 137 | laws principles or the United Nations Convention on Contracts for the 138 | International Sale of Goods. The state and federal courts residing in Santa 139 | Clara County, California will have exclusive jurisdiction over any dispute 140 | or claim arising out of or related to this Agreement, and you and NVIDIA 141 | irrevocably consent to personal jurisdiction and venue in those courts; 142 | except that either you or NVIDIA may apply for injunctive remedies or an 143 | equivalent type of urgent legal relief in any jurisdiction. 144 | 145 | 11. Indemnity 146 | 147 | You will indemnify and hold harmless NVIDIA and its affiliates from and 148 | against any and all claims, damages, obligations, losses, liabilities, costs 149 | and expenses (including but not limited to attorney's fees and costs of 150 | establishing the right of indemnification) arising out of or related to your 151 | use of the Dataset. 152 | 153 | 12. General 154 | 155 | 12.1 No Assignment 156 | NVIDIA may assign, delegate or transfer its rights or obligations under 157 | this Agreement by any means or operation of law. You may not, without 158 | NVIDIA's prior written consent, assign, delegate or transfer any of 159 | your rights or obligations under this Agreement by any means or 160 | operation of law, and any attempt to do so is null and void. 161 | 162 | 12.2 No Waiver 163 | No waiver of any term of the Agreement will be deemed a further or 164 | continuing waiver of such term or any other term, and NVIDIA's failure 165 | to assert any right or provision under the Agreement will not 166 | constitute a waiver of such right or provision. 167 | 168 | 12.3 Trade Compliance 169 | You agree to comply with all applicable export, import, trade and 170 | economic sanctions laws and regulations, including the Export 171 | Administration Regulations and Office of Foreign Assets Control 172 | regulations. These laws include restrictions on destinations, end-users 173 | and end-use. 174 | 175 | 12.4 Notices 176 | Please direct your legal notices or other correspondence to NVIDIA 177 | Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, 178 | United States of America, Attention: Legal Department, with a copy 179 | emailed to legalnotices@nvidia.com. 180 | 181 | 12.5 Independent Contractors 182 | You and NVIDIA are independent contractors, and this Agreement does not 183 | create a joint venture, partnership, agency or other form of business 184 | association between you and NVIDIA. Neither you nor NVIDIA will have 185 | the power to bind the other or incur any obligation on its behalf 186 | without the other's prior written consent. 187 | 188 | 12.6 Severability 189 | If a court of competent jurisdiction rules that a provision of this 190 | Agreement is unenforceable, that provision will be deemed modified to 191 | the extent necessary to make it enforceable and the remainder of this 192 | Agreement will continue in full force and effect. 193 | 194 | 12.7 Construction 195 | The headings in the Agreement are included solely for convenience and 196 | are not intended to affect the meaning or interpretation of the 197 | Agreement. As required by the context of the Agreement, the singular of 198 | a term includes the plural and vice versa. 199 | 200 | 12.8 Entire Agreement 201 | Regarding the subject matter of this Agreement, you and NVIDIA agree 202 | that (a) this Agreement constitutes the entire and exclusive agreement 203 | between the you and NVIDIA and supersedes all prior and contemporaneous 204 | communications and (b) any additional or different terms or conditions, 205 | whether contained in purchase orders, order acknowledgments, invoices 206 | or otherwise, will not be binding and are null and void. 207 | 208 | ================================================================================ 209 | (v November 7, 2025) -------------------------------------------------------------------------------- /compute_eval/generate_completions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import gzip 16 | import os 17 | import re 18 | from concurrent.futures import ThreadPoolExecutor, as_completed 19 | 20 | import tqdm 21 | 22 | from compute_eval.data.data_model import FileSolution, Problem, ReleaseVersion, Solution, SourceFile 23 | from compute_eval.data.utils import read_solutions, write_solutions 24 | 25 | from . import get_model_class 26 | from .data.data_pack import ProblemDatapack, SolutionDatapack 27 | from .models.model_interface import ModelInterface 28 | from .prompts import to_user_message 29 | 30 | 31 | def generate_model_completions( 32 | system_prompt: str, 33 | problem: Problem, 34 | model: str, 35 | base_url: str | None = None, 36 | reasoning: str | None = None, 37 | params: dict | None = None, 38 | debug: bool = False, 39 | ) -> Solution: 40 | """ 41 | Orchestrate the generation of code completions using the specified model. 42 | 43 | Args: 44 | system_prompt (str): The system prompt to use for generating completions. 45 | problem (Problem): The problem object containing the task details. 46 | model (str): The name of the model to use for generating completions. 47 | base_url (str, optional): The base URL for the custom model API endpoint. 48 | reasoning (str, optional): Reasoning mode for the model (e.g., 'low', 'medium', 'high' for GPT models, or any value for Claude models to enable extended thinking). 49 | params (dict, optional): Additional parameters to pass to the model invocation. 50 | debug (bool, optional): Whether to include the system prompt, prompt, and generated completion in the output solution for debugging. 51 | 52 | Returns: 53 | solution (Solution): The generated solution object containing the completions. 54 | """ 55 | 56 | model_class = get_model_class(model) 57 | model_instance: ModelInterface = model_class( 58 | model_name=model, 59 | base_url=base_url, 60 | reasoning=reasoning, 61 | ) 62 | 63 | if params is None: 64 | params = {} 65 | 66 | prompt = to_user_message(problem) 67 | 68 | completion = model_instance.generate_response(system_prompt, prompt, params) 69 | 70 | debug_info = ( 71 | { 72 | "system_prompt": system_prompt, 73 | "prompt": prompt, 74 | "generated_completion": completion, 75 | } 76 | if debug 77 | else {} 78 | ) 79 | 80 | return FileSolution( 81 | task_id=problem.task_id, 82 | files=_parse_solution(completion), 83 | **debug_info, 84 | ) 85 | 86 | 87 | def generate_samples( 88 | release: ReleaseVersion, 89 | problems_datapack_dir: str, 90 | solutions_per_problem: int, 91 | n_workers: int, 92 | system_prompt: str, 93 | model: str, 94 | base_url: str | None, 95 | reasoning: str | None, 96 | temperature: float | None, 97 | top_p: float | None, 98 | max_tokens: int | None, 99 | temp_dir: str | None, 100 | debug: bool, 101 | ): 102 | """ 103 | Generates code completions for a set of problems using a specified model and writes them to a solutions datapack. 104 | Args: 105 | release (ReleaseVersion): The release version to generate solutions for. 106 | problems_datapack_dir (str): Directory where released problem datapacks are stored. 107 | solutions_per_problem (int): Number of solutions to generate per problem. 108 | n_workers (int): Number of worker threads to use for parallel generation. 109 | system_prompt (str): The system prompt to use for generating completions. 110 | model (str): The name of the model to use for generating completions. 111 | base_url (str | None): Base URL for the custom model API endpoint. 112 | reasoning (str | None): Reasoning mode for the model (e.g., 'low', 'medium', 'high' for GPT models, or any value for Claude models to enable extended thinking). 113 | temperature (float | None): Temperature for generation. 114 | top_p (float | None): Top-p for generation. 115 | max_tokens (int | None): Maximum tokens for generation. 116 | temp_dir (str | None): Temporary directory to store intermediate results. 117 | debug (bool): Whether to include the system prompt, prompt, and generated completion in the output solution for debugging. 118 | """ 119 | 120 | def _task_id_to_filename(directory: str, _id: str) -> str: 121 | return f"{directory}/{_id.replace('/', '_')}.jsonl" 122 | 123 | if temp_dir is None: 124 | temp_dir = model if model else "temp_results" 125 | 126 | if not os.path.exists(temp_dir): 127 | os.makedirs(temp_dir) 128 | 129 | problem_file = os.path.join(problems_datapack_dir, f"{release.value}-problems.tar.gz") 130 | with ProblemDatapack(problem_file) as datapack: 131 | if datapack.metadata.release != release: 132 | raise ValueError( 133 | f"Problems datapack release {datapack.metadata.release} does not match expected release {release}." 134 | ) 135 | problems = list(datapack.read_items()) 136 | 137 | task_count = {p.task_id: _count_lines(_task_id_to_filename(temp_dir, p.task_id)) for p in problems} 138 | 139 | print("Started generating the model completions") 140 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 141 | futures = [] 142 | for problem in problems: 143 | existing_sample_count = task_count.get(problem.task_id, 0) 144 | solutions_to_generate = solutions_per_problem - existing_sample_count 145 | 146 | if solutions_to_generate <= 0: 147 | print(f"Skipping {problem.task_id}, already have {existing_sample_count} solutions") 148 | continue 149 | 150 | for _ in range(solutions_to_generate): 151 | params = { 152 | "temperature": temperature, 153 | "top_p": top_p, 154 | "max_tokens": max_tokens, 155 | } 156 | args = { 157 | "system_prompt": system_prompt, 158 | "problem": problem, 159 | "model": model, 160 | "base_url": base_url, 161 | "reasoning": reasoning, 162 | "params": params, 163 | "debug": debug, 164 | } 165 | future = executor.submit(generate_model_completions, **args) 166 | futures.append(future) 167 | 168 | print("Waiting for all the model completions") 169 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 170 | try: 171 | solution = future.result() 172 | write_solutions( 173 | file_path=_task_id_to_filename(temp_dir, solution.task_id), 174 | solutions=[solution], 175 | append=True, 176 | ) 177 | except Exception as e: 178 | print(f"Error processing future: {e}") 179 | 180 | all_results = [] 181 | for task_file in sorted(os.listdir(temp_dir)): 182 | task_file_path = os.path.join(temp_dir, task_file) 183 | all_results.extend(read_solutions(task_file_path)) 184 | 185 | if len(all_results) != len(problems) * solutions_per_problem: 186 | print(f"Error: Expected {len(problems) * solutions_per_problem} samples, but got {len(all_results)}") 187 | raise ValueError("Sample generation incomplete") 188 | 189 | model = model.replace("/", "-") if model else None 190 | SolutionDatapack.create( 191 | file_path=f"{release.value}-{model}-solutions.tar.gz" if model else f"{release.value}-solutions.tar.gz", 192 | items=all_results, 193 | release=release, 194 | ) 195 | 196 | # Clean up temporary files 197 | for task_file in os.listdir(temp_dir): 198 | os.remove(os.path.join(temp_dir, task_file)) 199 | os.rmdir(temp_dir) 200 | 201 | print("Completed generating all the samples for the problems. Written to the samples JSONL file") 202 | 203 | 204 | def _count_lines(filename: str) -> int: 205 | """ 206 | Counts the number of lines in a file 207 | """ 208 | if not os.path.exists(filename): 209 | return 0 210 | 211 | count = 0 212 | if filename.endswith(".gz"): 213 | with open(filename, "rb") as gzfp, gzip.open(gzfp, "rt") as fp: 214 | for _ in fp: 215 | count += 1 216 | else: 217 | with open(filename, "r") as fp: 218 | for _ in fp: 219 | count += 1 220 | return count 221 | 222 | 223 | _CODE_BLOCK_RE = re.compile(r"```([^\n`]*)\n(.*?)```", re.DOTALL | re.MULTILINE) 224 | _FIRST_LINE_PATH_RE = re.compile(r"^(?://|#|;)\s*file:\s*([A-Za-z0-9._/\-]+)\s*$", re.IGNORECASE) 225 | _FIRST_LINE_BLOCK_COMMENT_PATH_RE = re.compile(r"^\s*/\*\s*file:\s*(.+?)\s*\*/\s*$", re.IGNORECASE) 226 | 227 | 228 | def _normalize_newlines(s: str) -> str: 229 | return s.replace("\r\n", "\n").replace("\r", "\n") 230 | 231 | 232 | def _guess_ext_from_lang(lang: str) -> str: 233 | lang = (lang or "").strip().lower() 234 | if lang == "cuda": 235 | return ".cu" 236 | if lang in ("cpp", "c++"): 237 | return ".cc" 238 | if lang == "c": 239 | return ".c" 240 | if lang in ("h", "hpp", "header"): 241 | return ".h" 242 | return ".txt" 243 | 244 | 245 | def _parse_solution(response: str) -> list[SourceFile]: 246 | if not isinstance(response, str): 247 | raise TypeError("response must be a string") 248 | 249 | text = _normalize_newlines(response) 250 | matches = list(_CODE_BLOCK_RE.finditer(text)) 251 | 252 | # If no fenced code blocks found, treat entire response as raw code 253 | if not matches: 254 | source_file = _process_code_block(text) 255 | return [source_file] if source_file else [] 256 | 257 | # Process each fenced code block 258 | files: list[SourceFile] = [] 259 | for m in matches: 260 | block = _normalize_newlines(m.group(2)) 261 | source_file = _process_code_block(block) 262 | if source_file: 263 | files.append(source_file) 264 | 265 | return files 266 | 267 | 268 | def _process_code_block(block: str) -> SourceFile | None: 269 | """Process a single code block and extract path + content.""" 270 | block_stripped = block.lstrip("\n") 271 | lines = block_stripped.split("\n") 272 | if not lines: 273 | return None 274 | 275 | first_line = lines[0].strip("\ufeff").strip() 276 | path = _extract_path_from_line(first_line) 277 | 278 | if not path: 279 | return None 280 | 281 | return SourceFile( 282 | path="solution.cu", 283 | content="\n".join(lines[1:]), 284 | ) 285 | 286 | 287 | def _extract_path_from_line(line: str) -> str | None: 288 | """Extract file path from a line using various comment formats.""" 289 | # Try regular path format first 290 | m1 = _FIRST_LINE_PATH_RE.match(line) 291 | if m1: 292 | return m1.group(1).strip() 293 | 294 | # Try block comment format 295 | m2 = _FIRST_LINE_BLOCK_COMMENT_PATH_RE.match(line) 296 | if m2: 297 | return m2.group(1).strip() 298 | 299 | return None 300 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # compute-eval 2 | 3 | ComputeEval: Evaluating Large Language Models for CUDA Code Generation 4 | 5 | ComputeEval is a framework designed to generate and evaluate CUDA code from Large Language Models. 6 | It features: 7 | 8 | - A set of handcrafted CUDA programming challenges designed to evaluate an LLM's capability at writing reliable CUDA code 9 | - Utilities for generating multiple solutions to each challenge 10 | - Utilities for functional correctness evaluation of generated solutions 11 | 12 | ComputeEval is currently in Alpha. We plan to refine the evaluation framework 13 | and make frequent updates to the dataset with additional problems spanning all 14 | aspects of CUDA development. 15 | 16 | ## Benchmark Structure and Evaluation 17 | 18 | ### Problem Organization 19 | 20 | Each problem in ComputeEval is stored as a directory under `data`, containing: 21 | 22 | ``` 23 | CUDA-0/ 24 | ├── problem-spec.yaml # Problem metadata and configuration 25 | ├── context/ # Files visible to the tested model/system (headers, helpers) 26 | │ ├── include/ 27 | │ │ └── kernel.h # Interface contract to implement 28 | │ └── helpers/ 29 | │ └── helpers.cu # Optional helper utilities 30 | ├── solution/ # Reference implementation (not shown to tested model/system) 31 | │ └── solution.cu 32 | └── test/ # Test harness (not shown to tested model/system) 33 | └── test/ 34 | └── test_main.cu 35 | ``` 36 | 37 | #### Problem Specification Format 38 | 39 | The `problem-spec.yaml` file defines each problem's metadata and configuration: 40 | 41 | ```yaml 42 | task_id: "CUDA/0" # Unique identifier (generally matches directory name) 43 | date: "2024-12-19" # Problem creation date 44 | problem_type: cuda_cpp # Type: cuda_cpp or cuda_python 45 | prompt: "Implement a CUDA kernel..." # Problem description shown to model 46 | 47 | # Build and test configuration 48 | build_command: "nvcc -I include -o test.out solution.cu test/*.cu" 49 | test_command: "./test.out" 50 | timeout_seconds: 30.0 51 | 52 | # Requirements 53 | min_cuda_toolkit: "12.0" # Minimum CUDA version required 54 | arch_list: [] # Optional: specific GPU architectures 55 | 56 | # Optional metadata 57 | metadata: 58 | difficulty: medium # Problem difficulty level 59 | tags: [kernels, memory] # Classification tags 60 | releases: [2025-1, 2025-2] # Which releases include this problem 61 | do_not_release: false # Internal-only flag to skip CI 62 | 63 | source_references: null # Optional: required API calls/symbols to verify 64 | # - string: single item must be present 65 | # - list of strings: all must be present 66 | # - {any: [...]} at least one must be present 67 | # - {all: [...]} all must be present 68 | # - {all: [...], any: [...]} combines both 69 | ``` 70 | 71 | Example with source references requiring specific CUDA APIs: 72 | 73 | ```yaml 74 | source_references: 75 | all: [cudaMalloc, cudaFree] # Must use both malloc and free 76 | any: [cudaMemcpy, cudaMemcpyAsync] # Must use at least one copy method 77 | ``` 78 | 79 | ### Evaluation Rules of Engagement 80 | 81 | ComputeEval follows a strict separation between what systems/models see during generation versus what is used during evaluation: 82 | 83 | **What the system/model sees (generation time):** 84 | - Problem `prompt` - describes the task and requirements 85 | - `context_files` - headers defining interfaces, optional helper utilities 86 | - `build_command` - compilation instructions and flags 87 | - Minimum CUDA toolkit version and architecture requirements 88 | 89 | **What the system/model does NOT see:** 90 | - `test_files` - held-out test harness that validates correctness 91 | - `solution` - reference implementation directory 92 | 93 | **During evaluation:** 94 | 1. A temporary workspace is created 95 | 2. `context_files` are written to the workspace 96 | 3. `test_files` are written to the workspace (now visible) 97 | 4. The model-generated solution files are written to the workspace 98 | 5. The `build_command` is executed to compile the unified workspace 99 | 6. If compilation succeeds, the `test_command` is executed 100 | 7. Test exit code determines pass/fail 101 | 102 | This ensures models cannot overfit to test cases and must solve problems based solely on the problem description and interface contracts. 103 | 104 | ### Continuous Integration Validation 105 | 106 | Every problem in the repository includes a known-good reference solution. Our CI pipeline continuously validates the integrity of the benchmark by: 107 | 108 | 1. Running the evaluation procedure on each problem's reference solution 109 | 2. Verifying that build commands compile successfully 110 | 3. Ensuring test harnesses execute correctly and pass 111 | 4. Validating that problem specifications are well-formed 112 | 113 | This guarantees that all released problems are solvable and correctly specified. 114 | 115 | ### Release Datapacks 116 | 117 | For production use, ComputeEval distributes problems as **datapacks** - versioned, immutable releases stored as compressed tarballs (`.tar.gz`): 118 | 119 | ``` 120 | data/releases/ 121 | ├── 2025-1-problems.tar.gz 122 | ├── 2025-2-problems.tar.gz 123 | ``` 124 | 125 | #### Datapack Structure 126 | 127 | Each datapack contains: 128 | - **`metadata.json`** - Release version, creation timestamp, problem count, and integrity hashes 129 | - **`problems.jsonl`** or **`solutions.jsonl`** - One JSON object per line representing each problem/solution 130 | 131 | Problems in datapacks are serialized as JSON objects rather than directories. Each problem includes: 132 | - All fields from `problem-spec.yaml` 133 | - Embedded `context_files` (list of `{path, content}` objects) 134 | - Embedded `test_files` (held-out, for evaluation only) 135 | 136 | This format provides: 137 | - **Immutability** - Released benchmarks never change 138 | - **Integrity** - MD5 hashes verify problem consistency 139 | - **Portability** - Self-contained archives easy to distribute 140 | - **Versioning** - Clear separation between releases 141 | 142 | #### Release Strategy 143 | 144 | ComputeEval follows a regular release schedule: 145 | 146 | - **2025.1** (Released) - Initial benchmark with 231 problems 147 | - **2025.2** (Released) - Second release with expanded coverage 148 | - **2025.3** (Upcoming) - Third release scheduled November 2025 149 | 150 | We are committed to **permanently supporting all previous releases**. Model developers can benchmark against any release version to: 151 | - Track progress over time against a fixed baseline 152 | - Compare results with published benchmarks 153 | - Ensure reproducibility of evaluation results 154 | 155 | 156 | ## Setup 157 | 158 | ### Prerequisites 159 | 160 | - Python 3.11+ 161 | - NVIDIA GPU with CUDA Toolkit 12 or greater (for evaluation) 162 | 163 | ### Installation 164 | 165 | Install the package using uv: 166 | 167 | ```bash 168 | uv sync 169 | ``` 170 | 171 | ### Pre-commit Hooks 172 | 173 | Set up pre-commit hooks for code quality: 174 | 175 | ```bash 176 | uv sync --group dev 177 | uv run pre-commit install 178 | ``` 179 | 180 | ### API Keys 181 | 182 | To query an LLM, you must first obtain an API key from the respective service. 183 | 184 | #### NVIDIA NIM 185 | 186 | To use ComputeEval with NVIDIA-hosted models, you need an API key from 187 | [build.nvidia.com](https://build.nvidia.com). 188 | 189 | 1. Go to [build.nvidia.com](https://build.nvidia.com) 190 | 1. Sign in with your account 191 | 1. Verify that you have sufficient credits to call hosted models 192 | 1. Navigate to the desired model and click on it 193 | 1. Click on `Get API Key` 194 | 1. Copy the generated API key 195 | 1. Export it as an environment variable: 196 | 197 | ```bash 198 | export NEMO_API_KEY="" 199 | ``` 200 | 201 | #### OpenAI 202 | 203 | Follow the instructions in the [OpenAI docs](https://openai.com/index/openai-api), 204 | then: 205 | 206 | ```bash 207 | export OPENAI_API_KEY="" 208 | ``` 209 | 210 | #### Anthropic (Claude) 211 | 212 | Follow instruction on [Anthropic docs](https://www.anthropic.com/api), then: 213 | 214 | ```bash 215 | export ANTHROPIC_API_KEY="" 216 | ``` 217 | 218 | ## Usage 219 | 220 | **Note:** This repository executes machine-generated CUDA C++ code. 221 | While it's unlikely that the code is malicious, it could still pose potential risks. 222 | Therefore, all code execution requires the `--allow_execution` flag. 223 | We strongly recommend using a sandbox environment (e.g., a Docker container or virtual machine) when running generated code to minimize security risks. 224 | 225 | ### Using Preset NIM Models 226 | 227 | To generate solutions using NVIDIA-hosted models: 228 | 229 | ```bash 230 | uv run compute_eval generate_samples \ 231 | --release=2025-2 \ 232 | --base_url=https://integrate.api.nvidia.com/v1 \ 233 | --model=openai/gpt-oss-120b \ 234 | --solutions_per_problem=3 \ 235 | --n_workers=10 236 | ``` 237 | 238 | **Note:** Set `NEMO_API_KEY` environment variable when using preset NIM models. 239 | 240 | This will: 241 | - Read problems from the 2025-2 release datapack 242 | - Generate 3 solutions per problem using the `openai/gpt-oss-120b` model 243 | - Write all solutions to: `2025-2-openai-gpt-oss-120b-solutions.tar.gz` 244 | 245 | You can find the list of available models at [NVIDIA NIM Model Catalog](https://build.nvidia.com/models). 246 | 247 | ### Using OpenAI-Compatible APIs 248 | 249 | For models with OpenAI-compatible API endpoints: 250 | 251 | ```bash 252 | uv run compute_eval generate_samples \ 253 | --release=2025-2 \ 254 | --model=gpt-5 \ 255 | --solutions_per_problem=3 \ 256 | --n_workers=10 257 | ``` 258 | 259 | **Note:** Set `OPENAI_API_KEY` environment variable when using custom OpenAI-compatible endpoints. 260 | 261 | This will: 262 | - Read problems from the 2025-2 release datapack 263 | - Generate 3 solutions per problem using the `gpt-5` model 264 | - Write all solutions to: `2025-2-gpt-5-solutions.tar.gz` 265 | 266 | ### Using Configuration Files 267 | 268 | You can also use YAML configuration files for convenience: 269 | 270 | ```yaml 271 | # config.yaml 272 | release: 2025-2 273 | model: gpt-5 274 | solutions_per_problem: 3 275 | n_workers: 10 276 | ``` 277 | 278 | ```bash 279 | uv run compute_eval generate_samples --config_file=config.yaml 280 | ``` 281 | 282 | CLI arguments override config file values. 283 | 284 | ### Generating and Evaluating Solutions 285 | 286 | After generating solutions (see examples above), evaluate them with: 287 | 288 | ```bash 289 | uv run compute_eval evaluate_functional_correctness \ 290 | --solutions_datapack=2025-2-gpt-5-solutions.tar.gz \ 291 | --allow_execution=true \ 292 | --k='(1, 3)' \ 293 | --n_workers=4 294 | ``` 295 | 296 | **Security Note:** You must pass `--allow_execution=true` to run the evaluation. As described in the Evaluation Rules of Engagement section, this executes untrusted model-generated code, so use appropriate sandboxing. 297 | 298 | This will: 299 | - Read the problems and solutions datapacks 300 | - Build and execute each solution in an isolated workspace with the test harness 301 | - Output structured JSON with `pass@k` metrics and problem count 302 | - Write results to a graded solutions file (e.g., `2025-2-graded-solutions.jsonl`) 303 | 304 | **Note:** The `k` parameter can be a single integer (`--k=1`) or a tuple (`--k='(1, 3)'`). For accurate pass@k estimates, ensure `max(k) <= solutions_per_problem`. 305 | 306 | ## Command Reference 307 | 308 | ### `generate_samples` 309 | 310 | Generates solutions for all problems in a release datapack using a specified model or custom API endpoint. 311 | 312 | #### Configuration Parameters 313 | 314 | All parameters can be specified in a YAML config file or passed as CLI arguments (CLI arguments take precedence). 315 | 316 | - `release` (str): Release version to generate solutions for (e.g., "2025-2") (default: "2025-2") 317 | - `problems_datapack_dir` (str): Directory where released problem datapacks are stored (default: "data/releases/") 318 | - `solutions_per_problem` (int): Number of solutions to generate per problem (default: 1) 319 | - `n_workers` (int): Number of worker threads to use (default: 10) 320 | - `system_prompt` (str): System prompt for the model (default: predefined CUDA programming prompt) 321 | - `model` (str): Model to use (use an appropriate NIM or use an OpenAI model name) (required) 322 | - `base_url` (str | None): Custom API base URL (default: None) 323 | - `reasoning` (str | None): Reasoning level for OpenAI models (e.g., "low", "medium", "high") (default: None) 324 | - `temperature` (float): Sampling temperature for generation (default: 1.0) 325 | - `top_p` (float): Nucleus sampling parameter (default: Model dependent) 326 | - `max_tokens` (int): Maximum tokens to generate (default: Model dependent) 327 | - `temp_dir` (str | None): Temporary directory for intermediate results (default: None) 328 | - `debug` (bool): Include system prompt, prompt, and completion in output for debugging (default: False) 329 | 330 | **Note**: `model` must be specified. 331 | 332 | ### `evaluate_functional_correctness` 333 | 334 | Evaluates the functional correctness of generated solutions by compiling and executing them against held-out test suites. Outputs structured JSON with `pass@k` metrics. 335 | 336 | #### Configuration Parameters 337 | 338 | All parameters can be specified in a YAML config file or passed as CLI arguments (CLI arguments take precedence). 339 | 340 | - `solutions_datapack` (str): Path to the solutions datapack file (required) 341 | - `problems_datapack_dir` (str): Directory where released problem datapacks are stored (default: "data/releases/") 342 | - `allow_execution` (bool): Whether to allow execution of untrusted code - must be set to True (default: False) 343 | - `k` (int | tuple[int, ...]): K value(s) for pass@k evaluation (default: 1) 344 | - `n_workers` (int): Number of worker threads (default: 4) 345 | - `results_file` (str | None): Path to output results file (default: auto-generated from release name) 346 | 347 | ## Dataset 348 | 349 | For more information about the dataset see [`DATASET_CARD.md`](DATASET_CARD.md). 350 | 351 | ## Contributing 352 | 353 | See [`CONTRIBUTING.md`](CONTRIBUTING.md) for development instructions. 354 | -------------------------------------------------------------------------------- /compute_eval/utils/parsing.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import importlib.resources as pkg_resources 3 | import os 4 | from collections.abc import Collection 5 | from pathlib import Path 6 | 7 | import yaml 8 | from tree_sitter import Language, Node, Parser 9 | from tree_sitter_language_pack import get_language, get_parser 10 | 11 | # (tree sitter name, linguist name, codemirror mime type, popularity rank) tuples. 12 | # Source of tree sitter names: 13 | # https://github.com/Goldziher/tree-sitter-language-pack 14 | # Source of linguist names: 15 | # https://raw.githubusercontent.com/github-linguist/linguist/master/lib/linguist/languages.yml 16 | # (mirrored in encoding/language_data/languages.yml) 17 | # Source of CodeMirror MIME types: 18 | # From linguist languages.yml codemirror_mime_type field 19 | # Source of popularity ranks: 20 | # https://innovationgraph.github.com/global-metrics/programming-languages 21 | # https://raw.githubusercontent.com/github/innovationgraph/main/data/languages.csv 22 | RAW_LANGUAGE_DATA = [ 23 | ("actionscript", "ActionScript", None, 159), 24 | ("ada", "Ada", None, 138), 25 | ("agda", "Agda", None, 270), 26 | ("apex", "Apex", "text/x-java", 154), 27 | ("asm", "Assembly", None, 27), 28 | ("astro", "Astro", "text/jsx", 89), 29 | ("bash", "Shell", "text/x-sh", 5), 30 | ("bibtex", "BibTeX", "text/x-stex", 300), 31 | ("bicep", "Bicep", None, 135), 32 | ("bitbake", "BitBake", None, 139), 33 | ("c", "C", "text/x-csrc", 11), 34 | ("cairo", "Cairo", None, 262), 35 | ("capnp", "Cap'n Proto", None, 194), 36 | ("clarity", "Clarity", None, 365), 37 | ("clojure", "Clojure", "text/x-clojure", 86), 38 | ("cmake", "CMake", "text/x-cmake", 18), 39 | ("commonlisp", "Common Lisp", "text/x-common-lisp", 97), 40 | ("cpp", "C++", "text/x-c++src", 9), 41 | ("csharp", "C#", "text/x-csharp", 15), 42 | ("css", "CSS", "text/css", 3), 43 | ("csv", "CSV", None, 400), 44 | ("cuda", "Cuda", "text/x-c++src", 62), 45 | ("d", "D", "text/x-d", 115), 46 | ("dart", "Dart", "application/dart", 29), 47 | ("dockerfile", "Dockerfile", "text/x-dockerfile", 8), 48 | ("elisp", "Emacs Lisp", "text/x-common-lisp", 56), 49 | ("elixir", "Elixir", None, 94), 50 | ("elm", "Elm", "text/x-elm", 158), 51 | ("erlang", "Erlang", "text/x-erlang", 109), 52 | ("fennel", "Fennel", None, 333), 53 | ("firrtl", "FIRRTL", None, 750), 54 | ("fish", "fish", None, 220), 55 | ("fortran", "Fortran Free Form", "text/x-fortran", 355), 56 | ("fsharp", "F#", "text/x-fsharp", 116), 57 | ("gdscript", "GDScript", None, 98), 58 | ("gitattributes", "Git Attributes", "text/x-sh", 600), 59 | ("gitcommit", "Git Commit", None, 1000), 60 | ("gitignore", "Ignore List", "text/x-sh", 500), 61 | ("gleam", "Gleam", None, 328), 62 | ("glsl", "GLSL", None, 46), 63 | ("gn", "GN", "text/x-python", 500), 64 | ("go", "Go", "text/x-go", 20), 65 | ("gomod", "Go Module", None, 355), 66 | ("gosum", "Go Checksums", None, 600), 67 | ("graphql", "GraphQL", None, 130), 68 | ("groovy", "Groovy", "text/x-groovy", 52), 69 | ("hack", "Hack", "application/x-httpd-php", 31), 70 | ("hare", "Hare", None, 800), 71 | ("haskell", "Haskell", "text/x-haskell", 73), 72 | ("haxe", "Haxe", "text/x-haxe", 166), 73 | ("hcl", "HCL", "text/x-ruby", 38), 74 | ("heex", "HTML+EEX", "text/html", 400), 75 | ("hlsl", "HLSL", None, 41), 76 | ("html", "HTML", "text/html", 1), 77 | ("ispc", "ISPC", "text/x-csrc", 550), 78 | ("ini", "INI", "text/x-properties", 300), 79 | ("janet", "Janet", "text/x-scheme", 375), 80 | ("java", "Java", "text/x-java", 6), 81 | ("javascript", "JavaScript", "text/javascript", 2), 82 | ("json", "JSON", "application/json", 190), 83 | ("jsonnet", "Jsonnet", None, 127), 84 | ("julia", "Julia", "text/x-julia", 85), 85 | ("kdl", "KDL", "text/x-yacas", 800), 86 | ("kotlin", "Kotlin", "text/x-kotlin", 19), 87 | ("latex", "TeX", "text/x-stex", 30), 88 | ("linkerscript", "Linker Script", None, 201), 89 | ("llvm", "LLVM", None, 126), 90 | ("lua", "Lua", "text/x-lua", 25), 91 | ("luau", "Luau", "text/x-lua", 300), 92 | ("make", "Makefile", "text/x-cmake", 12), 93 | ("markdown", "Markdown", "text/x-gfm", 146), 94 | ("matlab", "MATLAB", "text/x-octave", 45), 95 | ("mermaid", "Mermaid", None, 192), 96 | ("meson", "Meson", None, 79), 97 | ("netlinx", "NetLinx", None, 1000), 98 | ("nim", "Nim", None, 140), 99 | ("ninja", "Ninja", None, 450), 100 | ("nix", "Nix", None, 50), 101 | ("objc", "Objective-C", "text/x-objectivec", 21), 102 | ("ocaml", "OCaml", "text/x-ocaml", 107), 103 | ("odin", "Odin", None, 314), 104 | ("org", "Org", None, 320), 105 | ("pascal", "Pascal", "text/x-pascal", 77), 106 | ("perl", "Perl", "text/x-perl", 28), 107 | ("php", "PHP", "application/x-httpd-php", 14), 108 | ("po", "Gettext Catalog", None, 500), 109 | ("pony", "Pony", None, 326), 110 | ("powershell", "PowerShell", "application/x-powershell", 24), 111 | ("prisma", "Prisma", None, 220), 112 | ("properties", "INI", "text/x-properties", 300), 113 | ("proto", "Protocol Buffer", "text/x-protobuf", 340), 114 | ("puppet", "Puppet", "text/x-puppet", 133), 115 | ("purescript", "PureScript", "text/x-haskell", 165), 116 | ("python", "Python", "text/x-python", 4), 117 | ("qmljs", "QML", None, 114), 118 | ("query", "Tree-sitter Query", None, 1000), 119 | ("r", "R", "text/x-rsrc", 32), 120 | ("racket", "Racket", None, 170), 121 | ("re2c", "RenderScript", None, 141), 122 | ("readline", "Readline Config", None, 1000), 123 | ("rego", "Open Policy Agent", None, 172), 124 | ("requirements", "Pip Requirements", None, 400), 125 | ("ron", "RON", None, 650), 126 | ("rst", "reStructuredText", "text/x-rst", 374), 127 | ("ruby", "Ruby", "text/x-ruby", 16), 128 | ("rust", "Rust", "text/x-rustsrc", 26), 129 | ("scala", "Scala", "text/x-scala", 59), 130 | ("scheme", "Scheme", "text/x-scheme", 84), 131 | ("scss", "SCSS", "text/x-scss", 10), 132 | ("smali", "Smali", None, 351), 133 | ("smithy", "Smithy", "text/x-csrc", 353), 134 | ("solidity", "Solidity", None, 57), 135 | ("sparql", "SPARQL", "application/sparql-query", 600), 136 | ("sql", "SQL", "text/x-sql", 339), 137 | ("squirrel", "Squirrel", "text/x-squirrel", 299), 138 | ("starlark", "Starlark", "text/x-python", 51), 139 | ("svelte", "Svelte", "text/html", 67), 140 | ("swift", "Swift", "text/x-swift", 23), 141 | ("tcl", "Tcl", "text/x-tcl", 68), 142 | ("thrift", "Thrift", None, 130), 143 | ("toml", "TOML", "text/x-toml", 355), 144 | ("tsv", "TSV", None, 550), 145 | ("tsx", "TSX", "text/typescript-jsx", 6), 146 | ("twig", "Twig", "text/x-twig", 72), 147 | ("typescript", "TypeScript", "application/typescript", 7), 148 | ("typst", "Typst", None, 500), 149 | ("v", "V", "text/x-go", 205), 150 | ("verilog", "Verilog", "text/x-verilog", 92), 151 | ("vhdl", "VHDL", "text/x-vhdl", 113), 152 | ("vim", "Vim Script", None, 37), 153 | ("vue", "Vue", "text/x-vue", 22), 154 | ("wast", "WebAssembly", "text/webassembly", 160), 155 | ("wat", "WebAssembly", "text/webassembly", 160), 156 | ("wgsl", "WGSL", None, 550), 157 | ("xcompose", "XCompose", None, 900), 158 | ("xml", "XML", "text/xml", 92), 159 | ("yaml", "YAML", "text/x-yaml", 180), 160 | ("zig", "Zig", None, 149), 161 | ] 162 | 163 | RESOURCE_PATH = "language_data" 164 | LINGUIST_FILE_NAME = "languages.yml.gz" 165 | 166 | 167 | class ParseableLanguage: 168 | """ 169 | Represents a programming language that can be parsed using tree-sitter. 170 | 171 | This class encapsulates language metadata (names, file extensions, known files) 172 | and provides parsing capabilities through tree-sitter. It lazily initializes 173 | the tree-sitter language and parser objects on first use. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | canonical_name: str, 179 | treesitter_name: str, 180 | file_extensions: Collection[str], 181 | alt_file_extensions: Collection[str], 182 | known_files: Collection[str], 183 | alt_known_files: Collection[str], 184 | popularity_rank: int, 185 | ): 186 | """ 187 | Initialize a ParseableLanguage instance. 188 | 189 | Args: 190 | canonical_name: The canonical/linguist name for the language (e.g., "Python", "C++"). 191 | treesitter_name: The tree-sitter language identifier (e.g., "python", "cpp"). 192 | file_extensions: Primary file extensions for this language (e.g., [".py", ".pyw"]). 193 | alt_file_extensions: Alternative/secondary file extensions. 194 | known_files: Known filenames that identify this language (e.g., ["Makefile"]). 195 | alt_known_files: Alternative known filenames. 196 | popularity_rank: Numeric rank indicating language popularity (lower is more popular). 197 | """ 198 | self.canonical_name = canonical_name 199 | self.treesitter_name = treesitter_name 200 | self.file_extensions = frozenset(file_extensions) 201 | self.alt_file_extensions = frozenset(alt_file_extensions) 202 | self.known_files = frozenset(known_files) 203 | self.alt_known_files = frozenset(alt_known_files) 204 | self.popularity_rank = popularity_rank 205 | self._language = None 206 | self._parser = None 207 | 208 | def get_language(self) -> Language: 209 | """ 210 | Get the tree-sitter Language object for this language. 211 | 212 | The Language object is lazily initialized on first access and cached 213 | for subsequent calls. 214 | 215 | Returns: 216 | The tree-sitter Language object for parsing this language. 217 | """ 218 | if self._language is None: 219 | self._language = get_language(self.treesitter_name) 220 | return self._language 221 | 222 | def get_parser(self) -> Parser: 223 | """ 224 | Get the tree-sitter Parser object for this language. 225 | 226 | The Parser object is lazily initialized on first access and cached 227 | for subsequent calls. 228 | 229 | Returns: 230 | The tree-sitter Parser object configured for this language. 231 | """ 232 | if self._parser is None: 233 | self._parser = get_parser(self.treesitter_name) 234 | return self._parser 235 | 236 | @staticmethod 237 | def _accumulate_errors(node: Node, errors: list[tuple[int, int]]): 238 | if node.is_error: 239 | errors.append((node.start_point[0] + 1, node.end_point[0] + 1)) 240 | elif node.has_error: 241 | for child in node.children: 242 | ParseableLanguage._accumulate_errors(child, errors) 243 | 244 | def parse_errors(self, contents: bytes) -> list[tuple[int, int]]: 245 | """ 246 | Parse the given contents and return a list of syntax errors. 247 | 248 | Each error is represented as a tuple of (start_line, end_line) where 249 | line numbers are 1-indexed. 250 | 251 | Args: 252 | contents: The source code to parse as bytes. 253 | 254 | Returns: 255 | A list of tuples (start_line, end_line) representing error ranges. 256 | Returns an empty list if there are no syntax errors. 257 | """ 258 | tree = self.get_parser().parse(contents) 259 | errors = [] 260 | self._accumulate_errors(tree.root_node, errors) 261 | return errors 262 | 263 | def parse_error_line_count(self, contents: bytes) -> int: 264 | """ 265 | Count the total number of lines containing syntax errors. 266 | 267 | This method parses the contents and calculates the total number of lines 268 | that are part of error ranges (inclusive of start and end lines). 269 | 270 | Args: 271 | contents: The source code to parse as bytes. 272 | 273 | Returns: 274 | The total count of lines with syntax errors. 275 | """ 276 | error_lines = 0 277 | for start_line, end_line in self.parse_errors(contents): 278 | error_lines += end_line - start_line + 1 279 | return error_lines 280 | 281 | def parse_tree_spans(self, contents: bytes) -> dict[tuple[int, int], str]: 282 | """ 283 | Parse contents and return a mapping of byte spans to node types. 284 | 285 | This method creates a dictionary mapping byte ranges to their corresponding 286 | tree-sitter node types. Only includes complete, non-error nodes that are not 287 | "extra" (comments, whitespace, etc.) or missing. Uses tree.walk() for 288 | efficient linear-time traversal via native C code. 289 | 290 | Args: 291 | contents: The source code to parse as bytes. 292 | 293 | Returns: 294 | A dictionary mapping (start_byte, end_byte) tuples to node type strings. 295 | The byte positions are 0-indexed offsets into the contents. 296 | """ 297 | tree = self.get_parser().parse(contents) 298 | spans = {} 299 | 300 | cursor = tree.walk() 301 | visited_children = False 302 | 303 | # preorder: matches a left-to-right, top-to-bottom scan of the file 304 | while True: 305 | node = cursor.node 306 | if not node.is_error and not visited_children: 307 | if not node.is_extra and not node.is_missing: 308 | spans[(node.start_byte, node.end_byte)] = node.type 309 | if cursor.goto_first_child(): 310 | visited_children = False 311 | continue 312 | 313 | if cursor.goto_next_sibling(): 314 | visited_children = False 315 | continue 316 | if not cursor.goto_parent(): 317 | break 318 | visited_children = True 319 | 320 | return spans 321 | 322 | @staticmethod 323 | def _find_all(content: bytes, target: bytes) -> list[int]: 324 | matches = [] 325 | start_byte = 0 326 | while True: 327 | start_byte = content.find(target, start_byte) 328 | if start_byte == -1: 329 | break 330 | matches.append(start_byte) 331 | start_byte += len(target) 332 | return matches 333 | 334 | def find_matching_subtrees(self, content: bytes, targets: list[bytes]) -> list[tuple[bytes, int]]: 335 | """ 336 | Find exact byte sequence matches that correspond to complete parse tree nodes. 337 | 338 | This method searches for exact byte sequences in the content and verifies that 339 | each match corresponds to a complete, valid subtree in the parse tree (not an 340 | error, extra, or missing node, and with exact byte boundaries). 341 | 342 | Args: 343 | content: The source code to search in as bytes. 344 | targets: A list of byte sequences to search for. 345 | 346 | Returns: 347 | A list of tuples (target, byte_offset) where: 348 | - target: The matched byte sequence from the targets list 349 | - byte_offset: The 0-indexed byte position where the match starts 350 | Only returns matches that are valid complete subtrees. 351 | """ 352 | tree = self.get_parser().parse(content) 353 | root = tree.root_node 354 | matches = [] 355 | for target in targets: 356 | target_len = len(target) 357 | for match in self._find_all(content, target): 358 | subtree = root.descendant_for_byte_range(match, match + target_len) 359 | if ( 360 | subtree is not None 361 | and not subtree.is_error 362 | and not subtree.is_extra 363 | and not subtree.is_missing 364 | and subtree.start_byte == match 365 | and subtree.end_byte == match + target_len 366 | ): 367 | matches.append((target, match)) 368 | 369 | return matches 370 | 371 | def __str__(self): 372 | return self.canonical_name 373 | 374 | def __repr__(self): 375 | return f"ParseableLanguage({self.canonical_name})" 376 | 377 | def __hash__(self): 378 | return hash(self.canonical_name) 379 | 380 | def __eq__(self, other): 381 | if not isinstance(other, ParseableLanguage): 382 | return False 383 | return self.canonical_name == other.canonical_name 384 | 385 | def __ne__(self, other): 386 | return not self.__eq__(other) 387 | 388 | 389 | def _collect_exts_and_files(linguist_data, ling_language, cm_mime): 390 | file_extensions, alt_file_extensions = set(), set() 391 | known_files, alt_known_files = set(), set() 392 | for language, data in linguist_data.items(): 393 | if language == ling_language: 394 | file_extensions.update(data.get("extensions", [])) 395 | known_files.update(data.get("filenames", [])) 396 | else: 397 | if ling_language and data.get("group") == ling_language: 398 | alt_file_extensions.update(data.get("extensions", [])) 399 | alt_known_files.update(data.get("filenames", [])) 400 | if cm_mime and data.get("codemirror_mime_type") == cm_mime: 401 | alt_file_extensions.update(data.get("extensions", [])) 402 | alt_known_files.update(data.get("filenames", [])) 403 | return file_extensions, alt_file_extensions, known_files, alt_known_files 404 | 405 | 406 | def _build_parseable_languages() -> list[ParseableLanguage]: 407 | try: 408 | if __package__: 409 | resources = pkg_resources.files(__package__) / RESOURCE_PATH 410 | else: 411 | resources = Path(__file__).parent / RESOURCE_PATH 412 | except Exception: 413 | resources = Path(__file__).parent / RESOURCE_PATH 414 | 415 | linguist_file = resources / LINGUIST_FILE_NAME 416 | with gzip.open(linguist_file, "rt") as f: 417 | linguist_data = yaml.safe_load(f) 418 | 419 | parseable_languages = [] 420 | 421 | for ts_language, ling_language, cm_mime, popularity_rank in RAW_LANGUAGE_DATA: 422 | if ling_language is not None: 423 | file_extensions, alt_file_extensions, known_files, alt_known_files = _collect_exts_and_files( 424 | linguist_data, ling_language, cm_mime 425 | ) 426 | 427 | parseable_languages.append( 428 | ParseableLanguage( 429 | canonical_name=ling_language, 430 | treesitter_name=ts_language, 431 | file_extensions=file_extensions, 432 | alt_file_extensions=alt_file_extensions, 433 | known_files=known_files, 434 | alt_known_files=alt_known_files, 435 | popularity_rank=popularity_rank, 436 | ) 437 | ) 438 | 439 | return parseable_languages 440 | 441 | 442 | def _build_file_and_extension_maps( 443 | languages: list[ParseableLanguage], 444 | ) -> tuple[ 445 | dict[str, set[ParseableLanguage]], 446 | dict[str, set[ParseableLanguage]], 447 | dict[str, set[ParseableLanguage]], 448 | dict[str, set[ParseableLanguage]], 449 | ]: 450 | extension_map, alt_extension_map = {}, {} 451 | file_map, alt_file_map = {}, {} 452 | for language in languages: 453 | for extension in language.file_extensions: 454 | extension_map.setdefault(extension, set()).add(language) 455 | for extension in language.alt_file_extensions: 456 | alt_extension_map.setdefault(extension, set()).add(language) 457 | for filename in language.known_files: 458 | file_map.setdefault(filename, set()).add(language) 459 | for filename in language.alt_known_files: 460 | alt_file_map.setdefault(filename, set()).add(language) 461 | return extension_map, alt_extension_map, file_map, alt_file_map 462 | 463 | 464 | ALL_LANGUAGES = _build_parseable_languages() 465 | 466 | ( 467 | LANGUAGES_BY_EXTENSION, 468 | LANGUAGES_BY_ALT_EXTENSION, 469 | LANGUAGES_BY_FILENAME, 470 | LANGUAGES_BY_ALT_FILENAME, 471 | ) = _build_file_and_extension_maps(ALL_LANGUAGES) 472 | 473 | 474 | def _get_languages_by_path( 475 | path: str, 476 | ext_map: dict[str, set[ParseableLanguage]], 477 | fname_map: dict[str, set[ParseableLanguage]], 478 | ) -> Collection[ParseableLanguage]: 479 | parts = path.split(os.sep) 480 | if not parts: 481 | return [] 482 | 483 | file = parts[-1] 484 | if file in fname_map: 485 | return fname_map[file] 486 | 487 | ext_idx = file.find(".") 488 | while ext_idx != -1: 489 | extension = file[ext_idx:] 490 | if extension in ext_map: 491 | return ext_map[extension] 492 | ext_idx = file.find(".", ext_idx + 1) 493 | return [] 494 | 495 | 496 | def get_language_by_path(path: str) -> ParseableLanguage | None: 497 | languages = _get_languages_by_path(path, LANGUAGES_BY_EXTENSION, LANGUAGES_BY_FILENAME) 498 | if not languages: 499 | languages = _get_languages_by_path(path, LANGUAGES_BY_ALT_EXTENSION, LANGUAGES_BY_ALT_FILENAME) 500 | return min(languages, key=lambda lang: lang.popularity_rank, default=None) 501 | 502 | 503 | def _get_most_likely_language( 504 | path: str, 505 | contents: bytes, 506 | ext_map: dict[str, set[ParseableLanguage]], 507 | fname_map: dict[str, set[ParseableLanguage]], 508 | ) -> ParseableLanguage | None: 509 | languages = _get_languages_by_path(path, ext_map, fname_map) 510 | if len(languages) == 1: 511 | return next(iter(languages)) 512 | elif len(languages) > 1: 513 | return min( 514 | languages, 515 | key=lambda lang: ( 516 | lang.parse_error_line_count(contents), 517 | lang.popularity_rank, 518 | ), 519 | ) 520 | return None 521 | 522 | 523 | def get_most_likely_language(path: str, contents: bytes | str) -> ParseableLanguage | None: 524 | if isinstance(contents, str): 525 | contents = contents.encode("utf-8") 526 | language = _get_most_likely_language(path, contents, LANGUAGES_BY_EXTENSION, LANGUAGES_BY_FILENAME) 527 | if language is None: 528 | language = _get_most_likely_language(path, contents, LANGUAGES_BY_ALT_EXTENSION, LANGUAGES_BY_ALT_FILENAME) 529 | return language 530 | 531 | 532 | if __name__ == "__main__": 533 | for language in ALL_LANGUAGES: 534 | if language.get_language() is None or language.get_parser() is None: 535 | print(f"Failed to load tree-sitter language or parser for {language}") 536 | exit(1) 537 | 538 | sample_file = Path(__file__).parent.parent / "temp/cudaTestProgram.cu" 539 | if not sample_file.exists(): 540 | print(f"Sample file not found: {sample_file}") 541 | print("Please provide a valid file path to test the parser.") 542 | exit(1) 543 | 544 | with open(sample_file, "rb") as f: 545 | sample_contents = f.read() 546 | 547 | language = get_most_likely_language(str(sample_file), sample_contents) 548 | if language is None: 549 | print(f"Could not determine language for {sample_file}") 550 | exit(1) 551 | 552 | print(f"Detected language: {language}") 553 | print(f"Parsing {sample_file}...") 554 | errors = language.parse_errors(sample_contents) 555 | if errors: 556 | print(f"Parse errors: {errors}") 557 | else: 558 | print("No parse errors") 559 | 560 | matches = language.find_matching_subtrees(sample_contents, [b"cudaCheckErrors", b"cuda", b"d_", b"d_A"]) 561 | if matches: 562 | print(f"Found matches: {matches}") 563 | else: 564 | print("No matches found") 565 | --------------------------------------------------------------------------------