├── .bazelignore ├── .bazelrc ├── .bazelversion ├── .bcr ├── config.yml ├── metadata.template.json ├── presubmit.yml └── source.template.json ├── .gitattributes ├── .github ├── actions │ ├── Set-VSEnv.ps1 │ └── set-build-env │ │ └── action.yaml ├── release_notes.template └── workflows │ ├── build-tests.yaml │ ├── ci.bazelrc │ ├── github-pages.yaml │ ├── integration-tests.yaml │ ├── pre-commit.yaml │ ├── release.yaml │ └── utilities-tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── BUILD.bazel ├── CODEOWNERS ├── LICENSE ├── MODULE.bazel ├── README.md ├── WORKSPACE.bazel ├── cuda ├── BUILD.bazel ├── defs.bzl ├── dummy │ ├── BUILD.bazel │ ├── dummy.cpp │ └── link.stub ├── extensions.bzl ├── private │ ├── BUILD.bazel │ ├── action_names.bzl │ ├── actions │ │ ├── compile.bzl │ │ └── dlink.bzl │ ├── artifact_categories.bzl │ ├── compat.bzl │ ├── cuda_helper.bzl │ ├── defs.bzl │ ├── macros │ │ ├── cuda_binary.bzl │ │ └── cuda_test.bzl │ ├── os_helpers.bzl │ ├── providers.bzl │ ├── repositories.bzl │ ├── rules │ │ ├── common.bzl │ │ ├── cuda_library.bzl │ │ ├── cuda_objects.bzl │ │ ├── cuda_toolkit_info.bzl │ │ └── flags.bzl │ ├── template_helper.bzl │ ├── templates │ │ ├── BUILD.cccl │ │ ├── BUILD.clang_compiler_deps │ │ ├── BUILD.cublas │ │ ├── BUILD.cuda_build_setting │ │ ├── BUILD.cuda_disabled │ │ ├── BUILD.cuda_headers │ │ ├── BUILD.cuda_shared │ │ ├── BUILD.cudart │ │ ├── BUILD.cufft │ │ ├── BUILD.cufile │ │ ├── BUILD.cupti │ │ ├── BUILD.curand │ │ ├── BUILD.cusolver │ │ ├── BUILD.cusparse │ │ ├── BUILD.npp │ │ ├── BUILD.nvcc │ │ ├── BUILD.nvidia_fs │ │ ├── BUILD.nvjitlink │ │ ├── BUILD.nvjpeg │ │ ├── BUILD.nvml │ │ ├── BUILD.nvrtc │ │ ├── BUILD.nvtx │ │ ├── BUILD.redist_json │ │ ├── BUILD.toolchain_clang │ │ ├── BUILD.toolchain_disabled │ │ ├── BUILD.toolchain_nvcc │ │ ├── BUILD.toolchain_nvcc_msvc │ │ ├── README.md │ │ ├── defs.bzl.tpl │ │ ├── redist.bzl.tpl │ │ └── registry.bzl │ ├── toolchain.bzl │ ├── toolchain_config_lib.bzl │ └── toolchain_configs │ │ ├── clang.bzl │ │ ├── disabled.bzl │ │ ├── nvcc.bzl │ │ ├── nvcc_msvc.bzl │ │ └── utils.bzl └── repositories.bzl ├── docs ├── .bazelversion ├── .gitignore ├── BUILD.bazel ├── MODULE.bazel ├── WORKSPACE.bazel ├── WORKSPACE.bzlmod ├── build-docs.sh ├── developer_docs.bzl ├── mkdocs.yaml ├── mkdocs │ └── stylesheets │ │ └── extra.css ├── providers_docs.bzl ├── requirements.txt ├── toolchain_config_docs.bzl ├── user_docs.bzl └── versioning.py ├── examples ├── .bazelrc ├── MODULE.bazel ├── WORKSPACE.bazel ├── WORKSPACE.bzlmod ├── basic │ ├── BUILD.bazel │ ├── kernel.cu │ ├── kernel.h │ └── main.cpp ├── basic_macros │ ├── BUILD.bazel │ └── main.cu ├── cublas │ ├── BUILD.bazel │ └── cublas.cpp ├── if_cuda │ ├── BUILD.bazel │ ├── README.md │ ├── kernel.cu │ ├── kernel.h │ └── main.cpp ├── nccl │ ├── BUILD.bazel │ ├── nccl-tests-clang.patch │ ├── nccl-tests.BUILD │ ├── nccl-tests.bzl │ ├── nccl.BUILD │ └── nccl.bzl ├── rdc │ ├── BUILD.bazel │ ├── a.cu │ ├── b.cu │ └── b.cuh └── thrust │ ├── BUILD.bazel │ └── thrust.cu ├── renovate.json └── tests ├── flag ├── BUILD.bazel └── flag_validation_test.bzl ├── integration ├── BUILD.to_symlink ├── MODULE.bazel ├── WORKSPACE.bazel ├── test_all.sh ├── toolchain_components │ ├── BUILD.bazel │ ├── MODULE.bazel │ ├── WORKSPACE.bzlmod │ └── WORKSPACK.bazel ├── toolchain_none │ ├── BUILD.bazel │ └── MODULE.bazel ├── toolchain_redist_json │ ├── BUILD.bazel │ ├── MODULE.bazel │ ├── WORKSPACE.bazel │ └── WORKSPACE.bzlmod ├── toolchain_root │ ├── BUILD.bazel │ └── MODULE.bazel └── toolchain_rules │ ├── BUILD.bazel │ └── MODULE.bazel ├── toolchain_config_lib ├── BUILD.bazel └── toolchain_config_lib_test.bzl └── utils ├── BUILD.bazel └── utils_test.bzl /.bazelignore: -------------------------------------------------------------------------------- 1 | docs 2 | examples 3 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | common --announce_rc 2 | 3 | # Convenient flag shortcuts. 4 | build --flag_alias=enable_cuda=//cuda:enable 5 | build --flag_alias=cuda_archs=//cuda:archs 6 | build --flag_alias=cuda_compiler=//cuda:compiler 7 | build --flag_alias=cuda_copts=//cuda:copts 8 | build --flag_alias=cuda_host_copts=//cuda:host_copts 9 | build --flag_alias=cuda_runtime=//cuda:runtime 10 | 11 | build --enable_cuda=True 12 | 13 | # Use --config=clang to build with clang instead of gcc and nvcc. 14 | build:clang --repo_env=CC=clang 15 | build:clang --//cuda:compiler=clang 16 | 17 | # https://github.com/bazel-contrib/rules_cuda/issues/1 18 | # build --ui_event_filters=-INFO 19 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 8.1.0 2 | -------------------------------------------------------------------------------- /.bcr/config.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/.bcr/config.yml -------------------------------------------------------------------------------- /.bcr/metadata.template.json: -------------------------------------------------------------------------------- 1 | { 2 | "homepage": "https://github.com/bazel-contrib/rules_cuda", 3 | "maintainers": [ 4 | { 5 | "email": "james.sharpe@zenotech.com", 6 | "github": "jsharpe", 7 | "name": "James Sharpe" 8 | } 9 | ], 10 | "repository": ["github:bazel-contrib/rules_cuda"], 11 | "versions": [], 12 | "yanked_versions": {} 13 | } 14 | -------------------------------------------------------------------------------- /.bcr/presubmit.yml: -------------------------------------------------------------------------------- 1 | matrix: 2 | bazel: 3 | - 6.x 4 | - 7.x 5 | tasks: 6 | verify_targets_linux: 7 | name: Verify build targets 8 | bazel: ${{ bazel }} 9 | platform: ubuntu2004 10 | -------------------------------------------------------------------------------- /.bcr/source.template.json: -------------------------------------------------------------------------------- 1 | { 2 | "integrity": "", 3 | "strip_prefix": "{REPO}-{TAG}", 4 | "url": "https://github.com/{OWNER}/{REPO}/releases/download/{TAG}/{REPO}-{TAG}.tar.gz" 5 | } 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text eol=lf 2 | -------------------------------------------------------------------------------- /.github/actions/Set-VSEnv.ps1: -------------------------------------------------------------------------------- 1 | param ( 2 | [parameter(Mandatory = $false)] 3 | [ValidateSet(2022, 2019, 2017)][int]$Version = 2019, 4 | 5 | [parameter(Mandatory = $false)] 6 | [ValidateSet("all", "x86", "x64")][String]$Arch = "x64" 7 | ) 8 | 9 | function Set-EnvFromCmdSet { 10 | [CmdletBinding()] 11 | param( 12 | [Parameter(ValueFromPipeline)] 13 | [string]$CmdSetResult 14 | ) 15 | process { 16 | if ($CmdSetResult -Match "=") { 17 | $i = $CmdSetResult.IndexOf("=") 18 | $k = $CmdSetResult.Substring(0, $i) 19 | $v = $CmdSetResult.Substring($i + 1) 20 | Set-Item -Force -Path "Env:\$k" -Value "$v" 21 | } 22 | } 23 | } 24 | 25 | $vs_where = 'C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe' 26 | 27 | $version_range = switch ($Version) { 28 | 2022 { '[17,18)' } 29 | 2019 { '[16,17)' } 30 | 2017 { '[15,16)' } 31 | } 32 | $info = &$vs_where -version $version_range -format json | ConvertFrom-Json 33 | $vs_env = @{ 34 | install_path = $info ? $info[0].installationPath : $null 35 | all = 'Common7\Tools\VsDevCmd.bat' 36 | x64 = 'VC\Auxiliary\Build\vcvars64.bat' 37 | x86 = 'VC\Auxiliary\Build\vcvars32.bat' 38 | } 39 | 40 | if ( $null -eq $vs_env.install_path) { 41 | Write-Host -ForegroundColor Red "Visual Studio $Version is not installed." 42 | return 43 | } 44 | 45 | $path = Join-Path $vs_env.install_path $vs_env.$Arch 46 | 47 | C:/Windows/System32/cmd.exe /c "`"$path`" & set" | Set-EnvFromCmdSet 48 | Set-Item -Force -Path "Env:\BAZEL_VC" -Value "$env:VCINSTALLDIR" 49 | Write-Host -ForegroundColor Green "Visual Studio $Version $Arch Command Prompt variables set." 50 | -------------------------------------------------------------------------------- /.github/actions/set-build-env/action.yaml: -------------------------------------------------------------------------------- 1 | name: "Setup Build Environment" 2 | description: "" 3 | 4 | inputs: 5 | os: 6 | description: "matrix.cases.os" 7 | required: true 8 | cuda-version: 9 | description: "matrix.cases.cuda-version" 10 | required: true 11 | source: 12 | description: "matrix.cases.source" 13 | required: true 14 | toolchain: 15 | description: "matrix.cases.toolchain" 16 | required: false 17 | toolchain-version: 18 | description: "matrix.cases.toolchain-version" 19 | required: false 20 | 21 | runs: 22 | using: "composite" 23 | steps: 24 | - name: Install CUDA (NVIDIA, Linux) 25 | uses: Jimver/cuda-toolkit@v0.2.22 26 | if: ${{ !startsWith(inputs.os, 'windows') && inputs.source == 'nvidia' }} 27 | with: 28 | cuda: ${{ inputs.cuda-version }} 29 | sub-packages: '["nvcc", "cudart-dev"]' 30 | method: network 31 | - name: Show bin, include, lib (NVIDIA, Linux) 32 | if: ${{ !startsWith(inputs.os, 'windows') && inputs.source == 'nvidia' }} 33 | shell: bash 34 | run: | 35 | tree ${CUDA_PATH}/bin 36 | tree ${CUDA_PATH}/include 37 | tree ${CUDA_PATH}/lib64 38 | - name: Install LLVM ${{ inputs.toolchain-version }} 39 | if: ${{ !startsWith(inputs.os, 'windows') && inputs.toolchain == 'llvm' }} 40 | shell: bash 41 | run: | 42 | wget https://apt.llvm.org/llvm.sh 43 | chmod +x llvm.sh 44 | sudo ./llvm.sh ${{ inputs.toolchain-version }} 45 | sudo ln -sf /usr/bin/clang-${{ inputs.toolchain-version }} /usr/bin/clang 46 | clang --version 47 | - name: Install CURAND For LLVM 48 | uses: Jimver/cuda-toolkit@v0.2.22 49 | if: ${{ !startsWith(inputs.os, 'windows') && inputs.toolchain == 'llvm' }} 50 | with: 51 | cuda: ${{ inputs.cuda-version }} 52 | sub-packages: '["nvcc", "cudart-dev"]' # avoid full cuda install 53 | non-cuda-sub-packages: '["libcurand-dev"]' 54 | method: network 55 | - name: Install CUDA (Ubuntu) 56 | if: ${{ !startsWith(inputs.os, 'windows') && inputs.source == 'ubuntu' }} 57 | shell: bash 58 | run: | 59 | sudo apt-get update 60 | sudo apt-get install -y nvidia-cuda-dev=${{ inputs.cuda-version }} nvidia-cuda-toolkit=${{ inputs.cuda-version }} gcc-9 g++-9 61 | export CC=gcc-9 62 | export CXX=g++-9 63 | echo "CC=gcc-9" >> $GITHUB_ENV 64 | echo "CXX=g++-9" >> $GITHUB_ENV 65 | 66 | - name: Install CUDA (Windows) 67 | uses: Jimver/cuda-toolkit@v0.2.22 68 | if: ${{ startsWith(inputs.os, 'windows') }} 69 | with: 70 | cuda: ${{ inputs.cuda-version }} 71 | sub-packages: '["nvcc", "cudart"]' 72 | method: network 73 | - name: Show bin, include, lib64 (Windows) 74 | if: ${{ startsWith(inputs.os, 'windows') }} 75 | shell: pwsh 76 | run: | 77 | tree /F $env:CUDA_PATH/bin 78 | tree /F $env:CUDA_PATH/include 79 | tree /F $env:CUDA_PATH/lib/x64 80 | - name: Set Visual Studio Environment (Windows) 81 | if: ${{ startsWith(inputs.os, 'windows') }} 82 | shell: pwsh 83 | run: .github/actions/Set-VSEnv.ps1 2019 84 | -------------------------------------------------------------------------------- /.github/release_notes.template: -------------------------------------------------------------------------------- 1 | ## `WORKSPACE` code 2 | ```starlark 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | http_archive( 5 | name = "rules_cuda", 6 | sha256 = "{archive_sha256}", 7 | strip_prefix = "rules_cuda-{version}", 8 | urls = ["https://github.com/bazel-contrib/rules_cuda/releases/download/{version}/rules_cuda-{version}.tar.gz"], 9 | ) 10 | 11 | load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains") 12 | rules_cuda_dependencies() 13 | rules_cuda_toolchains(register_toolchains = True) 14 | ``` 15 | -------------------------------------------------------------------------------- /.github/workflows/build-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Test Example Build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | name: "Test Examples Build (CUDA ${{ matrix.cases.cuda-version }} on ${{ matrix.cases.os }})" 12 | runs-on: ${{ matrix.cases.os }} 13 | timeout-minutes: 60 14 | strategy: 15 | matrix: 16 | cases: 17 | - { os: "ubuntu-22.04", cuda-version: "11.7.0", source: "nvidia" } 18 | - { os: "ubuntu-22.04", cuda-version: "11.8.0", source: "nvidia" } 19 | - { 20 | os: "ubuntu-22.04", 21 | cuda-version: "11.7.0", 22 | source: "nvidia", 23 | toolchain: "llvm", 24 | toolchain-version: "16", 25 | } 26 | - { 27 | os: "ubuntu-22.04", 28 | cuda-version: "11.5.1-1ubuntu1", 29 | source: "ubuntu", 30 | } 31 | - { os: "windows-2019", cuda-version: "10.1.243", source: "nvidia" } 32 | - { os: "windows-2019", cuda-version: "11.6.2", source: "nvidia" } 33 | steps: 34 | - uses: actions/checkout@v4 35 | 36 | - uses: bazelbuild/setup-bazelisk@v3 37 | - name: Mount bazel cache 38 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 39 | uses: actions/cache@v4 40 | with: 41 | path: ~/.cache/bazel 42 | key: bazel-${{ matrix.cases.os }}-cuda-${{ matrix.cases.cuda-version }}-${{ hashFiles('.bazelversion') }} 43 | 44 | - name: Setup build environment 45 | uses: ./.github/actions/set-build-env 46 | with: 47 | os: ${{ matrix.cases.os }} 48 | cuda-version: ${{ matrix.cases.cuda-version }} 49 | source: ${{ matrix.cases.source }} 50 | toolchain: ${{ matrix.cases.toolchain }} 51 | toolchain-version: ${{ matrix.cases.toolchain-version }} 52 | 53 | - name: Bazel build config for LLVM 54 | if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.toolchain == 'llvm' }} 55 | run: | 56 | echo "build --config=clang" > $HOME/.bazelrc 57 | echo "build:clang --@rules_cuda//cuda:archs=sm_80" >> $HOME/.bazelrc 58 | 59 | # Check https://bazel.build/release#support-matrix, manually unroll the the strategy matrix to avoid exploding 60 | # the combinations. 61 | 62 | # Use Bazel with version specified in .bazelversion 63 | - run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV 64 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 65 | - run: echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $env:GITHUB_ENV 66 | if: ${{ startsWith(matrix.cases.os, 'windows') }} 67 | 68 | # out of @examples repo build requires WORKSPACE-based external dependency system 69 | - run: bazelisk build --jobs=1 @rules_cuda_examples//basic:all 70 | - run: bazelisk build --jobs=1 @rules_cuda_examples//rdc:all 71 | - run: bazelisk build --jobs=1 @rules_cuda_examples//if_cuda:main 72 | - run: bazelisk build --jobs=1 @rules_cuda_examples//if_cuda:main --enable_cuda=False 73 | # in @examples repo build, bzlmod is enabled by default since Bazel 7 74 | - run: cd examples && bazelisk build --jobs=1 //basic:all 75 | - run: cd examples && bazelisk build --jobs=1 //rdc:all 76 | - run: cd examples && bazelisk build --jobs=1 //if_cuda:main 77 | - run: cd examples && bazelisk build --jobs=1 //if_cuda:main --enable_cuda=False 78 | - run: bazelisk shutdown 79 | # run some repo integration tests 80 | - run: cd tests/integration && ./test_all.sh 81 | 82 | # Use Bazel 7 83 | - run: echo "USE_BAZEL_VERSION=7.5.0" >> $GITHUB_ENV 84 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 85 | - run: echo "USE_BAZEL_VERSION=7.5.0" >> $env:GITHUB_ENV 86 | if: ${{ startsWith(matrix.cases.os, 'windows') }} 87 | 88 | # out of @examples repo build requires WORKSPACE-based external dependency system 89 | - run: bazelisk build --jobs=1 --noenable_bzlmod @rules_cuda_examples//basic:all 90 | - run: bazelisk build --jobs=1 --noenable_bzlmod @rules_cuda_examples//rdc:all 91 | - run: bazelisk build --jobs=1 --noenable_bzlmod @rules_cuda_examples//if_cuda:main 92 | - run: bazelisk build --jobs=1 --noenable_bzlmod @rules_cuda_examples//if_cuda:main --enable_cuda=False 93 | # in @examples repo build, bzlmod is enabled by default since Bazel 7 94 | - run: cd examples && bazelisk build --jobs=1 //basic:all 95 | - run: cd examples && bazelisk build --jobs=1 //rdc:all 96 | - run: cd examples && bazelisk build --jobs=1 //if_cuda:main 97 | - run: cd examples && bazelisk build --jobs=1 //if_cuda:main --enable_cuda=False 98 | - run: bazelisk shutdown 99 | # run some repo integration tests 100 | - run: cd tests/integration && ./test_all.sh 101 | 102 | # Use Bazel 6 103 | - run: echo "USE_BAZEL_VERSION=6.4.0" >> $GITHUB_ENV 104 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 105 | - run: echo "USE_BAZEL_VERSION=6.4.0" >> $env:GITHUB_ENV 106 | if: ${{ startsWith(matrix.cases.os, 'windows') }} 107 | 108 | - run: bazelisk build --jobs=1 @rules_cuda_examples//basic:all 109 | - run: bazelisk build --jobs=1 @rules_cuda_examples//rdc:all 110 | - run: bazelisk build --jobs=1 @rules_cuda_examples//if_cuda:main 111 | - run: bazelisk build --jobs=1 @rules_cuda_examples//if_cuda:main --enable_cuda=False 112 | - run: cd examples && bazelisk build --jobs=1 --enable_bzlmod //basic:all 113 | - run: cd examples && bazelisk build --jobs=1 --enable_bzlmod //rdc:all 114 | - run: cd examples && bazelisk build --jobs=1 --enable_bzlmod //if_cuda:main 115 | - run: cd examples && bazelisk build --jobs=1 --enable_bzlmod //if_cuda:main --enable_cuda=False 116 | - run: bazelisk shutdown 117 | -------------------------------------------------------------------------------- /.github/workflows/ci.bazelrc: -------------------------------------------------------------------------------- 1 | # This file contains Bazel settings to apply on CI only. 2 | # It is referenced with a --bazelrc option in the call to bazel in ci.yaml 3 | 4 | # Debug where options came from 5 | build --announce_rc 6 | # This directory is configured in GitHub actions to be persisted between runs. 7 | build --disk_cache=~/.cache/bazel 8 | build --repository_cache=~/.cache/bazel-repo 9 | # Don't rely on test logs being easily accessible from the test runner, 10 | # though it makes the log noisier. 11 | test --test_output=errors 12 | # Allows tests to run bazelisk-in-bazel, since this is the cache folder used 13 | test --test_env=XDG_CACHE_HOME -------------------------------------------------------------------------------- /.github/workflows/github-pages.yaml: -------------------------------------------------------------------------------- 1 | name: Generate docs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | pages: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | docs: 17 | # { ref: , name: } 18 | - { ref: main, name: latest } 19 | - { ref: v0.2.1, name: 0.2.1 } 20 | - { ref: v0.2.3, name: 0.2.3 } 21 | - { ref: v0.2.4, name: 0.2.4 } 22 | steps: 23 | - uses: actions/checkout@v4 24 | if: ${{ matrix.docs.ref == 'main' }} 25 | - uses: actions/checkout@v4 26 | with: 27 | ref: ${{ matrix.docs.ref }} 28 | if: ${{ matrix.docs.ref != 'main' }} 29 | 30 | - uses: bazelbuild/setup-bazelisk@v3 31 | 32 | - uses: actions/setup-python@v5 33 | with: 34 | python-version: "3.13" 35 | 36 | - name: Generate docs 37 | run: bash ./build-docs.sh 38 | env: 39 | CI: 1 40 | working-directory: ${{ github.workspace }}/docs 41 | 42 | - run: bazelisk shutdown 43 | 44 | - uses: actions/upload-artifact@v4 45 | with: 46 | name: "${{ matrix.docs.name }}" 47 | path: ${{ github.workspace }}/docs/site/ 48 | if-no-files-found: error 49 | if: ${{ github.event_name != 'pull_request' }} 50 | 51 | publish: 52 | needs: pages 53 | if: ${{ github.event_name != 'pull_request' }} 54 | runs-on: ubuntu-latest 55 | steps: 56 | - uses: actions/checkout@v4 57 | 58 | - uses: actions/download-artifact@v4 59 | with: 60 | path: ${{ github.workspace }}/docs/generated 61 | - name: Inspect docs site directory structure 62 | run: find ${{ github.workspace }}/docs/generated -maxdepth 2 63 | 64 | - uses: actions/setup-python@v5 65 | with: 66 | python-version: "3.13" 67 | - run: | 68 | pip install packaging==23.* 69 | python versioning.py generated/ 70 | working-directory: ${{ github.workspace }}/docs 71 | 72 | - uses: peaceiris/actions-gh-pages@v4 73 | with: 74 | github_token: ${{ secrets.GITHUB_TOKEN }} 75 | publish_dir: ./docs/generated 76 | force_orphan: true 77 | -------------------------------------------------------------------------------- /.github/workflows/integration-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Integration Build 2 | 3 | on: 4 | workflow_dispatch: 5 | issue_comment: 6 | types: [created] 7 | 8 | jobs: 9 | test-manual: 10 | name: "Integration Test Build (Manual)" 11 | if: github.event_name == 'workflow_dispatch' 12 | runs-on: ${{ matrix.cases.os }} 13 | timeout-minutes: 60 14 | strategy: 15 | matrix: 16 | cases: 17 | - { 18 | os: "ubuntu-22.04", 19 | cuda-version: "11.7.0", 20 | source: "nvidia", 21 | toolchain: "nvcc", 22 | } 23 | - { 24 | os: "ubuntu-22.04", 25 | cuda-version: "11.7.0", 26 | source: "nvidia", 27 | toolchain: "llvm", 28 | toolchain-version: "16", 29 | } 30 | steps: 31 | - uses: actions/checkout@v4 32 | 33 | - uses: bazelbuild/setup-bazelisk@v3 34 | - name: Mount bazel cache 35 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 36 | uses: actions/cache@v4 37 | with: 38 | path: ~/.cache/bazel 39 | key: ${{ matrix.cases.toolchain }}-${{ matrix.cases.toolchain-version }} 40 | 41 | - name: Setup build environment 42 | uses: ./.github/actions/set-build-env 43 | with: 44 | os: ${{ matrix.cases.os }} 45 | cuda-version: ${{ matrix.cases.cuda-version }} 46 | source: ${{ matrix.cases.source }} 47 | toolchain: ${{ matrix.cases.toolchain }} 48 | toolchain-version: ${{ matrix.cases.toolchain-version }} 49 | 50 | - name: Bazel build config for LLVM 51 | if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.toolchain == 'llvm' }} 52 | run: | 53 | echo "build --config=clang" > $HOME/.bazelrc 54 | echo "build:clang --cxxopt=--cuda-gpu-arch=sm_80 >> $HOME/.bazelrc" 55 | 56 | - run: cd examples && bazelisk build --verbose_failures --cuda_archs='compute_80,sm_80' @rules_cuda_examples//nccl:perf_binaries 57 | 58 | - run: bazelisk shutdown 59 | 60 | # based on https://dev.to/zirkelc/trigger-github-workflow-for-comment-on-pull-request-45l2 61 | pre-test-comment: 62 | name: "Integration Test Build - Set commit status pending" 63 | if: github.event.issue.pull_request && contains(github.event.comment.body, '/test') 64 | runs-on: ubuntu-latest 65 | steps: 66 | - name: Get PR branch 67 | uses: xt0rted/pull-request-comment-branch@v2 68 | id: comment-branch 69 | - name: Set commit status as pending 70 | uses: myrotvorets/set-commit-status-action@master 71 | with: 72 | token: ${{ secrets.GITHUB_TOKEN }} 73 | sha: ${{ steps.comment-branch.outputs.head_sha }} 74 | status: pending 75 | 76 | test-comment: 77 | name: "Integration Test Build (CUDA ${{ matrix.cases.cuda-version }} on ${{ matrix.cases.os }})" 78 | needs: [pre-test-comment] 79 | runs-on: ${{ matrix.cases.os }} 80 | timeout-minutes: 60 81 | strategy: 82 | matrix: 83 | cases: 84 | - { 85 | os: "ubuntu-22.04", 86 | cuda-version: "11.7.0", 87 | source: "nvidia", 88 | toolchain: "nvcc", 89 | } 90 | - { 91 | os: "ubuntu-22.04", 92 | cuda-version: "11.7.0", 93 | source: "nvidia", 94 | toolchain: "llvm", 95 | toolchain-version: "16", 96 | } 97 | steps: 98 | - name: Get PR branch 99 | uses: xt0rted/pull-request-comment-branch@v2 100 | id: comment-branch 101 | - name: Checkout PR branch 102 | uses: actions/checkout@v4 103 | with: 104 | ref: ${{ steps.comment-branch.outputs.head_ref }} 105 | 106 | - uses: bazelbuild/setup-bazelisk@v3 107 | - name: Mount bazel cache 108 | if: ${{ !startsWith(matrix.cases.os, 'windows') }} 109 | uses: actions/cache@v4 110 | with: 111 | path: ~/.cache/bazel 112 | key: ${{ matrix.cases.toolchain }}-${{ matrix.cases.toolchain-version }} 113 | 114 | - name: Setup build environment 115 | uses: ./.github/actions/set-build-env 116 | with: 117 | os: ${{ matrix.cases.os }} 118 | cuda-version: ${{ matrix.cases.cuda-version }} 119 | source: ${{ matrix.cases.source }} 120 | toolchain: ${{ matrix.cases.toolchain }} 121 | toolchain-version: ${{ matrix.cases.toolchain-version }} 122 | 123 | - name: Bazel build config for LLVM 124 | if: ${{ !startsWith(matrix.cases.os, 'windows') && matrix.cases.toolchain == 'llvm' }} 125 | run: | 126 | echo "build --config=clang" > $HOME/.bazelrc 127 | echo "build:clang --cxxopt=--cuda-gpu-arch=sm_80 >> $HOME/.bazelrc" 128 | 129 | - run: cd examples && bazelisk build --verbose_failures --cuda_archs='compute_80,sm_80' @rules_cuda_examples//nccl:perf_binaries 130 | 131 | - run: bazelisk shutdown 132 | 133 | post-test-comment: 134 | name: "Integration Test Build - Set commit status as test result" 135 | needs: [test-comment] 136 | runs-on: ubuntu-latest 137 | steps: 138 | - name: Get PR branch 139 | uses: xt0rted/pull-request-comment-branch@v2 140 | id: comment-branch 141 | 142 | - name: Set latest commit status as ${{ job.status }} 143 | uses: myrotvorets/set-commit-status-action@master 144 | if: always() 145 | with: 146 | sha: ${{ steps.comment-branch.outputs.head_sha }} 147 | token: ${{ secrets.GITHUB_TOKEN }} 148 | status: ${{ job.status }} 149 | 150 | - name: Add comment to PR 151 | uses: actions/github-script@v7 152 | if: always() 153 | with: 154 | script: | 155 | const name = '${{ github.workflow }}'; 156 | const url = '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}'; 157 | const success = '${{ job.status }}' === 'success'; 158 | const body = `${name}: ${success ? 'succeeded ✅' : 'failed ❌'}\n${url}`; 159 | 160 | await github.rest.issues.createComment({ 161 | issue_number: context.issue.number, 162 | owner: context.repo.owner, 163 | repo: context.repo.repo, 164 | body: body 165 | }) 166 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-24.04 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: set PY 15 | run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV 16 | - uses: actions/cache@v4 17 | with: 18 | path: ~/.cache/pre-commit 19 | key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} 20 | - uses: pre-commit/action@v3.0.1 21 | - uses: pre-commit-ci/lite-action@v1.1.0 22 | if: always() 23 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | # Cut a release whenever a new tag is pushed to the repo. 2 | # You should use an annotated tag, like `git tag -a v1.2.3` 3 | # and put the release notes into the commit message for the tag. 4 | name: Release 5 | 6 | on: 7 | push: 8 | tags: 9 | - "v*.*.*" 10 | 11 | jobs: 12 | build: 13 | runs-on: ${{ matrix.os }} 14 | timeout-minutes: 60 15 | strategy: 16 | matrix: 17 | os: 18 | - ubuntu-22.04 19 | cuda-version: 20 | - 11.8.0 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | - uses: bazelbuild/setup-bazelisk@v3 25 | - name: Mount bazel cache 26 | uses: actions/cache@v4 27 | with: 28 | path: ~/.cache/bazel 29 | key: bazel-${{ matrix.os }}-cuda-${{ matrix.cuda-version }}-${{ hashFiles('.bazelversion') }} 30 | - name: Install CUDA (Linux) 31 | uses: Jimver/cuda-toolkit@v0.2.22 32 | with: 33 | cuda: ${{ matrix.cuda-version }} 34 | sub-packages: '["nvcc", "cudart-dev"]' 35 | method: network 36 | - name: bazel test //... 37 | env: 38 | # Bazelisk will download bazel to here, ensure it is cached within tests. 39 | XDG_CACHE_HOME: /home/runner/.cache/bazel-repo 40 | run: bazelisk --bazelrc=.github/workflows/ci.bazelrc --bazelrc=.bazelrc test //... 41 | - name: Create rules archive 42 | run: | 43 | PREFIX="rules_cuda-${GITHUB_REF_NAME}" 44 | git archive --format=tar.gz --prefix=${PREFIX}/ ${GITHUB_REF_NAME} -o ${{ github.workspace }}/.github/rules_cuda.tar.gz 45 | echo "ARCHIVE_SHA256=$(shasum -a 256 ${{ github.workspace }}/.github/rules_cuda.tar.gz | cut -d ' ' -f 1)" >> $GITHUB_ENV 46 | echo "RELEASE_VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV 47 | # Upload the artifact in case creating a release fails so all artifacts can then be manually recovered. 48 | - uses: actions/upload-artifact@v4 49 | with: 50 | name: "rules_cuda.tar.gz" 51 | path: ${{ github.workspace }}/.github/rules_cuda.tar.gz 52 | if-no-files-found: error 53 | - name: Prepare workspace snippet 54 | run: | 55 | sed 's/{version}/${{ env.RELEASE_VERSION }}/g' ${{ github.workspace }}/.github/release_notes.template \ 56 | | sed 's/{archive_sha256}/${{ env.ARCHIVE_SHA256 }}/g' \ 57 | > ${{ github.workspace }}/.github/release_notes.txt 58 | - name: Create the release 59 | uses: softprops/action-gh-release@v2 60 | id: rules_cuda_release 61 | with: 62 | prerelease: true 63 | # Use GH feature to populate the changelog automatically 64 | generate_release_notes: true 65 | body_path: ${{ github.workspace }}/.github/release_notes.txt 66 | - name: "Upload the rules archive" 67 | uses: actions/upload-release-asset@v1 68 | env: 69 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 70 | with: 71 | upload_url: ${{ steps.rules_cuda_release.outputs.upload_url }} 72 | asset_name: rules_cuda-${{ env.RELEASE_VERSION }}.tar.gz 73 | asset_path: ${{ github.workspace }}/.github/rules_cuda.tar.gz 74 | asset_content_type: application/gzip 75 | -------------------------------------------------------------------------------- /.github/workflows/utilities-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Test Utilities 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | name: "Test Utilities (bazel ${{ matrix.bazel-version }} on ${{ matrix.os }})" 12 | runs-on: ${{ matrix.os }} 13 | timeout-minutes: 60 14 | strategy: 15 | matrix: 16 | os: 17 | - ubuntu-22.04 18 | - windows-2019 19 | bazel-version: 20 | # NOTE: read from .bazelversion so that we don't randomly break our 21 | # ci due to latest bazel version change 22 | - .bazelversion 23 | - 7.5.0 24 | - 6.5.0 25 | env: 26 | USE_BAZEL_VERSION: ${{ matrix.bazel-version }} 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | # conditionally override USE_BAZEL_VERSION 31 | - run: if [ "${{ matrix.bazel-version }}" = ".bazelversion" ]; then echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $GITHUB_ENV; fi 32 | if: ${{ !startsWith(matrix.os, 'windows') }} 33 | - run: if ("${{ matrix.bazel-version }}" -eq ".bazelversion") { echo "USE_BAZEL_VERSION=$(cat .bazelversion)" >> $env:GITHUB_ENV } 34 | if: ${{ startsWith(matrix.os, 'windows') }} 35 | 36 | - uses: bazelbuild/setup-bazelisk@v3 37 | - name: Mount bazel cache 38 | if: ${{ !startsWith(matrix.os, 'windows') }} 39 | uses: actions/cache@v4 40 | with: 41 | path: ~/.cache/bazel 42 | key: bazel-${{ matrix.os }}-${{ matrix.bazel-version }} 43 | 44 | - uses: Jimver/cuda-toolkit@v0.2.22 45 | with: 46 | cuda: 11.7.0 47 | sub-packages: '["cudart"]' 48 | method: network 49 | 50 | - run: bazelisk test -- //tests/... 51 | 52 | - run: bazelisk shutdown 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bazel-* 2 | 3 | # MODULE.bazel.lock is a new feature in bazel and is not stable yet. 4 | # Per https://github.com/bazelbuild/bazel/issues/20369, we ignore this file until 5 | # MODULE.bazel.lock mechanism is stable and guarantee consistency across different 6 | # development environments 7 | MODULE.bazel.lock -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See CONTRIBUTING.md for instructions. 2 | # See https://pre-commit.com for more information 3 | # See https://pre-commit.com/hooks.html for more hooks 4 | 5 | default_language_version: 6 | node: 16.18.0 7 | 8 | # Commitizen runs in commit-msg stage 9 | # but we don't want to run the other hooks on commit messages 10 | default_stages: [commit] 11 | 12 | repos: 13 | # Check formatting and lint for starlark code 14 | - repo: https://github.com/garymm/bazel-buildifier-pre-commit-hooks 15 | rev: v6.1.2 16 | hooks: 17 | - id: bazel-buildifier 18 | # Enforce that commit messages allow for later changelog generation 19 | - repo: https://github.com/commitizen-tools/commitizen 20 | rev: v2.18.0 21 | hooks: 22 | # Requires that commitizen is already installed 23 | - id: commitizen 24 | stages: [commit-msg] 25 | - repo: https://github.com/pre-commit/mirrors-prettier 26 | rev: "v2.4.0" 27 | hooks: 28 | - id: prettier 29 | -------------------------------------------------------------------------------- /BUILD.bazel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/BUILD.bazel -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @cloudhan @jsharpe @ryanleary 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Guangyun Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODULE.bazel: -------------------------------------------------------------------------------- 1 | module( 2 | name = "rules_cuda", 3 | version = "0.0.0", 4 | compatibility_level = 1, 5 | ) 6 | 7 | bazel_dep(name = "bazel_skylib", version = "1.4.2") 8 | bazel_dep(name = "platforms", version = "0.0.6") 9 | 10 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 11 | cuda.toolkit( 12 | name = "cuda", 13 | toolkit_path = "", 14 | ) 15 | use_repo(cuda, "cuda") 16 | 17 | register_toolchains( 18 | "@cuda//toolchain:nvcc-local-toolchain", 19 | "@cuda//toolchain/clang:clang-local-toolchain", 20 | "@cuda//toolchain/disabled:disabled-local-toolchain", 21 | ) 22 | 23 | bazel_dep(name = "rules_cuda_examples", dev_dependency = True) 24 | local_path_override( 25 | module_name = "rules_cuda_examples", 26 | path = "./examples", 27 | ) 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUDA rules for [Bazel](https://bazel.build) 2 | 3 | This repository contains [Starlark](https://github.com/bazelbuild/starlark) implementation of CUDA rules in Bazel. 4 | 5 | These rules provide some macros and rules that make it easier to build CUDA with Bazel. 6 | 7 | ## Getting Started 8 | 9 | ### Traditional WORKSPACE approach 10 | 11 | Add the following to your `WORKSPACE` file and replace the placeholders with actual values. 12 | 13 | ```starlark 14 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 15 | http_archive( 16 | name = "rules_cuda", 17 | sha256 = "{sha256_to_replace}", 18 | strip_prefix = "rules_cuda-{git_commit_hash}", 19 | urls = ["https://github.com/bazel-contrib/rules_cuda/archive/{git_commit_hash}.tar.gz"], 20 | ) 21 | load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains") 22 | rules_cuda_dependencies() 23 | rules_cuda_toolchains(register_toolchains = True) 24 | ``` 25 | 26 | **NOTE**: `rules_cuda_toolchains` implicitly calls to `register_detected_cuda_toolchains`, and the use of 27 | `register_detected_cuda_toolchains` depends on the environment variable `CUDA_PATH`. You must also ensure the 28 | host compiler is available. On Windows, this means that you will also need to set the environment variable 29 | `BAZEL_VC` properly. 30 | 31 | [`detect_cuda_toolkit`](https://github.com/bazel-contrib/rules_cuda/blob/5633f0c0f7/cuda/private/repositories.bzl#L28-L58) 32 | and [`detect_clang`](https://github.com/bazel-contrib/rules_cuda/blob/5633f0c0f7/cuda/private/repositories.bzl#L143-L166) 33 | determains how the toolchains are detected. 34 | 35 | ### Bzlmod 36 | 37 | Add the following to your `MODULE.bazel` file and replace the placeholders with actual values. 38 | 39 | ```starlark 40 | bazel_dep(name = "rules_cuda", version = "0.2.1") 41 | 42 | # pick a specific version (this is optional an can be skipped) 43 | archive_override( 44 | module_name = "rules_cuda", 45 | integrity = "{SRI value}", # see https://developer.mozilla.org/en-US/docs/Web/Security/Subresource_Integrity 46 | url = "https://github.com/bazel-contrib/rules_cuda/archive/{git_commit_hash}.tar.gz", 47 | strip_prefix = "rules_cuda-{git_commit_hash}", 48 | ) 49 | 50 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 51 | cuda.toolkit( 52 | name = "cuda", 53 | toolkit_path = "", 54 | ) 55 | use_repo(cuda, "cuda") 56 | ``` 57 | 58 | ### Rules 59 | 60 | - `cuda_library`: Can be used to compile and create static library for CUDA kernel code. The resulting targets can be 61 | consumed by [C/C++ Rules](https://bazel.build/reference/be/c-cpp#rules). 62 | - `cuda_objects`: If you don't understand what _device link_ means, you must never use it. This rule produces incomplete 63 | object files that can only be consumed by `cuda_library`. It is created for relocatable device code and device link 64 | time optimization source files. 65 | 66 | ### Flags 67 | 68 | Some flags are defined in [cuda/BUILD.bazel](cuda/BUILD.bazel). To use them, for example: 69 | 70 | ``` 71 | bazel build --@rules_cuda//cuda:archs=compute_61:compute_61,sm_61 72 | ``` 73 | 74 | In `.bazelrc` file, you can define a shortcut alias for the flag, for example: 75 | 76 | ``` 77 | # Convenient flag shortcuts. 78 | build --flag_alias=cuda_archs=@rules_cuda//cuda:archs 79 | ``` 80 | 81 | and then you can use it as following: 82 | 83 | ``` 84 | bazel build --cuda_archs=compute_61:compute_61,sm_61 85 | ``` 86 | 87 | #### Available flags 88 | 89 | - `@rules_cuda//cuda:enable` 90 | 91 | Enable or disable all rules_cuda related rules. When disabled, the detected cuda toolchains will also be disabled to avoid potential human error. 92 | By default, rules_cuda rules are enabled. See `examples/if_cuda` for how to support both cuda-enabled and cuda-free builds. 93 | 94 | - `@rules_cuda//cuda:archs` 95 | 96 | Select the cuda archs to support. See [cuda_archs specification DSL grammar](https://github.com/bazel-contrib/rules_cuda/blob/5633f0c0f7/cuda/private/rules/flags.bzl#L14-L44). 97 | 98 | - `@rules_cuda//cuda:compiler` 99 | 100 | Select the cuda compiler, available options are `nvcc` or `clang` 101 | 102 | - `@rules_cuda//cuda:copts` 103 | 104 | Add the copts to all cuda compile actions. 105 | 106 | - `@rules_cuda//cuda:host_copts` 107 | 108 | Add the copts to the host compiler. 109 | 110 | - `@rules_cuda//cuda:runtime` 111 | 112 | Set the default cudart to link, for example, `--@rules_cuda//cuda:runtime=@cuda//:cuda_runtime_static` link the static cuda runtime. 113 | 114 | - `--features=cuda_device_debug` 115 | 116 | Sets nvcc flags to enable debug information in device code. 117 | Currently ignored for clang, where `--compilation_mode=debug` applies to both 118 | host and device code. 119 | 120 | ## Examples 121 | 122 | Checkout the examples to see if it fits your needs. 123 | 124 | See [examples](./examples) for basic usage. 125 | 126 | See [rules_cuda_examples](https://github.com/cloudhan/rules_cuda_examples) for extended real-world projects. 127 | 128 | ## Known issue 129 | 130 | Sometimes the following error occurs: 131 | 132 | ``` 133 | cc1plus: fatal error: /tmp/tmpxft_00000002_00000019-2.cpp: No such file or directory 134 | ``` 135 | 136 | The problem is caused by nvcc use PID to determine temporary file name, and with `--spawn_strategy linux-sandbox` which is the default strategy on Linux, the PIDs nvcc sees are all very small numbers, say 2~4 due to sandboxing. `linux-sandbox` is not hermetic because [it mounts root into the sandbox](https://docs.bazel.build/versions/main/command-line-reference.html#flag--experimental_use_hermetic_linux_sandbox), thus, `/tmp` is shared between sandboxes, which is causing name conflict under high parallelism. Similar problem has been reported at [nvidia forums](https://forums.developer.nvidia.com/t/avoid-generating-temp-files-in-tmp-while-nvcc-compiling/197657/10). 137 | 138 | To avoid it: 139 | 140 | - Update to Bazel 7 where `--incompatible_sandbox_hermetic_tmp` is enabled by default. 141 | - Use `--spawn_strategy local` should eliminate the case because it will let nvcc sees the true PIDs. 142 | - Use `--experimental_use_hermetic_linux_sandbox` should eliminate the case because it will avoid the sharing of `/tmp`. 143 | - Add `-objtemp` option to the command should reduce the case from happening. 144 | -------------------------------------------------------------------------------- /WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | workspace(name = "rules_cuda") 2 | 3 | local_repository( 4 | name = "rules_cuda_examples", 5 | path = "examples", 6 | ) 7 | 8 | load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains") 9 | 10 | rules_cuda_dependencies() 11 | 12 | rules_cuda_toolchains(register_toolchains = True) 13 | 14 | load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") 15 | 16 | bazel_skylib_workspace() 17 | -------------------------------------------------------------------------------- /cuda/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 2 | load( 3 | "@bazel_skylib//rules:common_settings.bzl", 4 | "bool_flag", 5 | "string_flag", 6 | ) 7 | load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag") 8 | 9 | package(default_visibility = ["//visibility:public"]) 10 | 11 | bzl_library( 12 | name = "bzl_srcs", 13 | srcs = glob(["*.bzl"]), 14 | visibility = ["//visibility:public"], 15 | deps = ["//cuda/private:bzl_srcs"], 16 | ) 17 | 18 | toolchain_type(name = "toolchain_type") 19 | 20 | # Command line flag to set ":is_enabled" config setting. 21 | # 22 | # Set with --@rules_cuda//cuda:enable 23 | bool_flag( 24 | name = "enable", 25 | build_setting_default = True, 26 | ) 27 | 28 | # This config setting can be used for conditionally depend on cuda. 29 | # 30 | # Set with --@rules_cuda//cuda:enable 31 | config_setting( 32 | name = "is_enabled", 33 | flag_values = {":enable": "True"}, 34 | ) 35 | 36 | config_setting( 37 | name = "is_valid_toolchain_found", 38 | flag_values = {"@cuda//:valid_toolchain_found": "True"}, 39 | ) 40 | 41 | # Command line flag to specify the list of CUDA architectures to compile for. 42 | # 43 | # Provides CudaArchsInfo of the list of archs to build. 44 | # 45 | # Example usage: --@rules_cuda//cuda:archs=sm_70,sm_75;sm_80,sm_86 46 | # 47 | # See CudaArchsInfo for detailed grammar 48 | cuda_archs_flag( 49 | name = "archs", 50 | build_setting_default = "", 51 | ) 52 | 53 | # Command line flag to select compiler for cuda_library() code. 54 | string_flag( 55 | name = "compiler", 56 | build_setting_default = "nvcc", 57 | values = [ 58 | "nvcc", 59 | "clang", 60 | ], 61 | ) 62 | 63 | config_setting( 64 | name = "compiler_is_nvcc", 65 | flag_values = {":compiler": "nvcc"}, 66 | ) 67 | 68 | config_setting( 69 | name = "compiler_is_clang", 70 | flag_values = {":compiler": "clang"}, 71 | ) 72 | 73 | # Command line flag for copts to add to cuda_library() compile command. 74 | repeatable_string_flag( 75 | name = "copts", 76 | build_setting_default = "", 77 | ) 78 | 79 | repeatable_string_flag( 80 | name = "host_copts", 81 | build_setting_default = "", 82 | ) 83 | 84 | # Command line flag to specify the CUDA runtime. Use this target as CUDA 85 | # runtime dependency. 86 | # 87 | # This target is implicitly added as a dependency to cuda_library() targets. 88 | # 89 | # Example usage: --@rules_cuda//cuda:runtime=@cuda//:cuda_runtime_static 90 | label_flag( 91 | name = "runtime", 92 | build_setting_default = "@cuda//:cuda_runtime", 93 | ) 94 | 95 | constraint_setting(name = "rules_are_enabled_setting") 96 | 97 | constraint_value( 98 | name = "rules_are_enabled", 99 | constraint_setting = ":rules_are_enabled_setting", 100 | ) 101 | 102 | constraint_setting(name = "valid_toolchain_is_configured_setting") 103 | 104 | constraint_value( 105 | name = "valid_toolchain_is_configured", 106 | constraint_setting = ":valid_toolchain_is_configured_setting", 107 | ) 108 | -------------------------------------------------------------------------------- /cuda/defs.bzl: -------------------------------------------------------------------------------- 1 | """ 2 | Core rules for building CUDA projects. 3 | """ 4 | 5 | load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda") 6 | load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary") 7 | load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test") 8 | load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows") 9 | load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_archs = "cuda_archs") 10 | load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") 11 | load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects") 12 | load("//cuda/private:rules/cuda_toolkit_info.bzl", _cuda_toolkit_info = "cuda_toolkit_info") 13 | load( 14 | "//cuda/private:toolchain.bzl", 15 | _cuda_toolchain = "cuda_toolchain", 16 | _find_cuda_toolchain = "find_cuda_toolchain", 17 | _use_cuda_toolchain = "use_cuda_toolchain", 18 | ) 19 | load("//cuda/private:toolchain_configs/clang.bzl", _cuda_toolchain_config_clang = "cuda_toolchain_config") 20 | load("//cuda/private:toolchain_configs/disabled.bzl", _cuda_toolchain_config_disabled = "disabled_toolchain_config") 21 | load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config") 22 | load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config") 23 | 24 | cuda_toolkit_info = _cuda_toolkit_info 25 | cuda_toolchain = _cuda_toolchain 26 | find_cuda_toolchain = _find_cuda_toolchain 27 | use_cuda_toolchain = _use_cuda_toolchain 28 | cuda_toolchain_config_clang = _cuda_toolchain_config_clang 29 | cuda_toolchain_config_disabled = _cuda_toolchain_config_disabled 30 | cuda_toolchain_config_nvcc_msvc = _cuda_toolchain_config_nvcc_msvc 31 | cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc 32 | 33 | cuda_archs = _cuda_archs 34 | CudaArchsInfo = _CudaArchsInfo 35 | 36 | # rules 37 | cuda_objects = _cuda_objects 38 | cuda_library = _cuda_library 39 | 40 | # macros 41 | cuda_binary = _cuda_binary 42 | cuda_test = _cuda_test 43 | 44 | if_linux = _if_linux 45 | if_windows = _if_windows 46 | 47 | cc_import_versioned_sos = _cc_import_versioned_sos 48 | 49 | requires_cuda = _requires_cuda 50 | -------------------------------------------------------------------------------- /cuda/dummy/BUILD.bazel: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_binary( 4 | name = "nvcc", 5 | srcs = ["dummy.cpp"], 6 | defines = ["TOOLNAME=nvcc"], 7 | ) 8 | 9 | cc_binary( 10 | name = "nvlink", 11 | srcs = ["dummy.cpp"], 12 | defines = ["TOOLNAME=nvlink"], 13 | ) 14 | 15 | exports_files(["link.stub"]) 16 | 17 | cc_binary( 18 | name = "bin2c", 19 | srcs = ["dummy.cpp"], 20 | defines = ["TOOLNAME=bin2c"], 21 | ) 22 | 23 | cc_binary( 24 | name = "fatbinary", 25 | srcs = ["dummy.cpp"], 26 | defines = ["TOOLNAME=fatbinary"], 27 | ) 28 | -------------------------------------------------------------------------------- /cuda/dummy/dummy.cpp: -------------------------------------------------------------------------------- 1 | #include "cstdio" 2 | 3 | #define TO_STRING_IND(X) #X 4 | #define TO_STRING(X) TO_STRING_IND(X) 5 | 6 | int main(int argc, char* argv[]) { 7 | std::printf("ERROR: " TO_STRING(TOOLNAME) " of cuda toolkit does not exist\n"); 8 | return -1; 9 | } 10 | -------------------------------------------------------------------------------- /cuda/dummy/link.stub: -------------------------------------------------------------------------------- 1 | #error link.stub of cuda toolkit does not exist 2 | -------------------------------------------------------------------------------- /cuda/extensions.bzl: -------------------------------------------------------------------------------- 1 | """Entry point for extensions used by bzlmod.""" 2 | 3 | load("//cuda/private:compat.bzl", "components_mapping_compat") 4 | load("//cuda/private:repositories.bzl", "cuda_component", "cuda_redist_json", "cuda_toolkit") 5 | 6 | cuda_component_tag = tag_class(attrs = { 7 | "name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"), 8 | "component_name": attr.string(doc = "Short name of the component defined in registry."), 9 | "descriptive_name": attr.string(doc = "Official name of a component or simply the component name."), 10 | "integrity": attr.string( 11 | doc = "Expected checksum in Subresource Integrity format of the file downloaded. " + 12 | "This must match the checksum of the file downloaded.", 13 | ), 14 | "sha256": attr.string( 15 | doc = "The expected SHA-256 of the file downloaded. This must match the SHA-256 of the file downloaded.", 16 | ), 17 | "strip_prefix": attr.string( 18 | doc = "A directory prefix to strip from the extracted files. " + 19 | "Many archives contain a top-level directory that contains all of the useful files in archive.", 20 | ), 21 | "url": attr.string( 22 | doc = "A URL to a file that will be made available to Bazel. " + 23 | "This must be a file, http or https URL." + 24 | "Redirections are followed. Authentication is not supported. " + 25 | "More flexibility can be achieved by the urls parameter that allows " + 26 | "to specify alternative URLs to fetch from.", 27 | ), 28 | "urls": attr.string_list( 29 | doc = "A list of URLs to a file that will be made available to Bazel. " + 30 | "Each entry must be a file, http or https URL. " + 31 | "Redirections are followed. Authentication is not supported. " + 32 | "URLs are tried in order until one succeeds, so you should list local mirrors first. " + 33 | "If all downloads fail, the rule will fail.", 34 | ), 35 | "version": attr.string(doc = "A unique version number for component."), 36 | }) 37 | 38 | cuda_redist_json_tag = tag_class(attrs = { 39 | "name": attr.string(mandatory = True, doc = "Repo name for the cuda_redist_json"), 40 | "components": attr.string_list(mandatory = True, doc = "components to be used"), 41 | "integrity": attr.string( 42 | doc = "Expected checksum in Subresource Integrity format of the file downloaded. " + 43 | "This must match the checksum of the file downloaded.", 44 | ), 45 | "sha256": attr.string( 46 | doc = "The expected SHA-256 of the file downloaded. " + 47 | "This must match the SHA-256 of the file downloaded.", 48 | ), 49 | "urls": attr.string_list( 50 | doc = "A list of URLs to a file that will be made available to Bazel. " + 51 | "Each entry must be a file, http or https URL. Redirections are followed. " + 52 | "Authentication is not supported. " + 53 | "URLs are tried in order until one succeeds, so you should list local mirrors first. " + 54 | "If all downloads fail, the rule will fail.", 55 | ), 56 | "version": attr.string( 57 | doc = "Generate a URL by using the specified version." + 58 | "This URL will be tried after all URLs specified in the `urls` attribute.", 59 | ), 60 | }) 61 | 62 | cuda_toolkit_tag = tag_class(attrs = { 63 | "name": attr.string(mandatory = True, doc = "Name for the toolchain repository", default = "cuda"), 64 | "toolkit_path": attr.string( 65 | doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path.", 66 | ), 67 | "components_mapping": components_mapping_compat.attr( 68 | doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " + 69 | "Only the repo part of the label is useful", 70 | ), 71 | "version": attr.string(doc = "cuda toolkit version. Required for deliverable toolkit only."), 72 | "nvcc_version": attr.string( 73 | doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.", 74 | ), 75 | }) 76 | 77 | def _find_modules(module_ctx): 78 | root = None 79 | our_module = None 80 | for mod in module_ctx.modules: 81 | if mod.is_root: 82 | root = mod 83 | if mod.name == "rules_cuda": 84 | our_module = mod 85 | if root == None: 86 | root = our_module 87 | if our_module == None: 88 | fail("Unable to find rules_cuda module") 89 | 90 | return root, our_module 91 | 92 | def _module_tag_to_dict(t): 93 | return {attr: getattr(t, attr) for attr in dir(t)} 94 | 95 | def _impl(module_ctx): 96 | # Toolchain configuration is only allowed in the root module, or in rules_cuda. 97 | root, rules_cuda = _find_modules(module_ctx) 98 | components = None 99 | redist_jsons = None 100 | toolkits = None 101 | if root.tags.toolkit: 102 | components = root.tags.component 103 | redist_jsons = root.tags.redist_json 104 | toolkits = root.tags.toolkit 105 | else: 106 | components = rules_cuda.tags.component 107 | redist_jsons = rules_cuda.tags.redist_json 108 | toolkits = rules_cuda.tags.toolkit 109 | 110 | for component in components: 111 | cuda_component(**_module_tag_to_dict(component)) 112 | 113 | for redist_json in redist_jsons: 114 | cuda_redist_json(**_module_tag_to_dict(redist_json)) 115 | 116 | registrations = {} 117 | for toolkit in toolkits: 118 | if toolkit.name in registrations.keys(): 119 | if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path: 120 | # No problem to register a matching toolkit twice 121 | continue 122 | fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path)) 123 | else: 124 | registrations[toolkit.name] = toolkit 125 | for _, toolkit in registrations.items(): 126 | cuda_toolkit(**_module_tag_to_dict(toolkit)) 127 | 128 | toolchain = module_extension( 129 | implementation = _impl, 130 | tag_classes = { 131 | "component": cuda_component_tag, 132 | "redist_json": cuda_redist_json_tag, 133 | "toolkit": cuda_toolkit_tag, 134 | }, 135 | ) 136 | -------------------------------------------------------------------------------- /cuda/private/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 2 | 3 | package(default_visibility = ["//visibility:private"]) 4 | 5 | bzl_library( 6 | name = "bzl_srcs", 7 | srcs = glob(["**/*.bzl"]), 8 | visibility = ["//visibility:public"], 9 | deps = [ 10 | "@bazel_skylib//lib:partial", 11 | "@bazel_skylib//lib:paths", 12 | "@bazel_skylib//lib:unittest", 13 | "@bazel_skylib//rules:common_settings", 14 | "@bazel_tools//tools/build_defs/repo:http.bzl", 15 | "@bazel_tools//tools/build_defs/repo:utils.bzl", 16 | "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", 17 | "@bazel_tools//tools/cpp:toolchain_utils.bzl", 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /cuda/private/action_names.bzl: -------------------------------------------------------------------------------- 1 | CUDA_COMPILE = "cuda-compile" # cuda compile comprise of host and device compilation 2 | 3 | CUDA_DEVICE_LINK = "cuda-dlink" 4 | 5 | ACTION_NAMES = struct( 6 | cuda_compile = CUDA_COMPILE, 7 | device_link = CUDA_DEVICE_LINK, 8 | ) 9 | -------------------------------------------------------------------------------- /cuda/private/actions/compile.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/cc:action_names.bzl", CC_ACTION_NAMES = "ACTION_NAMES") 2 | load("//cuda/private:action_names.bzl", "ACTION_NAMES") 3 | load("//cuda/private:cuda_helper.bzl", "cuda_helper") 4 | load("//cuda/private:rules/common.bzl", "ALLOW_CUDA_SRCS") 5 | 6 | def compile( 7 | ctx, 8 | cuda_toolchain, 9 | cc_toolchain, 10 | srcs, 11 | common, 12 | pic = False, 13 | rdc = False, 14 | _prefix = "_objs"): 15 | """Perform CUDA compilation, return compiled object files. 16 | 17 | Notes: 18 | 19 | - If `rdc` is set to `True`, then an additional step of device link must be performed. 20 | - The rules should call this action only once in case srcs have non-unique basenames, 21 | say `foo/kernel.cu` and `bar/kernel.cu`. 22 | 23 | Args: 24 | ctx: A [context object](https://bazel.build/rules/lib/ctx). 25 | cuda_toolchain: A `platform_common.ToolchainInfo` of a cuda toolchain, Can be obtained with `find_cuda_toolchain(ctx)`. 26 | cc_toolchain: A `CcToolchainInfo`. Can be obtained with `find_cpp_toolchain(ctx)`. 27 | srcs: A list of `File`s to be compiled. 28 | common: A cuda common object. Can be obtained with `cuda_helper.create_common(ctx)` 29 | pic: Whether the `srcs` are compiled for position independent code. 30 | rdc: Whether the `srcs` are compiled for relocatable device code. 31 | _prefix: DON'T USE IT! Prefix of the output dir. Exposed for device link to redirect the output. 32 | 33 | Returns: 34 | An compiled object `File`. 35 | """ 36 | actions = ctx.actions 37 | cc_feature_configuration = cc_common.configure_features( 38 | ctx = ctx, 39 | cc_toolchain = cc_toolchain, 40 | requested_features = ctx.features, 41 | unsupported_features = ctx.disabled_features, 42 | ) 43 | host_compiler = cc_common.get_tool_for_action(feature_configuration = cc_feature_configuration, action_name = CC_ACTION_NAMES.cpp_compile) 44 | cuda_compiler = cuda_toolchain.compiler_executable 45 | 46 | cuda_feature_config = cuda_helper.configure_features(ctx, cuda_toolchain, requested_features = [ACTION_NAMES.cuda_compile]) 47 | artifact_category_name = cuda_helper.get_artifact_category_from_action(ACTION_NAMES.cuda_compile, pic, rdc) 48 | 49 | basename_counter = {} 50 | src_and_indexed_basenames = [] 51 | for src in srcs: 52 | # this also filter out all header files 53 | basename = cuda_helper.get_basename_without_ext(src.basename, ALLOW_CUDA_SRCS, fail_if_not_match = False) 54 | if not basename: 55 | continue 56 | basename_index = basename_counter.setdefault(basename, default = 0) 57 | basename_counter[basename] += 1 58 | src_and_indexed_basenames.append((src, basename, basename_index)) 59 | 60 | ret = [] 61 | for src, basename, basename_index in src_and_indexed_basenames: 62 | filename = None 63 | filename = cuda_helper.get_artifact_name(cuda_toolchain, artifact_category_name, basename) 64 | 65 | # Objects are placed in <_prefix>//. 66 | # For files with the same basename, say srcs = ["kernel.cu", "foo/kernel.cu", "bar/kernel.cu"], we get 67 | # <_prefix>//0/kernel., <_prefix>//1/kernel., <_prefix>//2/kernel.. 68 | # Otherwise, the index is not presented. 69 | if basename_counter[basename] > 1: 70 | filename = "{}/{}".format(basename_index, filename) 71 | obj_file = actions.declare_file("{}/{}/{}".format(_prefix, ctx.attr.name, filename)) 72 | ret.append(obj_file) 73 | 74 | var = cuda_helper.create_compile_variables( 75 | cuda_toolchain, 76 | cuda_feature_config, 77 | common.cuda_archs_info, 78 | common.sysroot, 79 | source_file = src.path, 80 | output_file = obj_file.path, 81 | host_compiler = host_compiler, 82 | compile_flags = common.compile_flags, 83 | host_compile_flags = common.host_compile_flags, 84 | include_paths = common.includes, 85 | quote_include_paths = common.quote_includes, 86 | system_include_paths = common.system_includes, 87 | defines = common.local_defines + common.defines, 88 | host_defines = common.host_local_defines + common.host_defines, 89 | ptxas_flags = common.ptxas_flags, 90 | use_pic = pic, 91 | use_rdc = rdc, 92 | ) 93 | cmd = cuda_helper.get_command_line(cuda_feature_config, ACTION_NAMES.cuda_compile, var) 94 | env = cuda_helper.get_environment_variables(cuda_feature_config, ACTION_NAMES.cuda_compile, var) 95 | 96 | args = actions.args() 97 | args.add_all(cmd) 98 | 99 | actions.run( 100 | executable = cuda_compiler, 101 | arguments = [args], 102 | outputs = [obj_file], 103 | inputs = depset([src], transitive = [common.headers, cc_toolchain.all_files, cuda_toolchain.all_files]), 104 | env = env, 105 | mnemonic = "CudaCompile", 106 | progress_message = "Compiling %s" % src.path, 107 | ) 108 | return ret 109 | -------------------------------------------------------------------------------- /cuda/private/actions/dlink.bzl: -------------------------------------------------------------------------------- 1 | "" 2 | 3 | load("@bazel_tools//tools/build_defs/cc:action_names.bzl", CC_ACTION_NAMES = "ACTION_NAMES") 4 | load("//cuda/private:action_names.bzl", "ACTION_NAMES") 5 | load("//cuda/private:actions/compile.bzl", "compile") 6 | load("//cuda/private:cuda_helper.bzl", "cuda_helper") 7 | load("//cuda/private:toolchain.bzl", "find_cuda_toolkit") 8 | 9 | def device_link( 10 | ctx, 11 | cuda_toolchain, 12 | cc_toolchain, 13 | objects, 14 | common, 15 | pic = False, 16 | rdc = False, 17 | dlto = False): 18 | """Perform device link, return a dlink-ed object file. 19 | 20 | Notes: 21 | Compilation is carried out during device linking, which involves the embeeding of the fatbin into the resulting object `File`. 22 | 23 | Args: 24 | ctx: A [context object](https://bazel.build/rules/lib/ctx). 25 | cuda_toolchain: A `platform_common.ToolchainInfo` of a cuda toolchain, Can be obtained with `find_cuda_toolchain(ctx)`. 26 | cc_toolchain: A `CcToolchainInfo`. Can be obtained with `find_cpp_toolchain(ctx)`. 27 | objects: A `depset` of `File`s to be device linked. 28 | common: A cuda common object. Can be obtained with `cuda_helper.create_common(ctx)` 29 | pic: Whether the `objects` are compiled for position independent code. 30 | rdc: Whether the `objects` are device linked for relocatable device code. 31 | dlto: Whether the device link time optimization is enabled. 32 | 33 | Returns: 34 | An deviced linked object `File`. 35 | """ 36 | cuda_feature_config = cuda_helper.configure_features(ctx, cuda_toolchain, requested_features = [ACTION_NAMES.device_link]) 37 | if cuda_helper.is_enabled(cuda_feature_config, "supports_compiler_device_link"): 38 | return _compiler_device_link(ctx, cuda_toolchain, cc_toolchain, cuda_feature_config, objects, common, pic = pic, rdc = rdc, dlto = dlto) 39 | elif cuda_helper.is_enabled(cuda_feature_config, "supports_wrapper_device_link"): 40 | return _wrapper_device_link(ctx, cuda_toolchain, cc_toolchain, objects, common, pic = pic, rdc = rdc, dlto = dlto) 41 | else: 42 | fail("toolchain must be configured to enable feature supports_compiler_device_link or supports_wrapper_device_link.") 43 | 44 | def _compiler_device_link( 45 | ctx, 46 | cuda_toolchain, 47 | cc_toolchain, 48 | cuda_feature_config, 49 | objects, 50 | common, 51 | pic = False, 52 | rdc = False, 53 | dlto = False): 54 | """perform compiler supported native device link, return a dlink-ed object file""" 55 | if not rdc: 56 | fail("device link is only meaningful on building relocatable device code") 57 | 58 | actions = ctx.actions 59 | cc_feature_configuration = cc_common.configure_features( 60 | ctx = ctx, 61 | cc_toolchain = cc_toolchain, 62 | requested_features = ctx.features, 63 | unsupported_features = ctx.disabled_features, 64 | ) 65 | host_compiler = cc_common.get_tool_for_action(feature_configuration = cc_feature_configuration, action_name = CC_ACTION_NAMES.cpp_compile) 66 | cuda_compiler = cuda_toolchain.compiler_executable 67 | 68 | artifact_category_name = cuda_helper.get_artifact_category_from_action(ACTION_NAMES.device_link, pic, rdc) 69 | basename = ctx.attr.name + "_dlink" 70 | filename = cuda_helper.get_artifact_name(cuda_toolchain, artifact_category_name, basename) 71 | 72 | obj_file = actions.declare_file("_objs/{}/{}".format(ctx.attr.name, filename)) 73 | 74 | var = cuda_helper.create_device_link_variables( 75 | cuda_toolchain, 76 | cuda_feature_config, 77 | common.cuda_archs_info, 78 | common.sysroot, 79 | output_file = obj_file.path, 80 | host_compiler = host_compiler, 81 | host_compile_flags = common.host_compile_flags, 82 | user_link_flags = common.link_flags, 83 | use_pic = pic, 84 | ) 85 | cmd = cuda_helper.get_command_line(cuda_feature_config, ACTION_NAMES.device_link, var) 86 | env = cuda_helper.get_environment_variables(cuda_feature_config, ACTION_NAMES.device_link, var) 87 | args = actions.args() 88 | args.add_all(cmd) 89 | args.add_all(objects) 90 | 91 | actions.run( 92 | executable = cuda_compiler, 93 | arguments = [args], 94 | outputs = [obj_file], 95 | inputs = depset(transitive = [objects, cc_toolchain.all_files, cuda_toolchain.all_files]), 96 | env = env, 97 | mnemonic = "CudaDeviceLink", 98 | progress_message = "Device linking %{output}", 99 | ) 100 | return obj_file 101 | 102 | def _wrapper_device_link( 103 | ctx, 104 | cuda_toolchain, 105 | cc_toolchain, 106 | objects, 107 | common, 108 | pic = False, 109 | rdc = False, 110 | dlto = False): 111 | """perform bazel macro supported device link, return a dlink-ed object file""" 112 | if not rdc: 113 | fail("device link is only meaningful on building relocatable device code") 114 | 115 | cuda_toolkit = find_cuda_toolkit(ctx) 116 | 117 | actions = ctx.actions 118 | pic_suffix = "_pic" if pic else "" 119 | 120 | # Device-link to cubins for each gpu architecture. The stage1 compiled PTX is embedded in the object files. 121 | # We don't need to do any thing about it, presumably. 122 | register_h = None 123 | cubins = [] 124 | images = [] 125 | obj_args = actions.args() 126 | obj_args.add_all(objects) 127 | if len(common.cuda_archs_info.arch_specs) == 0: 128 | fail('cuda toolchain "' + cuda_toolchain.name + '" is configured to enable feature supports_wrapper_device_link,' + 129 | " at least one cuda arch must be specified.") 130 | for arch_spec in common.cuda_archs_info.arch_specs: 131 | for stage2_arch in arch_spec.stage2_archs: 132 | if stage2_arch.gpu: 133 | arch = "sm_" + stage2_arch.arch 134 | elif stage2_arch.lto: 135 | arch = "lto_" + stage2_arch.arch 136 | else: 137 | # PTX is JIT-linked at runtime 138 | continue 139 | 140 | register_h = ctx.actions.declare_file("_dlink{suffix}/{0}/{0}_register_{1}.h".format(ctx.attr.name, arch, suffix = pic_suffix)) 141 | cubin = ctx.actions.declare_file("_dlink{suffix}/{0}/{0}_{1}.cubin".format(ctx.attr.name, arch, suffix = pic_suffix)) 142 | ctx.actions.run( 143 | outputs = [register_h, cubin], 144 | inputs = objects, 145 | executable = cuda_toolkit.nvlink, 146 | arguments = [ 147 | "--arch=" + arch, 148 | "--register-link-binaries=" + register_h.path, 149 | "--output-file=" + cubin.path, 150 | obj_args, 151 | ], 152 | mnemonic = "nvlink", 153 | ) 154 | cubins.append(cubin) 155 | images.append("--image=profile={},file={}".format(arch, cubin.path)) 156 | 157 | # Generate fatbin header from all cubins. 158 | fatbin = ctx.actions.declare_file("_dlink{suffix}/{0}/{0}.fatbin".format(ctx.attr.name, suffix = pic_suffix)) 159 | fatbin_h = ctx.actions.declare_file("_dlink{suffix}/{0}/{0}_fatbin.h".format(ctx.attr.name, suffix = pic_suffix)) 160 | 161 | arguments = [ 162 | "-64", 163 | "--cmdline=--compile-only", 164 | "--link", 165 | "--compress-all", 166 | "--create=" + fatbin.path, 167 | "--embedded-fatbin=" + fatbin_h.path, 168 | ] 169 | bin2c = cuda_toolkit.bin2c 170 | if (cuda_toolkit.version_major, cuda_toolkit.version_minor) <= (10, 1): 171 | arguments.append("--bin2c-path=%s" % bin2c.dirname) 172 | ctx.actions.run( 173 | outputs = [fatbin, fatbin_h], 174 | inputs = cubins, 175 | executable = cuda_toolkit.fatbinary, 176 | arguments = arguments + images, 177 | tools = [bin2c], 178 | mnemonic = "fatbinary", 179 | ) 180 | 181 | # Generate the source file #including the headers generated above. 182 | fatbin_c = ctx.actions.declare_file("_dlink{suffix}/{0}/{0}.cu".format(ctx.attr.name, suffix = pic_suffix)) 183 | ctx.actions.expand_template( 184 | output = fatbin_c, 185 | template = cuda_toolkit.link_stub, 186 | substitutions = { 187 | "REGISTERLINKBINARYFILE": '"{}"'.format(register_h.short_path), 188 | "FATBINFILE": '"{}"'.format(fatbin_h.short_path), 189 | }, 190 | ) 191 | 192 | # cc_common.compile will cause file conflict for pic and non-pic objects, 193 | # and it accepts only one set of src files. But pic fatbin_c and non-pic 194 | # fatbin_c have different compilation trajectories. This makes me feel bad. 195 | # Just avoid cc_common.compile at all. 196 | compile_common = cuda_helper.create_common_info( 197 | # this is useless 198 | cuda_archs_info = common.cuda_archs_info, 199 | headers = [fatbin_h, register_h], 200 | defines = [ 201 | # Silence warning about including internal header. 202 | "__CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__", 203 | # Macros that need to be defined starting with CUDA 10. 204 | "__NV_EXTRA_INITIALIZATION=", 205 | "__NV_EXTRA_FINALIZATION=", 206 | ], 207 | includes = common.includes, 208 | system_includes = common.system_includes, 209 | quote_includes = common.quote_includes, 210 | # suppress cuda mode as c++ mode 211 | compile_flags = ["-x", "c++"], 212 | host_compile_flags = common.host_compile_flags, 213 | ) 214 | ret = compile(ctx, cuda_toolchain, cc_toolchain, srcs = [fatbin_c], common = compile_common, pic = pic, rdc = rdc, _prefix = "_objs/_dlink") 215 | return ret[0] 216 | -------------------------------------------------------------------------------- /cuda/private/artifact_categories.bzl: -------------------------------------------------------------------------------- 1 | OBJECT_FILE = "object_file" 2 | PIC_OBJECT_FILE = "pic_object_file" 3 | RDC_OBJECT_FILE = "rdc_object_file" 4 | RDC_PIC_OBJECT_FILE = "rdc_pic_object_file" 5 | 6 | ARTIFACT_CATEGORIES = struct( 7 | object_file = OBJECT_FILE, 8 | pic_object_file = PIC_OBJECT_FILE, 9 | rdc_object_file = RDC_OBJECT_FILE, 10 | rdc_pic_object_file = RDC_PIC_OBJECT_FILE, 11 | ) 12 | -------------------------------------------------------------------------------- /cuda/private/compat.bzl: -------------------------------------------------------------------------------- 1 | _is_attr_string_keyed_label_dict_available = getattr(attr, "string_keyed_label_dict", None) != None 2 | _is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@") 3 | 4 | def _attr(*args, **kwargs): 5 | """Compatibility layer for attr.string_keyed_label_dict(...)""" 6 | if _is_attr_string_keyed_label_dict_available: 7 | return attr.string_keyed_label_dict(*args, **kwargs) 8 | else: 9 | return attr.string_dict(*args, **kwargs) 10 | 11 | def _repo_str(repo_str_or_repo_label): 12 | """Get mapped repo as string. 13 | 14 | Args: 15 | repo_str_or_repo_label: `"@repo"` or `Label("@repo")` """ 16 | if type(repo_str_or_repo_label) == "Label": 17 | canonical_repo_name = repo_str_or_repo_label.repo_name 18 | repo_str = ("@@{}" if _is_bzlmod_enabled else "@{}").format(canonical_repo_name) 19 | return repo_str 20 | else: 21 | return repo_str_or_repo_label 22 | 23 | components_mapping_compat = struct( 24 | attr = _attr, 25 | repo_str = _repo_str, 26 | ) 27 | -------------------------------------------------------------------------------- /cuda/private/defs.bzl: -------------------------------------------------------------------------------- 1 | """private""" 2 | 3 | def _requires_rules_are_enabled(): 4 | return select({ 5 | "@rules_cuda//cuda:is_enabled": [], 6 | "//conditions:default": ["@rules_cuda//cuda:rules_are_enabled"], 7 | }) 8 | 9 | def _requires_valid_toolchain_is_configured(): 10 | return select({ 11 | "@rules_cuda//cuda:is_valid_toolchain_found": [], 12 | "//conditions:default": ["@rules_cuda//cuda:valid_toolchain_is_configured"], 13 | }) 14 | 15 | def requires_cuda(): 16 | """Returns constraint_setting that is satisfied if: 17 | 18 | * rules are enabled and 19 | * a valid toolchain is configured. 20 | 21 | Add to 'target_compatible_with' attribute to mark a target incompatible when 22 | the conditions are not satisfied. Incompatible targets are excluded 23 | from bazel target wildcards and fail to build if requested explicitly. 24 | """ 25 | return _requires_rules_are_enabled() + _requires_valid_toolchain_is_configured() 26 | -------------------------------------------------------------------------------- /cuda/private/macros/cuda_binary.bzl: -------------------------------------------------------------------------------- 1 | load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") 2 | 3 | def cuda_binary(name, **attrs): 4 | """A macro wraps cuda_library and cc_binary to ensure the binary is compiled with the CUDA compiler. 5 | 6 | Notes: 7 | host_copts, host_defines, host_local_defines and host_linkopts will be used for cc_binary and renamed without "host_" prefix 8 | 9 | Args: 10 | name: A unique name for this target (cc_binary). 11 | **attrs: attrs of cc_binary and cuda_library. 12 | """ 13 | cuda_library_only_attrs = ["deps", "srcs", "hdrs", "alwayslink", "rdc", "ptxasopts"] 14 | cuda_library_only_attrs_defaults = { 15 | "alwayslink": True, 16 | } 17 | rename_attrs = { 18 | # for cc_binary 19 | "host_copts": "copts", 20 | "host_defines": "defines", 21 | "host_local_defines": "local_defines", 22 | "host_linkopts": "linkopts", 23 | } 24 | 25 | # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries 26 | cc_binary_only_attrs = ["args", "env", "output_licenses"] 27 | 28 | cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} 29 | for attr in cuda_library_only_attrs_defaults: 30 | if attr not in cuda_library_attrs: 31 | cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] 32 | 33 | cuda_library_name = "_" + name 34 | _cuda_library( 35 | name = cuda_library_name, 36 | **cuda_library_attrs 37 | ) 38 | 39 | cc_attrs = {k: v for k, v in attrs.items() if k not in cuda_library_only_attrs} 40 | for src, dst in rename_attrs.items(): 41 | if dst in cc_attrs: 42 | cc_attrs.pop(dst) 43 | if src in cc_attrs: 44 | cc_attrs[dst] = cc_attrs.pop(src) 45 | 46 | native.cc_binary( 47 | name = name, 48 | deps = [cuda_library_name], 49 | **cc_attrs 50 | ) 51 | -------------------------------------------------------------------------------- /cuda/private/macros/cuda_test.bzl: -------------------------------------------------------------------------------- 1 | load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") 2 | 3 | def cuda_test(name, **attrs): 4 | """A macro wraps cuda_library and cc_test to ensure the test is compiled with the CUDA compiler. 5 | 6 | Notes: 7 | host_copts, host_defines, host_local_defines and host_linkopts will be used for cc_test and renamed without "host_" prefix 8 | 9 | Args: 10 | name: A unique name for this target (cc_test). 11 | **attrs: attrs of cc_test and cuda_library. 12 | """ 13 | cuda_library_only_attrs = ["deps", "srcs", "hdrs", "testonly", "alwayslink", "rdc", "ptxasopts"] 14 | cuda_library_only_attrs_defaults = { 15 | "testonly": True, 16 | "alwayslink": True, 17 | } 18 | rename_attrs = { 19 | # for cc_test 20 | "host_copts": "copts", 21 | "host_defines": "defines", 22 | "host_local_defines": "local_defines", 23 | "host_linkopts": "linkopts", 24 | } 25 | 26 | # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-tests 27 | cc_test_only_attrs = ["args", "env", "env_inherit", "size", "timeout", "flaky", "shard_count", "local", "data"] 28 | 29 | cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_test_only_attrs} 30 | for attr in cuda_library_only_attrs_defaults: 31 | if attr not in cuda_library_attrs: 32 | cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] 33 | 34 | cuda_library_name = "_" + name 35 | _cuda_library( 36 | name = cuda_library_name, 37 | **cuda_library_attrs 38 | ) 39 | 40 | cc_attrs = {k: v for k, v in attrs.items() if k not in cuda_library_only_attrs} 41 | for src, dst in rename_attrs.items(): 42 | if dst in cc_attrs: 43 | cc_attrs.pop(dst) 44 | if src in cc_attrs: 45 | cc_attrs[dst] = cc_attrs.pop(src) 46 | 47 | native.cc_test( 48 | name = name, 49 | deps = [cuda_library_name], 50 | **cc_attrs 51 | ) 52 | -------------------------------------------------------------------------------- /cuda/private/os_helpers.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//lib:paths.bzl", "paths") 2 | 3 | def if_linux(if_true, if_false = []): 4 | return select({ 5 | "@platforms//os:linux": if_true, 6 | "//conditions:default": if_false, 7 | }) 8 | 9 | def if_windows(if_true, if_false = []): 10 | return select({ 11 | "@platforms//os:windows": if_true, 12 | "//conditions:default": if_false, 13 | }) 14 | 15 | def cc_import_versioned_sos(name, shared_library): 16 | """Creates a cc_library that depends on all versioned .so files with the given prefix. 17 | 18 | If is path/to/foo.so, and it is a symlink to foo.so., 19 | this should be used instead of cc_import. 20 | The versioned files are typically needed at runtime, but not at build time. 21 | 22 | Args: 23 | name: Name of the cc_library. 24 | shared_library: Prefix of the versioned .so files. 25 | """ 26 | 27 | # NOTE: only empty when the componnent is not installed on the system, say, cublas is not installed with apt-get 28 | so_paths = native.glob([shared_library + "*"], allow_empty = True) 29 | 30 | [native.cc_import( 31 | name = paths.basename(p), 32 | shared_library = p, 33 | target_compatible_with = ["@platforms//os:linux"], 34 | ) for p in so_paths] 35 | 36 | native.cc_library( 37 | name = name, 38 | deps = [":%s" % paths.basename(p) for p in so_paths], 39 | ) 40 | -------------------------------------------------------------------------------- /cuda/private/providers.bzl: -------------------------------------------------------------------------------- 1 | """Defines all providers that are used in this repo.""" 2 | 3 | cuda_archs = [ 4 | "30", 5 | "32", 6 | "35", 7 | "37", 8 | "50", 9 | "52", 10 | "53", 11 | "60", 12 | "61", 13 | "62", 14 | "70", 15 | "72", 16 | "75", 17 | "80", 18 | "86", 19 | "87", 20 | "89", 21 | "90", 22 | "90a", 23 | "100", 24 | "100a", 25 | "101", 26 | "101a", 27 | "120", 28 | "120a", 29 | ] 30 | 31 | Stage2ArchInfo = provider( 32 | """Provides the information of how the stage 2 compilation is carried out. 33 | 34 | One and only one of `virtual`, `gpu` and `lto` must be set to True. For example, if `arch` is set to `80` and `virtual` is `True`, then a 35 | ptx embedding process is carried out for `compute_80`. Multiple `Stage2ArchInfo` can be used for specifying how a stage 1 result is 36 | transformed into its final form.""", 37 | fields = { 38 | "arch": "str, arch number", 39 | "virtual": "bool, use virtual arch, default False", 40 | "gpu": "bool, use gpu arch, default False", 41 | "lto": "bool, use lto, default False", 42 | }, 43 | ) 44 | 45 | ArchSpecInfo = provider( 46 | """Provides the information of how [GPU compilation](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-compilation) 47 | is carried out of a single virtual architecture.""", 48 | fields = { 49 | "stage1_arch": "A virtual architecture, str, arch number only", 50 | "stage2_archs": "A list of virtual or gpu architecture, list of Stage2ArchInfo", 51 | }, 52 | ) 53 | 54 | CudaArchsInfo = provider( 55 | """Provides a list of CUDA archs to compile for. 56 | 57 | Read the whole [Chapter 5 of CUDA Compiler Driver NVCC Reference Guide](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-compilation) 58 | if more detail is needed.""", 59 | fields = { 60 | "arch_specs": "A list of ArchSpecInfo", 61 | }, 62 | ) 63 | 64 | CudaInfo = provider( 65 | """Provides cuda build artifacts that can be consumed by device linking or linking process. 66 | 67 | This provider is analog to [CcInfo](https://bazel.build/rules/lib/CcInfo) but only contains necessary information for 68 | linking in a flat structure. Objects are grouped by direct and transitive, because we have no way to split them again 69 | if merged a single depset. 70 | """, 71 | fields = { 72 | "defines": "A depset of strings. It is used for the compilation during device linking.", 73 | # direct only: 74 | "objects": "A depset of objects. Direct artifacts of the rule.", # but not rdc and pic 75 | "pic_objects": "A depset of position independent code objects. Direct artifacts of the rule.", # but not rdc 76 | "rdc_objects": "A depset of relocatable device code objects. Direct artifacts of the rule.", # but not pic 77 | "rdc_pic_objects": "A depset of relocatable device code and position independent code objects. Direct artifacts of the rule.", 78 | # transitive archive only (cuda_objects): 79 | "archive_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving.", 80 | "archive_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving.", 81 | "archive_rdc_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.", 82 | "archive_rdc_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.", 83 | 84 | # transitive dlink only (cuda_library): 85 | # NOTE: ideally, we can use the archived library to do the device linking, but the nvlink is not happy with library with *_dlink.o included 86 | "dlink_rdc_objects": "A depset of rdc objects. cuda_library only. Gathered from the transitive dependencies for device linking.", 87 | "dlink_rdc_pic_objects": "A depset of rdc pic objects. cuda_library only. Gathered from the transitive dependencies for device linking.", 88 | }, 89 | ) 90 | 91 | CudaToolkitInfo = provider( 92 | """Provides the information of CUDA Toolkit.""", 93 | fields = { 94 | "path": "string of path to cuda toolkit root", 95 | "version_major": "int of the cuda toolkit major version, e.g, 11 for 11.6", 96 | "version_minor": "int of the cuda toolkit minor version, e.g, 6 for 11.6", 97 | "nvlink": "File to the nvlink executable", 98 | "link_stub": "File to the link.stub file", 99 | "bin2c": "File to the bin2c executable", 100 | "fatbinary": "File to the fatbinary executable", 101 | }, 102 | ) 103 | 104 | CudaToolchainConfigInfo = provider( 105 | """Provides the information of what the toolchain is and how the toolchain is configured.""", 106 | fields = { 107 | "action_configs": "A list of action_configs.", 108 | "artifact_name_patterns": "A list of artifact_name_patterns.", 109 | "cuda_toolkit": "A target that provides a `CudaToolkitInfo`", 110 | "features": "A list of features.", 111 | "toolchain_identifier": "nvcc or clang", 112 | }, 113 | ) 114 | -------------------------------------------------------------------------------- /cuda/private/rules/common.bzl: -------------------------------------------------------------------------------- 1 | ALLOW_CUDA_HDRS = [ 2 | ".cuh", 3 | ".h", 4 | ".hpp", 5 | ".hh", 6 | ".inl", 7 | ] 8 | 9 | ALLOW_CUDA_SRCS = [ 10 | ".cc", 11 | ".cpp", 12 | ".cu", 13 | ] 14 | -------------------------------------------------------------------------------- /cuda/private/rules/cuda_objects.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") 2 | load("//cuda/private:actions/compile.bzl", "compile") 3 | load("//cuda/private:cuda_helper.bzl", "cuda_helper") 4 | load("//cuda/private:providers.bzl", "CudaInfo") 5 | load("//cuda/private:rules/common.bzl", "ALLOW_CUDA_HDRS", "ALLOW_CUDA_SRCS") 6 | load("//cuda/private:toolchain.bzl", "find_cuda_toolchain", "use_cuda_toolchain") 7 | 8 | def _cuda_objects_impl(ctx): 9 | attr = ctx.attr 10 | cuda_helper.check_srcs_extensions(ctx, ALLOW_CUDA_SRCS + ALLOW_CUDA_HDRS, "cuda_object") 11 | 12 | cc_toolchain = find_cpp_toolchain(ctx) 13 | cuda_toolchain = find_cuda_toolchain(ctx) 14 | 15 | common = cuda_helper.create_common(ctx) 16 | 17 | # flatten first, so that non-unique basenames can be properly deduplicated 18 | src_files = [] 19 | for src in ctx.attr.srcs: 20 | src_files.extend(src[DefaultInfo].files.to_list()) 21 | 22 | # merge deps' direct objects and archive objects as our archive objects 23 | archive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep] + 24 | [dep[CudaInfo].archive_objects for dep in attr.deps if CudaInfo in dep]) 25 | archive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep] + 26 | [dep[CudaInfo].archive_pic_objects for dep in attr.deps if CudaInfo in dep]) 27 | archive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep] + 28 | [dep[CudaInfo].archive_rdc_objects for dep in attr.deps if CudaInfo in dep]) 29 | archive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep] + 30 | [dep[CudaInfo].archive_rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) 31 | 32 | # direct outputs 33 | objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False)) 34 | pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False)) 35 | rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True)) 36 | rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True)) 37 | 38 | compilation_ctx = cc_common.create_compilation_context( 39 | headers = common.headers, 40 | includes = depset(common.includes), 41 | system_includes = depset(common.system_includes), 42 | quote_includes = depset(common.quote_includes), 43 | defines = depset(common.host_defines), 44 | local_defines = depset(common.host_local_defines), 45 | ) 46 | 47 | cc_info = cc_common.merge_cc_infos( 48 | direct_cc_infos = [CcInfo(compilation_context = compilation_ctx)], 49 | cc_infos = [common.transitive_cc_info], 50 | ) 51 | 52 | return [ 53 | # default output is only enabled for rdc_objects, otherwise, when you build with 54 | # 55 | # > bazel build //cuda_objects/that/needs/rdc/... 56 | # 57 | # compiling errors might be trigger due to objects and pic_objects been built if srcs require device link 58 | DefaultInfo( 59 | files = depset(transitive = [ 60 | # objects, 61 | # pic_objects, 62 | rdc_objects, 63 | # rdc_pic_objects, 64 | ]), 65 | ), 66 | OutputGroupInfo( 67 | objects = objects, 68 | pic_objects = pic_objects, 69 | rdc_objects = rdc_objects, 70 | rdc_pic_objects = rdc_pic_objects, 71 | ), 72 | CcInfo( 73 | compilation_context = cc_info.compilation_context, 74 | linking_context = cc_info.linking_context, 75 | ), 76 | cuda_helper.create_cuda_info( 77 | defines = depset(common.defines), 78 | objects = objects, 79 | pic_objects = pic_objects, 80 | rdc_objects = rdc_objects, 81 | rdc_pic_objects = rdc_pic_objects, 82 | archive_objects = archive_objects, 83 | archive_pic_objects = archive_pic_objects, 84 | archive_rdc_objects = archive_rdc_objects, 85 | archive_rdc_pic_objects = archive_rdc_pic_objects, 86 | ), 87 | ] 88 | 89 | cuda_objects = rule( 90 | doc = """This rule produces incomplete object files that can only be consumed by `cuda_library`. It is created for relocatable device 91 | code and device link time optimization source files.""", 92 | implementation = _cuda_objects_impl, 93 | attrs = { 94 | "srcs": attr.label_list(allow_files = ALLOW_CUDA_SRCS + ALLOW_CUDA_HDRS), 95 | "hdrs": attr.label_list(allow_files = ALLOW_CUDA_HDRS), 96 | "deps": attr.label_list(providers = [[CcInfo], [CudaInfo]]), 97 | "includes": attr.string_list(doc = "List of include dirs to be added to the compile line."), 98 | # host_* attrs will be passed transitively to cc_* and cuda_* targets 99 | "host_copts": attr.string_list(doc = "Add these options to the CUDA host compilation command."), 100 | "host_defines": attr.string_list(doc = "List of defines to add to the compile line."), 101 | "host_local_defines": attr.string_list(doc = "List of defines to add to the compile line, but only apply to this rule."), 102 | # non-host attrs will be passed transitively to cuda_* targets only. 103 | "copts": attr.string_list(doc = "Add these options to the CUDA device compilation command."), 104 | "defines": attr.string_list(doc = "List of defines to add to the compile line."), 105 | "local_defines": attr.string_list(doc = "List of defines to add to the compile line, but only apply to this rule."), 106 | "ptxasopts": attr.string_list(doc = "Add these flags to the ptxas command."), 107 | "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), # legacy behaviour 108 | "_default_cuda_copts": attr.label(default = "//cuda:copts"), 109 | "_default_host_copts": attr.label(default = "//cuda:host_copts"), 110 | "_default_cuda_archs": attr.label(default = "//cuda:archs"), 111 | }, 112 | fragments = ["cpp"], 113 | toolchains = use_cpp_toolchain() + use_cuda_toolchain(), 114 | provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo], 115 | ) 116 | -------------------------------------------------------------------------------- /cuda/private/rules/cuda_toolkit_info.bzl: -------------------------------------------------------------------------------- 1 | load("//cuda/private:providers.bzl", "CudaToolkitInfo") 2 | 3 | def _impl(ctx): 4 | version_major, version_minor = ctx.attr.version.split(".")[:2] 5 | return CudaToolkitInfo( 6 | path = ctx.attr.path, 7 | version_major = int(version_major), 8 | version_minor = int(version_minor), 9 | nvlink = ctx.file.nvlink, 10 | link_stub = ctx.file.link_stub, 11 | bin2c = ctx.file.bin2c, 12 | fatbinary = ctx.file.fatbinary, 13 | ) 14 | 15 | cuda_toolkit_info = rule( 16 | doc = """This rule provides CudaToolkitInfo.""", 17 | implementation = _impl, 18 | attrs = { 19 | "path": attr.string(mandatory = True, doc = "Root path to the CUDA Toolkit."), 20 | "version": attr.string(mandatory = True, doc = "Version of the CUDA Toolkit."), 21 | "nvlink": attr.label(allow_single_file = True, doc = "The nvlink executable."), 22 | "link_stub": attr.label(allow_single_file = True, doc = "The link.stub text file."), 23 | "bin2c": attr.label(allow_single_file = True, doc = "The bin2c executable."), 24 | "fatbinary": attr.label(allow_single_file = True, doc = "The fatbinary executable."), 25 | }, 26 | provides = [CudaToolkitInfo], 27 | ) 28 | -------------------------------------------------------------------------------- /cuda/private/rules/flags.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") 2 | load("//cuda/private:cuda_helper.bzl", "cuda_helper") 3 | load("//cuda/private:providers.bzl", "CudaArchsInfo") 4 | 5 | def _cuda_archs_flag_impl(ctx): 6 | specs_str = ctx.build_setting_value 7 | return CudaArchsInfo(arch_specs = cuda_helper.get_arch_specs(specs_str)) 8 | 9 | cuda_archs_flag = rule( 10 | doc = """A build setting for specifying cuda archs to compile for. 11 | 12 | To retain the flexibility of NVCC, the [extended notation](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#extended-notation) is adopted. 13 | 14 | When passing cuda_archs from commandline, its spec grammar is as follows: 15 | 16 | ARCH_SPECS ::= ARCH_SPEC [ ';' ARCH_SPECS ] 17 | ARCH_SPEC ::= [ VIRTUAL_ARCH ':' ] GPU_ARCHS 18 | GPU_ARCHS ::= GPU_ARCH [ ',' GPU_ARCHS ] 19 | GPU_ARCH ::= 'sm_' ARCH_NUMBER 20 | | 'lto_' ARCH_NUMBER 21 | | VIRTUAL_ARCH 22 | VIRTUAL_ARCH ::= 'compute_' ARCH_NUMBER 23 | | 'lto_' ARCH_NUMBER 24 | ARCH_NUMBER ::= (a string in predefined cuda_archs list) 25 | 26 | E.g.: 27 | 28 | - `compute_80:sm_80,sm_86`: 29 | Use `compute_80` PTX, generate cubin with `sm_80` and `sm_86`, no PTX embedded 30 | - `compute_80:compute_80,sm_80,sm_86`: 31 | Use `compute_80` PTX, generate cubin with `sm_80` and `sm_86`, PTX embedded 32 | - `compute_80:compute_80`: 33 | Embed `compute_80` PTX, fully relay on `ptxas` 34 | - `sm_80,sm_86`: 35 | Same as `compute_80:sm_80,sm_86`, the arch with minimum integer value will be automatically populated. 36 | - `sm_80;sm_86`: 37 | Two specs used. 38 | - `compute_80`: 39 | Same as `compute_80:compute_80` 40 | 41 | Best Practices: 42 | 43 | - Library supports a full range of archs from xx to yy, you should embed the yy PTX 44 | - Library supports a sparse range of archs from xx to yy, you should embed the xx PTX""", 45 | implementation = _cuda_archs_flag_impl, 46 | build_setting = config.string(flag = True), 47 | provides = [CudaArchsInfo], 48 | ) 49 | 50 | def _repeatable_string_flag_impl(ctx): 51 | flags = ctx.build_setting_value 52 | if (flags == [""]): 53 | flags = [] 54 | return BuildSettingInfo(value = flags) 55 | 56 | repeatable_string_flag = rule( 57 | implementation = _repeatable_string_flag_impl, 58 | build_setting = config.string(flag = True, allow_multiple = True), 59 | provides = [BuildSettingInfo], 60 | ) 61 | -------------------------------------------------------------------------------- /cuda/private/template_helper.bzl: -------------------------------------------------------------------------------- 1 | load("//cuda/private:compat.bzl", "components_mapping_compat") 2 | load("//cuda/private:templates/registry.bzl", "REGISTRY") 3 | 4 | def _to_forward_slash(s): 5 | return s.replace("\\", "/") 6 | 7 | def _is_linux(ctx): 8 | return ctx.os.name.startswith("linux") 9 | 10 | def _is_windows(ctx): 11 | return ctx.os.name.lower().startswith("windows") 12 | 13 | def _generate_build_impl(repository_ctx, libpath, components, is_cuda_repo, is_deliverable): 14 | # stitch template fragment 15 | fragments = [ 16 | Label("//cuda/private:templates/BUILD.cuda_shared"), 17 | Label("//cuda/private:templates/BUILD.cuda_headers"), 18 | Label("//cuda/private:templates/BUILD.cuda_build_setting"), 19 | ] 20 | if is_cuda_repo and not is_deliverable: # generate `@cuda//BUILD` for local host CTK 21 | fragments.extend([Label("//cuda/private:templates/BUILD.{}".format(c)) for c in components]) 22 | elif is_cuda_repo and is_deliverable: # generate `@cuda//BUILD` for CTK with deliverables 23 | pass 24 | elif not is_cuda_repo and is_deliverable: # generate `@cuda_//BUILD` for a deliverable 25 | if len(components) != 1: 26 | fail("one deliverable at a time") 27 | fragments.append(Label("//cuda/private:templates/BUILD.{}".format(components.keys()[0]))) 28 | else: 29 | fail("unreachable") 30 | 31 | template_content = [] 32 | for frag in fragments: 33 | template_content.append("# Generated from fragment " + str(frag)) 34 | template_content.append(repository_ctx.read(frag)) 35 | 36 | if is_cuda_repo and is_deliverable: # generate `@cuda//BUILD` for CTK with deliverables 37 | for comp in components: 38 | for target in REGISTRY[comp]: 39 | repo = components_mapping_compat.repo_str(components[comp]) 40 | line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo) 41 | template_content.append(line) 42 | 43 | # add an empty line to separate aliased targets from different components 44 | template_content.append("") 45 | 46 | template_content = "\n".join(template_content) 47 | 48 | template_path = repository_ctx.path("BUILD.tpl") 49 | repository_ctx.file(template_path, content = template_content, executable = False) 50 | 51 | substitutions = { 52 | "%{component_name}": "cuda" if is_cuda_repo else components.keys()[0], 53 | "%{libpath}": libpath, 54 | } 55 | repository_ctx.template("BUILD", template_path, substitutions = substitutions, executable = False) 56 | 57 | def _generate_build(repository_ctx, libpath, components = None, is_cuda_repo = True, is_deliverable = False): 58 | """Generate `@cuda//BUILD` or `@cuda_//BUILD` 59 | 60 | Notes: 61 | - is_cuda_repo==False and is_deliverable==False is an error 62 | - is_cuda_repo==True and is_deliverable==False generate `@cuda//BUILD` for local host CTK 63 | - is_cuda_repo==True and is_deliverable==True generate `@cuda//BUILD` for CTK with deliverables 64 | - is_cuda_repo==False and is_deliverable==True generate `@cuda_//BUILD` for a deliverable 65 | 66 | Args: 67 | repository_ctx: repository_ctx 68 | libpath: substitution of %{libpath} 69 | components: dict[str, str], the components of CTK to be included, mappeed to the repo names for the components 70 | is_cuda_repo: See Notes, True for @cuda generation, False for @cuda_ generation. 71 | is_deliverable: See Notes 72 | """ 73 | 74 | if is_cuda_repo and not is_deliverable: 75 | if components == None: 76 | components = [c for c in REGISTRY if len(REGISTRY[c]) > 0] 77 | else: 78 | for c in components: 79 | if c not in REGISTRY: 80 | fail("{} is not a valid component") 81 | 82 | _generate_build_impl(repository_ctx, libpath, components, is_cuda_repo, is_deliverable) 83 | 84 | def _generate_defs_bzl(repository_ctx, is_local_ctk): 85 | tpl_label = Label("//cuda/private:templates/defs.bzl.tpl") 86 | substitutions = { 87 | "%{is_local_ctk}": str(is_local_ctk), 88 | } 89 | repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False) 90 | 91 | def _generate_redist_bzl(repository_ctx, component_specs, redist_version): 92 | """Generate `@rules_cuda_redist_json//:redist.bzl` 93 | 94 | Args: 95 | repository_ctx: repository_ctx 96 | component_specs: list of dict, dict keys are component_name, urls, sha256, strip_prefix and version 97 | """ 98 | 99 | rules_cuda_components_body = [] 100 | mapping = {} 101 | 102 | component_tpl = """cuda_component( 103 | name = "{repo_name}", 104 | component_name = "{component_name}", 105 | descriptive_name = "{descriptive_name}", 106 | sha256 = {sha256}, 107 | strip_prefix = {strip_prefix}, 108 | urls = {urls}, 109 | version = "{version}", 110 | )""" 111 | 112 | for spec in component_specs: 113 | repo_name = "cuda_" + spec["component_name"] 114 | version = spec.get("version", None) 115 | if version != None: 116 | repo_name = repo_name + "_v" + version 117 | 118 | rules_cuda_components_body.append( 119 | component_tpl.format( 120 | repo_name = repo_name, 121 | component_name = spec["component_name"], 122 | descriptive_name = spec["descriptive_name"], 123 | sha256 = repr(spec["sha256"]), 124 | strip_prefix = repr(spec["strip_prefix"]), 125 | urls = repr(spec["urls"]), 126 | version = spec["version"], 127 | ), 128 | ) 129 | mapping[spec["component_name"]] = "@" + repo_name 130 | 131 | tpl_label = Label("//cuda/private:templates/redist.bzl.tpl") 132 | substitutions = { 133 | "%{rules_cuda_components_body}": "\n\n ".join(rules_cuda_components_body), 134 | "%{components_mapping}": repr(mapping), 135 | "%{version}": redist_version, 136 | } 137 | repository_ctx.template("redist.bzl", tpl_label, substitutions = substitutions, executable = False) 138 | 139 | def _generate_toolchain_build(repository_ctx, cuda): 140 | tpl_label = Label( 141 | "//cuda/private:templates/BUILD.toolchain_" + 142 | ("nvcc" if _is_linux(repository_ctx) else "nvcc_msvc"), 143 | ) 144 | substitutions = { 145 | "%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found", 146 | "%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor), 147 | "%{nvcc_version_major}": str(cuda.nvcc_version_major), 148 | "%{nvcc_version_minor}": str(cuda.nvcc_version_minor), 149 | "%{nvcc_label}": cuda.nvcc_label, 150 | "%{nvlink_label}": cuda.nvlink_label, 151 | "%{link_stub_label}": cuda.link_stub_label, 152 | "%{bin2c_label}": cuda.bin2c_label, 153 | "%{fatbinary_label}": cuda.fatbinary_label, 154 | } 155 | env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None)) 156 | if env_tmp != None: 157 | substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp) 158 | repository_ctx.template("toolchain/BUILD", tpl_label, substitutions = substitutions, executable = False) 159 | 160 | def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path_or_label): 161 | tpl_label = Label("//cuda/private:templates/BUILD.toolchain_clang") 162 | compiler_attr_line = "" 163 | clang_path_for_subst = "" 164 | clang_label_for_subst = "" 165 | 166 | compiler_use_cc_toolchain_env = repository_ctx.os.environ.get("CUDA_COMPILER_USE_CC_TOOLCHAIN", "false") 167 | if compiler_use_cc_toolchain_env == "true": 168 | compiler_attr_line = "compiler_use_cc_toolchain = True," 169 | elif clang_path_or_label != None and (clang_path_or_label.startswith("//") or clang_path_or_label.startswith("@")): 170 | # Use compiler_label 171 | compiler_attr_line = 'compiler_label = "%{{clang_label}}",' 172 | clang_label_for_subst = clang_path_or_label 173 | else: 174 | # Use compiler_executable 175 | compiler_attr_line = 'compiler_executable = "%{{clang_path}}",' 176 | clang_path_for_subst = _to_forward_slash(clang_path_or_label) if clang_path_or_label else "cuda-clang-path-not-found" 177 | 178 | compiler_attr_line = compiler_attr_line.format( 179 | clang_label = "%{clang_label}", 180 | clang_path = "%{clang_path}", 181 | ) 182 | 183 | cuda_path_for_subst = "" 184 | if cuda.path: 185 | cuda_path_for_subst = _to_forward_slash(cuda.path) 186 | else: 187 | cuda_path_for_subst = "{}/clang_compiler_deps".format(Label("@cuda//cuda").workspace_root) 188 | 189 | substitutions = { 190 | "# %{compiler_attribute_line}": compiler_attr_line, 191 | "%{clang_compiler_files}": "@cuda//:compiler_deps" if cuda.path else "@cuda//clang_compiler_deps", 192 | "%{clang_path}": clang_path_for_subst, # Will be empty if label is used 193 | "%{clang_label}": clang_label_for_subst, # Will be empty if path is used 194 | "%{cuda_path}": cuda_path_for_subst, 195 | "%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor), 196 | "%{nvcc_label}": cuda.nvcc_label, 197 | "%{nvlink_label}": cuda.nvlink_label, 198 | "%{link_stub_label}": cuda.link_stub_label, 199 | "%{bin2c_label}": cuda.bin2c_label, 200 | "%{fatbinary_label}": cuda.fatbinary_label, 201 | } 202 | 203 | if clang_label_for_subst: 204 | substitutions.pop("%{clang_path}") 205 | if clang_path_for_subst: 206 | substitutions.pop("%{clang_label}") 207 | 208 | repository_ctx.template( 209 | "toolchain/clang/BUILD", 210 | tpl_label, 211 | substitutions = substitutions, 212 | executable = False, 213 | ) 214 | 215 | template_helper = struct( 216 | generate_build = _generate_build, 217 | generate_defs_bzl = _generate_defs_bzl, 218 | generate_redist_bzl = _generate_redist_bzl, 219 | generate_toolchain_build = _generate_toolchain_build, 220 | generate_toolchain_clang_build = _generate_toolchain_clang_build, 221 | ) 222 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cccl: -------------------------------------------------------------------------------- 1 | cc_library( 2 | name = "cub", 3 | hdrs = glob( 4 | ["%{component_name}/include/cub/**"], 5 | allow_empty = True, 6 | ), 7 | includes = [ 8 | "%{component_name}/include", 9 | ], 10 | ) 11 | 12 | cc_library( 13 | name = "thrust", 14 | hdrs = glob( 15 | ["%{component_name}/include/thrust/**"], 16 | allow_empty = True, 17 | ), 18 | includes = [ 19 | "%{component_name}/include", 20 | ], 21 | deps = [":cub"], 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.clang_compiler_deps: -------------------------------------------------------------------------------- 1 | # clang needs these files at a single location and be passed as `cuda-path` arg. 2 | # These include, lib, bin, nvvm files are collected from cccl, nvvm, nvcc, cudart in cuda/private/repositories.bzl 3 | filegroup( 4 | name = "clang_compiler_deps", 5 | srcs = glob([ 6 | "bin/**", 7 | "include/**", 8 | "lib/**", 9 | "nvvm/**", 10 | ]), 11 | visibility = ["//visibility:public"], 12 | ) 13 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cublas: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "cublas_so", 3 | shared_library = "%{component_name}/%{libpath}/libcublas.so", 4 | ) 5 | 6 | cc_import_versioned_sos( 7 | name = "cublasLt_so", 8 | shared_library = "%{component_name}/%{libpath}/libcublasLt.so", 9 | ) 10 | 11 | cc_import( 12 | name = "cublas_lib", 13 | interface_library = "%{component_name}/%{libpath}/x64/cublas.lib", 14 | system_provided = 1, 15 | target_compatible_with = ["@platforms//os:windows"], 16 | ) 17 | 18 | cc_import( 19 | name = "cublasLt_lib", 20 | interface_library = "%{component_name}/%{libpath}/x64/cublasLt.lib", 21 | system_provided = 1, 22 | target_compatible_with = ["@platforms//os:windows"], 23 | ) 24 | 25 | cc_library( 26 | name = "cublas", 27 | deps = [ 28 | ":%{component_name}_headers", 29 | ] + if_linux([ 30 | ":cublasLt_so", 31 | ":cublas_so", 32 | ]) + if_windows([ 33 | ":cublasLt_lib", 34 | ":cublas_lib", 35 | ]), 36 | ) 37 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cuda_build_setting: -------------------------------------------------------------------------------- 1 | bool_setting( 2 | name = "valid_toolchain_found", 3 | build_setting_default = True, 4 | ) 5 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cuda_disabled: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//rules:common_settings.bzl", "bool_setting") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | bool_setting( 8 | name = "valid_toolchain_found", 9 | build_setting_default = False, 10 | ) 11 | 12 | filegroup( 13 | name = "compiler_deps", 14 | srcs = [], 15 | ) 16 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cuda_headers: -------------------------------------------------------------------------------- 1 | filegroup( 2 | name = "%{component_name}_header_files", 3 | srcs = glob( 4 | ["%{component_name}/include/**"], 5 | allow_empty = True, 6 | ), 7 | visibility = ["//visibility:private"], 8 | ) 9 | 10 | cc_library( 11 | name = "%{component_name}_headers", 12 | hdrs = [":%{component_name}_header_files"], 13 | includes = ["%{component_name}/include"], 14 | ) 15 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cuda_shared: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//rules:common_settings.bzl", "bool_setting") # @unused 2 | load("@cuda//:defs.bzl", "additional_header_deps", "if_local_cuda_toolkit") # @unused 3 | load("@rules_cuda//cuda:defs.bzl", "cc_import_versioned_sos", "if_linux", "if_windows") # @unused 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | ) 8 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cudart: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "cudart_so", 3 | shared_library = "%{component_name}/%{libpath}/libcudart.so", 4 | ) 5 | 6 | cc_library( 7 | name = "cudadevrt_a", 8 | srcs = ["%{component_name}/%{libpath}/libcudadevrt.a"], 9 | target_compatible_with = ["@platforms//os:linux"], 10 | ) 11 | 12 | cc_library( 13 | name = "culibos_a", 14 | srcs = ["%{component_name}/%{libpath}/libculibos.a"], 15 | target_compatible_with = ["@platforms//os:linux"], 16 | ) 17 | 18 | cc_import( 19 | name = "cudart_lib", 20 | interface_library = "%{component_name}/%{libpath}/x64/cudart.lib", 21 | system_provided = 1, 22 | target_compatible_with = ["@platforms//os:windows"], 23 | ) 24 | 25 | cc_import( 26 | name = "cudadevrt_lib", 27 | interface_library = "%{component_name}/%{libpath}/x64/cudadevrt.lib", 28 | system_provided = 1, 29 | target_compatible_with = ["@platforms//os:windows"], 30 | ) 31 | 32 | # Note: do not use this target directly, use the configurable label_flag 33 | # @rules_cuda//cuda:runtime instead. 34 | cc_library( 35 | name = "cuda_runtime", 36 | linkopts = if_linux([ 37 | "-ldl", 38 | "-lpthread", 39 | "-lrt", 40 | ]), 41 | deps = additional_header_deps("cudart") + [ 42 | ":%{component_name}_headers", 43 | ] + if_linux([ 44 | # devrt is required for jit linking when rdc is enabled 45 | ":cudadevrt_a", 46 | ":culibos_a", 47 | ":cudart_so", 48 | ]) + if_windows([ 49 | # devrt is required for jit linking when rdc is enabled 50 | ":cudadevrt_lib", 51 | ":cudart_lib", 52 | ]), 53 | # FIXME: 54 | # visibility = ["@rules_cuda//cuda:__pkg__"], 55 | ) 56 | 57 | # Note: do not use this target directly, use the configurable label_flag 58 | # @rules_cuda//cuda:runtime instead. 59 | cc_library( 60 | name = "cuda_runtime_static", 61 | srcs = ["%{component_name}/%{libpath}/libcudart_static.a"], 62 | hdrs = [":%{component_name}_header_files"], 63 | includes = ["%{component_name}/include"], 64 | linkopts = if_linux([ 65 | "-ldl", 66 | "-lpthread", 67 | "-lrt", 68 | ]), 69 | deps = additional_header_deps("cudart") + [":cudadevrt_a"], 70 | # FIXME: 71 | # visibility = ["@rules_cuda//cuda:__pkg__"], 72 | ) 73 | 74 | cc_library( 75 | name = "no_cuda_runtime", 76 | # FIXME: 77 | # visibility = ["@rules_cuda//cuda:__pkg__"], 78 | ) 79 | 80 | cc_import( 81 | name = "cuda_so", 82 | shared_library = "%{component_name}/%{libpath}/stubs/libcuda.so", 83 | target_compatible_with = ["@platforms//os:linux"], 84 | ) 85 | 86 | cc_import( 87 | name = "cuda_lib", 88 | interface_library = "%{component_name}/%{libpath}/x64/cuda.lib", 89 | system_provided = 1, 90 | target_compatible_with = ["@platforms//os:windows"], 91 | ) 92 | 93 | cc_library( 94 | name = "cuda", 95 | deps = [ 96 | ":%{component_name}_headers", 97 | ] + if_linux([ 98 | ":cuda_so", 99 | ]) + if_windows([ 100 | ":cuda_lib", 101 | ]), 102 | ) 103 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cufft: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "cufft_so", 3 | shared_library = "%{component_name}/%{libpath}/libcufft.so", 4 | ) 5 | 6 | cc_import( 7 | name = "cufft_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/cufft.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_import_versioned_sos( 14 | name = "cufftw_so", 15 | shared_library = "%{component_name}/%{libpath}/libcufftw.so", 16 | ) 17 | 18 | cc_import( 19 | name = "cufftw_lib", 20 | interface_library = "%{component_name}/%{libpath}/x64/cufftw.lib", 21 | system_provided = 1, 22 | target_compatible_with = ["@platforms//os:windows"], 23 | ) 24 | 25 | cc_library( 26 | name = "cufft", 27 | deps = [ 28 | ":%{component_name}_headers", 29 | ] + if_linux([ 30 | ":cufft_so", 31 | ":cufftw_so", 32 | ]) + if_windows([ 33 | ":cufft_lib", 34 | ":cufftw_lib", 35 | ]), 36 | ) 37 | 38 | cc_import( 39 | name = "cufftw_static_a", 40 | static_library = "%{component_name}/%{libpath}/libcufftw_static.a", 41 | target_compatible_with = ["@platforms//os:linux"], 42 | ) 43 | 44 | cc_import( 45 | name = "cufft_static_a", 46 | static_library = "%{component_name}/%{libpath}/libcufft_static.a", 47 | target_compatible_with = ["@platforms//os:linux"], 48 | ) 49 | 50 | cc_import( 51 | name = "cufft_static_nocallback_a", 52 | static_library = "%{component_name}/%{libpath}/libcufft_static_nocallback.a", 53 | target_compatible_with = ["@platforms//os:linux"], 54 | ) 55 | 56 | cc_library( 57 | name = "cufftw_static", 58 | deps = [ 59 | ":%{component_name}_headers", 60 | ] + if_linux([ 61 | ":cufftw_static_a", 62 | ]), 63 | ) 64 | 65 | cc_library( 66 | name = "cufft_static", 67 | deps = [ 68 | ":%{component_name}_headers", 69 | ] + if_linux([ 70 | ":cufftw_static_a", 71 | ":cufft_static_a", 72 | ]), 73 | ) 74 | 75 | cc_library( 76 | name = "cufft_static_nocallback", 77 | deps = [ 78 | ":%{component_name}_headers", 79 | ] + if_linux([ 80 | ":cufftw_static_a", 81 | ":cufft_static_nocallback_a", 82 | ]), 83 | ) 84 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cufile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/cuda/private/templates/BUILD.cufile -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cupti: -------------------------------------------------------------------------------- 1 | # CUPTI 2 | cc_import_versioned_sos( 3 | name = "cupti_so", 4 | shared_library = "%{component_name}/**/%{libpath}/libcupti.so", 5 | ) 6 | 7 | cc_import( 8 | name = "cupti_lib", 9 | interface_library = "%{component_name}/extras/CUPTI/lib64/cupti.lib", 10 | system_provided = 1, 11 | target_compatible_with = ["@platforms//os:windows"], 12 | ) 13 | 14 | cc_library( 15 | name = "cupti_headers", 16 | hdrs = glob( 17 | ["%{component_name}/extras/CUPTI/include/*.h"], 18 | allow_empty = True, 19 | ), 20 | includes = ["%{component_name}/extras/CUPTI/include"], 21 | ) 22 | 23 | cc_library( 24 | name = "cupti", 25 | deps = [ 26 | ":%{component_name}_headers", 27 | ] + if_linux([ 28 | ":cupti_so", 29 | ]) + if_windows([ 30 | ":cupti_headers", 31 | ":cupti_lib", 32 | ]), 33 | ) 34 | 35 | # nvperf 36 | cc_import( 37 | name = "nvperf_host_so", 38 | shared_library = "%{component_name}/%{libpath}/libnvperf_host.so", 39 | target_compatible_with = ["@platforms//os:linux"], 40 | ) 41 | 42 | cc_import( 43 | name = "nvperf_host_lib", 44 | interface_library = "%{component_name}/extras/CUPTI/lib64/nvperf_host.lib", 45 | system_provided = 1, 46 | target_compatible_with = ["@platforms//os:windows"], 47 | ) 48 | 49 | cc_library( 50 | name = "nvperf_host", 51 | deps = [ 52 | ":%{component_name}_headers", 53 | ] + if_linux([ 54 | ":nvperf_host_so", 55 | ]) + if_windows([ 56 | ":cupti_headers", 57 | ":nvperf_host_lib", 58 | ]), 59 | ) 60 | 61 | cc_import( 62 | name = "nvperf_target_so", 63 | shared_library = "%{component_name}/%{libpath}/libnvperf_target.so", 64 | target_compatible_with = ["@platforms//os:linux"], 65 | ) 66 | 67 | cc_import( 68 | name = "nvperf_target_lib", 69 | interface_library = "%{component_name}/extras/CUPTI/lib64/nvperf_target.lib", 70 | system_provided = 1, 71 | target_compatible_with = ["@platforms//os:windows"], 72 | ) 73 | 74 | cc_library( 75 | name = "nvperf_target", 76 | deps = [ 77 | ":%{component_name}_headers", 78 | ] + if_linux([ 79 | ":nvperf_target_so", 80 | ]) + if_windows([ 81 | ":cupti_headers", 82 | ":nvperf_target_lib", 83 | ]), 84 | ) 85 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.curand: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "curand_so", 3 | shared_library = "%{component_name}/%{libpath}/libcurand.so", 4 | ) 5 | 6 | cc_import( 7 | name = "curand_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/curand.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "curand", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":curand_so", 19 | ]) + if_windows([ 20 | ":curand_lib", 21 | ]), 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cusolver: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "cusolver_so", 3 | shared_library = "%{component_name}/%{libpath}/libcusolver.so", 4 | ) 5 | 6 | cc_import( 7 | name = "cusolver_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/cusolver.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "cusolver", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":cusolver_so", 19 | ]) + if_windows([ 20 | ":cusolver_lib", 21 | ]), 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.cusparse: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "cusparse_so", 3 | shared_library = "%{component_name}/%{libpath}/libcusparse.so", 4 | ) 5 | 6 | cc_import( 7 | name = "cusparse_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/cusparse.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "cusparse", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":cusparse_so", 19 | ]) + if_windows([ 20 | ":cusparse_lib", 21 | ]), 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.npp: -------------------------------------------------------------------------------- 1 | _NPP_LIBS = { 2 | "nppc": [ 3 | "npp.h", 4 | "nppcore.h", 5 | "nppdefs.h", 6 | ], 7 | "nppial": ["nppi_arithmetic_and_logical_operations.h"], 8 | "nppicc": ["nppi_color_conversion.h"], 9 | "nppidei": ["nppi_data_exchange_and_initialization.h"], 10 | "nppif": ["nppi_filtering_functions.h"], 11 | "nppig": ["nppi_geometry_transforms.h"], 12 | "nppim": ["nppi_morphological_operations.h"], 13 | "nppist": [ 14 | "nppi_statistics_functions.h", 15 | "nppi_linear_transforms.h", 16 | ], 17 | "nppisu": ["nppi_support_functions.h"], 18 | "nppitc": ["nppi_threshold_and_compare_operations.h"], 19 | "npps": [ 20 | "npps_arithmetic_and_logical_operations.h", 21 | "npps_conversion_functions.h", 22 | "npps_filtering_functions.h", 23 | "npps.h", 24 | "npps_initialization.h", 25 | "npps_statistics_functions.h", 26 | "npps_support_functions.h", 27 | ], 28 | } 29 | 30 | [ 31 | cc_import_versioned_sos( 32 | name = name + "_so", 33 | shared_library = "%{component_name}/%{libpath}/lib{}.so".format(name), 34 | ) 35 | for name in _NPP_LIBS.keys() 36 | ] 37 | 38 | [ 39 | cc_import( 40 | name = name + "_lib", 41 | interface_library = "%{component_name}/%{libpath}/x64/{}.lib".format(name), 42 | system_provided = 1, 43 | target_compatible_with = ["@platforms//os:windows"], 44 | ) 45 | for name in _NPP_LIBS.keys() 46 | ] 47 | 48 | [ 49 | cc_library( 50 | name = name, 51 | hdrs = ["%{component_name}/include/" + hdr for hdr in hdrs], 52 | includes = ["%{component_name}/include"], 53 | visibility = ["//visibility:public"], 54 | deps = ([":nppc"] if name != "nppc" else []) + if_linux([ 55 | ":{}_so".format(name), 56 | ]) + if_windows([ 57 | ":{}_lib".format(name), 58 | ]), 59 | ) 60 | for name, hdrs in _NPP_LIBS.items() 61 | ] 62 | 63 | cc_library( 64 | name = "nppi", 65 | hdrs = ["%{component_name}/include/nppi.h"], 66 | deps = [ 67 | ":nppial", 68 | ":nppicc", 69 | ":nppidei", 70 | ":nppif", 71 | ":nppig", 72 | ":nppim", 73 | ":nppist", 74 | ":nppisu", 75 | ":nppitc", 76 | ], 77 | ) 78 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvcc: -------------------------------------------------------------------------------- 1 | filegroup( 2 | name = "compiler_deps", 3 | srcs = [ 4 | ":%{component_name}_header_files", 5 | ] + glob( 6 | [ 7 | "%{component_name}/version.json", 8 | "%{component_name}/version.txt", 9 | "%{component_name}/bin/**", 10 | "%{component_name}/%{libpath}/**", 11 | "%{component_name}/nvvm/**", 12 | ], 13 | allow_empty = True, 14 | ), 15 | ) 16 | 17 | cc_import( 18 | name = "nvptxcompiler_so", 19 | static_library = "%{component_name}/%{libpath}/libnvptxcompiler_static.a", 20 | target_compatible_with = ["@platforms//os:linux"], 21 | ) 22 | 23 | cc_import( 24 | name = "nvptxcompiler_lib", 25 | interface_library = "%{component_name}/%{libpath}/x64/nvptxcompiler_static.lib", 26 | system_provided = 1, 27 | target_compatible_with = ["@platforms//os:windows"], 28 | ) 29 | 30 | cc_library( 31 | name = "nvptxcompiler", 32 | srcs = [], 33 | hdrs = glob( 34 | [ 35 | "%{component_name}/include/fatbinary_section.h", 36 | "%{component_name}/include/nvPTXCompiler.h", 37 | "%{component_name}/include/crt/*", 38 | ], 39 | allow_empty = True, 40 | ), 41 | includes = [ 42 | "%{component_name}/include", 43 | ], 44 | visibility = ["//visibility:public"], 45 | deps = [] + if_linux([ 46 | ":nvptxcompiler_so", 47 | ]) + if_windows([ 48 | ":nvptxcompiler_lib", 49 | ]), 50 | ) 51 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvidia_fs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/cuda/private/templates/BUILD.nvidia_fs -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvjitlink: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "nvJitLink_so", 3 | shared_library = "%{component_name}/%{libpath}/libnvJitLink.so", 4 | ) 5 | 6 | cc_import( 7 | name = "nvJitLink_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/nvJitLink.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "nvJitLink", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":nvJitLink_so", 19 | ]) + if_windows([ 20 | ":nvJitLink_lib", 21 | ]), 22 | ) 23 | 24 | cc_import( 25 | name = "nvJitLink_static_a", 26 | static_library = "%{component_name}/%{libpath}/libnvJitLink_static.a", 27 | target_compatible_with = ["@platforms//os:linux"], 28 | ) 29 | 30 | cc_library( 31 | name = "nvJitLink_static", 32 | deps = [ 33 | ":%{component_name}_headers", 34 | ] + if_linux([ 35 | ":nvJitLink_static_a", 36 | ]), 37 | ) 38 | 39 | alias( 40 | name = "nvjitlink", 41 | actual = ":nvJitLink", 42 | ) 43 | 44 | alias( 45 | name = "nvjitlink_static", 46 | actual = ":nvJitLink", 47 | ) 48 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvjpeg: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "nvjpeg_so", 3 | shared_library = "%{component_name}/%{libpath}/libnvjpeg.so", 4 | ) 5 | 6 | cc_import( 7 | name = "nvjpeg_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/nvjpeg.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "nvjpeg", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":nvjpeg_so", 19 | ]) + if_windows([ 20 | ":nvjpeg_lib", 21 | ]), 22 | ) 23 | 24 | cc_import( 25 | name = "nvjpeg_static_a", 26 | static_library = "%{component_name}/%{libpath}/libnvjpeg_static.a", 27 | target_compatible_with = ["@platforms//os:linux"], 28 | ) 29 | 30 | cc_library( 31 | name = "nvjpeg_static", 32 | deps = [ 33 | ":%{component_name}_headers", 34 | ] + if_linux([ 35 | ":nvjpeg_static_a", 36 | ]), 37 | ) 38 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvml: -------------------------------------------------------------------------------- 1 | cc_import( 2 | name = "nvidia-ml_so", 3 | shared_library = "%{component_name}/%{libpath}/stubs/libnvidia-ml.so", 4 | target_compatible_with = ["@platforms//os:linux"], 5 | ) 6 | 7 | cc_import( 8 | name = "nvml_lib", 9 | interface_library = "%{component_name}/%{libpath}/x64/nvml.lib", 10 | system_provided = 1, 11 | target_compatible_with = ["@platforms//os:windows"], 12 | ) 13 | 14 | cc_library( 15 | name = "nvml", 16 | deps = [ 17 | ":%{component_name}_headers", 18 | ] + if_linux([ 19 | ":nvidia-ml_so", 20 | ]) + if_windows([ 21 | ":nvml_lib", 22 | ]), 23 | ) 24 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvrtc: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "nvrtc_so", 3 | shared_library = "%{component_name}/%{libpath}/libnvrtc.so", 4 | ) 5 | 6 | cc_import( 7 | name = "nvrtc_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/nvrtc.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "nvrtc", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":nvrtc_so", 19 | ]) + if_windows([ 20 | ":nvrtc_lib", 21 | ]), 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.nvtx: -------------------------------------------------------------------------------- 1 | cc_import_versioned_sos( 2 | name = "nvtx_so", 3 | shared_library = "%{component_name}/%{libpath}/libnvToolsExt.so", 4 | ) 5 | 6 | cc_import( 7 | name = "nvtx_lib", 8 | interface_library = "%{component_name}/%{libpath}/x64/libnvToolsExt.lib", 9 | system_provided = 1, 10 | target_compatible_with = ["@platforms//os:windows"], 11 | ) 12 | 13 | cc_library( 14 | name = "nvtx", 15 | deps = [ 16 | ":%{component_name}_headers", 17 | ] + if_linux([ 18 | ":nvtx_so", 19 | ]) + if_windows([ 20 | ":nvtx_lib", 21 | ]), 22 | ) 23 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.redist_json: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | ) 4 | 5 | filegroup( 6 | name = "redist_bzl", 7 | srcs = [":redist.bzl"], 8 | ) 9 | 10 | filegroup( 11 | name = "redist_json", 12 | srcs = [":redist.json"], 13 | ) 14 | 15 | exports_files([ 16 | "redist.bzl", 17 | "redist.json", 18 | ]) 19 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.toolchain_clang: -------------------------------------------------------------------------------- 1 | # This becomes the BUILD file for @cuda//toolchain/clang if LLVM is detected 2 | 3 | load( 4 | "@rules_cuda//cuda:defs.bzl", 5 | "cuda_toolchain", 6 | "cuda_toolkit_info", 7 | cuda_toolchain_config = "cuda_toolchain_config_clang", 8 | ) 9 | 10 | cuda_toolkit_info( 11 | name = "cuda-toolkit", 12 | bin2c = "%{bin2c_label}", 13 | fatbinary = "%{fatbinary_label}", 14 | link_stub = "%{link_stub_label}", 15 | nvlink = "%{nvlink_label}", 16 | path = "%{cuda_path}", 17 | version = "%{cuda_version}", 18 | ) 19 | 20 | cuda_toolchain_config( 21 | name = "clang-local-config", 22 | cuda_toolkit = ":cuda-toolkit", 23 | toolchain_identifier = "clang", 24 | ) 25 | 26 | cuda_toolchain( 27 | name = "clang-local", 28 | # %{compiler_attribute_line} 29 | compiler_files = "%{clang_compiler_files}", 30 | toolchain_config = ":clang-local-config", 31 | ) 32 | 33 | toolchain( 34 | name = "clang-local-toolchain", 35 | target_settings = [ 36 | "@rules_cuda//cuda:is_enabled", 37 | "@rules_cuda//cuda:compiler_is_clang", 38 | ], 39 | toolchain = ":clang-local", 40 | toolchain_type = "@rules_cuda//cuda:toolchain_type", 41 | visibility = ["//visibility:public"], 42 | ) 43 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.toolchain_disabled: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_toolchain_config_disabled") 2 | 3 | config_setting( 4 | name = "cuda_is_disabled", 5 | flag_values = {"@rules_cuda//cuda:enable": "False"}, 6 | ) 7 | 8 | cuda_toolchain_config_disabled(name = "disabled-local") 9 | 10 | toolchain( 11 | name = "disabled-local-toolchain", 12 | target_settings = [":cuda_is_disabled"], 13 | toolchain = ":disabled-local", 14 | toolchain_type = "@rules_cuda//cuda:toolchain_type", 15 | visibility = ["//visibility:public"], 16 | ) 17 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.toolchain_nvcc: -------------------------------------------------------------------------------- 1 | # This becomes the BUILD file for @cuda//toolchain/ under Linux. 2 | 3 | load( 4 | "@rules_cuda//cuda:defs.bzl", 5 | "cuda_toolchain", 6 | "cuda_toolkit_info", 7 | cuda_toolchain_config = "cuda_toolchain_config_nvcc", 8 | ) 9 | 10 | cuda_toolkit_info( 11 | name = "cuda-toolkit", 12 | bin2c = "%{bin2c_label}", 13 | fatbinary = "%{fatbinary_label}", 14 | link_stub = "%{link_stub_label}", 15 | nvlink = "%{nvlink_label}", 16 | path = "%{cuda_path}", 17 | version = "%{cuda_version}", 18 | ) 19 | 20 | cuda_toolchain_config( 21 | name = "nvcc-local-config", 22 | cuda_toolkit = ":cuda-toolkit", 23 | # int("%{foo}") instead of %{foo} to make the file valid syntactically. 24 | nvcc_version_major = int("%{nvcc_version_major}"), 25 | nvcc_version_minor = int("%{nvcc_version_minor}"), 26 | toolchain_identifier = "nvcc", 27 | ) 28 | 29 | cuda_toolchain( 30 | name = "nvcc-local", 31 | compiler_files = "@cuda//:compiler_deps", 32 | compiler_label = "%{nvcc_label}", 33 | toolchain_config = ":nvcc-local-config", 34 | ) 35 | 36 | toolchain( 37 | name = "nvcc-local-toolchain", 38 | exec_compatible_with = [ 39 | "@platforms//os:linux", 40 | ], 41 | target_compatible_with = [ 42 | "@platforms//os:linux", 43 | ], 44 | target_settings = [ 45 | "@rules_cuda//cuda:is_enabled", 46 | "@rules_cuda//cuda:compiler_is_nvcc", 47 | ], 48 | toolchain = ":nvcc-local", 49 | toolchain_type = "@rules_cuda//cuda:toolchain_type", 50 | visibility = ["//visibility:public"], 51 | ) 52 | -------------------------------------------------------------------------------- /cuda/private/templates/BUILD.toolchain_nvcc_msvc: -------------------------------------------------------------------------------- 1 | # This becomes the BUILD file for @cuda//toolchain/ under Windows. 2 | 3 | load( 4 | "@rules_cuda//cuda:defs.bzl", 5 | "cuda_toolchain", 6 | "cuda_toolkit_info", 7 | cuda_toolchain_config = "cuda_toolchain_config_nvcc_msvc", 8 | ) 9 | 10 | cuda_toolkit_info( 11 | name = "cuda-toolkit", 12 | bin2c = "%{bin2c_label}", 13 | fatbinary = "%{fatbinary_label}", 14 | link_stub = "%{link_stub_label}", 15 | nvlink = "%{nvlink_label}", 16 | path = "%{cuda_path}", 17 | version = "%{cuda_version}", 18 | ) 19 | 20 | cuda_toolchain_config( 21 | name = "nvcc-local-config", 22 | cuda_toolkit = ":cuda-toolkit", 23 | msvc_env_tmp = "%{env_tmp}", 24 | # int("%{foo}") instead of %{foo} to make the file valid syntactically. 25 | nvcc_version_major = int("%{nvcc_version_major}"), 26 | nvcc_version_minor = int("%{nvcc_version_minor}"), 27 | toolchain_identifier = "nvcc", 28 | ) 29 | 30 | cuda_toolchain( 31 | name = "nvcc-local", 32 | compiler_label = "%{nvcc_label}", 33 | toolchain_config = ":nvcc-local-config", 34 | ) 35 | 36 | toolchain( 37 | name = "nvcc-local-toolchain", 38 | exec_compatible_with = [ 39 | "@platforms//os:windows", 40 | "@platforms//cpu:x86_64", 41 | ], 42 | target_compatible_with = [ 43 | "@platforms//os:windows", 44 | "@platforms//cpu:x86_64", 45 | ], 46 | target_settings = [ 47 | "@rules_cuda//cuda:is_enabled", 48 | "@rules_cuda//cuda:compiler_is_nvcc", 49 | ], 50 | toolchain = ":nvcc-local", 51 | toolchain_type = "@rules_cuda//cuda:toolchain_type", 52 | visibility = ["//visibility:public"], 53 | ) 54 | -------------------------------------------------------------------------------- /cuda/private/templates/README.md: -------------------------------------------------------------------------------- 1 | ## Template files 2 | 3 | - `BUILD.cuda_shared`: For `cuda` repo (CTK + toolchain) or `cuda_%{component_name}` 4 | - `BUILD.cuda_headers`: For `cuda` repo (CTK + toolchain) or `cuda_%{component_name}` headers 5 | - `BUILD.cuda_build_setting`: For `cuda` repo (CTK + toolchain) build_setting 6 | - `BUILD.cuda_disabled`: For creating a dummy local configuration. 7 | - `BUILD.toolchain_disabled`: For creating a dummy local toolchain. 8 | - `BUILD.toolchain_clang`: For Clang device compilation toolchain. 9 | - `BUILD.toolchain_nvcc`: For NVCC device compilation toolchain. 10 | - `BUILD.toolchain_nvcc_msvc`: For NVCC device compilation with (MSVC as host compiler) toolchain. 11 | - Otherwise, each `BUILD.*` corresponds to a component in CUDA Toolkit. 12 | 13 | ## Repository organization 14 | 15 | We organize the generated repo as follows, for both `cuda` and `cuda_` 16 | 17 | ``` 18 | # bazel unconditionally creates a directory for us 19 | ├── %{component_name}/ # cuda for local ctk, component name otherwise 20 | │ ├── include/ # 21 | │ └── %{libpath}/ # lib or lib64, platform dependent 22 | ├── defs.bzl # generated 23 | ├── BUILD # generated with template_helper 24 | └── WORKSPACE # generated 25 | ``` 26 | 27 | If the repo is `cuda`, we additionally generate toolchain config as follows 28 | 29 | ``` 30 | 31 | └── toolchain/ 32 | ├── BUILD # the default nvcc toolchain 33 | ├── clang/ # the optional clang toolchain 34 | │ └── BUILD # 35 | └── disabled/ # the fallback toolchain 36 | └── BUILD # 37 | ``` 38 | 39 | ## How are component repositories and `@cuda` connected? 40 | 41 | The `registry.bzl` file holds mappings from our (`rules_cuda`) components name to various things. 42 | 43 | The registry serve the following purpose: 44 | 45 | 1. maps our component names to full component names used `redistrib.json` file. 46 | 47 | This is purely for looking up the json files. 48 | 49 | 2. maps our component names to target names to be exposed under `@cuda` repo. 50 | 51 | To expose those targets, we use a `components_mapping` attr from our component names to labels of component 52 | repository (for example, `@cuda_nvcc`) as follows 53 | 54 | ```starlark 55 | # in registry.bzl 56 | ... 57 | "cudart": ["cuda", "cuda_runtime", "cuda_runtime_static"], 58 | ... 59 | 60 | # in WORKSPACE.bazel 61 | cuda_component( 62 | name = "cuda_cudart_v12.6.77", 63 | component_name = "cudart", 64 | ... 65 | ) 66 | 67 | cuda_toolkit( 68 | name = "cuda", 69 | components_mapping = {"cudart": "@cuda_cudart_v12.6.77"}, 70 | ... 71 | ) 72 | ``` 73 | 74 | This basically means the component `cudart` has `cuda`, `cuda_runtime` and `cuda_runtime_static` targets defined. 75 | 76 | - In locally installed CTK, we setup the targets in `@cuda` directly. 77 | - In a deliverable CTK, we setup the targets in `@cuda_cudart_v12.6.77` repo. And alias all targets to 78 | `@cuda` as follows 79 | 80 | ```starlark 81 | alias(name = "cuda", actual = "@cuda_cudart_v12.6.77//:cuda") 82 | alias(name = "cuda_runtime", actual = "@cuda_cudart_v12.6.77//:cuda_runtime") 83 | alias(name = "cuda_runtime_static", actual = "@cuda_cudart_v12.6.77//:cuda_runtime_static") 84 | ``` 85 | 86 | `cuda_component` is in charge of setting up the repo `@cuda_cudart_v12.6.77`. 87 | -------------------------------------------------------------------------------- /cuda/private/templates/defs.bzl.tpl: -------------------------------------------------------------------------------- 1 | def if_local_cuda_toolkit(if_true, if_false = []): 2 | is_local_ctk = %{is_local_ctk} 3 | if is_local_ctk: 4 | return if_true 5 | else: 6 | return if_false 7 | 8 | def if_deliverable_cuda_toolkit(if_true, if_false = []): 9 | return if_local_cuda_toolkit(if_false, if_true) 10 | 11 | def additional_header_deps(component_name): 12 | if component_name == "cudart": 13 | return if_deliverable_cuda_toolkit([ 14 | "@cuda//:nvcc_headers", 15 | "@cuda//:cccl_headers", 16 | ]) 17 | 18 | return [] 19 | -------------------------------------------------------------------------------- /cuda/private/templates/redist.bzl.tpl: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:repositories.bzl", "cuda_component", "rules_cuda_toolchains") 2 | 3 | def rules_cuda_components(): 4 | # See template_helper.generate_redist_bzl(...) for body generation logic 5 | %{rules_cuda_components_body} 6 | 7 | return %{components_mapping} 8 | 9 | def rules_cuda_components_and_toolchains(register_toolchains = False): 10 | components_mapping = rules_cuda_components() 11 | rules_cuda_toolchains( 12 | components_mapping= components_mapping, 13 | register_toolchains = register_toolchains, 14 | version = "%{version}", 15 | ) 16 | -------------------------------------------------------------------------------- /cuda/private/templates/registry.bzl: -------------------------------------------------------------------------------- 1 | # map short component name to consumable targets 2 | REGISTRY = { 3 | "cudart": ["cuda", "cuda_runtime", "cuda_runtime_static"], 4 | "nvcc": ["compiler_deps", "nvptxcompiler", "nvcc_headers"], 5 | "cccl": ["cub", "thrust", "cccl_headers"], 6 | "cublas": ["cublas"], 7 | "cufft": ["cufft", "cufft_static"], 8 | "cufile": [], 9 | "cupti": ["cupti", "nvperf_host", "nvperf_target"], 10 | "curand": ["curand"], 11 | "cusolver": ["cusolver"], 12 | "cusparse": ["cusparse"], 13 | "npp": ["nppc", "nppi", "nppial", "nppicc", "nppidei", "nppif", "nppig", "nppim", "nppist", "nppisu", "nppitc", "npps"], 14 | "nvidia_fs": [], 15 | "nvjitlink": ["nvjitlink", "nvjitlink_static"], 16 | "nvjpeg": ["nvjpeg", "nvjpeg_static"], 17 | "nvml": ["nvml"], 18 | "nvprof": [], 19 | "nvrtc": ["nvrtc"], 20 | "nvtx": ["nvtx"], 21 | } 22 | 23 | # map short component name to full component name 24 | FULL_COMPONENT_NAME = { 25 | "cudart": "cuda_cudart", 26 | "nvcc": "cuda_nvcc", 27 | "cccl": "cuda_cccl", 28 | "cublas": "libcublas", 29 | "cufft": "libcufft", 30 | "cufile": "libcufile", 31 | "cupti": "libcupti", 32 | "curand": "libcurand", 33 | "cusolver": "libcusolver", 34 | "cusparse": "libcusparse", 35 | "npp": "libnpp", 36 | "nvidia_fs": "nvidia_fs", 37 | "nvjitlink": "libnvjitlink", 38 | "nvjpeg": "libnvjpeg", 39 | "nvml": "cuda_nvml_dev", 40 | "nvrtc": "cuda_nvrtc", 41 | "nvtx": "cuda_nvtx", 42 | } 43 | -------------------------------------------------------------------------------- /cuda/private/toolchain.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") 2 | load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo") 3 | load("//cuda/private:toolchain_config_lib.bzl", "config_helper") 4 | 5 | def _cuda_toolchain_impl(ctx): 6 | cc_toolchain = find_cpp_toolchain(ctx) 7 | has_cc_toolchain = cc_toolchain != None 8 | has_compiler_executable = ctx.attr.compiler_executable != None and ctx.attr.compiler_executable != "" 9 | has_compiler_label = ctx.attr.compiler_label != None 10 | 11 | # Validation 12 | # compiler_use_cc_toolchain should be used alone and not along with compiler_executable or compiler_label 13 | if (ctx.attr.compiler_use_cc_toolchain == True) and (has_compiler_executable or has_compiler_label): 14 | fail("compiler_use_cc_toolchain set to True but compiler_executable or compiler_label also set.") 15 | elif (ctx.attr.compiler_use_cc_toolchain == False) and not has_compiler_executable and not has_compiler_label: 16 | fail("Either compiler_executable or compiler_label must be specified or if a valid cc_toolchain is registered, set attr compiler_use_cc_toolchain to True.") 17 | 18 | # First, attempt to use configured cc_toolchain if attr compiler_use_cc_toolchain set. 19 | if (ctx.attr.compiler_use_cc_toolchain == True): 20 | if has_cc_toolchain: 21 | compiler_executable = cc_toolchain.compiler_executable 22 | else: 23 | fail("compiler_use_cc_toolchain set to True but cannot find a configured cc_toolchain") 24 | elif has_compiler_executable: 25 | compiler_executable = ctx.attr.compiler_executable 26 | elif has_compiler_label: 27 | l = ctx.attr.compiler_label.label 28 | compiler_executable = "{}/{}/{}".format(l.workspace_root, l.package, l.name) 29 | 30 | cuda_toolchain_config = ctx.attr.toolchain_config[CudaToolchainConfigInfo] 31 | selectables_info = config_helper.collect_selectables_info(cuda_toolchain_config.action_configs + cuda_toolchain_config.features) 32 | must_have_selectables = [] 33 | for name in must_have_selectables: 34 | if not config_helper.is_configured(selectables_info, name): 35 | fail(name, "is not configured (not exists) in the provided toolchain_config") 36 | 37 | artifact_name_patterns = {} 38 | for pattern in cuda_toolchain_config.artifact_name_patterns: 39 | artifact_name_patterns[pattern.category_name] = pattern 40 | 41 | # construct compiler_depset 42 | compiler_depset = depset() 43 | if ctx.attr.compiler_use_cc_toolchain: 44 | compiler_depset = cc_toolchain.all_files # pass all cc_toolchain to toolchain_files 45 | elif has_compiler_executable: 46 | pass 47 | elif has_compiler_label: 48 | compiler_target_info = ctx.attr.compiler_label[DefaultInfo] 49 | if not compiler_target_info.files_to_run or not compiler_target_info.files_to_run.executable: 50 | fail("compiler_label specified is not an executable, specify a valid compiler_label") 51 | compiler_depset = depset(direct = [compiler_target_info.files_to_run.executable], transitive = [compiler_target_info.default_runfiles.files]) 52 | 53 | toolchain_files = depset(transitive = [ 54 | compiler_depset, 55 | ctx.attr.compiler_files.files if ctx.attr.compiler_files else depset(), 56 | ]) 57 | 58 | return [ 59 | platform_common.ToolchainInfo( 60 | name = ctx.label.name, 61 | compiler_executable = compiler_executable, 62 | all_files = toolchain_files, 63 | selectables_info = selectables_info, 64 | artifact_name_patterns = artifact_name_patterns, 65 | cuda_toolkit = cuda_toolchain_config.cuda_toolkit, 66 | ), 67 | ] 68 | 69 | cuda_toolchain = rule( 70 | doc = """This rule consumes a `CudaToolchainConfigInfo` and provides a `platform_common.ToolchainInfo`, a.k.a, the CUDA Toolchain.""", 71 | implementation = _cuda_toolchain_impl, 72 | toolchains = use_cpp_toolchain(), 73 | attrs = { 74 | "toolchain_config": attr.label( 75 | mandatory = True, 76 | providers = [CudaToolchainConfigInfo], 77 | doc = "A target that provides a `CudaToolchainConfigInfo`.", 78 | ), 79 | "compiler_use_cc_toolchain": attr.bool(default = False, doc = "Use existing cc_toolchain if configured as the compiler executable. Overrides compiler_executable or compiler_label"), 80 | "compiler_executable": attr.string(doc = "The path of the main executable of this toolchain. Either compiler_executable or compiler_label must be specified if compiler_use_cc_toolchain is not set."), 81 | "compiler_label": attr.label(allow_single_file = True, executable = True, cfg = "exec", doc = "The label of the main executable of this toolchain. Either compiler_executable or compiler_label must be specified."), 82 | "compiler_files": attr.label(allow_files = True, cfg = "exec", doc = "The set of files that are needed when compiling using this toolchain."), 83 | "_cc_toolchain": attr.label(default = Label("@bazel_tools//tools/cpp:current_cc_toolchain")), 84 | }, 85 | ) 86 | 87 | CUDA_TOOLCHAIN_TYPE = "//cuda:toolchain_type" 88 | 89 | # buildifier: disable=unused-variable 90 | def use_cuda_toolchain(): 91 | """Helper to depend on the CUDA toolchain.""" 92 | return [CUDA_TOOLCHAIN_TYPE] 93 | 94 | def find_cuda_toolchain(ctx): 95 | """Helper to get the cuda toolchain from context object. 96 | 97 | Args: 98 | ctx: The rule context for which to find a toolchain. 99 | 100 | Returns: 101 | A `platform_common.ToolchainInfo` that wraps around the necessary information of a cuda toolchain. 102 | """ 103 | return ctx.toolchains[CUDA_TOOLCHAIN_TYPE] 104 | 105 | def find_cuda_toolkit(ctx): 106 | """Finds the CUDA toolchain. 107 | 108 | Args: 109 | ctx: The rule context for which to find a toolchain. 110 | 111 | Returns: 112 | A CudaToolkitInfo. 113 | """ 114 | return ctx.toolchains[CUDA_TOOLCHAIN_TYPE].cuda_toolkit[CudaToolkitInfo] 115 | 116 | # buildifier: disable=unnamed-macro 117 | def register_detected_cuda_toolchains(): 118 | """Helper to register the automatically detected CUDA toolchain(s). 119 | 120 | User can setup their own toolchain if needed and ignore the detected ones by not calling this macro. 121 | """ 122 | native.register_toolchains( 123 | "@cuda//toolchain:nvcc-local-toolchain", 124 | "@cuda//toolchain/clang:clang-local-toolchain", 125 | "@cuda//toolchain/disabled:disabled-local-toolchain", 126 | ) 127 | -------------------------------------------------------------------------------- /cuda/private/toolchain_configs/disabled.bzl: -------------------------------------------------------------------------------- 1 | def _disabled_toolchain_config_impl(_ctx): 2 | return [platform_common.ToolchainInfo()] 3 | 4 | disabled_toolchain_config = rule(_disabled_toolchain_config_impl, attrs = {}) 5 | -------------------------------------------------------------------------------- /cuda/private/toolchain_configs/utils.bzl: -------------------------------------------------------------------------------- 1 | def nvcc_version_ge(ctx, major, minor): 2 | if ctx.attr.toolchain_identifier != "nvcc": 3 | return False 4 | if ctx.attr.nvcc_version_major < major: 5 | return False 6 | if ctx.attr.nvcc_version_minor < minor: 7 | return False 8 | return True 9 | -------------------------------------------------------------------------------- /cuda/repositories.bzl: -------------------------------------------------------------------------------- 1 | load( 2 | "//cuda/private:repositories.bzl", 3 | _cuda_component = "cuda_component", 4 | _cuda_redist_json = "cuda_redist_json", 5 | _cuda_toolkit = "cuda_toolkit", 6 | _default_components_mapping = "default_components_mapping", 7 | _rules_cuda_dependencies = "rules_cuda_dependencies", 8 | _rules_cuda_toolchains = "rules_cuda_toolchains", 9 | ) 10 | load("//cuda/private:toolchain.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains") 11 | 12 | # rules 13 | cuda_component = _cuda_component 14 | cuda_redist_json = _cuda_redist_json 15 | cuda_toolkit = _cuda_toolkit 16 | 17 | # macros 18 | rules_cuda_dependencies = _rules_cuda_dependencies 19 | rules_cuda_toolchains = _rules_cuda_toolchains 20 | register_detected_cuda_toolchains = _register_detected_cuda_toolchains 21 | default_components_mapping = _default_components_mapping 22 | -------------------------------------------------------------------------------- /docs/.bazelversion: -------------------------------------------------------------------------------- 1 | 7.6.1 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | docs/ 2 | generated/ 3 | site/ 4 | -------------------------------------------------------------------------------- /docs/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 2 | load("@io_bazel_stardoc//stardoc:stardoc.bzl", "stardoc") 3 | 4 | bzl_library( 5 | name = "bazel_tools_bzl_srcs", 6 | srcs = ["@bazel_tools//tools:bzl_srcs"], 7 | ) 8 | 9 | # NOTE: when the `out` is changed, the `nav` part of `mkdocs.yaml` must be also be changed correspondingly. 10 | stardoc( 11 | name = "user_docs", 12 | out = "user/user_docs.md", 13 | input = "user_docs.bzl", 14 | deps = [ 15 | ":bazel_tools_bzl_srcs", 16 | "@rules_cuda//cuda:bzl_srcs", 17 | ], 18 | ) 19 | 20 | stardoc( 21 | name = "toolchain_config_docs", 22 | out = "user/toolchain_config_docs.md", 23 | input = "toolchain_config_docs.bzl", 24 | deps = [ 25 | ":bazel_tools_bzl_srcs", 26 | "@rules_cuda//cuda:bzl_srcs", 27 | ], 28 | ) 29 | 30 | stardoc( 31 | name = "providers_docs", 32 | out = "developer/providers_docs.md", 33 | input = "providers_docs.bzl", 34 | deps = [ 35 | ":bazel_tools_bzl_srcs", 36 | "@rules_cuda//cuda:bzl_srcs", 37 | ], 38 | ) 39 | 40 | stardoc( 41 | name = "developer_docs", 42 | out = "developer/developer_docs.md", 43 | input = "developer_docs.bzl", 44 | deps = [ 45 | ":bazel_tools_bzl_srcs", 46 | "@rules_cuda//cuda:bzl_srcs", 47 | ], 48 | ) 49 | 50 | filegroup( 51 | name = "all_docs", 52 | srcs = [ 53 | ":developer_docs", 54 | ":providers_docs", 55 | ":toolchain_config_docs", 56 | ":user_docs", 57 | ], 58 | ) 59 | -------------------------------------------------------------------------------- /docs/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module( 2 | name = "rules_cuda_docs", 3 | version = "0.0.0", 4 | compatibility_level = 1, 5 | ) 6 | 7 | bazel_dep(name = "bazel_skylib", version = "1.4.2") 8 | bazel_dep(name = "platforms", version = "0.0.6") 9 | bazel_dep(name = "rules_cuda", version = "0.2.3") 10 | local_path_override( 11 | module_name = "rules_cuda", 12 | path = "..", 13 | ) 14 | 15 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 16 | cuda.toolkit( 17 | name = "cuda", 18 | toolkit_path = "", 19 | ) 20 | use_repo(cuda, "cuda") 21 | 22 | bazel_dep( 23 | name = "stardoc", 24 | version = "0.7.0", 25 | repo_name = "io_bazel_stardoc", 26 | ) 27 | -------------------------------------------------------------------------------- /docs/WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | workspace(name = "rules_cuda_docs") 2 | 3 | local_repository( 4 | name = "rules_cuda", 5 | path = "..", 6 | ) 7 | 8 | load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies") 9 | 10 | rules_cuda_dependencies() 11 | 12 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 13 | 14 | http_archive( 15 | name = "io_bazel_stardoc", 16 | sha256 = "ca933f39f2a6e0ad392fa91fd662545afcbd36c05c62365538385d35a0323096", 17 | urls = [ 18 | "https://mirror.bazel.build/github.com/bazelbuild/stardoc/releases/download/0.8.0/stardoc-0.8.0.tar.gz", 19 | "https://github.com/bazelbuild/stardoc/releases/download/0.8.0/stardoc-0.8.0.tar.gz", 20 | ], 21 | ) 22 | 23 | load("@io_bazel_stardoc//:setup.bzl", "stardoc_repositories") 24 | 25 | stardoc_repositories() 26 | 27 | load("@rules_jvm_external//:repositories.bzl", "rules_jvm_external_deps") 28 | 29 | rules_jvm_external_deps() 30 | 31 | load("@rules_jvm_external//:setup.bzl", "rules_jvm_external_setup") 32 | 33 | rules_jvm_external_setup() 34 | 35 | load("@io_bazel_stardoc//:deps.bzl", "stardoc_external_deps") 36 | 37 | stardoc_external_deps() 38 | 39 | load("@stardoc_maven//:defs.bzl", stardoc_pinned_maven_install = "pinned_maven_install") 40 | 41 | stardoc_pinned_maven_install() 42 | -------------------------------------------------------------------------------- /docs/WORKSPACE.bzlmod: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/docs/WORKSPACE.bzlmod -------------------------------------------------------------------------------- /docs/build-docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function prepare-env { 4 | pip install -r requirements.txt 5 | } 6 | 7 | function compose-docs { 8 | rm -rf docs site 9 | mkdir -p docs 10 | cp ../README.md docs/index.md 11 | bazel build :all_docs 12 | rsync -a --prune-empty-dirs --include '*/' mkdocs/stylesheets docs/ 13 | rsync -a --prune-empty-dirs --include '*/' --include '*.md' --exclude '*' bazel-bin/ docs/ 14 | find docs/ -name '*.md' -exec sed -i 's#
#
#g' {} \;
15 |     find docs/ -name '*.md' -exec sed -i 's#
#
#g' {} \; 16 | } 17 | 18 | function compose-versioned-site { 19 | mkdir -p generated 20 | rsync -a --prune-empty-dirs --include '*/' site/ generated/$1/ 21 | python versioning.py generated/ --force 22 | 23 | printf "\nRun following command to update version list then serve locally:\n\n" 24 | printf "\tpython -m http.server -d generated/\n\n" 25 | } 26 | 27 | 28 | CI=${CI:-0} # 1 for CI only logic 29 | 30 | if [ $CI == "1" ]; then 31 | set -ex 32 | prepare-env 33 | compose-docs 34 | mkdocs build 35 | else 36 | if [[ $# -ne 1 ]]; then 37 | printf "Usage: $0 \n" 38 | exit -1 39 | fi 40 | version=$1 41 | 42 | # env should be prepared manually 43 | compose-docs 44 | mkdocs build 45 | compose-versioned-site $version 46 | fi 47 | -------------------------------------------------------------------------------- /docs/developer_docs.bzl: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda/private:actions/compile.bzl", _compile = "compile") 2 | load("@rules_cuda//cuda/private:actions/dlink.bzl", _device_link = "device_link") 3 | load("@rules_cuda//cuda/private:cuda_helper.bzl", _cuda_helper = "cuda_helper") 4 | load( 5 | "@rules_cuda//cuda/private:repositories.bzl", 6 | _config_clang = "config_clang", 7 | _config_cuda_toolkit_and_nvcc = "config_cuda_toolkit_and_nvcc", 8 | _detect_clang = "detect_clang", 9 | _detect_cuda_toolkit = "detect_cuda_toolkit", 10 | ) 11 | load( 12 | "@rules_cuda//cuda/private:toolchain.bzl", 13 | _find_cuda_toolchain = "find_cuda_toolchain", 14 | _find_cuda_toolkit = "find_cuda_toolkit", 15 | _use_cuda_toolchain = "use_cuda_toolchain", 16 | ) 17 | load("@rules_cuda//cuda/private:toolchain_config_lib.bzl", _config_helper = "config_helper") 18 | 19 | # create a struct to group toolchain symbols semantically 20 | toolchain = struct( 21 | use_cuda_toolchain = _use_cuda_toolchain, 22 | find_cuda_toolchain = _find_cuda_toolchain, 23 | find_cuda_toolkit = _find_cuda_toolkit, 24 | ) 25 | 26 | cuda_helper = _cuda_helper 27 | config_helper = _config_helper 28 | 29 | # create a struct to group action symbols semantically 30 | actions = struct( 31 | compile = _compile, 32 | device_link = _device_link, 33 | ) 34 | 35 | # create a struct to group repositories symbols semantically 36 | repositories = struct( 37 | config_clang = _config_clang, 38 | config_cuda_toolkit_and_nvcc = _config_cuda_toolkit_and_nvcc, 39 | detect_clang = _detect_clang, 40 | detect_cuda_toolkit = _detect_cuda_toolkit, 41 | ) 42 | -------------------------------------------------------------------------------- /docs/mkdocs.yaml: -------------------------------------------------------------------------------- 1 | site_name: "rules_cuda: Starlark implementation of bazel rules for CUDA" 2 | repo_url: https://github.com/bazel-contrib/rules_cuda 3 | docs_dir: docs 4 | 5 | theme: 6 | name: material 7 | palette: 8 | - media: "(prefers-color-scheme: light)" 9 | scheme: default 10 | primary: green 11 | locale: en 12 | features: 13 | - navigation.tabs 14 | extra_css: 15 | - stylesheets/extra.css 16 | 17 | extra: 18 | version: 19 | # we are not actually using mike 20 | # just for `versions.json` to be functional 21 | provider: mike 22 | 23 | nav: 24 | - Home: index.md 25 | - User: 26 | - Using the rules: user/user_docs.md 27 | - Configure the toolchain: user/toolchain_config_docs.md 28 | - Developer: 29 | - Providers: developer/providers_docs.md 30 | - Rule Authoring: developer/developer_docs.md 31 | -------------------------------------------------------------------------------- /docs/mkdocs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .md-header { 2 | background: #44a147; 3 | } 4 | 5 | .stardoc-pre { 6 | height: fit-content; 7 | width: inherit; 8 | overflow: auto; 9 | scrollbar-width: thin; 10 | font-size: 0.8em; 11 | } 12 | 13 | .stardoc-pre::-webkit-scrollbar { 14 | height: 0.25em; 15 | width: 0.25em; 16 | } 17 | 18 | .stardoc-pre::-webkit-scrollbar-thumb { 19 | background-color: var(--md-default-fg-color--lighter); 20 | } 21 | -------------------------------------------------------------------------------- /docs/providers_docs.bzl: -------------------------------------------------------------------------------- 1 | load( 2 | "@rules_cuda//cuda/private:providers.bzl", 3 | _ArchSpecInfo = "ArchSpecInfo", 4 | _CudaArchsInfo = "CudaArchsInfo", 5 | _CudaInfo = "CudaInfo", 6 | _CudaToolchainConfigInfo = "CudaToolchainConfigInfo", 7 | _CudaToolkitInfo = "CudaToolkitInfo", 8 | _Stage2ArchInfo = "Stage2ArchInfo", 9 | ) 10 | 11 | ArchSpecInfo = _ArchSpecInfo 12 | Stage2ArchInfo = _Stage2ArchInfo 13 | 14 | CudaArchsInfo = _CudaArchsInfo 15 | CudaInfo = _CudaInfo 16 | CudaToolkitInfo = _CudaToolkitInfo 17 | CudaToolchainConfigInfo = _CudaToolchainConfigInfo 18 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material==9.6.11 2 | -------------------------------------------------------------------------------- /docs/toolchain_config_docs.bzl: -------------------------------------------------------------------------------- 1 | load( 2 | "@rules_cuda//cuda:defs.bzl", 3 | _cuda_toolchain = "cuda_toolchain", 4 | _cuda_toolchain_config_clang = "cuda_toolchain_config_clang", 5 | _cuda_toolchain_config_nvcc = "cuda_toolchain_config_nvcc", 6 | _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config_nvcc_msvc", 7 | _cuda_toolkit_info = "cuda_toolkit_info", 8 | ) 9 | 10 | cuda_toolkit_info = _cuda_toolkit_info 11 | cuda_toolchain = _cuda_toolchain 12 | cuda_toolchain_config_clang = _cuda_toolchain_config_clang 13 | cuda_toolchain_config_nvcc_msvc = _cuda_toolchain_config_nvcc_msvc 14 | cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc 15 | -------------------------------------------------------------------------------- /docs/user_docs.bzl: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", _cuda_binary = "cuda_binary", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test") 2 | load( 3 | "@rules_cuda//cuda:repositories.bzl", 4 | _register_detected_cuda_toolchains = "register_detected_cuda_toolchains", 5 | _rules_cuda_dependencies = "rules_cuda_dependencies", 6 | _rules_cuda_toolchains = "rules_cuda_toolchains", 7 | ) 8 | load("@rules_cuda//cuda/private:rules/flags.bzl", _cuda_archs_flag = "cuda_archs_flag") 9 | 10 | cuda_library = _cuda_library 11 | cuda_objects = _cuda_objects 12 | 13 | cuda_binary = _cuda_binary 14 | cuda_test = _cuda_test 15 | 16 | cuda_archs = _cuda_archs_flag 17 | 18 | register_detected_cuda_toolchains = _register_detected_cuda_toolchains 19 | rules_cuda_dependencies = _rules_cuda_dependencies 20 | rules_cuda_toolchains = _rules_cuda_toolchains 21 | -------------------------------------------------------------------------------- /docs/versioning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from packaging.version import parse as parse_version 5 | 6 | TEMPLATE = """ 7 | 8 | 9 | 10 | 11 | 12 | Redirecting 13 | 16 | 19 | 20 | 21 | Redirecting to {version}/... 22 | 23 | 24 | """ 25 | 26 | 27 | def collect_versions(work_dir): 28 | versioned_dirs = [item.name for item in os.scandir(work_dir) if item.is_dir()] 29 | names = [] 30 | versions = [] 31 | for v in versioned_dirs: 32 | try: 33 | parse_version(v) 34 | versions.append(v) 35 | except: 36 | names.append(v) 37 | 38 | versions.sort(key=lambda v: parse_version(v), reverse=True) 39 | names.sort() 40 | return versions + names 41 | 42 | 43 | def generate_redirect_page(work_dir, version, *, force=False): 44 | output = os.path.join(work_dir, "index.html") 45 | assert force or not os.path.exists(output) 46 | with open(output, "w") as f: 47 | f.write(TEMPLATE.format(version=version)) 48 | 49 | 50 | def generate_version_json(work_dir, versions, *, force=False): 51 | output = os.path.join(work_dir, "versions.json") 52 | assert force or not os.path.exists(output) 53 | with open(output, "w") as f: 54 | json.dump([{"version": v, "title": v, "aliases": []} for v in versions], f) 55 | 56 | 57 | def process(work_dir, default_version=None, *, force=False): 58 | versions = collect_versions(work_dir) 59 | if default_version is None: 60 | default_version = versions[0] 61 | generate_redirect_page(work_dir, default_version, force=force) 62 | generate_version_json(work_dir, versions, force=force) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("work_dir") 68 | parser.add_argument("--default_version", "-d", default=None) 69 | parser.add_argument("--force", "-f", action="store_true") 70 | args = parser.parse_args() 71 | 72 | process(args.work_dir, args.default_version, force=args.force) 73 | -------------------------------------------------------------------------------- /examples/.bazelrc: -------------------------------------------------------------------------------- 1 | common --announce_rc 2 | 3 | # Convenient flag shortcuts. 4 | build --flag_alias=enable_cuda=@rules_cuda//cuda:enable 5 | build --flag_alias=cuda_archs=@rules_cuda//cuda:archs 6 | build --flag_alias=cuda_compiler=@rules_cuda//cuda:compiler 7 | build --flag_alias=cuda_copts=@rules_cuda//cuda:copts 8 | build --flag_alias=cuda_runtime=@rules_cuda//cuda:runtime 9 | 10 | build --enable_cuda=True 11 | 12 | # Use --config=clang to build with clang instead of gcc and nvcc. 13 | build:clang --repo_env=CC=clang 14 | build:clang --@rules_cuda//cuda:compiler=clang 15 | 16 | # https://github.com/bazel-contrib/rules_cuda/issues/1 17 | # build --ui_event_filters=-INFO 18 | -------------------------------------------------------------------------------- /examples/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module( 2 | name = "rules_cuda_examples", 3 | version = "0.0.0", 4 | compatibility_level = 1, 5 | ) 6 | 7 | bazel_dep(name = "rules_cuda", version = "0.2.3") 8 | local_path_override( 9 | module_name = "rules_cuda", 10 | path = "..", 11 | ) 12 | 13 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 14 | cuda.toolkit( 15 | name = "cuda", 16 | toolkit_path = "", 17 | ) 18 | use_repo(cuda, "cuda") 19 | 20 | ################################# 21 | # Dependencies for nccl example # 22 | ################################# 23 | # See WORKSPACE.bzlmod for the remaining parts 24 | bazel_dep(name = "bazel_skylib", version = "1.4.2") 25 | -------------------------------------------------------------------------------- /examples/WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | workspace(name = "rules_cuda_examples") 2 | 3 | local_repository( 4 | name = "rules_cuda", 5 | path = "../", 6 | ) 7 | 8 | ###################### 9 | # rules_bazel setup # 10 | ###################### 11 | # Fetches the rules_bazel dependencies and initializes the cuda toolchain. 12 | # If you want to have a different version of some dependency, 13 | # you should fetch it *before* calling this. 14 | 15 | load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains") 16 | 17 | rules_cuda_dependencies() 18 | 19 | rules_cuda_toolchains(register_toolchains = True) 20 | 21 | ################################# 22 | # Dependencies for nccl example # 23 | ################################# 24 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 25 | 26 | http_archive( 27 | name = "nccl", 28 | build_file = "@rules_cuda_examples//nccl:nccl.BUILD", 29 | sha256 = "83b299cfc2dfe63887dadf3590b3ac2b8b2fd68ec5515b6878774eda39a697d2", 30 | strip_prefix = "nccl-9814c75eea18fc7374cde884592233b6b7dc055b", 31 | urls = ["https://github.com/nvidia/nccl/archive/9814c75eea18fc7374cde884592233b6b7dc055b.tar.gz"], 32 | ) 33 | 34 | http_archive( 35 | name = "nccl-tests", 36 | build_file = "@rules_cuda_examples//nccl:nccl-tests.BUILD", 37 | patch_args = [ 38 | "-p1", 39 | ], 40 | patches = ["@rules_cuda_examples//nccl:nccl-tests-clang.patch"], 41 | sha256 = "946adb84f63aec66aea7aab9739d41df81c24f783e85fba6328ba243cfc057e0", 42 | strip_prefix = "nccl-tests-1a5f551ffd6e3271982b03a9d5653a3f6ba545fa", 43 | urls = ["https://github.com/nvidia/nccl-tests/archive/1a5f551ffd6e3271982b03a9d5653a3f6ba545fa.tar.gz"], 44 | ) 45 | -------------------------------------------------------------------------------- /examples/WORKSPACE.bzlmod: -------------------------------------------------------------------------------- 1 | ################################# 2 | # Dependencies for nccl example # 3 | ################################# 4 | # NOTE: this should have been placed in MODULE.bazel, but use_repo_rule is introduced in Bazel 7. 5 | # See https://github.com/bazelbuild/bazel/issues/17141 6 | 7 | # http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # For MODULE.bazel 8 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # For WORKSPACE.bazel 9 | 10 | http_archive( 11 | name = "nccl", 12 | build_file = "@rules_cuda_examples//nccl:nccl.BUILD", 13 | sha256 = "83b299cfc2dfe63887dadf3590b3ac2b8b2fd68ec5515b6878774eda39a697d2", 14 | strip_prefix = "nccl-9814c75eea18fc7374cde884592233b6b7dc055b", 15 | urls = ["https://github.com/nvidia/nccl/archive/9814c75eea18fc7374cde884592233b6b7dc055b.tar.gz"], 16 | ) 17 | 18 | http_archive( 19 | name = "nccl-tests", 20 | build_file = "@rules_cuda_examples//nccl:nccl-tests.BUILD", 21 | patch_args = [ 22 | "-p1", 23 | ], 24 | patches = ["@rules_cuda_examples//nccl:nccl-tests-clang.patch"], 25 | sha256 = "946adb84f63aec66aea7aab9739d41df81c24f783e85fba6328ba243cfc057e0", 26 | strip_prefix = "nccl-tests-1a5f551ffd6e3271982b03a9d5653a3f6ba545fa", 27 | urls = ["https://github.com/nvidia/nccl-tests/archive/1a5f551ffd6e3271982b03a9d5653a3f6ba545fa.tar.gz"], 28 | ) 29 | -------------------------------------------------------------------------------- /examples/basic/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cuda_library( 6 | name = "kernel", 7 | srcs = ["kernel.cu"], 8 | hdrs = ["kernel.h"], 9 | ) 10 | 11 | cc_binary( 12 | name = "main", 13 | srcs = ["main.cpp"], 14 | deps = [":kernel"], 15 | ) 16 | -------------------------------------------------------------------------------- /examples/basic/kernel.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | #include 4 | 5 | #define CUDA_CHECK(expr) \ 6 | do { \ 7 | cudaError_t err = (expr); \ 8 | if (err != cudaSuccess) { \ 9 | fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \ 10 | err, cudaGetErrorString(err)); \ 11 | exit(err); \ 12 | } \ 13 | } while (0) 14 | 15 | __global__ void kernel() { 16 | printf("cuda kernel called!\n"); 17 | } 18 | 19 | void launch() { 20 | kernel<<<1, 1>>>(); 21 | CUDA_CHECK(cudaGetLastError()); 22 | CUDA_CHECK(cudaDeviceSynchronize()); 23 | } 24 | -------------------------------------------------------------------------------- /examples/basic/kernel.h: -------------------------------------------------------------------------------- 1 | void launch(); 2 | -------------------------------------------------------------------------------- /examples/basic/main.cpp: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | int main() { 4 | launch(); 5 | return 0; 6 | } 7 | -------------------------------------------------------------------------------- /examples/basic_macros/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_binary", "cuda_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cuda_binary( 6 | name = "main", 7 | srcs = ["main.cu"], 8 | ) 9 | 10 | cuda_test( 11 | name = "test", 12 | srcs = ["main.cu"], 13 | ) 14 | -------------------------------------------------------------------------------- /examples/basic_macros/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CUDA_CHECK(expr) \ 4 | do { \ 5 | cudaError_t err = (expr); \ 6 | if (err != cudaSuccess) { \ 7 | fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \ 8 | err, cudaGetErrorString(err)); \ 9 | exit(err); \ 10 | } \ 11 | } while (0) 12 | 13 | __global__ void kernel() { 14 | printf("cuda kernel called!\n"); 15 | } 16 | 17 | void launch() { 18 | kernel<<<1, 1>>>(); 19 | CUDA_CHECK(cudaGetLastError()); 20 | CUDA_CHECK(cudaDeviceSynchronize()); 21 | } 22 | 23 | int main() { 24 | launch(); 25 | return 0; 26 | } 27 | -------------------------------------------------------------------------------- /examples/cublas/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Run with 'bazel run //examples/cublas:main' 2 | cc_binary( 3 | name = "main", 4 | srcs = ["cublas.cpp"], 5 | deps = ["@cuda//:cublas"], 6 | ) 7 | -------------------------------------------------------------------------------- /examples/cublas/cublas.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #define CUBLAS_CHECK(expr) \ 6 | do { \ 7 | cublasStatus_t err = (expr); \ 8 | if (err != CUBLAS_STATUS_SUCCESS) { \ 9 | fprintf(stderr, "CUBLAS Error: %d at %s:%d\n", err, __FILE__, __LINE__); \ 10 | exit(err); \ 11 | } \ 12 | } while (0) 13 | 14 | int main() { 15 | cublasHandle_t handle; 16 | CUBLAS_CHECK(cublasCreate(&handle)); 17 | CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); 18 | printf("cublas handle created\n"); 19 | return 0; 20 | } 21 | -------------------------------------------------------------------------------- /examples/if_cuda/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cuda_library( 6 | name = "kernel", 7 | srcs = ["kernel.cu"], 8 | hdrs = ["kernel.h"], 9 | target_compatible_with = requires_cuda(), 10 | ) 11 | 12 | cc_binary( 13 | name = "main", 14 | srcs = ["main.cpp"], 15 | defines = [] + select({ 16 | "@rules_cuda//cuda:is_enabled": ["CUDA_ENABLED"], 17 | "//conditions:default": ["CUDA_DISABLED"], 18 | }), 19 | deps = [] + select({ 20 | "@rules_cuda//cuda:is_enabled": [":kernel"], 21 | "//conditions:default": [], 22 | }), 23 | ) 24 | -------------------------------------------------------------------------------- /examples/if_cuda/README.md: -------------------------------------------------------------------------------- 1 | # if_cuda Example 2 | 3 | This example demonstrates how to conditionally include CUDA targets in your build. 4 | 5 | By default, _rules_cuda_ rules are enabled. Disabling rules_cuda rules can be accomplished by passing the `@rules_cuda//cuda:enable` 6 | flag at the command-line or via `.bazelrc`. 7 | 8 | ## Building Example with rules_cuda 9 | 10 | From the `examples` directory, build the sample application: 11 | 12 | ```bash 13 | bazel build //if_cuda:main 14 | ``` 15 | 16 | And run the binary: 17 | 18 | ```bash 19 | ./bazel-bin/if_cuda/main 20 | ``` 21 | 22 | If a valid GPU device is running on your development machine, the application will exit successfully and print: 23 | 24 | ``` 25 | cuda enabled 26 | ``` 27 | 28 | If running without a valid GPU device, the code, as written, will print a CUDA error and exit: 29 | 30 | ``` 31 | CUDA_VISIBLE_DEVICES=-1 ./bazel-bin/if_cuda/main 32 | CUDA Error Code : 100 33 | Error String: no CUDA-capable device is detected 34 | ``` 35 | 36 | ## Building Example without rules_cuda 37 | 38 | To build the binary without CUDA support, disable rules_cuda: 39 | 40 | ```bash 41 | bazel build //if_cuda:main --@rules_cuda//cuda:enable=False 42 | ``` 43 | 44 | And run the binary: 45 | 46 | ```bash 47 | ./bazel-bin/if_cuda/main 48 | ``` 49 | 50 | The binary will output: 51 | 52 | ``` 53 | cuda disabled 54 | ``` 55 | 56 | ### rules_cuda targets 57 | 58 | It is a good practice to set target compatibility as e.g. done here for `//if_cuda:kernel` target: 59 | 60 | ``` 61 | target_compatible_with = requires_cuda(), 62 | ``` 63 | 64 | With target compatibility set up, any attempt to build a `rules_cuda`-defined rule (e.g. `cuda_library` or `cuda_objects`) will _FAIL_ if `rules_cuda` is disabled: 65 | 66 | ``` 67 | bazel build //if_cuda:kernel --@rules_cuda//cuda:enable=False 68 | ERROR: Target //if_cuda:kernel is incompatible and cannot be built, but was explicitly requested. 69 | Dependency chain: 70 | //if_cuda:kernel (6b3a99) <-- target platform (@local_config_platform//:host) didn't satisfy constraints [@rules_cuda//cuda:rules_are_enabled, @rules_cuda//cuda:valid_toolchain_is_configured] 71 | FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured) 72 | ``` 73 | 74 | ## Developing for CUDA- and CUDA-free targets 75 | 76 | Note the `BUILD.bazel` file takes care to ensure that 77 | 78 | - CUDA-related dependencies are excluded when CUDA is disabled 79 | - Preprocessor variables are set to enable compile-time differentiated codepaths between CUDA and non-CUDA builds 80 | 81 | Our example build (when CUDA is enabled) sets a `CUDA_ENABLED` preprocessor variable. This variable is then checked in `main.cpp` to determine 82 | the type of compilation underway. You are free to set any set of preprocessor variables as needed by your particular project. Checking whether 83 | rules_cuda or not can be achieved simply by using Bazel's `select` feature: 84 | 85 | ``` 86 | select({ 87 | "@rules_cuda//cuda:is_enabled": [], # add whatever settings are required if using CUDA 88 | "//conditions:default": [] # add whatever settings are required if not using CUDA 89 | }) 90 | ``` 91 | 92 | We use this same mechanism to include the `cuda_library` dependencies. 93 | -------------------------------------------------------------------------------- /examples/if_cuda/kernel.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | __global__ void kernel() { 4 | printf("cuda enabled\n"); 5 | } 6 | 7 | void launch() { 8 | kernel<<<1, 1>>>(); 9 | CUDA_CHECK(cudaGetLastError()); 10 | CUDA_CHECK(cudaDeviceSynchronize()); 11 | } 12 | -------------------------------------------------------------------------------- /examples/if_cuda/kernel.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CUDA_CHECK(expr) \ 4 | do { \ 5 | cudaError_t err = (expr); \ 6 | if (err != cudaSuccess) { \ 7 | fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \ 8 | err, cudaGetErrorString(err)); \ 9 | exit(err); \ 10 | } \ 11 | } while (0) 12 | 13 | void launch(); 14 | -------------------------------------------------------------------------------- /examples/if_cuda/main.cpp: -------------------------------------------------------------------------------- 1 | #if defined(CUDA_ENABLED) 2 | #include "kernel.h" 3 | #endif 4 | 5 | #include 6 | 7 | void do_something_else() { 8 | fprintf(stderr, "cuda disabled\n"); 9 | } 10 | 11 | int main() { 12 | #if defined(CUDA_ENABLED) 13 | launch(); 14 | return 0; 15 | #elif defined(CUDA_DISABLED) 16 | do_something_else(); 17 | return -1; 18 | #else 19 | #error either CUDA_ENABLED or CUDA_NOT_ENABLED must be defined 20 | #endif 21 | } 22 | -------------------------------------------------------------------------------- /examples/nccl/BUILD.bazel: -------------------------------------------------------------------------------- 1 | filegroup( 2 | name = "nccl_shared", 3 | srcs = [ 4 | "@nccl//:nccl_shared", 5 | ], 6 | ) 7 | 8 | filegroup( 9 | name = "perf_binaries", 10 | srcs = [ 11 | "@nccl-tests//:all_gather_perf", 12 | "@nccl-tests//:all_reduce_perf", 13 | "@nccl-tests//:alltoall_perf", 14 | "@nccl-tests//:broadcast_perf", 15 | "@nccl-tests//:gather_perf", 16 | "@nccl-tests//:hypercube_perf", 17 | "@nccl-tests//:reduce_perf", 18 | "@nccl-tests//:reduce_scatter_perf", 19 | "@nccl-tests//:scatter_perf", 20 | "@nccl-tests//:sendrecv_perf", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /examples/nccl/nccl-tests-clang.patch: -------------------------------------------------------------------------------- 1 | diff --git a/src/all_gather.cu b/src/all_gather.cu 2 | index 0831207..941ec1b 100644 3 | --- a/src/all_gather.cu 4 | +++ b/src/all_gather.cu 5 | @@ -85,9 +85,7 @@ testResult_t AllGatherRunTest(struct threadArgs* args, int root, ncclDataType_t 6 | return testSuccess; 7 | } 8 | 9 | -struct testEngine allGatherEngine = { 10 | +struct testEngine ncclTestEngine = { 11 | AllGatherGetBuffSize, 12 | AllGatherRunTest 13 | }; 14 | - 15 | -#pragma weak ncclTestEngine=allGatherEngine 16 | diff --git a/src/all_reduce.cu b/src/all_reduce.cu 17 | index a38eabe..acb66a8 100644 18 | --- a/src/all_reduce.cu 19 | +++ b/src/all_reduce.cu 20 | @@ -93,9 +93,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t 21 | return testSuccess; 22 | } 23 | 24 | -struct testEngine allReduceEngine = { 25 | +struct testEngine ncclTestEngine = { 26 | AllReduceGetBuffSize, 27 | AllReduceRunTest 28 | }; 29 | - 30 | -#pragma weak ncclTestEngine=allReduceEngine 31 | diff --git a/src/alltoall.cu b/src/alltoall.cu 32 | index 41c7c4a..712e664 100644 33 | --- a/src/alltoall.cu 34 | +++ b/src/alltoall.cu 35 | @@ -99,9 +99,7 @@ testResult_t AlltoAllRunTest(struct threadArgs* args, int root, ncclDataType_t t 36 | return testSuccess; 37 | } 38 | 39 | -struct testEngine alltoAllEngine = { 40 | +struct testEngine ncclTestEngine = { 41 | AlltoAllGetBuffSize, 42 | AlltoAllRunTest 43 | }; 44 | - 45 | -#pragma weak ncclTestEngine=alltoAllEngine 46 | diff --git a/src/broadcast.cu b/src/broadcast.cu 47 | index 903066a..778c664 100644 48 | --- a/src/broadcast.cu 49 | +++ b/src/broadcast.cu 50 | @@ -99,9 +99,7 @@ testResult_t BroadcastRunTest(struct threadArgs* args, int root, ncclDataType_t 51 | return testSuccess; 52 | } 53 | 54 | -struct testEngine broadcastEngine = { 55 | +struct testEngine ncclTestEngine = { 56 | BroadcastGetBuffSize, 57 | BroadcastRunTest 58 | }; 59 | - 60 | -#pragma weak ncclTestEngine=broadcastEngine 61 | diff --git a/src/common.cu b/src/common.cu 62 | index 48a629c..d888edc 100644 63 | --- a/src/common.cu 64 | +++ b/src/common.cu 65 | @@ -330,7 +330,7 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t 66 | size_t count = args->nbytes / wordSize(type); 67 | 68 | // Try to change offset for each iteration so that we avoid cache effects and catch race conditions in ptrExchange 69 | - size_t totalnbytes = max(args->sendBytes, args->expectedBytes); 70 | + size_t totalnbytes = std::max(args->sendBytes, args->expectedBytes); 71 | size_t steps = totalnbytes ? args->maxbytes / totalnbytes : 1; 72 | size_t shift = totalnbytes * (iter % steps); 73 | 74 | @@ -597,7 +597,7 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* 75 | setupArgs(size, type, args); 76 | char rootName[100]; 77 | sprintf(rootName, "%6i", root); 78 | - PRINT("%12li %12li %8s %6s %6s", max(args->sendBytes, args->expectedBytes), args->nbytes / wordSize(type), typeName, opName, rootName); 79 | + PRINT("%12li %12li %8s %6s %6s", std::max(args->sendBytes, args->expectedBytes), args->nbytes / wordSize(type), typeName, opName, rootName); 80 | TESTCHECK(BenchTime(args, type, op, root, 0)); 81 | TESTCHECK(BenchTime(args, type, op, root, 1)); 82 | PRINT("\n"); 83 | diff --git a/src/gather.cu b/src/gather.cu 84 | index 03ef4d9..242a298 100644 85 | --- a/src/gather.cu 86 | +++ b/src/gather.cu 87 | @@ -108,9 +108,7 @@ testResult_t GatherRunTest(struct threadArgs* args, int root, ncclDataType_t typ 88 | return testSuccess; 89 | } 90 | 91 | -struct testEngine gatherEngine = { 92 | +struct testEngine ncclTestEngine = { 93 | GatherGetBuffSize, 94 | GatherRunTest 95 | }; 96 | - 97 | -#pragma weak ncclTestEngine=gatherEngine 98 | diff --git a/src/hypercube.cu b/src/hypercube.cu 99 | index 5c1456f..9aadfc5 100644 100 | --- a/src/hypercube.cu 101 | +++ b/src/hypercube.cu 102 | @@ -110,9 +110,7 @@ testResult_t HyperCubeRunTest(struct threadArgs* args, int root, ncclDataType_t 103 | return testSuccess; 104 | } 105 | 106 | -struct testEngine hyperCubeEngine = { 107 | +struct testEngine ncclTestEngine = { 108 | HyperCubeGetBuffSize, 109 | HyperCubeRunTest 110 | }; 111 | - 112 | -#pragma weak ncclTestEngine=hyperCubeEngine 113 | diff --git a/src/reduce.cu b/src/reduce.cu 114 | index f2fa80d..80aadc5 100644 115 | --- a/src/reduce.cu 116 | +++ b/src/reduce.cu 117 | @@ -102,9 +102,7 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ 118 | return testSuccess; 119 | } 120 | 121 | -struct testEngine reduceEngine = { 122 | +struct testEngine ncclTestEngine = { 123 | ReduceGetBuffSize, 124 | ReduceRunTest 125 | }; 126 | - 127 | -#pragma weak ncclTestEngine=reduceEngine 128 | diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu 129 | index ed372e3..212a6f0 100644 130 | --- a/src/reduce_scatter.cu 131 | +++ b/src/reduce_scatter.cu 132 | @@ -97,9 +97,7 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp 133 | return testSuccess; 134 | } 135 | 136 | -struct testEngine reduceScatterEngine = { 137 | +struct testEngine ncclTestEngine = { 138 | ReduceScatterGetBuffSize, 139 | ReduceScatterRunTest 140 | }; 141 | - 142 | -#pragma weak ncclTestEngine=reduceScatterEngine 143 | diff --git a/src/scatter.cu b/src/scatter.cu 144 | index 49d20e1..56f5ede 100644 145 | --- a/src/scatter.cu 146 | +++ b/src/scatter.cu 147 | @@ -104,9 +104,7 @@ testResult_t ScatterRunTest(struct threadArgs* args, int root, ncclDataType_t ty 148 | return testSuccess; 149 | } 150 | 151 | -struct testEngine scatterEngine = { 152 | +struct testEngine ncclTestEngine = { 153 | ScatterGetBuffSize, 154 | ScatterRunTest 155 | }; 156 | - 157 | -#pragma weak ncclTestEngine=scatterEngine 158 | diff --git a/src/sendrecv.cu b/src/sendrecv.cu 159 | index c9eb5bb..316a449 100644 160 | --- a/src/sendrecv.cu 161 | +++ b/src/sendrecv.cu 162 | @@ -106,9 +106,7 @@ testResult_t SendRecvRunTest(struct threadArgs* args, int root, ncclDataType_t t 163 | return testSuccess; 164 | } 165 | 166 | -struct testEngine sendRecvEngine = { 167 | +struct testEngine ncclTestEngine = { 168 | SendRecvGetBuffSize, 169 | SendRecvRunTest 170 | }; 171 | - 172 | -#pragma weak ncclTestEngine=sendRecvEngine 173 | -------------------------------------------------------------------------------- /examples/nccl/nccl-tests.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library") 2 | load("@rules_cuda_examples//nccl:nccl-tests.bzl", "nccl_tests_binary") 3 | 4 | # NOTE: all paths in this file relative to @nccl-tests repo root. 5 | 6 | cc_library( 7 | name = "nccl_tests_include", 8 | hdrs = glob(["src/*.h"]), 9 | includes = ["src"], 10 | ) 11 | 12 | cuda_library( 13 | name = "common_cuda", 14 | srcs = [ 15 | "src/common.cu", 16 | "verifiable/verifiable.cu", 17 | ] + glob([ 18 | "**/*.h", 19 | ]), 20 | deps = [ 21 | ":nccl_tests_include", 22 | "@nccl", 23 | ], 24 | ) 25 | 26 | cc_library( 27 | name = "common_cc", 28 | srcs = ["src/timer.cc"], 29 | hdrs = ["src/timer.h"], 30 | alwayslink = 1, 31 | ) 32 | 33 | # :common_cuda, :common_cc and @nccl//:nccl_shared are implicitly hardcoded in `nccl_tests_binary` 34 | nccl_tests_binary(name = "all_reduce") 35 | 36 | nccl_tests_binary(name = "all_gather") 37 | 38 | nccl_tests_binary(name = "broadcast") 39 | 40 | nccl_tests_binary(name = "reduce_scatter") 41 | 42 | nccl_tests_binary(name = "reduce") 43 | 44 | nccl_tests_binary(name = "alltoall") 45 | 46 | nccl_tests_binary(name = "scatter") 47 | 48 | nccl_tests_binary(name = "gather") 49 | 50 | nccl_tests_binary(name = "sendrecv") 51 | 52 | nccl_tests_binary(name = "hypercube") 53 | -------------------------------------------------------------------------------- /examples/nccl/nccl-tests.bzl: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library") 2 | 3 | # NOTE: all paths in this file relative to @nccl-tests repo root. 4 | 5 | def nccl_tests_binary(name, cc_deps = [], cuda_deps = []): 6 | cuda_library( 7 | name = name, 8 | srcs = ["src/{}.cu".format(name)], 9 | deps = [ 10 | "@nccl//:nccl_shared", 11 | ":common_cuda", 12 | ], 13 | alwayslink = 1, 14 | ) 15 | 16 | bin_name = name + "_perf" 17 | native.cc_binary( 18 | name = bin_name, 19 | deps = [":common_cc", ":" + name], 20 | visibility = ["//visibility:public"], 21 | ) 22 | -------------------------------------------------------------------------------- /examples/nccl/nccl.BUILD: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//rules:expand_template.bzl", "expand_template") 2 | load("@rules_cuda//cuda:defs.bzl", "cuda_library", "cuda_objects") 3 | load("@rules_cuda_examples//nccl:nccl.bzl", "if_cuda_clang", "if_cuda_nvcc", "nccl_primitive") 4 | 5 | # NOTE: all paths in this file relative to @nccl repo root. 6 | 7 | expand_template( 8 | name = "nccl_h", 9 | out = "src/include/nccl.h", 10 | substitutions = { 11 | "${nccl:Major}": "2", 12 | "${nccl:Minor}": "18", 13 | "${nccl:Patch}": "3", 14 | "${nccl:Suffix}": "", 15 | # NCCL_VERSION(X,Y,Z) ((X) * 10000 + (Y) * 100 + (Z)) 16 | "${nccl:Version}": "21803", 17 | }, 18 | template = "src/nccl.h.in", 19 | ) 20 | 21 | cc_library( 22 | name = "nccl_include", 23 | hdrs = [ 24 | ":nccl_h", 25 | ] + glob([ 26 | "src/include/**/*.h", 27 | "src/include/**/*.hpp", 28 | ]), 29 | includes = [ 30 | # this will add both nccl/src/include in repo and 31 | # bazel-out//bin/nccl/src/include to include paths 32 | # so the previous expand_template generate nccl.h to the very path! 33 | "src/include", 34 | ], 35 | ) 36 | 37 | cuda_objects( 38 | name = "nccl_device_common", 39 | srcs = [ 40 | "src/collectives/device/functions.cu", 41 | "src/collectives/device/onerank_reduce.cu", 42 | ] + glob([ 43 | "src/collectives/device/**/*.h", 44 | ]), 45 | copts = if_cuda_nvcc(["--extended-lambda"]), 46 | ptxasopts = ["-maxrregcount=96"], 47 | deps = [":nccl_include"], 48 | ) 49 | 50 | # must be manually disabled if cuda version is lower than 11. 51 | USE_BF16 = True 52 | 53 | filegroup( 54 | name = "collective_dev_hdrs", 55 | srcs = [ 56 | "src/collectives/device/all_gather.h", 57 | "src/collectives/device/all_reduce.h", 58 | "src/collectives/device/broadcast.h", 59 | "src/collectives/device/common.h", 60 | "src/collectives/device/common_kernel.h", 61 | "src/collectives/device/gen_rules.sh", 62 | "src/collectives/device/op128.h", 63 | "src/collectives/device/primitives.h", 64 | "src/collectives/device/prims_ll.h", 65 | "src/collectives/device/prims_ll128.h", 66 | "src/collectives/device/prims_simple.h", 67 | "src/collectives/device/reduce.h", 68 | "src/collectives/device/reduce_kernel.h", 69 | "src/collectives/device/reduce_scatter.h", 70 | "src/collectives/device/sendrecv.h", 71 | ], 72 | ) 73 | 74 | # cuda_objects for each type of primitive 75 | nccl_primitive( 76 | name = "all_gather", 77 | hdrs = ["collective_dev_hdrs"], 78 | use_bf16 = USE_BF16, 79 | deps = [":nccl_device_common"], 80 | ) 81 | 82 | nccl_primitive( 83 | name = "all_reduce", 84 | hdrs = ["collective_dev_hdrs"], 85 | use_bf16 = USE_BF16, 86 | deps = [":nccl_device_common"], 87 | ) 88 | 89 | nccl_primitive( 90 | name = "broadcast", 91 | hdrs = ["collective_dev_hdrs"], 92 | use_bf16 = USE_BF16, 93 | deps = [":nccl_device_common"], 94 | ) 95 | 96 | nccl_primitive( 97 | name = "reduce", 98 | hdrs = ["collective_dev_hdrs"], 99 | use_bf16 = USE_BF16, 100 | deps = [":nccl_device_common"], 101 | ) 102 | 103 | nccl_primitive( 104 | name = "reduce_scatter", 105 | hdrs = ["collective_dev_hdrs"], 106 | use_bf16 = USE_BF16, 107 | deps = [":nccl_device_common"], 108 | ) 109 | 110 | nccl_primitive( 111 | name = "sendrecv", 112 | hdrs = ["collective_dev_hdrs"], 113 | use_bf16 = USE_BF16, 114 | deps = [":nccl_device_common"], 115 | ) 116 | 117 | # device link 118 | cuda_library( 119 | name = "collectives", 120 | rdc = 1, 121 | deps = [ 122 | ":all_gather", 123 | ":all_reduce", 124 | ":broadcast", 125 | ":reduce", 126 | ":reduce_scatter", 127 | ":sendrecv", 128 | ], 129 | alwayslink = 1, 130 | ) 131 | 132 | cc_binary( 133 | name = "nccl", 134 | srcs = glob( 135 | [ 136 | "src/*.cc", 137 | "src/collectives/*.cc", 138 | "src/graph/*.cc", 139 | "src/graph/*.h", 140 | "src/misc/*.cc", 141 | "src/transport/*.cc", 142 | ], 143 | exclude = [ 144 | # https://github.com/NVIDIA/nccl/issues/658 145 | "src/enhcompat.cc", 146 | ], 147 | ), 148 | copts = if_cuda_clang(["-xcu"]), 149 | linkshared = 1, 150 | linkstatic = 1, 151 | visibility = ["//visibility:public"], 152 | deps = [ 153 | ":collectives", 154 | ":nccl_include", 155 | "@rules_cuda//cuda:runtime", 156 | ], 157 | ) 158 | 159 | # To allow downstream targets to link with the nccl shared library, we need to `cc_import` it again. 160 | # See https://groups.google.com/g/bazel-discuss/c/RtbidPdVFyU/m/TsUDOVHIAwAJ 161 | cc_import( 162 | name = "nccl_shared", 163 | shared_library = ":nccl", 164 | visibility = ["//visibility:public"], 165 | ) 166 | -------------------------------------------------------------------------------- /examples/nccl/nccl.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//rules:copy_file.bzl", "copy_file") 2 | load("@rules_cuda//cuda:defs.bzl", "cuda_library", "cuda_objects") 3 | 4 | # NOTE: all paths in this file relative to @nccl repo root. 5 | 6 | def if_cuda_nvcc(if_true, if_false = []): 7 | return select({ 8 | "@rules_cuda//cuda:compiler_is_nvcc": if_true, 9 | "//conditions:default": if_false, 10 | }) 11 | 12 | def if_cuda_clang(if_true, if_false = []): 13 | return select({ 14 | "@rules_cuda//cuda:compiler_is_clang": if_true, 15 | "//conditions:default": if_false, 16 | }) 17 | 18 | def nccl_primitive(name, hdrs = [], deps = [], use_bf16 = True): 19 | ops = ["sum", "prod", "min", "max", "premulsum", "sumpostdiv"] 20 | datatypes = ["i8", "u8", "i32", "u32", "i64", "u64", "f16", "f32", "f64"] 21 | if use_bf16: 22 | datatypes.append("bf16") 23 | 24 | intermediate_targets = [] 25 | for opn, op in enumerate(ops): 26 | for dtn, dt in enumerate(datatypes): 27 | name_op_dt = "{}_{}_{}".format(name, op, dt) 28 | copy_file( 29 | name = name_op_dt + "_rename", 30 | src = "src/collectives/device/{}.cu".format(name), 31 | out = "src/collectives/device/{}.cu".format(name_op_dt), 32 | ) 33 | 34 | cuda_objects( 35 | name = name_op_dt, 36 | srcs = [":{}_rename".format(name_op_dt)], 37 | hdrs = hdrs, 38 | deps = deps, 39 | ptxasopts = ["-maxrregcount=96"], 40 | defines = ["NCCL_OP={}".format(opn), "NCCL_TYPE={}".format(dtn)], 41 | includes = ["src/collectives/device"], 42 | ) 43 | intermediate_targets.append(":" + name_op_dt) 44 | 45 | cuda_objects(name = name, deps = intermediate_targets) 46 | -------------------------------------------------------------------------------- /examples/rdc/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library", "cuda_objects") 2 | 3 | cuda_objects( 4 | name = "a_objects", 5 | srcs = ["a.cu"], 6 | deps = [":b_objects"], 7 | ) 8 | 9 | cuda_objects( 10 | name = "b_objects", 11 | srcs = ["b.cu"], 12 | hdrs = ["b.cuh"], 13 | ) 14 | 15 | cuda_library( 16 | name = "lib_from_objects", 17 | rdc = True, 18 | deps = [ 19 | ":a_objects", 20 | ":b_objects", 21 | ], 22 | ) 23 | 24 | cc_binary( 25 | name = "main_from_objects", 26 | deps = [ 27 | ":lib_from_objects", 28 | ], 29 | ) 30 | 31 | # Another way of doing it is to just use cuda_library 32 | cuda_library( 33 | name = "a", 34 | srcs = ["a.cu"], 35 | rdc = True, 36 | deps = [":b"], 37 | ) 38 | 39 | cuda_library( 40 | name = "b", 41 | srcs = ["b.cu"], 42 | hdrs = ["b.cuh"], 43 | rdc = True, 44 | ) 45 | 46 | cc_binary( 47 | name = "main_from_library", 48 | deps = [ 49 | ":a", 50 | ":b", 51 | ], 52 | ) 53 | -------------------------------------------------------------------------------- /examples/rdc/a.cu: -------------------------------------------------------------------------------- 1 | #include "b.cuh" 2 | #include 3 | 4 | #define CUDA_CHECK(expr) \ 5 | do { \ 6 | cudaError_t err = (expr); \ 7 | if (err != cudaSuccess) { \ 8 | fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \ 9 | err, cudaGetErrorString(err)); \ 10 | exit(err); \ 11 | } \ 12 | } while (0) 13 | 14 | __global__ void foo() { 15 | __shared__ int a[N]; 16 | a[threadIdx.x] = threadIdx.x; 17 | __syncthreads(); 18 | 19 | g[threadIdx.x] = a[blockDim.x - threadIdx.x - 1]; 20 | bar(); 21 | } 22 | 23 | int main(void) { 24 | unsigned int i; 25 | int *dg, hg[N]; 26 | int sum = 0; 27 | 28 | foo<<<1, N>>>(); 29 | CUDA_CHECK(cudaGetLastError()); 30 | CUDA_CHECK(cudaGetSymbolAddress((void**)&dg, g)); 31 | CUDA_CHECK(cudaMemcpy(hg, dg, N * sizeof(int), cudaMemcpyDeviceToHost)); 32 | 33 | for (i = 0; i < N; i++) { 34 | sum += hg[i]; 35 | } 36 | if (sum == 36) { 37 | printf("PASSED\n"); 38 | } else { 39 | printf("FAILED (%d)\n", sum); 40 | } 41 | 42 | return 0; 43 | } 44 | -------------------------------------------------------------------------------- /examples/rdc/b.cu: -------------------------------------------------------------------------------- 1 | #include "b.cuh" 2 | 3 | __device__ int g[N]; 4 | 5 | __device__ void bar() { 6 | g[threadIdx.x]++; 7 | } 8 | -------------------------------------------------------------------------------- /examples/rdc/b.cuh: -------------------------------------------------------------------------------- 1 | #define N 8 2 | 3 | extern __device__ int g[N]; 4 | 5 | extern __device__ void bar(void); 6 | -------------------------------------------------------------------------------- /examples/thrust/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library") 2 | 3 | # thrust have code marked as __global__ or __kernel__ 4 | # so source for thrust should be built with cuda_library first 5 | cuda_library( 6 | name = "thrust_cu", 7 | srcs = ["thrust.cu"], 8 | deps = ["@cuda//:thrust"], 9 | ) 10 | 11 | # Run with 'bazel run //examples/thrust:main' 12 | cc_binary( 13 | name = "main", 14 | deps = [":thrust_cu"], 15 | ) 16 | -------------------------------------------------------------------------------- /examples/thrust/thrust.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | int main() { 6 | const int num_elements = 8192; 7 | thrust::device_vector vec(num_elements, 42.0); 8 | auto sum = thrust::reduce(vec.begin(), vec.end(), (float)0.0, thrust::plus()); 9 | std::cout << "thrust device_vector created, sum reduce as " << sum << ", mean: " << sum / num_elements << std::endl; 10 | return 0; 11 | } 12 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": ["config:base"], 4 | "ignorePaths": ["MODULE.bazel"], 5 | "packageRules": [ 6 | { 7 | "matchManagers": ["bazel", "bazel-module"], 8 | "excludePackageNames": ["nccl", "nccl-tests"] 9 | } 10 | ], 11 | "schedule": ["on the first day of the month"] 12 | } 13 | -------------------------------------------------------------------------------- /tests/flag/flag_validation_test.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//lib:unittest.bzl", "analysistest", "asserts") 2 | 3 | def _num_actions_test_impl(ctx): 4 | env = analysistest.begin(ctx) 5 | target_under_test = analysistest.target_under_test(env) 6 | actions = analysistest.target_actions(env) 7 | if ctx.attr.num_actions > 0: 8 | asserts.equals(env, ctx.attr.num_actions, len(actions)) 9 | return analysistest.end(env) 10 | 11 | num_actions_test = analysistest.make( 12 | _num_actions_test_impl, 13 | attrs = { 14 | "num_actions": attr.int(), 15 | }, 16 | ) 17 | 18 | def cuda_library_flag_test_impl(ctx): 19 | env = analysistest.begin(ctx) 20 | target_under_test = analysistest.target_under_test(env) 21 | actions = analysistest.target_actions(env) 22 | 23 | asserts.true(env, len(ctx.attr.contain_flags) + len(ctx.attr.not_contain_flags) > 0, "Invalid test config") 24 | 25 | def has_flag(cmd, single_flag): 26 | if (" " + single_flag + " ") in cmd: 27 | return True 28 | if cmd.endswith(" " + single_flag): 29 | return True 30 | return False 31 | 32 | has_matched_action = False 33 | has_name_match = True 34 | for action in actions: 35 | if ctx.attr.action_mnemonic == action.mnemonic: 36 | if ctx.attr.output_name != "": 37 | has_name_match = False 38 | for output in action.outputs.to_list(): 39 | has_name_match = has_name_match or output.basename == ctx.attr.output_name 40 | if not has_name_match: 41 | continue 42 | 43 | has_matched_action = True 44 | cmd = " ".join(action.argv) 45 | for flag in ctx.attr.contain_flags: 46 | asserts.true(env, has_flag(cmd, flag), 'flag "{}" not in command line "{}"'.format(flag, cmd)) 47 | for flag in ctx.attr.not_contain_flags: 48 | asserts.true(env, not has_flag(cmd, flag), 'flag "{}" in command line "{}"'.format(flag, cmd)) 49 | 50 | msg = "" if has_name_match else ' has output named "{}"'.format(ctx.attr.output_name) 51 | asserts.true(env, has_matched_action, 'target "{}" do not have action with mnemonic "{}"'.format( 52 | str(target_under_test), 53 | ctx.attr.action_mnemonic, 54 | ) + msg) 55 | 56 | return analysistest.end(env) 57 | 58 | def _rules_cuda_target(target): 59 | # https://github.com/bazelbuild/bazel/issues/19286#issuecomment-1684325913 60 | # must only apply to rules_cuda related labels when bzlmod is enabled 61 | is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@") 62 | label_str = "@//" + target 63 | if is_bzlmod_enabled: 64 | return str(Label(label_str)) 65 | else: 66 | return label_str 67 | 68 | def _create_cuda_library_flag_test(*config_settings): 69 | merged_config_settings = {} 70 | for cs in config_settings: 71 | for k, v in cs.items(): 72 | merged_config_settings[k] = v 73 | return analysistest.make( 74 | cuda_library_flag_test_impl, 75 | config_settings = merged_config_settings, 76 | attrs = { 77 | "action_mnemonic": attr.string(mandatory = True), 78 | "output_name": attr.string(), 79 | "contain_flags": attr.string_list(), 80 | "not_contain_flags": attr.string_list(), 81 | }, 82 | ) 83 | 84 | cuda_library_flag_test = _create_cuda_library_flag_test({}) 85 | 86 | config_settings_dbg = {"//command_line_option:compilation_mode": "dbg"} 87 | config_settings_fastbuild = {"//command_line_option:compilation_mode": "fastbuild"} 88 | config_settings_opt = {"//command_line_option:compilation_mode": "opt"} 89 | 90 | cuda_library_c_dbg_flag_test = _create_cuda_library_flag_test(config_settings_dbg) 91 | cuda_library_c_fastbuild_flag_test = _create_cuda_library_flag_test(config_settings_fastbuild) 92 | cuda_library_c_opt_flag_test = _create_cuda_library_flag_test(config_settings_opt) 93 | 94 | static_link_msvcrt = {"//command_line_option:features": ["static_link_msvcrt"]} 95 | 96 | cuda_library_c_dbg_static_msvcrt_flag_test = _create_cuda_library_flag_test(config_settings_dbg, static_link_msvcrt) 97 | cuda_library_c_fastbuild_static_msvcrt_flag_test = _create_cuda_library_flag_test(config_settings_fastbuild, static_link_msvcrt) 98 | cuda_library_c_opt_static_msvcrt_flag_test = _create_cuda_library_flag_test(config_settings_opt, static_link_msvcrt) 99 | 100 | # NOTE: @rules_cuda//cuda:archs does not work 101 | config_settings_sm61 = {_rules_cuda_target("cuda:archs"): "sm_61"} 102 | config_settings_compute60 = {_rules_cuda_target("cuda:archs"): "compute_60"} 103 | config_settings_compute60_sm61 = {_rules_cuda_target("cuda:archs"): "compute_60,sm_61"} 104 | config_settings_compute61_sm61 = {_rules_cuda_target("cuda:archs"): "compute_61,sm_61"} 105 | config_settings_sm90a = {_rules_cuda_target("cuda:archs"): "sm_90a"} 106 | config_settings_sm90a_sm90 = {_rules_cuda_target("cuda:archs"): "sm_90a,sm_90"} 107 | config_settings_sm100_sm100a = {_rules_cuda_target("cuda:archs"): "sm_100;sm_100a"} # NOTE: two specs 108 | 109 | cuda_library_sm61_flag_test = _create_cuda_library_flag_test(config_settings_sm61) 110 | cuda_library_sm90a_flag_test = _create_cuda_library_flag_test(config_settings_sm90a) 111 | cuda_library_sm90a_sm90_flag_test = _create_cuda_library_flag_test(config_settings_sm90a_sm90) 112 | cuda_library_sm100_sm100a_flag_test = _create_cuda_library_flag_test(config_settings_sm100_sm100a) 113 | cuda_library_compute60_flag_test = _create_cuda_library_flag_test(config_settings_compute60) 114 | cuda_library_compute60_sm61_flag_test = _create_cuda_library_flag_test(config_settings_compute60_sm61) 115 | cuda_library_compute61_sm61_flag_test = _create_cuda_library_flag_test(config_settings_compute61_sm61) 116 | -------------------------------------------------------------------------------- /tests/integration/BUILD.to_symlink: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda") 2 | 3 | cc_library( 4 | name = "use_library", 5 | tags = ["manual"], 6 | deps = ["@cuda//:cuda_runtime"], 7 | ) 8 | 9 | cuda_library( 10 | name = "use_rule", 11 | srcs = ["@rules_cuda_examples//basic:kernel.cu"], 12 | hdrs = ["@rules_cuda_examples//basic:kernel.h"], 13 | tags = ["manual"], 14 | ) 15 | 16 | cuda_library( 17 | name = "optional_kernel", 18 | srcs = ["@rules_cuda_examples//if_cuda:kernel.cu"], 19 | hdrs = ["@rules_cuda_examples//if_cuda:kernel.h"], 20 | tags = ["manual"], 21 | target_compatible_with = requires_cuda(), 22 | ) 23 | 24 | cc_binary( 25 | name = "optinally_use_rule", 26 | srcs = ["@rules_cuda_examples//if_cuda:main.cpp"], 27 | defines = [] + select({ 28 | "@rules_cuda//cuda:is_enabled": ["CUDA_ENABLED"], 29 | "//conditions:default": ["CUDA_DISABLED"], 30 | }), 31 | tags = ["manual"], 32 | deps = [] + select({ 33 | "@rules_cuda//cuda:is_enabled": [":optional_kernel"], 34 | "//conditions:default": [], 35 | }), 36 | ) 37 | -------------------------------------------------------------------------------- /tests/integration/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "rules_cuda_integration_tests") 2 | -------------------------------------------------------------------------------- /tests/integration/WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | workspace(name = "rules_cuda_integration_tests") 2 | -------------------------------------------------------------------------------- /tests/integration/test_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | this_dir=$(realpath $(dirname $0)) 4 | 5 | set -ex 6 | 7 | # toolchain configured by the root module of the user 8 | pushd "$this_dir/toolchain_root" 9 | bazel build //... --@rules_cuda//cuda:enable=False 10 | bazel build //... --@rules_cuda//cuda:enable=True 11 | bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False 12 | bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True 13 | bazel build //:use_library 14 | bazel build //:use_rule 15 | bazel clean && bazel shutdown 16 | popd 17 | 18 | # toolchain does not exists 19 | pushd "$this_dir/toolchain_none" 20 | # analysis pass 21 | bazel build //... --@rules_cuda//cuda:enable=False 22 | bazel build //... --@rules_cuda//cuda:enable=True 23 | 24 | # force build optional targets 25 | bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False 26 | ERR=$(bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True 2>&1 || true) 27 | if ! [[ $ERR == *"didn't satisfy constraint"*"valid_toolchain_is_configured"* ]]; then exit 1; fi 28 | 29 | # use library fails because the library file does not exist 30 | ERR=$(bazel build //:use_library 2>&1 || true) 31 | if ! [[ $ERR =~ "target 'cuda_runtime' not declared in package" ]]; then exit 1; fi 32 | if ! [[ $ERR =~ "ERROR: Analysis of target '//:use_library' failed" ]]; then exit 1; fi 33 | 34 | # use rule fails because rules_cuda depends non-existent cuda toolkit 35 | ERR=$(bazel build //:use_rule 2>&1 || true) 36 | if ! [[ $ERR =~ "target 'cuda_runtime' not declared in package" ]]; then exit 1; fi 37 | if ! [[ $ERR =~ "ERROR: Analysis of target '//:use_rule' failed" ]]; then exit 1; fi 38 | 39 | bazel clean && bazel shutdown 40 | popd 41 | 42 | # toolchain configured by rules_cuda 43 | pushd "$this_dir/toolchain_rules" 44 | bazel build //... --@rules_cuda//cuda:enable=False 45 | bazel build //... --@rules_cuda//cuda:enable=True 46 | bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False 47 | bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True 48 | bazel build //:use_library 49 | bazel build //:use_rule 50 | bazel clean && bazel shutdown 51 | popd 52 | 53 | # toolchain configured with deliverables (manual components with workspace) 54 | pushd "$this_dir/toolchain_components" 55 | bazel build --enable_workspace //... --@rules_cuda//cuda:enable=False 56 | bazel build --enable_workspace //... --@rules_cuda//cuda:enable=True 57 | bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=False 58 | bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=True 59 | bazel build --enable_workspace //:use_library 60 | bazel build --enable_workspace //:use_rule 61 | bazel clean && bazel shutdown 62 | popd 63 | 64 | # toolchain configured with deliverables (manual components with bzlmod) 65 | pushd "$this_dir/toolchain_components" 66 | bazel build --enable_bzlmod //... --@rules_cuda//cuda:enable=False 67 | bazel build --enable_bzlmod //... --@rules_cuda//cuda:enable=True 68 | bazel build --enable_bzlmod //:optinally_use_rule --@rules_cuda//cuda:enable=False 69 | bazel build --enable_bzlmod //:optinally_use_rule --@rules_cuda//cuda:enable=True 70 | bazel build --enable_bzlmod //:use_library 71 | bazel build --enable_bzlmod //:use_rule 72 | bazel clean && bazel shutdown 73 | popd 74 | 75 | # toolchain configured with deliverables (redistrib.json with workspace) 76 | pushd "$this_dir/toolchain_redist_json" 77 | bazel build --enable_workspace //... --@rules_cuda//cuda:enable=False 78 | bazel build --enable_workspace //... --@rules_cuda//cuda:enable=True 79 | bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=False 80 | bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=True 81 | bazel build --enable_workspace //:use_library 82 | bazel build --enable_workspace //:use_rule 83 | bazel clean && bazel shutdown 84 | popd 85 | -------------------------------------------------------------------------------- /tests/integration/toolchain_components/BUILD.bazel: -------------------------------------------------------------------------------- 1 | ../BUILD.to_symlink -------------------------------------------------------------------------------- /tests/integration/toolchain_components/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "bzlmod_components") 2 | 3 | bazel_dep(name = "rules_cuda", version = "0.0.0") 4 | local_path_override( 5 | module_name = "rules_cuda", 6 | path = "../../..", 7 | ) 8 | 9 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 10 | cuda.component( 11 | name = "cuda_cccl", 12 | component_name = "cccl", 13 | sha256 = "9c3145ef01f73e50c0f5fcf923f0899c847f487c529817daa8f8b1a3ecf20925", 14 | strip_prefix = "cuda_cccl-linux-x86_64-12.6.77-archive", 15 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/linux-x86_64/cuda_cccl-linux-x86_64-12.6.77-archive.tar.xz"], 16 | ) 17 | cuda.component( 18 | name = "cuda_cudart", 19 | component_name = "cudart", 20 | sha256 = "f74689258a60fd9c5bdfa7679458527a55e22442691ba678dcfaeffbf4391ef9", 21 | strip_prefix = "cuda_cudart-linux-x86_64-12.6.77-archive", 22 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/linux-x86_64/cuda_cudart-linux-x86_64-12.6.77-archive.tar.xz"], 23 | ) 24 | cuda.component( 25 | name = "cuda_nvcc", 26 | component_name = "nvcc", 27 | sha256 = "840deff234d9bef20d6856439c49881cb4f29423b214f9ecd2fa59b7ac323817", 28 | strip_prefix = "cuda_nvcc-linux-x86_64-12.6.85-archive", 29 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/cuda_nvcc-linux-x86_64-12.6.85-archive.tar.xz"], 30 | ) 31 | cuda.toolkit( 32 | name = "cuda", 33 | components_mapping = { 34 | "cccl": "@cuda_cccl", 35 | "cudart": "@cuda_cudart", 36 | "nvcc": "@cuda_nvcc", 37 | }, 38 | version = "12.6", 39 | ) 40 | use_repo( 41 | cuda, 42 | "cuda", 43 | "cuda_cccl", 44 | "cuda_cudart", 45 | "cuda_nvcc", 46 | ) 47 | 48 | bazel_dep(name = "rules_cuda_examples", version = "0.0.0") 49 | local_path_override( 50 | module_name = "rules_cuda_examples", 51 | path = "../../../examples", 52 | ) 53 | -------------------------------------------------------------------------------- /tests/integration/toolchain_components/WORKSPACE.bzlmod: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/tests/integration/toolchain_components/WORKSPACE.bzlmod -------------------------------------------------------------------------------- /tests/integration/toolchain_components/WORKSPACK.bazel: -------------------------------------------------------------------------------- 1 | load("@rules_cuda//cuda:repositories.bzl", "cuda_component", "default_components_mapping", "rules_cuda_dependencies", "rules_cuda_toolchains") 2 | 3 | local_repository( 4 | name = "rules_cuda", 5 | path = "../rules_cuda", 6 | ) 7 | 8 | rules_cuda_dependencies() 9 | 10 | cuda_component( 11 | name = "cuda_cccl", 12 | component_name = "cccl", 13 | sha256 = "9c3145ef01f73e50c0f5fcf923f0899c847f487c529817daa8f8b1a3ecf20925", 14 | strip_prefix = "cuda_cccl-linux-x86_64-12.6.77-archive", 15 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/linux-x86_64/cuda_cccl-linux-x86_64-12.6.77-archive.tar.xz"], 16 | ) 17 | 18 | cuda_component( 19 | name = "cuda_cudart", 20 | component_name = "cudart", 21 | sha256 = "f74689258a60fd9c5bdfa7679458527a55e22442691ba678dcfaeffbf4391ef9", 22 | strip_prefix = "cuda_cudart-linux-x86_64-12.6.77-archive", 23 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/linux-x86_64/cuda_cudart-linux-x86_64-12.6.77-archive.tar.xz"], 24 | ) 25 | 26 | cuda_component( 27 | name = "cuda_nvcc", 28 | component_name = "nvcc", 29 | sha256 = "840deff234d9bef20d6856439c49881cb4f29423b214f9ecd2fa59b7ac323817", 30 | strip_prefix = "cuda_nvcc-linux-x86_64-12.6.85-archive", 31 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/cuda_nvcc-linux-x86_64-12.6.85-archive.tar.xz"], 32 | ) 33 | 34 | COMPONENTS = [ 35 | "cccl", 36 | "cudart", 37 | "nvcc", 38 | ] 39 | 40 | rules_cuda_toolchains( 41 | components_mapping = default_components_mapping(COMPONENTS), 42 | version = "12.6", 43 | register_toolchains = True, 44 | ) 45 | -------------------------------------------------------------------------------- /tests/integration/toolchain_none/BUILD.bazel: -------------------------------------------------------------------------------- 1 | ../BUILD.to_symlink -------------------------------------------------------------------------------- /tests/integration/toolchain_none/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "toolchain_none") 2 | 3 | bazel_dep(name = "rules_cuda", version = "0.0.0") 4 | local_path_override( 5 | module_name = "rules_cuda", 6 | path = "../../..", 7 | ) 8 | 9 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 10 | cuda.toolkit( 11 | name = "cuda", 12 | toolkit_path = "/nonexistent/cuda/toolkit/path", 13 | ) 14 | use_repo(cuda, "cuda") 15 | 16 | bazel_dep(name = "rules_cuda_examples", version = "0.0.0") 17 | local_path_override( 18 | module_name = "rules_cuda_examples", 19 | path = "../../../examples", 20 | ) 21 | -------------------------------------------------------------------------------- /tests/integration/toolchain_redist_json/BUILD.bazel: -------------------------------------------------------------------------------- 1 | ../BUILD.to_symlink -------------------------------------------------------------------------------- /tests/integration/toolchain_redist_json/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "bzlmod_components") 2 | 3 | # FIXME: cuda_redist_json is not exposed in bzlmod now. Fallback to manually specified components for tests 4 | bazel_dep(name = "rules_cuda", version = "0.0.0") 5 | local_path_override( 6 | module_name = "rules_cuda", 7 | path = "../../..", 8 | ) 9 | 10 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 11 | cuda.component( 12 | name = "cuda_cccl", 13 | component_name = "cccl", 14 | sha256 = "9c3145ef01f73e50c0f5fcf923f0899c847f487c529817daa8f8b1a3ecf20925", 15 | strip_prefix = "cuda_cccl-linux-x86_64-12.6.77-archive", 16 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/linux-x86_64/cuda_cccl-linux-x86_64-12.6.77-archive.tar.xz"], 17 | ) 18 | cuda.component( 19 | name = "cuda_cudart", 20 | component_name = "cudart", 21 | sha256 = "f74689258a60fd9c5bdfa7679458527a55e22442691ba678dcfaeffbf4391ef9", 22 | strip_prefix = "cuda_cudart-linux-x86_64-12.6.77-archive", 23 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/linux-x86_64/cuda_cudart-linux-x86_64-12.6.77-archive.tar.xz"], 24 | ) 25 | cuda.component( 26 | name = "cuda_nvcc", 27 | component_name = "nvcc", 28 | sha256 = "840deff234d9bef20d6856439c49881cb4f29423b214f9ecd2fa59b7ac323817", 29 | strip_prefix = "cuda_nvcc-linux-x86_64-12.6.85-archive", 30 | urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/cuda_nvcc-linux-x86_64-12.6.85-archive.tar.xz"], 31 | ) 32 | cuda.toolkit( 33 | name = "cuda", 34 | components_mapping = { 35 | "cccl": "@cuda_cccl", 36 | "cudart": "@cuda_cudart", 37 | "nvcc": "@cuda_nvcc", 38 | }, 39 | version = "12.6", 40 | ) 41 | use_repo( 42 | cuda, 43 | "cuda", 44 | "cuda_cccl", 45 | "cuda_cudart", 46 | "cuda_nvcc", 47 | ) 48 | 49 | bazel_dep(name = "rules_cuda_examples", version = "0.0.0") 50 | local_path_override( 51 | module_name = "rules_cuda_examples", 52 | path = "../../../examples", 53 | ) 54 | -------------------------------------------------------------------------------- /tests/integration/toolchain_redist_json/WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | local_repository( 2 | name = "rules_cuda", 3 | path = "../../..", 4 | ) 5 | 6 | # buildifier: disable=load-on-top 7 | load("@rules_cuda//cuda:repositories.bzl", "cuda_redist_json", "rules_cuda_dependencies") 8 | 9 | rules_cuda_dependencies() 10 | 11 | cuda_redist_json( 12 | name = "rules_cuda_redist_json", 13 | components = [ 14 | "cccl", 15 | "cudart", 16 | "nvcc", 17 | ], 18 | version = "12.6.3", 19 | ) 20 | 21 | load("@rules_cuda_redist_json//:redist.bzl", "rules_cuda_components_and_toolchains") 22 | 23 | rules_cuda_components_and_toolchains(register_toolchains = True) 24 | -------------------------------------------------------------------------------- /tests/integration/toolchain_redist_json/WORKSPACE.bzlmod: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bazel-contrib/rules_cuda/3f72f484a8ea5969c81a857a6785ebf0ede02c0c/tests/integration/toolchain_redist_json/WORKSPACE.bzlmod -------------------------------------------------------------------------------- /tests/integration/toolchain_root/BUILD.bazel: -------------------------------------------------------------------------------- 1 | ../BUILD.to_symlink -------------------------------------------------------------------------------- /tests/integration/toolchain_root/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "bzlmod_use_repo_no_toolchain") 2 | 3 | bazel_dep(name = "rules_cuda", version = "0.0.0") 4 | local_path_override( 5 | module_name = "rules_cuda", 6 | path = "../../..", 7 | ) 8 | 9 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 10 | cuda.toolkit( 11 | name = "cuda", 12 | toolkit_path = "", 13 | ) 14 | use_repo(cuda, "cuda") 15 | 16 | bazel_dep(name = "rules_cuda_examples", version = "0.0.0") 17 | local_path_override( 18 | module_name = "rules_cuda_examples", 19 | path = "../../../examples", 20 | ) 21 | -------------------------------------------------------------------------------- /tests/integration/toolchain_rules/BUILD.bazel: -------------------------------------------------------------------------------- 1 | ../BUILD.to_symlink -------------------------------------------------------------------------------- /tests/integration/toolchain_rules/MODULE.bazel: -------------------------------------------------------------------------------- 1 | module(name = "bzlmod_use_repo") 2 | 3 | bazel_dep(name = "rules_cuda", version = "0.0.0") 4 | local_path_override( 5 | module_name = "rules_cuda", 6 | path = "../../..", 7 | ) 8 | 9 | cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") 10 | use_repo(cuda, "cuda") 11 | 12 | bazel_dep(name = "rules_cuda_examples", version = "0.0.0") 13 | local_path_override( 14 | module_name = "rules_cuda_examples", 15 | path = "../../../examples", 16 | ) 17 | -------------------------------------------------------------------------------- /tests/toolchain_config_lib/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load( 2 | ":toolchain_config_lib_test.bzl", 3 | "eval_flag_group_failure_tests", 4 | "eval_flag_group_test", 5 | "expand_flag_test", 6 | "feature_configuration_env_test", 7 | "feature_configuration_failure_tests", 8 | "feature_configuration_flags_order_test", 9 | "feature_configuration_test", 10 | "feature_configuration_unsuppoted_features_test", 11 | "feature_constraint_test", 12 | "feature_flag_sets_test", 13 | "parse_flag_cache_test", 14 | "parse_flag_failure_tests", 15 | "parse_flag_test", 16 | "var_test", 17 | ) 18 | 19 | parse_flag_failure_tests() 20 | 21 | parse_flag_test(name = "parse_flag_test") 22 | 23 | parse_flag_cache_test(name = "parse_flag_cache_test") 24 | 25 | var_test(name = "var_test") 26 | 27 | expand_flag_test(name = "expand_flag_test") 28 | 29 | eval_flag_group_test(name = "eval_flag_group_test") 30 | 31 | eval_flag_group_failure_tests() 32 | 33 | feature_constraint_test(name = "feature_constraint_test") 34 | 35 | feature_flag_sets_test(name = "feature_flag_sets_test") 36 | 37 | feature_configuration_test(name = "feature_configuration_test") 38 | 39 | feature_configuration_failure_tests() 40 | 41 | feature_configuration_flags_order_test(name = "feature_configuration_flags_order_test") 42 | 43 | feature_configuration_env_test(name = "feature_configuration_env_test") 44 | 45 | feature_configuration_unsuppoted_features_test(name = "feature_configuration_unsuppoted_features_test") 46 | -------------------------------------------------------------------------------- /tests/utils/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load(":utils_test.bzl", "get_arch_specs_test") 2 | 3 | get_arch_specs_test(name = "get_arch_specs_test") 4 | -------------------------------------------------------------------------------- /tests/utils/utils_test.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//lib:unittest.bzl", "asserts", "unittest") 2 | load("//cuda/private:cuda_helper.bzl", "cuda_helper") 3 | load("//cuda/private:providers.bzl", "ArchSpecInfo", "Stage2ArchInfo") 4 | 5 | def _get_arch_specs_test_impl(ctx): 6 | env = unittest.begin(ctx) 7 | 8 | asserts.equals(env, [], cuda_helper.get_arch_specs("")) 9 | 10 | asserts.equals(env, [], cuda_helper.get_arch_specs(";")) 11 | 12 | ref = [ArchSpecInfo( 13 | stage1_arch = "80", 14 | stage2_archs = [ 15 | Stage2ArchInfo(arch = "80", virtual = False, gpu = True, lto = False), 16 | Stage2ArchInfo(arch = "86", virtual = False, gpu = True, lto = False), 17 | ], 18 | )] 19 | asserts.equals(env, ref, cuda_helper.get_arch_specs("compute_80:sm_80,sm_86")) 20 | 21 | ref = [ArchSpecInfo( 22 | stage1_arch = "60", 23 | stage2_archs = [ 24 | Stage2ArchInfo(arch = "60", virtual = True, gpu = False, lto = False), 25 | Stage2ArchInfo(arch = "61", virtual = False, gpu = True, lto = False), 26 | Stage2ArchInfo(arch = "62", virtual = False, gpu = True, lto = False), 27 | ], 28 | )] 29 | asserts.equals(env, ref, cuda_helper.get_arch_specs("compute_60:compute_60,sm_61,sm_62")) 30 | 31 | ref = [ArchSpecInfo( 32 | stage1_arch = "80", 33 | stage2_archs = [Stage2ArchInfo(arch = "80", virtual = True, gpu = False, lto = False)], 34 | )] 35 | asserts.equals(env, ref, cuda_helper.get_arch_specs("compute_80:compute_80")) 36 | 37 | ref = [ArchSpecInfo( 38 | stage1_arch = "80", 39 | stage2_archs = [ 40 | Stage2ArchInfo(arch = "80", virtual = False, gpu = True, lto = False), 41 | Stage2ArchInfo(arch = "86", virtual = False, gpu = True, lto = False), 42 | ], 43 | )] 44 | asserts.equals(env, ref, cuda_helper.get_arch_specs("sm_80,sm_86")) 45 | 46 | ref = [ 47 | ArchSpecInfo( 48 | stage1_arch = "80", 49 | stage2_archs = [Stage2ArchInfo(arch = "80", virtual = False, gpu = True, lto = False)], 50 | ), 51 | ArchSpecInfo( 52 | stage1_arch = "86", 53 | stage2_archs = [Stage2ArchInfo(arch = "86", virtual = False, gpu = True, lto = False)], 54 | ), 55 | ] 56 | asserts.equals(env, ref, cuda_helper.get_arch_specs("sm_80;sm_86")) 57 | 58 | ref = [ArchSpecInfo( 59 | stage1_arch = "80", 60 | stage2_archs = [Stage2ArchInfo(arch = "80", virtual = True, gpu = False, lto = False)], 61 | )] 62 | asserts.equals(env, ref, cuda_helper.get_arch_specs("compute_80")) 63 | 64 | return unittest.end(env) 65 | 66 | get_arch_specs_test = unittest.make(_get_arch_specs_test_impl) 67 | --------------------------------------------------------------------------------