├── requirements.txt ├── extension_cpp ├── __init__.py ├── ops.py └── csrc │ ├── cuda │ └── muladd.cu │ └── muladd.cpp ├── pyproject.toml ├── .github ├── scripts │ ├── unittest.sh │ └── setup-env.sh ├── ISSUE_TEMPLATE.md └── workflows │ └── tests.yml ├── README.md ├── .gitignore ├── setup.py └── test └── test_extension.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | -------------------------------------------------------------------------------- /extension_cpp/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from . import _C, ops 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "torch", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /.github/scripts/unittest.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | ./.github/scripts/setup-env.sh 6 | 7 | # Activate conda environment 8 | eval "$($(which conda) shell.bash hook)" && conda deactivate && conda activate ci 9 | 10 | echo '::group::Install testing utilities' 11 | pip install --progress-bar=off pytest pytest-mock pytest-cov expecttest numpy 12 | echo '::endgroup::' 13 | 14 | pytest test/ --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C++/CUDA Extensions in PyTorch 2 | 3 | An example of writing a C++/CUDA extension for PyTorch. See 4 | [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. 5 | This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` 6 | custom op that has both custom CPU and CUDA kernels. 7 | 8 | The examples in this repo work with PyTorch 2.4+. 9 | 10 | To build: 11 | ``` 12 | pip install --no-build-isolation -e . 13 | ``` 14 | 15 | To test: 16 | ``` 17 | python test/test_extension.py 18 | ``` 19 | 20 | ## Authors 21 | 22 | [Peter Goldsborough](https://github.com/goldsborough), [Richard Zou](https://github.com/zou3519) 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | PyTorch GitHub Issues Guidelines 2 | -------------------------------- 3 | 4 | We like to limit our issues to bug reports and feature requests. If you have a question or would like help and support, please visit our forums: https://discuss.pytorch.org/ 5 | 6 | If you are submitting a feature request, please preface the title with [feature request]. 7 | 8 | When submitting a bug report, please include the following information (where relevant): 9 | - OS: 10 | - PyTorch version: 11 | - How you installed PyTorch (conda, pip, source): 12 | - Python version: 13 | - CUDA/cuDNN version: 14 | - GPU models and configuration: 15 | - GCC version (if compiling from source): 16 | 17 | In addition, including the following information will also be very helpful for us to diagnose the problem: 18 | - A script to reproduce the bug. Please try to provide as minimal of a test case as possible. 19 | - Error messages and/or stack traces of the bug 20 | - Context around what you are trying to do 21 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | workflow_dispatch: 9 | 10 | jobs: 11 | unittests-linux: 12 | strategy: 13 | matrix: 14 | python-version: 15 | - "3.11" 16 | runner: ["linux.12xlarge"] 17 | gpu-arch-type: ["cpu"] 18 | include: 19 | - python-version: 3.13 20 | runner: linux.g5.4xlarge.nvidia.gpu 21 | gpu-arch-type: cuda 22 | gpu-arch-version: "12.4" 23 | fail-fast: false 24 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 25 | permissions: 26 | id-token: write 27 | contents: read 28 | with: 29 | repository: pytorch/extension-cpp 30 | runner: ${{ matrix.runner }} 31 | gpu-arch-type: ${{ matrix.gpu-arch-type }} 32 | gpu-arch-version: ${{ matrix.gpu-arch-version }} 33 | timeout: 120 34 | script: | 35 | set -euo pipefail 36 | 37 | export PYTHON_VERSION=${{ matrix.python-version }} 38 | export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }} 39 | export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }} 40 | 41 | ./.github/scripts/unittest.sh 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | torch.egg-info/ 4 | */**/__pycache__ 5 | torch/version.py 6 | torch/csrc/generic/TensorMethods.cpp 7 | torch/lib/*.so* 8 | torch/lib/*.a* 9 | torch/lib/*.dll* 10 | torch/lib/*.lib 11 | torch/lib/*.dylib* 12 | torch/lib/*.h 13 | torch/lib/build 14 | torch/lib/tmp_install 15 | torch/lib/include 16 | torch/lib/torch_shm_manager 17 | torch/csrc/jit/generated/* 18 | torch/csrc/autograd/generated/* 19 | torch/csrc/cudnn/cuDNN.cpp 20 | torch/csrc/nn/THNN.cwrap 21 | torch/csrc/nn/THNN.cpp 22 | torch/csrc/nn/THCUNN.cwrap 23 | torch/csrc/nn/THCUNN.cpp 24 | torch/csrc/nn/THNN_generic.cwrap 25 | torch/csrc/nn/THNN_generic.cpp 26 | torch/csrc/nn/THNN_generic.h 27 | torch/csrc/generated 28 | docs/src/**/* 29 | test/data/legacy_modules.t7 30 | test/data/gpu_tensors.pt 31 | test/htmlcov 32 | test/.coverage 33 | */*.pyc 34 | */**/*.pyc 35 | */**/**/*.pyc 36 | */**/**/**/*.pyc 37 | */**/**/**/**/*.pyc 38 | */*.so* 39 | */**/*.so* 40 | */**/*.dylib* 41 | test/data/legacy_serialized.pt 42 | test/data/linear.pt 43 | .mypy_cache 44 | 45 | # IPython notebook checkpoints 46 | .ipynb_checkpoints 47 | 48 | # Editor temporaries 49 | *.swn 50 | *.swo 51 | *.swp 52 | *~ 53 | 54 | # macOS dir files 55 | .DS_Store 56 | 57 | # Ninja files 58 | .ninja_deps 59 | .ninja_log 60 | compile_commands.json 61 | *.egg-info/ 62 | -------------------------------------------------------------------------------- /extension_cpp/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | __all__ = ["mymuladd", "myadd_out"] 5 | 6 | 7 | def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor: 8 | """Performs a * b + c in an efficient fused kernel""" 9 | return torch.ops.extension_cpp.mymuladd.default(a, b, c) 10 | 11 | 12 | # Registers a FakeTensor kernel (aka "meta kernel", "abstract impl") 13 | # that describes what the properties of the output Tensor are given 14 | # the properties of the input Tensor. The FakeTensor kernel is necessary 15 | # for the op to work performantly with torch.compile. 16 | @torch.library.register_fake("extension_cpp::mymuladd") 17 | def _(a, b, c): 18 | torch._check(a.shape == b.shape) 19 | torch._check(a.dtype == torch.float) 20 | torch._check(b.dtype == torch.float) 21 | torch._check(a.device == b.device) 22 | return torch.empty_like(a) 23 | 24 | 25 | def _backward(ctx, grad): 26 | a, b = ctx.saved_tensors 27 | grad_a, grad_b = None, None 28 | if ctx.needs_input_grad[0]: 29 | grad_a = torch.ops.extension_cpp.mymul.default(grad, b) 30 | if ctx.needs_input_grad[1]: 31 | grad_b = torch.ops.extension_cpp.mymul.default(grad, a) 32 | return grad_a, grad_b, None 33 | 34 | 35 | def _setup_context(ctx, inputs, output): 36 | a, b, c = inputs 37 | saved_a, saved_b = None, None 38 | if ctx.needs_input_grad[0]: 39 | saved_b = b 40 | if ctx.needs_input_grad[1]: 41 | saved_a = a 42 | ctx.save_for_backward(saved_a, saved_b) 43 | 44 | 45 | # This adds training support for the operator. You must provide us 46 | # the backward formula for the operator and a `setup_context` function 47 | # to save values to be used in the backward. 48 | torch.library.register_autograd( 49 | "extension_cpp::mymuladd", _backward, setup_context=_setup_context) 50 | 51 | 52 | @torch.library.register_fake("extension_cpp::mymul") 53 | def _(a, b): 54 | torch._check(a.shape == b.shape) 55 | torch._check(a.dtype == torch.float) 56 | torch._check(b.dtype == torch.float) 57 | torch._check(a.device == b.device) 58 | return torch.empty_like(a) 59 | 60 | 61 | def myadd_out(a: Tensor, b: Tensor, out: Tensor) -> None: 62 | """Writes a + b into out""" 63 | torch.ops.extension_cpp.myadd_out.default(a, b, out) 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import torch 8 | import glob 9 | 10 | from setuptools import find_packages, setup 11 | 12 | from torch.utils.cpp_extension import ( 13 | CppExtension, 14 | CUDAExtension, 15 | BuildExtension, 16 | CUDA_HOME, 17 | ) 18 | 19 | library_name = "extension_cpp" 20 | 21 | if torch.__version__ >= "2.6.0": 22 | py_limited_api = True 23 | else: 24 | py_limited_api = False 25 | 26 | 27 | def get_extensions(): 28 | debug_mode = os.getenv("DEBUG", "0") == "1" 29 | use_cuda = os.getenv("USE_CUDA", "1") == "1" 30 | if debug_mode: 31 | print("Compiling in debug mode") 32 | 33 | use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None 34 | extension = CUDAExtension if use_cuda else CppExtension 35 | 36 | extra_link_args = [] 37 | extra_compile_args = { 38 | "cxx": [ 39 | "-O3" if not debug_mode else "-O0", 40 | "-fdiagnostics-color=always", 41 | "-DPy_LIMITED_API=0x03090000", # min CPython version 3.9 42 | ], 43 | "nvcc": [ 44 | "-O3" if not debug_mode else "-O0", 45 | ], 46 | } 47 | if debug_mode: 48 | extra_compile_args["cxx"].append("-g") 49 | extra_compile_args["nvcc"].append("-g") 50 | extra_link_args.extend(["-O0", "-g"]) 51 | 52 | this_dir = os.path.dirname(os.path.curdir) 53 | extensions_dir = os.path.join(this_dir, library_name, "csrc") 54 | sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) 55 | 56 | extensions_cuda_dir = os.path.join(extensions_dir, "cuda") 57 | cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) 58 | 59 | if use_cuda: 60 | sources += cuda_sources 61 | 62 | ext_modules = [ 63 | extension( 64 | f"{library_name}._C", 65 | sources, 66 | extra_compile_args=extra_compile_args, 67 | extra_link_args=extra_link_args, 68 | py_limited_api=py_limited_api, 69 | ) 70 | ] 71 | 72 | return ext_modules 73 | 74 | 75 | setup( 76 | name=library_name, 77 | version="0.0.1", 78 | packages=find_packages(), 79 | ext_modules=get_extensions(), 80 | install_requires=["torch"], 81 | description="Example of PyTorch C++ and CUDA extensions", 82 | long_description=open("README.md").read(), 83 | long_description_content_type="text/markdown", 84 | url="https://github.com/pytorch/extension-cpp", 85 | cmdclass={"build_ext": BuildExtension}, 86 | options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, 87 | ) 88 | -------------------------------------------------------------------------------- /extension_cpp/csrc/cuda/muladd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace extension_cpp { 10 | 11 | __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { 12 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 13 | if (idx < numel) result[idx] = a[idx] * b[idx] + c; 14 | } 15 | 16 | at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { 17 | TORCH_CHECK(a.sizes() == b.sizes()); 18 | TORCH_CHECK(a.dtype() == at::kFloat); 19 | TORCH_CHECK(b.dtype() == at::kFloat); 20 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); 21 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); 22 | at::Tensor a_contig = a.contiguous(); 23 | at::Tensor b_contig = b.contiguous(); 24 | at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); 25 | const float* a_ptr = a_contig.data_ptr(); 26 | const float* b_ptr = b_contig.data_ptr(); 27 | float* result_ptr = result.data_ptr(); 28 | 29 | int numel = a_contig.numel(); 30 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 31 | muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); 32 | return result; 33 | } 34 | 35 | __global__ void mul_kernel(int numel, const float* a, const float* b, float* result) { 36 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 37 | if (idx < numel) result[idx] = a[idx] * b[idx]; 38 | } 39 | 40 | at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { 41 | TORCH_CHECK(a.sizes() == b.sizes()); 42 | TORCH_CHECK(a.dtype() == at::kFloat); 43 | TORCH_CHECK(b.dtype() == at::kFloat); 44 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); 45 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); 46 | at::Tensor a_contig = a.contiguous(); 47 | at::Tensor b_contig = b.contiguous(); 48 | at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); 49 | const float* a_ptr = a_contig.data_ptr(); 50 | const float* b_ptr = b_contig.data_ptr(); 51 | float* result_ptr = result.data_ptr(); 52 | int numel = a_contig.numel(); 53 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 54 | mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); 55 | return result; 56 | } 57 | 58 | __global__ void add_kernel(int numel, const float* a, const float* b, float* result) { 59 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 60 | if (idx < numel) result[idx] = a[idx] + b[idx]; 61 | } 62 | 63 | void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { 64 | TORCH_CHECK(a.sizes() == b.sizes()); 65 | TORCH_CHECK(b.sizes() == out.sizes()); 66 | TORCH_CHECK(a.dtype() == at::kFloat); 67 | TORCH_CHECK(b.dtype() == at::kFloat); 68 | TORCH_CHECK(out.dtype() == at::kFloat); 69 | TORCH_CHECK(out.is_contiguous()); 70 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); 71 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); 72 | TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA); 73 | at::Tensor a_contig = a.contiguous(); 74 | at::Tensor b_contig = b.contiguous(); 75 | const float* a_ptr = a_contig.data_ptr(); 76 | const float* b_ptr = b_contig.data_ptr(); 77 | float* result_ptr = out.data_ptr(); 78 | int numel = a_contig.numel(); 79 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 80 | add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); 81 | } 82 | 83 | 84 | // Registers CUDA implementations for mymuladd, mymul, myadd_out 85 | TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { 86 | m.impl("mymuladd", &mymuladd_cuda); 87 | m.impl("mymul", &mymul_cuda); 88 | m.impl("myadd_out", &myadd_out_cuda); 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /extension_cpp/csrc/muladd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | extern "C" { 9 | /* Creates a dummy empty _C module that can be imported from Python. 10 | The import from Python will load the .so consisting of this file 11 | in this extension, so that the TORCH_LIBRARY static initializers 12 | below are run. */ 13 | PyObject* PyInit__C(void) 14 | { 15 | static struct PyModuleDef module_def = { 16 | PyModuleDef_HEAD_INIT, 17 | "_C", /* name of module */ 18 | NULL, /* module documentation, may be NULL */ 19 | -1, /* size of per-interpreter state of the module, 20 | or -1 if the module keeps state in global variables. */ 21 | NULL, /* methods */ 22 | }; 23 | return PyModule_Create(&module_def); 24 | } 25 | } 26 | 27 | namespace extension_cpp { 28 | 29 | at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) { 30 | TORCH_CHECK(a.sizes() == b.sizes()); 31 | TORCH_CHECK(a.dtype() == at::kFloat); 32 | TORCH_CHECK(b.dtype() == at::kFloat); 33 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); 34 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); 35 | at::Tensor a_contig = a.contiguous(); 36 | at::Tensor b_contig = b.contiguous(); 37 | at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); 38 | const float* a_ptr = a_contig.data_ptr(); 39 | const float* b_ptr = b_contig.data_ptr(); 40 | float* result_ptr = result.data_ptr(); 41 | for (int64_t i = 0; i < result.numel(); i++) { 42 | result_ptr[i] = a_ptr[i] * b_ptr[i] + c; 43 | } 44 | return result; 45 | } 46 | 47 | at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { 48 | TORCH_CHECK(a.sizes() == b.sizes()); 49 | TORCH_CHECK(a.dtype() == at::kFloat); 50 | TORCH_CHECK(b.dtype() == at::kFloat); 51 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); 52 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); 53 | at::Tensor a_contig = a.contiguous(); 54 | at::Tensor b_contig = b.contiguous(); 55 | at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); 56 | const float* a_ptr = a_contig.data_ptr(); 57 | const float* b_ptr = b_contig.data_ptr(); 58 | float* result_ptr = result.data_ptr(); 59 | for (int64_t i = 0; i < result.numel(); i++) { 60 | result_ptr[i] = a_ptr[i] * b_ptr[i]; 61 | } 62 | return result; 63 | } 64 | 65 | // An example of an operator that mutates one of its inputs. 66 | void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { 67 | TORCH_CHECK(a.sizes() == b.sizes()); 68 | TORCH_CHECK(b.sizes() == out.sizes()); 69 | TORCH_CHECK(a.dtype() == at::kFloat); 70 | TORCH_CHECK(b.dtype() == at::kFloat); 71 | TORCH_CHECK(out.dtype() == at::kFloat); 72 | TORCH_CHECK(out.is_contiguous()); 73 | TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); 74 | TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); 75 | TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); 76 | at::Tensor a_contig = a.contiguous(); 77 | at::Tensor b_contig = b.contiguous(); 78 | const float* a_ptr = a_contig.data_ptr(); 79 | const float* b_ptr = b_contig.data_ptr(); 80 | float* result_ptr = out.data_ptr(); 81 | for (int64_t i = 0; i < out.numel(); i++) { 82 | result_ptr[i] = a_ptr[i] + b_ptr[i]; 83 | } 84 | } 85 | 86 | // Defines the operators 87 | TORCH_LIBRARY(extension_cpp, m) { 88 | m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); 89 | m.def("mymul(Tensor a, Tensor b) -> Tensor"); 90 | m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); 91 | } 92 | 93 | // Registers CPU implementations for mymuladd, mymul, myadd_out 94 | TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { 95 | m.impl("mymuladd", &mymuladd_cpu); 96 | m.impl("mymul", &mymul_cpu); 97 | m.impl("myadd_out", &myadd_out_cpu); 98 | } 99 | 100 | } 101 | -------------------------------------------------------------------------------- /.github/scripts/setup-env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euxo pipefail 4 | 5 | # Prepare conda 6 | set +x && eval "$($(which conda) shell.bash hook)" && set -x 7 | 8 | # Setup the OS_TYPE environment variable that should be used for conditions involving the OS below. 9 | case $(uname) in 10 | Linux) 11 | OS_TYPE=linux 12 | ;; 13 | Darwin) 14 | OS_TYPE=macos 15 | ;; 16 | MSYS*) 17 | OS_TYPE=windows 18 | ;; 19 | *) 20 | echo "Unknown OS type:" $(uname) 21 | exit 1 22 | ;; 23 | esac 24 | 25 | if [[ "${OS_TYPE}" == "macos" && $(uname -m) == x86_64 ]]; then 26 | echo '::group::Uninstall system JPEG libraries on macOS' 27 | # The x86 macOS runners, e.g. the GitHub Actions native "macos-12" runner, has some JPEG and PNG libraries 28 | # installed by default that interfere with our build. We uninstall them here and use the one from conda below. 29 | IMAGE_LIBS=$(brew list | grep -E "jpeg|png") 30 | for lib in $IMAGE_LIBS; do 31 | brew uninstall --ignore-dependencies --force "${lib}" 32 | done 33 | echo '::endgroup::' 34 | fi 35 | 36 | echo '::group::Create build environment' 37 | # See https://github.com/pytorch/vision/issues/7296 for ffmpeg 38 | conda create \ 39 | --name ci \ 40 | --quiet --yes \ 41 | python="${PYTHON_VERSION}" pip \ 42 | ninja cmake 43 | conda activate ci 44 | pip install --progress-bar=off --upgrade setuptools 45 | 46 | echo '::endgroup::' 47 | 48 | if [[ "${OS_TYPE}" == windows && "${GPU_ARCH_TYPE}" == cuda ]]; then 49 | echo '::group::Install VisualStudio CUDA extensions on Windows' 50 | if [[ "${VC_YEAR:-}" == "2022" ]]; then 51 | TARGET_DIR="/c/Program Files (x86)/Microsoft Visual Studio/2022/BuildTools/MSBuild/Microsoft/VC/v170/BuildCustomizations" 52 | else 53 | TARGET_DIR="/c/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/MSBuild/Microsoft/VC/v160/BuildCustomizations" 54 | fi 55 | mkdir -p "${TARGET_DIR}" 56 | cp -r "${CUDA_HOME}/MSBuildExtensions/"* "${TARGET_DIR}" 57 | echo '::endgroup::' 58 | fi 59 | 60 | echo '::group::Install PyTorch' 61 | # TODO: Can we maybe have this as environment variable in the job template? For example, `IS_RELEASE`. 62 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 63 | CHANNEL=test 64 | else 65 | CHANNEL=nightly 66 | fi 67 | 68 | case $GPU_ARCH_TYPE in 69 | cpu) 70 | GPU_ARCH_ID="cpu" 71 | ;; 72 | cuda) 73 | VERSION_WITHOUT_DOT=$(echo "${GPU_ARCH_VERSION}" | sed 's/\.//') 74 | GPU_ARCH_ID="cu${VERSION_WITHOUT_DOT}" 75 | ;; 76 | *) 77 | echo "Unknown GPU_ARCH_TYPE=${GPU_ARCH_TYPE}" 78 | exit 1 79 | ;; 80 | esac 81 | PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}" 82 | pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" 83 | 84 | if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then 85 | python -c "import torch; exit(not torch.cuda.is_available())" 86 | fi 87 | echo '::endgroup::' 88 | 89 | echo '::group::Install third party dependencies prior to extension-cpp install' 90 | # Installing with `easy_install`, e.g. `python setup.py install` or `python setup.py develop`, has some quirks when 91 | # when pulling in third-party dependencies. For example: 92 | # - On Windows, we often hit an SSL error although `pip` can install just fine. 93 | # - It happily pulls in pre-releases, which can lead to more problems down the line. 94 | # `pip` does not unless explicitly told to do so. 95 | # Thus, we use `easy_install` to extract the third-party dependencies here and install them upfront with `pip`. 96 | python setup.py egg_info 97 | # The requires.txt cannot be used with `pip install -r` directly. The requirements are listed at the top and the 98 | # optional dependencies come in non-standard syntax after a blank line. Thus, we just extract the header. 99 | sed -e '/^$/,$d' *.egg-info/requires.txt | tee requirements.txt 100 | pip install --progress-bar=off -r requirements.txt 101 | echo '::endgroup::' 102 | 103 | echo '::group::Install extension-cpp' 104 | python setup.py develop 105 | echo '::endgroup::' 106 | 107 | echo '::group::Collect environment information' 108 | conda list 109 | python -m torch.utils.collect_env 110 | echo '::endgroup::' 111 | -------------------------------------------------------------------------------- /test/test_extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing._internal.common_utils import TestCase 3 | from torch.testing._internal.optests import opcheck 4 | import unittest 5 | import extension_cpp 6 | from torch import Tensor 7 | from typing import Tuple 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | 12 | def reference_muladd(a, b, c): 13 | return a * b + c 14 | 15 | 16 | class TestMyMulAdd(TestCase): 17 | def sample_inputs(self, device, *, requires_grad=False): 18 | def make_tensor(*size): 19 | return torch.randn(size, device=device, requires_grad=requires_grad) 20 | 21 | def make_nondiff_tensor(*size): 22 | return torch.randn(size, device=device, requires_grad=False) 23 | 24 | return [ 25 | [make_tensor(3), make_tensor(3), 1], 26 | [make_tensor(20), make_tensor(20), 3.14], 27 | [make_tensor(20), make_nondiff_tensor(20), -123], 28 | [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], 29 | ] 30 | 31 | def _test_correctness(self, device): 32 | samples = self.sample_inputs(device) 33 | for args in samples: 34 | result = extension_cpp.ops.mymuladd(*args) 35 | expected = reference_muladd(*args) 36 | torch.testing.assert_close(result, expected) 37 | 38 | def test_correctness_cpu(self): 39 | self._test_correctness("cpu") 40 | 41 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 42 | def test_correctness_cuda(self): 43 | self._test_correctness("cuda") 44 | 45 | def _test_gradients(self, device): 46 | samples = self.sample_inputs(device, requires_grad=True) 47 | for args in samples: 48 | diff_tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] 49 | out = extension_cpp.ops.mymuladd(*args) 50 | grad_out = torch.randn_like(out) 51 | result = torch.autograd.grad(out, diff_tensors, grad_out) 52 | 53 | out = reference_muladd(*args) 54 | expected = torch.autograd.grad(out, diff_tensors, grad_out) 55 | 56 | torch.testing.assert_close(result, expected) 57 | 58 | def test_gradients_cpu(self): 59 | self._test_gradients("cpu") 60 | 61 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 62 | def test_gradients_cuda(self): 63 | self._test_gradients("cuda") 64 | 65 | def _opcheck(self, device): 66 | # Use opcheck to check for incorrect usage of operator registration APIs 67 | samples = self.sample_inputs(device, requires_grad=True) 68 | samples.extend(self.sample_inputs(device, requires_grad=False)) 69 | for args in samples: 70 | opcheck(torch.ops.extension_cpp.mymuladd.default, args) 71 | 72 | def test_opcheck_cpu(self): 73 | self._opcheck("cpu") 74 | 75 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 76 | def test_opcheck_cuda(self): 77 | self._opcheck("cuda") 78 | 79 | 80 | class TestMyAddOut(TestCase): 81 | def sample_inputs(self, device, *, requires_grad=False): 82 | def make_tensor(*size): 83 | return torch.randn(size, device=device, requires_grad=requires_grad) 84 | 85 | def make_nondiff_tensor(*size): 86 | return torch.randn(size, device=device, requires_grad=False) 87 | 88 | return [ 89 | [make_tensor(3), make_tensor(3), make_tensor(3)], 90 | [make_tensor(20), make_tensor(20), make_tensor(20)], 91 | ] 92 | 93 | def _test_correctness(self, device): 94 | samples = self.sample_inputs(device) 95 | for args in samples: 96 | result = args[-1] 97 | extension_cpp.ops.myadd_out(*args) 98 | expected = torch.add(*args[:2]) 99 | torch.testing.assert_close(result, expected) 100 | 101 | def test_correctness_cpu(self): 102 | self._test_correctness("cpu") 103 | 104 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 105 | def test_correctness_cuda(self): 106 | self._test_correctness("cuda") 107 | 108 | def _opcheck(self, device): 109 | # Use opcheck to check for incorrect usage of operator registration APIs 110 | samples = self.sample_inputs(device, requires_grad=True) 111 | samples.extend(self.sample_inputs(device, requires_grad=False)) 112 | for args in samples: 113 | opcheck(torch.ops.extension_cpp.myadd_out.default, args) 114 | 115 | def test_opcheck_cpu(self): 116 | self._opcheck("cpu") 117 | 118 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 119 | def test_opcheck_cuda(self): 120 | self._opcheck("cuda") 121 | 122 | 123 | class TestTorchCompileStreamSync(TestCase): 124 | """Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile""" 125 | 126 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 127 | def test_compile_with_linear_layer(self): 128 | """Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)""" 129 | 130 | class Model(nn.Module): 131 | def __init__(self, size): 132 | super().__init__() 133 | self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32) 134 | 135 | def forward(self, x): 136 | return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0) 137 | 138 | # Test sizes that previously failed 139 | for size in [1000, 5000, 10000]: 140 | with self.subTest(size=size): 141 | torch.manual_seed(42) 142 | model = Model(size) 143 | x = torch.randn((1, size), device="cuda", dtype=torch.float32) 144 | 145 | with torch.no_grad(): 146 | expected = model(x) 147 | compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) 148 | actual = compiled_model(x) 149 | 150 | self.assertEqual(actual, expected) 151 | 152 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 153 | def test_compile_custom_only(self): 154 | """Test custom operations alone with torch.compile""" 155 | 156 | def model(x): 157 | return extension_cpp.ops.mymuladd(x, x, 1.0) 158 | 159 | for size in [1000, 5000, 10000]: 160 | with self.subTest(size=size): 161 | torch.manual_seed(42) 162 | x = torch.randn((size,), device="cuda", dtype=torch.float32) 163 | 164 | with torch.no_grad(): 165 | expected = model(x) 166 | compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) 167 | actual = compiled_model(x) 168 | 169 | self.assertEqual(actual, expected) 170 | 171 | 172 | if __name__ == "__main__": 173 | unittest.main() 174 | --------------------------------------------------------------------------------