├── .github
└── workflows
│ ├── build_kernel.yaml
│ ├── build_kernel_rocm.yaml
│ ├── check_variants.yaml
│ ├── docker-build-push.yaml
│ ├── nix_fmt.yaml
│ └── rust.yaml
├── .gitignore
├── README.md
├── build-variants.json
├── build2cmake
├── Cargo.lock
├── Cargo.toml
├── build.rs
├── flake.lock
├── flake.nix
└── src
│ ├── config
│ ├── mod.rs
│ ├── v1.rs
│ └── v2.rs
│ ├── cuda_supported_archs.json
│ ├── fileset.rs
│ ├── main.rs
│ ├── templates
│ ├── _ops.py
│ ├── cuda
│ │ ├── dep-cutlass.cmake
│ │ ├── hipify.py
│ │ ├── kernel.cmake
│ │ ├── preamble.cmake
│ │ ├── setup.py
│ │ ├── torch-binding.cmake
│ │ └── torch-extension.cmake
│ ├── metal
│ │ ├── kernel.cmake
│ │ ├── preamble.cmake
│ │ ├── setup.py
│ │ ├── torch-binding.cmake
│ │ ├── torch-extension.cmake
│ │ └── utils.cmake
│ ├── pyproject.toml
│ ├── registration.h
│ ├── universal
│ │ ├── _ops.py
│ │ └── pyproject.toml
│ └── utils.cmake
│ └── torch
│ ├── cuda.rs
│ ├── metal.rs
│ ├── mod.rs
│ ├── ops_identifier.rs
│ └── universal.rs
├── default.nix
├── dockerfiles
├── Dockerfile
├── Dockerfile.user
└── README.md
├── docs
├── build-variants.md
├── docker.md
├── local-dev.md
├── nix.md
├── toolchain.md
├── why-nix.md
└── writing-kernels.md
├── examples
├── activation
│ ├── LICENSE
│ ├── README.md
│ ├── activation
│ │ ├── activation_kernels.cu
│ │ ├── cuda_compat.h
│ │ └── dispatch_utils.h
│ ├── build.toml
│ ├── flake.nix
│ ├── tests
│ │ ├── __init__.py
│ │ └── kernels
│ │ │ ├── __init__.py
│ │ │ ├── allclose_default.py
│ │ │ ├── test_activation.py
│ │ │ └── utils.py
│ └── torch-ext
│ │ ├── activation
│ │ └── __init__.py
│ │ ├── torch_binding.cpp
│ │ └── torch_binding.h
├── cutlass-gemm
│ ├── build.toml
│ ├── flake.nix
│ ├── gemm.cu
│ ├── tests
│ │ └── test_gemm.py
│ └── torch-ext
│ │ ├── cutlass_gemm
│ │ └── __init__.py
│ │ ├── registration.h
│ │ ├── torch_binding.cpp
│ │ └── torch_binding.h
├── relu
│ ├── build.toml
│ ├── flake.nix
│ ├── relu_cuda
│ │ └── relu.cu
│ ├── relu_metal
│ │ └── relu.mm
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_relu.py
│ └── torch-ext
│ │ ├── relu
│ │ └── __init__.py
│ │ ├── torch_binding.cpp
│ │ └── torch_binding.h
└── silu-and-mul-universal
│ ├── build.toml
│ ├── flake.nix
│ ├── tests
│ └── test_silu_and_mul.py
│ └── torch-ext
│ └── silu_and_mul_universal
│ ├── __init__.py
│ └── silu_and_mul.py
├── flake.lock
├── flake.nix
├── kernel-abi-check
├── Cargo.lock
├── Cargo.toml
├── flake.lock
├── flake.nix
└── src
│ ├── lib.rs
│ ├── main.rs
│ ├── manylinux
│ ├── manylinux-policy.json
│ └── mod.rs
│ ├── python_abi
│ ├── mod.rs
│ └── stable_abi.toml
│ └── version.rs
├── kernel-compliance-check
├── Cargo.lock
├── Cargo.toml
└── src
│ ├── build_variants.json
│ ├── formatter.rs
│ ├── lib.rs
│ ├── main.rs
│ └── models.rs
├── lib
├── build-version.nix
├── build.nix
├── buildsets.nix
├── deps.nix
├── join-paths
│ └── default.nix
├── source-set.nix
├── torch-extension-noarch
│ └── default.nix
├── torch-extension
│ ├── _ops.py.in
│ └── default.nix
└── version-utils.nix
├── overlay.nix
├── pkgs
├── build2cmake
│ └── default.nix
├── cmake-nvcc-threads-hook
│ ├── cmake-nvcc-threads-hook.sh
│ └── default.nix
├── kernel-abi-check
│ ├── default.nix
│ └── kernel-abi-check-hook.sh
└── stdenv-glibc-2_27
│ └── default.nix
├── scripts
├── gen_variants_markdown.py
└── init-kernel.py
├── tests
└── Dockerfile.test-kernel
└── versions.nix
/.github/workflows/build_kernel.yaml:
--------------------------------------------------------------------------------
1 | name: "Build and test kernel"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, synchronize, reopened] # trigger on PRs
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | name: Build kernel
13 | runs-on:
14 | group: aws-g6-12xlarge-plus
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: cachix/install-nix-action@v27
18 | with:
19 | nix_path: nixpkgs=channel:nixos-unstable
20 | - uses: cachix/cachix-action@v14
21 | with:
22 | name: huggingface
23 | authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
24 | env:
25 | USER: github_runner
26 | - name: Build activation kernel
27 | run: ( cd examples/activation && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux )
28 | - name: Copy activation kernel
29 | run: cp -rL examples/activation/result activation-kernel
30 |
31 | - name: Build cutlass GEMM kernel
32 | run: ( cd examples/cutlass-gemm && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux )
33 | - name: Copy cutlass GEMM kernel
34 | run: cp -rL examples/cutlass-gemm/result cutlass-gemm-kernel
35 |
36 | - name: Build relu kernel
37 | run: ( cd examples/relu && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux )
38 | - name: Copy relu kernel
39 | run: cp -rL examples/relu/result relu-kernel
40 |
41 | - name: Build silu-and-mul-universal kernel
42 | run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch26-cxx98-cu124-x86_64-linux )
43 | - name: Copy silu-and-mul-universal kernel
44 | run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel
45 |
46 | - name: Set up Docker Buildx
47 | uses: docker/setup-buildx-action@v3
48 | - name: Build Docker image
49 | uses: docker/build-push-action@v5
50 | with:
51 | context: .
52 | file: tests/Dockerfile.test-kernel
53 | platforms: linux/amd64
54 | load: true
55 | push: false
56 | tags: kernel-builder:latest
57 |
58 | - name: Run Tests
59 | run: |
60 | docker run --gpus all kernel-builder:latest
61 |
--------------------------------------------------------------------------------
/.github/workflows/build_kernel_rocm.yaml:
--------------------------------------------------------------------------------
1 | name: "Build and test kernel (ROCm)"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, synchronize, reopened] # trigger on PRs
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | name: Build kernel
13 | runs-on:
14 | group: aws-g6-12xlarge-plus
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: cachix/install-nix-action@v27
18 | with:
19 | nix_path: nixpkgs=channel:nixos-unstable
20 | - uses: cachix/cachix-action@v14
21 | with:
22 | name: huggingface
23 | authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
24 | env:
25 | USER: github_runner
26 | # For now we only test that there are no regressions in building ROCm
27 | # kernels. Also run tests once we have a ROCm runner.
28 | - name: Build relu kernel
29 | run: ( cd examples/relu && nix build .\#redistributable.torch26-cxx11-rocm62-x86_64-linux -L )
30 |
--------------------------------------------------------------------------------
/.github/workflows/check_variants.yaml:
--------------------------------------------------------------------------------
1 | name: "Check build variants"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, synchronize, reopened] # trigger on PRs
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | name: Check build variants
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - uses: cachix/install-nix-action@v27
17 | with:
18 | nix_path: nixpkgs=channel:nixos-unstable
19 | - name: Generate variants JSON
20 | run: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json
21 | - name: Check if variants JSON is up-to-date
22 | run: |
23 | if git diff --exit-code build-variants.json; then
24 | echo "✅ variants.json is up-to-date"
25 | else
26 | echo "🛑 regenerate variants.json: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json"
27 | exit 1
28 | fi
29 | - name: Generate variants Markdown
30 | run: nix run nixpkgs#python3 scripts/gen_variants_markdown.py
31 | - name: Check if variants Markdown is up-to-date
32 | run: |
33 | if git diff --exit-code docs/build-variants.md; then
34 | echo "✅ docs/build-variants.md is up-to-date"
35 | else
36 | echo "🛑 regenerate docs/build-variants: nix run nixpkgs#python3 scripts/gen_variants_markdown.py"
37 | exit 1
38 | fi
39 |
--------------------------------------------------------------------------------
/.github/workflows/docker-build-push.yaml:
--------------------------------------------------------------------------------
1 | name: Build and Push Docker Image
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | # Only run on changes to the Dockerfile or workflow file
9 | - "Dockerfile"
10 | - "dockerfiles/**"
11 | - ".github/workflows/docker-build-push.yaml"
12 | workflow_dispatch: # Allow manual triggering
13 |
14 | env:
15 | REGISTRY: ghcr.io
16 | IMAGE_NAME: ${{ github.repository_owner }}/kernel-builder
17 |
18 | jobs:
19 | build-and-push-user:
20 | name: Build and Push User Docker Image
21 | runs-on: ubuntu-latest
22 | permissions:
23 | contents: read
24 | packages: write
25 |
26 | steps:
27 | - name: Checkout repository
28 | uses: actions/checkout@v4
29 |
30 | - name: Set up Docker Buildx
31 | uses: docker/setup-buildx-action@v3
32 |
33 | - name: Log in to the Container registry
34 | uses: docker/login-action@v3
35 | with:
36 | registry: ${{ env.REGISTRY }}
37 | username: ${{ github.actor }}
38 | password: ${{ secrets.GITHUB_TOKEN }}
39 |
40 | - name: Extract metadata (tags, labels) for Docker
41 | id: meta
42 | uses: docker/metadata-action@v5
43 | with:
44 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
45 | tags: |
46 | type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'main') }}
47 | type=sha,prefix=user-,format=short
48 | type=ref,prefix=user-,event=branch
49 | type=semver,prefix=user-,pattern={{version}}
50 | type=semver,prefix=user-,pattern={{major}}.{{minor}}
51 |
52 | - name: Build and push Docker image
53 | uses: docker/build-push-action@v5
54 | with:
55 | context: .
56 | file: ./dockerfiles/Dockerfile.user
57 | push: true
58 | tags: ${{ steps.meta.outputs.tags }}
59 | labels: ${{ steps.meta.outputs.labels }}
60 | build-args: |
61 | MAX_JOBS=8
62 | CORES=8
63 | cache-from: type=gha
64 | cache-to: type=gha,mode=max
65 |
66 | build-and-push-root:
67 | name: Build and Push Root Docker Image
68 | runs-on: ubuntu-latest
69 | permissions:
70 | contents: read
71 | packages: write
72 |
73 | steps:
74 | - name: Checkout repository
75 | uses: actions/checkout@v4
76 |
77 | - name: Set up Docker Buildx
78 | uses: docker/setup-buildx-action@v3
79 |
80 | - name: Log in to the Container registry
81 | uses: docker/login-action@v3
82 | with:
83 | registry: ${{ env.REGISTRY }}
84 | username: ${{ github.actor }}
85 | password: ${{ secrets.GITHUB_TOKEN }}
86 |
87 | - name: Extract metadata (tags, labels) for Docker
88 | id: meta-root
89 | uses: docker/metadata-action@v5
90 | with:
91 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
92 | tags: |
93 | type=raw,value=root,enable=${{ github.ref == format('refs/heads/{0}', 'main') }}
94 | type=sha,format=short
95 | type=ref,event=branch
96 | type=semver,pattern={{version}}
97 | type=semver,pattern={{major}}.{{minor}}
98 |
99 | - name: Build and push Docker image
100 | uses: docker/build-push-action@v5
101 | with:
102 | context: .
103 | file: ./dockerfiles/Dockerfile
104 | push: true
105 | tags: ${{ steps.meta-root.outputs.tags }}
106 | labels: ${{ steps.meta-root.outputs.labels }}
107 | build-args: |
108 | MAX_JOBS=8
109 | CORES=8
110 | cache-from: type=gha,scope=root
111 | cache-to: type=gha,mode=max,scope=root
112 |
--------------------------------------------------------------------------------
/.github/workflows/nix_fmt.yaml:
--------------------------------------------------------------------------------
1 | name: "Check Nix formatting"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, synchronize, reopened] # trigger on PRs
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | name: Check Nix formatting
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - uses: cachix/install-nix-action@v27
17 | with:
18 | nix_path: nixpkgs=channel:nixos-unstable
19 | - name: Check formatting
20 | run: nix fmt -- --ci
21 |
--------------------------------------------------------------------------------
/.github/workflows/rust.yaml:
--------------------------------------------------------------------------------
1 | name: Rust
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | fmt:
7 | name: Rustfmt
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v4
11 | - uses: dtolnay/rust-toolchain@stable
12 | - name: Cargo fmt (kernel-abi-check)
13 | run: ( cd kernel-abi-check && cargo fmt --all -- --check )
14 | - name: Cargo fmt (build2cmake)
15 | run: ( cd build2cmake && cargo fmt --all -- --check )
16 |
17 | clippy:
18 | name: Clippy
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: actions/checkout@v4
22 | - uses: actions/cache@v4
23 | with:
24 | path: |
25 | ~/.cargo/bin/
26 | ~/.cargo/registry/index/
27 | ~/.cargo/registry/cache/
28 | ~/.cargo/git/db/
29 | target/
30 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
31 | - uses: dtolnay/rust-toolchain@stable
32 | with:
33 | components: clippy
34 | - name: Clippy (kernel-abi-check)
35 | run: ( cd kernel-abi-check && cargo clippy -- -D warnings )
36 | - name: Clippy (build2cmake)
37 | run: ( cd build2cmake && cargo clippy -- -D warnings )
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .obsidian
2 | target
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # kernel-builder
2 |
3 |
4 |

