├── .github └── workflows │ ├── build_kernel.yaml │ ├── build_kernel_rocm.yaml │ ├── check_variants.yaml │ ├── docker-build-push.yaml │ ├── nix_fmt.yaml │ └── rust.yaml ├── .gitignore ├── README.md ├── build-variants.json ├── build2cmake ├── Cargo.lock ├── Cargo.toml ├── build.rs ├── flake.lock ├── flake.nix └── src │ ├── config │ ├── mod.rs │ ├── v1.rs │ └── v2.rs │ ├── cuda_supported_archs.json │ ├── fileset.rs │ ├── main.rs │ ├── templates │ ├── _ops.py │ ├── cuda │ │ ├── dep-cutlass.cmake │ │ ├── hipify.py │ │ ├── kernel.cmake │ │ ├── preamble.cmake │ │ ├── setup.py │ │ ├── torch-binding.cmake │ │ └── torch-extension.cmake │ ├── metal │ │ ├── kernel.cmake │ │ ├── preamble.cmake │ │ ├── setup.py │ │ ├── torch-binding.cmake │ │ ├── torch-extension.cmake │ │ └── utils.cmake │ ├── pyproject.toml │ ├── registration.h │ ├── universal │ │ ├── _ops.py │ │ └── pyproject.toml │ └── utils.cmake │ └── torch │ ├── cuda.rs │ ├── metal.rs │ ├── mod.rs │ ├── ops_identifier.rs │ └── universal.rs ├── default.nix ├── dockerfiles ├── Dockerfile ├── Dockerfile.user └── README.md ├── docs ├── build-variants.md ├── docker.md ├── local-dev.md ├── nix.md ├── toolchain.md ├── why-nix.md └── writing-kernels.md ├── examples ├── activation │ ├── LICENSE │ ├── README.md │ ├── activation │ │ ├── activation_kernels.cu │ │ ├── cuda_compat.h │ │ └── dispatch_utils.h │ ├── build.toml │ ├── flake.nix │ ├── tests │ │ ├── __init__.py │ │ └── kernels │ │ │ ├── __init__.py │ │ │ ├── allclose_default.py │ │ │ ├── test_activation.py │ │ │ └── utils.py │ └── torch-ext │ │ ├── activation │ │ └── __init__.py │ │ ├── torch_binding.cpp │ │ └── torch_binding.h ├── cutlass-gemm │ ├── build.toml │ ├── flake.nix │ ├── gemm.cu │ ├── tests │ │ └── test_gemm.py │ └── torch-ext │ │ ├── cutlass_gemm │ │ └── __init__.py │ │ ├── registration.h │ │ ├── torch_binding.cpp │ │ └── torch_binding.h ├── relu │ ├── build.toml │ ├── flake.nix │ ├── relu_cuda │ │ └── relu.cu │ ├── relu_metal │ │ └── relu.mm │ ├── tests │ │ ├── __init__.py │ │ └── test_relu.py │ └── torch-ext │ │ ├── relu │ │ └── __init__.py │ │ ├── torch_binding.cpp │ │ └── torch_binding.h └── silu-and-mul-universal │ ├── build.toml │ ├── flake.nix │ ├── tests │ └── test_silu_and_mul.py │ └── torch-ext │ └── silu_and_mul_universal │ ├── __init__.py │ └── silu_and_mul.py ├── flake.lock ├── flake.nix ├── kernel-abi-check ├── Cargo.lock ├── Cargo.toml ├── flake.lock ├── flake.nix └── src │ ├── lib.rs │ ├── main.rs │ ├── manylinux │ ├── manylinux-policy.json │ └── mod.rs │ ├── python_abi │ ├── mod.rs │ └── stable_abi.toml │ └── version.rs ├── kernel-compliance-check ├── Cargo.lock ├── Cargo.toml └── src │ ├── build_variants.json │ ├── formatter.rs │ ├── lib.rs │ ├── main.rs │ └── models.rs ├── lib ├── build-version.nix ├── build.nix ├── buildsets.nix ├── deps.nix ├── join-paths │ └── default.nix ├── source-set.nix ├── torch-extension-noarch │ └── default.nix ├── torch-extension │ ├── _ops.py.in │ └── default.nix └── version-utils.nix ├── overlay.nix ├── pkgs ├── build2cmake │ └── default.nix ├── cmake-nvcc-threads-hook │ ├── cmake-nvcc-threads-hook.sh │ └── default.nix ├── kernel-abi-check │ ├── default.nix │ └── kernel-abi-check-hook.sh └── stdenv-glibc-2_27 │ └── default.nix ├── scripts ├── gen_variants_markdown.py └── init-kernel.py ├── tests └── Dockerfile.test-kernel └── versions.nix /.github/workflows/build_kernel.yaml: -------------------------------------------------------------------------------- 1 | name: "Build and test kernel" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, synchronize, reopened] # trigger on PRs 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | name: Build kernel 13 | runs-on: 14 | group: aws-g6-12xlarge-plus 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: cachix/install-nix-action@v27 18 | with: 19 | nix_path: nixpkgs=channel:nixos-unstable 20 | - uses: cachix/cachix-action@v14 21 | with: 22 | name: huggingface 23 | authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" 24 | env: 25 | USER: github_runner 26 | - name: Build activation kernel 27 | run: ( cd examples/activation && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux ) 28 | - name: Copy activation kernel 29 | run: cp -rL examples/activation/result activation-kernel 30 | 31 | - name: Build cutlass GEMM kernel 32 | run: ( cd examples/cutlass-gemm && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux ) 33 | - name: Copy cutlass GEMM kernel 34 | run: cp -rL examples/cutlass-gemm/result cutlass-gemm-kernel 35 | 36 | - name: Build relu kernel 37 | run: ( cd examples/relu && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux ) 38 | - name: Copy relu kernel 39 | run: cp -rL examples/relu/result relu-kernel 40 | 41 | - name: Build silu-and-mul-universal kernel 42 | run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux ) 43 | - name: Copy silu-and-mul-universal kernel 44 | run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel 45 | 46 | - name: Set up Docker Buildx 47 | uses: docker/setup-buildx-action@v3 48 | - name: Build Docker image 49 | uses: docker/build-push-action@v5 50 | with: 51 | context: . 52 | file: tests/Dockerfile.test-kernel 53 | platforms: linux/amd64 54 | load: true 55 | push: false 56 | tags: kernel-builder:latest 57 | 58 | - name: Run Tests 59 | run: | 60 | docker run --gpus all kernel-builder:latest 61 | -------------------------------------------------------------------------------- /.github/workflows/build_kernel_rocm.yaml: -------------------------------------------------------------------------------- 1 | name: "Build and test kernel (ROCm)" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, synchronize, reopened] # trigger on PRs 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | name: Build kernel 13 | runs-on: 14 | group: aws-g6-12xlarge-plus 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: cachix/install-nix-action@v27 18 | with: 19 | nix_path: nixpkgs=channel:nixos-unstable 20 | - uses: cachix/cachix-action@v14 21 | with: 22 | name: huggingface 23 | authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" 24 | env: 25 | USER: github_runner 26 | # For now we only test that there are no regressions in building ROCm 27 | # kernels. Also run tests once we have a ROCm runner. 28 | - name: Build relu kernel 29 | run: ( cd examples/relu && nix build .\#redistributable.torch26-cxx11-rocm62-x86_64-linux -L ) 30 | -------------------------------------------------------------------------------- /.github/workflows/check_variants.yaml: -------------------------------------------------------------------------------- 1 | name: "Check build variants" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, synchronize, reopened] # trigger on PRs 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | name: Check build variants 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: cachix/install-nix-action@v27 17 | with: 18 | nix_path: nixpkgs=channel:nixos-unstable 19 | - name: Generate variants JSON 20 | run: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json 21 | - name: Check if variants JSON is up-to-date 22 | run: | 23 | if git diff --exit-code build-variants.json; then 24 | echo "✅ variants.json is up-to-date" 25 | else 26 | echo "🛑 regenerate variants.json: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json" 27 | exit 1 28 | fi 29 | - name: Generate variants Markdown 30 | run: nix run nixpkgs#python3 scripts/gen_variants_markdown.py 31 | - name: Check if variants Markdown is up-to-date 32 | run: | 33 | if git diff --exit-code docs/build-variants.md; then 34 | echo "✅ docs/build-variants.md is up-to-date" 35 | else 36 | echo "🛑 regenerate docs/build-variants: nix run nixpkgs#python3 scripts/gen_variants_markdown.py" 37 | exit 1 38 | fi 39 | -------------------------------------------------------------------------------- /.github/workflows/docker-build-push.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Image 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | # Only run on changes to the Dockerfile or workflow file 9 | - "Dockerfile" 10 | - "dockerfiles/**" 11 | - ".github/workflows/docker-build-push.yaml" 12 | workflow_dispatch: # Allow manual triggering 13 | 14 | env: 15 | REGISTRY: ghcr.io 16 | IMAGE_NAME: ${{ github.repository_owner }}/kernel-builder 17 | 18 | jobs: 19 | build-and-push-user: 20 | name: Build and Push User Docker Image 21 | runs-on: ubuntu-latest 22 | permissions: 23 | contents: read 24 | packages: write 25 | 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | 30 | - name: Set up Docker Buildx 31 | uses: docker/setup-buildx-action@v3 32 | 33 | - name: Log in to the Container registry 34 | uses: docker/login-action@v3 35 | with: 36 | registry: ${{ env.REGISTRY }} 37 | username: ${{ github.actor }} 38 | password: ${{ secrets.GITHUB_TOKEN }} 39 | 40 | - name: Extract metadata (tags, labels) for Docker 41 | id: meta 42 | uses: docker/metadata-action@v5 43 | with: 44 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 45 | tags: | 46 | type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'main') }} 47 | type=sha,prefix=user-,format=short 48 | type=ref,prefix=user-,event=branch 49 | type=semver,prefix=user-,pattern={{version}} 50 | type=semver,prefix=user-,pattern={{major}}.{{minor}} 51 | 52 | - name: Build and push Docker image 53 | uses: docker/build-push-action@v5 54 | with: 55 | context: . 56 | file: ./dockerfiles/Dockerfile.user 57 | push: true 58 | tags: ${{ steps.meta.outputs.tags }} 59 | labels: ${{ steps.meta.outputs.labels }} 60 | build-args: | 61 | MAX_JOBS=8 62 | CORES=8 63 | cache-from: type=gha 64 | cache-to: type=gha,mode=max 65 | 66 | build-and-push-root: 67 | name: Build and Push Root Docker Image 68 | runs-on: ubuntu-latest 69 | permissions: 70 | contents: read 71 | packages: write 72 | 73 | steps: 74 | - name: Checkout repository 75 | uses: actions/checkout@v4 76 | 77 | - name: Set up Docker Buildx 78 | uses: docker/setup-buildx-action@v3 79 | 80 | - name: Log in to the Container registry 81 | uses: docker/login-action@v3 82 | with: 83 | registry: ${{ env.REGISTRY }} 84 | username: ${{ github.actor }} 85 | password: ${{ secrets.GITHUB_TOKEN }} 86 | 87 | - name: Extract metadata (tags, labels) for Docker 88 | id: meta-root 89 | uses: docker/metadata-action@v5 90 | with: 91 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 92 | tags: | 93 | type=raw,value=root,enable=${{ github.ref == format('refs/heads/{0}', 'main') }} 94 | type=sha,format=short 95 | type=ref,event=branch 96 | type=semver,pattern={{version}} 97 | type=semver,pattern={{major}}.{{minor}} 98 | 99 | - name: Build and push Docker image 100 | uses: docker/build-push-action@v5 101 | with: 102 | context: . 103 | file: ./dockerfiles/Dockerfile 104 | push: true 105 | tags: ${{ steps.meta-root.outputs.tags }} 106 | labels: ${{ steps.meta-root.outputs.labels }} 107 | build-args: | 108 | MAX_JOBS=8 109 | CORES=8 110 | cache-from: type=gha,scope=root 111 | cache-to: type=gha,mode=max,scope=root 112 | -------------------------------------------------------------------------------- /.github/workflows/nix_fmt.yaml: -------------------------------------------------------------------------------- 1 | name: "Check Nix formatting" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, synchronize, reopened] # trigger on PRs 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | name: Check Nix formatting 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: cachix/install-nix-action@v27 17 | with: 18 | nix_path: nixpkgs=channel:nixos-unstable 19 | - name: Check formatting 20 | run: nix fmt -- --ci 21 | -------------------------------------------------------------------------------- /.github/workflows/rust.yaml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | fmt: 7 | name: Rustfmt 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: dtolnay/rust-toolchain@stable 12 | - name: Cargo fmt (kernel-abi-check) 13 | run: ( cd kernel-abi-check && cargo fmt --all -- --check ) 14 | - name: Cargo fmt (build2cmake) 15 | run: ( cd build2cmake && cargo fmt --all -- --check ) 16 | 17 | clippy: 18 | name: Clippy 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/cache@v4 23 | with: 24 | path: | 25 | ~/.cargo/bin/ 26 | ~/.cargo/registry/index/ 27 | ~/.cargo/registry/cache/ 28 | ~/.cargo/git/db/ 29 | target/ 30 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 31 | - uses: dtolnay/rust-toolchain@stable 32 | with: 33 | components: clippy 34 | - name: Clippy (kernel-abi-check) 35 | run: ( cd kernel-abi-check && cargo clippy -- -D warnings ) 36 | - name: Clippy (build2cmake) 37 | run: ( cd build2cmake && cargo clippy -- -D warnings ) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .obsidian 2 | target -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kernel-builder 2 | 3 |
4 | kernel-builder logo 5 |

6 | Build and Push Docker Image 7 | GitHub tag 8 | GitHub package 9 |

