├── requirements.txt ├── __init__.py ├── pybind ├── extern.hpp └── bind.cpp ├── .gitignore ├── src ├── foo.cpp ├── add.cuh ├── common.hpp ├── gpu.cuh ├── utils.hpp ├── add.cpp └── add.cu ├── Add.py ├── example.py ├── README.md ├── setup.py └── Makefile /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from MakePytorchBackend import AddGPU, Foo 4 | 5 | from Add import add_gpu 6 | -------------------------------------------------------------------------------- /pybind/extern.hpp: -------------------------------------------------------------------------------- 1 | template 2 | void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c); 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.o 3 | *.swo 4 | *.swp 5 | *.swn 6 | *.pyc 7 | 8 | build/ 9 | dist/ 10 | MinkowskiEngine/ 11 | -------------------------------------------------------------------------------- /src/foo.cpp: -------------------------------------------------------------------------------- 1 | #include "src/common.hpp" 2 | 3 | void Foo::setKey(uint64_t key) { key_ = key; } 4 | 5 | uint64_t Foo::getKey() { return key_; } 6 | -------------------------------------------------------------------------------- /src/add.cuh: -------------------------------------------------------------------------------- 1 | template 2 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 3 | cudaStream_t stream); 4 | -------------------------------------------------------------------------------- /Add.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from MakePytorchBackend import AddGPU 4 | 5 | 6 | def add_gpu(a, b): 7 | assert isinstance(a, torch.cuda.FloatTensor) \ 8 | and isinstance(b, torch.cuda.FloatTensor) 9 | assert a.numel() == b.numel() 10 | 11 | c = a.new() 12 | AddGPU(a, b, c) 13 | return c 14 | -------------------------------------------------------------------------------- /src/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON 2 | #define COMMON 3 | 4 | #include 5 | #include 6 | 7 | class Foo { 8 | private: 9 | uint64_t key_; 10 | 11 | public: 12 | void setKey(uint64_t key); 13 | uint64_t getKey(); 14 | std::string toString() const { 15 | return "< Foo, key: " + std::to_string(key_) + " > "; 16 | }; 17 | }; 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /src/gpu.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | constexpr int CUDA_NUM_THREADS = 128; 4 | 5 | constexpr int MAXIMUM_NUM_BLOCKS = 4096; 6 | 7 | inline int GET_BLOCKS(const int N) { 8 | return std::max(std::min((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS, 9 | MAXIMUM_NUM_BLOCKS), 10 | // Use at least 1 block, since CUDA does not allow empty block 11 | 1); 12 | } 13 | -------------------------------------------------------------------------------- /pybind/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "pybind/extern.hpp" 6 | #include "src/common.hpp" 7 | 8 | namespace py = pybind11; 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 11 | std::string name = std::string("Foo"); 12 | py::class_(m, name.c_str()) 13 | .def(py::init<>()) 14 | .def("setKey", &Foo::setKey) 15 | .def("getKey", &Foo::getKey) 16 | .def("__repr__", [](const Foo &a) { return a.toString(); }); 17 | 18 | m.def("AddGPU", &AddGPU); 19 | } 20 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MakePytorchPlusPlus as MPP 3 | 4 | 5 | def test_foo(): 6 | foo = MPP.Foo() 7 | print(foo) 8 | foo.setKey(3) 9 | print(foo) 10 | print(foo.getKey()) 11 | 12 | 13 | def test_add_gpu(): 14 | if not torch.cuda.is_available(): 15 | return 16 | a = torch.cuda.FloatTensor(4) 17 | b = torch.cuda.FloatTensor(4) 18 | a.normal_() 19 | b.normal_() 20 | c = MPP.add_gpu(a, b) 21 | print(a, b, c) 22 | 23 | 24 | if __name__ == '__main__': 25 | test_foo() 26 | test_add_gpu() 27 | -------------------------------------------------------------------------------- /src/utils.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class Formatter { 6 | public: 7 | Formatter() {} 8 | ~Formatter() {} 9 | 10 | template Formatter &operator<<(const Type &value) { 11 | stream_ << value; 12 | return *this; 13 | } 14 | 15 | std::string str() const { return stream_.str(); } 16 | operator std::string() const { return stream_.str(); } 17 | 18 | enum ConvertToString { to_str }; 19 | 20 | std::string operator>>(ConvertToString) { return stream_.str(); } 21 | 22 | private: 23 | std::stringstream stream_; 24 | Formatter(const Formatter &); 25 | Formatter &operator=(Formatter &); 26 | }; 27 | -------------------------------------------------------------------------------- /src/add.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "src/add.cuh" 5 | #include "src/utils.hpp" 6 | 7 | template 8 | void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { 9 | int N = in_a.numel(); 10 | if (N != in_b.numel()) 11 | throw std::invalid_argument(Formatter() 12 | << "Size mismatch A.numel(): " << in_a.numel() 13 | << ", B.numel(): " << in_b.numel()); 14 | 15 | out_c.resize_({N}); 16 | 17 | AddGPUKernel(in_a.data(), in_b.data(), 18 | out_c.data(), N, at::cuda::getCurrentCUDAStream()); 19 | } 20 | 21 | template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c); 22 | -------------------------------------------------------------------------------- /src/add.cu: -------------------------------------------------------------------------------- 1 | #include "src/gpu.cuh" 2 | #include "src/utils.hpp" 3 | 4 | template 5 | __global__ void sum(Dtype *a, Dtype *b, Dtype *c, int N) { 6 | int i = blockIdx.x * blockDim.x + threadIdx.x; 7 | if (i <= N) { 8 | c[i] = a[i] + b[i]; 9 | } 10 | } 11 | 12 | template 13 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 14 | cudaStream_t stream) { 15 | sum 16 | <<>>(in_a, in_b, out_c, N); 17 | 18 | cudaError_t err = cudaGetLastError(); 19 | if (cudaSuccess != err) 20 | throw std::runtime_error(Formatter() 21 | << "CUDA kernel failed : " << std::to_string(err)); 22 | } 23 | 24 | template void AddGPUKernel(float *in_a, float *in_b, float *out_c, int N, 25 | cudaStream_t stream); 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Developing a Pytorch CPP/CUDA Extension with a Makefile 2 | 3 | [Pytorch cpp extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) provides a good way to augment pytorch with custom functions. The cpp-extension uses the setuptool to compile files. However, as it is mainly used for deployment rather than debugging and development, using the setuptool for development can be slow and cumbersome. 4 | 5 | In this repository, I provide an alternative way to compile and debug your custom extension with a makefile. 6 | The associated tutorial can be found at the [blog post](https://chrischoy.github.io/research/pytorch-extension-with-makefile). 7 | 8 | 9 | ## Installation 10 | 11 | You must have `torch` installed in your current (virtual environment) python. 12 | 13 | ``` 14 | git clone https://github.com/chrischoy/MakePytorchPlusPlus 15 | cd MakePytorchPlusPlus 16 | python setup.py install 17 | ``` 18 | 19 | It automatically selects the maximum number of CPU for parallel compilation. 20 | 21 | 22 | ## Running the example 23 | 24 | ``` 25 | python example.py 26 | ``` 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 4 | 5 | os.system('make -j%d' % os.cpu_count()) 6 | 7 | # Python interface 8 | setup( 9 | name='PytorchMakefileTutorial', 10 | version='0.2.0', 11 | install_requires=['torch'], 12 | packages=['MakePytorchPlusPlus'], 13 | package_dir={'MakePytorchPlusPlus': './'}, 14 | ext_modules=[ 15 | CUDAExtension( 16 | name='MakePytorchBackend', 17 | include_dirs=['./'], 18 | sources=[ 19 | 'pybind/bind.cpp', 20 | ], 21 | libraries=['make_pytorch'], 22 | library_dirs=['objs'], 23 | # extra_compile_args=['-g'] 24 | ) 25 | ], 26 | cmdclass={'build_ext': BuildExtension}, 27 | author='Christopher B. Choy', 28 | author_email='chrischoy@ai.stanford.edu', 29 | description='Tutorial for Pytorch C++ Extension with a Makefile', 30 | keywords='Pytorch C++ Extension', 31 | url='https://github.com/chrischoy/MakePytorchPlusPlus', 32 | zip_safe=False, 33 | ) 34 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Uncomment for debugging 3 | # DEBUG := 1 4 | # Pretty build 5 | # Q ?= @ 6 | 7 | CXX := g++ 8 | 9 | # PYTHON Header path 10 | PYTHON_HEADER_DIR := $(shell python -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') 11 | PYTORCH_INCLUDES := $(shell python -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') 12 | PYTORCH_LIBRARIES := $(shell python -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') 13 | 14 | # CUDA ROOT DIR that contains bin/ lib64/ and include/ 15 | # CUDA_DIR := /usr/local/cuda 16 | CUDA_DIR := $(shell python -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') 17 | 18 | # Assume pytorch > v1.1 19 | WITH_ABI := $(shell python -c 'import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))') 20 | 21 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include 22 | 23 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR) 24 | INCLUDE_DIRS += $(PYTORCH_INCLUDES) 25 | 26 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories. 27 | # Leave commented to accept the defaults for your choice of BLAS 28 | # (which should work)! 29 | # BLAS_INCLUDE := /path/to/your/blas 30 | # BLAS_LIB := /path/to/your/blas 31 | 32 | ############################################################################### 33 | SRC_DIR := ./src 34 | OBJ_DIR := ./objs 35 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) 36 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) 37 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) 38 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) 39 | STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a 40 | 41 | # CUDA architecture setting: going with all of them. 42 | # For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. 43 | # For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. 44 | CUDA_ARCH := -gencode arch=compute_30,code=sm_30 \ 45 | -gencode arch=compute_35,code=sm_35 \ 46 | -gencode arch=compute_50,code=sm_50 \ 47 | -gencode arch=compute_52,code=sm_52 \ 48 | -gencode arch=compute_60,code=sm_60 \ 49 | -gencode arch=compute_61,code=sm_61 \ 50 | -gencode arch=compute_61,code=compute_61 51 | 52 | # We will also explicitly add stdc++ to the link target. 53 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu 54 | 55 | # Debugging 56 | ifeq ($(DEBUG), 1) 57 | COMMON_FLAGS += -DDEBUG -g -O0 58 | # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ 59 | NVCCFLAGS += -g -G # -rdc true 60 | else 61 | COMMON_FLAGS += -DNDEBUG -O3 62 | endif 63 | 64 | WARNINGS := -Wall -Wno-sign-compare -Wcomment 65 | 66 | INCLUDE_DIRS += $(BLAS_INCLUDE) 67 | 68 | # Automatic dependency generation (nvcc is handled separately) 69 | CXXFLAGS += -MMD -MP 70 | 71 | # Complete build flags. 72 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ 73 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=$(WITH_ABI) 74 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS) 75 | NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) 76 | 77 | all: $(STATIC_LIB) 78 | # python setup.py install --force 79 | 80 | $(OBJ_DIR): 81 | @ mkdir -p $@ 82 | @ mkdir -p $@/cuda 83 | 84 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) 85 | @ echo CXX $< 86 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 87 | 88 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) 89 | @ echo NVCC $< 90 | $(Q)nvcc $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ 91 | -odir $(@D) 92 | $(Q)nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 93 | 94 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) 95 | $(RM) -f $(STATIC_LIB) 96 | $(RM) -rf build dist 97 | @ echo LD -o $@ 98 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) 99 | 100 | clean: 101 | @- $(RM) -rf $(OBJ_DIR) build dist 102 | --------------------------------------------------------------------------------