5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | This repo contains a Nix package that can be used to build custom machine learning kernels for PyTorch. The kernels are built using the [PyTorch C++ Frontend](https://pytorch.org/cppdocs/frontend.html) and can be loaded from the Hub with the [kernels](https://github.com/huggingface/kernels)
14 | Python package.
15 |
16 | This builder is a core component of the larger kernel build/distribution system.
17 |
18 | **Torch 2.7 note:** kernel-builder currently builds Torch 2.7 extensions based on
19 | the [final release candidate](https://dev-discuss.pytorch.org/t/pytorch-release-2-7-0-final-rc-is-available/2898).
20 | If you upload kernels Torch 2.7 kernels, please validate them against
21 | the final Torch 2.7.0 release. In the unlikely case of an ABI-breaking
22 | change, you can rebuild and upload a your kernel once kernel-builder
23 | is updated for the final release.
24 |
25 | ## 🚀 Quick Start
26 |
27 | We provide Docker containers for building kernels. For a quick build:
28 |
29 | ```bash
30 | # Using the prebuilt container
31 | docker run --mount type=bind,source=$(pwd),target=/kernelcode ghcr.io/huggingface/kernel-builder:{SHA}
32 | ```
33 |
34 | or build the container locally:
35 |
36 | ```bash
37 | docker build -t kernel-builder:local -f dockerfiles/Dockerfile .
38 | docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:local
39 | ```
40 |
41 | See [dockerfiles/README.md](./dockerfiles/README.md) for more options, including a user-level container for CI/CD environments.
42 |
43 | # 📚 Documentation
44 |
45 | - [Writing Hub kernels](./docs/writing-kernels.md)
46 | - [Building kernels with Docker](./docs/docker.md)
47 | - [Building kernels with Nix](./docs/nix.md)
48 | - [Local kernel development](docs/local-dev.md) (IDE integration)
49 | - [Why Nix?](./docs/why-nix.md)
50 |
51 | ## Credits
52 |
53 | The generated CMake build files are based on the vLLM build infrastructure.
54 |
--------------------------------------------------------------------------------
/build-variants.json:
--------------------------------------------------------------------------------
1 | {
2 | "aarch64-darwin": {},
3 | "aarch64-linux": {
4 | "cuda": [
5 | "torch26-cxx11-cu126-aarch64-linux",
6 | "torch26-cxx98-cu126-aarch64-linux",
7 | "torch27-cxx11-cu126-aarch64-linux",
8 | "torch27-cxx11-cu128-aarch64-linux"
9 | ]
10 | },
11 | "x86_64-linux": {
12 | "cuda": [
13 | "torch26-cxx11-cu118-x86_64-linux",
14 | "torch26-cxx11-cu124-x86_64-linux",
15 | "torch26-cxx11-cu126-x86_64-linux",
16 | "torch26-cxx98-cu118-x86_64-linux",
17 | "torch26-cxx98-cu124-x86_64-linux",
18 | "torch26-cxx98-cu126-x86_64-linux",
19 | "torch27-cxx11-cu118-x86_64-linux",
20 | "torch27-cxx11-cu126-x86_64-linux",
21 | "torch27-cxx11-cu128-x86_64-linux"
22 | ],
23 | "rocm": [
24 | "torch26-cxx11-rocm62-x86_64-linux",
25 | "torch27-cxx11-rocm63-x86_64-linux"
26 | ]
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/build2cmake/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "build2cmake"
3 | version = "0.2.1"
4 | edition = "2021"
5 | description = "Generate CMake files for kernel-builder projects"
6 | homepage = "https://github.com/huggingface/kernel-builder"
7 | license = "Apache-2.0"
8 | documentation = "https://docs.rs/build2cmake"
9 | repository = "https://github.com/huggingface/kernel-builder"
10 |
11 | [dependencies]
12 | base32 = "0.5"
13 | clap = { version = "4", features = ["derive"] }
14 | eyre = "0.6.12"
15 | git2 = "0.20"
16 | itertools = "0.13"
17 | minijinja = "2.5"
18 | minijinja-embed = "2.5"
19 | rand = "0.8"
20 | serde = { version = "1", features = ["derive"] }
21 | serde_json = "1"
22 | serde-value = "0.7"
23 | toml = "0.8"
24 |
25 | [build-dependencies]
26 | minijinja-embed = "2.5"
27 |
--------------------------------------------------------------------------------
/build2cmake/build.rs:
--------------------------------------------------------------------------------
1 | fn main() {
2 | minijinja_embed::embed_templates!("src/templates");
3 | }
4 |
--------------------------------------------------------------------------------
/build2cmake/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "flake-utils": {
4 | "inputs": {
5 | "systems": "systems"
6 | },
7 | "locked": {
8 | "lastModified": 1731533236,
9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
10 | "owner": "numtide",
11 | "repo": "flake-utils",
12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
13 | "type": "github"
14 | },
15 | "original": {
16 | "owner": "numtide",
17 | "repo": "flake-utils",
18 | "type": "github"
19 | }
20 | },
21 | "nixpkgs": {
22 | "locked": {
23 | "lastModified": 1734424634,
24 | "narHash": "sha256-cHar1vqHOOyC7f1+tVycPoWTfKIaqkoe1Q6TnKzuti4=",
25 | "owner": "nixos",
26 | "repo": "nixpkgs",
27 | "rev": "d3c42f187194c26d9f0309a8ecc469d6c878ce33",
28 | "type": "github"
29 | },
30 | "original": {
31 | "owner": "nixos",
32 | "ref": "nixos-unstable",
33 | "repo": "nixpkgs",
34 | "type": "github"
35 | }
36 | },
37 | "root": {
38 | "inputs": {
39 | "flake-utils": "flake-utils",
40 | "nixpkgs": "nixpkgs"
41 | }
42 | },
43 | "systems": {
44 | "locked": {
45 | "lastModified": 1681028828,
46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
47 | "owner": "nix-systems",
48 | "repo": "default",
49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
50 | "type": "github"
51 | },
52 | "original": {
53 | "owner": "nix-systems",
54 | "repo": "default",
55 | "type": "github"
56 | }
57 | }
58 | },
59 | "root": "root",
60 | "version": 7
61 | }
62 |
--------------------------------------------------------------------------------
/build2cmake/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "A very basic flake";
3 |
4 | inputs = {
5 | flake-utils.url = "github:numtide/flake-utils";
6 | nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable";
7 | };
8 |
9 | outputs =
10 | {
11 | self,
12 | flake-utils,
13 | nixpkgs,
14 | }:
15 | flake-utils.lib.eachDefaultSystem (
16 | system:
17 | let
18 | pkgs = nixpkgs.legacyPackages.${system};
19 | in
20 | {
21 |
22 | devShells.default =
23 | with pkgs;
24 | mkShell {
25 | buildInputs = [
26 | cargo
27 | clippy
28 | openssl.dev
29 | pkg-config
30 | rustc
31 | rustfmt
32 | rust-analyzer
33 | ];
34 |
35 | RUST_SRC_PATH = "${rustPlatform.rustLibSrc}";
36 | };
37 | }
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/build2cmake/src/config/mod.rs:
--------------------------------------------------------------------------------
1 | use eyre::Result;
2 | use serde::Deserialize;
3 |
4 | pub mod v1;
5 |
6 | mod v2;
7 | use serde_value::Value;
8 | pub use v2::{Backend, Build, Dependencies, Kernel, Torch};
9 |
10 | #[derive(Debug)]
11 | pub enum BuildCompat {
12 | V1(v1::Build),
13 | V2(Build),
14 | }
15 |
16 | impl<'de> Deserialize<'de> for BuildCompat {
17 | fn deserialize(deserializer: D) -> std::result::Result
18 | where
19 | D: serde::Deserializer<'de>,
20 | {
21 | let value = Value::deserialize(deserializer)?;
22 |
23 | match v1::Build::deserialize(value.clone()) {
24 | Ok(v1_build) => Ok(BuildCompat::V1(v1_build)),
25 | Err(_) => {
26 | let v2_build: Build =
27 | Build::deserialize(value).map_err(serde::de::Error::custom)?;
28 | Ok(BuildCompat::V2(v2_build))
29 | }
30 | }
31 | }
32 | }
33 |
34 | impl TryFrom for Build {
35 | type Error = eyre::Error;
36 |
37 | fn try_from(compat: BuildCompat) -> Result {
38 | match compat {
39 | BuildCompat::V1(v1_build) => v1_build.try_into(),
40 | BuildCompat::V2(v2_build) => Ok(v2_build),
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/build2cmake/src/config/v1.rs:
--------------------------------------------------------------------------------
1 | use std::{collections::HashMap, fmt::Display, path::PathBuf};
2 |
3 | use serde::Deserialize;
4 |
5 | use super::v2::Dependencies;
6 |
7 | #[derive(Debug, Deserialize)]
8 | #[serde(deny_unknown_fields)]
9 | pub struct Build {
10 | pub general: General,
11 | pub torch: Option,
12 |
13 | #[serde(rename = "kernel", default)]
14 | pub kernels: HashMap,
15 | }
16 |
17 | #[derive(Debug, Deserialize)]
18 | #[serde(deny_unknown_fields)]
19 | pub struct General {
20 | pub name: String,
21 | }
22 |
23 | #[derive(Debug, Deserialize, Clone)]
24 | #[serde(deny_unknown_fields)]
25 | pub struct Torch {
26 | pub include: Option>,
27 | pub pyext: Option>,
28 |
29 | #[serde(default)]
30 | pub src: Vec,
31 |
32 | #[serde(default)]
33 | pub universal: bool,
34 | }
35 |
36 | #[derive(Debug, Deserialize)]
37 | #[serde(deny_unknown_fields, rename_all = "kebab-case")]
38 | pub struct Kernel {
39 | pub cuda_capabilities: Option>,
40 | pub rocm_archs: Option>,
41 | #[serde(default)]
42 | pub language: Language,
43 | pub depends: Vec,
44 | pub include: Option>,
45 | pub src: Vec,
46 | }
47 |
48 | #[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq)]
49 | #[serde(deny_unknown_fields, rename_all = "kebab-case")]
50 | pub enum Language {
51 | #[default]
52 | Cuda,
53 | CudaHipify,
54 | Metal,
55 | }
56 |
57 | impl Display for Language {
58 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 | match self {
60 | Language::Cuda => f.write_str("cuda"),
61 | Language::CudaHipify => f.write_str("cuda-hipify"),
62 | Language::Metal => f.write_str("metal"),
63 | }
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/build2cmake/src/config/v2.rs:
--------------------------------------------------------------------------------
1 | use std::{
2 | collections::{BTreeSet, HashMap},
3 | fmt::Display,
4 | path::PathBuf,
5 | str::FromStr,
6 | };
7 |
8 | use eyre::{bail, Result};
9 | use itertools::Itertools;
10 | use serde::{Deserialize, Serialize};
11 |
12 | use super::v1::{self, Language};
13 |
14 | #[derive(Debug, Deserialize, Serialize)]
15 | #[serde(deny_unknown_fields)]
16 | pub struct Build {
17 | pub general: General,
18 | pub torch: Option,
19 |
20 | #[serde(rename = "kernel", default)]
21 | pub kernels: HashMap,
22 | }
23 |
24 | impl Build {
25 | pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool {
26 | self.kernels
27 | .values()
28 | .any(|kernel| kernel.backend == *backend)
29 | }
30 |
31 | pub fn backends(&self) -> BTreeSet {
32 | self.kernels.values().map(|kernel| kernel.backend).collect()
33 | }
34 | }
35 |
36 | #[derive(Debug, Deserialize, Serialize)]
37 | #[serde(deny_unknown_fields, rename_all = "kebab-case")]
38 | pub struct General {
39 | pub name: String,
40 | #[serde(default)]
41 | pub universal: bool,
42 | }
43 |
44 | #[derive(Debug, Deserialize, Clone, Serialize)]
45 | #[serde(deny_unknown_fields)]
46 | pub struct Torch {
47 | pub include: Option>,
48 | pub pyext: Option>,
49 |
50 | #[serde(default)]
51 | pub src: Vec,
52 | }
53 |
54 | impl Torch {
55 | pub fn data_globs(&self) -> Option> {
56 | match self.pyext.as_ref() {
57 | Some(exts) => {
58 | let globs = exts
59 | .iter()
60 | .filter(|&ext| ext != "py" && ext != "pyi")
61 | .map(|ext| format!("\"**/*.{}\"", ext))
62 | .collect_vec();
63 | if globs.is_empty() {
64 | None
65 | } else {
66 | Some(globs)
67 | }
68 | }
69 |
70 | None => None,
71 | }
72 | }
73 | }
74 |
75 | #[derive(Debug, Deserialize, Serialize)]
76 | #[serde(deny_unknown_fields, rename_all = "kebab-case")]
77 | pub struct Kernel {
78 | pub backend: Backend,
79 | pub cuda_capabilities: Option>,
80 | pub rocm_archs: Option>,
81 | pub depends: Vec,
82 | pub include: Option>,
83 | pub src: Vec,
84 | }
85 |
86 | #[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)]
87 | #[serde(deny_unknown_fields, rename_all = "kebab-case")]
88 | pub enum Backend {
89 | Cuda,
90 | Metal,
91 | Rocm,
92 | }
93 |
94 | impl Display for Backend {
95 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 | match self {
97 | Backend::Cuda => write!(f, "cuda"),
98 | Backend::Metal => write!(f, "metal"),
99 | Backend::Rocm => write!(f, "rocm"),
100 | }
101 | }
102 | }
103 |
104 | impl FromStr for Backend {
105 | type Err = String;
106 |
107 | fn from_str(s: &str) -> Result {
108 | match s.to_lowercase().as_str() {
109 | "cuda" => Ok(Backend::Cuda),
110 | "metal" => Ok(Backend::Metal),
111 | "rocm" => Ok(Backend::Rocm),
112 | _ => Err(format!("Unknown backend: {}", s)),
113 | }
114 | }
115 | }
116 |
117 | #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
118 | #[non_exhaustive]
119 | #[serde(rename_all = "lowercase")]
120 | pub enum Dependencies {
121 | #[serde(rename = "cutlass_2_10")]
122 | Cutlass2_10,
123 | #[serde(rename = "cutlass_3_5")]
124 | Cutlass3_5,
125 | #[serde(rename = "cutlass_3_6")]
126 | Cutlass3_6,
127 | #[serde(rename = "cutlass_3_8")]
128 | Cutlass3_8,
129 | Torch,
130 | }
131 |
132 | impl TryFrom for Build {
133 | type Error = eyre::Error;
134 |
135 | fn try_from(build: v1::Build) -> Result {
136 | let universal = build
137 | .torch
138 | .as_ref()
139 | .map(|torch| torch.universal)
140 | .unwrap_or(false);
141 | Ok(Self {
142 | general: General::from(build.general, universal),
143 | torch: build.torch.map(Into::into),
144 | kernels: convert_kernels(build.kernels)?,
145 | })
146 | }
147 | }
148 |
149 | impl General {
150 | fn from(general: v1::General, universal: bool) -> Self {
151 | Self {
152 | name: general.name,
153 | universal,
154 | }
155 | }
156 | }
157 |
158 | fn convert_kernels(v1_kernels: HashMap) -> Result> {
159 | let mut kernels = HashMap::new();
160 |
161 | for (name, kernel) in v1_kernels {
162 | if kernel.language == Language::CudaHipify {
163 | // We need to add an affix to avoid confflict with the CUDA kernel.
164 | let rocm_name = format!("{name}_rocm");
165 | if kernels.contains_key(&rocm_name) {
166 | bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`")
167 | }
168 |
169 | kernels.insert(
170 | format!("{name}_rocm"),
171 | Kernel {
172 | backend: Backend::Rocm,
173 | cuda_capabilities: None,
174 | rocm_archs: kernel.rocm_archs,
175 | depends: kernel.depends.clone(),
176 | include: kernel.include.clone(),
177 | src: kernel.src.clone(),
178 | },
179 | );
180 | }
181 |
182 | kernels.insert(
183 | name,
184 | Kernel {
185 | backend: Backend::Cuda,
186 | cuda_capabilities: kernel.cuda_capabilities,
187 | rocm_archs: None,
188 | depends: kernel.depends,
189 | include: kernel.include,
190 | src: kernel.src,
191 | },
192 | );
193 | }
194 |
195 | Ok(kernels)
196 | }
197 |
198 | impl From for Torch {
199 | fn from(torch: v1::Torch) -> Self {
200 | Self {
201 | include: torch.include,
202 | pyext: torch.pyext,
203 | src: torch.src,
204 | }
205 | }
206 | }
207 |
--------------------------------------------------------------------------------
/build2cmake/src/cuda_supported_archs.json:
--------------------------------------------------------------------------------
1 | ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0"]
2 |
--------------------------------------------------------------------------------
/build2cmake/src/fileset.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 | use std::path::{Path, PathBuf};
3 |
4 | use eyre::{bail, eyre, Context, Result};
5 |
6 | pub struct FileSet(HashMap>);
7 |
8 | impl FileSet {
9 | pub fn new() -> FileSet {
10 | FileSet(HashMap::new())
11 | }
12 |
13 | fn check_exist(&self, target_dir: &Path) -> Result<()> {
14 | let mut existing = Vec::new();
15 | for path in self.0.keys() {
16 | let full_path = target_dir.join(path);
17 | if full_path.exists() {
18 | existing.push(path.to_string_lossy().into_owned());
19 | }
20 | }
21 |
22 | if !existing.is_empty() {
23 | bail!(
24 | "File(s) already exists in target directory: {}\nUse `--force` to overwrite.",
25 | existing.join(", ")
26 | );
27 | }
28 |
29 | Ok(())
30 | }
31 |
32 | pub fn entry(&mut self, path: impl Into) -> &mut Vec {
33 | self.0.entry(path.into()).or_default()
34 | }
35 |
36 | pub fn write(&self, target_dir: &Path, force: bool) -> Result<()> {
37 | // Check that the paths do not exist and that we can write.
38 | if !force {
39 | self.check_exist(target_dir)?;
40 | }
41 |
42 | for (path, content) in &self.0 {
43 | let full_path = target_dir.join(path);
44 | write_to_file(&full_path, content)?;
45 | }
46 |
47 | Ok(())
48 | }
49 | }
50 |
51 | impl Default for FileSet {
52 | fn default() -> Self {
53 | FileSet::new()
54 | }
55 | }
56 |
57 | fn write_to_file(path: impl AsRef, data: &[u8]) -> Result<()> {
58 | let path = path.as_ref();
59 |
60 | let parent = path
61 | .parent()
62 | .ok_or_else(|| eyre!("Cannot get parent of `{}`", path.to_string_lossy()))?;
63 | std::fs::create_dir_all(parent)
64 | .wrap_err_with(|| format!("Cannot create directory `{}`", parent.to_string_lossy()))?;
65 |
66 | std::fs::write(path, data)
67 | .wrap_err_with(|| format!("Cannot create: {}", path.to_string_lossy()))
68 | }
69 |
--------------------------------------------------------------------------------
/build2cmake/src/main.rs:
--------------------------------------------------------------------------------
1 | use std::{
2 | fs::File,
3 | io::{BufWriter, Read, Write},
4 | path::{Path, PathBuf},
5 | };
6 |
7 | use clap::{Parser, Subcommand};
8 | use eyre::{bail, ensure, Context, Result};
9 | use minijinja::Environment;
10 |
11 | mod torch;
12 | use torch::{write_torch_ext, write_torch_ext_metal, write_torch_universal_ext};
13 |
14 | mod config;
15 | use config::{Backend, Build, BuildCompat};
16 |
17 | mod fileset;
18 | use fileset::FileSet;
19 |
20 | #[derive(Parser, Debug)]
21 | #[command(version, about, long_about = None)]
22 | struct Cli {
23 | #[command(subcommand)]
24 | command: Commands,
25 | }
26 |
27 | #[derive(Debug, Subcommand)]
28 | enum Commands {
29 | /// Generate CMake files for Torch extension builds.
30 | GenerateTorch {
31 | #[arg(name = "BUILD_TOML")]
32 | build_toml: PathBuf,
33 |
34 | /// The directory to write the generated files to
35 | /// (directory of `BUILD_TOML` when absent).
36 | #[arg(name = "TARGET_DIR")]
37 | target_dir: Option,
38 |
39 | /// Force-overwrite existing files.
40 | #[arg(short, long)]
41 | force: bool,
42 |
43 | /// This is an optional unique identifier that is suffixed to the
44 | /// kernel name to avoid name collisions. (e.g. Git SHA)
45 | #[arg(long)]
46 | ops_id: Option,
47 |
48 | #[arg(long)]
49 | backend: Option,
50 | },
51 |
52 | /// Update a `build.toml` to the current format.
53 | UpdateBuild {
54 | #[arg(name = "BUILD_TOML")]
55 | build_toml: PathBuf,
56 | },
57 |
58 | /// Validate the build.toml file.
59 | Validate {
60 | #[arg(name = "BUILD_TOML")]
61 | build_toml: PathBuf,
62 | },
63 | }
64 |
65 | fn main() -> Result<()> {
66 | let args = Cli::parse();
67 | match args.command {
68 | Commands::GenerateTorch {
69 | backend,
70 | build_toml,
71 | force,
72 | target_dir,
73 | ops_id,
74 | } => generate_torch(backend, build_toml, target_dir, force, ops_id),
75 | Commands::UpdateBuild { build_toml } => update_build(build_toml),
76 | Commands::Validate { build_toml } => {
77 | parse_and_validate(build_toml)?;
78 | Ok(())
79 | }
80 | }
81 | }
82 |
83 | fn generate_torch(
84 | backend: Option,
85 | build_toml: PathBuf,
86 | target_dir: Option,
87 | force: bool,
88 | ops_id: Option,
89 | ) -> Result<()> {
90 | let target_dir = check_or_infer_target_dir(&build_toml, target_dir)?;
91 |
92 | let build_compat = parse_and_validate(build_toml)?;
93 |
94 | if matches!(build_compat, BuildCompat::V1(_)) {
95 | eprintln!(
96 | "build.toml is in the deprecated V1 format, use `build2cmake update-build` to update."
97 | )
98 | }
99 |
100 | let build: Build = build_compat
101 | .try_into()
102 | .context("Cannot update build configuration")?;
103 |
104 | let mut env = Environment::new();
105 | env.set_trim_blocks(true);
106 | minijinja_embed::load_templates!(&mut env);
107 |
108 | let backend = match (backend, build.general.universal) {
109 | (None, true) => return write_torch_universal_ext(&env, &build, target_dir, force, ops_id),
110 | (Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend),
111 | (Some(backend), false) => {
112 | if !build.has_kernel_with_backend(&backend) {
113 | bail!("No kernels found for backend {}", backend);
114 | }
115 |
116 | backend
117 | }
118 | (None, false) => {
119 | let mut kernel_backends = build.backends();
120 | let backend = if let Some(backend) = kernel_backends.pop_first() {
121 | backend
122 | } else {
123 | bail!("No kernels found in build.toml");
124 | };
125 |
126 | if !kernel_backends.is_empty() {
127 | let kernel_backends: Vec<_> = build
128 | .backends()
129 | .into_iter()
130 | .map(|backend| backend.to_string())
131 | .collect();
132 | bail!(
133 | "Multiple supported backends found in build.toml: {}. Please specify one with --backend.",
134 | kernel_backends.join(", ")
135 | );
136 | }
137 |
138 | backend
139 | }
140 | };
141 |
142 | match backend {
143 | Backend::Cuda | Backend::Rocm => {
144 | write_torch_ext(&env, backend, &build, target_dir, force, ops_id)
145 | }
146 | Backend::Metal => write_torch_ext_metal(&env, &build, target_dir, force, ops_id),
147 | }
148 | }
149 |
150 | fn update_build(build_toml: PathBuf) -> Result<()> {
151 | let build_compat: BuildCompat = parse_and_validate(&build_toml)?;
152 |
153 | if matches!(build_compat, BuildCompat::V2(_)) {
154 | return Ok(());
155 | }
156 |
157 | let build: Build = build_compat
158 | .try_into()
159 | .context("Cannot update build configuration")?;
160 | let pretty_toml = toml::to_string_pretty(&build)?;
161 |
162 | let mut writer =
163 | BufWriter::new(File::create(&build_toml).wrap_err_with(|| {
164 | format!("Cannot open {} for writing", build_toml.to_string_lossy())
165 | })?);
166 | writer
167 | .write_all(pretty_toml.as_bytes())
168 | .wrap_err_with(|| format!("Cannot write to {}", build_toml.to_string_lossy()))?;
169 |
170 | Ok(())
171 | }
172 |
173 | fn check_or_infer_target_dir(
174 | build_toml: impl AsRef,
175 | target_dir: Option,
176 | ) -> Result {
177 | let build_toml = build_toml.as_ref();
178 | match target_dir {
179 | Some(target_dir) => {
180 | ensure!(
181 | target_dir.is_dir(),
182 | "`{}` is not a directory",
183 | target_dir.to_string_lossy()
184 | );
185 | Ok(target_dir)
186 | }
187 | None => {
188 | let absolute = std::path::absolute(build_toml)?;
189 | match absolute.parent() {
190 | Some(parent) => Ok(parent.to_owned()),
191 | None => bail!(
192 | "Cannot get parent path of `{}`",
193 | build_toml.to_string_lossy()
194 | ),
195 | }
196 | }
197 | }
198 | }
199 |
200 | fn parse_and_validate(build_toml: impl AsRef) -> Result {
201 | let build_toml = build_toml.as_ref();
202 | let mut toml_data = String::new();
203 | File::open(build_toml)
204 | .wrap_err_with(|| format!("Cannot open {} for reading", build_toml.to_string_lossy()))?
205 | .read_to_string(&mut toml_data)
206 | .wrap_err_with(|| format!("Cannot read from {}", build_toml.to_string_lossy()))?;
207 |
208 | let build_compat: BuildCompat = toml::from_str(&toml_data)
209 | .wrap_err_with(|| format!("Cannot parse TOML in {}", build_toml.to_string_lossy()))?;
210 |
211 | Ok(build_compat)
212 | }
213 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import {{ ops_name }}
3 | ops = torch.ops.{{ ops_name }}
4 |
5 | def add_op_namespace_prefix(op_name: str):
6 | """
7 | Prefix op by namespace.
8 | """
9 | return f"{{ ops_name }}::{op_name}"
10 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/dep-cutlass.cmake:
--------------------------------------------------------------------------------
1 | find_package(NvidiaCutlass)
2 |
3 | if (NOT NvidiaCutlass_FOUND)
4 | set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
5 |
6 | # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
7 | set(CUTLASS_REVISION "v{{ version }}" CACHE STRING "CUTLASS revision to use")
8 |
9 |
10 | # Use the specified CUTLASS source directory for compilation if CUTLASS_SRC_DIR is provided
11 | if (DEFINED ENV{CUTLASS_SRC_DIR})
12 | set(CUTLASS_SRC_DIR $ENV{CUTLASS_SRC_DIR})
13 | endif()
14 |
15 | if(CUTLASS_SRC_DIR)
16 | if(NOT IS_ABSOLUTE CUTLASS_SRC_DIR)
17 | get_filename_component(CUTLASS_SRC_DIR "${CUTLASS_SRC_DIR}" ABSOLUTE)
18 | endif()
19 | message(STATUS "The CUTLASS_SRC_DIR is set, using ${CUTLASS_SRC_DIR} for compilation")
20 | FetchContent_Declare(cutlass SOURCE_DIR ${CUTLASS_SRC_DIR})
21 | else()
22 | FetchContent_Declare(
23 | cutlass
24 | GIT_REPOSITORY https://github.com/nvidia/cutlass.git
25 | GIT_TAG ${CUTLASS_REVISION}
26 | GIT_PROGRESS TRUE
27 |
28 | # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
29 | # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
30 | # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
31 | GIT_SHALLOW TRUE
32 | )
33 | endif()
34 | FetchContent_MakeAvailable(cutlass)
35 |
36 | include_directories(${CUTLASS_INCLUDE_DIR})
37 | else()
38 | message(STATUS "Using system cutlass with version: ${NvidiaCutlass_VERSION}")
39 | endif(NOT NvidiaCutlass_FOUND)
40 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/hipify.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py
5 |
6 | #
7 | # A command line tool for running pytorch's hipify preprocessor on CUDA
8 | # source files.
9 | #
10 | # See https://github.com/ROCm/hipify_torch
11 | # and /utils/hipify/hipify_python.py
12 | #
13 |
14 | import argparse
15 | import os
16 | import shutil
17 |
18 | from torch.utils.hipify.hipify_python import hipify
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser()
22 |
23 | # Project directory where all the source + include files live.
24 | parser.add_argument(
25 | "-p",
26 | "--project_dir",
27 | help="The project directory.",
28 | )
29 |
30 | # Directory where hipified files are written.
31 | parser.add_argument(
32 | "-o",
33 | "--output_dir",
34 | help="The output directory.",
35 | )
36 |
37 | # Source files to convert.
38 | parser.add_argument("sources",
39 | help="Source files to hipify.",
40 | nargs="*",
41 | default=[])
42 |
43 | args = parser.parse_args()
44 |
45 | # Limit include scope to project_dir only
46 | includes = [os.path.join(args.project_dir, '*')]
47 |
48 | # Get absolute path for all source files.
49 | extra_files = [os.path.abspath(s) for s in args.sources]
50 |
51 | # Copy sources from project directory to output directory.
52 | # The directory might already exist to hold object files so we ignore that.
53 | shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
54 |
55 | hipify_result = hipify(project_directory=args.project_dir,
56 | output_directory=args.output_dir,
57 | header_include_dirs=[],
58 | includes=includes,
59 | extra_files=extra_files,
60 | show_detailed=True,
61 | is_pytorch_extension=True,
62 | hipify_extra_files_only=True)
63 |
64 | hipified_sources = []
65 | for source in args.sources:
66 | s_abs = os.path.abspath(source)
67 | hipified_s_abs = (hipify_result[s_abs].hipified_path if
68 | (s_abs in hipify_result
69 | and hipify_result[s_abs].hipified_path is not None)
70 | else s_abs)
71 | hipified_sources.append(hipified_s_abs)
72 |
73 | assert (len(hipified_sources) == len(args.sources))
74 |
75 | # Print hipified source files.
76 | print("\n".join(hipified_sources))
77 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/kernel.cmake:
--------------------------------------------------------------------------------
1 | set({{kernel_name}}_SRC
2 | {{ sources }}
3 | )
4 |
5 | {% if includes %}
6 | # TODO: check if CLion support this:
7 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
8 | set_source_files_properties(
9 | {{'${' + kernel_name + '_SRC}'}}
10 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
11 | {% endif %}
12 |
13 | if(GPU_LANG STREQUAL "CUDA")
14 | {% if cuda_capabilities %}
15 | cuda_archs_loose_intersection({{kernel_name}}_ARCHS "{{ cuda_capabilities|join(";") }}" "${CUDA_ARCHS}")
16 | {% else %}
17 | cuda_archs_loose_intersection({{kernel_name}}_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
18 | {% endif %}
19 | message(STATUS "Capabilities for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}")
20 | set_gencode_flags_for_srcs(SRCS {{'"${' + kernel_name + '_SRC}"'}} CUDA_ARCHS "{{ '${' + kernel_name + '_ARCHS}'}}")
21 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}})
22 | {% if supports_hipify %}
23 | elseif(GPU_LANG STREQUAL "HIP")
24 | hip_archs_loose_intersection({{kernel_name}}_ARCHS "{{ rocm_archs|join(";") }}" ${ROCM_ARCHS})
25 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}})
26 | {% endif %}
27 | endif()
28 |
29 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/preamble.cmake:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.26)
2 | project({{name}} LANGUAGES CXX)
3 |
4 | set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
5 |
6 | install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7 |
8 | include(FetchContent)
9 | file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10 | message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11 |
12 | set(CUDA_SUPPORTED_ARCHS "{{ cuda_supported_archs }}")
13 |
14 | set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
15 |
16 | include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
17 |
18 | if(DEFINED Python_EXECUTABLE)
19 | # Allow passing through the interpreter (e.g. from setup.py).
20 | find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
21 | if (NOT Python_FOUND)
22 | message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
23 | endif()
24 | else()
25 | find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
26 | endif()
27 |
28 | append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
29 |
30 | find_package(Torch REQUIRED)
31 |
32 | if (NOT TARGET_DEVICE STREQUAL "cuda" AND
33 | NOT TARGET_DEVICE STREQUAL "rocm")
34 | return()
35 | endif()
36 |
37 | if (NOT HIP_FOUND AND CUDA_FOUND)
38 | set(GPU_LANG "CUDA")
39 | elseif(HIP_FOUND)
40 | set(GPU_LANG "HIP")
41 |
42 | # Importing torch recognizes and sets up some HIP/ROCm configuration but does
43 | # not let cmake recognize .hip files. In order to get cmake to understand the
44 | # .hip extension automatically, HIP must be enabled explicitly.
45 | enable_language(HIP)
46 | else()
47 | message(FATAL_ERROR "Can't find CUDA or HIP installation.")
48 | endif()
49 |
50 |
51 | if(GPU_LANG STREQUAL "CUDA")
52 | clear_cuda_arches(CUDA_ARCH_FLAGS)
53 | extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
54 | message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
55 | # Filter the target architectures by the supported supported archs
56 | # since for some files we will build for all CUDA_ARCHS.
57 | cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
58 | message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
59 |
60 | if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
61 | list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
62 | endif()
63 |
64 | add_compile_definitions(CUDA_KERNEL)
65 | elseif(GPU_LANG STREQUAL "HIP")
66 | set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
67 | # TODO: remove this once we can set specific archs per source file set.
68 | override_gpu_arches(GPU_ARCHES
69 | ${GPU_LANG}
70 | "${${GPU_LANG}_SUPPORTED_ARCHS}")
71 |
72 | add_compile_definitions(ROCM_KERNEL)
73 | else()
74 | override_gpu_arches(GPU_ARCHES
75 | ${GPU_LANG}
76 | "${${GPU_LANG}_SUPPORTED_ARCHS}")
77 | endif()
78 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/setup.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from shutil import which, move
4 | import subprocess
5 | import sys
6 | from pathlib import Path
7 |
8 | from setuptools import Extension, find_packages, setup
9 | from setuptools.command.build_ext import build_ext
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def is_sccache_available() -> bool:
15 | return which("sccache") is not None
16 |
17 |
18 | def is_ccache_available() -> bool:
19 | return which("ccache") is not None
20 |
21 |
22 | def is_ninja_available() -> bool:
23 | return which("ninja") is not None
24 |
25 |
26 | class CMakeExtension(Extension):
27 | def __init__(self, name: str, sourcedir: str = "") -> None:
28 | super().__init__(name, sources=[], py_limited_api=True)
29 | self.sourcedir = os.fspath(Path(sourcedir).resolve())
30 |
31 |
32 | class CMakeBuild(build_ext):
33 | def build_extension(self, ext: CMakeExtension) -> None:
34 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
35 | extdir = ext_fullpath.parent.resolve()
36 |
37 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
38 | cfg = "Debug" if debug else "Release"
39 |
40 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
41 |
42 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
43 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
44 | # from Python.
45 | cmake_args = [
46 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
47 | f"-DPython_EXECUTABLE={sys.executable}",
48 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
49 | ]
50 | build_args = []
51 | if "CMAKE_ARGS" in os.environ:
52 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
53 |
54 | if not cmake_generator or cmake_generator == "Ninja":
55 | try:
56 | import ninja
57 |
58 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
59 | cmake_args += [
60 | "-GNinja",
61 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
62 | ]
63 | except ImportError:
64 | pass
65 |
66 | if is_sccache_available():
67 | cmake_args += [
68 | "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
69 | "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
70 | "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
71 | "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
72 | ]
73 | elif is_ccache_available():
74 | cmake_args += [
75 | "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
76 | "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
77 | "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
78 | "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
79 | ]
80 |
81 | num_jobs = os.getenv("MAX_JOBS", None)
82 | if num_jobs is not None:
83 | num_jobs = int(num_jobs)
84 | logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
85 | else:
86 | try:
87 | # os.sched_getaffinity() isn't universally available, so fall
88 | # back to os.cpu_count() if we get an error here.
89 | num_jobs = len(os.sched_getaffinity(0))
90 | except AttributeError:
91 | num_jobs = os.cpu_count()
92 |
93 | nvcc_threads = os.getenv("NVCC_THREADS", None)
94 | if nvcc_threads is not None:
95 | nvcc_threads = int(nvcc_threads)
96 | logger.info(
97 | "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
98 | )
99 | else:
100 | nvcc_threads = 1
101 | num_jobs = max(1, num_jobs // nvcc_threads)
102 |
103 | build_args += [f"-j{num_jobs}"]
104 | if sys.platform == "win32":
105 | build_args += ["--config", cfg]
106 |
107 | if nvcc_threads:
108 | cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
109 |
110 | build_temp = Path(self.build_temp) / ext.name
111 | if not build_temp.exists():
112 | build_temp.mkdir(parents=True)
113 |
114 | subprocess.run(
115 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
116 | )
117 | subprocess.run(
118 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
119 | )
120 | if sys.platform == "win32":
121 | # Move the dylib one folder up for discovery.
122 | for filename in os.listdir(extdir / cfg):
123 | move(extdir / cfg / filename, extdir / filename)
124 |
125 |
126 |
127 | setup(
128 | name="{{ name }}",
129 | # The version is just a stub, it's not used by the final build artefact.
130 | version="0.1.0",
131 | ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")],
132 | cmdclass={"build_ext": CMakeBuild},
133 | packages=find_packages(where="torch-ext", include=["{{ name }}*"]),
134 | package_dir={"": "torch-ext"},
135 | {% if data_globs %}
136 | package_data={"{{ name }}": [ {{ data_globs }} ]},
137 | {% endif %}
138 | zip_safe=False,
139 | install_requires=["torch"],
140 | python_requires=">=3.9",
141 | )
142 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/torch-binding.cmake:
--------------------------------------------------------------------------------
1 | get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
2 | list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
3 |
4 | set(TORCH_{{name}}_SRC
5 | {{ src|join(' ') }}
6 | )
7 |
8 | {% if includes %}
9 | # TODO: check if CLion support this:
10 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
11 | set_source_files_properties(
12 | {{'${TORCH_' + name + '_SRC}'}}
13 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
14 | {% endif %}
15 |
16 | list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}})
17 |
18 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/cuda/torch-extension.cmake:
--------------------------------------------------------------------------------
1 | define_gpu_extension_target(
2 | {{ ops_name }}
3 | DESTINATION {{ ops_name }}
4 | LANGUAGE ${GPU_LANG}
5 | SOURCES ${SRC}
6 | COMPILE_FLAGS ${GPU_FLAGS}
7 | ARCHITECTURES ${GPU_ARCHES}
8 | #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
9 | USE_SABI 3
10 | WITH_SOABI)
11 |
12 | target_link_options({{ ops_name }} PRIVATE -static-libstdc++)
13 |
14 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/kernel.cmake:
--------------------------------------------------------------------------------
1 | set({{kernel_name}}_SRC
2 | {{ sources }}
3 | )
4 |
5 | {% if includes %}
6 | # TODO: check if CLion support this:
7 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
8 | set_source_files_properties(
9 | {{'${' + kernel_name + '_SRC}'}}
10 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
11 | {% endif %}
12 |
13 | list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}})
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/preamble.cmake:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.26)
2 | project({{name}} LANGUAGES CXX)
3 |
4 | set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
5 |
6 | install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7 |
8 | include(FetchContent)
9 | file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10 | message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11 |
12 | include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
13 |
14 | if(DEFINED Python_EXECUTABLE)
15 | # Allow passing through the interpreter (e.g. from setup.py).
16 | find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
17 | if (NOT Python_FOUND)
18 | message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
19 | endif()
20 | else()
21 | find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
22 | endif()
23 |
24 | append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
25 |
26 | find_package(Torch REQUIRED)
27 |
28 | add_compile_definitions(METAL_KERNEL)
29 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/setup.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from shutil import which, move
4 | import subprocess
5 | import sys
6 | from pathlib import Path
7 |
8 | from setuptools import Extension, find_packages, setup
9 | from setuptools.command.build_ext import build_ext
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def is_sccache_available() -> bool:
15 | return which("sccache") is not None
16 |
17 |
18 | def is_ccache_available() -> bool:
19 | return which("ccache") is not None
20 |
21 |
22 | def is_ninja_available() -> bool:
23 | return which("ninja") is not None
24 |
25 |
26 | class CMakeExtension(Extension):
27 | def __init__(self, name: str, sourcedir: str = "") -> None:
28 | super().__init__(name, sources=[], py_limited_api=True)
29 | self.sourcedir = os.fspath(Path(sourcedir).resolve())
30 |
31 |
32 | class CMakeBuild(build_ext):
33 | def build_extension(self, ext: CMakeExtension) -> None:
34 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
35 | extdir = ext_fullpath.parent.resolve()
36 |
37 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
38 | cfg = "Debug" if debug else "Release"
39 |
40 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
41 |
42 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
43 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
44 | # from Python.
45 | cmake_args = [
46 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
47 | f"-DPython_EXECUTABLE={sys.executable}",
48 | f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
49 | ]
50 | build_args = []
51 | if "CMAKE_ARGS" in os.environ:
52 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
53 |
54 | if not cmake_generator or cmake_generator == "Ninja":
55 | try:
56 | import ninja
57 |
58 | ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
59 | cmake_args += [
60 | "-GNinja",
61 | f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
62 | ]
63 | except ImportError:
64 | pass
65 |
66 | if is_sccache_available():
67 | cmake_args += [
68 | "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
69 | "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
70 | "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
71 | "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
72 | "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
73 | ]
74 | elif is_ccache_available():
75 | cmake_args += [
76 | "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
77 | "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
78 | "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
79 | "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
80 | "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
81 | ]
82 |
83 | num_jobs = os.getenv("MAX_JOBS", None)
84 | if num_jobs is not None:
85 | num_jobs = int(num_jobs)
86 | logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
87 | else:
88 | try:
89 | # os.sched_getaffinity() isn't universally available, so fall
90 | # back to os.cpu_count() if we get an error here.
91 | num_jobs = len(os.sched_getaffinity(0))
92 | except AttributeError:
93 | num_jobs = os.cpu_count()
94 |
95 | build_temp = Path(self.build_temp) / ext.name
96 | if not build_temp.exists():
97 | build_temp.mkdir(parents=True)
98 |
99 | subprocess.run(
100 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
101 | )
102 | subprocess.run(
103 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
104 | )
105 |
106 |
107 | setup(
108 | name="{{ name }}",
109 | # The version is just a stub, it's not used by the final build artefact.
110 | version="0.1.0",
111 | ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")],
112 | cmdclass={"build_ext": CMakeBuild},
113 | packages=find_packages(where="torch-ext", include=["{{ name }}*"]),
114 | package_dir={"": "torch-ext"},
115 | {% if data_globs %}
116 | package_data={"{{ name }}": [ {{ data_globs }} ]},
117 | {% endif %}
118 | zip_safe=False,
119 | install_requires=["torch"],
120 | python_requires=">=3.9",
121 | )
122 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/torch-binding.cmake:
--------------------------------------------------------------------------------
1 | #get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
2 | #list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
3 |
4 | set(TORCH_{{name}}_SRC
5 | {{ src|join(' ') }}
6 | )
7 |
8 | {% if includes %}
9 | # TODO: check if CLion support this:
10 | # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
11 | set_source_files_properties(
12 | {{'${TORCH_' + name + '_SRC}'}}
13 | PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}")
14 | {% endif %}
15 |
16 | list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}})
17 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/torch-extension.cmake:
--------------------------------------------------------------------------------
1 | define_gpu_extension_target(
2 | {{ ops_name }}
3 | DESTINATION {{ ops_name }}
4 | LANGUAGE ${GPU_LANG}
5 | SOURCES ${SRC}
6 | COMPILE_FLAGS ${GPU_FLAGS}
7 | ARCHITECTURES ${GPU_ARCHES}
8 | USE_SABI 3
9 | WITH_SOABI)
--------------------------------------------------------------------------------
/build2cmake/src/templates/metal/utils.cmake:
--------------------------------------------------------------------------------
1 | # Run `EXPR` in python after importing `PKG`. Use the result of this to extend
2 | # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
3 | macro (append_cmake_prefix_path PKG EXPR)
4 | run_python(_PREFIX_PATH
5 | "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
6 | list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
7 | endmacro()
--------------------------------------------------------------------------------
/build2cmake/src/templates/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "cmake>=3.26",
4 | "ninja",
5 | "packaging",
6 | "setuptools>=61",
7 | "torch",
8 | "wheel",
9 | ]
10 | build-backend = "setuptools.build_meta"
11 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/registration.h:
--------------------------------------------------------------------------------
1 | // Registration macros from vLLM:
2 | // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
3 |
4 | #pragma once
5 |
6 | #include
7 |
8 | #define _CONCAT(A, B) A##B
9 | #define CONCAT(A, B) _CONCAT(A, B)
10 |
11 | #define _STRINGIFY(A) #A
12 | #define STRINGIFY(A) _STRINGIFY(A)
13 |
14 | // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
15 | // could be a macro instead of a literal token.
16 | #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
17 |
18 | // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
19 | // could be a macro instead of a literal token.
20 | #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
21 | TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
22 |
23 | // REGISTER_EXTENSION allows the shared library to be loaded and initialized
24 | // via python's import statement.
25 | #define REGISTER_EXTENSION(NAME) \
26 | PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
27 | static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
28 | STRINGIFY(NAME), nullptr, 0, nullptr}; \
29 | return PyModule_Create(&module); \
30 | }
31 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/universal/_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | ops = torch.ops.{{ ops_name }}
3 |
4 | def add_op_namespace_prefix(op_name: str):
5 | """
6 | Prefix op by namespace.
7 | """
8 | return f"{{ ops_name }}::{op_name}"
9 |
--------------------------------------------------------------------------------
/build2cmake/src/templates/universal/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "{{ name }}"
3 | version = "0.0.1"
4 | requires-python = ">= 3.9"
5 | dependencies = ["torch>=2.4"]
6 |
7 | [tool.setuptools]
8 | package-dir = { "" = "torch-ext" }
9 |
10 | [tool.setuptools.packages.find]
11 | where = ["torch-ext"]
12 | include = ["{{name}}*"]
13 |
14 | {% if data_globs %}
15 | [tool.setuptools.package-data]
16 | {{ name }} = [ {{ data_globs }} ]
17 | {% endif %}
18 |
--------------------------------------------------------------------------------
/build2cmake/src/torch/metal.rs:
--------------------------------------------------------------------------------
1 | use std::{io::Write, path::PathBuf};
2 |
3 | use eyre::{bail, Context, Result};
4 | use itertools::Itertools;
5 | use minijinja::{context, Environment};
6 |
7 | use super::kernel_ops_identifier;
8 | use crate::{
9 | config::{Build, Kernel, Torch},
10 | fileset::FileSet,
11 | };
12 |
13 | static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
14 | static REGISTRATION_H: &str = include_str!("../templates/registration.h");
15 |
16 | pub fn write_torch_ext_metal(
17 | env: &Environment,
18 | build: &Build,
19 | target_dir: PathBuf,
20 | force: bool,
21 | ops_id: Option,
22 | ) -> Result<()> {
23 | let torch_ext = match build.torch.as_ref() {
24 | Some(torch_ext) => torch_ext,
25 | None => bail!("Build configuration does not have `torch` section"),
26 | };
27 |
28 | let mut file_set = FileSet::default();
29 |
30 | let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id);
31 |
32 | write_cmake(
33 | env,
34 | build,
35 | torch_ext,
36 | &build.general.name,
37 | &ops_name,
38 | &mut file_set,
39 | )?;
40 |
41 | write_setup_py(
42 | env,
43 | torch_ext,
44 | &build.general.name,
45 | &ops_name,
46 | &mut file_set,
47 | )?;
48 |
49 | write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?;
50 |
51 | write_pyproject_toml(env, &mut file_set)?;
52 |
53 | write_torch_registration_macros(&mut file_set)?;
54 |
55 | file_set.write(&target_dir, force)?;
56 |
57 | Ok(())
58 | }
59 |
60 | fn write_cmake(
61 | env: &Environment,
62 | build: &Build,
63 | torch: &Torch,
64 | name: &str,
65 | ops_name: &str,
66 | file_set: &mut FileSet,
67 | ) -> Result<()> {
68 | let mut utils_path = PathBuf::new();
69 | utils_path.push("cmake");
70 | utils_path.push("utils.cmake");
71 | file_set
72 | .entry(utils_path.clone())
73 | .extend_from_slice(CMAKE_UTILS.as_bytes());
74 |
75 | let cmake_writer = file_set.entry("CMakeLists.txt");
76 |
77 | render_preamble(env, name, cmake_writer)?;
78 |
79 | // Add deps once we have any non-CUDA deps.
80 | // render_deps(env, build, cmake_writer)?;
81 |
82 | render_binding(env, torch, name, cmake_writer)?;
83 |
84 | for (kernel_name, kernel) in &build.kernels {
85 | render_kernel(env, kernel_name, kernel, cmake_writer)?;
86 | }
87 |
88 | render_extension(env, ops_name, cmake_writer)?;
89 |
90 | Ok(())
91 | }
92 |
93 | fn render_binding(
94 | env: &Environment,
95 | torch: &Torch,
96 | name: &str,
97 | write: &mut impl Write,
98 | ) -> Result<()> {
99 | env.get_template("metal/torch-binding.cmake")
100 | .wrap_err("Cannot get Torch binding template")?
101 | .render_to_write(
102 | context! {
103 | includes => torch.include.as_ref().map(prefix_and_join_includes),
104 | name => name,
105 | src => torch.src
106 | },
107 | &mut *write,
108 | )
109 | .wrap_err("Cannot render Torch binding template")?;
110 |
111 | write.write_all(b"\n")?;
112 |
113 | Ok(())
114 | }
115 |
116 | pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Write) -> Result<()> {
117 | env.get_template("metal/torch-extension.cmake")
118 | .wrap_err("Cannot get Torch extension template")?
119 | .render_to_write(
120 | context! {
121 | ops_name => ops_name,
122 | },
123 | &mut *write,
124 | )
125 | .wrap_err("Cannot render Torch extension template")?;
126 |
127 | write.write_all(b"\n")?;
128 |
129 | Ok(())
130 | }
131 |
132 | pub fn render_kernel(
133 | env: &Environment,
134 | kernel_name: &str,
135 | kernel: &Kernel,
136 | write: &mut impl Write,
137 | ) -> Result<()> {
138 | // Easier to do in Rust than Jinja.
139 | let sources = kernel
140 | .src
141 | .iter()
142 | .map(|src| format!("\"{src}\""))
143 | .collect_vec()
144 | .join("\n");
145 |
146 | env.get_template("metal/kernel.cmake")
147 | .wrap_err("Cannot get kernel template")?
148 | .render_to_write(
149 | context! {
150 | includes => kernel.include.as_ref().map(prefix_and_join_includes),
151 | kernel_name => kernel_name,
152 | sources => sources,
153 | },
154 | &mut *write,
155 | )
156 | .wrap_err("Cannot render kernel template")?;
157 |
158 | write.write_all(b"\n")?;
159 |
160 | Ok(())
161 | }
162 |
163 | fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
164 | env.get_template("metal/preamble.cmake")
165 | .wrap_err("Cannot get CMake prelude template")?
166 | .render_to_write(
167 | context! {
168 | name => name,
169 | },
170 | &mut *write,
171 | )
172 | .wrap_err("Cannot render CMake prelude template")?;
173 |
174 | write.write_all(b"\n")?;
175 |
176 | Ok(())
177 | }
178 |
179 | fn write_ops_py(
180 | env: &Environment,
181 | name: &str,
182 | ops_name: &str,
183 | file_set: &mut FileSet,
184 | ) -> Result<()> {
185 | let mut path = PathBuf::new();
186 | path.push("torch-ext");
187 | path.push(name);
188 | path.push("_ops.py");
189 | let writer = file_set.entry(path);
190 |
191 | env.get_template("_ops.py")
192 | .wrap_err("Cannot get _ops.py template")?
193 | .render_to_write(
194 | context! {
195 | ops_name => ops_name,
196 | },
197 | writer,
198 | )
199 | .wrap_err("Cannot render kernel template")?;
200 |
201 | Ok(())
202 | }
203 |
204 | fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> {
205 | let writer = file_set.entry("pyproject.toml");
206 |
207 | env.get_template("pyproject.toml")
208 | .wrap_err("Cannot get pyproject.toml template")?
209 | .render_to_write(context! {}, writer)
210 | .wrap_err("Cannot render kernel template")?;
211 |
212 | Ok(())
213 | }
214 |
215 | fn write_setup_py(
216 | env: &Environment,
217 | torch: &Torch,
218 | name: &str,
219 | ops_name: &str,
220 | file_set: &mut FileSet,
221 | ) -> Result<()> {
222 | let writer = file_set.entry("setup.py");
223 |
224 | let data_globs = torch.data_globs().map(|globs| globs.join(", "));
225 |
226 | env.get_template("metal/setup.py")
227 | .wrap_err("Cannot get setup.py template")?
228 | .render_to_write(
229 | context! {
230 | data_globs => data_globs,
231 | ops_name => ops_name,
232 | name => name,
233 | version => "0.1.0",
234 | },
235 | writer,
236 | )
237 | .wrap_err("Cannot render kernel template")?;
238 |
239 | Ok(())
240 | }
241 |
242 | fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
243 | let mut path = PathBuf::new();
244 | path.push("torch-ext");
245 | path.push("registration.h");
246 | file_set
247 | .entry(path)
248 | .extend_from_slice(REGISTRATION_H.as_bytes());
249 |
250 | Ok(())
251 | }
252 |
253 | fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String
254 | where
255 | S: AsRef,
256 | {
257 | includes
258 | .as_ref()
259 | .iter()
260 | .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref()))
261 | .collect_vec()
262 | .join(";")
263 | }
264 |
--------------------------------------------------------------------------------
/build2cmake/src/torch/mod.rs:
--------------------------------------------------------------------------------
1 | mod cuda;
2 | pub use cuda::write_torch_ext;
3 |
4 | mod metal;
5 | pub use metal::write_torch_ext_metal;
6 |
7 | mod ops_identifier;
8 | pub(crate) use ops_identifier::kernel_ops_identifier;
9 |
10 | mod universal;
11 | pub use universal::write_torch_universal_ext;
12 |
--------------------------------------------------------------------------------
/build2cmake/src/torch/ops_identifier.rs:
--------------------------------------------------------------------------------
1 | use std::path::Path;
2 |
3 | use eyre::{Result, WrapErr};
4 | use git2::Repository;
5 | use rand::Rng;
6 |
7 | fn random_identifier() -> String {
8 | // Generate a random string when no ops_id is provided
9 | let mut rng = rand::thread_rng();
10 | let build_id: u64 = rng.gen();
11 | base32::encode(
12 | base32::Alphabet::Rfc4648Lower { padding: false },
13 | &build_id.to_le_bytes(),
14 | )
15 | }
16 |
17 | fn git_identifier(target_dir: impl AsRef) -> Result {
18 | let repo = Repository::discover(target_dir.as_ref()).context("Cannot open git repository")?;
19 | let head = repo.head()?;
20 | let commit = head.peel_to_commit()?;
21 | let rev = commit.tree_id().to_string().chars().take(7).collect();
22 | let dirty = !repo.statuses(None)?.is_empty();
23 | Ok(if dirty { format!("{rev}_dirty") } else { rev })
24 | }
25 |
26 | pub fn kernel_ops_identifier(
27 | target_dir: impl AsRef,
28 | name: &str,
29 | ops_id: Option,
30 | ) -> String {
31 | let identifier = ops_id.unwrap_or_else(|| match git_identifier(target_dir.as_ref()) {
32 | Ok(rev) => rev,
33 | Err(_) => random_identifier(),
34 | });
35 |
36 | format!("_{name}_{identifier}")
37 | }
38 |
--------------------------------------------------------------------------------
/build2cmake/src/torch/universal.rs:
--------------------------------------------------------------------------------
1 | use std::path::PathBuf;
2 |
3 | use eyre::{Context, Result};
4 | use minijinja::{context, Environment};
5 |
6 | use crate::{
7 | config::{Build, Torch},
8 | fileset::FileSet,
9 | torch::kernel_ops_identifier,
10 | };
11 |
12 | pub fn write_torch_universal_ext(
13 | env: &Environment,
14 | build: &Build,
15 | target_dir: PathBuf,
16 | force: bool,
17 | ops_id: Option,
18 | ) -> Result<()> {
19 | let mut file_set = FileSet::default();
20 |
21 | let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id);
22 |
23 | write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?;
24 | write_pyproject_toml(
25 | env,
26 | build.torch.as_ref(),
27 | &build.general.name,
28 | &mut file_set,
29 | )?;
30 |
31 | file_set.write(&target_dir, force)?;
32 |
33 | Ok(())
34 | }
35 |
36 | fn write_ops_py(
37 | env: &Environment,
38 | name: &str,
39 | ops_name: &str,
40 | file_set: &mut FileSet,
41 | ) -> Result<()> {
42 | let mut path = PathBuf::new();
43 | path.push("torch-ext");
44 | path.push(name);
45 | path.push("_ops.py");
46 | let writer = file_set.entry(path);
47 |
48 | env.get_template("universal/_ops.py")
49 | .wrap_err("Cannot get _ops-universal.py template")?
50 | .render_to_write(
51 | context! {
52 | ops_name => ops_name,
53 | },
54 | writer,
55 | )
56 | .wrap_err("Cannot render kernel template")?;
57 |
58 | Ok(())
59 | }
60 |
61 | fn write_pyproject_toml(
62 | env: &Environment,
63 | torch: Option<&Torch>,
64 | name: &str,
65 | file_set: &mut FileSet,
66 | ) -> Result<()> {
67 | let writer = file_set.entry("pyproject.toml");
68 |
69 | let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", ")));
70 |
71 | env.get_template("universal/pyproject.toml")
72 | .wrap_err("Cannot get universal pyproject.toml template")?
73 | .render_to_write(
74 | context! {
75 | data_globs => data_globs,
76 | name => name,
77 | },
78 | writer,
79 | )
80 | .wrap_err("Cannot render kernel template")?;
81 |
82 | Ok(())
83 | }
84 |
--------------------------------------------------------------------------------
/default.nix:
--------------------------------------------------------------------------------
1 | (import (
2 | let
3 | lock = builtins.fromJSON (builtins.readFile ./flake.lock);
4 | in
5 | fetchTarball {
6 | url =
7 | lock.nodes.flake-compat.locked.url
8 | or "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz";
9 | sha256 = lock.nodes.flake-compat.locked.narHash;
10 | }
11 | ) { src = ./.; }).defaultNix
12 |
--------------------------------------------------------------------------------
/dockerfiles/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nixos/nix:2.18.8
2 |
3 | # default build args
4 | ARG MAX_JOBS=4
5 | ARG CORES=4
6 |
7 | # Combine RUN commands to reduce layers and improve caching
8 | RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf \
9 | && echo "max-jobs = $MAX_JOBS" >> /etc/nix/nix.conf \
10 | && echo "cores = $CORES" >> /etc/nix/nix.conf \
11 | && nix profile install nixpkgs#cachix nixpkgs#git-lfs \
12 | && cachix use huggingface
13 | WORKDIR /app
14 | # Copy builder files
15 | COPY . /etc/kernel-builder/
16 | # Set environment variables
17 | ENV MAX_JOBS=${MAX_JOBS}
18 | ENV CORES=${CORES}
19 | # Create directory and setup script
20 | RUN mkdir -p /etc/kernel-builder && \
21 | cat <<'EOF' > /etc/kernel-builder/cli.sh
22 | #!/bin/sh
23 | set -e
24 | # Default values
25 | BUILD_URL=""
26 | DEV_SHELL=0
27 | HELP=0
28 | # CLI usage function
29 | function show_usage {
30 | echo "Kernel Builder CLI"
31 | echo ""
32 | echo "Usage: docker run [docker-options] kernel-builder:dev [command] [options]"
33 | echo ""
34 | echo "Commands:"
35 | echo " build Build the kernel extension (default if no command specified)"
36 | echo " dev Start a development shell"
37 | echo " fetch [URL] Clone and build from a Git URL"
38 | echo " help Show this help message"
39 | echo ""
40 | echo "Options:"
41 | echo " --jobs, -j NUMBER Set maximum number of parallel jobs (default: $MAX_JOBS)"
42 | echo " --cores, -c NUMBER Set number of cores per job (default: $CORES)"
43 | echo ""
44 | echo "Examples:"
45 | echo " docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build"
46 | echo " docker run -it --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root dev"
47 | echo " docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root fetch https://huggingface.co/user/repo.git"
48 | }
49 |
50 | # Function to generate a basic flake.nix if it doesn't exist
51 | function ensure_flake_exists {
52 | local work_dir=$1
53 | if [ ! -f "${work_dir}/flake.nix" ]; then
54 | echo "No flake.nix found, creating a basic one..."
55 | cat <<'FLAKE_EOF' > "${work_dir}/flake.nix"
56 | {
57 | description = "Flake for Torch kernel extension";
58 | inputs = {
59 | kernel-builder.url = "github:huggingface/kernel-builder";
60 | };
61 | outputs = { self, kernel-builder, }:
62 | kernel-builder.lib.genFlakeOutputs {
63 | path = ./.;
64 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
65 | };
66 | }
67 | FLAKE_EOF
68 | echo "flake.nix created. You can customize it as needed."
69 | else
70 | echo "flake.nix already exists, skipping creation."
71 | fi
72 | }
73 | # Function to build the extension
74 | function build_extension {
75 | local work_dir=$1
76 | local output_dir=$2
77 |
78 | echo "Building Torch Extension Bundle from ${work_dir}"
79 | cd "${work_dir}"
80 |
81 | # Check if work_dir is a git repo and get hash if possible
82 | if [ -d "${work_dir}/.git" ]; then
83 | # Mark git as safe to allow commands
84 | git config --global --add safe.directory "${work_dir}"
85 | # Try to get git revision
86 | REV=$(git rev-parse --short=8 HEAD)
87 |
88 | # Check if working directory is dirty
89 | if [ -n "$(git status --porcelain 2)" ]; then
90 | REV="${REV}-dirty"
91 | fi
92 | else
93 | # Generate random material if not a git repo
94 | REV=$(dd if=/dev/urandom status=none bs=1 count=10 2>/dev/null | base32 | tr '[:upper:]' '[:lower:]' | head -c 10)
95 | fi
96 | echo "Building with rev $REV"
97 |
98 | # Check for flake.nix or create one
99 | ensure_flake_exists "${work_dir}"
100 |
101 | # Make sure the build is up to date
102 | nix run github:huggingface/kernel-builder#update-build -- build.toml
103 |
104 | # Pure bundle build
105 | echo "Building with Nix..."
106 | nix build \
107 | . \
108 | --max-jobs $MAX_JOBS \
109 | -j $CORES \
110 | -L
111 |
112 | echo "Build completed. Copying results to ${output_dir}"
113 | mkdir -p "${output_dir}"
114 | cp -r --dereference ./result/* "${output_dir}/"
115 | # As root, ensure proper permissions for host access
116 | chmod -R 777 "${output_dir}"
117 | echo "Done - results available in ${output_dir}"
118 | }
119 | # Function to start a dev shell
120 | function start_dev_shell {
121 | echo "Starting development shell..."
122 | # Check for flake.nix or create one
123 | ensure_flake_exists "/kernelcode"
124 | cd /kernelcode
125 | /root/.nix-profile/bin/nix develop
126 | }
127 | # Function to fetch and build from URL
128 | function fetch_and_build {
129 | if [ -z "$1" ]; then
130 | echo "Error: URL required for fetch command"
131 | show_usage
132 | exit 1
133 | fi
134 |
135 | local repo_url="$1"
136 | local src_dir="/tmp/kernel-src"
137 | local output_dir="/kernelcode/result"
138 |
139 | echo "Fetching code from ${repo_url} to ${src_dir}"
140 | # Create a temporary directory for the clone
141 | mkdir -p "${src_dir}"
142 |
143 | # Clone the repository to the temporary directory
144 | git lfs install
145 | git clone "${repo_url}" "${src_dir}"
146 |
147 | # Build from the temporary directory and copy results to mounted output
148 | build_extension "${src_dir}" "${output_dir}"
149 | }
150 | # Parse arguments
151 | COMMAND="build" # Default command
152 | ARGS=()
153 |
154 | while [[ $# -gt 0 ]]; do
155 | case $1 in
156 | build|dev|fetch|help)
157 | COMMAND="$1"
158 | shift
159 | ;;
160 | --jobs|-j)
161 | MAX_JOBS="$2"
162 | shift 2
163 | ;;
164 | --cores|-c)
165 | CORES="$2"
166 | shift 2
167 | ;;
168 | -*)
169 | echo "Unknown option: $1"
170 | show_usage
171 | exit 1
172 | ;;
173 | *)
174 | ARGS+=("$1")
175 | shift
176 | ;;
177 | esac
178 | done
179 | # Execute the command
180 | case $COMMAND in
181 | build)
182 | # When building existing code, use the mounted directory
183 | build_extension "/kernelcode" "/kernelcode/build"
184 | ;;
185 | dev)
186 | start_dev_shell
187 | ;;
188 | fetch)
189 | fetch_and_build "${ARGS[0]}"
190 | ;;
191 | help)
192 | show_usage
193 | ;;
194 | *)
195 | echo "Unknown command: $COMMAND"
196 | show_usage
197 | exit 1
198 | ;;
199 | esac
200 | EOF
201 | RUN chmod +x /etc/kernel-builder/cli.sh
202 | # Create output directory structure
203 | RUN mkdir -p /kernelcode/build
204 | # Set up volume for kernelcode
205 | VOLUME /kernelcode
206 |
207 | ENTRYPOINT ["/etc/kernel-builder/cli.sh"]
208 |
--------------------------------------------------------------------------------
/dockerfiles/Dockerfile.user:
--------------------------------------------------------------------------------
1 | FROM nixos/nix:2.18.8
2 |
3 | # default build args
4 | ARG MAX_JOBS=1
5 | ARG CORES=1
6 |
7 | # Set up Nix configuration and user
8 | RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf \
9 | && echo "max-jobs = $MAX_JOBS" >> /etc/nix/nix.conf \
10 | && echo "cores = $CORES" >> /etc/nix/nix.conf \
11 | && echo "trusted-users = root nixuser" >> /etc/nix/nix.conf \
12 | # Create user entries directly in password and group files
13 | && echo "nixuser:x:1000:1000:NixOS User:/home/nixuser:/bin/bash" >> /etc/passwd \
14 | && echo "nixuser:x:1000:" >> /etc/group \
15 | && mkdir -p /home/nixuser/kernelcode \
16 | # Create Nix directories with proper permissions
17 | && mkdir -p /nix/var/nix/profiles/per-user/nixuser \
18 | && mkdir -p /nix/var/nix/gcroots/per-user/nixuser \
19 | && chown -R 1000:1000 /home/nixuser /nix/var/nix/profiles/per-user/nixuser /nix/var/nix/gcroots/per-user/nixuser \
20 | # Install necessary packages
21 | && nix profile install nixpkgs#cachix nixpkgs#git-lfs nixpkgs#gawk \
22 | && cachix use huggingface
23 |
24 | # Set permissions for Nix directories
25 | RUN chown -R nixuser:nixuser /nix
26 |
27 | # Set working directory and copy files
28 | WORKDIR /home/nixuser/kernelcode
29 | COPY --chown=nixuser:nixuser . /home/nixuser/kernel-builder/
30 |
31 | # Set environment variables
32 | ENV MAX_JOBS=${MAX_JOBS}
33 | ENV CORES=${CORES}
34 | ENV HF_TOKEN=${HF_TOKEN}
35 | ENV HOME=/home/nixuser
36 | ENV PUSH_REVISION=hfjob-build
37 | ENV REPO=kernels-community/job-build-test-repo
38 |
39 | # Set up CLI script in nixuser's home
40 | RUN mkdir -p /home/nixuser/bin && \
41 | cat <<'EOF' > /home/nixuser/bin/cli.sh
42 | #!/bin/sh
43 | set -e
44 |
45 | # Default values
46 | BUILD_URL=""
47 | DEV_SHELL=0
48 | HELP=0
49 |
50 | # CLI usage function
51 | function show_usage {
52 | echo "Kernel Builder CLI"
53 | echo ""
54 | echo "Usage: docker run [docker-options] kernel-builder:dev [command] [options]"
55 | echo ""
56 | echo "Commands:"
57 | echo " build Build the kernel extension (default if no command specified)"
58 | echo " dev Start a development shell"
59 | echo " fetch [URL] Clone and build from a Git URL"
60 | echo " help Show this help message"
61 | echo ""
62 | echo "Options:"
63 | echo " --jobs, -j NUMBER Set maximum number of parallel jobs (default: $MAX_JOBS)"
64 | echo " --cores, -c NUMBER Set number of cores per job (default: $CORES)"
65 | echo ""
66 | echo "Examples:"
67 | echo " docker run -v \$(pwd):/home/nixuser/kernelcode kernel-builder:dev build"
68 | echo " docker run -it -v \$(pwd):/home/nixuser/kernelcode kernel-builder:dev dev"
69 | echo " docker run kernel-builder:dev fetch https://huggingface.co/user/repo.git"
70 | }
71 |
72 | # Function to generate a basic flake.nix if it doesn't exist
73 | function ensure_flake_exists {
74 | if [ ! -f "/home/nixuser/kernelcode/flake.nix" ]; then
75 | echo "No flake.nix found, creating a basic one..."
76 | cat <<'FLAKE_EOF' > /home/nixuser/kernelcode/flake.nix
77 | {
78 | description = "Flake for Torch kernel extension";
79 |
80 | inputs = {
81 | kernel-builder.url = "github:huggingface/kernel-builder";
82 | };
83 |
84 | outputs = { self, kernel-builder, }:
85 | kernel-builder.lib.genFlakeOutputs {
86 | path = ./.;
87 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
88 | };
89 | }
90 | FLAKE_EOF
91 | echo "flake.nix created. You can customize it as needed."
92 | else
93 | echo "flake.nix already exists, skipping creation."
94 | fi
95 | }
96 |
97 | # Function to build the extension
98 | function build_extension {
99 | echo "Building Torch Extension Bundle"
100 | # Check if kernelcode is a git repo and get hash if possible
101 | if [ -d "/home/nixuser/kernelcode/.git" ]; then
102 | # Mark git as safe to allow commands
103 | git config --global --add safe.directory /home/nixuser/kernelcode
104 | # Try to get git revision
105 | REV=$(git rev-parse --short=8 HEAD)
106 |
107 | # Check if working directory is dirty
108 | if [ -n "$(git status --porcelain 2)" ]; then
109 | REV="${REV}-dirty"
110 | fi
111 | else
112 | # Generate random material if not a git repo
113 | REV=$(dd if=/dev/urandom status=none bs=1 count=10 2>/dev/null | base32 | tr '[:upper:]' '[:lower:]' | head -c 10)
114 | fi
115 | echo "Building with rev $REV"
116 |
117 | # Check for flake.nix or create one
118 | ensure_flake_exists
119 |
120 | # Make sure the build is up to date
121 | nix run github:huggingface/kernel-builder#update-build -- build.toml
122 |
123 | # Pure bundle build
124 | # TODO: remove the "bundle" after resolving
125 | echo "Building with Nix..."
126 | nix build \
127 | .\#bundle \
128 | --max-jobs $MAX_JOBS \
129 | -j $CORES \
130 | -L 2>&1 | awk '{ print strftime("[%Y-%m-%d %H:%M:%S]"), $0; fflush(); }'
131 |
132 | echo "Build completed. Copying results to /home/nixuser/kernelcode/build/"
133 | mkdir -p /home/nixuser/kernelcode/build
134 | cp -r --dereference ./result/* /home/nixuser/kernelcode/build/
135 | chmod -R u+w /home/nixuser/kernelcode/build
136 | echo 'Done'
137 | }
138 |
139 | # Function to start a dev shell
140 | function start_dev_shell {
141 | echo "Starting development shell..."
142 | # Check for flake.nix or create one
143 | ensure_flake_exists
144 | nix develop
145 | }
146 |
147 | # Function to fetch and build from URL
148 | function fetch_and_build {
149 | if [ -z "$1" ]; then
150 | echo "Error: URL required for fetch command"
151 | show_usage
152 | exit 1
153 | fi
154 |
155 | echo "Fetching code from $1"
156 | rm -rf /home/nixuser/kernelcode/* /home/nixuser/kernelcode/.* 2>/dev/null || true
157 | git lfs install
158 | git clone "$1" /home/nixuser/kernelcode
159 | cd /home/nixuser/kernelcode
160 | build_extension
161 | echo "Build completed. Results are in /home/nixuser/kernelcode/build/"
162 |
163 | # skip login to huggingface since token is set in the env
164 | # check user
165 | nix shell nixpkgs#python3 nixpkgs#python3Packages.huggingface-hub -c huggingface-cli whoami
166 |
167 | # upload the build to the repo
168 | nix shell nixpkgs#python3 nixpkgs#python3Packages.huggingface-hub -c huggingface-cli \
169 | upload \
170 | --revision ${PUSH_REVISION} \
171 | --commit-message "Build from kernel-builder job" \
172 | ${REPO} \
173 | /home/nixuser/kernelcode/build/ \
174 | build/
175 | }
176 |
177 | # Parse arguments
178 | COMMAND="build" # Default command
179 | ARGS=()
180 |
181 | while [[ $# -gt 0 ]]; do
182 | case $1 in
183 | build|dev|fetch|help)
184 | COMMAND="$1"
185 | shift
186 | ;;
187 | --jobs|-j)
188 | MAX_JOBS="$2"
189 | shift 2
190 | ;;
191 | --cores|-c)
192 | CORES="$2"
193 | shift 2
194 | ;;
195 | -*)
196 | echo "Unknown option: $1"
197 | show_usage
198 | exit 1
199 | ;;
200 | *)
201 | ARGS+=("$1")
202 | shift
203 | ;;
204 | esac
205 | done
206 |
207 | # Execute the command
208 | case $COMMAND in
209 | build)
210 | build_extension
211 | ;;
212 | dev)
213 | start_dev_shell
214 | ;;
215 | fetch)
216 | fetch_and_build "${ARGS[0]}"
217 | ;;
218 | help)
219 | show_usage
220 | ;;
221 | *)
222 | echo "Unknown command: $COMMAND"
223 | show_usage
224 | exit 1
225 | ;;
226 | esac
227 | EOF
228 |
229 | # Set permissions and make the script executable
230 | RUN chmod +x /home/nixuser/bin/cli.sh && \
231 | chown -R nixuser:nixuser /home/nixuser
232 |
233 | # Switch to nixuser
234 | USER nixuser
235 |
236 | # Use the cli.sh script directly
237 | ENTRYPOINT ["/home/nixuser/bin/cli.sh"]
--------------------------------------------------------------------------------
/dockerfiles/README.md:
--------------------------------------------------------------------------------
1 | # Kernel Builder Docker Containers
2 |
3 | This directory contains two Docker containers for different use cases:
4 |
5 | ## Root Container (`Dockerfile`)
6 |
7 | This container runs as root and can modify file permissions when mounting volumes.
8 |
9 | ```bash
10 | # Build the container
11 | docker build -t kernel-builder:root -f Dockerfile ..
12 |
13 | # Use the container
14 | docker run --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build
15 | ```
16 |
17 | ## User Container (`Dockerfile.user`)
18 |
19 | This container runs as a non-root user (nixuser with UID 1000) for more secure environments.
20 |
21 | ```bash
22 | # Build the container
23 | docker build -t kernel-builder:user -f Dockerfile.user ..
24 |
25 | # Important: Prepare a directory with correct permissions
26 | mkdir -p ./build
27 | chown -R 1000:1000 ./build # Match the UID:GID of nixuser in the container
28 |
29 | # Use with proper permissions for the build directory
30 | docker run --mount type=bind,source=$(pwd),target=/home/nixuser/kernelcode \
31 | --mount type=bind,source=$(pwd)/build,target=/home/nixuser/kernelcode/build \
32 | kernel-builder:user build
33 | ```
34 |
35 | ## Environment Variables
36 |
37 | Both containers support these build options:
38 |
39 | ```bash
40 | # Set options at build time
41 | docker build -t kernel-builder:custom --build-arg MAX_JOBS=8 --build-arg CORES=2 -f Dockerfile ..
42 |
43 | # Or at runtime
44 | docker run -e MAX_JOBS=8 -e CORES=2 --mount type=bind,source=$(pwd),target=/kernelcode kernel-builder:root build
45 | ```
46 |
--------------------------------------------------------------------------------
/docs/build-variants.md:
--------------------------------------------------------------------------------
1 | # Build variants
2 |
3 | A kernel can be compliant for a specific compute framework (e.g. CUDA) or
4 | architecture (e.g. x86_64). For compliance with a compute framework and
5 | architecture combination, all the build variants listed below must be
6 | available. This list will be updated as new PyTorch versions are released.
7 |
8 | ## CUDA aarch64-linux
9 |
10 | - `torch26-cxx11-cu126-aarch64-linux`
11 | - `torch26-cxx98-cu126-aarch64-linux`
12 | - `torch27-cxx11-cu126-aarch64-linux`
13 | - `torch27-cxx11-cu128-aarch64-linux`
14 |
15 | ## CUDA x86_64-linux
16 |
17 | - `torch26-cxx11-cu118-x86_64-linux`
18 | - `torch26-cxx11-cu124-x86_64-linux`
19 | - `torch26-cxx11-cu126-x86_64-linux`
20 | - `torch26-cxx98-cu118-x86_64-linux`
21 | - `torch26-cxx98-cu124-x86_64-linux`
22 | - `torch26-cxx98-cu126-x86_64-linux`
23 | - `torch27-cxx11-cu118-x86_64-linux`
24 | - `torch27-cxx11-cu126-x86_64-linux`
25 | - `torch27-cxx11-cu128-x86_64-linux`
26 |
27 | ## ROCm x86_64-linux
28 |
29 | - `torch26-cxx11-rocm62-x86_64-linux`
30 | - `torch27-cxx11-rocm63-x86_64-linux`
31 |
32 | ## Universal
33 |
34 | Kernels that are in pure Python (e.g. Triton kernels) only need to provide
35 | a single build variant:
36 |
37 | - `torch-universal`
38 |
--------------------------------------------------------------------------------
/docs/local-dev.md:
--------------------------------------------------------------------------------
1 | # Local development of kernels
2 |
3 | ## Introduction
4 |
5 | `kernel-builder` builds kernels in a sandbox. This has various benefits,
6 | such as building kernels for a wide range of Torch versions, compatibility
7 | with older C library versions and avoiding accidental dependencies.
8 |
9 | However, this is not ideal during kernel development, since language
10 | servers and IDEs do not interpret the `build.toml` file. As a result,
11 | code completion will typically not work. `kernel-builder` provides the
12 | `build2cmake` utility to generate CMake files to build native code and
13 | setuptools files for building the kernel as a regular Python package.
14 | Since CMake and setuptools are widely supported by IDEs, this provides
15 | a much-improved development experience.
16 |
17 | ## Installing `build2cmake`
18 |
19 | `build2cmake` is available as a Rust crate. After [installing Rust](https://rustup.rs),
20 | it can be built and installed as follows:
21 |
22 | ```bash
23 | $ cargo install build2cmake
24 | ```
25 |
26 | ## Generating a Python project with `build2cmake`
27 |
28 | `build2cmake` generates a CMake/Python project from a [`build.toml`](./writing-kernels.md)
29 | file. The invocation is as follows:
30 |
31 | ```bash
32 | $ build2cmake generate-torch build.toml -f
33 | ```
34 |
35 | The `-f` flag is optional and instructs `build2cmake` to overwrite
36 | existing files.
37 |
38 | It is recommended to do an editable install of the generated project into
39 | your Python virtual environment for development:
40 |
41 | ```bash
42 | $ pip install wheel # Needed once to enable bdist_wheel.
43 | $ pip install --no-build-isolation -e .
44 | ```
45 |
46 | **Warnings:**
47 |
48 | - Kernels built in this way should **not** be published on the Kernel
49 | Hub. They do not fulfill the [kernel requirements](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md).
50 | - Do not add the generated files to Git. `build2cmake` has regular updates
51 | and you generally want to use files generated by the latest version.
52 |
--------------------------------------------------------------------------------
/docs/nix.md:
--------------------------------------------------------------------------------
1 | # Using the kernel builder with Nix
2 |
3 | The kernel builder uses Nix for building kernels. You can build or
4 | run the kernels directly if you have [Nix installed](https://nixos.org/download/)
5 | on your system. On systems without Nix you can use the [Docker](./docker.md)
6 | image, which is a wrapper around Nix.
7 |
8 | ## Getting started
9 |
10 | The easiest way get all the Nix functionality is by putting a
11 | `flake.nix` in your kernel repository. To do so, copy
12 | [`examples/relu/flake.nix`](../examples/relu/flake.nix) into the
13 | same directory as your `build.toml` file. Then run `nix flake update`.
14 | This generates a `flake.lock` file that pins the kernel builder
15 | and _all_ its transitive dependencies. Commit both `flake.nix`
16 | and `flake.lock` to your repository, this will ensure that kernel
17 | builds are reproducible.
18 |
19 | Since the kernel builder depends on many packages (e.g. every supported
20 | PyTorch version), it is recommended to [enable the huggingface cache](https://app.cachix.org/cache/huggingface)
21 | to avoid expensive rebuilds.
22 |
23 | The kernel builder also provides Nix development shells with all Torch
24 | and CUDA/ROCm dependencies needed to develop kernels (see below). If
25 | you want to test your kernels inside a Nix development shell and you
26 | are not using NixOS, [make sure that the CUDA driver is visible](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library) to Torch.
27 |
28 | ## Building kernels with Nix
29 |
30 | A kernel that has a `flake.nix` file can be built with `nix build`.
31 | For example:
32 |
33 | ```bash
34 | cd examples/activation
35 | nix build . -L
36 | ```
37 |
38 | You can put this `flake.nix` in your own kernel's root directory to
39 | get add Nix support to your kernel.
40 |
41 | ## Shell for local development
42 |
43 | `kernel-builder` provides shells for developing kernels. In such a shell,
44 | all required dependencies are available, as well as `build2cmake` for generating
45 | project files. For example:
46 |
47 | ```bash
48 | $ nix develop
49 | $ build2cmake generate-torch build.toml
50 | $ cmake -B build-ext
51 | $ cmake --build build-ext
52 | ```
53 |
54 | If you want to test the kernel as a Python package, you can make a virtual
55 | environment inside the shell:
56 |
57 | ```bash
58 | $ nix develop
59 | $ build2cmake generate-torch build.toml
60 | $ python -m venv .venv
61 | $ source .venv/bin/activate
62 | $ pip install --no-build-isolation -e .
63 | ```
64 |
65 | Development shells are available for every build configuration. For
66 | instance, you can get a Torch 2.6 development shell for ROCm extensions
67 | using:
68 |
69 | ```bash
70 | $ nix develop .#devShells.torch26-cxx11-rocm62-x86_64-linux
71 | ```
72 |
73 | ## Shell for testing a kernel
74 |
75 | You can also start a development shell. This will give you a Python interpreter
76 | with the kernel in Python's search path. This makes it more convenient to run
77 | tests:
78 |
79 | ```bash
80 | cd examples/activation
81 | nix develop -L .#test
82 | python -m pytest tests
83 | ```
84 |
85 | ## Building a kernel without `flake.nix`
86 |
87 | If a kernels source directory does not have a `flake.nix` file, you can build the
88 | kernel using the `buildTorchExtensionBundle` function from the kernel builder
89 | itself:
90 |
91 | ```bash
92 | cd examples/activation
93 | nix build --impure --expr 'with import ../..; lib.x86_64-linux.buildTorchExtensionBundle ./.' -L
94 | ```
95 |
--------------------------------------------------------------------------------
/docs/toolchain.md:
--------------------------------------------------------------------------------
1 | # Toolchain
2 |
3 | ## ABI compatibility
4 |
5 | Kernels and kernel extensions typically do not have any explicit external
6 | dependencies, except:
7 |
8 | - The CUDA library version that they were compiled against.
9 | - The Torch version that they were compiled against.
10 | - The C and C++ standard libraries (glibc and libstdc++ in Linux).
11 | - The Python library.
12 |
13 | Of course, the versions on a user's system can differ from the build
14 | system, so we have to account for these dependencies.
15 |
16 | ### Python
17 |
18 | In the case of Python we use the [limited API](https://docs.python.org/3/c-api/stable.html#limited-c-api).
19 | For the limited API, ABI stability it guaranteed. This excludes the
20 | possibility to use some dependencies like pybind11, but since it
21 | reduces the number of builds drastically, we use it anyway.
22 |
23 | ### CUDA/Torch
24 |
25 | Torch and CUDA only have limited ABI compatibility. Therefore, we
26 | compile extensions for all supported CUDA/Torch combinations.
27 |
28 | ### glibc/libstdc++
29 |
30 | glibc and libstdc++ use symbol versioning. As a result, a binary or
31 | library compiled against an older version of these libraries work
32 | on newer versions (modulo the C++11 ABI change). It is however,
33 | not possible to use only older symbol versions when building against
34 | a newer version of these libraries.
35 |
36 | The traditional solution to this problem is to build software in
37 | a container that uses an ancient Linux distribution with old
38 | glibc/libstdc++ versions.
39 |
40 | With Nix we can do better --- with some work we can compile for old
41 | versions of these libraries using a recent nixpkgs. There are several
42 | nuances:
43 |
44 | - libstdc++ is distributed with gcc, so it has the same library version.
45 | - CUDA's nvcc uses a 'backend stdenv'. This stdenv has the latest
46 | gcc that is supported by nvcc. It can differ from the default gcc,
47 | because gcc in nixpkgs is sometimes newer than the version supported
48 | by CUDA.
49 | - gcc also links a binary against libgcc. libgcc must also be compiled
50 | against the target glibc, otherwise the resulting extensions will
51 | still rely on symbols from newer glibc versions.
52 |
53 | With that in mind, there are (at least?) three ways to do this:
54 |
55 | 1. Override glibc and libstdc++ system-wide using an overlay.
56 | 2. Override the backend stdenv of CUDA with one that has older
57 | library versions.
58 | 3. Only override glibc and libstdc++ through the stdenv for
59 | the kernel/extension packages.
60 |
61 | (1) is the most principled approach -- it guarantees that all packages
62 | use the same library versions, making it impossible for a newer version
63 | to creep in. Unfortunately, this has many issues, libraries and
64 | derivations from simply don't interoperate well with a package set from 2024.
65 | For instance, the build of older glibc versions hangs with GNU
66 | make >= 4.4 due to some dependency cycle.
67 |
68 | Overriding the backend stdenv of CUDA (2) has the issue that some
69 | derivations end up with two versions. E.g. they would build using the
70 | headers of the latest glibc and then try to link against the old glibc.
71 |
72 | Finally, (3) seems to work really well. We build everything except
73 | the kernels and extensions using unmodified nixpkgs. Then we tell nvcc
74 | to use our modified stdenv using CMake.
75 |
76 | To make this possible, we import the glibc and libstd++ derivations
77 | from an old nixpkgs. We then create an intermediate stdenv to rebuild
78 | gcc/libgcc against the old glibc. Then glibc, libstdc++, and the
79 | rebuilt gcc form make a new stdenv together. We also link libstdc++
80 | statically.
81 |
--------------------------------------------------------------------------------
/docs/why-nix.md:
--------------------------------------------------------------------------------
1 | # Why Nix?
2 |
3 | The Kernel Builder project uses Nix to build custom kernels designed specifically for PyTorch.
4 |
5 | Here’s why we chose Nix and why it's particularly suited to our workflow:
6 |
7 | ## 1. Consistent and Reproducible Builds
8 |
9 | Nix guarantees deterministic evaluation, ensuring that every kernel is built identically, regardless of the host environment. This consistency prevents "it works on my machine" problems, making debugging and deployment straightforward.
10 |
11 | ## 2. Simplified Dependency Management
12 |
13 | Compiling PyTorch kernels often involves complex dependencies such as CUDA versions, PyTorch APIs, and various C++ toolchains. Nix explicitly defines and manages these dependencies, eliminating version conflicts and making maintenance easier.
14 |
15 | ## 3. Declarative Configuration
16 |
17 | Nix’s declarative approach clearly specifies exactly what each kernel build needs. This transparency aids collaboration, speeds up troubleshooting, and makes it easy to document the build process.
18 |
19 | ## 4. Isolated, Reliable Builds
20 |
21 | Each kernel build with Nix runs in a fully isolated sandbox, removing any uncertainty about external state. This isolation ensures clean builds, free of unexpected side effects.
22 |
23 | ## 5. Efficient Caching and CI Integration
24 |
25 | Kernel compilation can be resource-intensive. Nix leverages efficient caching of build artifacts, significantly reducing build times and optimizing continuous integration workflows.
26 |
27 | ## 6. Easy Experimentation and Rollbacks
28 |
29 | Nix allows you to experiment with different kernel configurations, PyTorch versions, or CUDA toolkits easily. If a change introduces an issue, reverting to a previous state is quick and effortless.
30 |
31 | Overall, Nix streamlines the Kernel Builder workflow, allowing us to efficiently and reliably manage complex machine learning kernel builds.
32 |
33 | ---
34 |
35 | If you want to learn more about Nix, check out the following resources:
36 |
37 | ## References
38 |
39 | - **The Official Nix Manual:**
40 | - The definitive source for all things Nix, providing comprehensive coverage of its features, commands, and ecosystem.
41 | - Link: [Nix Manual (nixos.org)](https://nixos.org/manual/nix/stable/)
42 | - **Nix Pills:**
43 | - A series of blog posts breaking down complex Nix concepts into digestible pieces, ideal for a structured, tutorial-style approach.
44 | - Link: [Nix Pills (nixos.org)](https://nixos.org/guides/nix-pills/)
45 | - **nix.dev**:
46 | - Home of official documentation for the Nix ecosystem.
47 | - Link [nix.dev](https://nix.dev/)
48 | - **NixOS Wiki:**
49 | - A community-driven wiki with a wealth of information, including tips, tricks, and tutorials, covering a wide range of topics, including NixOS-specific information.
50 | - Link: [NixOS Wiki](https://nixos.wiki/wiki/Main_Page)
51 |
52 |
--------------------------------------------------------------------------------
/examples/activation/README.md:
--------------------------------------------------------------------------------
1 | ## Activation
2 |
3 | Activation kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu).
4 |
5 | Copyright 2023-2024, the vLLM team.
6 |
--------------------------------------------------------------------------------
/examples/activation/activation/cuda_compat.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef USE_ROCM
4 | #include
5 | #endif
6 |
7 | #ifndef USE_ROCM
8 | #define WARP_SIZE 32
9 | #else
10 | #define WARP_SIZE warpSize
11 | #endif
12 |
13 | #ifndef USE_ROCM
14 | #define VLLM_LDG(arg) __ldg(arg)
15 | #else
16 | #define VLLM_LDG(arg) *(arg)
17 | #endif
18 |
19 | #ifndef USE_ROCM
20 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
21 | __shfl_xor_sync(uint32_t(-1), var, lane_mask)
22 | #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
23 | __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
24 | #else
25 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
26 | #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
27 | __shfl_xor(var, lane_mask, width)
28 | #endif
29 |
30 | #ifndef USE_ROCM
31 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
32 | #else
33 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
34 | #endif
35 |
36 | #ifndef USE_ROCM
37 | #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
38 | __shfl_down_sync(uint32_t(-1), var, lane_delta)
39 | #else
40 | #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
41 | #endif
42 |
43 | #ifndef USE_ROCM
44 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
45 | cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
46 | #else
47 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
48 | hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
49 | #endif
50 |
--------------------------------------------------------------------------------
/examples/activation/activation/dispatch_utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Adapted from
3 | * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4 | */
5 | #pragma once
6 |
7 | #include
8 |
9 | #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13 |
14 | #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15 | AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16 |
17 | #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
20 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
21 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
22 |
23 | #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
24 | AT_DISPATCH_SWITCH(TYPE, NAME, \
25 | VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
26 |
27 | #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
28 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
29 | AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
30 | AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
31 | AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32 | AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33 |
34 | #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35 | AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
36 |
--------------------------------------------------------------------------------
/examples/activation/build.toml:
--------------------------------------------------------------------------------
1 | [general]
2 | name = "activation"
3 | universal = false
4 |
5 | [torch]
6 | src = [
7 | "torch-ext/torch_binding.cpp",
8 | "torch-ext/torch_binding.h",
9 | ]
10 |
11 | [kernel.activation]
12 | backend = "cuda"
13 | depends = ["torch"]
14 | src = [
15 | "activation/activation_kernels.cu",
16 | "activation/cuda_compat.h",
17 | "activation/dispatch_utils.h",
18 | ]
19 |
--------------------------------------------------------------------------------
/examples/activation/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Flake for activation kernels";
3 |
4 | inputs = {
5 | kernel-builder.url = "path:../..";
6 | };
7 |
8 | outputs =
9 | {
10 | self,
11 | kernel-builder,
12 | }:
13 |
14 | kernel-builder.lib.genFlakeOutputs {
15 | path = ./.;
16 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
17 | };
18 | }
19 |
--------------------------------------------------------------------------------
/examples/activation/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/activation/tests/__init__.py
--------------------------------------------------------------------------------
/examples/activation/tests/kernels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/activation/tests/kernels/__init__.py
--------------------------------------------------------------------------------
/examples/activation/tests/kernels/allclose_default.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Reference default values of atol and rtol are from
4 | # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5 | default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6 | default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
7 |
8 |
9 | def get_default_atol(output) -> float:
10 | return default_atol[output.dtype]
11 |
12 |
13 | def get_default_rtol(output) -> float:
14 | return default_rtol[output.dtype]
15 |
--------------------------------------------------------------------------------
/examples/activation/tests/kernels/test_activation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | from typing import Type
4 |
5 | import activation
6 | import pytest
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import opcheck
11 | from .allclose_default import get_default_atol, get_default_rtol
12 |
13 | DTYPES = [torch.half, torch.bfloat16, torch.float]
14 | NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
15 | D = [512, 13824] # Arbitrary values for testing
16 | SEEDS = [0]
17 | CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
18 |
19 |
20 | def gelu_fast(x: torch.Tensor) -> torch.Tensor:
21 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
22 |
23 |
24 | def gelu_new(x: torch.Tensor) -> torch.Tensor:
25 | c = math.sqrt(2.0 / math.pi)
26 | return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
27 |
28 |
29 | def gelu_quick(x: torch.Tensor) -> torch.Tensor:
30 | return x * torch.sigmoid(1.702 * x)
31 |
32 |
33 | def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
34 | d = x.shape[-1] // 2
35 | x1 = x[..., :d]
36 | x2 = x[..., d:]
37 | x1 = F.threshold(x1, threshold, 0.0)
38 | return x1 * x2
39 |
40 |
41 | def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
42 | d = x.shape[-1] // 2
43 | return F.silu(x[..., :d]) * x[..., d:]
44 |
45 |
46 | def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
47 | d = x.shape[-1] // 2
48 | return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
49 |
50 |
51 | @pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
52 | @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
53 | @pytest.mark.parametrize("d", D)
54 | @pytest.mark.parametrize("dtype", DTYPES)
55 | @pytest.mark.parametrize("seed", SEEDS)
56 | @pytest.mark.parametrize("device", CUDA_DEVICES)
57 | @torch.inference_mode()
58 | def test_act_and_mul(
59 | activation_name: str,
60 | num_tokens: int,
61 | d: int,
62 | dtype: torch.dtype,
63 | seed: int,
64 | device: str,
65 | ) -> None:
66 | random.seed(seed)
67 | torch.manual_seed(seed)
68 | torch.set_default_device(device)
69 | x = torch.randn(num_tokens, 2 * d, dtype=dtype)
70 | if activation_name == "silu":
71 | torch_fn = silu_and_mul
72 | fn = activation.silu_and_mul
73 | op = activation.ops.silu_and_mul
74 | elif activation_name == "gelu":
75 | torch_fn = lambda x: gelu_and_mul(x, "none")
76 | fn = activation.gelu_and_mul
77 | op = activation.ops.gelu_and_mul
78 | elif activation_name == "gelu_tanh":
79 | torch_fn = lambda x: gelu_and_mul(x, "tanh")
80 | fn = activation.gelu_tanh_and_mul
81 | op = activation.ops.gelu_tanh_and_mul
82 | elif activation_name == "fatrelu":
83 | threshold = random.uniform(0, 1)
84 | torch_fn = lambda x: fatrelu_and_mul(x, threshold)
85 | fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
86 | op = activation.ops.fatrelu_and_mul
87 |
88 | out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
89 | out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
90 | out = fn(out, x)
91 | ref_out = torch_fn(x)
92 |
93 | # The SiLU, GELU and FatReLU implementations are equivalent to the native
94 | # PyTorch implementations, so we can do exact comparison.
95 | torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
96 |
97 | d = x.shape[-1] // 2
98 | output_shape = x.shape[:-1] + (d,)
99 | out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
100 | if activation_name == "fatrelu":
101 | opcheck(op, (out, x, threshold))
102 | else:
103 | opcheck(op, (out, x))
104 |
105 |
106 | @pytest.mark.parametrize(
107 | "activation_fns",
108 | [
109 | (gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
110 | (gelu_new, activation.gelu_new, activation.ops.gelu_new),
111 | (gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
112 | ],
113 | )
114 | @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
115 | @pytest.mark.parametrize("d", D)
116 | @pytest.mark.parametrize("dtype", DTYPES)
117 | @pytest.mark.parametrize("seed", SEEDS)
118 | @pytest.mark.parametrize("device", CUDA_DEVICES)
119 | @torch.inference_mode()
120 | def test_activation(
121 | activation_fns,
122 | num_tokens: int,
123 | d: int,
124 | dtype: torch.dtype,
125 | seed: int,
126 | device: str,
127 | ) -> None:
128 | torch.manual_seed(seed)
129 | torch.set_default_device(device)
130 | x = torch.randn(num_tokens, d, dtype=dtype)
131 | torch_fn, fn, op = activation_fns
132 | out = fn(torch.empty_like(x), x)
133 | ref_out = torch_fn(x)
134 | torch.testing.assert_close(
135 | out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
136 | )
137 |
138 | out = torch.empty_like(x)
139 | opcheck(op, (out, x))
140 |
--------------------------------------------------------------------------------
/examples/activation/tests/kernels/utils.py:
--------------------------------------------------------------------------------
1 | """Kernel test utils"""
2 |
3 | import itertools
4 | import random
5 | import unittest
6 | from numbers import Number
7 | from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
8 |
9 | import pytest
10 | import torch
11 | from torch._prims_common import TensorLikeType
12 |
13 | # For now, disable "test_aot_dispatch_dynamic" since there are some
14 | # bugs related to this test in PyTorch 2.4.
15 | DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
16 | "test_schema",
17 | "test_autograd_registration",
18 | "test_faketensor",
19 | )
20 |
21 | ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
22 | "test_schema",
23 | "test_autograd_registration",
24 | "test_faketensor",
25 | "test_aot_dispatch_dynamic",
26 | )
27 |
28 |
29 | # Copied/modified from torch._refs.__init__.py
30 | def fp8_allclose(
31 | a: TensorLikeType,
32 | b: TensorLikeType,
33 | rtol: float = 1e-05,
34 | atol: float = 1e-08,
35 | equal_nan: bool = False,
36 | ) -> bool:
37 | """
38 | Reference implementation of torch.allclose
39 | """
40 | torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
41 |
42 | return bool(
43 | torch.all(
44 | torch.isclose(
45 | a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
46 | )
47 | ).item()
48 | )
49 |
50 |
51 | # A special version of op check that has a restricted default set of test_utils
52 | # and a patched version of allclose that supports fp8 types.
53 | def opcheck(
54 | op: Union[
55 | torch._ops.OpOverload,
56 | torch._ops.OpOverloadPacket,
57 | torch._library.custom_ops.CustomOpDef,
58 | ],
59 | args: Tuple[Any, ...],
60 | kwargs: Optional[Dict[str, Any]] = None,
61 | *,
62 | test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
63 | raise_exception: bool = True,
64 | cond: bool = True
65 | ) -> Dict[str, str]:
66 | with unittest.mock.patch("torch.allclose", new=fp8_allclose):
67 | return (
68 | torch.library.opcheck(
69 | op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
70 | )
71 | if cond
72 | else {}
73 | )
74 |
--------------------------------------------------------------------------------
/examples/activation/torch-ext/activation/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | try:
4 | from ._ops import ops
5 | except ImportError as e:
6 | # Fallback for local development.
7 | try:
8 | import _activation
9 |
10 | ops = torch.ops._activition
11 | except ImportError:
12 | raise e
13 |
14 |
15 | def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
16 | ops.silu_and_mul(out, x)
17 | return out
18 |
19 |
20 | def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
21 | ops.gelu_and_mul(out, x)
22 | return out
23 |
24 |
25 | def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
26 | ops.gelu_tanh_and_mul(out, x)
27 | return out
28 |
29 |
30 | def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None:
31 | ops.fatrelu_and_mul(out, x, threshold)
32 | return out
33 |
34 |
35 | def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
36 | ops.gelu_fast(out, x)
37 | return out
38 |
39 |
40 | def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
41 | ops.gelu_new(out, x)
42 | return out
43 |
44 |
45 | def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
46 | ops.gelu_quick(out, x)
47 | return out
48 |
--------------------------------------------------------------------------------
/examples/activation/torch-ext/torch_binding.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "registration.h"
4 | #include "torch_binding.h"
5 |
6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7 | // Activation ops
8 | // Activation function used in SwiGLU.
9 | ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
10 | ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
11 |
12 | // Activation function used in GeGLU with `none` approximation.
13 | ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
14 | ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
15 |
16 | // Activation function used in GeGLU with `tanh` approximation.
17 | ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
18 | ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
19 |
20 | // FATReLU implementation.
21 | ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
22 | ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
23 |
24 | // GELU implementation used in GPT-2.
25 | ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
26 | ops.impl("gelu_new", torch::kCUDA, &gelu_new);
27 |
28 | // Approximate GELU implementation.
29 | ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
30 | ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
31 |
32 | // Quick GELU implementation.
33 | ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
34 | ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
35 | }
36 |
37 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
38 |
--------------------------------------------------------------------------------
/examples/activation/torch-ext/torch_binding.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
6 |
7 | void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
8 |
9 | void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
10 |
11 | void fatrelu_and_mul(torch::Tensor &out, torch::Tensor &input,
12 | double threshold);
13 |
14 | void gelu_new(torch::Tensor &out, torch::Tensor &input);
15 |
16 | void gelu_fast(torch::Tensor &out, torch::Tensor &input);
17 |
18 | void gelu_quick(torch::Tensor &out, torch::Tensor &input);
19 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/build.toml:
--------------------------------------------------------------------------------
1 | [general]
2 | name = "cutlass_gemm"
3 | universal = false
4 |
5 | [torch]
6 | src = [
7 | "torch-ext/torch_binding.cpp",
8 | "torch-ext/torch_binding.h",
9 | ]
10 |
11 | [kernel.gemm]
12 | backend = "cuda"
13 | depends = [
14 | "torch",
15 | "cutlass_3_6",
16 | ]
17 | src = ["gemm.cu"]
18 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Flake for CUTLASS gemm test kernel";
3 |
4 | inputs = {
5 | kernel-builder.url = "path:../..";
6 | };
7 |
8 | outputs =
9 | {
10 | self,
11 | kernel-builder,
12 | }:
13 | kernel-builder.lib.genFlakeOutputs {
14 | path = ./.;
15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16 | };
17 | }
18 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/gemm.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | void cutlass_gemm(torch::Tensor &out, torch::Tensor const &A, torch::Tensor const &B) {
5 | TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor");
6 | TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor");
7 | TORCH_CHECK(out.device().is_cuda(), "out must be a CUDA tensor");
8 |
9 | TORCH_CHECK(A.is_contiguous(), "A must be a contiguous");
10 | TORCH_CHECK(B.is_contiguous(), "B must be a contiguous");
11 | TORCH_CHECK(out.is_contiguous(), "out must be a contiguous");
12 |
13 | // Define the GEMM operation
14 | using Gemm = cutlass::gemm::device::Gemm;
17 |
18 | // Create a GEMM object
19 | Gemm gemm_op;
20 |
21 | // Define the problem size
22 | cutlass::gemm::GemmCoord problem_size(A.size(0), B.size(1), A.size(1));
23 |
24 | // Define the arguments for the GEMM operation
25 | typename Gemm::Arguments args(
26 | problem_size,
27 | {A.data_ptr(), A.size(1)},
28 | {B.data_ptr(), B.size(1)},
29 | {out.data_ptr(), out.size(1)},
30 | {out.data_ptr(), out.size(1)},
31 | {1.0f, 0.0f}
32 | );
33 |
34 | // Launch the GEMM operation
35 | cutlass::Status status = gemm_op(args);
36 |
37 | // Check for errors
38 | if (status != cutlass::Status::kSuccess) {
39 | throw std::runtime_error("CUTLASS GEMM operation failed");
40 | }
41 | }
42 |
43 |
44 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/tests/test_gemm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from cutlass_gemm import cutlass_gemm
4 |
5 | def test_gemm():
6 | A = torch.randn((10, 20), device="cuda", dtype=torch.float32)
7 | B = torch.randn((20, 30), device="cuda", dtype=torch.float32)
8 | out = torch.randn((10, 30), device="cuda", dtype=torch.float32)
9 |
10 | cutlass_gemm(out, A, B)
11 |
12 | torch.testing.assert_allclose(out, torch.mm(A, B))
13 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/torch-ext/cutlass_gemm/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ._ops import ops
4 |
5 | def cutlass_gemm(out: torch.Tensor, A: torch.Tensor, B: torch.Tensor) -> None:
6 | ops.cutlass_gemm(out, A, B)
7 | return out
8 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/torch-ext/registration.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #define _CONCAT(A, B) A##B
6 | #define CONCAT(A, B) _CONCAT(A, B)
7 |
8 | #define _STRINGIFY(A) #A
9 | #define STRINGIFY(A) _STRINGIFY(A)
10 |
11 | // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12 | // could be a macro instead of a literal token.
13 | #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14 |
15 | // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16 | // could be a macro instead of a literal token.
17 | #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18 | TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19 |
20 | // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21 | // via python's import statement.
22 | #define REGISTER_EXTENSION(NAME) \
23 | PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24 | static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25 | STRINGIFY(NAME), nullptr, 0, nullptr}; \
26 | return PyModule_Create(&module); \
27 | }
28 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/torch-ext/torch_binding.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "registration.h"
4 | #include "torch_binding.h"
5 |
6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7 | ops.def("cutlass_gemm(Tensor! out, Tensor A, Tensor B) -> ()");
8 | ops.impl("cutlass_gemm", torch::kCUDA, &cutlass_gemm);
9 | }
10 |
11 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
12 |
--------------------------------------------------------------------------------
/examples/cutlass-gemm/torch-ext/torch_binding.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void cutlass_gemm(torch::Tensor &out, torch::Tensor const &A, torch::Tensor const &B);
6 |
--------------------------------------------------------------------------------
/examples/relu/build.toml:
--------------------------------------------------------------------------------
1 | [general]
2 | name = "relu"
3 | universal = false
4 |
5 | [torch]
6 | src = [
7 | "torch-ext/torch_binding.cpp",
8 | "torch-ext/torch_binding.h",
9 | ]
10 |
11 | [kernel.activation]
12 | backend = "cuda"
13 | depends = ["torch"]
14 | src = ["relu_cuda/relu.cu"]
15 |
16 | [kernel.activation_metal]
17 | backend = "metal"
18 | src = [
19 | "relu_metal/relu.mm",
20 | ]
21 | depends = [ "torch" ]
22 |
23 | [kernel.activation_rocm]
24 | backend = "rocm"
25 | rocm-archs = [
26 | "gfx906",
27 | "gfx908",
28 | "gfx90a",
29 | "gfx940",
30 | "gfx941",
31 | "gfx942",
32 | "gfx1030",
33 | "gfx1100",
34 | "gfx1101",
35 | ]
36 | depends = ["torch"]
37 | src = ["relu_cuda/relu.cu"]
38 |
--------------------------------------------------------------------------------
/examples/relu/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Flake for ReLU kernel";
3 |
4 | inputs = {
5 | kernel-builder.url = "path:../..";
6 | };
7 |
8 | outputs =
9 | {
10 | self,
11 | kernel-builder,
12 | }:
13 | kernel-builder.lib.genFlakeOutputs {
14 | path = ./.;
15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16 | };
17 | }
18 |
--------------------------------------------------------------------------------
/examples/relu/relu_cuda/relu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 |
7 | __global__ void relu_kernel(float *__restrict__ out,
8 | float const *__restrict__ input, const int d) {
9 | const int64_t token_idx = blockIdx.x;
10 | for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
11 | auto x = input[token_idx * d + idx];
12 | out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
13 | }
14 | }
15 |
16 | void relu(torch::Tensor &out, torch::Tensor const &input) {
17 | TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
18 | TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
19 | TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
20 | input.scalar_type() == at::ScalarType::Float,
21 | "relu_kernel only supports float32");
22 |
23 | TORCH_CHECK(input.sizes() == out.sizes(),
24 | "Tensors must have the same shape. Got input shape: ",
25 | input.sizes(), " and output shape: ", out.sizes());
26 |
27 | TORCH_CHECK(input.scalar_type() == out.scalar_type(),
28 | "Tensors must have the same data type. Got input dtype: ",
29 | input.scalar_type(), " and output dtype: ", out.scalar_type());
30 |
31 | TORCH_CHECK(input.device() == out.device(),
32 | "Tensors must be on the same device. Got input device: ",
33 | input.device(), " and output device: ", out.device());
34 |
35 | int d = input.size(-1);
36 | int64_t num_tokens = input.numel() / d;
37 | dim3 grid(num_tokens);
38 | dim3 block(std::min(d, 1024));
39 | const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
40 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
41 | relu_kernel<<>>(out.data_ptr(),
42 | input.data_ptr(), d);
43 | }
44 |
--------------------------------------------------------------------------------
/examples/relu/relu_metal/relu.mm:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #import
4 | #import
5 | #include
6 |
7 | char const *CUSTOM_KERNEL = R"(
8 | #include
9 | using namespace metal;
10 |
11 | kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
12 | device float *outC [[buffer(1)]],
13 | uint index [[thread_position_in_grid]]) {
14 | // Explicitly write to output
15 | outC[index] = max(0.0f, inA[index]);
16 | }
17 |
18 | kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
19 | device half *outC [[buffer(1)]],
20 | uint index [[thread_position_in_grid]]) {
21 | // Explicitly write to output
22 | outC[index] = max(static_cast(0.0), inA[index]);
23 | }
24 | )";
25 |
26 | static inline id getMTLBufferStorage(const torch::Tensor &tensor) {
27 | return __builtin_bit_cast(id, tensor.storage().data());
28 | }
29 |
30 | torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
31 | torch::Tensor &output) {
32 | @autoreleasepool {
33 | id device = MTLCreateSystemDefaultDevice();
34 | NSError *error = nil;
35 |
36 | int numThreads = input.numel();
37 |
38 | id customKernelLibrary = [device
39 | newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL]
40 | options:nil
41 | error:&error];
42 | TORCH_CHECK(customKernelLibrary,
43 | "Failed to to create custom kernel library, error: ",
44 | error.localizedDescription.UTF8String);
45 |
46 | std::string kernel_name =
47 | std::string("relu_forward_kernel_") +
48 | (input.scalar_type() == torch::kFloat ? "float" : "half");
49 | id customReluFunction = [customKernelLibrary
50 | newFunctionWithName:[NSString
51 | stringWithUTF8String:kernel_name.c_str()]];
52 | TORCH_CHECK(customReluFunction,
53 | "Failed to create function state object for ",
54 | kernel_name.c_str());
55 |
56 | id reluPSO =
57 | [device newComputePipelineStateWithFunction:customReluFunction
58 | error:&error];
59 | TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
60 |
61 | id commandBuffer = torch::mps::get_command_buffer();
62 | TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
63 |
64 | dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
65 |
66 | dispatch_sync(serialQueue, ^() {
67 | id computeEncoder =
68 | [commandBuffer computeCommandEncoder];
69 | TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
70 |
71 | [computeEncoder setComputePipelineState:reluPSO];
72 | [computeEncoder setBuffer:getMTLBufferStorage(input)
73 | offset:input.storage_offset() * input.element_size()
74 | atIndex:0];
75 | [computeEncoder setBuffer:getMTLBufferStorage(output)
76 | offset:output.storage_offset() * output.element_size()
77 | atIndex:1];
78 |
79 | MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
80 |
81 | NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup;
82 | if (threadGroupSize > numThreads) {
83 | threadGroupSize = numThreads;
84 | }
85 | MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
86 |
87 | [computeEncoder dispatchThreads:gridSize
88 | threadsPerThreadgroup:threadgroupSize];
89 |
90 | [computeEncoder endEncoding];
91 |
92 | torch::mps::commit();
93 | });
94 | }
95 |
96 | return output;
97 | }
98 |
99 | void relu(torch::Tensor &out, const torch::Tensor &input) {
100 | TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
101 | TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
102 | TORCH_CHECK(input.scalar_type() == torch::kFloat ||
103 | input.scalar_type() == torch::kHalf,
104 | "Unsupported data type: ", input.scalar_type());
105 |
106 | TORCH_CHECK(input.sizes() == out.sizes(),
107 | "Tensors must have the same shape. Got input shape: ",
108 | input.sizes(), " and output shape: ", out.sizes());
109 |
110 | TORCH_CHECK(input.scalar_type() == out.scalar_type(),
111 | "Tensors must have the same data type. Got input dtype: ",
112 | input.scalar_type(), " and output dtype: ", out.scalar_type());
113 |
114 | TORCH_CHECK(input.device() == out.device(),
115 | "Tensors must be on the same device. Got input device: ",
116 | input.device(), " and output device: ", out.device());
117 |
118 | dispatchReluKernel(input, out);
119 | }
120 |
--------------------------------------------------------------------------------
/examples/relu/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/kernel-builder/924bce43ecb17b61e1cfa928443a5bed2f3fbc3c/examples/relu/tests/__init__.py
--------------------------------------------------------------------------------
/examples/relu/tests/test_relu.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | import relu
7 |
8 |
9 | def test_relu():
10 | if platform.system() == "Darwin":
11 | device = torch.device("mps")
12 | else:
13 | device = torch.device("cuda")
14 | x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
15 | torch.testing.assert_allclose(F.relu(x), relu.relu(x))
16 |
--------------------------------------------------------------------------------
/examples/relu/torch-ext/relu/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from ._ops import ops
6 |
7 |
8 | def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
9 | if out is None:
10 | out = torch.empty_like(x)
11 | ops.relu(out, x)
12 | return out
13 |
--------------------------------------------------------------------------------
/examples/relu/torch-ext/torch_binding.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "registration.h"
4 | #include "torch_binding.h"
5 |
6 | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7 | ops.def("relu(Tensor! out, Tensor input) -> ()");
8 | #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
9 | ops.impl("relu", torch::kCUDA, &relu);
10 | #elif defined(METAL_KERNEL)
11 | ops.impl("relu", torch::kMPS, relu);
12 | #endif
13 | }
14 |
15 | REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
16 |
--------------------------------------------------------------------------------
/examples/relu/torch-ext/torch_binding.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void relu(torch::Tensor &out, torch::Tensor const &input);
6 |
--------------------------------------------------------------------------------
/examples/silu-and-mul-universal/build.toml:
--------------------------------------------------------------------------------
1 | [general]
2 | name = "silu_and_mul_universal"
3 | universal = true
4 |
--------------------------------------------------------------------------------
/examples/silu-and-mul-universal/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Flake for kernels tests";
3 |
4 | inputs = {
5 | kernel-builder.url = "path:../..";
6 | };
7 |
8 | outputs =
9 | {
10 | self,
11 | kernel-builder,
12 | }:
13 | kernel-builder.lib.genFlakeOutputs {
14 | path = ./.;
15 | rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16 | };
17 | }
18 |
--------------------------------------------------------------------------------
/examples/silu-and-mul-universal/tests/test_silu_and_mul.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.library import opcheck
5 |
6 | from silu_and_mul_universal import ops, silu_and_mul
7 |
8 |
9 | def silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
10 | d = x.shape[-1] // 2
11 | return F.silu(x[..., :d]) * x[..., d:]
12 |
13 |
14 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
15 | @pytest.mark.parametrize("requires_grad", [False, True])
16 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
17 | def test_opcheck(device, requires_grad, dtype):
18 | torch.manual_seed(42)
19 | x = torch.randn(32, 128, device=device, requires_grad=requires_grad, dtype=dtype)
20 | opcheck(ops.silu_and_mul, (x,))
21 |
22 |
23 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
24 | @pytest.mark.parametrize("requires_grad", [False, True])
25 | # Only do float32, the numerical instabilities of float16 and bfloat16
26 | # are too large with the different orderings of computing the gradients.
27 | @pytest.mark.parametrize("dtype", [torch.float32])
28 | def test_silu_and_mul(device, requires_grad, dtype):
29 | torch.manual_seed(42)
30 | x_ref = torch.randn(
31 | 32, 128, device=device, requires_grad=requires_grad, dtype=dtype
32 | )
33 | x = torch.empty(32, 128, device=device, requires_grad=requires_grad, dtype=dtype)
34 | with torch.no_grad():
35 | x.copy_(x_ref)
36 |
37 | y_ref = silu_and_mul_ref(x_ref)
38 | y = silu_and_mul(x)
39 |
40 | torch.testing.assert_close(y_ref, y)
41 |
42 | if requires_grad:
43 | d_y = torch.randn((32, 64), device=device, dtype=dtype)
44 | y_ref.backward(d_y)
45 | y.backward(d_y)
46 | torch.testing.assert_close(x_ref.grad, x.grad)
47 |
--------------------------------------------------------------------------------
/examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ._ops import ops
4 | from .silu_and_mul import _silu_and_mul
5 |
6 |
7 | def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
8 | return ops.silu_and_mul(x)
9 |
10 |
11 | __all__ = ["silu_and_mul"]
12 |
--------------------------------------------------------------------------------
/examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/silu_and_mul.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from ._ops import add_op_namespace_prefix
5 |
6 |
7 | @torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=())
8 | def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
9 | d = x.shape[-1] // 2
10 | return F.silu(x[..., :d]) * x[..., d:]
11 |
12 |
13 | def backward(ctx, grad_output):
14 | x = ctx.saved_tensors[0]
15 | d = x.shape[-1] // 2
16 | x1, x2 = x[..., :d], x[..., d:]
17 | sigmoid_x1 = torch.sigmoid(x1)
18 | silu_x1 = F.silu(x1)
19 | dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1)
20 | dx1 = grad_output * x2 * dsilu_dx1
21 | dx2 = grad_output * silu_x1
22 | return torch.cat([dx1, dx2], dim=-1)
23 |
24 |
25 | def setup_context(ctx, inputs, output):
26 | (x,) = inputs
27 | ctx.save_for_backward(x)
28 |
29 |
30 | _silu_and_mul.register_autograd(backward, setup_context=setup_context)
31 |
32 |
33 | @_silu_and_mul.register_fake
34 | def _(x: torch.Tensor) -> torch.Tensor:
35 | return x.new_empty(x.shape[0], x.shape[1] // 2)
36 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "flake-compat": {
4 | "locked": {
5 | "lastModified": 1747046372,
6 | "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7 | "owner": "edolstra",
8 | "repo": "flake-compat",
9 | "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10 | "type": "github"
11 | },
12 | "original": {
13 | "owner": "edolstra",
14 | "repo": "flake-compat",
15 | "type": "github"
16 | }
17 | },
18 | "flake-compat_2": {
19 | "locked": {
20 | "lastModified": 1733328505,
21 | "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22 | "owner": "edolstra",
23 | "repo": "flake-compat",
24 | "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25 | "type": "github"
26 | },
27 | "original": {
28 | "owner": "edolstra",
29 | "repo": "flake-compat",
30 | "type": "github"
31 | }
32 | },
33 | "flake-utils": {
34 | "inputs": {
35 | "systems": "systems"
36 | },
37 | "locked": {
38 | "lastModified": 1731533236,
39 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40 | "owner": "numtide",
41 | "repo": "flake-utils",
42 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43 | "type": "github"
44 | },
45 | "original": {
46 | "owner": "numtide",
47 | "repo": "flake-utils",
48 | "type": "github"
49 | }
50 | },
51 | "flake-utils_2": {
52 | "inputs": {
53 | "systems": "systems_2"
54 | },
55 | "locked": {
56 | "lastModified": 1731533236,
57 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58 | "owner": "numtide",
59 | "repo": "flake-utils",
60 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61 | "type": "github"
62 | },
63 | "original": {
64 | "owner": "numtide",
65 | "repo": "flake-utils",
66 | "type": "github"
67 | }
68 | },
69 | "hf-nix": {
70 | "inputs": {
71 | "flake-compat": "flake-compat_2",
72 | "flake-utils": "flake-utils_2",
73 | "nixpkgs": "nixpkgs"
74 | },
75 | "locked": {
76 | "lastModified": 1748598786,
77 | "owner": "huggingface",
78 | "repo": "hf-nix",
79 | "rev": "6ca679441494139fde1f2355691ddb5dc8170269",
80 | "type": "github"
81 | },
82 | "original": {
83 | "owner": "huggingface",
84 | "repo": "hf-nix",
85 | "type": "github"
86 | }
87 | },
88 | "nixpkgs": {
89 | "locked": {
90 | "lastModified": 1747820358,
91 | "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
92 | "owner": "danieldk",
93 | "repo": "nixpkgs",
94 | "rev": "d3c1681180717528068082103bf323147de6ab0b",
95 | "type": "github"
96 | },
97 | "original": {
98 | "owner": "danieldk",
99 | "ref": "cudatoolkit-12.9-kernel-builder",
100 | "repo": "nixpkgs",
101 | "type": "github"
102 | }
103 | },
104 | "root": {
105 | "inputs": {
106 | "flake-compat": "flake-compat",
107 | "flake-utils": "flake-utils",
108 | "hf-nix": "hf-nix",
109 | "nixpkgs": [
110 | "hf-nix",
111 | "nixpkgs"
112 | ]
113 | }
114 | },
115 | "systems": {
116 | "locked": {
117 | "lastModified": 1681028828,
118 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
119 | "owner": "nix-systems",
120 | "repo": "default",
121 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
122 | "type": "github"
123 | },
124 | "original": {
125 | "owner": "nix-systems",
126 | "repo": "default",
127 | "type": "github"
128 | }
129 | },
130 | "systems_2": {
131 | "locked": {
132 | "lastModified": 1681028828,
133 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
134 | "owner": "nix-systems",
135 | "repo": "default",
136 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
137 | "type": "github"
138 | },
139 | "original": {
140 | "owner": "nix-systems",
141 | "repo": "default",
142 | "type": "github"
143 | }
144 | }
145 | },
146 | "root": "root",
147 | "version": 7
148 | }
149 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Kernel builder";
3 |
4 | inputs = {
5 | flake-utils.url = "github:numtide/flake-utils";
6 | nixpkgs.follows = "hf-nix/nixpkgs";
7 | flake-compat.url = "github:edolstra/flake-compat";
8 | hf-nix.url = "github:huggingface/hf-nix";
9 | };
10 |
11 | outputs =
12 | {
13 | self,
14 | flake-compat,
15 | flake-utils,
16 | hf-nix,
17 | nixpkgs,
18 | }:
19 | let
20 | systems = with flake-utils.lib.system; [
21 | aarch64-darwin
22 | aarch64-linux
23 | x86_64-linux
24 | ];
25 |
26 | # Create an attrset { "" = [ ...]; ... }.
27 | buildSetPerSystem = builtins.listToAttrs (
28 | builtins.map (system: {
29 | name = system;
30 | value = import ./lib/buildsets.nix {
31 | inherit nixpkgs system;
32 | hf-nix = hf-nix.overlays.default;
33 | };
34 | }) systems
35 | );
36 |
37 | libPerSystem = builtins.mapAttrs (
38 | system: buildSet:
39 | import lib/build.nix {
40 | inherit (nixpkgs) lib;
41 | buildSets = buildSetPerSystem.${system};
42 | }
43 | ) buildSetPerSystem;
44 |
45 | # The lib output consists of two parts:
46 | #
47 | # - Per-system build functions.
48 | # - `genFlakeOutputs`, which can be used by downstream flakes to make
49 | # standardized outputs (for all supported systems).
50 | lib = {
51 | allBuildVariantsJSON =
52 | let
53 | buildVariants = (import ./versions.nix { inherit (nixpkgs) lib; }).buildVariants;
54 | in
55 | builtins.toJSON (nixpkgs.lib.foldl' (acc: system: acc // buildVariants system) { } systems);
56 | genFlakeOutputs =
57 | { path, rev }:
58 | flake-utils.lib.eachSystem systems (
59 | system:
60 | let
61 | build = libPerSystem.${system};
62 | revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] rev;
63 | pkgs = nixpkgs.legacyPackages.${system};
64 | shellTorch =
65 | if system == "aarch64-darwin" then "torch27-metal-${system}" else "torch27-cxx11-cu126-${system}";
66 | in
67 | {
68 | devShells = rec {
69 | default = devShells.${shellTorch};
70 | test = testShells.${shellTorch};
71 | devShells = build.torchDevShells {
72 | inherit path;
73 | rev = revUnderscored;
74 | };
75 | testShells = build.torchExtensionShells {
76 | inherit path;
77 | rev = revUnderscored;
78 | };
79 | };
80 | packages = rec {
81 | default = bundle;
82 | bundle = build.buildTorchExtensionBundle {
83 | inherit path;
84 | rev = revUnderscored;
85 | };
86 | redistributable = build.buildDistTorchExtensions {
87 | inherit path;
88 | buildSets = buildSetPerSystem.${system};
89 | rev = revUnderscored;
90 | };
91 | buildTree =
92 | let
93 | build2cmake = self.packages.${system}.build2cmake;
94 | src = build.mkSourceSet path;
95 | in
96 | pkgs.runCommand "torch-extension-build-tree"
97 | {
98 | nativeBuildInputs = [ build2cmake ];
99 | inherit src;
100 | meta = {
101 | description = "Build tree for torch extension with source files and CMake configuration";
102 | };
103 | }
104 | ''
105 | # Copy sources
106 | install -dm755 $out/src
107 | cp -r $src/. $out/src/
108 |
109 | # Generate cmake files
110 | build2cmake generate-torch --ops-id "${revUnderscored}" $src/build.toml $out --force
111 | '';
112 | };
113 | }
114 | );
115 | } // libPerSystem;
116 |
117 | in
118 | flake-utils.lib.eachSystem systems (
119 | system:
120 | let
121 | # Plain nixkpgs that we use to access utility funtions.
122 | pkgs = import nixpkgs {
123 | inherit system;
124 | };
125 | inherit (nixpkgs) lib;
126 |
127 | buildVersion = import ./lib/build-version.nix;
128 |
129 | buildSets = buildSetPerSystem.${system};
130 |
131 | in
132 | rec {
133 | formatter = pkgs.nixfmt-tree;
134 |
135 | packages = rec {
136 | build2cmake = pkgs.callPackage ./pkgs/build2cmake { };
137 |
138 | update-build = pkgs.writeShellScriptBin "update-build" ''
139 | ${build2cmake}/bin/build2cmake update-build ''${1:-build.toml}
140 | '';
141 |
142 | # This package set is exposed so that we can prebuild the Torch versions.
143 | torch = builtins.listToAttrs (
144 | map (buildSet: {
145 | name = buildVersion buildSet;
146 | value = buildSet.torch;
147 | }) buildSets
148 | );
149 |
150 | # Dependencies that should be cached.
151 | forCache =
152 | let
153 | filterDist = lib.filter (output: output != "dist");
154 | # Get all `torch` outputs except for `dist`. Not all outputs
155 | # are dependencies of `out`, but we'll need the `cxxdev` and
156 | # `dev` outputs for kernel builds.
157 | torchOutputs = builtins.listToAttrs (
158 | lib.flatten (
159 | # Map over build sets.
160 | map (
161 | buildSet:
162 | # Map over all outputs of `torch` in a buildset.
163 | map (output: {
164 | name = "${buildVersion buildSet}-${output}";
165 | value = buildSet.torch.${output};
166 | }) (filterDist buildSet.torch.outputs)
167 | ) buildSets
168 | )
169 | );
170 | oldLinuxStdenvs = builtins.listToAttrs (
171 | map (buildSet: {
172 | name = "stdenv-${buildVersion buildSet}";
173 | value = buildSet.pkgs.stdenvGlibc_2_27;
174 | }) buildSets
175 | );
176 | in
177 | pkgs.linkFarm "packages-for-cache" (
178 | torchOutputs // lib.optionalAttrs nixpkgs.legacyPackages.${system}.stdenv.isLinux oldLinuxStdenvs
179 | );
180 | };
181 | }
182 | )
183 | // {
184 | inherit lib;
185 | };
186 | }
187 |
--------------------------------------------------------------------------------
/kernel-abi-check/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "kernel-abi-check"
3 | version = "0.4.0"
4 | edition = "2021"
5 | description = "Check the ABI of Hub Kernels"
6 | homepage = "https://github.com/huggingface/kernel-builder"
7 | license = "Apache-2.0"
8 | documentation = "https://docs.rs/kernel-abi-check"
9 | repository = "https://github.com/huggingface/kernel-builder"
10 |
11 |
12 | [dependencies]
13 | clap = { version = "4", features = ["derive"] }
14 | color-eyre = "0.6"
15 | eyre = "0.6"
16 | itertools = "0.14.0"
17 | object = "0.36.7"
18 | once_cell = "1"
19 | serde = { version = "1", features = ["derive"] }
20 | serde_json = "1"
21 | toml = "0.8"
22 |
23 |
--------------------------------------------------------------------------------
/kernel-abi-check/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "flake-utils": {
4 | "inputs": {
5 | "systems": "systems"
6 | },
7 | "locked": {
8 | "lastModified": 1731533236,
9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
10 | "owner": "numtide",
11 | "repo": "flake-utils",
12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
13 | "type": "github"
14 | },
15 | "original": {
16 | "owner": "numtide",
17 | "repo": "flake-utils",
18 | "type": "github"
19 | }
20 | },
21 | "nixpkgs": {
22 | "locked": {
23 | "lastModified": 1742889210,
24 | "narHash": "sha256-hw63HnwnqU3ZQfsMclLhMvOezpM7RSB0dMAtD5/sOiw=",
25 | "owner": "nixos",
26 | "repo": "nixpkgs",
27 | "rev": "698214a32beb4f4c8e3942372c694f40848b360d",
28 | "type": "github"
29 | },
30 | "original": {
31 | "owner": "nixos",
32 | "ref": "nixos-unstable",
33 | "repo": "nixpkgs",
34 | "type": "github"
35 | }
36 | },
37 | "root": {
38 | "inputs": {
39 | "flake-utils": "flake-utils",
40 | "nixpkgs": "nixpkgs"
41 | }
42 | },
43 | "systems": {
44 | "locked": {
45 | "lastModified": 1681028828,
46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
47 | "owner": "nix-systems",
48 | "repo": "default",
49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
50 | "type": "github"
51 | },
52 | "original": {
53 | "owner": "nix-systems",
54 | "repo": "default",
55 | "type": "github"
56 | }
57 | }
58 | },
59 | "root": "root",
60 | "version": 7
61 | }
62 |
--------------------------------------------------------------------------------
/kernel-abi-check/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "kernel-abi-check devenv";
3 |
4 | inputs = {
5 | flake-utils.url = "github:numtide/flake-utils";
6 | nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable";
7 | };
8 |
9 | outputs =
10 | {
11 | self,
12 | flake-utils,
13 | nixpkgs,
14 | }:
15 | flake-utils.lib.eachDefaultSystem (
16 | system:
17 | let
18 | pkgs = nixpkgs.legacyPackages.${system};
19 | in
20 | {
21 |
22 | devShells.default =
23 | with pkgs;
24 | mkShell {
25 | buildInputs = [
26 | cargo
27 | clippy
28 | openssl.dev
29 | pkg-config
30 | rustc
31 | rustfmt
32 | rust-analyzer
33 | ];
34 |
35 | RUST_SRC_PATH = "${rustPlatform.rustLibSrc}";
36 | };
37 | }
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/kernel-abi-check/src/lib.rs:
--------------------------------------------------------------------------------
1 | //! Functions for checking kernel ABI compatibility.
2 |
3 | mod manylinux;
4 | pub use manylinux::{check_manylinux, ManylinuxViolation};
5 |
6 | mod python_abi;
7 | pub use python_abi::{check_python_abi, PythonAbiViolation};
8 |
9 | mod version;
10 | pub use version::Version;
11 |
--------------------------------------------------------------------------------
/kernel-abi-check/src/main.rs:
--------------------------------------------------------------------------------
1 | use std::path::PathBuf;
2 | use std::{collections::BTreeSet, fs};
3 |
4 | use clap::Parser;
5 | use eyre::{Context, Result};
6 | use object::Object;
7 |
8 | use kernel_abi_check::{
9 | check_manylinux, check_python_abi, ManylinuxViolation, PythonAbiViolation, Version,
10 | };
11 |
12 | /// CLI tool to check library versions
13 | #[derive(Parser, Debug)]
14 | #[command(version, about, long_about = None)]
15 | struct Cli {
16 | /// Python extension library.
17 | object: PathBuf,
18 |
19 | /// Manylinux version.
20 | #[arg(short, long, value_name = "VERSION", default_value = "manylinux_2_28")]
21 | manylinux: String,
22 |
23 | /// Python ABI version.
24 | #[arg(short, long, value_name = "VERSION", default_value = "3.9")]
25 | python_abi: Version,
26 | }
27 |
28 | fn main() -> Result<()> {
29 | // Initialize color_eyre error handling
30 | color_eyre::install()?;
31 |
32 | // Parse command-line arguments
33 | let args = Cli::parse();
34 |
35 | eprintln!(
36 | "🐍 Checking for compatibility with {} and Python ABI version {}",
37 | args.manylinux, args.python_abi
38 | );
39 |
40 | let binary_data = fs::read(args.object).context("Cannot open object file")?;
41 | let file = object::File::parse(&*binary_data).context("Cannot parse object")?;
42 |
43 | let many_linux_violations = check_manylinux(
44 | &args.manylinux,
45 | file.architecture(),
46 | file.endianness(),
47 | file.symbols(),
48 | )?;
49 | print_manylinux_violations(&many_linux_violations, &args.manylinux)?;
50 |
51 | let python_abi_violations = check_python_abi(&args.python_abi, file.symbols())?;
52 | print_python_abi_violations(&python_abi_violations, &args.python_abi);
53 |
54 | if !(many_linux_violations.is_empty() && python_abi_violations.is_empty()) {
55 | return Err(eyre::eyre!("Incompatible symbols found"));
56 | } else {
57 | eprintln!("✅ No compatibility issues found");
58 | }
59 |
60 | Ok(())
61 | }
62 |
63 | fn print_manylinux_violations(
64 | violations: &BTreeSet,
65 | manylinux_version: &str,
66 | ) -> Result<()> {
67 | if !violations.is_empty() {
68 | eprintln!(
69 | "\n⛔ Symbols incompatible with `{}` found:\n",
70 | manylinux_version
71 | );
72 | for violation in violations {
73 | match violation {
74 | ManylinuxViolation::Symbol { name, dep, version } => {
75 | eprintln!("{}_{}: {}", name, dep, version);
76 | }
77 | }
78 | }
79 | }
80 | Ok(())
81 | }
82 |
83 | fn print_python_abi_violations(violations: &BTreeSet, python_abi: &Version) {
84 | if !violations.is_empty() {
85 | let newer_abi3_symbols = violations
86 | .iter()
87 | .filter(|v| matches!(v, PythonAbiViolation::IncompatibleAbi3Symbol { .. }))
88 | .collect::>();
89 | let non_abi3_symbols = violations
90 | .iter()
91 | .filter(|v| matches!(v, PythonAbiViolation::NonAbi3Symbol { .. }))
92 | .collect::>();
93 |
94 | if !newer_abi3_symbols.is_empty() {
95 | eprintln!("\n⛔ Symbols >= Python ABI {} found:\n", python_abi);
96 | for violation in newer_abi3_symbols {
97 | if let PythonAbiViolation::IncompatibleAbi3Symbol { name, added } = violation {
98 | eprintln!("{}: {}", name, added);
99 | }
100 | }
101 | }
102 |
103 | if !non_abi3_symbols.is_empty() {
104 | eprintln!("\n⛔ Non-ABI3 symbols found:\n");
105 | for violation in &non_abi3_symbols {
106 | if let PythonAbiViolation::NonAbi3Symbol { name } = violation {
107 | eprintln!("{}", name);
108 | }
109 | }
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/kernel-abi-check/src/manylinux/mod.rs:
--------------------------------------------------------------------------------
1 | use std::collections::{BTreeSet, HashSet};
2 | use std::str;
3 | use std::{collections::HashMap, str::FromStr};
4 |
5 | use eyre::{bail, Context, ContextCompat, OptionExt, Result};
6 | use object::{Architecture, Endianness, ObjectSymbol, Symbol};
7 | use once_cell::sync::Lazy;
8 | use serde::Deserialize;
9 |
10 | use crate::Version;
11 |
12 | #[derive(Debug, Deserialize)]
13 | struct ManyLinux {
14 | name: String,
15 | #[allow(dead_code)]
16 | aliases: Vec,
17 | #[allow(dead_code)]
18 | priority: u32,
19 | symbol_versions: HashMap>>,
20 | #[allow(dead_code)]
21 | lib_whitelist: Vec,
22 | #[allow(dead_code)]
23 | blacklist: HashMap>,
24 | }
25 |
26 | static MANYLINUX_POLICY_JSON: &str = include_str!("manylinux-policy.json");
27 |
28 | static MANYLINUX_VERSIONS: Lazy> = Lazy::new(|| {
29 | let deserialized: Vec = serde_json::from_str(MANYLINUX_POLICY_JSON).unwrap();
30 | deserialized
31 | .into_iter()
32 | .map(|manylinux| (manylinux.name.clone(), manylinux))
33 | .collect()
34 | });
35 |
36 | /// A violation of the manylinux policy.
37 | #[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
38 | pub enum ManylinuxViolation {
39 | /// A symbol is not allowed in the manylinux version.
40 | Symbol {
41 | name: String,
42 | dep: String,
43 | version: String,
44 | },
45 | }
46 |
47 | pub fn check_manylinux<'a>(
48 | manylinux_version: &str,
49 | architecture: Architecture,
50 | endianness: Endianness,
51 | symbols: impl IntoIterator- >,
52 | ) -> Result> {
53 | let arch_str = architecture.arch_str(endianness)?;
54 | let symbol_versions = MANYLINUX_VERSIONS
55 | .get(manylinux_version)
56 | .context(format!("Unknown manylinux version: {}", manylinux_version))?
57 | .symbol_versions
58 | .get(&arch_str)
59 | .context(format!(
60 | "Cannot find arch `{}` for: {}`",
61 | arch_str, manylinux_version
62 | ))?;
63 |
64 | let mut violations = BTreeSet::new();
65 |
66 | for symbol in symbols {
67 | if symbol.is_undefined() {
68 | let symbol = symbol.name_bytes().context("Cannot get symbol name")?;
69 | let symbol = str::from_utf8(symbol).context("Cannot parse symbol name as UTF-8")?;
70 |
71 | let mut symbol_parts = symbol.split('@');
72 | let symbol_name = symbol_parts.next().context("Cannot get symbol name")?;
73 |
74 | let version_info = match symbol_parts.next() {
75 | Some(version_info) => version_info,
76 | None => continue,
77 | };
78 |
79 | let mut version_parts = version_info.split('_');
80 |
81 | let dep = version_parts
82 | .next()
83 | .ok_or_eyre("Cannot get symbol version name")?;
84 |
85 | let version = match version_parts.next() {
86 | Some(version) => Version::from_str(version)?,
87 | // We also get symbol versions like: libcudart.so.12
88 | None => continue,
89 | };
90 |
91 | if let Some(versions) = symbol_versions.get(dep) {
92 | if !versions.contains(&version.to_string()) {
93 | violations.insert(ManylinuxViolation::Symbol {
94 | name: symbol_name.to_string(),
95 | dep: dep.to_string(),
96 | version: version.to_string(),
97 | });
98 | }
99 | }
100 | }
101 | }
102 |
103 | Ok(violations)
104 | }
105 |
106 | pub trait ArchStr {
107 | fn arch_str(&self, endiannes: Endianness) -> Result;
108 | }
109 |
110 | impl ArchStr for Architecture {
111 | fn arch_str(&self, endiannes: Endianness) -> Result {
112 | Ok(match self {
113 | Architecture::Aarch64 => "aarch64",
114 | Architecture::I386 => "i686",
115 | Architecture::PowerPc64 if matches!(endiannes, Endianness::Big) => "ppc64",
116 | Architecture::PowerPc64 if matches!(endiannes, Endianness::Little) => "ppc64le",
117 | Architecture::S390x => "s390x",
118 | Architecture::X86_64 => "x86_64",
119 | _ => bail!("Unsupported architecture: {:?}", self),
120 | }
121 | .to_string())
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/kernel-abi-check/src/python_abi/mod.rs:
--------------------------------------------------------------------------------
1 | use std::collections::{BTreeSet, HashMap};
2 |
3 | use eyre::Result;
4 | use object::{ObjectSymbol, Symbol};
5 | use once_cell::sync::Lazy;
6 | use serde::Deserialize;
7 |
8 | use crate::version::Version;
9 |
10 | static ABI_TOML: &str = include_str!("stable_abi.toml");
11 |
12 | #[derive(Deserialize)]
13 | struct AbiInfoSerde {
14 | added: Version,
15 | }
16 |
17 | #[derive(Deserialize)]
18 | struct StableAbiSerde {
19 | function: HashMap,
20 | data: HashMap,
21 | }
22 |
23 | #[derive(Clone, Copy, Debug)]
24 | pub enum SymbolType {
25 | Data,
26 | Function,
27 | }
28 |
29 | #[derive(Clone, Debug)]
30 | pub struct AbiInfo {
31 | #[allow(dead_code)]
32 | pub symbol_type: SymbolType,
33 | pub added: Version,
34 | }
35 |
36 | pub static PYTHON_STABLE_ABI: Lazy> = Lazy::new(|| {
37 | let deserialized: StableAbiSerde = toml::de::from_str(ABI_TOML).unwrap();
38 | let mut symbols = HashMap::new();
39 | for (name, abi) in deserialized.function {
40 | symbols.insert(
41 | name,
42 | AbiInfo {
43 | symbol_type: SymbolType::Function,
44 | added: abi.added,
45 | },
46 | );
47 | }
48 | for (name, abi) in deserialized.data {
49 | symbols.insert(
50 | name,
51 | AbiInfo {
52 | symbol_type: SymbolType::Data,
53 | added: abi.added,
54 | },
55 | );
56 | }
57 | symbols
58 | });
59 |
60 | /// Python ABI violation.
61 | #[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
62 | pub enum PythonAbiViolation {
63 | /// Symbol is newer than the specified Python ABI version.
64 | IncompatibleAbi3Symbol { name: String, added: Version },
65 |
66 | /// Symbol is not part of ABI3.
67 | NonAbi3Symbol { name: String },
68 | }
69 |
70 | /// Check for violations of the Python ABI policy.
71 | pub fn check_python_abi<'a>(
72 | python_abi: &Version,
73 | symbols: impl IntoIterator
- >,
74 | ) -> Result> {
75 | let mut violations = BTreeSet::new();
76 | for symbol in symbols {
77 | if !symbol.is_undefined() {
78 | continue;
79 | }
80 |
81 | let symbol_name = symbol.name()?;
82 |
83 | match PYTHON_STABLE_ABI.get(symbol_name) {
84 | Some(abi_info) => {
85 | if &abi_info.added > python_abi {
86 | violations.insert(PythonAbiViolation::IncompatibleAbi3Symbol {
87 | name: symbol_name.to_string(),
88 | added: abi_info.added.clone(),
89 | });
90 | }
91 | }
92 | None => {
93 | if symbol_name.starts_with("Py") || symbol_name.starts_with("_Py") {
94 | violations.insert(PythonAbiViolation::NonAbi3Symbol {
95 | name: symbol_name.to_string(),
96 | });
97 | }
98 | }
99 | }
100 | }
101 |
102 | Ok(violations)
103 | }
104 |
--------------------------------------------------------------------------------
/kernel-abi-check/src/version.rs:
--------------------------------------------------------------------------------
1 | use std::{fmt::Display, str::FromStr};
2 |
3 | use eyre::{ensure, Context, Result};
4 | use serde::{de, Deserialize, Deserializer};
5 |
6 | /// Symbol version.
7 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
8 | pub struct Version(pub Vec);
9 |
10 | impl<'de> Deserialize<'de> for Version {
11 | fn deserialize(deserializer: D) -> Result
12 | where
13 | D: Deserializer<'de>,
14 | {
15 | let s = String::deserialize(deserializer)?;
16 | FromStr::from_str(&s).map_err(de::Error::custom)
17 | }
18 | }
19 |
20 | impl Display for Version {
21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 | write!(
23 | f,
24 | "{}",
25 | itertools::join(self.0.iter().map(|v| v.to_string()), ".")
26 | )
27 | }
28 | }
29 |
30 | impl FromStr for Version {
31 | type Err = eyre::Report;
32 |
33 | fn from_str(version: &str) -> Result {
34 | let version = version.trim().to_owned();
35 | ensure!(!version.is_empty(), "Empty version string");
36 | let mut version_parts = Vec::new();
37 | for part in version.split('.') {
38 | let version_part: usize = part.parse().context("Version must consist of numbers")?;
39 | version_parts.push(version_part);
40 | }
41 |
42 | Ok(Version(version_parts))
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/kernel-compliance-check/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "kernel-compliance-check"
3 | version = "0.1.0"
4 | edition = "2021"
5 | description = "Command-line utility for validating kernel compliance standards"
6 | license = "Apache-2.0"
7 | repository = "https://github.com/huggingface/kernel-builder"
8 | documentation = "https://docs.rs/kernel-compliance-check"
9 | homepage = "https://github.com/huggingface/kernel-builder"
10 |
11 | [features]
12 | default = []
13 | enable_rocm = []
14 |
15 | [dependencies]
16 | eyre = "0.6.12"
17 | clap = { version = "4.5.35", features = ["derive"] }
18 | colored = "3.0.0"
19 | dirs = "6.0.0"
20 | futures = "0.3.31"
21 | hf-hub = { version = "0.4.2", features = ["tokio"] }
22 | kernel-abi-check = "0.4.0"
23 | object = { version = "0.36.7", default-features = false, features = ["read"] }
24 | once_cell = "1.18.0"
25 | reqwest = { version = "0.11", features = ["json"] }
26 | serde = { version = "1.0", features = ["derive"] }
27 | serde_json = "1.0.114"
28 | thiserror = "1.0.40"
29 | tokio = { version = "1.44.2", features = ["full"] }
30 |
31 | [build-dependencies]
32 | ureq = "2.7.1"
--------------------------------------------------------------------------------
/kernel-compliance-check/src/build_variants.json:
--------------------------------------------------------------------------------
1 | ../../build_variants.json
--------------------------------------------------------------------------------
/kernel-compliance-check/src/formatter.rs:
--------------------------------------------------------------------------------
1 | use crate::models::AbiCheckResult;
2 | use colored::Colorize;
3 |
4 | /// Struct for console output formatting
5 | pub struct Console;
6 |
7 | impl Console {
8 | pub fn format_repo_list(repos: &[String], count: usize) {
9 | println!(".");
10 | for repo_id in repos {
11 | println!("├── {repo_id}");
12 | }
13 | println!("╰── {count} kernel repositories found\n");
14 | }
15 |
16 | pub fn format_fetch_status(repo_id: &str, fetching: bool, result: Option<&str>) {
17 | println!("repository: {repo_id}");
18 | if fetching {
19 | println!("status: not found locally, fetching...");
20 | }
21 | if let Some(message) = result {
22 | println!("status: {message}");
23 | }
24 | }
25 |
26 | #[allow(clippy::too_many_arguments)]
27 | pub fn format_repository_check_result(
28 | repo_id: &str,
29 | build_status: &str,
30 | cuda_compatible: bool,
31 | #[cfg(feature = "enable_rocm")] rocm_compatible: bool,
32 | #[cfg(not(feature = "enable_rocm"))] _rocm_compatible: bool,
33 | cuda_variants: &[String],
34 | #[cfg(feature = "enable_rocm")] rocm_variants: &[String],
35 | #[cfg(not(feature = "enable_rocm"))] _rocm_variants: &[String],
36 | cuda_variants_present: &[String],
37 | #[cfg(feature = "enable_rocm")] rocm_variants_present: Vec,
38 | #[cfg(not(feature = "enable_rocm"))] _rocm_variants_present: Vec,
39 | compact_output: bool,
40 | abi_output: &AbiCheckResult,
41 | abi_status: &str,
42 | ) {
43 | // Display console-formatted output
44 | let abi_mark = if abi_output.overall_compatible {
45 | "✓".green()
46 | } else {
47 | "✗".red()
48 | };
49 |
50 | let cuda_mark = if cuda_compatible {
51 | "✓".green()
52 | } else {
53 | "✗".red()
54 | };
55 |
56 | #[cfg(feature = "enable_rocm")]
57 | let rocm_mark = if rocm_compatible {
58 | "✓".green()
59 | } else {
60 | "✗".red()
61 | };
62 |
63 | let label = format!(" {repo_id} ").black().on_bright_white().bold();
64 |
65 | println!("\n{label}");
66 | println!("├── build: {build_status}");
67 |
68 | if compact_output {
69 | // Compact output
70 | #[cfg(feature = "enable_rocm")]
71 | {
72 | println!("│ ├── {} CUDA", cuda_mark);
73 | println!("│ ╰── {} ROCM", rocm_mark);
74 | }
75 |
76 | #[cfg(not(feature = "enable_rocm"))]
77 | {
78 | println!("│ ╰── {cuda_mark} CUDA");
79 | }
80 | } else {
81 | println!("│ {} {}", cuda_mark, "CUDA".bold());
82 |
83 | // Print variant list with proper tree characters
84 | let mut cuda_iter = cuda_variants.iter().peekable();
85 | while let Some(cuda_variant) = cuda_iter.next() {
86 | let is_last = cuda_iter.peek().is_none();
87 | let is_present = cuda_variants_present.contains(cuda_variant);
88 | let prefix = if is_last {
89 | "│ ╰── "
90 | } else {
91 | "│ ├── "
92 | };
93 |
94 | if is_present {
95 | println!("{prefix}{cuda_variant}");
96 | } else {
97 | println!("{}{}", prefix, cuda_variant.dimmed());
98 | }
99 | }
100 |
101 | // Only show ROCm section if the feature is enabled
102 | #[cfg(feature = "enable_rocm")]
103 | {
104 | println!("│ {} {}", rocm_mark, "ROCM".bold());
105 |
106 | let mut rocm_iter = rocm_variants.iter().peekable();
107 | while let Some(rocm_variant) = rocm_iter.next() {
108 | let is_last = rocm_iter.peek().is_none();
109 | let is_present = rocm_variants_present.contains(rocm_variant);
110 | let prefix = if is_last {
111 | "│ ╰── "
112 | } else {
113 | "│ ├── "
114 | };
115 |
116 | if is_present {
117 | println!("{}{}", prefix, rocm_variant);
118 | } else {
119 | println!("{}{}", prefix, rocm_variant.dimmed());
120 | }
121 | }
122 | }
123 | }
124 |
125 | // ABI status section
126 | println!("╰── abi: {abi_status}");
127 | println!(" ├── {} {}", abi_mark, abi_output.manylinux_version);
128 | println!(
129 | " ╰── {} python {}",
130 | abi_mark, abi_output.python_abi_version
131 | );
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/kernel-compliance-check/src/main.rs:
--------------------------------------------------------------------------------
1 | use core::str::FromStr as _;
2 |
3 | use clap::Parser as _;
4 | use eyre::{Context as _, Result};
5 | use kernel_abi_check::Version;
6 | use kernel_compliance_check::{process_repository, Cli, Commands, Format};
7 |
8 | fn main() -> Result<()> {
9 | // Parse CLI arguments
10 | let cli = Cli::parse();
11 |
12 | // Respect KERNELS_CACHE if set
13 | let non_standard_cache: Option = std::env::var("KERNELS_CACHE").ok();
14 |
15 | // Prefer the cli unless explicitly set to avoid it
16 | let prefer_hub_cli = !std::env::var("AVOID_HUB_CLI")
17 | .map(|v| v == "1" || v == "true")
18 | .unwrap_or(false);
19 |
20 | match cli.command {
21 | Commands::Check {
22 | repos,
23 | manylinux,
24 | python_abi,
25 | revision,
26 | long,
27 | force_fetch,
28 | show_violations,
29 | format,
30 | } => {
31 | eprintln!("Running kernel compliance check");
32 | eprintln!("Repositories: {repos}");
33 | eprintln!("Kernel Revision: {revision}");
34 |
35 | // Check repositories for compliance
36 | check_repositories(
37 | &repos,
38 | &manylinux.to_string(),
39 | &python_abi,
40 | prefer_hub_cli,
41 | force_fetch,
42 | &revision,
43 | long,
44 | show_violations,
45 | format,
46 | non_standard_cache.as_ref(),
47 | )?;
48 | }
49 | }
50 |
51 | Ok(())
52 | }
53 |
54 | #[allow(clippy::fn_params_excessive_bools)]
55 | #[expect(clippy::too_many_arguments)]
56 | fn check_repositories(
57 | repos: &str,
58 | manylinux: &str,
59 | python_abi: &str,
60 | prefer_hub_cli: bool,
61 | force_fetch: bool,
62 | revision: &str,
63 | long: bool,
64 | show_violations: bool,
65 | format: Format,
66 | non_standard_cache: Option<&String>,
67 | ) -> Result<()> {
68 | let repositories: Vec = repos
69 | .split(',')
70 | .map(|s| s.trim().to_owned())
71 | .filter(|s| !s.is_empty())
72 | .collect();
73 |
74 | if repositories.is_empty() {
75 | #[derive(serde::Serialize)]
76 | struct ErrorResponse {
77 | status: &'static str,
78 | error: &'static str,
79 | }
80 |
81 | if format.is_json() {
82 | let error = ErrorResponse {
83 | status: "error",
84 | error: "no repository ids provided",
85 | };
86 | let json = serde_json::to_string_pretty(&error)
87 | .context("Failed to serialize error response")?;
88 | println!("{json}");
89 | } else {
90 | eprintln!("no repository ids provided");
91 | }
92 | return Ok(());
93 | }
94 |
95 | let python_version = Version::from_str(python_abi)
96 | .map_err(|e| eyre::eyre!("Invalid Python ABI version {}: {}", python_abi, e))?;
97 |
98 | for repo_id in &repositories {
99 | if let Err(e) = process_repository(
100 | repo_id,
101 | revision,
102 | force_fetch,
103 | prefer_hub_cli,
104 | manylinux,
105 | &python_version,
106 | !long,
107 | show_violations,
108 | format,
109 | non_standard_cache,
110 | ) {
111 | eprintln!("Error processing repository {repo_id}: {e}");
112 |
113 | // Continue processing other repositories rather than exiting early
114 | // This is more user-friendly for batch processing
115 | }
116 | }
117 |
118 | Ok(())
119 | }
120 |
--------------------------------------------------------------------------------
/kernel-compliance-check/src/models.rs:
--------------------------------------------------------------------------------
1 | use std::fmt;
2 |
3 | use clap::{Parser, Subcommand, ValueEnum};
4 | use serde::{Deserialize, Serialize};
5 | use thiserror::Error;
6 |
7 | #[derive(Error, Debug)]
8 | pub enum CompliantError {
9 | #[error("IO error: {0}")]
10 | Io(#[from] std::io::Error),
11 |
12 | #[error("Repository not found: {0}")]
13 | RepositoryNotFound(String),
14 |
15 | #[error("Build directory not found in repository: {0}")]
16 | BuildDirNotFound(String),
17 |
18 | #[error("Failed to fetch repository: {0}")]
19 | FetchError(String),
20 |
21 | #[error("Failed to parse object file: {0}")]
22 | ObjectParseError(String),
23 |
24 | #[error("Failed to check ABI compatibility: {0}")]
25 | AbiCheckError(String),
26 |
27 | #[error("Failed to serialize JSON: {0}")]
28 | SerializationError(String),
29 |
30 | #[error("Failed to fetch variants: {0}")]
31 | VariantsFetchError(String),
32 |
33 | #[error("Network error: {0}")]
34 | NetworkError(String),
35 |
36 | #[error("Unknown error: {0}")]
37 | Other(String),
38 | }
39 |
40 | /// Hugging Face kernel compliance checker
41 | #[derive(Parser)]
42 | #[command(author, version, about)]
43 | pub struct Cli {
44 | #[command(subcommand)]
45 | pub command: Commands,
46 | }
47 |
48 | #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
49 | pub enum Format {
50 | Console,
51 | Json,
52 | }
53 |
54 | impl Format {
55 | #[must_use]
56 | pub fn is_json(&self) -> bool {
57 | matches!(self, Format::Json)
58 | }
59 | }
60 |
61 | #[derive(Subcommand)]
62 | pub enum Commands {
63 | /// Check repository compliance and ABI compatibility
64 | Check {
65 | /// Repository IDs or names (comma-separated)
66 | #[arg(value_name = "REPOS")]
67 | repos: String,
68 |
69 | /// Manylinux version to check against
70 | #[arg(short, long, default_value = "manylinux_2_28")]
71 | manylinux: String,
72 |
73 | /// Python ABI version to check against
74 | #[arg(short, long, default_value = "3.9")]
75 | python_abi: String,
76 |
77 | /// Revision (branch, tag, or commit hash) to use when fetching
78 | #[arg(short, long, default_value = "main")]
79 | revision: String,
80 |
81 | /// Show all variants in a long format. Default is compact output.
82 | #[arg(long, default_value_t = false)]
83 | long: bool,
84 |
85 | /// Force fetch the repository if not found locally
86 | #[arg(long, alias = "force", default_value_t = false)]
87 | force_fetch: bool,
88 |
89 | /// Show ABI violations in the output. Default is to only show compatibility status.
90 | #[arg(long, default_value_t = false)]
91 | show_violations: bool,
92 |
93 | /// Format of the output. Default is console
94 | #[arg(long, default_value = "console")]
95 | format: Format,
96 | },
97 | }
98 |
99 | /// Structured representation of build variants
100 | #[derive(Debug, Deserialize)]
101 | pub struct VariantsConfig {
102 | #[serde(rename = "x86_64-linux")]
103 | pub x86_64_linux: ArchConfig,
104 | #[serde(rename = "aarch64-linux")]
105 | pub aarch64_linux: ArchConfig,
106 | }
107 |
108 | #[derive(Debug, Deserialize)]
109 | pub struct ArchConfig {
110 | pub cuda: Vec,
111 | #[serde(default)]
112 | #[cfg(feature = "enable_rocm")]
113 | pub rocm: Vec,
114 | #[cfg(not(feature = "enable_rocm"))]
115 | #[serde(default, skip)]
116 | _rocm: Vec,
117 | }
118 |
119 | #[derive(Debug, Clone, Serialize, Deserialize)]
120 | pub struct Variant {
121 | pub torch_version: String,
122 | pub cxx_abi: String,
123 | pub compute_framework: String,
124 | pub arch: String,
125 | pub os: String,
126 | }
127 |
128 | impl fmt::Display for Variant {
129 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 | write!(
131 | f,
132 | "{}-{}-{}-{}-{}",
133 | self.torch_version, self.cxx_abi, self.compute_framework, self.arch, self.os
134 | )
135 | }
136 | }
137 |
138 | impl Variant {
139 | #[must_use]
140 | pub fn from_name(name: &str) -> Option {
141 | let parts: Vec<&str> = name.split('-').collect();
142 | if parts.len() < 5 {
143 | return None;
144 | }
145 | // Format: torch{major}{minor}-{cxxabi}-{compute_framework}-{arch}-{os}
146 | Some(Variant {
147 | torch_version: parts[0].to_string(),
148 | cxx_abi: parts[1].to_string(),
149 | compute_framework: parts[2].to_string(),
150 | arch: parts[3].to_string(),
151 | os: parts[4].to_string(),
152 | })
153 | }
154 | }
155 |
156 | #[derive(Serialize)]
157 | pub struct RepoErrorResponse {
158 | pub repository: String,
159 | pub status: String,
160 | pub error: String,
161 | }
162 |
163 | #[derive(Serialize)]
164 | pub struct RepositoryCheckResult {
165 | pub repository: String,
166 | pub status: String,
167 | pub build_status: BuildStatus,
168 | pub abi_status: AbiStatus,
169 | }
170 |
171 | #[derive(Serialize)]
172 | pub struct BuildStatus {
173 | pub summary: String,
174 | pub cuda: CudaStatus,
175 | #[serde(skip_serializing_if = "Option::is_none")]
176 | pub rocm: Option,
177 | }
178 |
179 | #[derive(Serialize)]
180 | pub struct CudaStatus {
181 | pub compatible: bool,
182 | pub present: Vec,
183 | pub missing: Vec,
184 | }
185 |
186 | #[derive(Serialize)]
187 | pub struct RocmStatus {
188 | pub compatible: bool,
189 | pub present: Vec,
190 | pub missing: Vec,
191 | }
192 |
193 | #[derive(Serialize)]
194 | pub struct AbiStatus {
195 | pub compatible: bool,
196 | pub manylinux_version: String,
197 | pub python_abi_version: String,
198 | pub variants: Vec,
199 | }
200 |
201 | #[derive(Serialize)]
202 | pub struct VariantCheckOutput {
203 | pub name: String,
204 | pub compatible: bool,
205 | pub has_shared_objects: bool,
206 | pub violations: Vec,
207 | }
208 |
209 | #[derive(Debug, Clone, Serialize, Deserialize)]
210 | pub struct SharedObjectViolation {
211 | pub message: String,
212 | }
213 |
214 | #[derive(Debug, Clone, Serialize, Deserialize)]
215 | pub struct VariantResult {
216 | pub name: String,
217 | pub is_compatible: bool,
218 | pub violations: Vec,
219 | pub has_shared_objects: bool,
220 | }
221 |
222 | #[derive(Debug, Clone)]
223 | pub struct AbiCheckResult {
224 | pub overall_compatible: bool,
225 | pub variants: Vec,
226 | pub manylinux_version: String,
227 | pub python_abi_version: kernel_abi_check::Version,
228 | }
229 |
--------------------------------------------------------------------------------
/lib/build-version.nix:
--------------------------------------------------------------------------------
1 | {
2 | gpu,
3 | pkgs,
4 | torch,
5 | upstreamVariant,
6 | }:
7 | let
8 | inherit (pkgs) lib;
9 | inherit (import ./version-utils.nix { inherit lib; }) flattenVersion abiString;
10 | abi = torch: abiString torch.passthru.cxx11Abi;
11 | targetPlatform = pkgs.stdenv.targetPlatform.system;
12 | cudaVersion = torch: "cu${flattenVersion torch.cudaPackages.cudaMajorMinorVersion}";
13 | rocmVersion =
14 | torch: "rocm${flattenVersion (lib.versions.majorMinor torch.rocmPackages.rocm.version)}";
15 | gpuVersion = torch: (if torch.cudaSupport then cudaVersion else rocmVersion) torch;
16 | torchVersion = torch: flattenVersion torch.version;
17 | in
18 | if pkgs.stdenv.hostPlatform.isDarwin then
19 | "torch${torchVersion torch}-metal-${targetPlatform}"
20 | else
21 | "torch${torchVersion torch}-${abi torch}-${gpuVersion torch}-${targetPlatform}"
22 |
--------------------------------------------------------------------------------
/lib/buildsets.nix:
--------------------------------------------------------------------------------
1 | {
2 | nixpkgs,
3 | system,
4 | hf-nix,
5 | }:
6 |
7 | let
8 | inherit (nixpkgs) lib;
9 |
10 | overlay = import ../overlay.nix;
11 |
12 | # Get versions.
13 | inherit (import ../versions.nix { inherit lib; }) buildConfigs cudaVersions rocmVersions;
14 |
15 | flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version);
16 |
17 | # An overlay that overides CUDA to the given version.
18 | overlayForCudaVersion = cudaVersion: self: super: {
19 | cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}";
20 | };
21 |
22 | overlayForRocmVersion = rocmVersion: self: super: {
23 | rocmPackages = super."rocmPackages_${flattenVersion rocmVersion}";
24 | };
25 |
26 | # Construct the nixpkgs package set for the given versions.
27 | pkgsForVersions =
28 | pkgsByCudaVer:
29 | {
30 | gpu,
31 | cudaVersion ? "",
32 | metal ? false,
33 | rocmVersion ? "",
34 | torchVersion,
35 | cxx11Abi,
36 | upstreamVariant ? false,
37 | }:
38 | let
39 | pkgs =
40 | if gpu == "cuda" then
41 | pkgsByCudaVer.${cudaVersion}
42 | else if gpu == "rocm" then
43 | pkgsByRocmVer.${rocmVersion}
44 | else if gpu == "metal" then
45 | pkgsForMetal
46 | else
47 | throw "Unknown compute framework: ${gpu}";
48 | torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
49 | inherit cxx11Abi;
50 | };
51 | in
52 | {
53 | inherit
54 | gpu
55 | pkgs
56 | torch
57 | upstreamVariant
58 | ;
59 | };
60 |
61 | pkgsForMetal = import nixpkgs {
62 | inherit system;
63 | overlays = [
64 | hf-nix
65 | overlay
66 | ];
67 | };
68 |
69 | pkgsForRocm = import nixpkgs {
70 | inherit system;
71 | config = {
72 | allowUnfree = true;
73 | rocmSupport = true;
74 | };
75 | overlays = [
76 | hf-nix
77 | overlay
78 | ];
79 | };
80 |
81 | # Instantiate nixpkgs for the given CUDA versions. Returns
82 | # an attribute set like `{ "12.4" = ; ... }`.
83 | pkgsForCudaVersions =
84 | cudaVersions:
85 | builtins.listToAttrs (
86 | map (cudaVersion: {
87 | name = cudaVersion;
88 | value = import nixpkgs {
89 | inherit system;
90 | config = {
91 | allowUnfree = true;
92 | cudaSupport = true;
93 | };
94 | overlays = [
95 | hf-nix
96 | overlay
97 | (overlayForCudaVersion cudaVersion)
98 | ];
99 | };
100 | }) cudaVersions
101 | );
102 |
103 | pkgsByCudaVer = pkgsForCudaVersions cudaVersions;
104 |
105 | pkgsForRocmVersions =
106 | rocmVersions:
107 | builtins.listToAttrs (
108 | map (rocmVersion: {
109 | name = rocmVersion;
110 | value = import nixpkgs {
111 | inherit system;
112 | config = {
113 | allowUnfree = true;
114 | rocmSupport = true;
115 | };
116 | overlays = [
117 | hf-nix
118 | overlay
119 | (overlayForRocmVersion rocmVersion)
120 | ];
121 | };
122 | }) rocmVersions
123 | );
124 |
125 | pkgsByRocmVer = pkgsForRocmVersions rocmVersions;
126 |
127 | in
128 | map (pkgsForVersions pkgsByCudaVer) (buildConfigs system)
129 |
--------------------------------------------------------------------------------
/lib/deps.nix:
--------------------------------------------------------------------------------
1 | {
2 | lib,
3 | }:
4 |
5 | {
6 | pkgs,
7 | torch,
8 | deps,
9 | }:
10 |
11 | let
12 | knownDeps = with pkgs.cudaPackages; {
13 | "cutlass_2_10" = [
14 | pkgs.cutlass_2_10
15 | ];
16 | "cutlass_3_5" = [
17 | pkgs.cutlass_3_5
18 | ];
19 | "cutlass_3_6" = [
20 | pkgs.cutlass_3_6
21 | ];
22 | "cutlass_3_8" = [
23 | pkgs.cutlass_3_8
24 | ];
25 | "torch" = [
26 | torch
27 | torch.cxxdev
28 | ];
29 | };
30 |
31 | in
32 | let
33 | depToPkg =
34 | dep:
35 | assert lib.assertMsg (builtins.hasAttr dep knownDeps) "Unknown dependency: ${dep}";
36 | knownDeps.${dep};
37 | in
38 | lib.flatten (map depToPkg deps)
39 |
--------------------------------------------------------------------------------
/lib/join-paths/default.nix:
--------------------------------------------------------------------------------
1 | args@{
2 | pkgs,
3 |
4 | name,
5 |
6 | # Attribute set with names to paths.
7 | namePaths,
8 |
9 | preferLocalBuild ? true,
10 | allowSubstitutes ? false,
11 | }:
12 | let
13 | inherit (pkgs) lib;
14 | args_ = removeAttrs args [
15 | "name"
16 | "pkgs"
17 | "namePaths"
18 | ];
19 | # Iterating over pairs in bash sucks, so let's generate
20 | # the commands in Nix instead.
21 | copyPath = path: pkg: ''
22 | mkdir -p ${placeholder "out"}/${path}
23 | cp -r ${pkg}/* ${placeholder "out"}/${path}
24 | '';
25 | in
26 | pkgs.runCommand name args_ (lib.concatStringsSep "\n" (lib.mapAttrsToList copyPath namePaths))
27 |
--------------------------------------------------------------------------------
/lib/source-set.nix:
--------------------------------------------------------------------------------
1 | { lib }:
2 |
3 | path:
4 |
5 | let
6 | inherit (lib) fileset;
7 | readToml = path: builtins.fromTOML (builtins.readFile path);
8 | readBuildConfig = path: readToml (path + "/build.toml");
9 | buildConfig = readBuildConfig path;
10 | nameToPath = path: name: path + "/${name}";
11 | kernels = buildConfig.kernel or { };
12 | extConfig = buildConfig.torch or { };
13 | pyExt =
14 | extConfig.pyext or [
15 | "py"
16 | "pyi"
17 | ];
18 | pyFilter = file: builtins.any (ext: file.hasExt ext) pyExt;
19 | extSrc = extConfig.src or [ ] ++ [ "build.toml" ];
20 | pySrcSet = fileset.fileFilter pyFilter (path + "/torch-ext");
21 | kernelsSrc = fileset.unions (
22 | lib.flatten (lib.mapAttrsToList (name: buildConfig: map (nameToPath path) buildConfig.src) kernels)
23 | );
24 | srcSet = fileset.unions (map (nameToPath path) extSrc);
25 | in
26 | fileset.toSource {
27 | root = path;
28 | fileset = fileset.unions [
29 | kernelsSrc
30 | srcSet
31 | pySrcSet
32 | ];
33 | }
34 |
--------------------------------------------------------------------------------
/lib/torch-extension-noarch/default.nix:
--------------------------------------------------------------------------------
1 | {
2 | stdenv,
3 | extensionName,
4 | rev,
5 |
6 | build2cmake,
7 | torch,
8 |
9 | src,
10 | }:
11 |
12 | stdenv.mkDerivation (prevAttrs: {
13 | name = "${extensionName}-torch-ext";
14 |
15 | inherit src;
16 |
17 | # Add Torch as a dependency, so that devshells for universal kernels
18 | # also get torch as a build input.
19 | buildInputs = [ torch ];
20 |
21 | nativeBuildInputs = [ build2cmake ];
22 |
23 | dontBuild = true;
24 |
25 | # We do not strictly need this, since we don't use the setuptools-based
26 | # build. But `build2cmake` does proper validation of the build.toml, so
27 | # we run it anyway.
28 | postPatch = ''
29 | build2cmake generate-torch --ops-id ${rev} build.toml
30 | '';
31 |
32 | installPhase = ''
33 | mkdir -p $out
34 | cp -r torch-ext/${extensionName} $out/
35 | '';
36 | })
37 |
--------------------------------------------------------------------------------
/lib/torch-extension/_ops.py.in:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import _@EXTENSION_NAME@
3 | ops = torch.ops._@EXTENSION_NAME@
4 |
5 | def add_op_namespace_prefix(op_name: str):
6 | """
7 | Prefix op by namespace.
8 | """
9 | return f"_@EXTENSION_NAME@::{op_name}"
10 |
--------------------------------------------------------------------------------
/lib/torch-extension/default.nix:
--------------------------------------------------------------------------------
1 | {
2 | extensionName,
3 | nvccThreads,
4 | rev,
5 |
6 | # Wheter to strip rpath for non-nix use.
7 | stripRPath ? false,
8 |
9 | src,
10 |
11 | config,
12 | cudaSupport ? config.cudaSupport,
13 | rocmSupport ? config.rocmSupport,
14 |
15 | lib,
16 | stdenv,
17 | cudaPackages,
18 | cmake,
19 | cmakeNvccThreadsHook,
20 | ninja,
21 | python3,
22 | kernel-abi-check,
23 | build2cmake,
24 | rocmPackages,
25 |
26 | apple-sdk_15,
27 | extraDeps ? [ ],
28 | torch,
29 |
30 | doAbiCheck,
31 | }:
32 |
33 | let
34 | # CLR that uses the provided stdenv, which can be different from the default
35 | # to support old glibc/libstdc++ versions.
36 | clr = (
37 | rocmPackages.clr.override {
38 | clang = rocmPackages.llvm.clang.override {
39 | inherit stdenv;
40 | bintools = rocmPackages.llvm.bintools.override { libc = stdenv.cc.libc; };
41 | glibc = stdenv.cc.libc;
42 | };
43 | }
44 | );
45 |
46 | in
47 | stdenv.mkDerivation (prevAttrs: {
48 | name = "${extensionName}-torch-ext";
49 |
50 | inherit doAbiCheck nvccThreads src;
51 |
52 | # Generate build files.
53 | postPatch = ''
54 | build2cmake generate-torch --backend ${
55 | if cudaSupport then
56 | "cuda"
57 | else if rocmSupport then
58 | "rocm"
59 | else
60 | "metal"
61 | } --ops-id ${rev} build.toml
62 | '';
63 |
64 | # hipify copies files, but its target is run in the CMake build and install
65 | # phases. Since some of the files come from the Nix store, this fails the
66 | # second time around.
67 | preInstall = ''
68 | chmod -R u+w .
69 | '';
70 |
71 | nativeBuildInputs =
72 | [
73 | kernel-abi-check
74 | cmake
75 | ninja
76 | build2cmake
77 | ]
78 | ++ lib.optionals cudaSupport [
79 | cmakeNvccThreadsHook
80 | cudaPackages.cuda_nvcc
81 | ]
82 | ++ lib.optionals rocmSupport [
83 | clr
84 | ];
85 |
86 | buildInputs =
87 | [
88 | torch
89 | torch.cxxdev
90 | ]
91 | ++ lib.optionals cudaSupport (
92 | with cudaPackages;
93 | [
94 | cuda_cudart
95 |
96 | # Make dependent on build configuration dependencies once
97 | # the Torch dependency is gone.
98 | cuda_cccl
99 | libcublas
100 | libcusolver
101 | libcusparse
102 | ]
103 | )
104 | #++ lib.optionals rocmSupport (with rocmPackages; [ clr rocm-core ])
105 | ++ lib.optionals stdenv.hostPlatform.isDarwin [
106 | apple-sdk_15
107 | ]
108 | ++ extraDeps;
109 |
110 | env =
111 | lib.optionalAttrs cudaSupport {
112 | CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}";
113 | TORCH_CUDA_ARCH_LIST =
114 | if cudaPackages.cudaOlder "12.8" then
115 | "7.0;7.5;8.0;8.6;8.9;9.0+PTX"
116 | else
117 | "7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0+PTX";
118 | }
119 | // lib.optionalAttrs rocmSupport {
120 | PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;
121 | };
122 |
123 | # If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++.
124 | dontSetupCUDAToolkitCompilers = true;
125 |
126 | cmakeFlags =
127 | [
128 | (lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python")
129 | ]
130 | ++ lib.optionals cudaSupport [
131 | (lib.cmakeFeature "CMAKE_CUDA_HOST_COMPILER" "${stdenv.cc}/bin/g++")
132 | ]
133 | ++ lib.optionals rocmSupport [
134 | # Ensure sure that we use HIP from our CLR override and not HIP from
135 | # the symlink-joined ROCm toolkit.
136 | (lib.cmakeFeature "CMAKE_HIP_COMPILER_ROCM_ROOT" "${clr}")
137 | (lib.cmakeFeature "HIP_ROOT_DIR" "${clr}")
138 | ];
139 |
140 | postInstall =
141 | ''
142 | (
143 | cd ..
144 | cp -r torch-ext/${extensionName} $out/
145 | )
146 | cp $out/_${extensionName}_*/* $out/${extensionName}
147 | rm -rf $out/_${extensionName}_*
148 | ''
149 | + lib.optionalString stripRPath ''
150 | find $out/${extensionName} -name '*.so' \
151 | -exec patchelf --set-rpath "" {} \;
152 | '';
153 |
154 | doInstallCheck = true;
155 |
156 | passthru = {
157 | inherit torch;
158 | };
159 | })
160 |
--------------------------------------------------------------------------------
/lib/version-utils.nix:
--------------------------------------------------------------------------------
1 | { lib }:
2 |
3 | {
4 | flattenVersion = version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.pad 2 version);
5 | abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98";
6 | }
7 |
--------------------------------------------------------------------------------
/overlay.nix:
--------------------------------------------------------------------------------
1 | final: prev: {
2 | cmakeNvccThreadsHook = prev.callPackage ./pkgs/cmake-nvcc-threads-hook { };
3 |
4 | # Local packages
5 |
6 | kernel-abi-check = prev.callPackage ./pkgs/kernel-abi-check { };
7 |
8 | build2cmake = prev.callPackage ./pkgs/build2cmake { };
9 |
10 | stdenvGlibc_2_27 = prev.callPackage ./pkgs/stdenv-glibc-2_27 { };
11 | }
12 |
--------------------------------------------------------------------------------
/pkgs/build2cmake/default.nix:
--------------------------------------------------------------------------------
1 | {
2 | lib,
3 | rustPlatform,
4 | pkg-config,
5 | libgit2,
6 | openssl,
7 | }:
8 |
9 | let
10 | version = (builtins.fromTOML (builtins.readFile ../../build2cmake/Cargo.toml)).package.version;
11 | in
12 | rustPlatform.buildRustPackage {
13 | inherit version;
14 | pname = "build2cmake";
15 |
16 | src =
17 | let
18 | sourceFiles =
19 | file:
20 | file.name == "Cargo.toml"
21 | || file.name == "Cargo.lock"
22 | || file.name == "pyproject.toml"
23 | || file.name == "pyproject_universal.toml"
24 | || file.name == "cuda_supported_archs.json"
25 | || (builtins.any file.hasExt [
26 | "cmake"
27 | "h"
28 | "py"
29 | "rs"
30 | ]);
31 | in
32 | lib.fileset.toSource {
33 | root = ../../build2cmake;
34 | fileset = lib.fileset.fileFilter sourceFiles ../../build2cmake;
35 | };
36 |
37 | cargoLock = {
38 | lockFile = ../../build2cmake/Cargo.lock;
39 | };
40 |
41 | nativeBuildInputs = [ pkg-config ];
42 |
43 | buildInputs = [
44 | libgit2
45 | openssl.dev
46 | ];
47 |
48 | meta = {
49 | description = "Create cmake build infrastructure from build.toml files";
50 | };
51 | }
52 |
--------------------------------------------------------------------------------
/pkgs/cmake-nvcc-threads-hook/cmake-nvcc-threads-hook.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | _setNvccThreadsHook() {
4 | if [ -z "${nvccThreads}" ] || [ "${nvccThreads}" -ne "${nvccThreads}" ] 2>/dev/null; then
5 | >&2 echo "Number of nvcc threads is not (correctly) set, setting to 4"
6 | nvccThreads=4
7 | fi
8 |
9 | # Ensure that we do not use more threads than build cores.
10 | nvccThreads=$((NIX_BUILD_CORES < nvccThreads ? NIX_BUILD_CORES : nvccThreads ))
11 |
12 | # Change the number of build cores so that build cores * threads is
13 | # within bounds.
14 | export NIX_BUILD_CORES=$(($NIX_BUILD_CORES / nvccThreads))
15 |
16 | appendToVar cmakeFlags -DNVCC_THREADS="${nvccThreads}"
17 | }
18 |
19 | preConfigureHooks+=(_setNvccThreadsHook)
20 |
--------------------------------------------------------------------------------
/pkgs/cmake-nvcc-threads-hook/default.nix:
--------------------------------------------------------------------------------
1 | { makeSetupHook }:
2 |
3 | makeSetupHook {
4 | name = "cmake-nvcc-threads-hook";
5 | } ./cmake-nvcc-threads-hook.sh
6 |
--------------------------------------------------------------------------------
/pkgs/kernel-abi-check/default.nix:
--------------------------------------------------------------------------------
1 | {
2 | lib,
3 | rustPlatform,
4 | }:
5 |
6 | let
7 | version = (builtins.fromTOML (builtins.readFile ../../kernel-abi-check/Cargo.toml)).package.version;
8 | in
9 | rustPlatform.buildRustPackage {
10 | inherit version;
11 | pname = "kernel-abi-check";
12 |
13 | src =
14 | let
15 | sourceFiles =
16 | file:
17 | file.name == "Cargo.toml"
18 | || file.name == "Cargo.lock"
19 | || file.name == "manylinux-policy.json"
20 | || file.hasExt "rs"
21 | || file.name == "stable_abi.toml";
22 | in
23 | lib.fileset.toSource {
24 | root = ../../kernel-abi-check;
25 | fileset = lib.fileset.fileFilter sourceFiles ../../kernel-abi-check;
26 | };
27 |
28 | cargoLock = {
29 | lockFile = ../../kernel-abi-check/Cargo.lock;
30 | };
31 |
32 | setupHook = ./kernel-abi-check-hook.sh;
33 |
34 | meta = {
35 | description = "Check glibc and libstdc++ ABI compat";
36 | };
37 | }
38 |
--------------------------------------------------------------------------------
/pkgs/kernel-abi-check/kernel-abi-check-hook.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | _checkAbiHook() {
4 | if [ -z "${doAbiCheck:-}" ]; then
5 | echo "Skipping ABI check"
6 | else
7 | echo "Checking of ABI compatibility"
8 | find $out/${extensionName} -name '*.so' \
9 | -exec kernel-abi-check {} \;
10 | fi
11 | }
12 |
13 | postInstallCheckHooks+=(_checkAbiHook)
14 |
--------------------------------------------------------------------------------
/pkgs/stdenv-glibc-2_27/default.nix:
--------------------------------------------------------------------------------
1 | {
2 | config,
3 | cudaSupport ? config.cudaSupport,
4 | fetchFromGitHub,
5 | overrideCC,
6 | system,
7 | wrapBintoolsWith,
8 | wrapCCWith,
9 | gcc12Stdenv,
10 | stdenv,
11 | bintools-unwrapped,
12 | cudaPackages,
13 | libgcc,
14 | }:
15 |
16 | let
17 | nixpkgs_20191230 = import (fetchFromGitHub {
18 | owner = "NixOS";
19 | repo = "nixpkgs";
20 | rev = "a9eb3eed170fa916e0a8364e5227ee661af76fde";
21 | hash = "sha256-1ycrr9HMrGA3ZDM8qmKcZICBupE5UShnIIhPRWdvAzA=";
22 | }) { inherit system; };
23 |
24 | glibc_2_27 = nixpkgs_20191230.glibc.overrideAttrs (prevAttrs: {
25 | # Slight adjustments for compatibility with modern nixpkgs:
26 | #
27 | # - pname is required
28 | # - an additional getent output
29 | # - passthru libgcc
30 |
31 | pname = "glibc";
32 |
33 | outputs = prevAttrs.outputs ++ [ "getent" ];
34 |
35 | postInstall =
36 | prevAttrs.postInstall
37 | + ''
38 | install -Dm755 $bin/bin/getent -t $getent/bin
39 | '';
40 |
41 | passthru = prevAttrs.passthru // {
42 | # Should be stdenv's gcc, but we don't have access to it.
43 | libgcc = stdenv.cc.cc.libgcc;
44 | };
45 | });
46 |
47 | stdenvWith =
48 | newGlibc: newGcc: stdenv:
49 | let
50 | # We need gcc to have a libgcc/libstdc++ that is compatible with
51 | # glibc. We do this in three steps to avoid an infinite recursion:
52 | # (1) we create an stdenv with gcc and glibc; (2) we rebuild gcc using
53 | # this stdenv, so that we have a libgcc/libstdc++ that is compatible
54 | # with glibc; (3) we create the final stdenv that contains the compatible
55 | # gcc + glibc.
56 | onlyGlibc = overrideCC stdenv (wrapCCWith {
57 | cc = newGcc;
58 | bintools = wrapBintoolsWith {
59 | bintools = bintools-unwrapped;
60 | libc = newGlibc;
61 | };
62 | });
63 | compilerWrapped = wrapCCWith rec {
64 | cc = newGcc.override { stdenv = onlyGlibc; };
65 | bintools = wrapBintoolsWith {
66 | bintools = bintools-unwrapped;
67 | libc = newGlibc;
68 | };
69 | libcxx = cc.lib;
70 | };
71 | in
72 | overrideCC stdenv compilerWrapped;
73 |
74 | in
75 | stdenvWith glibc_2_27 (if cudaSupport then cudaPackages.backendStdenv else gcc12Stdenv).cc.cc stdenv
76 |
--------------------------------------------------------------------------------
/scripts/gen_variants_markdown.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env nix-shell
2 | #!nix-shell -i python3 -p python3
3 | import json
4 | from pathlib import Path
5 |
6 | _PLATFORM_NAMES = {
7 | "cuda": "CUDA",
8 | "rocm": "ROCm",
9 | }
10 |
11 | HEADER = """# Build variants
12 |
13 | A kernel can be compliant for a specific compute framework (e.g. CUDA) or
14 | architecture (e.g. x86_64). For compliance with a compute framework and
15 | architecture combination, all the build variants listed below must be
16 | available. This list will be updated as new PyTorch versions are released.\n
17 | """
18 |
19 | FOOTER = """## Universal
20 |
21 | Kernels that are in pure Python (e.g. Triton kernels) only need to provide
22 | a single build variant:
23 |
24 | - `torch-universal`
25 | """
26 |
27 |
28 | def json_to_markdown():
29 | project_root = Path(__file__).parent.parent
30 |
31 | with open(project_root / "build-variants.json", "r") as f:
32 | data = json.load(f)
33 |
34 | with open(project_root / "docs" / "build-variants.md", "w") as f:
35 | f.write(HEADER)
36 | for arch, platforms in data.items():
37 | for platform, variants in platforms.items():
38 | f.write(f"## {_PLATFORM_NAMES[platform]} {arch}\n\n")
39 |
40 | for variant in variants:
41 | f.write(f"- `{variant}`\n")
42 |
43 | f.write("\n")
44 | f.write(FOOTER)
45 |
46 |
47 | if __name__ == "__main__":
48 | json_to_markdown()
49 |
--------------------------------------------------------------------------------
/scripts/init-kernel.py:
--------------------------------------------------------------------------------
1 | # This script creates the necessary files for a new kernel example in the specified directory.
2 | #
3 | # Example Usage:
4 | # $ uv run scripts/init-kernel.py relu
5 | #
6 | # Created directory: relu
7 | #
8 | # relu/
9 | # ├── relu_kernel/
10 | # │ └── relu.cu
11 | # ├── tests/
12 | # │ ├── __init__.py
13 | # │ └── test_relu.py
14 | # ├── torch-ext/
15 | # │ ├── relu/
16 | # │ │ └── __init__.py
17 | # │ ├── torch_binding.cpp
18 | # │ └── torch_binding.h
19 | # ├── build.toml
20 | # └── flake.nix
21 | #
22 | # ✓ Success! All files for the ReLU example have been created successfully.
23 | #
24 | # Next steps:
25 | # 1. Build the kernel: cd relu && git add . && nix develop -L
26 | # 2. Run the tests: pytest -vv tests/
27 |
28 | import os
29 | import argparse
30 | import pathlib
31 |
32 |
33 | class Colors:
34 | HEADER = "\033[95m"
35 | BLUE = "\033[94m"
36 | CYAN = "\033[96m"
37 | GREEN = "\033[92m"
38 | YELLOW = "\033[93m"
39 | RED = "\033[91m"
40 | ENDC = "\033[0m"
41 | BOLD = "\033[1m"
42 | UNDERLINE = "\033[4m"
43 | GREY = "\033[90m"
44 |
45 |
46 | def create_file_with_content(file_path: str, content: str):
47 | """Creates a file at 'file_path' with the specified content."""
48 | directory = os.path.dirname(file_path)
49 | if directory and not os.path.exists(directory):
50 | os.makedirs(directory)
51 |
52 | with open(file_path, "w") as f:
53 | f.write(content)
54 |
55 |
56 | # Generate a tree view of the created files
57 | def print_tree(directory: str, prefix: str = ""):
58 | entries = sorted(os.listdir(directory))
59 |
60 | # Process directories first, then files
61 | dirs = [e for e in entries if os.path.isdir(os.path.join(directory, e))]
62 | files = [e for e in entries if os.path.isfile(os.path.join(directory, e))]
63 |
64 | # Process all items except the last one
65 | count = len(dirs) + len(files)
66 |
67 | # Print directories
68 | for i, dirname in enumerate(dirs):
69 | is_last_dir = i == len(dirs) - 1 and len(files) == 0
70 | connector = "└── " if is_last_dir else "├── "
71 | print(
72 | f" {prefix}{connector}{Colors.BOLD}{Colors.BLUE}{dirname}/{Colors.ENDC}"
73 | )
74 |
75 | # Prepare the prefix for the next level
76 | next_prefix = prefix + (" " if is_last_dir else "│ ")
77 | print_tree(os.path.join(directory, dirname), next_prefix)
78 |
79 | # Print files
80 | for i, filename in enumerate(files):
81 | is_last = i == len(files) - 1
82 | connector = "└── " if is_last else "├── "
83 | file_color = ""
84 |
85 | print(f" {prefix}{connector}{file_color}{filename}{Colors.ENDC}")
86 |
87 |
88 | def main():
89 | # Get the directory where this script is located
90 | script_dir = pathlib.Path(__file__).parent.resolve().parent.resolve()
91 |
92 | # Create argument parser
93 | parser = argparse.ArgumentParser(
94 | description="Create ReLU example files in the specified directory"
95 | )
96 | parser.add_argument(
97 | "target_dir", help="Target directory where files will be created"
98 | )
99 | args = parser.parse_args()
100 |
101 | # Get the target directory from arguments
102 | target_dir = args.target_dir
103 |
104 | # Create the target directory if it doesn't exist
105 | if not os.path.exists(target_dir):
106 | os.makedirs(target_dir)
107 | print(
108 | f"\n{Colors.CYAN}{Colors.BOLD}Created directory: {Colors.BOLD}{target_dir}{Colors.ENDC}\n"
109 | )
110 | else:
111 | print(
112 | f"\n{Colors.CYAN}{Colors.BOLD}Directory already exists: {Colors.BOLD}{target_dir}{Colors.ENDC}\n"
113 | )
114 |
115 | # get files from examples/relu
116 | relu_dir = script_dir / "examples" / "relu"
117 | for root, _, files in os.walk(relu_dir):
118 | for file in files:
119 | file_path = os.path.join(root, file)
120 | with open(file_path, "r") as f:
121 | content = f.read()
122 |
123 | # Replace kernel-builder.url with path:../ in flake.nix
124 | if file_path.endswith("flake.nix"):
125 | kernel_builder_url_start = content.find("kernel-builder.url =")
126 | kernel_builder_url_end = content.find(";", kernel_builder_url_start)
127 | content = (
128 | content[:kernel_builder_url_start]
129 | + 'kernel-builder.url = "path:../"'
130 | + content[kernel_builder_url_end:]
131 | )
132 |
133 | target_file = file_path.replace(str(relu_dir), target_dir)
134 | create_file_with_content(target_file, content)
135 |
136 | print(f" {Colors.BOLD}{target_dir}/{Colors.ENDC}")
137 | print_tree(target_dir)
138 |
139 | print(
140 | f"\n{Colors.GREEN}{Colors.BOLD}✓ Success!{Colors.ENDC} All files for the ReLU example have been created successfully."
141 | )
142 |
143 | print(f"\n{Colors.CYAN}{Colors.BOLD}Next steps:{Colors.ENDC}")
144 |
145 | commands = [
146 | "nix run nixpkgs#cachix -- use huggingface",
147 | f"cd {target_dir}",
148 | "git add .",
149 | "nix develop -L",
150 | ]
151 |
152 | for index, command in enumerate(commands, start=1):
153 | print(
154 | f" {Colors.YELLOW}{index}.{Colors.ENDC} {Colors.BOLD}{command}{Colors.ENDC}"
155 | )
156 |
157 | print(
158 | f"\none line build:\n{Colors.GREY}{Colors.BOLD}{' && '.join(commands)}{Colors.ENDC}{Colors.ENDC}"
159 | )
160 |
161 | print(f"\n{Colors.CYAN}{Colors.BOLD}Run the tests{Colors.ENDC}")
162 | print(
163 | f" {Colors.YELLOW}{1}.{Colors.ENDC} {Colors.BOLD}pytest -vv tests/{Colors.ENDC}"
164 | )
165 |
166 | print("")
167 |
168 |
169 | if __name__ == "__main__":
170 | main()
171 |
--------------------------------------------------------------------------------
/tests/Dockerfile.test-kernel:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1.4
2 | ARG PYTHON_VERSION=3.10
3 | # Ideally we'd test with 11.8, but the GELU kernel is subtly off.
4 | ARG CUDA_VERSION=12.1.0
5 | ARG UBUNTU_VERSION=18.04
6 | ARG TORCH_VERSION=2.5.0
7 |
8 | FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} as base
9 |
10 | # Set environment variables
11 | ENV DEBIAN_FRONTEND=noninteractive \
12 | PYTHONUNBUFFERED=1 \
13 | PATH="/root/.local/bin:/root/.cargo/bin:${PATH}" \
14 | NVIDIA_VISIBLE_DEVICES=all \
15 | NVIDIA_DRIVER_CAPABILITIES=compute,utility
16 |
17 | # Install system dependencies
18 | RUN apt-get update && apt-get install -y --no-install-recommends \
19 | curl \
20 | python3 \
21 | python3-pip \
22 | && rm -rf /var/lib/apt/lists/*
23 |
24 | # Install uv package manager
25 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh
26 |
27 | # Set working directory
28 | WORKDIR /app
29 |
30 | # Need to re-declare ARG after FROM for use in RUN
31 | ARG CUDA_VERSION
32 | ARG TORCH_VERSION
33 | ARG PYTHON_VERSION
34 |
35 | RUN echo "Building with CUDA_VERSION=${CUDA_VERSION}, TORCH_VERSION=${TORCH_VERSION}, PYTHON_VERSION=${PYTHON_VERSION}"
36 |
37 | # Initialize uv and create virtual env
38 | RUN uv init --app kernel-test --python "${PYTHON_VERSION}"
39 |
40 | # Move into the app
41 | WORKDIR /app/kernel-test
42 |
43 | # Install PyTorch with the appropriate CUDA version
44 |
45 | # NOTE: `markupsafe` must be installed first to avoid a conflict with the torch package.
46 | # See: https://github.com/astral-sh/uv/issues/9647
47 |
48 | RUN CUDA_MAJOR_MINOR=$(echo ${CUDA_VERSION} | cut -d'.' -f1,2) && \
49 | case ${CUDA_MAJOR_MINOR} in \
50 | "11.8") CUDA_TAG="cu118" ;; \
51 | "12.1") CUDA_TAG="cu121" ;; \
52 | "12.2") CUDA_TAG="cu122" ;; \
53 | "12.4") CUDA_TAG="cu124" ;; \
54 | *) CUDA_TAG="" ;; \
55 | esac && \
56 | if [ -n "${CUDA_TAG}" ]; then \
57 | echo "Installing PyTorch ${TORCH_VERSION} with CUDA ${CUDA_TAG}" && \
58 | uv add markupsafe --default-index "https://pypi.org/simple" && \
59 | uv add "torch==${TORCH_VERSION}" --index-url "https://download.pytorch.org/whl/${CUDA_TAG}"; \
60 | else \
61 | echo "Installing PyTorch ${TORCH_VERSION} without CUDA-specific index" && \
62 | uv add "torch==${TORCH_VERSION}"; \
63 | fi
64 |
65 | # add pytest for runtime tests
66 | RUN uv add numpy pytest
67 |
68 | # Copy kernels and tests
69 | COPY activation-kernel ./activation-kernel
70 | COPY cutlass-gemm-kernel ./cutlass-gemm-kernel
71 | COPY silu-and-mul-universal-kernel ./silu-and-mul-universal-kernel
72 | COPY examples/activation/tests ./activation_tests
73 | COPY examples/cutlass-gemm/tests ./tests/cutlass_gemm_tests
74 |
75 | # Run tests
76 | ENV PYTHONPATH="activation-kernel:cutlass-gemm-kernel:silu-and-mul-universal-kernel:$PYTHONPATH"
77 | CMD ["/bin/sh", "-c", ".venv/bin/pytest", "activation_tests", "cutlass_gemm_tests"]
78 |
79 | # We only care about importing, the kernel is trivial.
80 | CMD ["/bin/sh", "-c", ".venv/bin/python", "-c", "'import silu_and_mul_universal'"]
81 |
--------------------------------------------------------------------------------
/versions.nix:
--------------------------------------------------------------------------------
1 | { lib }:
2 |
3 | rec {
4 | torchVersions = [
5 | {
6 | torchVersion = "2.6";
7 | cudaVersion = "11.8";
8 | cxx11Abi = false;
9 | upstreamVariant = true;
10 | }
11 | {
12 | torchVersion = "2.6";
13 | cudaVersion = "11.8";
14 | cxx11Abi = true;
15 | upstreamVariant = true;
16 | }
17 | {
18 | torchVersion = "2.6";
19 | cudaVersion = "12.4";
20 | cxx11Abi = false;
21 | upstreamVariant = true;
22 | }
23 | {
24 | torchVersion = "2.6";
25 | cudaVersion = "12.4";
26 | cxx11Abi = true;
27 | upstreamVariant = true;
28 | }
29 | {
30 | torchVersion = "2.6";
31 | cudaVersion = "12.6";
32 | cxx11Abi = false;
33 | upstreamVariant = true;
34 | }
35 | {
36 | torchVersion = "2.6";
37 | cudaVersion = "12.6";
38 | cxx11Abi = true;
39 | upstreamVariant = true;
40 | }
41 | {
42 | torchVersion = "2.6";
43 | rocmVersion = "6.2.4";
44 | cxx11Abi = true;
45 | upstreamVariant = true;
46 | }
47 |
48 | {
49 | torchVersion = "2.7";
50 | cudaVersion = "11.8";
51 | cxx11Abi = true;
52 | upstreamVariant = true;
53 | }
54 | {
55 | torchVersion = "2.7";
56 | cudaVersion = "12.6";
57 | cxx11Abi = true;
58 | upstreamVariant = true;
59 | }
60 | {
61 | torchVersion = "2.7";
62 | cudaVersion = "12.8";
63 | cxx11Abi = true;
64 | upstreamVariant = true;
65 | }
66 | {
67 | torchVersion = "2.7";
68 | rocmVersion = "6.3.4";
69 | cxx11Abi = true;
70 | upstreamVariant = true;
71 | }
72 | {
73 | torchVersion = "2.7";
74 | cxx11Abi = true;
75 | metal = true;
76 | # Set to false for now, needs more testing.
77 | upstreamVariant = false;
78 | }
79 |
80 | # Non-standard versions; not included in bundle builds.
81 | {
82 | torchVersion = "2.7";
83 | cudaVersion = "12.9";
84 | cxx11Abi = true;
85 | }
86 | ];
87 |
88 | cudaVersions =
89 | let
90 | withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions;
91 | in
92 | builtins.map (torchVersion: torchVersion.cudaVersion) withCuda;
93 |
94 | rocmVersions =
95 | let
96 | withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions;
97 | in
98 | builtins.map (torchVersion: torchVersion.rocmVersion) withRocm;
99 |
100 | # Upstream only builds aarch64 for CUDA >= 12.6.
101 | isCudaSupported =
102 | system: torchVersion:
103 | system == "x86_64-linux"
104 | || (
105 | system == "aarch64-linux" && lib.strings.versionAtLeast (torchVersion.cudaVersion or "0.0") "12.6"
106 | );
107 |
108 | isMetalSupported = system: torchVersion: system == "aarch64-darwin" && torchVersion ? metal;
109 |
110 | # ROCm only builds on x86_64.
111 | isRocmSupported = system: torchVersion: system == "x86_64-linux" && torchVersion ? rocmVersion;
112 |
113 | isSupported =
114 | system: torchVersion:
115 | (isCudaSupported system torchVersion)
116 | || (isMetalSupported system torchVersion)
117 | || (isRocmSupported system torchVersion);
118 |
119 | computeFramework =
120 | buildConfig:
121 | if buildConfig ? cudaVersion then
122 | "cuda"
123 | else if buildConfig ? metal then
124 | "metal"
125 | else if buildConfig ? "rocmVersion" then
126 | "rocm"
127 | else
128 | throw "Could not find compute framework: no CUDA or ROCm version specified and Metal is not enabled";
129 |
130 | # All build configurations supported by Torch.
131 | buildConfigs =
132 | system:
133 | let
134 | supported = builtins.filter (isSupported system) torchVersions;
135 | in
136 | map (version: version // { gpu = computeFramework version; }) supported;
137 |
138 | # Upstream build variants.
139 | buildVariants =
140 | system:
141 | let
142 | inherit (import ./lib/version-utils.nix { inherit lib; }) abiString flattenVersion;
143 | computeString =
144 | buildConfig:
145 | if buildConfig.gpu == "cuda" then
146 | "cu${flattenVersion (lib.versions.majorMinor buildConfig.cudaVersion)}"
147 | else if buildConfig.gpu == "rocm" then
148 | "rocm${flattenVersion (lib.versions.majorMinor buildConfig.rocmVersion)}"
149 | else
150 | throw "Unknown compute framework: ${buildConfig.gpu}";
151 | buildName =
152 | buildConfig:
153 | "torch${flattenVersion buildConfig.torchVersion}-${abiString buildConfig.cxx11Abi}-${computeString buildConfig}-${system}";
154 | filterMap = f: xs: builtins.filter (x: x != null) (builtins.map f xs);
155 | in
156 | {
157 | ${system} = lib.zipAttrs (
158 | filterMap (
159 | buildConfig:
160 | if buildConfig.upstreamVariant or false then
161 | {
162 | ${buildConfig.gpu} = buildName buildConfig;
163 | }
164 | else
165 | null
166 | ) (buildConfigs system)
167 | );
168 | };
169 | }
170 |
--------------------------------------------------------------------------------