10 |
11 |
12 | 13 | This repo contains a Nix package that can be used to build custom machine learning kernels for PyTorch. The kernels are built using the [PyTorch C++ Frontend](https://pytorch.org/cppdocs/frontend.html) and can be loaded from the Hub with the [kernels](https://github.com/huggingface/kernels) 14 | Python package. 15 | 16 | This builder is a core component of the larger kernel build/distribution system. 17 | 18 | **Torch 2.7 note:** kernel-builder currently builds Torch 2.7 extensions based on 19 | the [final release candidate](https://dev-discuss.pytorch.org/t/pytorch-release-2-7-0-final-rc-is-available/2898). 20 | If you upload kernels Torch 2.7 kernels, please validate them against 21 | the final Torch 2.7.0 release. In the unlikely case of an ABI-breaking 22 | change, you can rebuild and upload a your kernel once kernel-builder 23 | is updated for the final release. 24 | 25 | ## 🚀 Quick Start 26 | 27 | We provide Docker containers for building kernels. For a quick build: 28 | 29 | ```bash 30 | # Using the prebuilt container 31 | docker run --mount type=bind,source=$(pwd),target=/kernelcode ghcr.io/huggingface/kernel-builder:{SHA} 32 | ``` 33 | 34 | or build the container locally: 35 | 36 | ```bash 37 | docker build -t kernel-builder:local -f dockerfiles/Dockerfile . 38 | docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:local 39 | ``` 40 | 41 | See [dockerfiles/README.md](./dockerfiles/README.md) for more options, including a user-level container for CI/CD environments. 42 | 43 | # 📚 Documentation 44 | 45 | - [Writing Hub kernels](./docs/writing-kernels.md) 46 | - [Building kernels with Docker](./docs/docker.md) 47 | - [Building kernels with Nix](./docs/nix.md) 48 | - [Local kernel development](docs/local-dev.md) (IDE integration) 49 | - [Why Nix?](./docs/why-nix.md) 50 | 51 | ## Credits 52 | 53 | The generated CMake build files are based on the vLLM build infrastructure. 54 | -------------------------------------------------------------------------------- /build-variants.json: -------------------------------------------------------------------------------- 1 | { 2 | "aarch64-darwin": {}, 3 | "aarch64-linux": { 4 | "cuda": [ 5 | "torch26-cxx11-cu126-aarch64-linux", 6 | "torch26-cxx98-cu126-aarch64-linux", 7 | "torch27-cxx11-cu126-aarch64-linux", 8 | "torch27-cxx11-cu128-aarch64-linux" 9 | ] 10 | }, 11 | "x86_64-linux": { 12 | "cuda": [ 13 | "torch26-cxx11-cu118-x86_64-linux", 14 | "torch26-cxx11-cu124-x86_64-linux", 15 | "torch26-cxx11-cu126-x86_64-linux", 16 | "torch26-cxx98-cu118-x86_64-linux", 17 | "torch26-cxx98-cu124-x86_64-linux", 18 | "torch26-cxx98-cu126-x86_64-linux", 19 | "torch27-cxx11-cu118-x86_64-linux", 20 | "torch27-cxx11-cu126-x86_64-linux", 21 | "torch27-cxx11-cu128-x86_64-linux" 22 | ], 23 | "rocm": [ 24 | "torch26-cxx11-rocm62-x86_64-linux", 25 | "torch27-cxx11-rocm63-x86_64-linux" 26 | ] 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /build2cmake/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "build2cmake" 3 | version = "0.2.1" 4 | edition = "2021" 5 | description = "Generate CMake files for kernel-builder projects" 6 | homepage = "https://github.com/huggingface/kernel-builder" 7 | license = "Apache-2.0" 8 | documentation = "https://docs.rs/build2cmake" 9 | repository = "https://github.com/huggingface/kernel-builder" 10 | 11 | [dependencies] 12 | base32 = "0.5" 13 | clap = { version = "4", features = ["derive"] } 14 | eyre = "0.6.12" 15 | git2 = "0.20" 16 | itertools = "0.13" 17 | minijinja = "2.5" 18 | minijinja-embed = "2.5" 19 | rand = "0.8" 20 | serde = { version = "1", features = ["derive"] } 21 | serde_json = "1" 22 | serde-value = "0.7" 23 | toml = "0.8" 24 | 25 | [build-dependencies] 26 | minijinja-embed = "2.5" 27 | -------------------------------------------------------------------------------- /build2cmake/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | minijinja_embed::embed_templates!("src/templates"); 3 | } 4 | -------------------------------------------------------------------------------- /build2cmake/flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1731533236, 9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "nixpkgs": { 22 | "locked": { 23 | "lastModified": 1734424634, 24 | "narHash": "sha256-cHar1vqHOOyC7f1+tVycPoWTfKIaqkoe1Q6TnKzuti4=", 25 | "owner": "nixos", 26 | "repo": "nixpkgs", 27 | "rev": "d3c42f187194c26d9f0309a8ecc469d6c878ce33", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "nixos", 32 | "ref": "nixos-unstable", 33 | "repo": "nixpkgs", 34 | "type": "github" 35 | } 36 | }, 37 | "root": { 38 | "inputs": { 39 | "flake-utils": "flake-utils", 40 | "nixpkgs": "nixpkgs" 41 | } 42 | }, 43 | "systems": { 44 | "locked": { 45 | "lastModified": 1681028828, 46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 47 | "owner": "nix-systems", 48 | "repo": "default", 49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 50 | "type": "github" 51 | }, 52 | "original": { 53 | "owner": "nix-systems", 54 | "repo": "default", 55 | "type": "github" 56 | } 57 | } 58 | }, 59 | "root": "root", 60 | "version": 7 61 | } 62 | -------------------------------------------------------------------------------- /build2cmake/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "A very basic flake"; 3 | 4 | inputs = { 5 | flake-utils.url = "github:numtide/flake-utils"; 6 | nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable"; 7 | }; 8 | 9 | outputs = 10 | { 11 | self, 12 | flake-utils, 13 | nixpkgs, 14 | }: 15 | flake-utils.lib.eachDefaultSystem ( 16 | system: 17 | let 18 | pkgs = nixpkgs.legacyPackages.${system}; 19 | in 20 | { 21 | 22 | devShells.default = 23 | with pkgs; 24 | mkShell { 25 | buildInputs = [ 26 | cargo 27 | clippy 28 | openssl.dev 29 | pkg-config 30 | rustc 31 | rustfmt 32 | rust-analyzer 33 | ]; 34 | 35 | RUST_SRC_PATH = "${rustPlatform.rustLibSrc}"; 36 | }; 37 | } 38 | ); 39 | } 40 | -------------------------------------------------------------------------------- /build2cmake/src/config/mod.rs: -------------------------------------------------------------------------------- 1 | use eyre::Result; 2 | use serde::Deserialize; 3 | 4 | pub mod v1; 5 | 6 | mod v2; 7 | use serde_value::Value; 8 | pub use v2::{Backend, Build, Dependencies, Kernel, Torch}; 9 | 10 | #[derive(Debug)] 11 | pub enum BuildCompat { 12 | V1(v1::Build), 13 | V2(Build), 14 | } 15 | 16 | impl<'de> Deserialize<'de> for BuildCompat { 17 | fn deserialize(deserializer: D) -> std::result::Result 18 | where 19 | D: serde::Deserializer<'de>, 20 | { 21 | let value = Value::deserialize(deserializer)?; 22 | 23 | match v1::Build::deserialize(value.clone()) { 24 | Ok(v1_build) => Ok(BuildCompat::V1(v1_build)), 25 | Err(_) => { 26 | let v2_build: Build = 27 | Build::deserialize(value).map_err(serde::de::Error::custom)?; 28 | Ok(BuildCompat::V2(v2_build)) 29 | } 30 | } 31 | } 32 | } 33 | 34 | impl TryFrom for Build { 35 | type Error = eyre::Error; 36 | 37 | fn try_from(compat: BuildCompat) -> Result { 38 | match compat { 39 | BuildCompat::V1(v1_build) => v1_build.try_into(), 40 | BuildCompat::V2(v2_build) => Ok(v2_build), 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /build2cmake/src/config/v1.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, fmt::Display, path::PathBuf}; 2 | 3 | use serde::Deserialize; 4 | 5 | use super::v2::Dependencies; 6 | 7 | #[derive(Debug, Deserialize)] 8 | #[serde(deny_unknown_fields)] 9 | pub struct Build { 10 | pub general: General, 11 | pub torch: Option, 12 | 13 | #[serde(rename = "kernel", default)] 14 | pub kernels: HashMap, 15 | } 16 | 17 | #[derive(Debug, Deserialize)] 18 | #[serde(deny_unknown_fields)] 19 | pub struct General { 20 | pub name: String, 21 | } 22 | 23 | #[derive(Debug, Deserialize, Clone)] 24 | #[serde(deny_unknown_fields)] 25 | pub struct Torch { 26 | pub include: Option>, 27 | pub pyext: Option>, 28 | 29 | #[serde(default)] 30 | pub src: Vec, 31 | 32 | #[serde(default)] 33 | pub universal: bool, 34 | } 35 | 36 | #[derive(Debug, Deserialize)] 37 | #[serde(deny_unknown_fields, rename_all = "kebab-case")] 38 | pub struct Kernel { 39 | pub cuda_capabilities: Option>, 40 | pub rocm_archs: Option>, 41 | #[serde(default)] 42 | pub language: Language, 43 | pub depends: Vec, 44 | pub include: Option>, 45 | pub src: Vec, 46 | } 47 | 48 | #[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq)] 49 | #[serde(deny_unknown_fields, rename_all = "kebab-case")] 50 | pub enum Language { 51 | #[default] 52 | Cuda, 53 | CudaHipify, 54 | Metal, 55 | } 56 | 57 | impl Display for Language { 58 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 59 | match self { 60 | Language::Cuda => f.write_str("cuda"), 61 | Language::CudaHipify => f.write_str("cuda-hipify"), 62 | Language::Metal => f.write_str("metal"), 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /build2cmake/src/config/v2.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{BTreeSet, HashMap}, 3 | fmt::Display, 4 | path::PathBuf, 5 | str::FromStr, 6 | }; 7 | 8 | use eyre::{bail, Result}; 9 | use itertools::Itertools; 10 | use serde::{Deserialize, Serialize}; 11 | 12 | use super::v1::{self, Language}; 13 | 14 | #[derive(Debug, Deserialize, Serialize)] 15 | #[serde(deny_unknown_fields)] 16 | pub struct Build { 17 | pub general: General, 18 | pub torch: Option, 19 | 20 | #[serde(rename = "kernel", default)] 21 | pub kernels: HashMap, 22 | } 23 | 24 | impl Build { 25 | pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool { 26 | self.kernels 27 | .values() 28 | .any(|kernel| kernel.backend == *backend) 29 | } 30 | 31 | pub fn backends(&self) -> BTreeSet { 32 | self.kernels.values().map(|kernel| kernel.backend).collect() 33 | } 34 | } 35 | 36 | #[derive(Debug, Deserialize, Serialize)] 37 | #[serde(deny_unknown_fields, rename_all = "kebab-case")] 38 | pub struct General { 39 | pub name: String, 40 | #[serde(default)] 41 | pub universal: bool, 42 | } 43 | 44 | #[derive(Debug, Deserialize, Clone, Serialize)] 45 | #[serde(deny_unknown_fields)] 46 | pub struct Torch { 47 | pub include: Option>, 48 | pub pyext: Option>, 49 | 50 | #[serde(default)] 51 | pub src: Vec, 52 | } 53 | 54 | impl Torch { 55 | pub fn data_globs(&self) -> Option> { 56 | match self.pyext.as_ref() { 57 | Some(exts) => { 58 | let globs = exts 59 | .iter() 60 | .filter(|&ext| ext != "py" && ext != "pyi") 61 | .map(|ext| format!("\"**/*.{}\"", ext)) 62 | .collect_vec(); 63 | if globs.is_empty() { 64 | None 65 | } else { 66 | Some(globs) 67 | } 68 | } 69 | 70 | None => None, 71 | } 72 | } 73 | } 74 | 75 | #[derive(Debug, Deserialize, Serialize)] 76 | #[serde(deny_unknown_fields, rename_all = "kebab-case")] 77 | pub struct Kernel { 78 | pub backend: Backend, 79 | pub cuda_capabilities: Option>, 80 | pub rocm_archs: Option>, 81 | pub depends: Vec, 82 | pub include: Option>, 83 | pub src: Vec, 84 | } 85 | 86 | #[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] 87 | #[serde(deny_unknown_fields, rename_all = "kebab-case")] 88 | pub enum Backend { 89 | Cuda, 90 | Metal, 91 | Rocm, 92 | } 93 | 94 | impl Display for Backend { 95 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 96 | match self { 97 | Backend::Cuda => write!(f, "cuda"), 98 | Backend::Metal => write!(f, "metal"), 99 | Backend::Rocm => write!(f, "rocm"), 100 | } 101 | } 102 | } 103 | 104 | impl FromStr for Backend { 105 | type Err = String; 106 | 107 | fn from_str(s: &str) -> Result { 108 | match s.to_lowercase().as_str() { 109 | "cuda" => Ok(Backend::Cuda), 110 | "metal" => Ok(Backend::Metal), 111 | "rocm" => Ok(Backend::Rocm), 112 | _ => Err(format!("Unknown backend: {}", s)), 113 | } 114 | } 115 | } 116 | 117 | #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] 118 | #[non_exhaustive] 119 | #[serde(rename_all = "lowercase")] 120 | pub enum Dependencies { 121 | #[serde(rename = "cutlass_2_10")] 122 | Cutlass2_10, 123 | #[serde(rename = "cutlass_3_5")] 124 | Cutlass3_5, 125 | #[serde(rename = "cutlass_3_6")] 126 | Cutlass3_6, 127 | #[serde(rename = "cutlass_3_8")] 128 | Cutlass3_8, 129 | Torch, 130 | } 131 | 132 | impl TryFrom for Build { 133 | type Error = eyre::Error; 134 | 135 | fn try_from(build: v1::Build) -> Result { 136 | let universal = build 137 | .torch 138 | .as_ref() 139 | .map(|torch| torch.universal) 140 | .unwrap_or(false); 141 | Ok(Self { 142 | general: General::from(build.general, universal), 143 | torch: build.torch.map(Into::into), 144 | kernels: convert_kernels(build.kernels)?, 145 | }) 146 | } 147 | } 148 | 149 | impl General { 150 | fn from(general: v1::General, universal: bool) -> Self { 151 | Self { 152 | name: general.name, 153 | universal, 154 | } 155 | } 156 | } 157 | 158 | fn convert_kernels(v1_kernels: HashMap) -> Result> { 159 | let mut kernels = HashMap::new(); 160 | 161 | for (name, kernel) in v1_kernels { 162 | if kernel.language == Language::CudaHipify { 163 | // We need to add an affix to avoid confflict with the CUDA kernel. 164 | let rocm_name = format!("{name}_rocm"); 165 | if kernels.contains_key(&rocm_name) { 166 | bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`") 167 | } 168 | 169 | kernels.insert( 170 | format!("{name}_rocm"), 171 | Kernel { 172 | backend: Backend::Rocm, 173 | cuda_capabilities: None, 174 | rocm_archs: kernel.rocm_archs, 175 | depends: kernel.depends.clone(), 176 | include: kernel.include.clone(), 177 | src: kernel.src.clone(), 178 | }, 179 | ); 180 | } 181 | 182 | kernels.insert( 183 | name, 184 | Kernel { 185 | backend: Backend::Cuda, 186 | cuda_capabilities: kernel.cuda_capabilities, 187 | rocm_archs: None, 188 | depends: kernel.depends, 189 | include: kernel.include, 190 | src: kernel.src, 191 | }, 192 | ); 193 | } 194 | 195 | Ok(kernels) 196 | } 197 | 198 | impl From for Torch { 199 | fn from(torch: v1::Torch) -> Self { 200 | Self { 201 | include: torch.include, 202 | pyext: torch.pyext, 203 | src: torch.src, 204 | } 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /build2cmake/src/cuda_supported_archs.json: -------------------------------------------------------------------------------- 1 | ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0"] 2 | -------------------------------------------------------------------------------- /build2cmake/src/fileset.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::path::{Path, PathBuf}; 3 | 4 | use eyre::{bail, eyre, Context, Result}; 5 | 6 | pub struct FileSet(HashMap>); 7 | 8 | impl FileSet { 9 | pub fn new() -> FileSet { 10 | FileSet(HashMap::new()) 11 | } 12 | 13 | fn check_exist(&self, target_dir: &Path) -> Result<()> { 14 | let mut existing = Vec::new(); 15 | for path in self.0.keys() { 16 | let full_path = target_dir.join(path); 17 | if full_path.exists() { 18 | existing.push(path.to_string_lossy().into_owned()); 19 | } 20 | } 21 | 22 | if !existing.is_empty() { 23 | bail!( 24 | "File(s) already exists in target directory: {}\nUse `--force` to overwrite.", 25 | existing.join(", ") 26 | ); 27 | } 28 | 29 | Ok(()) 30 | } 31 | 32 | pub fn entry(&mut self, path: impl Into) -> &mut Vec { 33 | self.0.entry(path.into()).or_default() 34 | } 35 | 36 | pub fn write(&self, target_dir: &Path, force: bool) -> Result<()> { 37 | // Check that the paths do not exist and that we can write. 38 | if !force { 39 | self.check_exist(target_dir)?; 40 | } 41 | 42 | for (path, content) in &self.0 { 43 | let full_path = target_dir.join(path); 44 | write_to_file(&full_path, content)?; 45 | } 46 | 47 | Ok(()) 48 | } 49 | } 50 | 51 | impl Default for FileSet { 52 | fn default() -> Self { 53 | FileSet::new() 54 | } 55 | } 56 | 57 | fn write_to_file(path: impl AsRef, data: &[u8]) -> Result<()> { 58 | let path = path.as_ref(); 59 | 60 | let parent = path 61 | .parent() 62 | .ok_or_else(|| eyre!("Cannot get parent of `{}`", path.to_string_lossy()))?; 63 | std::fs::create_dir_all(parent) 64 | .wrap_err_with(|| format!("Cannot create directory `{}`", parent.to_string_lossy()))?; 65 | 66 | std::fs::write(path, data) 67 | .wrap_err_with(|| format!("Cannot create: {}", path.to_string_lossy())) 68 | } 69 | -------------------------------------------------------------------------------- /build2cmake/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{BufWriter, Read, Write}, 4 | path::{Path, PathBuf}, 5 | }; 6 | 7 | use clap::{Parser, Subcommand}; 8 | use eyre::{bail, ensure, Context, Result}; 9 | use minijinja::Environment; 10 | 11 | mod torch; 12 | use torch::{write_torch_ext, write_torch_ext_metal, write_torch_universal_ext}; 13 | 14 | mod config; 15 | use config::{Backend, Build, BuildCompat}; 16 | 17 | mod fileset; 18 | use fileset::FileSet; 19 | 20 | #[derive(Parser, Debug)] 21 | #[command(version, about, long_about = None)] 22 | struct Cli { 23 | #[command(subcommand)] 24 | command: Commands, 25 | } 26 | 27 | #[derive(Debug, Subcommand)] 28 | enum Commands { 29 | /// Generate CMake files for Torch extension builds. 30 | GenerateTorch { 31 | #[arg(name = "BUILD_TOML")] 32 | build_toml: PathBuf, 33 | 34 | /// The directory to write the generated files to 35 | /// (directory of `BUILD_TOML` when absent). 36 | #[arg(name = "TARGET_DIR")] 37 | target_dir: Option, 38 | 39 | /// Force-overwrite existing files. 40 | #[arg(short, long)] 41 | force: bool, 42 | 43 | /// This is an optional unique identifier that is suffixed to the 44 | /// kernel name to avoid name collisions. (e.g. Git SHA) 45 | #[arg(long)] 46 | ops_id: Option, 47 | 48 | #[arg(long)] 49 | backend: Option, 50 | }, 51 | 52 | /// Update a `build.toml` to the current format. 53 | UpdateBuild { 54 | #[arg(name = "BUILD_TOML")] 55 | build_toml: PathBuf, 56 | }, 57 | 58 | /// Validate the build.toml file. 59 | Validate { 60 | #[arg(name = "BUILD_TOML")] 61 | build_toml: PathBuf, 62 | }, 63 | } 64 | 65 | fn main() -> Result<()> { 66 | let args = Cli::parse(); 67 | match args.command { 68 | Commands::GenerateTorch { 69 | backend, 70 | build_toml, 71 | force, 72 | target_dir, 73 | ops_id, 74 | } => generate_torch(backend, build_toml, target_dir, force, ops_id), 75 | Commands::UpdateBuild { build_toml } => update_build(build_toml), 76 | Commands::Validate { build_toml } => { 77 | parse_and_validate(build_toml)?; 78 | Ok(()) 79 | } 80 | } 81 | } 82 | 83 | fn generate_torch( 84 | backend: Option, 85 | build_toml: PathBuf, 86 | target_dir: Option, 87 | force: bool, 88 | ops_id: Option, 89 | ) -> Result<()> { 90 | let target_dir = check_or_infer_target_dir(&build_toml, target_dir)?; 91 | 92 | let build_compat = parse_and_validate(build_toml)?; 93 | 94 | if matches!(build_compat, BuildCompat::V1(_)) { 95 | eprintln!( 96 | "build.toml is in the deprecated V1 format, use `build2cmake update-build` to update." 97 | ) 98 | } 99 | 100 | let build: Build = build_compat 101 | .try_into() 102 | .context("Cannot update build configuration")?; 103 | 104 | let mut env = Environment::new(); 105 | env.set_trim_blocks(true); 106 | minijinja_embed::load_templates!(&mut env); 107 | 108 | let backend = match (backend, build.general.universal) { 109 | (None, true) => return write_torch_universal_ext(&env, &build, target_dir, force, ops_id), 110 | (Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend), 111 | (Some(backend), false) => { 112 | if !build.has_kernel_with_backend(&backend) { 113 | bail!("No kernels found for backend {}", backend); 114 | } 115 | 116 | backend 117 | } 118 | (None, false) => { 119 | let mut kernel_backends = build.backends(); 120 | let backend = if let Some(backend) = kernel_backends.pop_first() { 121 | backend 122 | } else { 123 | bail!("No kernels found in build.toml"); 124 | }; 125 | 126 | if !kernel_backends.is_empty() { 127 | let kernel_backends: Vec<_> = build 128 | .backends() 129 | .into_iter() 130 | .map(|backend| backend.to_string()) 131 | .collect(); 132 | bail!( 133 | "Multiple supported backends found in build.toml: {}. Please specify one with --backend.", 134 | kernel_backends.join(", ") 135 | ); 136 | } 137 | 138 | backend 139 | } 140 | }; 141 | 142 | match backend { 143 | Backend::Cuda | Backend::Rocm => { 144 | write_torch_ext(&env, backend, &build, target_dir, force, ops_id) 145 | } 146 | Backend::Metal => write_torch_ext_metal(&env, &build, target_dir, force, ops_id), 147 | } 148 | } 149 | 150 | fn update_build(build_toml: PathBuf) -> Result<()> { 151 | let build_compat: BuildCompat = parse_and_validate(&build_toml)?; 152 | 153 | if matches!(build_compat, BuildCompat::V2(_)) { 154 | return Ok(()); 155 | } 156 | 157 | let build: Build = build_compat 158 | .try_into() 159 | .context("Cannot update build configuration")?; 160 | let pretty_toml = toml::to_string_pretty(&build)?; 161 | 162 | let mut writer = 163 | BufWriter::new(File::create(&build_toml).wrap_err_with(|| { 164 | format!("Cannot open {} for writing", build_toml.to_string_lossy()) 165 | })?); 166 | writer 167 | .write_all(pretty_toml.as_bytes()) 168 | .wrap_err_with(|| format!("Cannot write to {}", build_toml.to_string_lossy()))?; 169 | 170 | Ok(()) 171 | } 172 | 173 | fn check_or_infer_target_dir( 174 | build_toml: impl AsRef, 175 | target_dir: Option, 176 | ) -> Result { 177 | let build_toml = build_toml.as_ref(); 178 | match target_dir { 179 | Some(target_dir) => { 180 | ensure!( 181 | target_dir.is_dir(), 182 | "`{}` is not a directory", 183 | target_dir.to_string_lossy() 184 | ); 185 | Ok(target_dir) 186 | } 187 | None => { 188 | let absolute = std::path::absolute(build_toml)?; 189 | match absolute.parent() { 190 | Some(parent) => Ok(parent.to_owned()), 191 | None => bail!( 192 | "Cannot get parent path of `{}`", 193 | build_toml.to_string_lossy() 194 | ), 195 | } 196 | } 197 | } 198 | } 199 | 200 | fn parse_and_validate(build_toml: impl AsRef) -> Result { 201 | let build_toml = build_toml.as_ref(); 202 | let mut toml_data = String::new(); 203 | File::open(build_toml) 204 | .wrap_err_with(|| format!("Cannot open {} for reading", build_toml.to_string_lossy()))? 205 | .read_to_string(&mut toml_data) 206 | .wrap_err_with(|| format!("Cannot read from {}", build_toml.to_string_lossy()))?; 207 | 208 | let build_compat: BuildCompat = toml::from_str(&toml_data) 209 | .wrap_err_with(|| format!("Cannot parse TOML in {}", build_toml.to_string_lossy()))?; 210 | 211 | Ok(build_compat) 212 | } 213 | -------------------------------------------------------------------------------- /build2cmake/src/templates/_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import {{ ops_name }} 3 | ops = torch.ops.{{ ops_name }} 4 | 5 | def add_op_namespace_prefix(op_name: str): 6 | """ 7 | Prefix op by namespace. 8 | """ 9 | return f"{{ ops_name }}::{op_name}" 10 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/dep-cutlass.cmake: -------------------------------------------------------------------------------- 1 | find_package(NvidiaCutlass) 2 | 3 | if (NOT NvidiaCutlass_FOUND) 4 | set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") 5 | 6 | # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. 7 | set(CUTLASS_REVISION "v{{ version }}" CACHE STRING "CUTLASS revision to use") 8 | 9 | 10 | # Use the specified CUTLASS source directory for compilation if CUTLASS_SRC_DIR is provided 11 | if (DEFINED ENV{CUTLASS_SRC_DIR}) 12 | set(CUTLASS_SRC_DIR $ENV{CUTLASS_SRC_DIR}) 13 | endif() 14 | 15 | if(CUTLASS_SRC_DIR) 16 | if(NOT IS_ABSOLUTE CUTLASS_SRC_DIR) 17 | get_filename_component(CUTLASS_SRC_DIR "${CUTLASS_SRC_DIR}" ABSOLUTE) 18 | endif() 19 | message(STATUS "The CUTLASS_SRC_DIR is set, using ${CUTLASS_SRC_DIR} for compilation") 20 | FetchContent_Declare(cutlass SOURCE_DIR ${CUTLASS_SRC_DIR}) 21 | else() 22 | FetchContent_Declare( 23 | cutlass 24 | GIT_REPOSITORY https://github.com/nvidia/cutlass.git 25 | GIT_TAG ${CUTLASS_REVISION} 26 | GIT_PROGRESS TRUE 27 | 28 | # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. 29 | # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. 30 | # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE 31 | GIT_SHALLOW TRUE 32 | ) 33 | endif() 34 | FetchContent_MakeAvailable(cutlass) 35 | 36 | include_directories(${CUTLASS_INCLUDE_DIR}) 37 | else() 38 | message(STATUS "Using system cutlass with version: ${NvidiaCutlass_VERSION}") 39 | endif(NOT NvidiaCutlass_FOUND) 40 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/hipify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py 5 | 6 | # 7 | # A command line tool for running pytorch's hipify preprocessor on CUDA 8 | # source files. 9 | # 10 | # See https://github.com/ROCm/hipify_torch 11 | # and /utils/hipify/hipify_python.py 12 | # 13 | 14 | import argparse 15 | import os 16 | import shutil 17 | 18 | from torch.utils.hipify.hipify_python import hipify 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | 23 | # Project directory where all the source + include files live. 24 | parser.add_argument( 25 | "-p", 26 | "--project_dir", 27 | help="The project directory.", 28 | ) 29 | 30 | # Directory where hipified files are written. 31 | parser.add_argument( 32 | "-o", 33 | "--output_dir", 34 | help="The output directory.", 35 | ) 36 | 37 | # Source files to convert. 38 | parser.add_argument("sources", 39 | help="Source files to hipify.", 40 | nargs="*", 41 | default=[]) 42 | 43 | args = parser.parse_args() 44 | 45 | # Limit include scope to project_dir only 46 | includes = [os.path.join(args.project_dir, '*')] 47 | 48 | # Get absolute path for all source files. 49 | extra_files = [os.path.abspath(s) for s in args.sources] 50 | 51 | # Copy sources from project directory to output directory. 52 | # The directory might already exist to hold object files so we ignore that. 53 | shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) 54 | 55 | hipify_result = hipify(project_directory=args.project_dir, 56 | output_directory=args.output_dir, 57 | header_include_dirs=[], 58 | includes=includes, 59 | extra_files=extra_files, 60 | show_detailed=True, 61 | is_pytorch_extension=True, 62 | hipify_extra_files_only=True) 63 | 64 | hipified_sources = [] 65 | for source in args.sources: 66 | s_abs = os.path.abspath(source) 67 | hipified_s_abs = (hipify_result[s_abs].hipified_path if 68 | (s_abs in hipify_result 69 | and hipify_result[s_abs].hipified_path is not None) 70 | else s_abs) 71 | hipified_sources.append(hipified_s_abs) 72 | 73 | assert (len(hipified_sources) == len(args.sources)) 74 | 75 | # Print hipified source files. 76 | print("\n".join(hipified_sources)) 77 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/kernel.cmake: -------------------------------------------------------------------------------- 1 | set({{kernel_name}}_SRC 2 | {{ sources }} 3 | ) 4 | 5 | {% if includes %} 6 | # TODO: check if CLion support this: 7 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories 8 | set_source_files_properties( 9 | {{'${' + kernel_name + '_SRC}'}} 10 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") 11 | {% endif %} 12 | 13 | if(GPU_LANG STREQUAL "CUDA") 14 | {% if cuda_capabilities %} 15 | cuda_archs_loose_intersection({{kernel_name}}_ARCHS "{{ cuda_capabilities|join(";") }}" "${CUDA_ARCHS}") 16 | {% else %} 17 | cuda_archs_loose_intersection({{kernel_name}}_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") 18 | {% endif %} 19 | message(STATUS "Capabilities for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}") 20 | set_gencode_flags_for_srcs(SRCS {{'"${' + kernel_name + '_SRC}"'}} CUDA_ARCHS "{{ '${' + kernel_name + '_ARCHS}'}}") 21 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) 22 | {% if supports_hipify %} 23 | elseif(GPU_LANG STREQUAL "HIP") 24 | hip_archs_loose_intersection({{kernel_name}}_ARCHS "{{ rocm_archs|join(";") }}" ${ROCM_ARCHS}) 25 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) 26 | {% endif %} 27 | endif() 28 | 29 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/preamble.cmake: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26) 2 | project({{name}} LANGUAGES CXX) 3 | 4 | set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel") 5 | 6 | install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) 7 | 8 | include(FetchContent) 9 | file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists 10 | message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") 11 | 12 | set(CUDA_SUPPORTED_ARCHS "{{ cuda_supported_archs }}") 13 | 14 | set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") 15 | 16 | include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) 17 | 18 | if(DEFINED Python_EXECUTABLE) 19 | # Allow passing through the interpreter (e.g. from setup.py). 20 | find_package(Python COMPONENTS Development Development.SABIModule Interpreter) 21 | if (NOT Python_FOUND) 22 | message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") 23 | endif() 24 | else() 25 | find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) 26 | endif() 27 | 28 | append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") 29 | 30 | find_package(Torch REQUIRED) 31 | 32 | if (NOT TARGET_DEVICE STREQUAL "cuda" AND 33 | NOT TARGET_DEVICE STREQUAL "rocm") 34 | return() 35 | endif() 36 | 37 | if (NOT HIP_FOUND AND CUDA_FOUND) 38 | set(GPU_LANG "CUDA") 39 | elseif(HIP_FOUND) 40 | set(GPU_LANG "HIP") 41 | 42 | # Importing torch recognizes and sets up some HIP/ROCm configuration but does 43 | # not let cmake recognize .hip files. In order to get cmake to understand the 44 | # .hip extension automatically, HIP must be enabled explicitly. 45 | enable_language(HIP) 46 | else() 47 | message(FATAL_ERROR "Can't find CUDA or HIP installation.") 48 | endif() 49 | 50 | 51 | if(GPU_LANG STREQUAL "CUDA") 52 | clear_cuda_arches(CUDA_ARCH_FLAGS) 53 | extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") 54 | message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") 55 | # Filter the target architectures by the supported supported archs 56 | # since for some files we will build for all CUDA_ARCHS. 57 | cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") 58 | message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") 59 | 60 | if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA") 61 | list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}") 62 | endif() 63 | 64 | add_compile_definitions(CUDA_KERNEL) 65 | elseif(GPU_LANG STREQUAL "HIP") 66 | set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}") 67 | # TODO: remove this once we can set specific archs per source file set. 68 | override_gpu_arches(GPU_ARCHES 69 | ${GPU_LANG} 70 | "${${GPU_LANG}_SUPPORTED_ARCHS}") 71 | 72 | add_compile_definitions(ROCM_KERNEL) 73 | else() 74 | override_gpu_arches(GPU_ARCHES 75 | ${GPU_LANG} 76 | "${${GPU_LANG}_SUPPORTED_ARCHS}") 77 | endif() 78 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from shutil import which, move 4 | import subprocess 5 | import sys 6 | from pathlib import Path 7 | 8 | from setuptools import Extension, find_packages, setup 9 | from setuptools.command.build_ext import build_ext 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def is_sccache_available() -> bool: 15 | return which("sccache") is not None 16 | 17 | 18 | def is_ccache_available() -> bool: 19 | return which("ccache") is not None 20 | 21 | 22 | def is_ninja_available() -> bool: 23 | return which("ninja") is not None 24 | 25 | 26 | class CMakeExtension(Extension): 27 | def __init__(self, name: str, sourcedir: str = "") -> None: 28 | super().__init__(name, sources=[], py_limited_api=True) 29 | self.sourcedir = os.fspath(Path(sourcedir).resolve()) 30 | 31 | 32 | class CMakeBuild(build_ext): 33 | def build_extension(self, ext: CMakeExtension) -> None: 34 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) 35 | extdir = ext_fullpath.parent.resolve() 36 | 37 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug 38 | cfg = "Debug" if debug else "Release" 39 | 40 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "") 41 | 42 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON 43 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code 44 | # from Python. 45 | cmake_args = [ 46 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", 47 | f"-DPython_EXECUTABLE={sys.executable}", 48 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm 49 | ] 50 | build_args = [] 51 | if "CMAKE_ARGS" in os.environ: 52 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] 53 | 54 | if not cmake_generator or cmake_generator == "Ninja": 55 | try: 56 | import ninja 57 | 58 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" 59 | cmake_args += [ 60 | "-GNinja", 61 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", 62 | ] 63 | except ImportError: 64 | pass 65 | 66 | if is_sccache_available(): 67 | cmake_args += [ 68 | "-DCMAKE_C_COMPILER_LAUNCHER=sccache", 69 | "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", 70 | "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", 71 | "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", 72 | ] 73 | elif is_ccache_available(): 74 | cmake_args += [ 75 | "-DCMAKE_C_COMPILER_LAUNCHER=ccache", 76 | "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", 77 | "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", 78 | "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", 79 | ] 80 | 81 | num_jobs = os.getenv("MAX_JOBS", None) 82 | if num_jobs is not None: 83 | num_jobs = int(num_jobs) 84 | logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) 85 | else: 86 | try: 87 | # os.sched_getaffinity() isn't universally available, so fall 88 | # back to os.cpu_count() if we get an error here. 89 | num_jobs = len(os.sched_getaffinity(0)) 90 | except AttributeError: 91 | num_jobs = os.cpu_count() 92 | 93 | nvcc_threads = os.getenv("NVCC_THREADS", None) 94 | if nvcc_threads is not None: 95 | nvcc_threads = int(nvcc_threads) 96 | logger.info( 97 | "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads 98 | ) 99 | else: 100 | nvcc_threads = 1 101 | num_jobs = max(1, num_jobs // nvcc_threads) 102 | 103 | build_args += [f"-j{num_jobs}"] 104 | if sys.platform == "win32": 105 | build_args += ["--config", cfg] 106 | 107 | if nvcc_threads: 108 | cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] 109 | 110 | build_temp = Path(self.build_temp) / ext.name 111 | if not build_temp.exists(): 112 | build_temp.mkdir(parents=True) 113 | 114 | subprocess.run( 115 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True 116 | ) 117 | subprocess.run( 118 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True 119 | ) 120 | if sys.platform == "win32": 121 | # Move the dylib one folder up for discovery. 122 | for filename in os.listdir(extdir / cfg): 123 | move(extdir / cfg / filename, extdir / filename) 124 | 125 | 126 | 127 | setup( 128 | name="{{ name }}", 129 | # The version is just a stub, it's not used by the final build artefact. 130 | version="0.1.0", 131 | ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")], 132 | cmdclass={"build_ext": CMakeBuild}, 133 | packages=find_packages(where="torch-ext", include=["{{ name }}*"]), 134 | package_dir={"": "torch-ext"}, 135 | {% if data_globs %} 136 | package_data={"{{ name }}": [ {{ data_globs }} ]}, 137 | {% endif %} 138 | zip_safe=False, 139 | install_requires=["torch"], 140 | python_requires=">=3.9", 141 | ) 142 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/torch-binding.cmake: -------------------------------------------------------------------------------- 1 | get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) 2 | list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) 3 | 4 | set(TORCH_{{name}}_SRC 5 | {{ src|join(' ') }} 6 | ) 7 | 8 | {% if includes %} 9 | # TODO: check if CLion support this: 10 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories 11 | set_source_files_properties( 12 | {{'${TORCH_' + name + '_SRC}'}} 13 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") 14 | {% endif %} 15 | 16 | list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}}) 17 | 18 | -------------------------------------------------------------------------------- /build2cmake/src/templates/cuda/torch-extension.cmake: -------------------------------------------------------------------------------- 1 | define_gpu_extension_target( 2 | {{ ops_name }} 3 | DESTINATION {{ ops_name }} 4 | LANGUAGE ${GPU_LANG} 5 | SOURCES ${SRC} 6 | COMPILE_FLAGS ${GPU_FLAGS} 7 | ARCHITECTURES ${GPU_ARCHES} 8 | #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} 9 | USE_SABI 3 10 | WITH_SOABI) 11 | 12 | target_link_options({{ ops_name }} PRIVATE -static-libstdc++) 13 | 14 | -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/kernel.cmake: -------------------------------------------------------------------------------- 1 | set({{kernel_name}}_SRC 2 | {{ sources }} 3 | ) 4 | 5 | {% if includes %} 6 | # TODO: check if CLion support this: 7 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories 8 | set_source_files_properties( 9 | {{'${' + kernel_name + '_SRC}'}} 10 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") 11 | {% endif %} 12 | 13 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/preamble.cmake: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26) 2 | project({{name}} LANGUAGES CXX) 3 | 4 | set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version") 5 | 6 | install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) 7 | 8 | include(FetchContent) 9 | file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists 10 | message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") 11 | 12 | include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) 13 | 14 | if(DEFINED Python_EXECUTABLE) 15 | # Allow passing through the interpreter (e.g. from setup.py). 16 | find_package(Python COMPONENTS Development Development.SABIModule Interpreter) 17 | if (NOT Python_FOUND) 18 | message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") 19 | endif() 20 | else() 21 | find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) 22 | endif() 23 | 24 | append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") 25 | 26 | find_package(Torch REQUIRED) 27 | 28 | add_compile_definitions(METAL_KERNEL) 29 | -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from shutil import which, move 4 | import subprocess 5 | import sys 6 | from pathlib import Path 7 | 8 | from setuptools import Extension, find_packages, setup 9 | from setuptools.command.build_ext import build_ext 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def is_sccache_available() -> bool: 15 | return which("sccache") is not None 16 | 17 | 18 | def is_ccache_available() -> bool: 19 | return which("ccache") is not None 20 | 21 | 22 | def is_ninja_available() -> bool: 23 | return which("ninja") is not None 24 | 25 | 26 | class CMakeExtension(Extension): 27 | def __init__(self, name: str, sourcedir: str = "") -> None: 28 | super().__init__(name, sources=[], py_limited_api=True) 29 | self.sourcedir = os.fspath(Path(sourcedir).resolve()) 30 | 31 | 32 | class CMakeBuild(build_ext): 33 | def build_extension(self, ext: CMakeExtension) -> None: 34 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) 35 | extdir = ext_fullpath.parent.resolve() 36 | 37 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug 38 | cfg = "Debug" if debug else "Release" 39 | 40 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "") 41 | 42 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON 43 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code 44 | # from Python. 45 | cmake_args = [ 46 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", 47 | f"-DPython_EXECUTABLE={sys.executable}", 48 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm 49 | ] 50 | build_args = [] 51 | if "CMAKE_ARGS" in os.environ: 52 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] 53 | 54 | if not cmake_generator or cmake_generator == "Ninja": 55 | try: 56 | import ninja 57 | 58 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" 59 | cmake_args += [ 60 | "-GNinja", 61 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", 62 | ] 63 | except ImportError: 64 | pass 65 | 66 | if is_sccache_available(): 67 | cmake_args += [ 68 | "-DCMAKE_C_COMPILER_LAUNCHER=sccache", 69 | "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", 70 | "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", 71 | "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache", 72 | "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache", 73 | ] 74 | elif is_ccache_available(): 75 | cmake_args += [ 76 | "-DCMAKE_C_COMPILER_LAUNCHER=ccache", 77 | "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", 78 | "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", 79 | "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache", 80 | "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache", 81 | ] 82 | 83 | num_jobs = os.getenv("MAX_JOBS", None) 84 | if num_jobs is not None: 85 | num_jobs = int(num_jobs) 86 | logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) 87 | else: 88 | try: 89 | # os.sched_getaffinity() isn't universally available, so fall 90 | # back to os.cpu_count() if we get an error here. 91 | num_jobs = len(os.sched_getaffinity(0)) 92 | except AttributeError: 93 | num_jobs = os.cpu_count() 94 | 95 | build_temp = Path(self.build_temp) / ext.name 96 | if not build_temp.exists(): 97 | build_temp.mkdir(parents=True) 98 | 99 | subprocess.run( 100 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True 101 | ) 102 | subprocess.run( 103 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True 104 | ) 105 | 106 | 107 | setup( 108 | name="{{ name }}", 109 | # The version is just a stub, it's not used by the final build artefact. 110 | version="0.1.0", 111 | ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")], 112 | cmdclass={"build_ext": CMakeBuild}, 113 | packages=find_packages(where="torch-ext", include=["{{ name }}*"]), 114 | package_dir={"": "torch-ext"}, 115 | {% if data_globs %} 116 | package_data={"{{ name }}": [ {{ data_globs }} ]}, 117 | {% endif %} 118 | zip_safe=False, 119 | install_requires=["torch"], 120 | python_requires=">=3.9", 121 | ) 122 | -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/torch-binding.cmake: -------------------------------------------------------------------------------- 1 | #get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) 2 | #list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) 3 | 4 | set(TORCH_{{name}}_SRC 5 | {{ src|join(' ') }} 6 | ) 7 | 8 | {% if includes %} 9 | # TODO: check if CLion support this: 10 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories 11 | set_source_files_properties( 12 | {{'${TORCH_' + name + '_SRC}'}} 13 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") 14 | {% endif %} 15 | 16 | list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}}) 17 | -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/torch-extension.cmake: -------------------------------------------------------------------------------- 1 | define_gpu_extension_target( 2 | {{ ops_name }} 3 | DESTINATION {{ ops_name }} 4 | LANGUAGE ${GPU_LANG} 5 | SOURCES ${SRC} 6 | COMPILE_FLAGS ${GPU_FLAGS} 7 | ARCHITECTURES ${GPU_ARCHES} 8 | USE_SABI 3 9 | WITH_SOABI) -------------------------------------------------------------------------------- /build2cmake/src/templates/metal/utils.cmake: -------------------------------------------------------------------------------- 1 | # Run `EXPR` in python after importing `PKG`. Use the result of this to extend 2 | # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. 3 | macro (append_cmake_prefix_path PKG EXPR) 4 | run_python(_PREFIX_PATH 5 | "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") 6 | list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) 7 | endmacro() -------------------------------------------------------------------------------- /build2cmake/src/templates/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "cmake>=3.26", 4 | "ninja", 5 | "packaging", 6 | "setuptools>=61", 7 | "torch", 8 | "wheel", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | -------------------------------------------------------------------------------- /build2cmake/src/templates/registration.h: -------------------------------------------------------------------------------- 1 | // Registration macros from vLLM: 2 | // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | #define _CONCAT(A, B) A##B 9 | #define CONCAT(A, B) _CONCAT(A, B) 10 | 11 | #define _STRINGIFY(A) #A 12 | #define STRINGIFY(A) _STRINGIFY(A) 13 | 14 | // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME 15 | // could be a macro instead of a literal token. 16 | #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) 17 | 18 | // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME 19 | // could be a macro instead of a literal token. 20 | #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ 21 | TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) 22 | 23 | // REGISTER_EXTENSION allows the shared library to be loaded and initialized 24 | // via python's import statement. 25 | #define REGISTER_EXTENSION(NAME) \ 26 | PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ 27 | static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ 28 | STRINGIFY(NAME), nullptr, 0, nullptr}; \ 29 | return PyModule_Create(&module); \ 30 | } 31 | -------------------------------------------------------------------------------- /build2cmake/src/templates/universal/_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | ops = torch.ops.{{ ops_name }} 3 | 4 | def add_op_namespace_prefix(op_name: str): 5 | """ 6 | Prefix op by namespace. 7 | """ 8 | return f"{{ ops_name }}::{op_name}" 9 | -------------------------------------------------------------------------------- /build2cmake/src/templates/universal/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "{{ name }}" 3 | version = "0.0.1" 4 | requires-python = ">= 3.9" 5 | dependencies = ["torch>=2.4"] 6 | 7 | [tool.setuptools] 8 | package-dir = { "" = "torch-ext" } 9 | 10 | [tool.setuptools.packages.find] 11 | where = ["torch-ext"] 12 | include = ["{{name}}*"] 13 | 14 | {% if data_globs %} 15 | [tool.setuptools.package-data] 16 | {{ name }} = [ {{ data_globs }} ] 17 | {% endif %} 18 | -------------------------------------------------------------------------------- /build2cmake/src/torch/metal.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Write, path::PathBuf}; 2 | 3 | use eyre::{bail, Context, Result}; 4 | use itertools::Itertools; 5 | use minijinja::{context, Environment}; 6 | 7 | use super::kernel_ops_identifier; 8 | use crate::{ 9 | config::{Build, Kernel, Torch}, 10 | fileset::FileSet, 11 | }; 12 | 13 | static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); 14 | static REGISTRATION_H: &str = include_str!("../templates/registration.h"); 15 | 16 | pub fn write_torch_ext_metal( 17 | env: &Environment, 18 | build: &Build, 19 | target_dir: PathBuf, 20 | force: bool, 21 | ops_id: Option, 22 | ) -> Result<()> { 23 | let torch_ext = match build.torch.as_ref() { 24 | Some(torch_ext) => torch_ext, 25 | None => bail!("Build configuration does not have `torch` section"), 26 | }; 27 | 28 | let mut file_set = FileSet::default(); 29 | 30 | let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id); 31 | 32 | write_cmake( 33 | env, 34 | build, 35 | torch_ext, 36 | &build.general.name, 37 | &ops_name, 38 | &mut file_set, 39 | )?; 40 | 41 | write_setup_py( 42 | env, 43 | torch_ext, 44 | &build.general.name, 45 | &ops_name, 46 | &mut file_set, 47 | )?; 48 | 49 | write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?; 50 | 51 | write_pyproject_toml(env, &mut file_set)?; 52 | 53 | write_torch_registration_macros(&mut file_set)?; 54 | 55 | file_set.write(&target_dir, force)?; 56 | 57 | Ok(()) 58 | } 59 | 60 | fn write_cmake( 61 | env: &Environment, 62 | build: &Build, 63 | torch: &Torch, 64 | name: &str, 65 | ops_name: &str, 66 | file_set: &mut FileSet, 67 | ) -> Result<()> { 68 | let mut utils_path = PathBuf::new(); 69 | utils_path.push("cmake"); 70 | utils_path.push("utils.cmake"); 71 | file_set 72 | .entry(utils_path.clone()) 73 | .extend_from_slice(CMAKE_UTILS.as_bytes()); 74 | 75 | let cmake_writer = file_set.entry("CMakeLists.txt"); 76 | 77 | render_preamble(env, name, cmake_writer)?; 78 | 79 | // Add deps once we have any non-CUDA deps. 80 | // render_deps(env, build, cmake_writer)?; 81 | 82 | render_binding(env, torch, name, cmake_writer)?; 83 | 84 | for (kernel_name, kernel) in &build.kernels { 85 | render_kernel(env, kernel_name, kernel, cmake_writer)?; 86 | } 87 | 88 | render_extension(env, ops_name, cmake_writer)?; 89 | 90 | Ok(()) 91 | } 92 | 93 | fn render_binding( 94 | env: &Environment, 95 | torch: &Torch, 96 | name: &str, 97 | write: &mut impl Write, 98 | ) -> Result<()> { 99 | env.get_template("metal/torch-binding.cmake") 100 | .wrap_err("Cannot get Torch binding template")? 101 | .render_to_write( 102 | context! { 103 | includes => torch.include.as_ref().map(prefix_and_join_includes), 104 | name => name, 105 | src => torch.src 106 | }, 107 | &mut *write, 108 | ) 109 | .wrap_err("Cannot render Torch binding template")?; 110 | 111 | write.write_all(b"\n")?; 112 | 113 | Ok(()) 114 | } 115 | 116 | pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Write) -> Result<()> { 117 | env.get_template("metal/torch-extension.cmake") 118 | .wrap_err("Cannot get Torch extension template")? 119 | .render_to_write( 120 | context! { 121 | ops_name => ops_name, 122 | }, 123 | &mut *write, 124 | ) 125 | .wrap_err("Cannot render Torch extension template")?; 126 | 127 | write.write_all(b"\n")?; 128 | 129 | Ok(()) 130 | } 131 | 132 | pub fn render_kernel( 133 | env: &Environment, 134 | kernel_name: &str, 135 | kernel: &Kernel, 136 | write: &mut impl Write, 137 | ) -> Result<()> { 138 | // Easier to do in Rust than Jinja. 139 | let sources = kernel 140 | .src 141 | .iter() 142 | .map(|src| format!("\"{src}\"")) 143 | .collect_vec() 144 | .join("\n"); 145 | 146 | env.get_template("metal/kernel.cmake") 147 | .wrap_err("Cannot get kernel template")? 148 | .render_to_write( 149 | context! { 150 | includes => kernel.include.as_ref().map(prefix_and_join_includes), 151 | kernel_name => kernel_name, 152 | sources => sources, 153 | }, 154 | &mut *write, 155 | ) 156 | .wrap_err("Cannot render kernel template")?; 157 | 158 | write.write_all(b"\n")?; 159 | 160 | Ok(()) 161 | } 162 | 163 | fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { 164 | env.get_template("metal/preamble.cmake") 165 | .wrap_err("Cannot get CMake prelude template")? 166 | .render_to_write( 167 | context! { 168 | name => name, 169 | }, 170 | &mut *write, 171 | ) 172 | .wrap_err("Cannot render CMake prelude template")?; 173 | 174 | write.write_all(b"\n")?; 175 | 176 | Ok(()) 177 | } 178 | 179 | fn write_ops_py( 180 | env: &Environment, 181 | name: &str, 182 | ops_name: &str, 183 | file_set: &mut FileSet, 184 | ) -> Result<()> { 185 | let mut path = PathBuf::new(); 186 | path.push("torch-ext"); 187 | path.push(name); 188 | path.push("_ops.py"); 189 | let writer = file_set.entry(path); 190 | 191 | env.get_template("_ops.py") 192 | .wrap_err("Cannot get _ops.py template")? 193 | .render_to_write( 194 | context! { 195 | ops_name => ops_name, 196 | }, 197 | writer, 198 | ) 199 | .wrap_err("Cannot render kernel template")?; 200 | 201 | Ok(()) 202 | } 203 | 204 | fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { 205 | let writer = file_set.entry("pyproject.toml"); 206 | 207 | env.get_template("pyproject.toml") 208 | .wrap_err("Cannot get pyproject.toml template")? 209 | .render_to_write(context! {}, writer) 210 | .wrap_err("Cannot render kernel template")?; 211 | 212 | Ok(()) 213 | } 214 | 215 | fn write_setup_py( 216 | env: &Environment, 217 | torch: &Torch, 218 | name: &str, 219 | ops_name: &str, 220 | file_set: &mut FileSet, 221 | ) -> Result<()> { 222 | let writer = file_set.entry("setup.py"); 223 | 224 | let data_globs = torch.data_globs().map(|globs| globs.join(", ")); 225 | 226 | env.get_template("metal/setup.py") 227 | .wrap_err("Cannot get setup.py template")? 228 | .render_to_write( 229 | context! { 230 | data_globs => data_globs, 231 | ops_name => ops_name, 232 | name => name, 233 | version => "0.1.0", 234 | }, 235 | writer, 236 | ) 237 | .wrap_err("Cannot render kernel template")?; 238 | 239 | Ok(()) 240 | } 241 | 242 | fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { 243 | let mut path = PathBuf::new(); 244 | path.push("torch-ext"); 245 | path.push("registration.h"); 246 | file_set 247 | .entry(path) 248 | .extend_from_slice(REGISTRATION_H.as_bytes()); 249 | 250 | Ok(()) 251 | } 252 | 253 | fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String 254 | where 255 | S: AsRef, 256 | { 257 | includes 258 | .as_ref() 259 | .iter() 260 | .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) 261 | .collect_vec() 262 | .join(";") 263 | } 264 | -------------------------------------------------------------------------------- /build2cmake/src/torch/mod.rs: -------------------------------------------------------------------------------- 1 | mod cuda; 2 | pub use cuda::write_torch_ext; 3 | 4 | mod metal; 5 | pub use metal::write_torch_ext_metal; 6 | 7 | mod ops_identifier; 8 | pub(crate) use ops_identifier::kernel_ops_identifier; 9 | 10 | mod universal; 11 | pub use universal::write_torch_universal_ext; 12 | -------------------------------------------------------------------------------- /build2cmake/src/torch/ops_identifier.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use eyre::{Result, WrapErr}; 4 | use git2::Repository; 5 | use rand::Rng; 6 | 7 | fn random_identifier() -> String { 8 | // Generate a random string when no ops_id is provided 9 | let mut rng = rand::thread_rng(); 10 | let build_id: u64 = rng.gen(); 11 | base32::encode( 12 | base32::Alphabet::Rfc4648Lower { padding: false }, 13 | &build_id.to_le_bytes(), 14 | ) 15 | } 16 | 17 | fn git_identifier(target_dir: impl AsRef) -> Result { 18 | let repo = Repository::discover(target_dir.as_ref()).context("Cannot open git repository")?; 19 | let head = repo.head()?; 20 | let commit = head.peel_to_commit()?; 21 | let rev = commit.tree_id().to_string().chars().take(7).collect(); 22 | let dirty = !repo.statuses(None)?.is_empty(); 23 | Ok(if dirty { format!("{rev}_dirty") } else { rev }) 24 | } 25 | 26 | pub fn kernel_ops_identifier( 27 | target_dir: impl AsRef, 28 | name: &str, 29 | ops_id: Option, 30 | ) -> String { 31 | let identifier = ops_id.unwrap_or_else(|| match git_identifier(target_dir.as_ref()) { 32 | Ok(rev) => rev, 33 | Err(_) => random_identifier(), 34 | }); 35 | 36 | format!("_{name}_{identifier}") 37 | } 38 | -------------------------------------------------------------------------------- /build2cmake/src/torch/universal.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use eyre::{Context, Result}; 4 | use minijinja::{context, Environment}; 5 | 6 | use crate::{ 7 | config::{Build, Torch}, 8 | fileset::FileSet, 9 | torch::kernel_ops_identifier, 10 | }; 11 | 12 | pub fn write_torch_universal_ext( 13 | env: &Environment, 14 | build: &Build, 15 | target_dir: PathBuf, 16 | force: bool, 17 | ops_id: Option, 18 | ) -> Result<()> { 19 | let mut file_set = FileSet::default(); 20 | 21 | let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id); 22 | 23 | write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?; 24 | write_pyproject_toml( 25 | env, 26 | build.torch.as_ref(), 27 | &build.general.name, 28 | &mut file_set, 29 | )?; 30 | 31 | file_set.write(&target_dir, force)?; 32 | 33 | Ok(()) 34 | } 35 | 36 | fn write_ops_py( 37 | env: &Environment, 38 | name: &str, 39 | ops_name: &str, 40 | file_set: &mut FileSet, 41 | ) -> Result<()> { 42 | let mut path = PathBuf::new(); 43 | path.push("torch-ext"); 44 | path.push(name); 45 | path.push("_ops.py"); 46 | let writer = file_set.entry(path); 47 | 48 | env.get_template("universal/_ops.py") 49 | .wrap_err("Cannot get _ops-universal.py template")? 50 | .render_to_write( 51 | context! { 52 | ops_name => ops_name, 53 | }, 54 | writer, 55 | ) 56 | .wrap_err("Cannot render kernel template")?; 57 | 58 | Ok(()) 59 | } 60 | 61 | fn write_pyproject_toml( 62 | env: &Environment, 63 | torch: Option<&Torch>, 64 | name: &str, 65 | file_set: &mut FileSet, 66 | ) -> Result<()> { 67 | let writer = file_set.entry("pyproject.toml"); 68 | 69 | let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", "))); 70 | 71 | env.get_template("universal/pyproject.toml") 72 | .wrap_err("Cannot get universal pyproject.toml template")? 73 | .render_to_write( 74 | context! { 75 | data_globs => data_globs, 76 | name => name, 77 | }, 78 | writer, 79 | ) 80 | .wrap_err("Cannot render kernel template")?; 81 | 82 | Ok(()) 83 | } 84 | -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | (import ( 2 | let 3 | lock = builtins.fromJSON (builtins.readFile ./flake.lock); 4 | in 5 | fetchTarball { 6 | url = 7 | lock.nodes.flake-compat.locked.url 8 | or "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz"; 9 | sha256 = lock.nodes.flake-compat.locked.narHash; 10 | } 11 | ) { src = ./.; }).defaultNix 12 | -------------------------------------------------------------------------------- /dockerfiles/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nixos/nix:2.18.8 2 | 3 | # default build args 4 | ARG MAX_JOBS=4 5 | ARG CORES=4 6 | 7 | # Combine RUN commands to reduce layers and improve caching 8 | RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf \ 9 | && echo "max-jobs = $MAX_JOBS" >> /etc/nix/nix.conf \ 10 | && echo "cores = $CORES" >> /etc/nix/nix.conf \ 11 | && nix profile install nixpkgs#cachix nixpkgs#git-lfs \ 12 | && cachix use huggingface 13 | WORKDIR /app 14 | # Copy builder files 15 | COPY . /etc/kernel-builder/ 16 | # Set environment variables 17 | ENV MAX_JOBS=${MAX_JOBS} 18 | ENV CORES=${CORES} 19 | # Create directory and setup script 20 | RUN mkdir -p /etc/kernel-builder && \ 21 | cat <<'EOF' > /etc/kernel-builder/cli.sh 22 | #!/bin/sh 23 | set -e 24 | # Default values 25 | BUILD_URL="" 26 | DEV_SHELL=0 27 | HELP=0 28 | # CLI usage function 29 | function show_usage { 30 | echo "Kernel Builder CLI" 31 | echo "" 32 | echo "Usage: docker run [docker-options] kernel-builder:dev [command] [options]" 33 | echo "" 34 | echo "Commands:" 35 | echo " build Build the kernel extension (default if no command specified)" 36 | echo " dev Start a development shell" 37 | echo " fetch [URL] Clone and build from a Git URL" 38 | echo " help Show this help message" 39 | echo "" 40 | echo "Options:" 41 | echo " --jobs, -j NUMBER Set maximum number of parallel jobs (default: $MAX_JOBS)" 42 | echo " --cores, -c NUMBER Set number of cores per job (default: $CORES)" 43 | echo "" 44 | echo "Examples:" 45 | echo " docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build" 46 | echo " docker run -it --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root dev" 47 | echo " docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root fetch https://huggingface.co/user/repo.git" 48 | } 49 | 50 | # Function to generate a basic flake.nix if it doesn't exist 51 | function ensure_flake_exists { 52 | local work_dir=$1 53 | if [ ! -f "${work_dir}/flake.nix" ]; then 54 | echo "No flake.nix found, creating a basic one..." 55 | cat <<'FLAKE_EOF' > "${work_dir}/flake.nix" 56 | { 57 | description = "Flake for Torch kernel extension"; 58 | inputs = { 59 | kernel-builder.url = "github:huggingface/kernel-builder"; 60 | }; 61 | outputs = { self, kernel-builder, }: 62 | kernel-builder.lib.genFlakeOutputs { 63 | path = ./.; 64 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 65 | }; 66 | } 67 | FLAKE_EOF 68 | echo "flake.nix created. You can customize it as needed." 69 | else 70 | echo "flake.nix already exists, skipping creation." 71 | fi 72 | } 73 | # Function to build the extension 74 | function build_extension { 75 | local work_dir=$1 76 | local output_dir=$2 77 | 78 | echo "Building Torch Extension Bundle from ${work_dir}" 79 | cd "${work_dir}" 80 | 81 | # Check if work_dir is a git repo and get hash if possible 82 | if [ -d "${work_dir}/.git" ]; then 83 | # Mark git as safe to allow commands 84 | git config --global --add safe.directory "${work_dir}" 85 | # Try to get git revision 86 | REV=$(git rev-parse --short=8 HEAD) 87 | 88 | # Check if working directory is dirty 89 | if [ -n "$(git status --porcelain 2)" ]; then 90 | REV="${REV}-dirty" 91 | fi 92 | else 93 | # Generate random material if not a git repo 94 | REV=$(dd if=/dev/urandom status=none bs=1 count=10 2>/dev/null | base32 | tr '[:upper:]' '[:lower:]' | head -c 10) 95 | fi 96 | echo "Building with rev $REV" 97 | 98 | # Check for flake.nix or create one 99 | ensure_flake_exists "${work_dir}" 100 | 101 | # Make sure the build is up to date 102 | nix run github:huggingface/kernel-builder#update-build -- build.toml 103 | 104 | # Pure bundle build 105 | echo "Building with Nix..." 106 | nix build \ 107 | . \ 108 | --max-jobs $MAX_JOBS \ 109 | -j $CORES \ 110 | -L 111 | 112 | echo "Build completed. Copying results to ${output_dir}" 113 | mkdir -p "${output_dir}" 114 | cp -r --dereference ./result/* "${output_dir}/" 115 | # As root, ensure proper permissions for host access 116 | chmod -R 777 "${output_dir}" 117 | echo "Done - results available in ${output_dir}" 118 | } 119 | # Function to start a dev shell 120 | function start_dev_shell { 121 | echo "Starting development shell..." 122 | # Check for flake.nix or create one 123 | ensure_flake_exists "/kernelcode" 124 | cd /kernelcode 125 | /root/.nix-profile/bin/nix develop 126 | } 127 | # Function to fetch and build from URL 128 | function fetch_and_build { 129 | if [ -z "$1" ]; then 130 | echo "Error: URL required for fetch command" 131 | show_usage 132 | exit 1 133 | fi 134 | 135 | local repo_url="$1" 136 | local src_dir="/tmp/kernel-src" 137 | local output_dir="/kernelcode/result" 138 | 139 | echo "Fetching code from ${repo_url} to ${src_dir}" 140 | # Create a temporary directory for the clone 141 | mkdir -p "${src_dir}" 142 | 143 | # Clone the repository to the temporary directory 144 | git lfs install 145 | git clone "${repo_url}" "${src_dir}" 146 | 147 | # Build from the temporary directory and copy results to mounted output 148 | build_extension "${src_dir}" "${output_dir}" 149 | } 150 | # Parse arguments 151 | COMMAND="build" # Default command 152 | ARGS=() 153 | 154 | while [[ $# -gt 0 ]]; do 155 | case $1 in 156 | build|dev|fetch|help) 157 | COMMAND="$1" 158 | shift 159 | ;; 160 | --jobs|-j) 161 | MAX_JOBS="$2" 162 | shift 2 163 | ;; 164 | --cores|-c) 165 | CORES="$2" 166 | shift 2 167 | ;; 168 | -*) 169 | echo "Unknown option: $1" 170 | show_usage 171 | exit 1 172 | ;; 173 | *) 174 | ARGS+=("$1") 175 | shift 176 | ;; 177 | esac 178 | done 179 | # Execute the command 180 | case $COMMAND in 181 | build) 182 | # When building existing code, use the mounted directory 183 | build_extension "/kernelcode" "/kernelcode/build" 184 | ;; 185 | dev) 186 | start_dev_shell 187 | ;; 188 | fetch) 189 | fetch_and_build "${ARGS[0]}" 190 | ;; 191 | help) 192 | show_usage 193 | ;; 194 | *) 195 | echo "Unknown command: $COMMAND" 196 | show_usage 197 | exit 1 198 | ;; 199 | esac 200 | EOF 201 | RUN chmod +x /etc/kernel-builder/cli.sh 202 | # Create output directory structure 203 | RUN mkdir -p /kernelcode/build 204 | # Set up volume for kernelcode 205 | VOLUME /kernelcode 206 | 207 | ENTRYPOINT ["/etc/kernel-builder/cli.sh"] 208 | -------------------------------------------------------------------------------- /dockerfiles/Dockerfile.user: -------------------------------------------------------------------------------- 1 | FROM nixos/nix:2.18.8 2 | 3 | # default build args 4 | ARG MAX_JOBS=1 5 | ARG CORES=1 6 | 7 | # Set up Nix configuration and user 8 | RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf \ 9 | && echo "max-jobs = $MAX_JOBS" >> /etc/nix/nix.conf \ 10 | && echo "cores = $CORES" >> /etc/nix/nix.conf \ 11 | && echo "trusted-users = root nixuser" >> /etc/nix/nix.conf \ 12 | # Create user entries directly in password and group files 13 | && echo "nixuser:x:1000:1000:NixOS User:/home/nixuser:/bin/bash" >> /etc/passwd \ 14 | && echo "nixuser:x:1000:" >> /etc/group \ 15 | && mkdir -p /home/nixuser/kernelcode \ 16 | # Create Nix directories with proper permissions 17 | && mkdir -p /nix/var/nix/profiles/per-user/nixuser \ 18 | && mkdir -p /nix/var/nix/gcroots/per-user/nixuser \ 19 | && chown -R 1000:1000 /home/nixuser /nix/var/nix/profiles/per-user/nixuser /nix/var/nix/gcroots/per-user/nixuser \ 20 | # Install necessary packages 21 | && nix profile install nixpkgs#cachix nixpkgs#git-lfs nixpkgs#gawk \ 22 | && cachix use huggingface 23 | 24 | # Set permissions for Nix directories 25 | RUN chown -R nixuser:nixuser /nix 26 | 27 | # Set working directory and copy files 28 | WORKDIR /home/nixuser/kernelcode 29 | COPY --chown=nixuser:nixuser . /home/nixuser/kernel-builder/ 30 | 31 | # Set environment variables 32 | ENV MAX_JOBS=${MAX_JOBS} 33 | ENV CORES=${CORES} 34 | ENV HF_TOKEN=${HF_TOKEN} 35 | ENV HOME=/home/nixuser 36 | ENV PUSH_REVISION=hfjob-build 37 | ENV REPO=kernels-community/job-build-test-repo 38 | 39 | # Set up CLI script in nixuser's home 40 | RUN mkdir -p /home/nixuser/bin && \ 41 | cat <<'EOF' > /home/nixuser/bin/cli.sh 42 | #!/bin/sh 43 | set -e 44 | 45 | # Default values 46 | BUILD_URL="" 47 | DEV_SHELL=0 48 | HELP=0 49 | 50 | # CLI usage function 51 | function show_usage { 52 | echo "Kernel Builder CLI" 53 | echo "" 54 | echo "Usage: docker run [docker-options] kernel-builder:dev [command] [options]" 55 | echo "" 56 | echo "Commands:" 57 | echo " build Build the kernel extension (default if no command specified)" 58 | echo " dev Start a development shell" 59 | echo " fetch [URL] Clone and build from a Git URL" 60 | echo " help Show this help message" 61 | echo "" 62 | echo "Options:" 63 | echo " --jobs, -j NUMBER Set maximum number of parallel jobs (default: $MAX_JOBS)" 64 | echo " --cores, -c NUMBER Set number of cores per job (default: $CORES)" 65 | echo "" 66 | echo "Examples:" 67 | echo " docker run -v \$(pwd):/home/nixuser/kernelcode kernel-builder:dev build" 68 | echo " docker run -it -v \$(pwd):/home/nixuser/kernelcode kernel-builder:dev dev" 69 | echo " docker run kernel-builder:dev fetch https://huggingface.co/user/repo.git" 70 | } 71 | 72 | # Function to generate a basic flake.nix if it doesn't exist 73 | function ensure_flake_exists { 74 | if [ ! -f "/home/nixuser/kernelcode/flake.nix" ]; then 75 | echo "No flake.nix found, creating a basic one..." 76 | cat <<'FLAKE_EOF' > /home/nixuser/kernelcode/flake.nix 77 | { 78 | description = "Flake for Torch kernel extension"; 79 | 80 | inputs = { 81 | kernel-builder.url = "github:huggingface/kernel-builder"; 82 | }; 83 | 84 | outputs = { self, kernel-builder, }: 85 | kernel-builder.lib.genFlakeOutputs { 86 | path = ./.; 87 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 88 | }; 89 | } 90 | FLAKE_EOF 91 | echo "flake.nix created. You can customize it as needed." 92 | else 93 | echo "flake.nix already exists, skipping creation." 94 | fi 95 | } 96 | 97 | # Function to build the extension 98 | function build_extension { 99 | echo "Building Torch Extension Bundle" 100 | # Check if kernelcode is a git repo and get hash if possible 101 | if [ -d "/home/nixuser/kernelcode/.git" ]; then 102 | # Mark git as safe to allow commands 103 | git config --global --add safe.directory /home/nixuser/kernelcode 104 | # Try to get git revision 105 | REV=$(git rev-parse --short=8 HEAD) 106 | 107 | # Check if working directory is dirty 108 | if [ -n "$(git status --porcelain 2)" ]; then 109 | REV="${REV}-dirty" 110 | fi 111 | else 112 | # Generate random material if not a git repo 113 | REV=$(dd if=/dev/urandom status=none bs=1 count=10 2>/dev/null | base32 | tr '[:upper:]' '[:lower:]' | head -c 10) 114 | fi 115 | echo "Building with rev $REV" 116 | 117 | # Check for flake.nix or create one 118 | ensure_flake_exists 119 | 120 | # Make sure the build is up to date 121 | nix run github:huggingface/kernel-builder#update-build -- build.toml 122 | 123 | # Pure bundle build 124 | # TODO: remove the "bundle" after resolving 125 | echo "Building with Nix..." 126 | nix build \ 127 | .\#bundle \ 128 | --max-jobs $MAX_JOBS \ 129 | -j $CORES \ 130 | -L 2>&1 | awk '{ print strftime("[%Y-%m-%d %H:%M:%S]"), $0; fflush(); }' 131 | 132 | echo "Build completed. Copying results to /home/nixuser/kernelcode/build/" 133 | mkdir -p /home/nixuser/kernelcode/build 134 | cp -r --dereference ./result/* /home/nixuser/kernelcode/build/ 135 | chmod -R u+w /home/nixuser/kernelcode/build 136 | echo 'Done' 137 | } 138 | 139 | # Function to start a dev shell 140 | function start_dev_shell { 141 | echo "Starting development shell..." 142 | # Check for flake.nix or create one 143 | ensure_flake_exists 144 | nix develop 145 | } 146 | 147 | # Function to fetch and build from URL 148 | function fetch_and_build { 149 | if [ -z "$1" ]; then 150 | echo "Error: URL required for fetch command" 151 | show_usage 152 | exit 1 153 | fi 154 | 155 | echo "Fetching code from $1" 156 | rm -rf /home/nixuser/kernelcode/* /home/nixuser/kernelcode/.* 2>/dev/null || true 157 | git lfs install 158 | git clone "$1" /home/nixuser/kernelcode 159 | cd /home/nixuser/kernelcode 160 | build_extension 161 | echo "Build completed. Results are in /home/nixuser/kernelcode/build/" 162 | 163 | # skip login to huggingface since token is set in the env 164 | # check user 165 | nix shell nixpkgs#python3 nixpkgs#python3Packages.huggingface-hub -c huggingface-cli whoami 166 | 167 | # upload the build to the repo 168 | nix shell nixpkgs#python3 nixpkgs#python3Packages.huggingface-hub -c huggingface-cli \ 169 | upload \ 170 | --revision ${PUSH_REVISION} \ 171 | --commit-message "Build from kernel-builder job" \ 172 | ${REPO} \ 173 | /home/nixuser/kernelcode/build/ \ 174 | build/ 175 | } 176 | 177 | # Parse arguments 178 | COMMAND="build" # Default command 179 | ARGS=() 180 | 181 | while [[ $# -gt 0 ]]; do 182 | case $1 in 183 | build|dev|fetch|help) 184 | COMMAND="$1" 185 | shift 186 | ;; 187 | --jobs|-j) 188 | MAX_JOBS="$2" 189 | shift 2 190 | ;; 191 | --cores|-c) 192 | CORES="$2" 193 | shift 2 194 | ;; 195 | -*) 196 | echo "Unknown option: $1" 197 | show_usage 198 | exit 1 199 | ;; 200 | *) 201 | ARGS+=("$1") 202 | shift 203 | ;; 204 | esac 205 | done 206 | 207 | # Execute the command 208 | case $COMMAND in 209 | build) 210 | build_extension 211 | ;; 212 | dev) 213 | start_dev_shell 214 | ;; 215 | fetch) 216 | fetch_and_build "${ARGS[0]}" 217 | ;; 218 | help) 219 | show_usage 220 | ;; 221 | *) 222 | echo "Unknown command: $COMMAND" 223 | show_usage 224 | exit 1 225 | ;; 226 | esac 227 | EOF 228 | 229 | # Set permissions and make the script executable 230 | RUN chmod +x /home/nixuser/bin/cli.sh && \ 231 | chown -R nixuser:nixuser /home/nixuser 232 | 233 | # Switch to nixuser 234 | USER nixuser 235 | 236 | # Use the cli.sh script directly 237 | ENTRYPOINT ["/home/nixuser/bin/cli.sh"] -------------------------------------------------------------------------------- /dockerfiles/README.md: -------------------------------------------------------------------------------- 1 | # Kernel Builder Docker Containers 2 | 3 | This directory contains two Docker containers for different use cases: 4 | 5 | ## Root Container (`Dockerfile`) 6 | 7 | This container runs as root and can modify file permissions when mounting volumes. 8 | 9 | ```bash 10 | # Build the container 11 | docker build -t kernel-builder:root -f Dockerfile .. 12 | 13 | # Use the container 14 | docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build 15 | ``` 16 | 17 | ## User Container (`Dockerfile.user`) 18 | 19 | This container runs as a non-root user (nixuser with UID 1000) for more secure environments. 20 | 21 | ```bash 22 | # Build the container 23 | docker build -t kernel-builder:user -f Dockerfile.user .. 24 | 25 | # Important: Prepare a directory with correct permissions 26 | mkdir -p ./build 27 | chown -R 1000:1000 ./build # Match the UID:GID of nixuser in the container 28 | 29 | # Use with proper permissions for the build directory 30 | docker run --mount type=bind,source=$(pwd),target=/home/nixuser/kernelcode \ 31 | --mount type=bind,source=$(pwd)/build,target=/home/nixuser/kernelcode/build \ 32 | kernel-builder:user build 33 | ``` 34 | 35 | ## Environment Variables 36 | 37 | Both containers support these build options: 38 | 39 | ```bash 40 | # Set options at build time 41 | docker build -t kernel-builder:custom --build-arg MAX_JOBS=8 --build-arg CORES=2 -f Dockerfile .. 42 | 43 | # Or at runtime 44 | docker run -e MAX_JOBS=8 -e CORES=2 --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build 45 | ``` 46 | -------------------------------------------------------------------------------- /docs/build-variants.md: -------------------------------------------------------------------------------- 1 | # Build variants 2 | 3 | A kernel can be compliant for a specific compute framework (e.g. CUDA) or 4 | architecture (e.g. x86_64). For compliance with a compute framework and 5 | architecture combination, all the build variants listed below must be 6 | available. This list will be updated as new PyTorch versions are released. 7 | 8 | ## CUDA aarch64-linux 9 | 10 | - `torch26-cxx11-cu126-aarch64-linux` 11 | - `torch26-cxx98-cu126-aarch64-linux` 12 | - `torch27-cxx11-cu126-aarch64-linux` 13 | - `torch27-cxx11-cu128-aarch64-linux` 14 | 15 | ## CUDA x86_64-linux 16 | 17 | - `torch26-cxx11-cu118-x86_64-linux` 18 | - `torch26-cxx11-cu124-x86_64-linux` 19 | - `torch26-cxx11-cu126-x86_64-linux` 20 | - `torch26-cxx98-cu118-x86_64-linux` 21 | - `torch26-cxx98-cu124-x86_64-linux` 22 | - `torch26-cxx98-cu126-x86_64-linux` 23 | - `torch27-cxx11-cu118-x86_64-linux` 24 | - `torch27-cxx11-cu126-x86_64-linux` 25 | - `torch27-cxx11-cu128-x86_64-linux` 26 | 27 | ## ROCm x86_64-linux 28 | 29 | - `torch26-cxx11-rocm62-x86_64-linux` 30 | - `torch27-cxx11-rocm63-x86_64-linux` 31 | 32 | ## Universal 33 | 34 | Kernels that are in pure Python (e.g. Triton kernels) only need to provide 35 | a single build variant: 36 | 37 | - `torch-universal` 38 | -------------------------------------------------------------------------------- /docs/local-dev.md: -------------------------------------------------------------------------------- 1 | # Local development of kernels 2 | 3 | ## Introduction 4 | 5 | `kernel-builder` builds kernels in a sandbox. This has various benefits, 6 | such as building kernels for a wide range of Torch versions, compatibility 7 | with older C library versions and avoiding accidental dependencies. 8 | 9 | However, this is not ideal during kernel development, since language 10 | servers and IDEs do not interpret the `build.toml` file. As a result, 11 | code completion will typically not work. `kernel-builder` provides the 12 | `build2cmake` utility to generate CMake files to build native code and 13 | setuptools files for building the kernel as a regular Python package. 14 | Since CMake and setuptools are widely supported by IDEs, this provides 15 | a much-improved development experience. 16 | 17 | ## Installing `build2cmake` 18 | 19 | `build2cmake` is available as a Rust crate. After [installing Rust](https://rustup.rs), 20 | it can be built and installed as follows: 21 | 22 | ```bash 23 | $ cargo install build2cmake 24 | ``` 25 | 26 | ## Generating a Python project with `build2cmake` 27 | 28 | `build2cmake` generates a CMake/Python project from a [`build.toml`](./writing-kernels.md) 29 | file. The invocation is as follows: 30 | 31 | ```bash 32 | $ build2cmake generate-torch build.toml -f 33 | ``` 34 | 35 | The `-f` flag is optional and instructs `build2cmake` to overwrite 36 | existing files. 37 | 38 | It is recommended to do an editable install of the generated project into 39 | your Python virtual environment for development: 40 | 41 | ```bash 42 | $ pip install wheel # Needed once to enable bdist_wheel. 43 | $ pip install --no-build-isolation -e . 44 | ``` 45 | 46 | **Warnings:** 47 | 48 | - Kernels built in this way should **not** be published on the Kernel 49 | Hub. They do not fulfill the [kernel requirements](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). 50 | - Do not add the generated files to Git. `build2cmake` has regular updates 51 | and you generally want to use files generated by the latest version. 52 | -------------------------------------------------------------------------------- /docs/nix.md: -------------------------------------------------------------------------------- 1 | # Using the kernel builder with Nix 2 | 3 | The kernel builder uses Nix for building kernels. You can build or 4 | run the kernels directly if you have [Nix installed](https://nixos.org/download/) 5 | on your system. On systems without Nix you can use the [Docker](./docker.md) 6 | image, which is a wrapper around Nix. 7 | 8 | ## Getting started 9 | 10 | The easiest way get all the Nix functionality is by putting a 11 | `flake.nix` in your kernel repository. To do so, copy 12 | [`examples/relu/flake.nix`](../examples/relu/flake.nix) into the 13 | same directory as your `build.toml` file. Then run `nix flake update`. 14 | This generates a `flake.lock` file that pins the kernel builder 15 | and _all_ its transitive dependencies. Commit both `flake.nix` 16 | and `flake.lock` to your repository, this will ensure that kernel 17 | builds are reproducible. 18 | 19 | Since the kernel builder depends on many packages (e.g. every supported 20 | PyTorch version), it is recommended to [enable the huggingface cache](https://app.cachix.org/cache/huggingface) 21 | to avoid expensive rebuilds. 22 | 23 | The kernel builder also provides Nix development shells with all Torch 24 | and CUDA/ROCm dependencies needed to develop kernels (see below). If 25 | you want to test your kernels inside a Nix development shell and you 26 | are not using NixOS, [make sure that the CUDA driver is visible](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library) to Torch. 27 | 28 | ## Building kernels with Nix 29 | 30 | A kernel that has a `flake.nix` file can be built with `nix build`. 31 | For example: 32 | 33 | ```bash 34 | cd examples/activation 35 | nix build . -L 36 | ``` 37 | 38 | You can put this `flake.nix` in your own kernel's root directory to 39 | get add Nix support to your kernel. 40 | 41 | ## Shell for local development 42 | 43 | `kernel-builder` provides shells for developing kernels. In such a shell, 44 | all required dependencies are available, as well as `build2cmake` for generating 45 | project files. For example: 46 | 47 | ```bash 48 | $ nix develop 49 | $ build2cmake generate-torch build.toml 50 | $ cmake -B build-ext 51 | $ cmake --build build-ext 52 | ``` 53 | 54 | If you want to test the kernel as a Python package, you can make a virtual 55 | environment inside the shell: 56 | 57 | ```bash 58 | $ nix develop 59 | $ build2cmake generate-torch build.toml 60 | $ python -m venv .venv 61 | $ source .venv/bin/activate 62 | $ pip install --no-build-isolation -e . 63 | ``` 64 | 65 | Development shells are available for every build configuration. For 66 | instance, you can get a Torch 2.6 development shell for ROCm extensions 67 | using: 68 | 69 | ```bash 70 | $ nix develop .#devShells.torch26-cxx11-rocm62-x86_64-linux 71 | ``` 72 | 73 | ## Shell for testing a kernel 74 | 75 | You can also start a development shell. This will give you a Python interpreter 76 | with the kernel in Python's search path. This makes it more convenient to run 77 | tests: 78 | 79 | ```bash 80 | cd examples/activation 81 | nix develop -L .#test 82 | python -m pytest tests 83 | ``` 84 | 85 | ## Building a kernel without `flake.nix` 86 | 87 | If a kernels source directory does not have a `flake.nix` file, you can build the 88 | kernel using the `buildTorchExtensionBundle` function from the kernel builder 89 | itself: 90 | 91 | ```bash 92 | cd examples/activation 93 | nix build --impure --expr 'with import ../..; lib.x86_64-linux.buildTorchExtensionBundle ./.' -L 94 | ``` 95 | -------------------------------------------------------------------------------- /docs/toolchain.md: -------------------------------------------------------------------------------- 1 | # Toolchain 2 | 3 | ## ABI compatibility 4 | 5 | Kernels and kernel extensions typically do not have any explicit external 6 | dependencies, except: 7 | 8 | - The CUDA library version that they were compiled against. 9 | - The Torch version that they were compiled against. 10 | - The C and C++ standard libraries (glibc and libstdc++ in Linux). 11 | - The Python library. 12 | 13 | Of course, the versions on a user's system can differ from the build 14 | system, so we have to account for these dependencies. 15 | 16 | ### Python 17 | 18 | In the case of Python we use the [limited API](https://docs.python.org/3/c-api/stable.html#limited-c-api). 19 | For the limited API, ABI stability it guaranteed. This excludes the 20 | possibility to use some dependencies like pybind11, but since it 21 | reduces the number of builds drastically, we use it anyway. 22 | 23 | ### CUDA/Torch 24 | 25 | Torch and CUDA only have limited ABI compatibility. Therefore, we 26 | compile extensions for all supported CUDA/Torch combinations. 27 | 28 | ### glibc/libstdc++ 29 | 30 | glibc and libstdc++ use symbol versioning. As a result, a binary or 31 | library compiled against an older version of these libraries work 32 | on newer versions (modulo the C++11 ABI change). It is however, 33 | not possible to use only older symbol versions when building against 34 | a newer version of these libraries. 35 | 36 | The traditional solution to this problem is to build software in 37 | a container that uses an ancient Linux distribution with old 38 | glibc/libstdc++ versions. 39 | 40 | With Nix we can do better --- with some work we can compile for old 41 | versions of these libraries using a recent nixpkgs. There are several 42 | nuances: 43 | 44 | - libstdc++ is distributed with gcc, so it has the same library version. 45 | - CUDA's nvcc uses a 'backend stdenv'. This stdenv has the latest 46 | gcc that is supported by nvcc. It can differ from the default gcc, 47 | because gcc in nixpkgs is sometimes newer than the version supported 48 | by CUDA. 49 | - gcc also links a binary against libgcc. libgcc must also be compiled 50 | against the target glibc, otherwise the resulting extensions will 51 | still rely on symbols from newer glibc versions. 52 | 53 | With that in mind, there are (at least?) three ways to do this: 54 | 55 | 1. Override glibc and libstdc++ system-wide using an overlay. 56 | 2. Override the backend stdenv of CUDA with one that has older 57 | library versions. 58 | 3. Only override glibc and libstdc++ through the stdenv for 59 | the kernel/extension packages. 60 | 61 | (1) is the most principled approach -- it guarantees that all packages 62 | use the same library versions, making it impossible for a newer version 63 | to creep in. Unfortunately, this has many issues, libraries and 64 | derivations from simply don't interoperate well with a package set from 2024. 65 | For instance, the build of older glibc versions hangs with GNU 66 | make >= 4.4 due to some dependency cycle. 67 | 68 | Overriding the backend stdenv of CUDA (2) has the issue that some 69 | derivations end up with two versions. E.g. they would build using the 70 | headers of the latest glibc and then try to link against the old glibc. 71 | 72 | Finally, (3) seems to work really well. We build everything except 73 | the kernels and extensions using unmodified nixpkgs. Then we tell nvcc 74 | to use our modified stdenv using CMake. 75 | 76 | To make this possible, we import the glibc and libstd++ derivations 77 | from an old nixpkgs. We then create an intermediate stdenv to rebuild 78 | gcc/libgcc against the old glibc. Then glibc, libstdc++, and the 79 | rebuilt gcc form make a new stdenv together. We also link libstdc++ 80 | statically. 81 | -------------------------------------------------------------------------------- /docs/why-nix.md: -------------------------------------------------------------------------------- 1 | # Why Nix? 2 | 3 | The Kernel Builder project uses Nix to build custom kernels designed specifically for PyTorch. 4 | 5 | Here’s why we chose Nix and why it's particularly suited to our workflow: 6 | 7 | ## 1. Consistent and Reproducible Builds 8 | 9 | Nix guarantees deterministic evaluation, ensuring that every kernel is built identically, regardless of the host environment. This consistency prevents "it works on my machine" problems, making debugging and deployment straightforward. 10 | 11 | ## 2. Simplified Dependency Management 12 | 13 | Compiling PyTorch kernels often involves complex dependencies such as CUDA versions, PyTorch APIs, and various C++ toolchains. Nix explicitly defines and manages these dependencies, eliminating version conflicts and making maintenance easier. 14 | 15 | ## 3. Declarative Configuration 16 | 17 | Nix’s declarative approach clearly specifies exactly what each kernel build needs. This transparency aids collaboration, speeds up troubleshooting, and makes it easy to document the build process. 18 | 19 | ## 4. Isolated, Reliable Builds 20 | 21 | Each kernel build with Nix runs in a fully isolated sandbox, removing any uncertainty about external state. This isolation ensures clean builds, free of unexpected side effects. 22 | 23 | ## 5. Efficient Caching and CI Integration 24 | 25 | Kernel compilation can be resource-intensive. Nix leverages efficient caching of build artifacts, significantly reducing build times and optimizing continuous integration workflows. 26 | 27 | ## 6. Easy Experimentation and Rollbacks 28 | 29 | Nix allows you to experiment with different kernel configurations, PyTorch versions, or CUDA toolkits easily. If a change introduces an issue, reverting to a previous state is quick and effortless. 30 | 31 | Overall, Nix streamlines the Kernel Builder workflow, allowing us to efficiently and reliably manage complex machine learning kernel builds. 32 | 33 | --- 34 | 35 | If you want to learn more about Nix, check out the following resources: 36 | 37 | ## References 38 | 39 | - **The Official Nix Manual:** 40 | - The definitive source for all things Nix, providing comprehensive coverage of its features, commands, and ecosystem. 41 | - Link: [Nix Manual (nixos.org)](https://nixos.org/manual/nix/stable/) 42 | - **Nix Pills:** 43 | - A series of blog posts breaking down complex Nix concepts into digestible pieces, ideal for a structured, tutorial-style approach. 44 | - Link: [Nix Pills (nixos.org)](https://nixos.org/guides/nix-pills/) 45 | - **nix.dev**: 46 | - Home of official documentation for the Nix ecosystem. 47 | - Link [nix.dev](https://nix.dev/) 48 | - **NixOS Wiki:** 49 | - A community-driven wiki with a wealth of information, including tips, tricks, and tutorials, covering a wide range of topics, including NixOS-specific information. 50 | - Link: [NixOS Wiki](https://nixos.wiki/wiki/Main_Page) 51 | 52 | -------------------------------------------------------------------------------- /examples/activation/README.md: -------------------------------------------------------------------------------- 1 | ## Activation 2 | 3 | Activation kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu). 4 | 5 | Copyright 2023-2024, the vLLM team. 6 | -------------------------------------------------------------------------------- /examples/activation/activation/cuda_compat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef USE_ROCM 4 | #include 5 | #endif 6 | 7 | #ifndef USE_ROCM 8 | #define WARP_SIZE 32 9 | #else 10 | #define WARP_SIZE warpSize 11 | #endif 12 | 13 | #ifndef USE_ROCM 14 | #define VLLM_LDG(arg) __ldg(arg) 15 | #else 16 | #define VLLM_LDG(arg) *(arg) 17 | #endif 18 | 19 | #ifndef USE_ROCM 20 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ 21 | __shfl_xor_sync(uint32_t(-1), var, lane_mask) 22 | #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ 23 | __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) 24 | #else 25 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) 26 | #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ 27 | __shfl_xor(var, lane_mask, width) 28 | #endif 29 | 30 | #ifndef USE_ROCM 31 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) 32 | #else 33 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) 34 | #endif 35 | 36 | #ifndef USE_ROCM 37 | #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ 38 | __shfl_down_sync(uint32_t(-1), var, lane_delta) 39 | #else 40 | #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) 41 | #endif 42 | 43 | #ifndef USE_ROCM 44 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 45 | cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) 46 | #else 47 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 48 | hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) 49 | #endif 50 | -------------------------------------------------------------------------------- /examples/activation/activation/dispatch_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from 3 | * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h 4 | */ 5 | #pragma once 6 | 7 | #include 8 | 9 | #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ 10 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ 11 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 12 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) 13 | 14 | #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ 15 | AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) 16 | 17 | #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ 18 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ 19 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 20 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ 21 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) 22 | 23 | #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ 24 | AT_DISPATCH_SWITCH(TYPE, NAME, \ 25 | VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) 26 | 27 | #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ 28 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ 29 | AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ 30 | AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ 31 | AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ 32 | AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) 33 | 34 | #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ 35 | AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) 36 | -------------------------------------------------------------------------------- /examples/activation/build.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | name = "activation" 3 | universal = false 4 | 5 | [torch] 6 | src = [ 7 | "torch-ext/torch_binding.cpp", 8 | "torch-ext/torch_binding.h", 9 | ] 10 | 11 | [kernel.activation] 12 | backend = "cuda" 13 | depends = ["torch"] 14 | src = [ 15 | "activation/activation_kernels.cu", 16 | "activation/cuda_compat.h", 17 | "activation/dispatch_utils.h", 18 | ] 19 | -------------------------------------------------------------------------------- /examples/activation/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Flake for activation kernels"; 3 | 4 | inputs = { 5 | kernel-builder.url = "path:../.."; 6 | }; 7 | 8 | outputs = 9 | { 10 | self, 11 | kernel-builder, 12 | }: 13 | 14 | kernel-builder.lib.genFlakeOutputs { 15 | path = ./.; 16 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 17 | }; 18 | } 19 | -------------------------------------------------------------------------------- /examples/activation/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/activation/tests/__init__.py -------------------------------------------------------------------------------- /examples/activation/tests/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/activation/tests/kernels/__init__.py -------------------------------------------------------------------------------- /examples/activation/tests/kernels/allclose_default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Reference default values of atol and rtol are from 4 | # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 5 | default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} 6 | default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} 7 | 8 | 9 | def get_default_atol(output) -> float: 10 | return default_atol[output.dtype] 11 | 12 | 13 | def get_default_rtol(output) -> float: 14 | return default_rtol[output.dtype] 15 | -------------------------------------------------------------------------------- /examples/activation/tests/kernels/test_activation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from typing import Type 4 | 5 | import activation 6 | import pytest 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import opcheck 11 | from .allclose_default import get_default_atol, get_default_rtol 12 | 13 | DTYPES = [torch.half, torch.bfloat16, torch.float] 14 | NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing 15 | D = [512, 13824] # Arbitrary values for testing 16 | SEEDS = [0] 17 | CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] 18 | 19 | 20 | def gelu_fast(x: torch.Tensor) -> torch.Tensor: 21 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 22 | 23 | 24 | def gelu_new(x: torch.Tensor) -> torch.Tensor: 25 | c = math.sqrt(2.0 / math.pi) 26 | return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) 27 | 28 | 29 | def gelu_quick(x: torch.Tensor) -> torch.Tensor: 30 | return x * torch.sigmoid(1.702 * x) 31 | 32 | 33 | def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor: 34 | d = x.shape[-1] // 2 35 | x1 = x[..., :d] 36 | x2 = x[..., d:] 37 | x1 = F.threshold(x1, threshold, 0.0) 38 | return x1 * x2 39 | 40 | 41 | def silu_and_mul(x: torch.Tensor) -> torch.Tensor: 42 | d = x.shape[-1] // 2 43 | return F.silu(x[..., :d]) * x[..., d:] 44 | 45 | 46 | def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor: 47 | d = x.shape[-1] // 2 48 | return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] 49 | 50 | 51 | @pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"]) 52 | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) 53 | @pytest.mark.parametrize("d", D) 54 | @pytest.mark.parametrize("dtype", DTYPES) 55 | @pytest.mark.parametrize("seed", SEEDS) 56 | @pytest.mark.parametrize("device", CUDA_DEVICES) 57 | @torch.inference_mode() 58 | def test_act_and_mul( 59 | activation_name: str, 60 | num_tokens: int, 61 | d: int, 62 | dtype: torch.dtype, 63 | seed: int, 64 | device: str, 65 | ) -> None: 66 | random.seed(seed) 67 | torch.manual_seed(seed) 68 | torch.set_default_device(device) 69 | x = torch.randn(num_tokens, 2 * d, dtype=dtype) 70 | if activation_name == "silu": 71 | torch_fn = silu_and_mul 72 | fn = activation.silu_and_mul 73 | op = activation.ops.silu_and_mul 74 | elif activation_name == "gelu": 75 | torch_fn = lambda x: gelu_and_mul(x, "none") 76 | fn = activation.gelu_and_mul 77 | op = activation.ops.gelu_and_mul 78 | elif activation_name == "gelu_tanh": 79 | torch_fn = lambda x: gelu_and_mul(x, "tanh") 80 | fn = activation.gelu_tanh_and_mul 81 | op = activation.ops.gelu_tanh_and_mul 82 | elif activation_name == "fatrelu": 83 | threshold = random.uniform(0, 1) 84 | torch_fn = lambda x: fatrelu_and_mul(x, threshold) 85 | fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold) 86 | op = activation.ops.fatrelu_and_mul 87 | 88 | out_shape = x.shape[:-1] + (x.shape[-1] // 2,) 89 | out = torch.empty(out_shape, dtype=x.dtype, device=x.device) 90 | out = fn(out, x) 91 | ref_out = torch_fn(x) 92 | 93 | # The SiLU, GELU and FatReLU implementations are equivalent to the native 94 | # PyTorch implementations, so we can do exact comparison. 95 | torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) 96 | 97 | d = x.shape[-1] // 2 98 | output_shape = x.shape[:-1] + (d,) 99 | out = torch.empty(output_shape, dtype=x.dtype, device=x.device) 100 | if activation_name == "fatrelu": 101 | opcheck(op, (out, x, threshold)) 102 | else: 103 | opcheck(op, (out, x)) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | "activation_fns", 108 | [ 109 | (gelu_fast, activation.gelu_fast, activation.ops.gelu_fast), 110 | (gelu_new, activation.gelu_new, activation.ops.gelu_new), 111 | (gelu_quick, activation.gelu_quick, activation.ops.gelu_quick), 112 | ], 113 | ) 114 | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) 115 | @pytest.mark.parametrize("d", D) 116 | @pytest.mark.parametrize("dtype", DTYPES) 117 | @pytest.mark.parametrize("seed", SEEDS) 118 | @pytest.mark.parametrize("device", CUDA_DEVICES) 119 | @torch.inference_mode() 120 | def test_activation( 121 | activation_fns, 122 | num_tokens: int, 123 | d: int, 124 | dtype: torch.dtype, 125 | seed: int, 126 | device: str, 127 | ) -> None: 128 | torch.manual_seed(seed) 129 | torch.set_default_device(device) 130 | x = torch.randn(num_tokens, d, dtype=dtype) 131 | torch_fn, fn, op = activation_fns 132 | out = fn(torch.empty_like(x), x) 133 | ref_out = torch_fn(x) 134 | torch.testing.assert_close( 135 | out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) 136 | ) 137 | 138 | out = torch.empty_like(x) 139 | opcheck(op, (out, x)) 140 | -------------------------------------------------------------------------------- /examples/activation/tests/kernels/utils.py: -------------------------------------------------------------------------------- 1 | """Kernel test utils""" 2 | 3 | import itertools 4 | import random 5 | import unittest 6 | from numbers import Number 7 | from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union 8 | 9 | import pytest 10 | import torch 11 | from torch._prims_common import TensorLikeType 12 | 13 | # For now, disable "test_aot_dispatch_dynamic" since there are some 14 | # bugs related to this test in PyTorch 2.4. 15 | DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( 16 | "test_schema", 17 | "test_autograd_registration", 18 | "test_faketensor", 19 | ) 20 | 21 | ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( 22 | "test_schema", 23 | "test_autograd_registration", 24 | "test_faketensor", 25 | "test_aot_dispatch_dynamic", 26 | ) 27 | 28 | 29 | # Copied/modified from torch._refs.__init__.py 30 | def fp8_allclose( 31 | a: TensorLikeType, 32 | b: TensorLikeType, 33 | rtol: float = 1e-05, 34 | atol: float = 1e-08, 35 | equal_nan: bool = False, 36 | ) -> bool: 37 | """ 38 | Reference implementation of torch.allclose 39 | """ 40 | torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) 41 | 42 | return bool( 43 | torch.all( 44 | torch.isclose( 45 | a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan 46 | ) 47 | ).item() 48 | ) 49 | 50 | 51 | # A special version of op check that has a restricted default set of test_utils 52 | # and a patched version of allclose that supports fp8 types. 53 | def opcheck( 54 | op: Union[ 55 | torch._ops.OpOverload, 56 | torch._ops.OpOverloadPacket, 57 | torch._library.custom_ops.CustomOpDef, 58 | ], 59 | args: Tuple[Any, ...], 60 | kwargs: Optional[Dict[str, Any]] = None, 61 | *, 62 | test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, 63 | raise_exception: bool = True, 64 | cond: bool = True 65 | ) -> Dict[str, str]: 66 | with unittest.mock.patch("torch.allclose", new=fp8_allclose): 67 | return ( 68 | torch.library.opcheck( 69 | op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception 70 | ) 71 | if cond 72 | else {} 73 | ) 74 | -------------------------------------------------------------------------------- /examples/activation/torch-ext/activation/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | from ._ops import ops 5 | except ImportError as e: 6 | # Fallback for local development. 7 | try: 8 | import _activation 9 | 10 | ops = torch.ops._activition 11 | except ImportError: 12 | raise e 13 | 14 | 15 | def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: 16 | ops.silu_and_mul(out, x) 17 | return out 18 | 19 | 20 | def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: 21 | ops.gelu_and_mul(out, x) 22 | return out 23 | 24 | 25 | def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: 26 | ops.gelu_tanh_and_mul(out, x) 27 | return out 28 | 29 | 30 | def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None: 31 | ops.fatrelu_and_mul(out, x, threshold) 32 | return out 33 | 34 | 35 | def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: 36 | ops.gelu_fast(out, x) 37 | return out 38 | 39 | 40 | def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: 41 | ops.gelu_new(out, x) 42 | return out 43 | 44 | 45 | def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: 46 | ops.gelu_quick(out, x) 47 | return out 48 | -------------------------------------------------------------------------------- /examples/activation/torch-ext/torch_binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "registration.h" 4 | #include "torch_binding.h" 5 | 6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 7 | // Activation ops 8 | // Activation function used in SwiGLU. 9 | ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); 10 | ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); 11 | 12 | // Activation function used in GeGLU with `none` approximation. 13 | ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); 14 | ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); 15 | 16 | // Activation function used in GeGLU with `tanh` approximation. 17 | ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); 18 | ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); 19 | 20 | // FATReLU implementation. 21 | ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); 22 | ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); 23 | 24 | // GELU implementation used in GPT-2. 25 | ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); 26 | ops.impl("gelu_new", torch::kCUDA, &gelu_new); 27 | 28 | // Approximate GELU implementation. 29 | ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); 30 | ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); 31 | 32 | // Quick GELU implementation. 33 | ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); 34 | ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); 35 | } 36 | 37 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) 38 | -------------------------------------------------------------------------------- /examples/activation/torch-ext/torch_binding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void silu_and_mul(torch::Tensor &out, torch::Tensor &input); 6 | 7 | void gelu_and_mul(torch::Tensor &out, torch::Tensor &input); 8 | 9 | void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input); 10 | 11 | void fatrelu_and_mul(torch::Tensor &out, torch::Tensor &input, 12 | double threshold); 13 | 14 | void gelu_new(torch::Tensor &out, torch::Tensor &input); 15 | 16 | void gelu_fast(torch::Tensor &out, torch::Tensor &input); 17 | 18 | void gelu_quick(torch::Tensor &out, torch::Tensor &input); 19 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/build.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | name = "cutlass_gemm" 3 | universal = false 4 | 5 | [torch] 6 | src = [ 7 | "torch-ext/torch_binding.cpp", 8 | "torch-ext/torch_binding.h", 9 | ] 10 | 11 | [kernel.gemm] 12 | backend = "cuda" 13 | depends = [ 14 | "torch", 15 | "cutlass_3_6", 16 | ] 17 | src = ["gemm.cu"] 18 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Flake for CUTLASS gemm test kernel"; 3 | 4 | inputs = { 5 | kernel-builder.url = "path:../.."; 6 | }; 7 | 8 | outputs = 9 | { 10 | self, 11 | kernel-builder, 12 | }: 13 | kernel-builder.lib.genFlakeOutputs { 14 | path = ./.; 15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 16 | }; 17 | } 18 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/gemm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void cutlass_gemm(torch::Tensor &out, torch::Tensor const &A, torch::Tensor const &B) { 5 | TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); 6 | TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); 7 | TORCH_CHECK(out.device().is_cuda(), "out must be a CUDA tensor"); 8 | 9 | TORCH_CHECK(A.is_contiguous(), "A must be a contiguous"); 10 | TORCH_CHECK(B.is_contiguous(), "B must be a contiguous"); 11 | TORCH_CHECK(out.is_contiguous(), "out must be a contiguous"); 12 | 13 | // Define the GEMM operation 14 | using Gemm = cutlass::gemm::device::Gemm; 17 | 18 | // Create a GEMM object 19 | Gemm gemm_op; 20 | 21 | // Define the problem size 22 | cutlass::gemm::GemmCoord problem_size(A.size(0), B.size(1), A.size(1)); 23 | 24 | // Define the arguments for the GEMM operation 25 | typename Gemm::Arguments args( 26 | problem_size, 27 | {A.data_ptr(), A.size(1)}, 28 | {B.data_ptr(), B.size(1)}, 29 | {out.data_ptr(), out.size(1)}, 30 | {out.data_ptr(), out.size(1)}, 31 | {1.0f, 0.0f} 32 | ); 33 | 34 | // Launch the GEMM operation 35 | cutlass::Status status = gemm_op(args); 36 | 37 | // Check for errors 38 | if (status != cutlass::Status::kSuccess) { 39 | throw std::runtime_error("CUTLASS GEMM operation failed"); 40 | } 41 | } 42 | 43 | 44 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/tests/test_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from cutlass_gemm import cutlass_gemm 4 | 5 | def test_gemm(): 6 | A = torch.randn((10, 20), device="cuda", dtype=torch.float32) 7 | B = torch.randn((20, 30), device="cuda", dtype=torch.float32) 8 | out = torch.randn((10, 30), device="cuda", dtype=torch.float32) 9 | 10 | cutlass_gemm(out, A, B) 11 | 12 | torch.testing.assert_allclose(out, torch.mm(A, B)) 13 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/torch-ext/cutlass_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ._ops import ops 4 | 5 | def cutlass_gemm(out: torch.Tensor, A: torch.Tensor, B: torch.Tensor) -> None: 6 | ops.cutlass_gemm(out, A, B) 7 | return out 8 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/torch-ext/registration.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #define _CONCAT(A, B) A##B 6 | #define CONCAT(A, B) _CONCAT(A, B) 7 | 8 | #define _STRINGIFY(A) #A 9 | #define STRINGIFY(A) _STRINGIFY(A) 10 | 11 | // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME 12 | // could be a macro instead of a literal token. 13 | #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) 14 | 15 | // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME 16 | // could be a macro instead of a literal token. 17 | #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ 18 | TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) 19 | 20 | // REGISTER_EXTENSION allows the shared library to be loaded and initialized 21 | // via python's import statement. 22 | #define REGISTER_EXTENSION(NAME) \ 23 | PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ 24 | static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ 25 | STRINGIFY(NAME), nullptr, 0, nullptr}; \ 26 | return PyModule_Create(&module); \ 27 | } 28 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/torch-ext/torch_binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "registration.h" 4 | #include "torch_binding.h" 5 | 6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 7 | ops.def("cutlass_gemm(Tensor! out, Tensor A, Tensor B) -> ()"); 8 | ops.impl("cutlass_gemm", torch::kCUDA, &cutlass_gemm); 9 | } 10 | 11 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) 12 | -------------------------------------------------------------------------------- /examples/cutlass-gemm/torch-ext/torch_binding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void cutlass_gemm(torch::Tensor &out, torch::Tensor const &A, torch::Tensor const &B); 6 | -------------------------------------------------------------------------------- /examples/relu/build.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | name = "relu" 3 | universal = false 4 | 5 | [torch] 6 | src = [ 7 | "torch-ext/torch_binding.cpp", 8 | "torch-ext/torch_binding.h", 9 | ] 10 | 11 | [kernel.activation] 12 | backend = "cuda" 13 | depends = ["torch"] 14 | src = ["relu_cuda/relu.cu"] 15 | 16 | [kernel.activation_metal] 17 | backend = "metal" 18 | src = [ 19 | "relu_metal/relu.mm", 20 | ] 21 | depends = [ "torch" ] 22 | 23 | [kernel.activation_rocm] 24 | backend = "rocm" 25 | rocm-archs = [ 26 | "gfx906", 27 | "gfx908", 28 | "gfx90a", 29 | "gfx940", 30 | "gfx941", 31 | "gfx942", 32 | "gfx1030", 33 | "gfx1100", 34 | "gfx1101", 35 | ] 36 | depends = ["torch"] 37 | src = ["relu_cuda/relu.cu"] 38 | -------------------------------------------------------------------------------- /examples/relu/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Flake for ReLU kernel"; 3 | 4 | inputs = { 5 | kernel-builder.url = "path:../.."; 6 | }; 7 | 8 | outputs = 9 | { 10 | self, 11 | kernel-builder, 12 | }: 13 | kernel-builder.lib.genFlakeOutputs { 14 | path = ./.; 15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 16 | }; 17 | } 18 | -------------------------------------------------------------------------------- /examples/relu/relu_cuda/relu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | __global__ void relu_kernel(float *__restrict__ out, 8 | float const *__restrict__ input, const int d) { 9 | const int64_t token_idx = blockIdx.x; 10 | for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { 11 | auto x = input[token_idx * d + idx]; 12 | out[token_idx * d + idx] = x > 0.0f ? x : 0.0f; 13 | } 14 | } 15 | 16 | void relu(torch::Tensor &out, torch::Tensor const &input) { 17 | TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); 18 | TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); 19 | TORCH_CHECK(input.scalar_type() == at::ScalarType::Float && 20 | input.scalar_type() == at::ScalarType::Float, 21 | "relu_kernel only supports float32"); 22 | 23 | TORCH_CHECK(input.sizes() == out.sizes(), 24 | "Tensors must have the same shape. Got input shape: ", 25 | input.sizes(), " and output shape: ", out.sizes()); 26 | 27 | TORCH_CHECK(input.scalar_type() == out.scalar_type(), 28 | "Tensors must have the same data type. Got input dtype: ", 29 | input.scalar_type(), " and output dtype: ", out.scalar_type()); 30 | 31 | TORCH_CHECK(input.device() == out.device(), 32 | "Tensors must be on the same device. Got input device: ", 33 | input.device(), " and output device: ", out.device()); 34 | 35 | int d = input.size(-1); 36 | int64_t num_tokens = input.numel() / d; 37 | dim3 grid(num_tokens); 38 | dim3 block(std::min(d, 1024)); 39 | const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); 40 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 41 | relu_kernel<<>>(out.data_ptr(), 42 | input.data_ptr(), d); 43 | } 44 | -------------------------------------------------------------------------------- /examples/relu/relu_metal/relu.mm: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #import 4 | #import 5 | #include 6 | 7 | char const *CUSTOM_KERNEL = R"( 8 | #include 9 | using namespace metal; 10 | 11 | kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], 12 | device float *outC [[buffer(1)]], 13 | uint index [[thread_position_in_grid]]) { 14 | // Explicitly write to output 15 | outC[index] = max(0.0f, inA[index]); 16 | } 17 | 18 | kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], 19 | device half *outC [[buffer(1)]], 20 | uint index [[thread_position_in_grid]]) { 21 | // Explicitly write to output 22 | outC[index] = max(static_cast(0.0), inA[index]); 23 | } 24 | )"; 25 | 26 | static inline id getMTLBufferStorage(const torch::Tensor &tensor) { 27 | return __builtin_bit_cast(id, tensor.storage().data()); 28 | } 29 | 30 | torch::Tensor &dispatchReluKernel(torch::Tensor const &input, 31 | torch::Tensor &output) { 32 | @autoreleasepool { 33 | id device = MTLCreateSystemDefaultDevice(); 34 | NSError *error = nil; 35 | 36 | int numThreads = input.numel(); 37 | 38 | id customKernelLibrary = [device 39 | newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL] 40 | options:nil 41 | error:&error]; 42 | TORCH_CHECK(customKernelLibrary, 43 | "Failed to to create custom kernel library, error: ", 44 | error.localizedDescription.UTF8String); 45 | 46 | std::string kernel_name = 47 | std::string("relu_forward_kernel_") + 48 | (input.scalar_type() == torch::kFloat ? "float" : "half"); 49 | id customReluFunction = [customKernelLibrary 50 | newFunctionWithName:[NSString 51 | stringWithUTF8String:kernel_name.c_str()]]; 52 | TORCH_CHECK(customReluFunction, 53 | "Failed to create function state object for ", 54 | kernel_name.c_str()); 55 | 56 | id reluPSO = 57 | [device newComputePipelineStateWithFunction:customReluFunction 58 | error:&error]; 59 | TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String); 60 | 61 | id commandBuffer = torch::mps::get_command_buffer(); 62 | TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); 63 | 64 | dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); 65 | 66 | dispatch_sync(serialQueue, ^() { 67 | id computeEncoder = 68 | [commandBuffer computeCommandEncoder]; 69 | TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); 70 | 71 | [computeEncoder setComputePipelineState:reluPSO]; 72 | [computeEncoder setBuffer:getMTLBufferStorage(input) 73 | offset:input.storage_offset() * input.element_size() 74 | atIndex:0]; 75 | [computeEncoder setBuffer:getMTLBufferStorage(output) 76 | offset:output.storage_offset() * output.element_size() 77 | atIndex:1]; 78 | 79 | MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); 80 | 81 | NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup; 82 | if (threadGroupSize > numThreads) { 83 | threadGroupSize = numThreads; 84 | } 85 | MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); 86 | 87 | [computeEncoder dispatchThreads:gridSize 88 | threadsPerThreadgroup:threadgroupSize]; 89 | 90 | [computeEncoder endEncoding]; 91 | 92 | torch::mps::commit(); 93 | }); 94 | } 95 | 96 | return output; 97 | } 98 | 99 | void relu(torch::Tensor &out, const torch::Tensor &input) { 100 | TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); 101 | TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); 102 | TORCH_CHECK(input.scalar_type() == torch::kFloat || 103 | input.scalar_type() == torch::kHalf, 104 | "Unsupported data type: ", input.scalar_type()); 105 | 106 | TORCH_CHECK(input.sizes() == out.sizes(), 107 | "Tensors must have the same shape. Got input shape: ", 108 | input.sizes(), " and output shape: ", out.sizes()); 109 | 110 | TORCH_CHECK(input.scalar_type() == out.scalar_type(), 111 | "Tensors must have the same data type. Got input dtype: ", 112 | input.scalar_type(), " and output dtype: ", out.scalar_type()); 113 | 114 | TORCH_CHECK(input.device() == out.device(), 115 | "Tensors must be on the same device. Got input device: ", 116 | input.device(), " and output device: ", out.device()); 117 | 118 | dispatchReluKernel(input, out); 119 | } 120 | -------------------------------------------------------------------------------- /examples/relu/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/relu/tests/__init__.py -------------------------------------------------------------------------------- /examples/relu/tests/test_relu.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import relu 7 | 8 | 9 | def test_relu(): 10 | if platform.system() == "Darwin": 11 | device = torch.device("mps") 12 | else: 13 | device = torch.device("cuda") 14 | x = torch.randn(1024, 1024, dtype=torch.float32, device=device) 15 | torch.testing.assert_allclose(F.relu(x), relu.relu(x)) 16 | -------------------------------------------------------------------------------- /examples/relu/torch-ext/relu/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from ._ops import ops 6 | 7 | 8 | def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: 9 | if out is None: 10 | out = torch.empty_like(x) 11 | ops.relu(out, x) 12 | return out 13 | -------------------------------------------------------------------------------- /examples/relu/torch-ext/torch_binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "registration.h" 4 | #include "torch_binding.h" 5 | 6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 7 | ops.def("relu(Tensor! out, Tensor input) -> ()"); 8 | #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) 9 | ops.impl("relu", torch::kCUDA, &relu); 10 | #elif defined(METAL_KERNEL) 11 | ops.impl("relu", torch::kMPS, relu); 12 | #endif 13 | } 14 | 15 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) 16 | -------------------------------------------------------------------------------- /examples/relu/torch-ext/torch_binding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void relu(torch::Tensor &out, torch::Tensor const &input); 6 | -------------------------------------------------------------------------------- /examples/silu-and-mul-universal/build.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | name = "silu_and_mul_universal" 3 | universal = true 4 | -------------------------------------------------------------------------------- /examples/silu-and-mul-universal/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Flake for kernels tests"; 3 | 4 | inputs = { 5 | kernel-builder.url = "path:../.."; 6 | }; 7 | 8 | outputs = 9 | { 10 | self, 11 | kernel-builder, 12 | }: 13 | kernel-builder.lib.genFlakeOutputs { 14 | path = ./.; 15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; 16 | }; 17 | } 18 | -------------------------------------------------------------------------------- /examples/silu-and-mul-universal/tests/test_silu_and_mul.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.library import opcheck 5 | 6 | from silu_and_mul_universal import ops, silu_and_mul 7 | 8 | 9 | def silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: 10 | d = x.shape[-1] // 2 11 | return F.silu(x[..., :d]) * x[..., d:] 12 | 13 | 14 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 15 | @pytest.mark.parametrize("requires_grad", [False, True]) 16 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) 17 | def test_opcheck(device, requires_grad, dtype): 18 | torch.manual_seed(42) 19 | x = torch.randn(32, 128, device=device, requires_grad=requires_grad, dtype=dtype) 20 | opcheck(ops.silu_and_mul, (x,)) 21 | 22 | 23 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 24 | @pytest.mark.parametrize("requires_grad", [False, True]) 25 | # Only do float32, the numerical instabilities of float16 and bfloat16 26 | # are too large with the different orderings of computing the gradients. 27 | @pytest.mark.parametrize("dtype", [torch.float32]) 28 | def test_silu_and_mul(device, requires_grad, dtype): 29 | torch.manual_seed(42) 30 | x_ref = torch.randn( 31 | 32, 128, device=device, requires_grad=requires_grad, dtype=dtype 32 | ) 33 | x = torch.empty(32, 128, device=device, requires_grad=requires_grad, dtype=dtype) 34 | with torch.no_grad(): 35 | x.copy_(x_ref) 36 | 37 | y_ref = silu_and_mul_ref(x_ref) 38 | y = silu_and_mul(x) 39 | 40 | torch.testing.assert_close(y_ref, y) 41 | 42 | if requires_grad: 43 | d_y = torch.randn((32, 64), device=device, dtype=dtype) 44 | y_ref.backward(d_y) 45 | y.backward(d_y) 46 | torch.testing.assert_close(x_ref.grad, x.grad) 47 | -------------------------------------------------------------------------------- /examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ._ops import ops 4 | from .silu_and_mul import _silu_and_mul 5 | 6 | 7 | def silu_and_mul(x: torch.Tensor) -> torch.Tensor: 8 | return ops.silu_and_mul(x) 9 | 10 | 11 | __all__ = ["silu_and_mul"] 12 | -------------------------------------------------------------------------------- /examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/silu_and_mul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ._ops import add_op_namespace_prefix 5 | 6 | 7 | @torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=()) 8 | def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: 9 | d = x.shape[-1] // 2 10 | return F.silu(x[..., :d]) * x[..., d:] 11 | 12 | 13 | def backward(ctx, grad_output): 14 | x = ctx.saved_tensors[0] 15 | d = x.shape[-1] // 2 16 | x1, x2 = x[..., :d], x[..., d:] 17 | sigmoid_x1 = torch.sigmoid(x1) 18 | silu_x1 = F.silu(x1) 19 | dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1) 20 | dx1 = grad_output * x2 * dsilu_dx1 21 | dx2 = grad_output * silu_x1 22 | return torch.cat([dx1, dx2], dim=-1) 23 | 24 | 25 | def setup_context(ctx, inputs, output): 26 | (x,) = inputs 27 | ctx.save_for_backward(x) 28 | 29 | 30 | _silu_and_mul.register_autograd(backward, setup_context=setup_context) 31 | 32 | 33 | @_silu_and_mul.register_fake 34 | def _(x: torch.Tensor) -> torch.Tensor: 35 | return x.new_empty(x.shape[0], x.shape[1] // 2) 36 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-compat": { 4 | "locked": { 5 | "lastModified": 1747046372, 6 | "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", 7 | "owner": "edolstra", 8 | "repo": "flake-compat", 9 | "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "edolstra", 14 | "repo": "flake-compat", 15 | "type": "github" 16 | } 17 | }, 18 | "flake-compat_2": { 19 | "locked": { 20 | "lastModified": 1733328505, 21 | "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", 22 | "owner": "edolstra", 23 | "repo": "flake-compat", 24 | "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", 25 | "type": "github" 26 | }, 27 | "original": { 28 | "owner": "edolstra", 29 | "repo": "flake-compat", 30 | "type": "github" 31 | } 32 | }, 33 | "flake-utils": { 34 | "inputs": { 35 | "systems": "systems" 36 | }, 37 | "locked": { 38 | "lastModified": 1731533236, 39 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 40 | "owner": "numtide", 41 | "repo": "flake-utils", 42 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 43 | "type": "github" 44 | }, 45 | "original": { 46 | "owner": "numtide", 47 | "repo": "flake-utils", 48 | "type": "github" 49 | } 50 | }, 51 | "flake-utils_2": { 52 | "inputs": { 53 | "systems": "systems_2" 54 | }, 55 | "locked": { 56 | "lastModified": 1731533236, 57 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 58 | "owner": "numtide", 59 | "repo": "flake-utils", 60 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 61 | "type": "github" 62 | }, 63 | "original": { 64 | "owner": "numtide", 65 | "repo": "flake-utils", 66 | "type": "github" 67 | } 68 | }, 69 | "hf-nix": { 70 | "inputs": { 71 | "flake-compat": "flake-compat_2", 72 | "flake-utils": "flake-utils_2", 73 | "nixpkgs": "nixpkgs" 74 | }, 75 | "locked": { 76 | "lastModified": 1748598786, 77 | "owner": "huggingface", 78 | "repo": "hf-nix", 79 | "rev": "6ca679441494139fde1f2355691ddb5dc8170269", 80 | "type": "github" 81 | }, 82 | "original": { 83 | "owner": "huggingface", 84 | "repo": "hf-nix", 85 | "type": "github" 86 | } 87 | }, 88 | "nixpkgs": { 89 | "locked": { 90 | "lastModified": 1747820358, 91 | "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", 92 | "owner": "danieldk", 93 | "repo": "nixpkgs", 94 | "rev": "d3c1681180717528068082103bf323147de6ab0b", 95 | "type": "github" 96 | }, 97 | "original": { 98 | "owner": "danieldk", 99 | "ref": "cudatoolkit-12.9-kernel-builder", 100 | "repo": "nixpkgs", 101 | "type": "github" 102 | } 103 | }, 104 | "root": { 105 | "inputs": { 106 | "flake-compat": "flake-compat", 107 | "flake-utils": "flake-utils", 108 | "hf-nix": "hf-nix", 109 | "nixpkgs": [ 110 | "hf-nix", 111 | "nixpkgs" 112 | ] 113 | } 114 | }, 115 | "systems": { 116 | "locked": { 117 | "lastModified": 1681028828, 118 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 119 | "owner": "nix-systems", 120 | "repo": "default", 121 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 122 | "type": "github" 123 | }, 124 | "original": { 125 | "owner": "nix-systems", 126 | "repo": "default", 127 | "type": "github" 128 | } 129 | }, 130 | "systems_2": { 131 | "locked": { 132 | "lastModified": 1681028828, 133 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 134 | "owner": "nix-systems", 135 | "repo": "default", 136 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 137 | "type": "github" 138 | }, 139 | "original": { 140 | "owner": "nix-systems", 141 | "repo": "default", 142 | "type": "github" 143 | } 144 | } 145 | }, 146 | "root": "root", 147 | "version": 7 148 | } 149 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Kernel builder"; 3 | 4 | inputs = { 5 | flake-utils.url = "github:numtide/flake-utils"; 6 | nixpkgs.follows = "hf-nix/nixpkgs"; 7 | flake-compat.url = "github:edolstra/flake-compat"; 8 | hf-nix.url = "github:huggingface/hf-nix"; 9 | }; 10 | 11 | outputs = 12 | { 13 | self, 14 | flake-compat, 15 | flake-utils, 16 | hf-nix, 17 | nixpkgs, 18 | }: 19 | let 20 | systems = with flake-utils.lib.system; [ 21 | aarch64-darwin 22 | aarch64-linux 23 | x86_64-linux 24 | ]; 25 | 26 | # Create an attrset { "" = [ ...]; ... }. 27 | buildSetPerSystem = builtins.listToAttrs ( 28 | builtins.map (system: { 29 | name = system; 30 | value = import ./lib/buildsets.nix { 31 | inherit nixpkgs system; 32 | hf-nix = hf-nix.overlays.default; 33 | }; 34 | }) systems 35 | ); 36 | 37 | libPerSystem = builtins.mapAttrs ( 38 | system: buildSet: 39 | import lib/build.nix { 40 | inherit (nixpkgs) lib; 41 | buildSets = buildSetPerSystem.${system}; 42 | } 43 | ) buildSetPerSystem; 44 | 45 | # The lib output consists of two parts: 46 | # 47 | # - Per-system build functions. 48 | # - `genFlakeOutputs`, which can be used by downstream flakes to make 49 | # standardized outputs (for all supported systems). 50 | lib = { 51 | allBuildVariantsJSON = 52 | let 53 | buildVariants = (import ./versions.nix { inherit (nixpkgs) lib; }).buildVariants; 54 | in 55 | builtins.toJSON (nixpkgs.lib.foldl' (acc: system: acc // buildVariants system) { } systems); 56 | genFlakeOutputs = 57 | { path, rev }: 58 | flake-utils.lib.eachSystem systems ( 59 | system: 60 | let 61 | build = libPerSystem.${system}; 62 | revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] rev; 63 | pkgs = nixpkgs.legacyPackages.${system}; 64 | shellTorch = 65 | if system == "aarch64-darwin" then "torch27-metal-${system}" else "torch27-cxx11-cu126-${system}"; 66 | in 67 | { 68 | devShells = rec { 69 | default = devShells.${shellTorch}; 70 | test = testShells.${shellTorch}; 71 | devShells = build.torchDevShells { 72 | inherit path; 73 | rev = revUnderscored; 74 | }; 75 | testShells = build.torchExtensionShells { 76 | inherit path; 77 | rev = revUnderscored; 78 | }; 79 | }; 80 | packages = rec { 81 | default = bundle; 82 | bundle = build.buildTorchExtensionBundle { 83 | inherit path; 84 | rev = revUnderscored; 85 | }; 86 | redistributable = build.buildDistTorchExtensions { 87 | inherit path; 88 | buildSets = buildSetPerSystem.${system}; 89 | rev = revUnderscored; 90 | }; 91 | buildTree = 92 | let 93 | build2cmake = self.packages.${system}.build2cmake; 94 | src = build.mkSourceSet path; 95 | in 96 | pkgs.runCommand "torch-extension-build-tree" 97 | { 98 | nativeBuildInputs = [ build2cmake ]; 99 | inherit src; 100 | meta = { 101 | description = "Build tree for torch extension with source files and CMake configuration"; 102 | }; 103 | } 104 | '' 105 | # Copy sources 106 | install -dm755 $out/src 107 | cp -r $src/. $out/src/ 108 | 109 | # Generate cmake files 110 | build2cmake generate-torch --ops-id "${revUnderscored}" $src/build.toml $out --force 111 | ''; 112 | }; 113 | } 114 | ); 115 | } // libPerSystem; 116 | 117 | in 118 | flake-utils.lib.eachSystem systems ( 119 | system: 120 | let 121 | # Plain nixkpgs that we use to access utility funtions. 122 | pkgs = import nixpkgs { 123 | inherit system; 124 | }; 125 | inherit (nixpkgs) lib; 126 | 127 | buildVersion = import ./lib/build-version.nix; 128 | 129 | buildSets = buildSetPerSystem.${system}; 130 | 131 | in 132 | rec { 133 | formatter = pkgs.nixfmt-tree; 134 | 135 | packages = rec { 136 | build2cmake = pkgs.callPackage ./pkgs/build2cmake { }; 137 | 138 | update-build = pkgs.writeShellScriptBin "update-build" '' 139 | ${build2cmake}/bin/build2cmake update-build ''${1:-build.toml} 140 | ''; 141 | 142 | # This package set is exposed so that we can prebuild the Torch versions. 143 | torch = builtins.listToAttrs ( 144 | map (buildSet: { 145 | name = buildVersion buildSet; 146 | value = buildSet.torch; 147 | }) buildSets 148 | ); 149 | 150 | # Dependencies that should be cached. 151 | forCache = 152 | let 153 | filterDist = lib.filter (output: output != "dist"); 154 | # Get all `torch` outputs except for `dist`. Not all outputs 155 | # are dependencies of `out`, but we'll need the `cxxdev` and 156 | # `dev` outputs for kernel builds. 157 | torchOutputs = builtins.listToAttrs ( 158 | lib.flatten ( 159 | # Map over build sets. 160 | map ( 161 | buildSet: 162 | # Map over all outputs of `torch` in a buildset. 163 | map (output: { 164 | name = "${buildVersion buildSet}-${output}"; 165 | value = buildSet.torch.${output}; 166 | }) (filterDist buildSet.torch.outputs) 167 | ) buildSets 168 | ) 169 | ); 170 | oldLinuxStdenvs = builtins.listToAttrs ( 171 | map (buildSet: { 172 | name = "stdenv-${buildVersion buildSet}"; 173 | value = buildSet.pkgs.stdenvGlibc_2_27; 174 | }) buildSets 175 | ); 176 | in 177 | pkgs.linkFarm "packages-for-cache" ( 178 | torchOutputs // lib.optionalAttrs nixpkgs.legacyPackages.${system}.stdenv.isLinux oldLinuxStdenvs 179 | ); 180 | }; 181 | } 182 | ) 183 | // { 184 | inherit lib; 185 | }; 186 | } 187 | -------------------------------------------------------------------------------- /kernel-abi-check/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kernel-abi-check" 3 | version = "0.4.0" 4 | edition = "2021" 5 | description = "Check the ABI of Hub Kernels" 6 | homepage = "https://github.com/huggingface/kernel-builder" 7 | license = "Apache-2.0" 8 | documentation = "https://docs.rs/kernel-abi-check" 9 | repository = "https://github.com/huggingface/kernel-builder" 10 | 11 | 12 | [dependencies] 13 | clap = { version = "4", features = ["derive"] } 14 | color-eyre = "0.6" 15 | eyre = "0.6" 16 | itertools = "0.14.0" 17 | object = "0.36.7" 18 | once_cell = "1" 19 | serde = { version = "1", features = ["derive"] } 20 | serde_json = "1" 21 | toml = "0.8" 22 | 23 | -------------------------------------------------------------------------------- /kernel-abi-check/flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1731533236, 9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "nixpkgs": { 22 | "locked": { 23 | "lastModified": 1742889210, 24 | "narHash": "sha256-hw63HnwnqU3ZQfsMclLhMvOezpM7RSB0dMAtD5/sOiw=", 25 | "owner": "nixos", 26 | "repo": "nixpkgs", 27 | "rev": "698214a32beb4f4c8e3942372c694f40848b360d", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "nixos", 32 | "ref": "nixos-unstable", 33 | "repo": "nixpkgs", 34 | "type": "github" 35 | } 36 | }, 37 | "root": { 38 | "inputs": { 39 | "flake-utils": "flake-utils", 40 | "nixpkgs": "nixpkgs" 41 | } 42 | }, 43 | "systems": { 44 | "locked": { 45 | "lastModified": 1681028828, 46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 47 | "owner": "nix-systems", 48 | "repo": "default", 49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 50 | "type": "github" 51 | }, 52 | "original": { 53 | "owner": "nix-systems", 54 | "repo": "default", 55 | "type": "github" 56 | } 57 | } 58 | }, 59 | "root": "root", 60 | "version": 7 61 | } 62 | -------------------------------------------------------------------------------- /kernel-abi-check/flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "kernel-abi-check devenv"; 3 | 4 | inputs = { 5 | flake-utils.url = "github:numtide/flake-utils"; 6 | nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable"; 7 | }; 8 | 9 | outputs = 10 | { 11 | self, 12 | flake-utils, 13 | nixpkgs, 14 | }: 15 | flake-utils.lib.eachDefaultSystem ( 16 | system: 17 | let 18 | pkgs = nixpkgs.legacyPackages.${system}; 19 | in 20 | { 21 | 22 | devShells.default = 23 | with pkgs; 24 | mkShell { 25 | buildInputs = [ 26 | cargo 27 | clippy 28 | openssl.dev 29 | pkg-config 30 | rustc 31 | rustfmt 32 | rust-analyzer 33 | ]; 34 | 35 | RUST_SRC_PATH = "${rustPlatform.rustLibSrc}"; 36 | }; 37 | } 38 | ); 39 | } 40 | -------------------------------------------------------------------------------- /kernel-abi-check/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Functions for checking kernel ABI compatibility. 2 | 3 | mod manylinux; 4 | pub use manylinux::{check_manylinux, ManylinuxViolation}; 5 | 6 | mod python_abi; 7 | pub use python_abi::{check_python_abi, PythonAbiViolation}; 8 | 9 | mod version; 10 | pub use version::Version; 11 | -------------------------------------------------------------------------------- /kernel-abi-check/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | use std::{collections::BTreeSet, fs}; 3 | 4 | use clap::Parser; 5 | use eyre::{Context, Result}; 6 | use object::Object; 7 | 8 | use kernel_abi_check::{ 9 | check_manylinux, check_python_abi, ManylinuxViolation, PythonAbiViolation, Version, 10 | }; 11 | 12 | /// CLI tool to check library versions 13 | #[derive(Parser, Debug)] 14 | #[command(version, about, long_about = None)] 15 | struct Cli { 16 | /// Python extension library. 17 | object: PathBuf, 18 | 19 | /// Manylinux version. 20 | #[arg(short, long, value_name = "VERSION", default_value = "manylinux_2_28")] 21 | manylinux: String, 22 | 23 | /// Python ABI version. 24 | #[arg(short, long, value_name = "VERSION", default_value = "3.9")] 25 | python_abi: Version, 26 | } 27 | 28 | fn main() -> Result<()> { 29 | // Initialize color_eyre error handling 30 | color_eyre::install()?; 31 | 32 | // Parse command-line arguments 33 | let args = Cli::parse(); 34 | 35 | eprintln!( 36 | "🐍 Checking for compatibility with {} and Python ABI version {}", 37 | args.manylinux, args.python_abi 38 | ); 39 | 40 | let binary_data = fs::read(args.object).context("Cannot open object file")?; 41 | let file = object::File::parse(&*binary_data).context("Cannot parse object")?; 42 | 43 | let many_linux_violations = check_manylinux( 44 | &args.manylinux, 45 | file.architecture(), 46 | file.endianness(), 47 | file.symbols(), 48 | )?; 49 | print_manylinux_violations(&many_linux_violations, &args.manylinux)?; 50 | 51 | let python_abi_violations = check_python_abi(&args.python_abi, file.symbols())?; 52 | print_python_abi_violations(&python_abi_violations, &args.python_abi); 53 | 54 | if !(many_linux_violations.is_empty() && python_abi_violations.is_empty()) { 55 | return Err(eyre::eyre!("Incompatible symbols found")); 56 | } else { 57 | eprintln!("✅ No compatibility issues found"); 58 | } 59 | 60 | Ok(()) 61 | } 62 | 63 | fn print_manylinux_violations( 64 | violations: &BTreeSet, 65 | manylinux_version: &str, 66 | ) -> Result<()> { 67 | if !violations.is_empty() { 68 | eprintln!( 69 | "\n⛔ Symbols incompatible with `{}` found:\n", 70 | manylinux_version 71 | ); 72 | for violation in violations { 73 | match violation { 74 | ManylinuxViolation::Symbol { name, dep, version } => { 75 | eprintln!("{}_{}: {}", name, dep, version); 76 | } 77 | } 78 | } 79 | } 80 | Ok(()) 81 | } 82 | 83 | fn print_python_abi_violations(violations: &BTreeSet, python_abi: &Version) { 84 | if !violations.is_empty() { 85 | let newer_abi3_symbols = violations 86 | .iter() 87 | .filter(|v| matches!(v, PythonAbiViolation::IncompatibleAbi3Symbol { .. })) 88 | .collect::>(); 89 | let non_abi3_symbols = violations 90 | .iter() 91 | .filter(|v| matches!(v, PythonAbiViolation::NonAbi3Symbol { .. })) 92 | .collect::>(); 93 | 94 | if !newer_abi3_symbols.is_empty() { 95 | eprintln!("\n⛔ Symbols >= Python ABI {} found:\n", python_abi); 96 | for violation in newer_abi3_symbols { 97 | if let PythonAbiViolation::IncompatibleAbi3Symbol { name, added } = violation { 98 | eprintln!("{}: {}", name, added); 99 | } 100 | } 101 | } 102 | 103 | if !non_abi3_symbols.is_empty() { 104 | eprintln!("\n⛔ Non-ABI3 symbols found:\n"); 105 | for violation in &non_abi3_symbols { 106 | if let PythonAbiViolation::NonAbi3Symbol { name } = violation { 107 | eprintln!("{}", name); 108 | } 109 | } 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /kernel-abi-check/src/manylinux/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{BTreeSet, HashSet}; 2 | use std::str; 3 | use std::{collections::HashMap, str::FromStr}; 4 | 5 | use eyre::{bail, Context, ContextCompat, OptionExt, Result}; 6 | use object::{Architecture, Endianness, ObjectSymbol, Symbol}; 7 | use once_cell::sync::Lazy; 8 | use serde::Deserialize; 9 | 10 | use crate::Version; 11 | 12 | #[derive(Debug, Deserialize)] 13 | struct ManyLinux { 14 | name: String, 15 | #[allow(dead_code)] 16 | aliases: Vec, 17 | #[allow(dead_code)] 18 | priority: u32, 19 | symbol_versions: HashMap>>, 20 | #[allow(dead_code)] 21 | lib_whitelist: Vec, 22 | #[allow(dead_code)] 23 | blacklist: HashMap>, 24 | } 25 | 26 | static MANYLINUX_POLICY_JSON: &str = include_str!("manylinux-policy.json"); 27 | 28 | static MANYLINUX_VERSIONS: Lazy> = Lazy::new(|| { 29 | let deserialized: Vec = serde_json::from_str(MANYLINUX_POLICY_JSON).unwrap(); 30 | deserialized 31 | .into_iter() 32 | .map(|manylinux| (manylinux.name.clone(), manylinux)) 33 | .collect() 34 | }); 35 | 36 | /// A violation of the manylinux policy. 37 | #[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)] 38 | pub enum ManylinuxViolation { 39 | /// A symbol is not allowed in the manylinux version. 40 | Symbol { 41 | name: String, 42 | dep: String, 43 | version: String, 44 | }, 45 | } 46 | 47 | pub fn check_manylinux<'a>( 48 | manylinux_version: &str, 49 | architecture: Architecture, 50 | endianness: Endianness, 51 | symbols: impl IntoIterator>, 52 | ) -> Result> { 53 | let arch_str = architecture.arch_str(endianness)?; 54 | let symbol_versions = MANYLINUX_VERSIONS 55 | .get(manylinux_version) 56 | .context(format!("Unknown manylinux version: {}", manylinux_version))? 57 | .symbol_versions 58 | .get(&arch_str) 59 | .context(format!( 60 | "Cannot find arch `{}` for: {}`", 61 | arch_str, manylinux_version 62 | ))?; 63 | 64 | let mut violations = BTreeSet::new(); 65 | 66 | for symbol in symbols { 67 | if symbol.is_undefined() { 68 | let symbol = symbol.name_bytes().context("Cannot get symbol name")?; 69 | let symbol = str::from_utf8(symbol).context("Cannot parse symbol name as UTF-8")?; 70 | 71 | let mut symbol_parts = symbol.split('@'); 72 | let symbol_name = symbol_parts.next().context("Cannot get symbol name")?; 73 | 74 | let version_info = match symbol_parts.next() { 75 | Some(version_info) => version_info, 76 | None => continue, 77 | }; 78 | 79 | let mut version_parts = version_info.split('_'); 80 | 81 | let dep = version_parts 82 | .next() 83 | .ok_or_eyre("Cannot get symbol version name")?; 84 | 85 | let version = match version_parts.next() { 86 | Some(version) => Version::from_str(version)?, 87 | // We also get symbol versions like: libcudart.so.12 88 | None => continue, 89 | }; 90 | 91 | if let Some(versions) = symbol_versions.get(dep) { 92 | if !versions.contains(&version.to_string()) { 93 | violations.insert(ManylinuxViolation::Symbol { 94 | name: symbol_name.to_string(), 95 | dep: dep.to_string(), 96 | version: version.to_string(), 97 | }); 98 | } 99 | } 100 | } 101 | } 102 | 103 | Ok(violations) 104 | } 105 | 106 | pub trait ArchStr { 107 | fn arch_str(&self, endiannes: Endianness) -> Result; 108 | } 109 | 110 | impl ArchStr for Architecture { 111 | fn arch_str(&self, endiannes: Endianness) -> Result { 112 | Ok(match self { 113 | Architecture::Aarch64 => "aarch64", 114 | Architecture::I386 => "i686", 115 | Architecture::PowerPc64 if matches!(endiannes, Endianness::Big) => "ppc64", 116 | Architecture::PowerPc64 if matches!(endiannes, Endianness::Little) => "ppc64le", 117 | Architecture::S390x => "s390x", 118 | Architecture::X86_64 => "x86_64", 119 | _ => bail!("Unsupported architecture: {:?}", self), 120 | } 121 | .to_string()) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /kernel-abi-check/src/python_abi/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{BTreeSet, HashMap}; 2 | 3 | use eyre::Result; 4 | use object::{ObjectSymbol, Symbol}; 5 | use once_cell::sync::Lazy; 6 | use serde::Deserialize; 7 | 8 | use crate::version::Version; 9 | 10 | static ABI_TOML: &str = include_str!("stable_abi.toml"); 11 | 12 | #[derive(Deserialize)] 13 | struct AbiInfoSerde { 14 | added: Version, 15 | } 16 | 17 | #[derive(Deserialize)] 18 | struct StableAbiSerde { 19 | function: HashMap, 20 | data: HashMap, 21 | } 22 | 23 | #[derive(Clone, Copy, Debug)] 24 | pub enum SymbolType { 25 | Data, 26 | Function, 27 | } 28 | 29 | #[derive(Clone, Debug)] 30 | pub struct AbiInfo { 31 | #[allow(dead_code)] 32 | pub symbol_type: SymbolType, 33 | pub added: Version, 34 | } 35 | 36 | pub static PYTHON_STABLE_ABI: Lazy> = Lazy::new(|| { 37 | let deserialized: StableAbiSerde = toml::de::from_str(ABI_TOML).unwrap(); 38 | let mut symbols = HashMap::new(); 39 | for (name, abi) in deserialized.function { 40 | symbols.insert( 41 | name, 42 | AbiInfo { 43 | symbol_type: SymbolType::Function, 44 | added: abi.added, 45 | }, 46 | ); 47 | } 48 | for (name, abi) in deserialized.data { 49 | symbols.insert( 50 | name, 51 | AbiInfo { 52 | symbol_type: SymbolType::Data, 53 | added: abi.added, 54 | }, 55 | ); 56 | } 57 | symbols 58 | }); 59 | 60 | /// Python ABI violation. 61 | #[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)] 62 | pub enum PythonAbiViolation { 63 | /// Symbol is newer than the specified Python ABI version. 64 | IncompatibleAbi3Symbol { name: String, added: Version }, 65 | 66 | /// Symbol is not part of ABI3. 67 | NonAbi3Symbol { name: String }, 68 | } 69 | 70 | /// Check for violations of the Python ABI policy. 71 | pub fn check_python_abi<'a>( 72 | python_abi: &Version, 73 | symbols: impl IntoIterator>, 74 | ) -> Result> { 75 | let mut violations = BTreeSet::new(); 76 | for symbol in symbols { 77 | if !symbol.is_undefined() { 78 | continue; 79 | } 80 | 81 | let symbol_name = symbol.name()?; 82 | 83 | match PYTHON_STABLE_ABI.get(symbol_name) { 84 | Some(abi_info) => { 85 | if &abi_info.added > python_abi { 86 | violations.insert(PythonAbiViolation::IncompatibleAbi3Symbol { 87 | name: symbol_name.to_string(), 88 | added: abi_info.added.clone(), 89 | }); 90 | } 91 | } 92 | None => { 93 | if symbol_name.starts_with("Py") || symbol_name.starts_with("_Py") { 94 | violations.insert(PythonAbiViolation::NonAbi3Symbol { 95 | name: symbol_name.to_string(), 96 | }); 97 | } 98 | } 99 | } 100 | } 101 | 102 | Ok(violations) 103 | } 104 | -------------------------------------------------------------------------------- /kernel-abi-check/src/version.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, str::FromStr}; 2 | 3 | use eyre::{ensure, Context, Result}; 4 | use serde::{de, Deserialize, Deserializer}; 5 | 6 | /// Symbol version. 7 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] 8 | pub struct Version(pub Vec); 9 | 10 | impl<'de> Deserialize<'de> for Version { 11 | fn deserialize(deserializer: D) -> Result 12 | where 13 | D: Deserializer<'de>, 14 | { 15 | let s = String::deserialize(deserializer)?; 16 | FromStr::from_str(&s).map_err(de::Error::custom) 17 | } 18 | } 19 | 20 | impl Display for Version { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | write!( 23 | f, 24 | "{}", 25 | itertools::join(self.0.iter().map(|v| v.to_string()), ".") 26 | ) 27 | } 28 | } 29 | 30 | impl FromStr for Version { 31 | type Err = eyre::Report; 32 | 33 | fn from_str(version: &str) -> Result { 34 | let version = version.trim().to_owned(); 35 | ensure!(!version.is_empty(), "Empty version string"); 36 | let mut version_parts = Vec::new(); 37 | for part in version.split('.') { 38 | let version_part: usize = part.parse().context("Version must consist of numbers")?; 39 | version_parts.push(version_part); 40 | } 41 | 42 | Ok(Version(version_parts)) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /kernel-compliance-check/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kernel-compliance-check" 3 | version = "0.1.0" 4 | edition = "2021" 5 | description = "Command-line utility for validating kernel compliance standards" 6 | license = "Apache-2.0" 7 | repository = "https://github.com/huggingface/kernel-builder" 8 | documentation = "https://docs.rs/kernel-compliance-check" 9 | homepage = "https://github.com/huggingface/kernel-builder" 10 | 11 | [features] 12 | default = [] 13 | enable_rocm = [] 14 | 15 | [dependencies] 16 | eyre = "0.6.12" 17 | clap = { version = "4.5.35", features = ["derive"] } 18 | colored = "3.0.0" 19 | dirs = "6.0.0" 20 | futures = "0.3.31" 21 | hf-hub = { version = "0.4.2", features = ["tokio"] } 22 | kernel-abi-check = "0.4.0" 23 | object = { version = "0.36.7", default-features = false, features = ["read"] } 24 | once_cell = "1.18.0" 25 | reqwest = { version = "0.11", features = ["json"] } 26 | serde = { version = "1.0", features = ["derive"] } 27 | serde_json = "1.0.114" 28 | thiserror = "1.0.40" 29 | tokio = { version = "1.44.2", features = ["full"] } 30 | 31 | [build-dependencies] 32 | ureq = "2.7.1" -------------------------------------------------------------------------------- /kernel-compliance-check/src/build_variants.json: -------------------------------------------------------------------------------- 1 | ../../build_variants.json -------------------------------------------------------------------------------- /kernel-compliance-check/src/formatter.rs: -------------------------------------------------------------------------------- 1 | use crate::models::AbiCheckResult; 2 | use colored::Colorize; 3 | 4 | /// Struct for console output formatting 5 | pub struct Console; 6 | 7 | impl Console { 8 | pub fn format_repo_list(repos: &[String], count: usize) { 9 | println!("."); 10 | for repo_id in repos { 11 | println!("├── {repo_id}"); 12 | } 13 | println!("╰── {count} kernel repositories found\n"); 14 | } 15 | 16 | pub fn format_fetch_status(repo_id: &str, fetching: bool, result: Option<&str>) { 17 | println!("repository: {repo_id}"); 18 | if fetching { 19 | println!("status: not found locally, fetching..."); 20 | } 21 | if let Some(message) = result { 22 | println!("status: {message}"); 23 | } 24 | } 25 | 26 | #[allow(clippy::too_many_arguments)] 27 | pub fn format_repository_check_result( 28 | repo_id: &str, 29 | build_status: &str, 30 | cuda_compatible: bool, 31 | #[cfg(feature = "enable_rocm")] rocm_compatible: bool, 32 | #[cfg(not(feature = "enable_rocm"))] _rocm_compatible: bool, 33 | cuda_variants: &[String], 34 | #[cfg(feature = "enable_rocm")] rocm_variants: &[String], 35 | #[cfg(not(feature = "enable_rocm"))] _rocm_variants: &[String], 36 | cuda_variants_present: &[String], 37 | #[cfg(feature = "enable_rocm")] rocm_variants_present: Vec, 38 | #[cfg(not(feature = "enable_rocm"))] _rocm_variants_present: Vec, 39 | compact_output: bool, 40 | abi_output: &AbiCheckResult, 41 | abi_status: &str, 42 | ) { 43 | // Display console-formatted output 44 | let abi_mark = if abi_output.overall_compatible { 45 | "✓".green() 46 | } else { 47 | "✗".red() 48 | }; 49 | 50 | let cuda_mark = if cuda_compatible { 51 | "✓".green() 52 | } else { 53 | "✗".red() 54 | }; 55 | 56 | #[cfg(feature = "enable_rocm")] 57 | let rocm_mark = if rocm_compatible { 58 | "✓".green() 59 | } else { 60 | "✗".red() 61 | }; 62 | 63 | let label = format!(" {repo_id} ").black().on_bright_white().bold(); 64 | 65 | println!("\n{label}"); 66 | println!("├── build: {build_status}"); 67 | 68 | if compact_output { 69 | // Compact output 70 | #[cfg(feature = "enable_rocm")] 71 | { 72 | println!("│ ├── {} CUDA", cuda_mark); 73 | println!("│ ╰── {} ROCM", rocm_mark); 74 | } 75 | 76 | #[cfg(not(feature = "enable_rocm"))] 77 | { 78 | println!("│ ╰── {cuda_mark} CUDA"); 79 | } 80 | } else { 81 | println!("│ {} {}", cuda_mark, "CUDA".bold()); 82 | 83 | // Print variant list with proper tree characters 84 | let mut cuda_iter = cuda_variants.iter().peekable(); 85 | while let Some(cuda_variant) = cuda_iter.next() { 86 | let is_last = cuda_iter.peek().is_none(); 87 | let is_present = cuda_variants_present.contains(cuda_variant); 88 | let prefix = if is_last { 89 | "│ ╰── " 90 | } else { 91 | "│ ├── " 92 | }; 93 | 94 | if is_present { 95 | println!("{prefix}{cuda_variant}"); 96 | } else { 97 | println!("{}{}", prefix, cuda_variant.dimmed()); 98 | } 99 | } 100 | 101 | // Only show ROCm section if the feature is enabled 102 | #[cfg(feature = "enable_rocm")] 103 | { 104 | println!("│ {} {}", rocm_mark, "ROCM".bold()); 105 | 106 | let mut rocm_iter = rocm_variants.iter().peekable(); 107 | while let Some(rocm_variant) = rocm_iter.next() { 108 | let is_last = rocm_iter.peek().is_none(); 109 | let is_present = rocm_variants_present.contains(rocm_variant); 110 | let prefix = if is_last { 111 | "│ ╰── " 112 | } else { 113 | "│ ├── " 114 | }; 115 | 116 | if is_present { 117 | println!("{}{}", prefix, rocm_variant); 118 | } else { 119 | println!("{}{}", prefix, rocm_variant.dimmed()); 120 | } 121 | } 122 | } 123 | } 124 | 125 | // ABI status section 126 | println!("╰── abi: {abi_status}"); 127 | println!(" ├── {} {}", abi_mark, abi_output.manylinux_version); 128 | println!( 129 | " ╰── {} python {}", 130 | abi_mark, abi_output.python_abi_version 131 | ); 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /kernel-compliance-check/src/main.rs: -------------------------------------------------------------------------------- 1 | use core::str::FromStr as _; 2 | 3 | use clap::Parser as _; 4 | use eyre::{Context as _, Result}; 5 | use kernel_abi_check::Version; 6 | use kernel_compliance_check::{process_repository, Cli, Commands, Format}; 7 | 8 | fn main() -> Result<()> { 9 | // Parse CLI arguments 10 | let cli = Cli::parse(); 11 | 12 | // Respect KERNELS_CACHE if set 13 | let non_standard_cache: Option = std::env::var("KERNELS_CACHE").ok(); 14 | 15 | // Prefer the cli unless explicitly set to avoid it 16 | let prefer_hub_cli = !std::env::var("AVOID_HUB_CLI") 17 | .map(|v| v == "1" || v == "true") 18 | .unwrap_or(false); 19 | 20 | match cli.command { 21 | Commands::Check { 22 | repos, 23 | manylinux, 24 | python_abi, 25 | revision, 26 | long, 27 | force_fetch, 28 | show_violations, 29 | format, 30 | } => { 31 | eprintln!("Running kernel compliance check"); 32 | eprintln!("Repositories: {repos}"); 33 | eprintln!("Kernel Revision: {revision}"); 34 | 35 | // Check repositories for compliance 36 | check_repositories( 37 | &repos, 38 | &manylinux.to_string(), 39 | &python_abi, 40 | prefer_hub_cli, 41 | force_fetch, 42 | &revision, 43 | long, 44 | show_violations, 45 | format, 46 | non_standard_cache.as_ref(), 47 | )?; 48 | } 49 | } 50 | 51 | Ok(()) 52 | } 53 | 54 | #[allow(clippy::fn_params_excessive_bools)] 55 | #[expect(clippy::too_many_arguments)] 56 | fn check_repositories( 57 | repos: &str, 58 | manylinux: &str, 59 | python_abi: &str, 60 | prefer_hub_cli: bool, 61 | force_fetch: bool, 62 | revision: &str, 63 | long: bool, 64 | show_violations: bool, 65 | format: Format, 66 | non_standard_cache: Option<&String>, 67 | ) -> Result<()> { 68 | let repositories: Vec = repos 69 | .split(',') 70 | .map(|s| s.trim().to_owned()) 71 | .filter(|s| !s.is_empty()) 72 | .collect(); 73 | 74 | if repositories.is_empty() { 75 | #[derive(serde::Serialize)] 76 | struct ErrorResponse { 77 | status: &'static str, 78 | error: &'static str, 79 | } 80 | 81 | if format.is_json() { 82 | let error = ErrorResponse { 83 | status: "error", 84 | error: "no repository ids provided", 85 | }; 86 | let json = serde_json::to_string_pretty(&error) 87 | .context("Failed to serialize error response")?; 88 | println!("{json}"); 89 | } else { 90 | eprintln!("no repository ids provided"); 91 | } 92 | return Ok(()); 93 | } 94 | 95 | let python_version = Version::from_str(python_abi) 96 | .map_err(|e| eyre::eyre!("Invalid Python ABI version {}: {}", python_abi, e))?; 97 | 98 | for repo_id in &repositories { 99 | if let Err(e) = process_repository( 100 | repo_id, 101 | revision, 102 | force_fetch, 103 | prefer_hub_cli, 104 | manylinux, 105 | &python_version, 106 | !long, 107 | show_violations, 108 | format, 109 | non_standard_cache, 110 | ) { 111 | eprintln!("Error processing repository {repo_id}: {e}"); 112 | 113 | // Continue processing other repositories rather than exiting early 114 | // This is more user-friendly for batch processing 115 | } 116 | } 117 | 118 | Ok(()) 119 | } 120 | -------------------------------------------------------------------------------- /kernel-compliance-check/src/models.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use clap::{Parser, Subcommand, ValueEnum}; 4 | use serde::{Deserialize, Serialize}; 5 | use thiserror::Error; 6 | 7 | #[derive(Error, Debug)] 8 | pub enum CompliantError { 9 | #[error("IO error: {0}")] 10 | Io(#[from] std::io::Error), 11 | 12 | #[error("Repository not found: {0}")] 13 | RepositoryNotFound(String), 14 | 15 | #[error("Build directory not found in repository: {0}")] 16 | BuildDirNotFound(String), 17 | 18 | #[error("Failed to fetch repository: {0}")] 19 | FetchError(String), 20 | 21 | #[error("Failed to parse object file: {0}")] 22 | ObjectParseError(String), 23 | 24 | #[error("Failed to check ABI compatibility: {0}")] 25 | AbiCheckError(String), 26 | 27 | #[error("Failed to serialize JSON: {0}")] 28 | SerializationError(String), 29 | 30 | #[error("Failed to fetch variants: {0}")] 31 | VariantsFetchError(String), 32 | 33 | #[error("Network error: {0}")] 34 | NetworkError(String), 35 | 36 | #[error("Unknown error: {0}")] 37 | Other(String), 38 | } 39 | 40 | /// Hugging Face kernel compliance checker 41 | #[derive(Parser)] 42 | #[command(author, version, about)] 43 | pub struct Cli { 44 | #[command(subcommand)] 45 | pub command: Commands, 46 | } 47 | 48 | #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] 49 | pub enum Format { 50 | Console, 51 | Json, 52 | } 53 | 54 | impl Format { 55 | #[must_use] 56 | pub fn is_json(&self) -> bool { 57 | matches!(self, Format::Json) 58 | } 59 | } 60 | 61 | #[derive(Subcommand)] 62 | pub enum Commands { 63 | /// Check repository compliance and ABI compatibility 64 | Check { 65 | /// Repository IDs or names (comma-separated) 66 | #[arg(value_name = "REPOS")] 67 | repos: String, 68 | 69 | /// Manylinux version to check against 70 | #[arg(short, long, default_value = "manylinux_2_28")] 71 | manylinux: String, 72 | 73 | /// Python ABI version to check against 74 | #[arg(short, long, default_value = "3.9")] 75 | python_abi: String, 76 | 77 | /// Revision (branch, tag, or commit hash) to use when fetching 78 | #[arg(short, long, default_value = "main")] 79 | revision: String, 80 | 81 | /// Show all variants in a long format. Default is compact output. 82 | #[arg(long, default_value_t = false)] 83 | long: bool, 84 | 85 | /// Force fetch the repository if not found locally 86 | #[arg(long, alias = "force", default_value_t = false)] 87 | force_fetch: bool, 88 | 89 | /// Show ABI violations in the output. Default is to only show compatibility status. 90 | #[arg(long, default_value_t = false)] 91 | show_violations: bool, 92 | 93 | /// Format of the output. Default is console 94 | #[arg(long, default_value = "console")] 95 | format: Format, 96 | }, 97 | } 98 | 99 | /// Structured representation of build variants 100 | #[derive(Debug, Deserialize)] 101 | pub struct VariantsConfig { 102 | #[serde(rename = "x86_64-linux")] 103 | pub x86_64_linux: ArchConfig, 104 | #[serde(rename = "aarch64-linux")] 105 | pub aarch64_linux: ArchConfig, 106 | } 107 | 108 | #[derive(Debug, Deserialize)] 109 | pub struct ArchConfig { 110 | pub cuda: Vec, 111 | #[serde(default)] 112 | #[cfg(feature = "enable_rocm")] 113 | pub rocm: Vec, 114 | #[cfg(not(feature = "enable_rocm"))] 115 | #[serde(default, skip)] 116 | _rocm: Vec, 117 | } 118 | 119 | #[derive(Debug, Clone, Serialize, Deserialize)] 120 | pub struct Variant { 121 | pub torch_version: String, 122 | pub cxx_abi: String, 123 | pub compute_framework: String, 124 | pub arch: String, 125 | pub os: String, 126 | } 127 | 128 | impl fmt::Display for Variant { 129 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 130 | write!( 131 | f, 132 | "{}-{}-{}-{}-{}", 133 | self.torch_version, self.cxx_abi, self.compute_framework, self.arch, self.os 134 | ) 135 | } 136 | } 137 | 138 | impl Variant { 139 | #[must_use] 140 | pub fn from_name(name: &str) -> Option { 141 | let parts: Vec<&str> = name.split('-').collect(); 142 | if parts.len() < 5 { 143 | return None; 144 | } 145 | // Format: torch{major}{minor}-{cxxabi}-{compute_framework}-{arch}-{os} 146 | Some(Variant { 147 | torch_version: parts[0].to_string(), 148 | cxx_abi: parts[1].to_string(), 149 | compute_framework: parts[2].to_string(), 150 | arch: parts[3].to_string(), 151 | os: parts[4].to_string(), 152 | }) 153 | } 154 | } 155 | 156 | #[derive(Serialize)] 157 | pub struct RepoErrorResponse { 158 | pub repository: String, 159 | pub status: String, 160 | pub error: String, 161 | } 162 | 163 | #[derive(Serialize)] 164 | pub struct RepositoryCheckResult { 165 | pub repository: String, 166 | pub status: String, 167 | pub build_status: BuildStatus, 168 | pub abi_status: AbiStatus, 169 | } 170 | 171 | #[derive(Serialize)] 172 | pub struct BuildStatus { 173 | pub summary: String, 174 | pub cuda: CudaStatus, 175 | #[serde(skip_serializing_if = "Option::is_none")] 176 | pub rocm: Option, 177 | } 178 | 179 | #[derive(Serialize)] 180 | pub struct CudaStatus { 181 | pub compatible: bool, 182 | pub present: Vec, 183 | pub missing: Vec, 184 | } 185 | 186 | #[derive(Serialize)] 187 | pub struct RocmStatus { 188 | pub compatible: bool, 189 | pub present: Vec, 190 | pub missing: Vec, 191 | } 192 | 193 | #[derive(Serialize)] 194 | pub struct AbiStatus { 195 | pub compatible: bool, 196 | pub manylinux_version: String, 197 | pub python_abi_version: String, 198 | pub variants: Vec, 199 | } 200 | 201 | #[derive(Serialize)] 202 | pub struct VariantCheckOutput { 203 | pub name: String, 204 | pub compatible: bool, 205 | pub has_shared_objects: bool, 206 | pub violations: Vec, 207 | } 208 | 209 | #[derive(Debug, Clone, Serialize, Deserialize)] 210 | pub struct SharedObjectViolation { 211 | pub message: String, 212 | } 213 | 214 | #[derive(Debug, Clone, Serialize, Deserialize)] 215 | pub struct VariantResult { 216 | pub name: String, 217 | pub is_compatible: bool, 218 | pub violations: Vec, 219 | pub has_shared_objects: bool, 220 | } 221 | 222 | #[derive(Debug, Clone)] 223 | pub struct AbiCheckResult { 224 | pub overall_compatible: bool, 225 | pub variants: Vec, 226 | pub manylinux_version: String, 227 | pub python_abi_version: kernel_abi_check::Version, 228 | } 229 | -------------------------------------------------------------------------------- /lib/build-version.nix: -------------------------------------------------------------------------------- 1 | { 2 | gpu, 3 | pkgs, 4 | torch, 5 | upstreamVariant, 6 | }: 7 | let 8 | inherit (pkgs) lib; 9 | inherit (import ./version-utils.nix { inherit lib; }) flattenVersion abiString; 10 | abi = torch: abiString torch.passthru.cxx11Abi; 11 | targetPlatform = pkgs.stdenv.targetPlatform.system; 12 | cudaVersion = torch: "cu${flattenVersion torch.cudaPackages.cudaMajorMinorVersion}"; 13 | rocmVersion = 14 | torch: "rocm${flattenVersion (lib.versions.majorMinor torch.rocmPackages.rocm.version)}"; 15 | gpuVersion = torch: (if torch.cudaSupport then cudaVersion else rocmVersion) torch; 16 | torchVersion = torch: flattenVersion torch.version; 17 | in 18 | if pkgs.stdenv.hostPlatform.isDarwin then 19 | "torch${torchVersion torch}-metal-${targetPlatform}" 20 | else 21 | "torch${torchVersion torch}-${abi torch}-${gpuVersion torch}-${targetPlatform}" 22 | -------------------------------------------------------------------------------- /lib/buildsets.nix: -------------------------------------------------------------------------------- 1 | { 2 | nixpkgs, 3 | system, 4 | hf-nix, 5 | }: 6 | 7 | let 8 | inherit (nixpkgs) lib; 9 | 10 | overlay = import ../overlay.nix; 11 | 12 | # Get versions. 13 | inherit (import ../versions.nix { inherit lib; }) buildConfigs cudaVersions rocmVersions; 14 | 15 | flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version); 16 | 17 | # An overlay that overides CUDA to the given version. 18 | overlayForCudaVersion = cudaVersion: self: super: { 19 | cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}"; 20 | }; 21 | 22 | overlayForRocmVersion = rocmVersion: self: super: { 23 | rocmPackages = super."rocmPackages_${flattenVersion rocmVersion}"; 24 | }; 25 | 26 | # Construct the nixpkgs package set for the given versions. 27 | pkgsForVersions = 28 | pkgsByCudaVer: 29 | { 30 | gpu, 31 | cudaVersion ? "", 32 | metal ? false, 33 | rocmVersion ? "", 34 | torchVersion, 35 | cxx11Abi, 36 | upstreamVariant ? false, 37 | }: 38 | let 39 | pkgs = 40 | if gpu == "cuda" then 41 | pkgsByCudaVer.${cudaVersion} 42 | else if gpu == "rocm" then 43 | pkgsByRocmVer.${rocmVersion} 44 | else if gpu == "metal" then 45 | pkgsForMetal 46 | else 47 | throw "Unknown compute framework: ${gpu}"; 48 | torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override { 49 | inherit cxx11Abi; 50 | }; 51 | in 52 | { 53 | inherit 54 | gpu 55 | pkgs 56 | torch 57 | upstreamVariant 58 | ; 59 | }; 60 | 61 | pkgsForMetal = import nixpkgs { 62 | inherit system; 63 | overlays = [ 64 | hf-nix 65 | overlay 66 | ]; 67 | }; 68 | 69 | pkgsForRocm = import nixpkgs { 70 | inherit system; 71 | config = { 72 | allowUnfree = true; 73 | rocmSupport = true; 74 | }; 75 | overlays = [ 76 | hf-nix 77 | overlay 78 | ]; 79 | }; 80 | 81 | # Instantiate nixpkgs for the given CUDA versions. Returns 82 | # an attribute set like `{ "12.4" = ; ... }`. 83 | pkgsForCudaVersions = 84 | cudaVersions: 85 | builtins.listToAttrs ( 86 | map (cudaVersion: { 87 | name = cudaVersion; 88 | value = import nixpkgs { 89 | inherit system; 90 | config = { 91 | allowUnfree = true; 92 | cudaSupport = true; 93 | }; 94 | overlays = [ 95 | hf-nix 96 | overlay 97 | (overlayForCudaVersion cudaVersion) 98 | ]; 99 | }; 100 | }) cudaVersions 101 | ); 102 | 103 | pkgsByCudaVer = pkgsForCudaVersions cudaVersions; 104 | 105 | pkgsForRocmVersions = 106 | rocmVersions: 107 | builtins.listToAttrs ( 108 | map (rocmVersion: { 109 | name = rocmVersion; 110 | value = import nixpkgs { 111 | inherit system; 112 | config = { 113 | allowUnfree = true; 114 | rocmSupport = true; 115 | }; 116 | overlays = [ 117 | hf-nix 118 | overlay 119 | (overlayForRocmVersion rocmVersion) 120 | ]; 121 | }; 122 | }) rocmVersions 123 | ); 124 | 125 | pkgsByRocmVer = pkgsForRocmVersions rocmVersions; 126 | 127 | in 128 | map (pkgsForVersions pkgsByCudaVer) (buildConfigs system) 129 | -------------------------------------------------------------------------------- /lib/deps.nix: -------------------------------------------------------------------------------- 1 | { 2 | lib, 3 | }: 4 | 5 | { 6 | pkgs, 7 | torch, 8 | deps, 9 | }: 10 | 11 | let 12 | knownDeps = with pkgs.cudaPackages; { 13 | "cutlass_2_10" = [ 14 | pkgs.cutlass_2_10 15 | ]; 16 | "cutlass_3_5" = [ 17 | pkgs.cutlass_3_5 18 | ]; 19 | "cutlass_3_6" = [ 20 | pkgs.cutlass_3_6 21 | ]; 22 | "cutlass_3_8" = [ 23 | pkgs.cutlass_3_8 24 | ]; 25 | "torch" = [ 26 | torch 27 | torch.cxxdev 28 | ]; 29 | }; 30 | 31 | in 32 | let 33 | depToPkg = 34 | dep: 35 | assert lib.assertMsg (builtins.hasAttr dep knownDeps) "Unknown dependency: ${dep}"; 36 | knownDeps.${dep}; 37 | in 38 | lib.flatten (map depToPkg deps) 39 | -------------------------------------------------------------------------------- /lib/join-paths/default.nix: -------------------------------------------------------------------------------- 1 | args@{ 2 | pkgs, 3 | 4 | name, 5 | 6 | # Attribute set with names to paths. 7 | namePaths, 8 | 9 | preferLocalBuild ? true, 10 | allowSubstitutes ? false, 11 | }: 12 | let 13 | inherit (pkgs) lib; 14 | args_ = removeAttrs args [ 15 | "name" 16 | "pkgs" 17 | "namePaths" 18 | ]; 19 | # Iterating over pairs in bash sucks, so let's generate 20 | # the commands in Nix instead. 21 | copyPath = path: pkg: '' 22 | mkdir -p ${placeholder "out"}/${path} 23 | cp -r ${pkg}/* ${placeholder "out"}/${path} 24 | ''; 25 | in 26 | pkgs.runCommand name args_ (lib.concatStringsSep "\n" (lib.mapAttrsToList copyPath namePaths)) 27 | -------------------------------------------------------------------------------- /lib/source-set.nix: -------------------------------------------------------------------------------- 1 | { lib }: 2 | 3 | path: 4 | 5 | let 6 | inherit (lib) fileset; 7 | readToml = path: builtins.fromTOML (builtins.readFile path); 8 | readBuildConfig = path: readToml (path + "/build.toml"); 9 | buildConfig = readBuildConfig path; 10 | nameToPath = path: name: path + "/${name}"; 11 | kernels = buildConfig.kernel or { }; 12 | extConfig = buildConfig.torch or { }; 13 | pyExt = 14 | extConfig.pyext or [ 15 | "py" 16 | "pyi" 17 | ]; 18 | pyFilter = file: builtins.any (ext: file.hasExt ext) pyExt; 19 | extSrc = extConfig.src or [ ] ++ [ "build.toml" ]; 20 | pySrcSet = fileset.fileFilter pyFilter (path + "/torch-ext"); 21 | kernelsSrc = fileset.unions ( 22 | lib.flatten (lib.mapAttrsToList (name: buildConfig: map (nameToPath path) buildConfig.src) kernels) 23 | ); 24 | srcSet = fileset.unions (map (nameToPath path) extSrc); 25 | in 26 | fileset.toSource { 27 | root = path; 28 | fileset = fileset.unions [ 29 | kernelsSrc 30 | srcSet 31 | pySrcSet 32 | ]; 33 | } 34 | -------------------------------------------------------------------------------- /lib/torch-extension-noarch/default.nix: -------------------------------------------------------------------------------- 1 | { 2 | stdenv, 3 | extensionName, 4 | rev, 5 | 6 | build2cmake, 7 | torch, 8 | 9 | src, 10 | }: 11 | 12 | stdenv.mkDerivation (prevAttrs: { 13 | name = "${extensionName}-torch-ext"; 14 | 15 | inherit src; 16 | 17 | # Add Torch as a dependency, so that devshells for universal kernels 18 | # also get torch as a build input. 19 | buildInputs = [ torch ]; 20 | 21 | nativeBuildInputs = [ build2cmake ]; 22 | 23 | dontBuild = true; 24 | 25 | # We do not strictly need this, since we don't use the setuptools-based 26 | # build. But `build2cmake` does proper validation of the build.toml, so 27 | # we run it anyway. 28 | postPatch = '' 29 | build2cmake generate-torch --ops-id ${rev} build.toml 30 | ''; 31 | 32 | installPhase = '' 33 | mkdir -p $out 34 | cp -r torch-ext/${extensionName} $out/ 35 | ''; 36 | }) 37 | -------------------------------------------------------------------------------- /lib/torch-extension/_ops.py.in: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import _@EXTENSION_NAME@ 3 | ops = torch.ops._@EXTENSION_NAME@ 4 | 5 | def add_op_namespace_prefix(op_name: str): 6 | """ 7 | Prefix op by namespace. 8 | """ 9 | return f"_@EXTENSION_NAME@::{op_name}" 10 | -------------------------------------------------------------------------------- /lib/torch-extension/default.nix: -------------------------------------------------------------------------------- 1 | { 2 | extensionName, 3 | nvccThreads, 4 | rev, 5 | 6 | # Wheter to strip rpath for non-nix use. 7 | stripRPath ? false, 8 | 9 | src, 10 | 11 | config, 12 | cudaSupport ? config.cudaSupport, 13 | rocmSupport ? config.rocmSupport, 14 | 15 | lib, 16 | stdenv, 17 | cudaPackages, 18 | cmake, 19 | cmakeNvccThreadsHook, 20 | ninja, 21 | python3, 22 | kernel-abi-check, 23 | build2cmake, 24 | rocmPackages, 25 | 26 | apple-sdk_15, 27 | extraDeps ? [ ], 28 | torch, 29 | 30 | doAbiCheck, 31 | }: 32 | 33 | let 34 | # CLR that uses the provided stdenv, which can be different from the default 35 | # to support old glibc/libstdc++ versions. 36 | clr = ( 37 | rocmPackages.clr.override { 38 | clang = rocmPackages.llvm.clang.override { 39 | inherit stdenv; 40 | bintools = rocmPackages.llvm.bintools.override { libc = stdenv.cc.libc; }; 41 | glibc = stdenv.cc.libc; 42 | }; 43 | } 44 | ); 45 | 46 | in 47 | stdenv.mkDerivation (prevAttrs: { 48 | name = "${extensionName}-torch-ext"; 49 | 50 | inherit doAbiCheck nvccThreads src; 51 | 52 | # Generate build files. 53 | postPatch = '' 54 | build2cmake generate-torch --backend ${ 55 | if cudaSupport then 56 | "cuda" 57 | else if rocmSupport then 58 | "rocm" 59 | else 60 | "metal" 61 | } --ops-id ${rev} build.toml 62 | ''; 63 | 64 | # hipify copies files, but its target is run in the CMake build and install 65 | # phases. Since some of the files come from the Nix store, this fails the 66 | # second time around. 67 | preInstall = '' 68 | chmod -R u+w . 69 | ''; 70 | 71 | nativeBuildInputs = 72 | [ 73 | kernel-abi-check 74 | cmake 75 | ninja 76 | build2cmake 77 | ] 78 | ++ lib.optionals cudaSupport [ 79 | cmakeNvccThreadsHook 80 | cudaPackages.cuda_nvcc 81 | ] 82 | ++ lib.optionals rocmSupport [ 83 | clr 84 | ]; 85 | 86 | buildInputs = 87 | [ 88 | torch 89 | torch.cxxdev 90 | ] 91 | ++ lib.optionals cudaSupport ( 92 | with cudaPackages; 93 | [ 94 | cuda_cudart 95 | 96 | # Make dependent on build configuration dependencies once 97 | # the Torch dependency is gone. 98 | cuda_cccl 99 | libcublas 100 | libcusolver 101 | libcusparse 102 | ] 103 | ) 104 | #++ lib.optionals rocmSupport (with rocmPackages; [ clr rocm-core ]) 105 | ++ lib.optionals stdenv.hostPlatform.isDarwin [ 106 | apple-sdk_15 107 | ] 108 | ++ extraDeps; 109 | 110 | env = 111 | lib.optionalAttrs cudaSupport { 112 | CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}"; 113 | TORCH_CUDA_ARCH_LIST = 114 | if cudaPackages.cudaOlder "12.8" then 115 | "7.0;7.5;8.0;8.6;8.9;9.0+PTX" 116 | else 117 | "7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0+PTX"; 118 | } 119 | // lib.optionalAttrs rocmSupport { 120 | PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs; 121 | }; 122 | 123 | # If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++. 124 | dontSetupCUDAToolkitCompilers = true; 125 | 126 | cmakeFlags = 127 | [ 128 | (lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python") 129 | ] 130 | ++ lib.optionals cudaSupport [ 131 | (lib.cmakeFeature "CMAKE_CUDA_HOST_COMPILER" "${stdenv.cc}/bin/g++") 132 | ] 133 | ++ lib.optionals rocmSupport [ 134 | # Ensure sure that we use HIP from our CLR override and not HIP from 135 | # the symlink-joined ROCm toolkit. 136 | (lib.cmakeFeature "CMAKE_HIP_COMPILER_ROCM_ROOT" "${clr}") 137 | (lib.cmakeFeature "HIP_ROOT_DIR" "${clr}") 138 | ]; 139 | 140 | postInstall = 141 | '' 142 | ( 143 | cd .. 144 | cp -r torch-ext/${extensionName} $out/ 145 | ) 146 | cp $out/_${extensionName}_*/* $out/${extensionName} 147 | rm -rf $out/_${extensionName}_* 148 | '' 149 | + lib.optionalString stripRPath '' 150 | find $out/${extensionName} -name '*.so' \ 151 | -exec patchelf --set-rpath "" {} \; 152 | ''; 153 | 154 | doInstallCheck = true; 155 | 156 | passthru = { 157 | inherit torch; 158 | }; 159 | }) 160 | -------------------------------------------------------------------------------- /lib/version-utils.nix: -------------------------------------------------------------------------------- 1 | { lib }: 2 | 3 | { 4 | flattenVersion = version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.pad 2 version); 5 | abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98"; 6 | } 7 | -------------------------------------------------------------------------------- /overlay.nix: -------------------------------------------------------------------------------- 1 | final: prev: { 2 | cmakeNvccThreadsHook = prev.callPackage ./pkgs/cmake-nvcc-threads-hook { }; 3 | 4 | # Local packages 5 | 6 | kernel-abi-check = prev.callPackage ./pkgs/kernel-abi-check { }; 7 | 8 | build2cmake = prev.callPackage ./pkgs/build2cmake { }; 9 | 10 | stdenvGlibc_2_27 = prev.callPackage ./pkgs/stdenv-glibc-2_27 { }; 11 | } 12 | -------------------------------------------------------------------------------- /pkgs/build2cmake/default.nix: -------------------------------------------------------------------------------- 1 | { 2 | lib, 3 | rustPlatform, 4 | pkg-config, 5 | libgit2, 6 | openssl, 7 | }: 8 | 9 | let 10 | version = (builtins.fromTOML (builtins.readFile ../../build2cmake/Cargo.toml)).package.version; 11 | in 12 | rustPlatform.buildRustPackage { 13 | inherit version; 14 | pname = "build2cmake"; 15 | 16 | src = 17 | let 18 | sourceFiles = 19 | file: 20 | file.name == "Cargo.toml" 21 | || file.name == "Cargo.lock" 22 | || file.name == "pyproject.toml" 23 | || file.name == "pyproject_universal.toml" 24 | || file.name == "cuda_supported_archs.json" 25 | || (builtins.any file.hasExt [ 26 | "cmake" 27 | "h" 28 | "py" 29 | "rs" 30 | ]); 31 | in 32 | lib.fileset.toSource { 33 | root = ../../build2cmake; 34 | fileset = lib.fileset.fileFilter sourceFiles ../../build2cmake; 35 | }; 36 | 37 | cargoLock = { 38 | lockFile = ../../build2cmake/Cargo.lock; 39 | }; 40 | 41 | nativeBuildInputs = [ pkg-config ]; 42 | 43 | buildInputs = [ 44 | libgit2 45 | openssl.dev 46 | ]; 47 | 48 | meta = { 49 | description = "Create cmake build infrastructure from build.toml files"; 50 | }; 51 | } 52 | -------------------------------------------------------------------------------- /pkgs/cmake-nvcc-threads-hook/cmake-nvcc-threads-hook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | _setNvccThreadsHook() { 4 | if [ -z "${nvccThreads}" ] || [ "${nvccThreads}" -ne "${nvccThreads}" ] 2>/dev/null; then 5 | >&2 echo "Number of nvcc threads is not (correctly) set, setting to 4" 6 | nvccThreads=4 7 | fi 8 | 9 | # Ensure that we do not use more threads than build cores. 10 | nvccThreads=$((NIX_BUILD_CORES < nvccThreads ? NIX_BUILD_CORES : nvccThreads )) 11 | 12 | # Change the number of build cores so that build cores * threads is 13 | # within bounds. 14 | export NIX_BUILD_CORES=$(($NIX_BUILD_CORES / nvccThreads)) 15 | 16 | appendToVar cmakeFlags -DNVCC_THREADS="${nvccThreads}" 17 | } 18 | 19 | preConfigureHooks+=(_setNvccThreadsHook) 20 | -------------------------------------------------------------------------------- /pkgs/cmake-nvcc-threads-hook/default.nix: -------------------------------------------------------------------------------- 1 | { makeSetupHook }: 2 | 3 | makeSetupHook { 4 | name = "cmake-nvcc-threads-hook"; 5 | } ./cmake-nvcc-threads-hook.sh 6 | -------------------------------------------------------------------------------- /pkgs/kernel-abi-check/default.nix: -------------------------------------------------------------------------------- 1 | { 2 | lib, 3 | rustPlatform, 4 | }: 5 | 6 | let 7 | version = (builtins.fromTOML (builtins.readFile ../../kernel-abi-check/Cargo.toml)).package.version; 8 | in 9 | rustPlatform.buildRustPackage { 10 | inherit version; 11 | pname = "kernel-abi-check"; 12 | 13 | src = 14 | let 15 | sourceFiles = 16 | file: 17 | file.name == "Cargo.toml" 18 | || file.name == "Cargo.lock" 19 | || file.name == "manylinux-policy.json" 20 | || file.hasExt "rs" 21 | || file.name == "stable_abi.toml"; 22 | in 23 | lib.fileset.toSource { 24 | root = ../../kernel-abi-check; 25 | fileset = lib.fileset.fileFilter sourceFiles ../../kernel-abi-check; 26 | }; 27 | 28 | cargoLock = { 29 | lockFile = ../../kernel-abi-check/Cargo.lock; 30 | }; 31 | 32 | setupHook = ./kernel-abi-check-hook.sh; 33 | 34 | meta = { 35 | description = "Check glibc and libstdc++ ABI compat"; 36 | }; 37 | } 38 | -------------------------------------------------------------------------------- /pkgs/kernel-abi-check/kernel-abi-check-hook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | _checkAbiHook() { 4 | if [ -z "${doAbiCheck:-}" ]; then 5 | echo "Skipping ABI check" 6 | else 7 | echo "Checking of ABI compatibility" 8 | find $out/${extensionName} -name '*.so' \ 9 | -exec kernel-abi-check {} \; 10 | fi 11 | } 12 | 13 | postInstallCheckHooks+=(_checkAbiHook) 14 | -------------------------------------------------------------------------------- /pkgs/stdenv-glibc-2_27/default.nix: -------------------------------------------------------------------------------- 1 | { 2 | config, 3 | cudaSupport ? config.cudaSupport, 4 | fetchFromGitHub, 5 | overrideCC, 6 | system, 7 | wrapBintoolsWith, 8 | wrapCCWith, 9 | gcc12Stdenv, 10 | stdenv, 11 | bintools-unwrapped, 12 | cudaPackages, 13 | libgcc, 14 | }: 15 | 16 | let 17 | nixpkgs_20191230 = import (fetchFromGitHub { 18 | owner = "NixOS"; 19 | repo = "nixpkgs"; 20 | rev = "a9eb3eed170fa916e0a8364e5227ee661af76fde"; 21 | hash = "sha256-1ycrr9HMrGA3ZDM8qmKcZICBupE5UShnIIhPRWdvAzA="; 22 | }) { inherit system; }; 23 | 24 | glibc_2_27 = nixpkgs_20191230.glibc.overrideAttrs (prevAttrs: { 25 | # Slight adjustments for compatibility with modern nixpkgs: 26 | # 27 | # - pname is required 28 | # - an additional getent output 29 | # - passthru libgcc 30 | 31 | pname = "glibc"; 32 | 33 | outputs = prevAttrs.outputs ++ [ "getent" ]; 34 | 35 | postInstall = 36 | prevAttrs.postInstall 37 | + '' 38 | install -Dm755 $bin/bin/getent -t $getent/bin 39 | ''; 40 | 41 | passthru = prevAttrs.passthru // { 42 | # Should be stdenv's gcc, but we don't have access to it. 43 | libgcc = stdenv.cc.cc.libgcc; 44 | }; 45 | }); 46 | 47 | stdenvWith = 48 | newGlibc: newGcc: stdenv: 49 | let 50 | # We need gcc to have a libgcc/libstdc++ that is compatible with 51 | # glibc. We do this in three steps to avoid an infinite recursion: 52 | # (1) we create an stdenv with gcc and glibc; (2) we rebuild gcc using 53 | # this stdenv, so that we have a libgcc/libstdc++ that is compatible 54 | # with glibc; (3) we create the final stdenv that contains the compatible 55 | # gcc + glibc. 56 | onlyGlibc = overrideCC stdenv (wrapCCWith { 57 | cc = newGcc; 58 | bintools = wrapBintoolsWith { 59 | bintools = bintools-unwrapped; 60 | libc = newGlibc; 61 | }; 62 | }); 63 | compilerWrapped = wrapCCWith rec { 64 | cc = newGcc.override { stdenv = onlyGlibc; }; 65 | bintools = wrapBintoolsWith { 66 | bintools = bintools-unwrapped; 67 | libc = newGlibc; 68 | }; 69 | libcxx = cc.lib; 70 | }; 71 | in 72 | overrideCC stdenv compilerWrapped; 73 | 74 | in 75 | stdenvWith glibc_2_27 (if cudaSupport then cudaPackages.backendStdenv else gcc12Stdenv).cc.cc stdenv 76 | -------------------------------------------------------------------------------- /scripts/gen_variants_markdown.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env nix-shell 2 | #!nix-shell -i python3 -p python3 3 | import json 4 | from pathlib import Path 5 | 6 | _PLATFORM_NAMES = { 7 | "cuda": "CUDA", 8 | "rocm": "ROCm", 9 | } 10 | 11 | HEADER = """# Build variants 12 | 13 | A kernel can be compliant for a specific compute framework (e.g. CUDA) or 14 | architecture (e.g. x86_64). For compliance with a compute framework and 15 | architecture combination, all the build variants listed below must be 16 | available. This list will be updated as new PyTorch versions are released.\n 17 | """ 18 | 19 | FOOTER = """## Universal 20 | 21 | Kernels that are in pure Python (e.g. Triton kernels) only need to provide 22 | a single build variant: 23 | 24 | - `torch-universal` 25 | """ 26 | 27 | 28 | def json_to_markdown(): 29 | project_root = Path(__file__).parent.parent 30 | 31 | with open(project_root / "build-variants.json", "r") as f: 32 | data = json.load(f) 33 | 34 | with open(project_root / "docs" / "build-variants.md", "w") as f: 35 | f.write(HEADER) 36 | for arch, platforms in data.items(): 37 | for platform, variants in platforms.items(): 38 | f.write(f"## {_PLATFORM_NAMES[platform]} {arch}\n\n") 39 | 40 | for variant in variants: 41 | f.write(f"- `{variant}`\n") 42 | 43 | f.write("\n") 44 | f.write(FOOTER) 45 | 46 | 47 | if __name__ == "__main__": 48 | json_to_markdown() 49 | -------------------------------------------------------------------------------- /scripts/init-kernel.py: -------------------------------------------------------------------------------- 1 | # This script creates the necessary files for a new kernel example in the specified directory. 2 | # 3 | # Example Usage: 4 | # $ uv run scripts/init-kernel.py relu 5 | # 6 | # Created directory: relu 7 | # 8 | # relu/ 9 | # ├── relu_kernel/ 10 | # │ └── relu.cu 11 | # ├── tests/ 12 | # │ ├── __init__.py 13 | # │ └── test_relu.py 14 | # ├── torch-ext/ 15 | # │ ├── relu/ 16 | # │ │ └── __init__.py 17 | # │ ├── torch_binding.cpp 18 | # │ └── torch_binding.h 19 | # ├── build.toml 20 | # └── flake.nix 21 | # 22 | # ✓ Success! All files for the ReLU example have been created successfully. 23 | # 24 | # Next steps: 25 | # 1. Build the kernel: cd relu && git add . && nix develop -L 26 | # 2. Run the tests: pytest -vv tests/ 27 | 28 | import os 29 | import argparse 30 | import pathlib 31 | 32 | 33 | class Colors: 34 | HEADER = "\033[95m" 35 | BLUE = "\033[94m" 36 | CYAN = "\033[96m" 37 | GREEN = "\033[92m" 38 | YELLOW = "\033[93m" 39 | RED = "\033[91m" 40 | ENDC = "\033[0m" 41 | BOLD = "\033[1m" 42 | UNDERLINE = "\033[4m" 43 | GREY = "\033[90m" 44 | 45 | 46 | def create_file_with_content(file_path: str, content: str): 47 | """Creates a file at 'file_path' with the specified content.""" 48 | directory = os.path.dirname(file_path) 49 | if directory and not os.path.exists(directory): 50 | os.makedirs(directory) 51 | 52 | with open(file_path, "w") as f: 53 | f.write(content) 54 | 55 | 56 | # Generate a tree view of the created files 57 | def print_tree(directory: str, prefix: str = ""): 58 | entries = sorted(os.listdir(directory)) 59 | 60 | # Process directories first, then files 61 | dirs = [e for e in entries if os.path.isdir(os.path.join(directory, e))] 62 | files = [e for e in entries if os.path.isfile(os.path.join(directory, e))] 63 | 64 | # Process all items except the last one 65 | count = len(dirs) + len(files) 66 | 67 | # Print directories 68 | for i, dirname in enumerate(dirs): 69 | is_last_dir = i == len(dirs) - 1 and len(files) == 0 70 | connector = "└── " if is_last_dir else "├── " 71 | print( 72 | f" {prefix}{connector}{Colors.BOLD}{Colors.BLUE}{dirname}/{Colors.ENDC}" 73 | ) 74 | 75 | # Prepare the prefix for the next level 76 | next_prefix = prefix + (" " if is_last_dir else "│ ") 77 | print_tree(os.path.join(directory, dirname), next_prefix) 78 | 79 | # Print files 80 | for i, filename in enumerate(files): 81 | is_last = i == len(files) - 1 82 | connector = "└── " if is_last else "├── " 83 | file_color = "" 84 | 85 | print(f" {prefix}{connector}{file_color}{filename}{Colors.ENDC}") 86 | 87 | 88 | def main(): 89 | # Get the directory where this script is located 90 | script_dir = pathlib.Path(__file__).parent.resolve().parent.resolve() 91 | 92 | # Create argument parser 93 | parser = argparse.ArgumentParser( 94 | description="Create ReLU example files in the specified directory" 95 | ) 96 | parser.add_argument( 97 | "target_dir", help="Target directory where files will be created" 98 | ) 99 | args = parser.parse_args() 100 | 101 | # Get the target directory from arguments 102 | target_dir = args.target_dir 103 | 104 | # Create the target directory if it doesn't exist 105 | if not os.path.exists(target_dir): 106 | os.makedirs(target_dir) 107 | print( 108 | f"\n{Colors.CYAN}{Colors.BOLD}Created directory: {Colors.BOLD}{target_dir}{Colors.ENDC}\n" 109 | ) 110 | else: 111 | print( 112 | f"\n{Colors.CYAN}{Colors.BOLD}Directory already exists: {Colors.BOLD}{target_dir}{Colors.ENDC}\n" 113 | ) 114 | 115 | # get files from examples/relu 116 | relu_dir = script_dir / "examples" / "relu" 117 | for root, _, files in os.walk(relu_dir): 118 | for file in files: 119 | file_path = os.path.join(root, file) 120 | with open(file_path, "r") as f: 121 | content = f.read() 122 | 123 | # Replace kernel-builder.url with path:../ in flake.nix 124 | if file_path.endswith("flake.nix"): 125 | kernel_builder_url_start = content.find("kernel-builder.url =") 126 | kernel_builder_url_end = content.find(";", kernel_builder_url_start) 127 | content = ( 128 | content[:kernel_builder_url_start] 129 | + 'kernel-builder.url = "path:../"' 130 | + content[kernel_builder_url_end:] 131 | ) 132 | 133 | target_file = file_path.replace(str(relu_dir), target_dir) 134 | create_file_with_content(target_file, content) 135 | 136 | print(f" {Colors.BOLD}{target_dir}/{Colors.ENDC}") 137 | print_tree(target_dir) 138 | 139 | print( 140 | f"\n{Colors.GREEN}{Colors.BOLD}✓ Success!{Colors.ENDC} All files for the ReLU example have been created successfully." 141 | ) 142 | 143 | print(f"\n{Colors.CYAN}{Colors.BOLD}Next steps:{Colors.ENDC}") 144 | 145 | commands = [ 146 | "nix run nixpkgs#cachix -- use huggingface", 147 | f"cd {target_dir}", 148 | "git add .", 149 | "nix develop -L", 150 | ] 151 | 152 | for index, command in enumerate(commands, start=1): 153 | print( 154 | f" {Colors.YELLOW}{index}.{Colors.ENDC} {Colors.BOLD}{command}{Colors.ENDC}" 155 | ) 156 | 157 | print( 158 | f"\none line build:\n{Colors.GREY}{Colors.BOLD}{' && '.join(commands)}{Colors.ENDC}{Colors.ENDC}" 159 | ) 160 | 161 | print(f"\n{Colors.CYAN}{Colors.BOLD}Run the tests{Colors.ENDC}") 162 | print( 163 | f" {Colors.YELLOW}{1}.{Colors.ENDC} {Colors.BOLD}pytest -vv tests/{Colors.ENDC}" 164 | ) 165 | 166 | print("") 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /tests/Dockerfile.test-kernel: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1.4 2 | ARG PYTHON_VERSION=3.10 3 | # Ideally we'd test with 11.8, but the GELU kernel is subtly off. 4 | ARG CUDA_VERSION=12.1.0 5 | ARG UBUNTU_VERSION=18.04 6 | ARG TORCH_VERSION=2.5.0 7 | 8 | FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} as base 9 | 10 | # Set environment variables 11 | ENV DEBIAN_FRONTEND=noninteractive \ 12 | PYTHONUNBUFFERED=1 \ 13 | PATH="/root/.local/bin:/root/.cargo/bin:${PATH}" \ 14 | NVIDIA_VISIBLE_DEVICES=all \ 15 | NVIDIA_DRIVER_CAPABILITIES=compute,utility 16 | 17 | # Install system dependencies 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | curl \ 20 | python3 \ 21 | python3-pip \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | # Install uv package manager 25 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh 26 | 27 | # Set working directory 28 | WORKDIR /app 29 | 30 | # Need to re-declare ARG after FROM for use in RUN 31 | ARG CUDA_VERSION 32 | ARG TORCH_VERSION 33 | ARG PYTHON_VERSION 34 | 35 | RUN echo "Building with CUDA_VERSION=${CUDA_VERSION}, TORCH_VERSION=${TORCH_VERSION}, PYTHON_VERSION=${PYTHON_VERSION}" 36 | 37 | # Initialize uv and create virtual env 38 | RUN uv init --app kernel-test --python "${PYTHON_VERSION}" 39 | 40 | # Move into the app 41 | WORKDIR /app/kernel-test 42 | 43 | # Install PyTorch with the appropriate CUDA version 44 | 45 | # NOTE: `markupsafe` must be installed first to avoid a conflict with the torch package. 46 | # See: https://github.com/astral-sh/uv/issues/9647 47 | 48 | RUN CUDA_MAJOR_MINOR=$(echo ${CUDA_VERSION} | cut -d'.' -f1,2) && \ 49 | case ${CUDA_MAJOR_MINOR} in \ 50 | "11.8") CUDA_TAG="cu118" ;; \ 51 | "12.1") CUDA_TAG="cu121" ;; \ 52 | "12.2") CUDA_TAG="cu122" ;; \ 53 | "12.4") CUDA_TAG="cu124" ;; \ 54 | *) CUDA_TAG="" ;; \ 55 | esac && \ 56 | if [ -n "${CUDA_TAG}" ]; then \ 57 | echo "Installing PyTorch ${TORCH_VERSION} with CUDA ${CUDA_TAG}" && \ 58 | uv add markupsafe --default-index "https://pypi.org/simple" && \ 59 | uv add "torch==${TORCH_VERSION}" --index-url "https://download.pytorch.org/whl/${CUDA_TAG}"; \ 60 | else \ 61 | echo "Installing PyTorch ${TORCH_VERSION} without CUDA-specific index" && \ 62 | uv add "torch==${TORCH_VERSION}"; \ 63 | fi 64 | 65 | # add pytest for runtime tests 66 | RUN uv add numpy pytest 67 | 68 | # Copy kernels and tests 69 | COPY activation-kernel ./activation-kernel 70 | COPY cutlass-gemm-kernel ./cutlass-gemm-kernel 71 | COPY silu-and-mul-universal-kernel ./silu-and-mul-universal-kernel 72 | COPY examples/activation/tests ./activation_tests 73 | COPY examples/cutlass-gemm/tests ./tests/cutlass_gemm_tests 74 | 75 | # Run tests 76 | ENV PYTHONPATH="activation-kernel:cutlass-gemm-kernel:silu-and-mul-universal-kernel:$PYTHONPATH" 77 | CMD ["/bin/sh", "-c", ".venv/bin/pytest", "activation_tests", "cutlass_gemm_tests"] 78 | 79 | # We only care about importing, the kernel is trivial. 80 | CMD ["/bin/sh", "-c", ".venv/bin/python", "-c", "'import silu_and_mul_universal'"] 81 | -------------------------------------------------------------------------------- /versions.nix: -------------------------------------------------------------------------------- 1 | { lib }: 2 | 3 | rec { 4 | torchVersions = [ 5 | { 6 | torchVersion = "2.6"; 7 | cudaVersion = "11.8"; 8 | cxx11Abi = false; 9 | upstreamVariant = true; 10 | } 11 | { 12 | torchVersion = "2.6"; 13 | cudaVersion = "11.8"; 14 | cxx11Abi = true; 15 | upstreamVariant = true; 16 | } 17 | { 18 | torchVersion = "2.6"; 19 | cudaVersion = "12.4"; 20 | cxx11Abi = false; 21 | upstreamVariant = true; 22 | } 23 | { 24 | torchVersion = "2.6"; 25 | cudaVersion = "12.4"; 26 | cxx11Abi = true; 27 | upstreamVariant = true; 28 | } 29 | { 30 | torchVersion = "2.6"; 31 | cudaVersion = "12.6"; 32 | cxx11Abi = false; 33 | upstreamVariant = true; 34 | } 35 | { 36 | torchVersion = "2.6"; 37 | cudaVersion = "12.6"; 38 | cxx11Abi = true; 39 | upstreamVariant = true; 40 | } 41 | { 42 | torchVersion = "2.6"; 43 | rocmVersion = "6.2.4"; 44 | cxx11Abi = true; 45 | upstreamVariant = true; 46 | } 47 | 48 | { 49 | torchVersion = "2.7"; 50 | cudaVersion = "11.8"; 51 | cxx11Abi = true; 52 | upstreamVariant = true; 53 | } 54 | { 55 | torchVersion = "2.7"; 56 | cudaVersion = "12.6"; 57 | cxx11Abi = true; 58 | upstreamVariant = true; 59 | } 60 | { 61 | torchVersion = "2.7"; 62 | cudaVersion = "12.8"; 63 | cxx11Abi = true; 64 | upstreamVariant = true; 65 | } 66 | { 67 | torchVersion = "2.7"; 68 | rocmVersion = "6.3.4"; 69 | cxx11Abi = true; 70 | upstreamVariant = true; 71 | } 72 | { 73 | torchVersion = "2.7"; 74 | cxx11Abi = true; 75 | metal = true; 76 | # Set to false for now, needs more testing. 77 | upstreamVariant = false; 78 | } 79 | 80 | # Non-standard versions; not included in bundle builds. 81 | { 82 | torchVersion = "2.7"; 83 | cudaVersion = "12.9"; 84 | cxx11Abi = true; 85 | } 86 | ]; 87 | 88 | cudaVersions = 89 | let 90 | withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions; 91 | in 92 | builtins.map (torchVersion: torchVersion.cudaVersion) withCuda; 93 | 94 | rocmVersions = 95 | let 96 | withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions; 97 | in 98 | builtins.map (torchVersion: torchVersion.rocmVersion) withRocm; 99 | 100 | # Upstream only builds aarch64 for CUDA >= 12.6. 101 | isCudaSupported = 102 | system: torchVersion: 103 | system == "x86_64-linux" 104 | || ( 105 | system == "aarch64-linux" && lib.strings.versionAtLeast (torchVersion.cudaVersion or "0.0") "12.6" 106 | ); 107 | 108 | isMetalSupported = system: torchVersion: system == "aarch64-darwin" && torchVersion ? metal; 109 | 110 | # ROCm only builds on x86_64. 111 | isRocmSupported = system: torchVersion: system == "x86_64-linux" && torchVersion ? rocmVersion; 112 | 113 | isSupported = 114 | system: torchVersion: 115 | (isCudaSupported system torchVersion) 116 | || (isMetalSupported system torchVersion) 117 | || (isRocmSupported system torchVersion); 118 | 119 | computeFramework = 120 | buildConfig: 121 | if buildConfig ? cudaVersion then 122 | "cuda" 123 | else if buildConfig ? metal then 124 | "metal" 125 | else if buildConfig ? "rocmVersion" then 126 | "rocm" 127 | else 128 | throw "Could not find compute framework: no CUDA or ROCm version specified and Metal is not enabled"; 129 | 130 | # All build configurations supported by Torch. 131 | buildConfigs = 132 | system: 133 | let 134 | supported = builtins.filter (isSupported system) torchVersions; 135 | in 136 | map (version: version // { gpu = computeFramework version; }) supported; 137 | 138 | # Upstream build variants. 139 | buildVariants = 140 | system: 141 | let 142 | inherit (import ./lib/version-utils.nix { inherit lib; }) abiString flattenVersion; 143 | computeString = 144 | buildConfig: 145 | if buildConfig.gpu == "cuda" then 146 | "cu${flattenVersion (lib.versions.majorMinor buildConfig.cudaVersion)}" 147 | else if buildConfig.gpu == "rocm" then 148 | "rocm${flattenVersion (lib.versions.majorMinor buildConfig.rocmVersion)}" 149 | else 150 | throw "Unknown compute framework: ${buildConfig.gpu}"; 151 | buildName = 152 | buildConfig: 153 | "torch${flattenVersion buildConfig.torchVersion}-${abiString buildConfig.cxx11Abi}-${computeString buildConfig}-${system}"; 154 | filterMap = f: xs: builtins.filter (x: x != null) (builtins.map f xs); 155 | in 156 | { 157 | ${system} = lib.zipAttrs ( 158 | filterMap ( 159 | buildConfig: 160 | if buildConfig.upstreamVariant or false then 161 | { 162 | ${buildConfig.gpu} = buildName buildConfig; 163 | } 164 | else 165 | null 166 | ) (buildConfigs system) 167 | ); 168 | }; 169 | } 170 | --------------------------------------------------------------------------------