├── .gitignore
├── 3rdparty
├── opencv
│ ├── linux.zip
│ └── win.zip
└── pybind11.zip
├── README.md
├── docs
├── README.pdf
├── files
│ ├── Custom C++ and CUDA Extensions — PyTorch Tutorials 1.12.1+cu102 documentation.pdf
│ ├── Keywords - setuptools 65.4.1.post20221009 documentation.pdf
│ ├── PyTorch 源码解读之 cpp_extension:揭秘 C++_CUDA 算子实现和调用全流程 - 知乎.pdf
│ ├── Python中使用C++_CUDA|以PointNet中的ball query 为例 - 知乎.pdf
│ ├── Pytorch拓展进阶(二):Pytorch结合C++以及Cuda拓展 - Oldpan的个人博客.pdf
│ ├── mmdetection源码剖析(1)--NMS.pdf
│ └── setup.py实现C++扩展和python库编译.pdf
└── imgs
│ ├── Dependencies_x64_Release.png
│ ├── compile_win.png
│ ├── dali.png
│ ├── pbind11_pipeline.pptx
│ ├── permute.png
│ ├── pybind11.png
│ ├── pybind11_pipeline.png
│ ├── windows_opencv_cfg.png
│ ├── 动态库拷贝到python文件同路径.png
│ └── 安装成功.png
├── mmcv_setup.py
├── orbbec
├── __init__.py
├── nms
│ ├── __init__.py
│ ├── nms_wrapper.py
│ └── src
│ │ ├── cpu
│ │ └── nms_cpu.cpp
│ │ ├── cuda
│ │ ├── nms_cuda.cpp
│ │ └── nms_kernel.cu
│ │ └── nms_ext.cpp
├── roi_align
│ ├── __init__.py
│ ├── gradcheck.py
│ ├── roi_align.py
│ └── src
│ │ ├── cpu
│ │ └── roi_align_v2.cpp
│ │ ├── cuda
│ │ ├── roi_align_kernel.cu
│ │ └── roi_align_kernel_v2.cu
│ │ └── roi_align_ext.cpp
├── utils
│ ├── __init__.py
│ └── src
│ │ └── compiling_info.cpp
└── warpaffine
│ ├── __init__.py
│ └── src
│ ├── cpu
│ ├── warpaffine_opencv.cpp
│ ├── warpaffine_torch_v1.cpp
│ └── warpaffine_torch_v2.cpp
│ ├── cuda
│ ├── warpaffine_cuda.cpp
│ └── warpaffine_kernel.cu
│ └── warpaffine_ext.cpp
├── requirements.txt
├── requirements
├── build.txt
├── docs.txt
├── optional.txt
├── readthedocs.txt
├── runtime.txt
└── tests.txt
├── scripts
├── benchmark.py
├── check.py
├── collect_env.py
├── cpp_img.png
├── demo.png
├── grad_check.py
├── image.jpg
├── opencv_world453.dll
├── py_img.png
├── test_nms.py
├── test_warpaffine_opencv.py
├── test_warpaffine_torch_cpu.py
├── test_warpaffine_torch_gpu.py
├── torch_affine_cpu.png
└── torch_affine_gpu.png
├── setup.py
└── tools
└── Dependencies_x64_Release.7z
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | *.pyd
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | mmdet/version.py
108 | ./data/
109 | .vscode
110 | .idea
111 | .DS_Store
112 |
113 | # custom
114 | *.pkl
115 | *.pkl.json
116 | *.log.json
117 | work_dirs/
118 |
119 | # Pytorch
120 | *.pth
121 | *.py~
122 | *.sh~
123 |
--------------------------------------------------------------------------------
/3rdparty/opencv/linux.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/3rdparty/opencv/linux.zip
--------------------------------------------------------------------------------
/3rdparty/opencv/win.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/3rdparty/opencv/win.zip
--------------------------------------------------------------------------------
/3rdparty/pybind11.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/3rdparty/pybind11.zip
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction:
2 |
3 | 一直以来,得益于GPU的日益发展,深度学习中网络训练以及部署推理速度越来越快,在各大主流的深度学习框架,诸如Pytorch、TensorFlow、OneFlow等都有很多算子对GPU的加速支持。从网络结构角度,Pytorch虽然已经使用了NVIDIA cuDNN、Intel MKL和NNPACK这些底层来加快训练速度,但是在某些情况下,比如我们要实现一些特定算法,如果只是用Pytorch已有的算子或操作远远不够。因为Pytorch虽然在特定操作上经过了很好的优化,但是对于Pytorch已经写好的这些操作,假如我们组合起来成为一个新的算子(OP),Pytorch不会管你的算法的具体执行流程,一般Pytorch只会按照设计好的操作去使用GPU,然后GPU可能不能充分利用或者直接超负载,并且python解释器也不能对此进行优化,导致训练过程变慢很多[^1]。从数据流角度,**深度学习一般都需要复杂的、多阶段的数据处理流程,包括数据加载、解码以及一定量的数据增强预处理操作,这些目前在CPU上执行的数据处理管道已经成为瓶颈,使得模型训练耗时很长大**。对于此,NVIDIA提出了Data Loading Library (DALI)[^2],通过将数据预处理交给GPU处理,缓解CPU瓶颈问题。DALI依赖于它自己的执行引擎,其构建目的是最大化输入管道的吞吐量。诸如预取、并行执行和批处理等特性都是为用户透明处理,如下图所示:
4 |
5 |
6 |
9 |
10 |
14 | DALI Pipeline
15 |
16 |
17 |
18 | 使用DALI以及配置DALI环境比较复杂,并且DALI当前的支持的函数实现也比较有限,具体使用可以看文献[^2]中的说明文档。实际开发中,对于一些复杂的特定数据增强操作,就需要自己实现。因此,本工程利用Pytorch的C++/CUDA扩展,实现GPU的数据增强,然后直接推送给网络,从而达到训练加速效果。为了指导大家系统性掌握该方面的相关知识,本工程也包含了python的C++拓展,且详细讲解了在需要依赖第三方库的情况下怎样编写setup.py文件以及相关配置,关于如何编译和测试,在后续有详细的讲解。
19 |
20 | ## Project Structure:
21 |
22 | ```python
23 | ├── 3rdparty # 工程依赖的第三方库
24 | │ ├── opencv
25 | │ │ ├── linux
26 | │ │ └── win
27 | │ └── pybind11
28 | ├── docs # 说明文档及相关资料
29 | ├── requirements # python相关安装依赖
30 | ├── requirements.txt # python相关安装依赖项, 与requirements文件夹配合
31 | ├── scripts # 相关测试脚本
32 | ├── tools # 分析工具
33 | ├── orbbec # 源码文件
34 | │ ├── nms # 非极大值抑制
35 | │ ├── roi_align # ROI Align
36 | │ ├── utils # 编译工具函数
37 | │ └── warpaffine # 仿射变换增强
38 | └── setup.py # 用于编译和构建python包(.egg), 类似:CMakeLists.txt
39 | ```
40 |
41 | ## Compilation And Python Environment :
42 |
43 | ### Compile Environment:
44 |
45 | - [x] **GCC/G++ >= 5.5.0(Visual Studio 2017 or newer for Windows)**
46 | - [x] **CUDA(NVCC): 10.1~11.5**
47 | ### Python Environment(requirements.txt):
48 |
49 | ```python
50 | certifi==2021.5.30
51 | cycler==0.11.0
52 | future==0.18.2
53 | kiwisolver==1.3.1
54 | matplotlib==3.3.4
55 | mkl-fft==1.3.0
56 | mkl-random==1.1.1
57 | mkl-service==2.3.0
58 | numpy @ file:///C:/ci/numpy_and_numpy_base_1603480701039/work
59 | olefile==0.46
60 | opencv-python==3.4.0.12
61 | Pillow @ file:///C:/ci/pillow_1625663293114/work
62 | pyparsing==3.0.9
63 | python-dateutil==2.8.2
64 | six @ file:///tmp/build/80754af9/six_1644875935023/work
65 | terminaltables==3.1.10
66 | torch==1.5.0
67 | torchvision==0.6.0
68 | wincertstore==0.2
69 | ```
70 |
71 | ### Python Package infos:
72 |
73 | ```python
74 | Package Version
75 | --------------- ---------
76 | certifi 2016.2.28
77 | cycler 0.11.0
78 | Cython 0.29.32
79 | future 0.18.2
80 | kiwisolver 1.3.1
81 | matplotlib 3.3.4
82 | mkl-fft 1.3.0
83 | mkl-random 1.1.1
84 | mkl-service 2.3.0
85 | numpy 1.19.2
86 | olefile 0.44
87 | opencv-python 3.4.0.12
88 | Pillow 8.3.1
89 | pip 21.3.1
90 | pyparsing 3.0.9
91 | python-dateutil 2.8.2
92 | setuptools 59.6.0
93 | six 1.10.0
94 | terminaltables 3.1.10
95 | torch 1.5.0
96 | torchvision 0.6.0
97 | wheel 0.29.0
98 | wincertstore 0.2
99 | ```
100 |
101 | 【注】:上述环境中的Pytorch版本需要对应的CUDA版本,本工程支持的Pytorch版本:**Pytorch version:1.5.0~latest**。
102 |
103 | ## C++ And CUDA Extensions For Python/Pytorch:
104 |
105 | C++与python或pytotch的交互,业界主流做法是采用**pybind11**,关于Pybind11的更多详细说明可以参看文献[^15],其核心原理如下图所示:
106 |
107 |
108 |
111 |
112 |
116 | pybind11 pipeline
117 |
118 |
119 | 由于Pytorch的C++拓展与纯Python有一些区别,因为Pytorch的基础数据类型是**torch.Tensor**,该数据类型可以认为是Pytorch库对np.array进行了更高一层的封装。所以,在写拓展程序时,其接口函数所需要的数据类型以及调用的库会有些区别,下面会详细解释。
120 |
121 | ### C++ Extensions For Python:
122 |
123 | 首先我们看Python代码,如下所示(scripts/test_warpaffine_opencv.py):
124 |
125 | ```python
126 | import cv2
127 | import torch # 不能删掉, 因为需要动态加载torch的一些动态库,后面会详细说明.
128 | import numpy as np
129 | from orbbec.warpaffine import affine_opencv # C++ interface
130 |
131 | data_path = "./demo.png"
132 | img = cv2.imread(data_path, cv2.IMREAD_GRAYSCALE)
133 |
134 | # python中的numpy.array()与 pybind中的py::array_t一一对应.
135 | src_point = np.array([[262.0, 324.0], [325.0, 323.0], [295.0, 349.0]], dtype=np.float32)
136 | dst_point = np.array([[38.29, 51.69], [73.53, 51.69], [56.02, 71.73]], dtype=np.float32)
137 | # python interface
138 | mat_trans = cv2.getAffineTransform(src_point, dst_point)
139 | res = cv2.warpAffine(img, mat_trans, (600,800))
140 | cv2.imwrite("py_img.png", res)
141 |
142 | # C++ interface
143 | warpffine_img = affine_opencv(img, src_point, dst_point)
144 | cv2.imwrite("cpp_img.png", warpffine_img)
145 | ```
146 |
147 | 从上述代码可以看到,python文件中调用了affine_opencv函数,而affine_opencv的C++实现在orbbec/warpaffine/src/cpu/warpaffine_opencv.cpp中,如下所示:
148 |
149 | ```c++
150 | #include
151 | #include
152 | #include
153 | #include
154 | #include
155 | #include
156 |
157 |
158 | namespace py = pybind11;
159 |
160 | /* Python->C++ Mat */
161 | cv::Mat numpy_uint8_1c_to_cv_mat(py::array_t& input)
162 | {
163 | ...
164 | }
165 |
166 | cv::Mat numpy_uint8_3c_to_cv_mat(py::array_t& input)
167 | {
168 | ...
169 | }
170 |
171 | /* C++ Mat ->numpy */
172 | py::array_t cv_mat_uint8_1c_to_numpy(cv::Mat& input)
173 | {
174 | ...
175 | }
176 |
177 | py::array_t cv_mat_uint8_3c_to_numpy(cv::Mat& input)
178 | {
179 | ...
180 | }
181 |
182 | py::array_t affine_opencv(py::array_t& input,
183 | py::array_t& from_point,
184 | py::array_t& to_point)
185 | {
186 | ...
187 | }
188 | ```
189 |
190 | 由于本工程同时兼容了Pytorch的C++/CUDA拓展,为了更加规范,这里在拓展接口程序(orbbec/warpaffine/src/warpaffine_ext.cpp)中通过PYBIND11_MODULE定义好接口,如下所示:
191 |
192 | ```c++
193 | #include
194 | #include
195 |
196 | // python的C++拓展函数申明
197 | py::array_t affine_opencv(py::array_t& input,
198 | py::array_t& from_point,
199 | py::array_t& to_point);
200 |
201 | // Pytorch的C++拓展函数申明(CPU)
202 | at::Tensor affine_cpu(const at::Tensor& input, /*[B, C, H, W]*/
203 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
204 | const int out_h,
205 | const int out_w);
206 |
207 | // Pytorch的CUDA拓展函数申明(GPU)
208 | #ifdef WITH_CUDA
209 | at::Tensor affine_gpu(const at::Tensor& input, /*[B, C, H, W]*/
210 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
211 | const int out_h,
212 | const int out_w);
213 | #endif
214 |
215 | // 通过WITH_CUDA宏进一步封装Pytorch的拓展接口
216 | at::Tensor affine_torch(const at::Tensor& input, /*[B, C, H, W]*/
217 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
218 | const int out_h,
219 | const int out_w)
220 | {
221 | if (input.device().is_cuda())
222 | {
223 | #ifdef WITH_CUDA
224 | return affine_gpu(input, affine_matrix, out_h, out_w);
225 | #else
226 | AT_ERROR("affine is not compiled with GPU support");
227 | #endif
228 | }
229 | return affine_cpu(input, affine_matrix, out_h, out_w);
230 | }
231 |
232 | // 使用pybind11模块定义python/pytorch接口
233 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
234 | m.def("affine_opencv", &affine_opencv, "affine with c++ opencv");
235 | m.def("affine_torch", &affine_torch, "affine with c++ libtorch");
236 | }
237 | ```
238 |
239 | 从上面代码可以看出,**python中的np.array数组与pybind11的py::array_t相互对应,也即python接口函数中,传入的np.array数组,在C++对应的函数中用py::array_t接收,操作Numpy数组,需要引入头文件**。数组本质上在底层是一块一维的连续内存区,通过pybind11中的request()函数可以把数组解析成py::buffer_info结构体,buffer_info类型可以公开一个缓冲区视图,它提供对内部数据的快速直接访问,如下代码所示:
240 |
241 | ```c++
242 | struct buffer_info {
243 | void *ptr; // 指向数组(缓冲区)数据的指针
244 | py::ssize_t itemsize; // 数组元素总数
245 | std::string format; // 数组元素格式(python表示的类型)
246 | py::ssize_t ndim; // 数组维度信息
247 | std::vector shape; // 数组形状
248 | std::vector strides; // 每个维度相邻元素的间隔(字节数表示)
249 | };
250 | ```
251 |
252 | 在写好C++源码以后,在setup.py中将相关C++源文件,以及依赖的第三方库:opencv、pybind11的路径写入对应位置(本工程已经写好,请具体看setup.py文件),然后进行编译和安装:
253 |
254 | ```python
255 | # 切换工作路径
256 | step 1: cd F:/code/python_cpp_extension
257 | # 编译
258 | step 2: python setup.py develop
259 | # 安装, 如果没有指定--prefix, 则最终编译成功的安装包(.egg)文件会安装到对应的python环境下的site-packages下.
260 | step 3: python setup.py install
261 | ```
262 |
263 | 【注】:关于工程文件中的setup.py相关知识可以参考文献[^7][^12][^13],该三篇文献对此有详细的解释。执行step2和step3之后,如下图所示,最终源码文件会编译成.pyd二进制文件(linux系统下编译成.so文件),且会生成一个python包文件:**orbbec-0.0.1-py36-win-amd64.egg**,包名取决于setup.py中规定的name和version信息,该安装包会被安装在当前python环境的site-packages文件夹下。同时,在终端执行命令:**pip list**,会发现安装包以及对应的版本信息。安装成功后,也就意味着,在该python环境(本工程的python环境是cpp_extension)下,可以在任何一个python文件中,导入orbbec安装包中的接口函数,比如上述scripts/test_warpaffine_opencv.py文件中的语句:**``from orbbec.warpaffine import affine_opencv``**。
264 |
265 |
266 |
269 |
270 |
274 | 编译和安装成功
275 |
276 |
277 |
278 |
281 |
282 |
286 | 安装成功后, packages信息
287 |
288 |
289 |
290 | 编译完成后,可以运行tools/collect_env.py,查看当前一些必要工具的版本等一系列信息,输出如下:
291 |
292 | ```python
293 | sys.platform : win32
294 | Python : 3.6.13 |Anaconda, Inc.| (default, Mar 16 2021, 11:37:27) [MSC v.1916 64 bit (AMD64)]
295 | CUDA available : True
296 | CUDA_HOME : C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1
297 | NVCC : Not Available
298 | GPU 0 : NVIDIA GeForce GTX 1650
299 | OpenCV : 3.4.0
300 | PyTorch : 1.5.0
301 | PyTorch compiling details : PyTorch built with:
302 | - C++ Version: 199711
303 | - MSVC 191627039
304 | - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191125 for Intel(R) 64 architecture applications
305 | - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
306 | - OpenMP 200203
307 | - CPU capability usage: AVX2
308 | - CUDA Runtime 10.1
309 | - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
310 | - CuDNN 7.6.4
311 | - Magma 2.5.2
312 | - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /w /EHa /bigobj -openmp -DNDEBUG -DUSE_FBGEMM, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,
313 |
314 | TorchVision : 0.6.0
315 | C/C++ Compiler : MSVC 191627045
316 | CUDA Compiler : 10.1
317 | ```
318 |
319 | 在运行scripts/test_warpaffine_opencv.py文件之前,由于warpaffine_opencv.cpp源码用到相关opencv库,因此,还需要配置动态库路径,windows系统配置如下:
320 |
321 |
322 |
325 |
326 |
330 | opencv库路径配置
331 |
332 |
333 |
334 | Linux系统同样也需要配置进行配置,命令如下:
335 |
336 | ```python
337 | root@aistation:/xxx/code/python_cpp_extension# export LD_LIBRARY_PATH=/xxx/code/python_cpp_extension/3rdparty/opencv/linux/lib
338 | root@aistation:/xxx/code/python_cpp_extension# ldconfig
339 | ```
340 |
341 | 也可以通过修改~/.bashrc文件,加入上述``export LD_LIBRARY_PATH=/...``,然后命令:``source ~/.bashrc``。也可以直接修改配置文件/etc/profile,与修改.bashrc文件一样,对所有用户有效。
342 |
343 | 如果是在服务器上训练模型,比如浪潮的AIStation,则可以将上述命令写入.sh脚本中,然后训练时直接脚本启动即可。如下所示:
344 |
345 | ```sh
346 | # add dll path to env
347 | export LD_LIBRARY_PATH=/jiaozhu01/code/insightface_ir_train
348 | ldconfig
349 |
350 | # run
351 | cd /jiaozhu01/code/insightface_ir_train/
352 | OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 --master_addr="127.0.0.1" --master_port=12581 train.py configs/mbf200.py
353 |
354 | # kill process
355 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
356 | ```
357 |
358 | 可以通过tools下的Dependencies_x64_Release工具(运行:DependenciesGui.exe),查看编译好的文件(.pyd)依赖的动态库是否都配置完好,如下图所示:
359 |
360 |
361 |
364 |
365 |
369 | 检查编译好的动态库依赖的动态库路径
370 |
371 |
372 | 可以发现,该工具没有找到python36.dll、c10.dll、torch_cpu.dll、torch_python.dll和c10_cuda.dll的路径,这里说明一下,python相关的dll库以及torch相关的动态库是动态加载的,也就是说,如果你在python代码中写一句:**import torch**,只有在程序运行时才会动态加载torch相关库。所以,Dependencies_x64_Release工具检查不到编译好的warpaffine_ext.cp36-win_amd64.pyd文件依赖完好性。这里还需要说明一下为什么warpaffine_ext.cp36-win_amd64.pyd需要依赖torch相关库,这是因为源文件orbbec/warpaffine/src/warpaffine_ext.cpp兼容了pytorch的c++拓展,所以依赖torch和cuda相关动态库文件,如果你单纯只在orbbec/warpaffine/src/warpaffine_ext.cpp实现纯粹python的c++拓展,则是不需要依赖torch和cuda相关动态库。
373 |
374 | 配置好之后,还需要将warpaffine_ext.cp36-win_amd64.pyd无法**动态加载**的动态库文件(opencv_world453.dll)放到scripts/test_warpaffine_opencv.py同路径之下(Linux系统也一样),如下图所示:
375 |
376 |
377 |
380 |
381 |
385 | 拷贝动态库与测试脚本同一目录
386 |
387 |
388 | 需要注意一个问题,有时候,如果在docker中进行编译和安装,其最终生成的python安装包(.egg)文件并不会安装到当前python环境下的site-packages中。也就意味着,在python文件中执行:``from orbbec.warpaffine import affine_opencv``会失败。原因是``orbbec.warpaffine``并不在其python的搜索路径中,这个时候有两种解决办法,一种是在执行:``python setup.py install``时,加上``--prefix='install path'``,但是经过本人验证,有时候不可行,另外一种办法是在python文件中,将``orbbec``文件夹路径添加到python的搜索路径中,如下所示:
389 |
390 | ```python
391 | import cv2
392 | import torch # 不能删掉, 因为需要动态加载torch的一些动态库.
393 | import numpy as np
394 |
395 | # 添加下述两行代码,这里默认此python脚本所在目录的上一层目录路径包含orbbec文件夹.
396 | _FILE_PATH = os.path.dirname(os.path.abspath(__file__))
397 | sys.path.insert(0, os.path.join(_FILE_PATH, "../"))
398 |
399 | from orbbec.warpaffine import affine_opencv # C++ interface
400 | ```
401 |
402 | ### C++/CUDA Extensions For Pytorch:
403 |
404 | Pytorch的C++/CUDA拓展同样也是利用Pybind11工具,但是,由于Pytorch使用的基础数据类型是torch.Tensor类型,因此,在写拓展程序中,必须要有libtorch库中对应的数据类型与Pytorch的tensor类型对应,这样才能进行正确传参。这里需要知道Pytorch对应的C++版本libtorch中几个常用的库和命名空间:
405 |
406 | 常用的**命名空间**:
407 | ①:at(ATen)负责声明和定义Tensor运算,是最常用到的命名空间;
408 | ②:c10是ATen的基础,包含了PyTorch的核心抽象、Tensor和Storage数据结构的实际实现;
409 | ③:torch命名空间下定义的Tensor相比于ATen 增加自动求导功能。
410 |
411 | Pytorch的Aten目录下的主要构成:
412 | ①:ATen(ATen核心源文件);
413 | ②:TH(Torch 张量计算库);
414 | ③:THC(Torch CUDA张量计算库);
415 | ④:THCUNN(Torch CUDA神经网络库);
416 | ⑤:THNN(Torch神经网络库)。
417 |
418 | C10是Caffe Tensor Library的缩写。这里存放的都是最基础的Tensor库的代码,可以运行在服务端和移动端,C10主要目的之一是为了统一Pytorch的张量计算后端代码和caffe2的张量计算后端代码。
419 | libtorch中还有个csrc模块,主要适用于**C++和python的API之间的相互映射**,比如pytorch的nn.Conv2d对应于torch中的at:conv2d,其次是autograd和自动求导机制。了解如上内容后,首先来看python测试代码,如下所示(scripts/test_warpaffine_torch_cpu.py):
420 |
421 | ```python
422 | import cv2
423 | import torch
424 | import numpy as np
425 | from orbbec.warpaffine import affine_torch # C++ interface
426 |
427 | data_path = "demo.png"
428 |
429 | img = cv2.imread(data_path)
430 | # transform img(numpy.array) to tensor(torch.Tensor)
431 | # use permute
432 | img_tensor = torch.from_numpy(img / 255.0).permute(2, 0, 1).contiguous()
433 | img_tensor = img_tensor.unsqueeze(0).float()
434 |
435 | src_tensor = torch.tensor([[38.29, 51.69, 1.0], [73.53, 51.69, 1.0], [56.02, 71.73, 1.0]], dtype=torch.float32).unsqueeze(0)
436 | dst_tensor = torch.tensor([[262.0, 324.0], [325.0, 323.0], [295.0, 349.0]], dtype=torch.float32).unsqueeze(0)
437 |
438 | # compute affine transform matrix
439 | matrix_l = torch.transpose(src_tensor, 1, 2).bmm(src_tensor)
440 | matrix_l = torch.inverse(matrix_l)
441 | matrix_r = torch.transpose(src_tensor, 1, 2).bmm(dst_tensor)
442 | affine_matrix = torch.transpose(matrix_l.bmm(matrix_r), 1, 2)
443 |
444 | warpffine_img = affine_torch(img_tensor, affine_matrix, 112, 112)
445 |
446 | warpffine_img = warpffine_img.squeeze(0).permute(1, 2, 0).numpy()
447 | cv2.imwrite("torch_affine_cpu.png", np.uint8(warpffine_img * 255.0))
448 | ```
449 |
450 | 从上述代码可以看到,python文件中调用了affine_torch函数,并且传入的参数类型是cpu类型的tensor,而affine_torch的C++实现在orbbec/warpaffine/src/warpaffine_ext.cpp中,如下所示:
451 |
452 | ```c++
453 | #include
454 | #include
455 |
456 | // python的C++拓展函数申明
457 | py::array_t affine_opencv(py::array_t& input,
458 | py::array_t& from_point,
459 | py::array_t& to_point);
460 |
461 | // Pytorch的C++拓展函数申明(CPU)
462 | at::Tensor affine_cpu(const at::Tensor& input, /*[B, C, H, W]*/
463 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
464 | const int out_h,
465 | const int out_w);
466 |
467 | // Pytorch的CUDA拓展函数申明(GPU)
468 | #ifdef WITH_CUDA
469 | at::Tensor affine_gpu(const at::Tensor& input, /*[B, C, H, W]*/
470 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
471 | const int out_h,
472 | const int out_w);
473 | #endif
474 |
475 | // 通过WITH_CUDA宏进一步封装Pytorch的拓展接口
476 | at::Tensor affine_torch(const at::Tensor& input, /*[B, C, H, W]*/
477 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
478 | const int out_h,
479 | const int out_w)
480 | {
481 | if (input.device().is_cuda())
482 | {
483 | #ifdef WITH_CUDA
484 | return affine_gpu(input, affine_matrix, out_h, out_w);
485 | #else
486 | AT_ERROR("affine is not compiled with GPU support");
487 | #endif
488 | }
489 | return affine_cpu(input, affine_matrix, out_h, out_w);
490 | }
491 |
492 | // 使用pybind11模块定义python/pytorch接口
493 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
494 | m.def("affine_opencv", &affine_opencv, "affine with c++ opencv");
495 | m.def("affine_torch", &affine_torch, "affine with c++ libtorch");
496 | }
497 | ```
498 |
499 | 从上述代码可以看出,根据宏WITH_CUDA和tensor类型控制affine_torch最终底层执行affine_cpu还是affine_gpu函数。同时也注意到,**python中的torch.Tensor类型与libtorch库中的at::Tensor对应**。再看看affine_cpu函数的具体实现(orbbec/warpaffine/src/cpu/warpaffine_torch_v2.cpp):
500 |
501 | ```c++
502 | at::Tensor affine_cpu(const at::Tensor& input, /*[B, C, H, W]*/
503 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
504 | const int out_h,
505 | const int out_w)
506 | {
507 | at::Tensor result;
508 | // AT_DISPATCH_FLOATING_TYPES: input.scalar_type() => scalar_t
509 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "affine_cpu", [&] {
510 | result = affine_torch_cpu(input, affine_matrix, out_h, out_w);
511 | });
512 | return result;
513 | }
514 | ```
515 |
516 | 进一步看affine_torch_cpu函数的具体实现(orbbec/warpaffine/src/cpu/warpaffine_torch_v2.cpp):
517 |
518 | ```c++
519 | template
520 | at::Tensor affine_torch_cpu(const at::Tensor& input, /*[B, C, H, W]*/
521 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
522 | const int out_h,
523 | const int out_w)
524 | {
525 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
526 | AT_ASSERTM(affine_matrix.device().is_cpu(), "affine_matrix must be a CPU tensor");
527 |
528 | auto matrix_ptr = affine_matrix.contiguous().data_ptr();
529 | auto input_ptr = input.contiguous().data_ptr();
530 | auto nimgs = input.size(0);
531 | auto img_c = input.size(1);
532 | auto img_h = input.size(2);
533 | auto img_w = input.size(3);
534 | auto in_img_size = img_c * img_h * img_w;
535 | auto out_img_size = img_c * out_h * out_w;
536 |
537 | // build dst tensor
538 | auto output_tensor = at::zeros({nimgs, img_c, out_h, out_w}, input.options());
539 | auto output_ptr = output_tensor.contiguous().data_ptr();
540 |
541 | for(int i = 0; i < nimgs; i++)
542 | {
543 | scalar_t* matrix = matrix_ptr + i * 6;
544 | scalar_t* in = input_ptr + i * in_img_size;
545 | scalar_t* out = output_ptr + i * out_img_size;
546 | affine_cpu_kernel(img_h, img_w, img_c, img_w*img_h,
547 | out_h, out_w, out_h*out_w, out, in, matrix, 0.0f);
548 | }
549 |
550 | return output_tensor;
551 | }
552 | ```
553 |
554 | 这里有一个非常注意的地方就是,上述代码中的**tensor的.contiguous()方法(上述代码第10、11、21行)**。可以看到,我们在获取tensor的数据指针时候(**data_ptr()**),Pytorch官方示例代码和mmdtection/mmcv中的一些相关代码都推荐先做这个操作。这是因为,不管是在python还是在c++代码中,使用**permute()、transpose()、view()**等方法操作返回一个新的tensor时,其与旧的tensor是**共享数据存储**,所以他们的storage不会发生变化,只是会重新返回一个新的view,这样做的目的是减少数据拷贝,减少内存消耗,一定程度上加速网络训练或推理过程,如果在Python端对tensor做了.contiguous()操作,则在C++端就不需要再做了,因为.contiguous()是一个深拷贝操作。
555 |
556 |
557 |
560 |
561 |
565 | permute操作分析
566 |
567 |
568 | 接下来,再来看pytorch的CUDA扩展,首先测试文件test_warpaffine_torch_gpu.py如下:
569 |
570 | ```python
571 | import cv2
572 | import torch
573 | import numpy as np
574 | from orbbec.warpaffine import affine_torch # CUDA interface
575 |
576 | data_path = "demo.png"
577 |
578 | img = cv2.imread(data_path)
579 | # transform img(numpy.array) to tensor(torch.Tensor)
580 | # use permute
581 | img_tensor = torch.from_numpy(img / 255.0).permute(2, 0, 1).contiguous()
582 | img_tensor = img_tensor.unsqueeze(0).float()
583 | img_tensor = img_tensor.cuda() # gpu tensor
584 |
585 | # dst -> src
586 | src_tensor = torch.tensor([[38.29, 51.69, 1.0], [73.53, 51.69, 1.0], [56.02, 71.73, 1.0]], dtype=torch.float32).unsqueeze(0)
587 | dst_tensor = torch.tensor([[262.0, 324.0], [325.0, 323.0], [295.0, 349.0]], dtype=torch.float32).unsqueeze(0)
588 | src_tensor = src_tensor.cuda() # gpu tensor
589 | dst_tensor = dst_tensor.cuda() # gpu tensor
590 |
591 | # compute affine transform matrix
592 | matrix_l = torch.transpose(src_tensor, 1, 2).bmm(src_tensor)
593 | matrix_l = torch.inverse(matrix_l)
594 | matrix_r = torch.transpose(src_tensor, 1, 2).bmm(dst_tensor)
595 | affine_matrix = torch.transpose(matrix_l.bmm(matrix_r), 1, 2)
596 | affine_matrix = affine_matrix.contiguous().cuda() # gpu tensor
597 |
598 | warpffine_img = affine_torch(img_tensor, affine_matrix, 112, 112)
599 | warpffine_img = warpffine_img.cpu().squeeze(0).permute(1, 2, 0).numpy()
600 | cv2.imwrite("torch_affine_gpu.png", np.uint8(warpffine_img * 255.0))
601 | ```
602 |
603 | 从上述脚本代码可以看到,affine_torch接收的是GPU类型的Tensor数据,其底层会在GPU上执行相关计算。进一步分析orbbec/warpaffine/src/warpaffine_ext.cpp中的affine_torch()函数的CUDA接口,可以发现,最终调用的是affine_gpu()函数,如下代码所示:
604 |
605 | ```c++
606 | at::Tensor affine_gpu(const at::Tensor& input, /*[B, C, H, W]*/
607 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
608 | const int out_h,
609 | const int out_w)
610 | {
611 | CHECK_INPUT(input);
612 | CHECK_INPUT(affine_matrix);
613 |
614 | // Ensure CUDA uses the input tensor device.
615 | at::DeviceGuard guard(input.device());
616 |
617 | return affine_cuda_forward(input, affine_matrix, out_h, out_w);
618 | }
619 | ```
620 |
621 | 可以发现,最终执行的是affine_cuda_forward()函数,如下代码所示:
622 |
623 | ```c++
624 | at::Tensor affine_cuda_forward(const at::Tensor& input, /*[B, C, H, W]*/
625 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
626 | const int out_h,
627 | const int out_w)
628 | {
629 | // build dst tensor
630 | auto nimgs = input.size(0);
631 | auto img_c = input.size(1);
632 | auto img_h = input.size(2);
633 | auto img_w = input.size(3);
634 | const int output_size = nimgs * img_c * out_h * out_w;
635 | auto output_tensor = at::zeros({nimgs, img_c, out_h, out_w}, input.options());
636 |
637 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "affine_cuda", [&] {
638 | auto matrix_ptr = affine_matrix.data_ptr();
639 | auto input_ptr = input.data_ptr();
640 | auto output_ptr = output_tensor.data_ptr();
641 |
642 | // launch kernel function on GPU with CUDA.
643 | affine_gpu_kernel<<>>(output_size, img_h,
645 | img_w, img_c, out_h, out_w, output_ptr, input_ptr, matrix_ptr, 0.0f);
646 | });
647 |
648 | return output_tensor;
649 | }
650 | ```
651 |
652 | 通过配置grid_size和block_size之后,启动核函数: **affine_gpu_kernel**,关于核函数这一部分涉及很多CUDA知识,这里并不进行展开说明。最终返回gpu类型的output_tensor给python接口。
653 |
654 | ## GPU-Accelerated Augmentation:
655 |
656 | 在掌握了pytorch的C++/CUDA拓展之后,我们就可以轻松做到与NVIDIA的DALI库一样的加速效果,不管多么复杂的数据增强,都可以通过上述操作进行一定程度上的加速,伪代码如下所示(假设编译和安装步骤都已完成):
657 |
658 | ```python
659 | for _, (img, local_labels) in enumerate(train_loader):
660 | global_step += 1
661 | # 这里假设从train_loader取出的gpu类型的Tensor, 如果是cpu类型的Tensor, 则需要首先放到对应的编号为:local_rank的GPU上.
662 | # local_rank = torch.distributed.get_rank()
663 | # ================== add data augmentation (这里只做一个示意)===================
664 | batch = img.shape[0] # get batchsize
665 | devive = img.device # get local_rank
666 | src_tensor = torch.tensor([[38.29, 51.69, 1.0], [73.53, 51.69, 1.0], [56.02, 71.73, 1.0]],dtype=torch.float32).unsqueeze(0)
667 | dst_tensor = torch.tensor([[42.0, 52.0], [78.0, 55.0], [58.0, 74.0]], dtype=torch.float32).unsqueeze(0)
668 | src_tensor = src_tensor.repeat(batch, 1, 1)
669 | dst_tensor = dst_tensor.repeat(batch, 1, 1)
670 | # compute affine transform matrix
671 | matrix_l = torch.transpose(src_tensor, 1, 2).bmm(src_tensor)
672 | matrix_l = torch.inverse(matrix_l)
673 | matrix_r = torch.transpose(src_tensor, 1, 2).bmm(dst_tensor)
674 | affine_matrix = torch.transpose(matrix_l.bmm(matrix_r), 1, 2)
675 | affine_matrix = affine_matrix.contiguous().to(devive) # python端做了.contiguous()操作, 则CUDA拓展底层不需要再做.
676 | img = affine_torch(img, affine_matrix, 112, 112) # 在gpu上进行数据增强
677 | # ==============================================================================
678 | local_embeddings = backbone(img)
679 | loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
680 | ```
681 |
682 | 【**注**】:本工程编译完成后,可以将orbbec文件夹直接拷贝自己的训练工程,然后在对应的需要调用拓展函数的python文件中(比如上述代码在train.py文件中),通过之前提到的方法,将orbbec文件夹所在路径加入到python环境中,就可以正常调用拓展函数了(比如:affine_torch)。
683 |
684 | ## Reference:
685 |
686 | [^1]: https://pytorch.org/tutorials/advanced/cpp_extension.html
687 | [^2]: https://github.com/NVIDIA/DALI
688 | [^3]:https://github.com/open-mmlab/mmdetection/tree/v2.0.0
689 | [^4]: https://github.com/open-mmlab/mmcv
690 | [^5]: https://github.com/openppl-public/ppl.cv
691 | [^6]: https://github.com/pytorch/extension-cpp
692 | [^7]: https://setuptools.pypa.io/en/latest/references/keywords.html
693 | [^8]: https://www.bbsmax.com/A/MAzAwZAo59/
694 | [^9]: https://zhuanlan.zhihu.com/p/419076427
695 | [^10]: https://zhuanlan.zhihu.com/p/348555597
696 | [^11]: https://oldpan.me/archives/pytorch-cuda-c-plus-plus
697 | [^12]: https://docs.python.org/zh-cn/3/extending/building.html
698 | [^13]: https://zhuanlan.zhihu.com/p/276461821
699 | [^14]: https://blog.51cto.com/u_15357586/3784891
700 | [^15]: https://zhuanlan.zhihu.com/p/383572973
701 |
--------------------------------------------------------------------------------
/docs/README.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/README.pdf
--------------------------------------------------------------------------------
/docs/files/Custom C++ and CUDA Extensions — PyTorch Tutorials 1.12.1+cu102 documentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/Custom C++ and CUDA Extensions — PyTorch Tutorials 1.12.1+cu102 documentation.pdf
--------------------------------------------------------------------------------
/docs/files/Keywords - setuptools 65.4.1.post20221009 documentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/Keywords - setuptools 65.4.1.post20221009 documentation.pdf
--------------------------------------------------------------------------------
/docs/files/PyTorch 源码解读之 cpp_extension:揭秘 C++_CUDA 算子实现和调用全流程 - 知乎.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/PyTorch 源码解读之 cpp_extension:揭秘 C++_CUDA 算子实现和调用全流程 - 知乎.pdf
--------------------------------------------------------------------------------
/docs/files/Python中使用C++_CUDA|以PointNet中的ball query 为例 - 知乎.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/Python中使用C++_CUDA|以PointNet中的ball query 为例 - 知乎.pdf
--------------------------------------------------------------------------------
/docs/files/Pytorch拓展进阶(二):Pytorch结合C++以及Cuda拓展 - Oldpan的个人博客.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/Pytorch拓展进阶(二):Pytorch结合C++以及Cuda拓展 - Oldpan的个人博客.pdf
--------------------------------------------------------------------------------
/docs/files/mmdetection源码剖析(1)--NMS.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/mmdetection源码剖析(1)--NMS.pdf
--------------------------------------------------------------------------------
/docs/files/setup.py实现C++扩展和python库编译.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/files/setup.py实现C++扩展和python库编译.pdf
--------------------------------------------------------------------------------
/docs/imgs/Dependencies_x64_Release.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/Dependencies_x64_Release.png
--------------------------------------------------------------------------------
/docs/imgs/compile_win.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/compile_win.png
--------------------------------------------------------------------------------
/docs/imgs/dali.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/dali.png
--------------------------------------------------------------------------------
/docs/imgs/pbind11_pipeline.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/pbind11_pipeline.pptx
--------------------------------------------------------------------------------
/docs/imgs/permute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/permute.png
--------------------------------------------------------------------------------
/docs/imgs/pybind11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/pybind11.png
--------------------------------------------------------------------------------
/docs/imgs/pybind11_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/pybind11_pipeline.png
--------------------------------------------------------------------------------
/docs/imgs/windows_opencv_cfg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/windows_opencv_cfg.png
--------------------------------------------------------------------------------
/docs/imgs/动态库拷贝到python文件同路径.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/动态库拷贝到python文件同路径.png
--------------------------------------------------------------------------------
/docs/imgs/安装成功.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenCVer/python_cpp_extension/7984cc513026d5ffde88ea711fc803adb4044850/docs/imgs/安装成功.png
--------------------------------------------------------------------------------
/mmcv_setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import platform
4 | import re
5 | import warnings
6 | from pkg_resources import DistributionNotFound, get_distribution
7 | from setuptools import find_packages, setup
8 |
9 | EXT_TYPE = ''
10 | try:
11 | import torch
12 | if torch.__version__ == 'parrots':
13 | from parrots.utils.build_extension import BuildExtension
14 | EXT_TYPE = 'parrots'
15 | elif (hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()) or \
16 | os.getenv('FORCE_MLU', '0') == '1':
17 | from torch_mlu.utils.cpp_extension import BuildExtension
18 | EXT_TYPE = 'pytorch'
19 | else:
20 | from torch.utils.cpp_extension import BuildExtension
21 | EXT_TYPE = 'pytorch'
22 | cmd_class = {'build_ext': BuildExtension}
23 | except ModuleNotFoundError:
24 | cmd_class = {}
25 | print('Skip building ext ops due to the absence of torch.')
26 |
27 |
28 | def choose_requirement(primary, secondary):
29 | """If some version of primary requirement installed, return primary, else
30 | return secondary."""
31 | try:
32 | name = re.split(r'[!<>=]', primary)[0]
33 | get_distribution(name)
34 | except DistributionNotFound:
35 | return secondary
36 |
37 | return str(primary)
38 |
39 |
40 | def get_version():
41 | version_file = 'mmcv/version.py'
42 | with open(version_file, encoding='utf-8') as f:
43 | exec(compile(f.read(), version_file, 'exec'))
44 | return locals()['__version__']
45 |
46 |
47 | def parse_requirements(fname='requirements/runtime.txt', with_version=True):
48 | """Parse the package dependencies listed in a requirements file but strips
49 | specific versioning information.
50 |
51 | Args:
52 | fname (str): path to requirements file
53 | with_version (bool, default=False): if True include version specs
54 |
55 | Returns:
56 | List[str]: list of requirements items
57 |
58 | CommandLine:
59 | python -c "import setup; print(setup.parse_requirements())"
60 | """
61 | import sys
62 | from os.path import exists
63 | require_fpath = fname
64 |
65 | def parse_line(line):
66 | """Parse information from a line in a requirements text file."""
67 | if line.startswith('-r '):
68 | # Allow specifying requirements in other files
69 | target = line.split(' ')[1]
70 | for info in parse_require_file(target):
71 | yield info
72 | else:
73 | info = {'line': line}
74 | if line.startswith('-e '):
75 | info['package'] = line.split('#egg=')[1]
76 | else:
77 | # Remove versioning from the package
78 | pat = '(' + '|'.join(['>=', '==', '>']) + ')'
79 | parts = re.split(pat, line, maxsplit=1)
80 | parts = [p.strip() for p in parts]
81 |
82 | info['package'] = parts[0]
83 | if len(parts) > 1:
84 | op, rest = parts[1:]
85 | if ';' in rest:
86 | # Handle platform specific dependencies
87 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
88 | version, platform_deps = map(str.strip,
89 | rest.split(';'))
90 | info['platform_deps'] = platform_deps
91 | else:
92 | version = rest # NOQA
93 | info['version'] = (op, version)
94 | yield info
95 |
96 | def parse_require_file(fpath):
97 | with open(fpath) as f:
98 | for line in f.readlines():
99 | line = line.strip()
100 | if line and not line.startswith('#'):
101 | yield from parse_line(line)
102 |
103 | def gen_packages_items():
104 | if exists(require_fpath):
105 | for info in parse_require_file(require_fpath):
106 | parts = [info['package']]
107 | if with_version and 'version' in info:
108 | parts.extend(info['version'])
109 | if not sys.version.startswith('3.4'):
110 | # apparently package_deps are broken in 3.4
111 | platform_deps = info.get('platform_deps')
112 | if platform_deps is not None:
113 | parts.append(';' + platform_deps)
114 | item = ''.join(parts)
115 | yield item
116 |
117 | packages = list(gen_packages_items())
118 | return packages
119 |
120 |
121 | install_requires = parse_requirements()
122 |
123 | try:
124 | # OpenCV installed via conda.
125 | import cv2 # NOQA: F401
126 | major, minor, *rest = cv2.__version__.split('.')
127 | if int(major) < 3:
128 | raise RuntimeError(
129 | f'OpenCV >=3 is required but {cv2.__version__} is installed')
130 | except ImportError:
131 | # If first not installed install second package
132 | CHOOSE_INSTALL_REQUIRES = [('opencv-python-headless>=3',
133 | 'opencv-python>=3')]
134 | for main, secondary in CHOOSE_INSTALL_REQUIRES:
135 | install_requires.append(choose_requirement(main, secondary))
136 |
137 |
138 | def get_extensions():
139 | extensions = []
140 |
141 | if os.getenv('MMCV_WITH_TRT', '0') != '0':
142 |
143 | # Following strings of text style are from colorama package
144 | bright_style, reset_style = '\x1b[1m', '\x1b[0m'
145 | red_text, blue_text = '\x1b[31m', '\x1b[34m'
146 | white_background = '\x1b[107m'
147 |
148 | msg = white_background + bright_style + red_text
149 | msg += 'DeprecationWarning: ' + \
150 | 'Custom TensorRT Ops will be deprecated in future. '
151 | msg += blue_text + \
152 | 'Welcome to use the unified model deployment toolbox '
153 | msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
154 | msg += reset_style
155 | warnings.warn(msg)
156 |
157 | ext_name = 'mmcv._ext_trt'
158 | from torch.utils.cpp_extension import include_paths, library_paths
159 | library_dirs = []
160 | libraries = []
161 | include_dirs = []
162 | tensorrt_path = os.getenv('TENSORRT_DIR', '0')
163 | tensorrt_lib_path = glob.glob(
164 | os.path.join(tensorrt_path, 'targets', '*', 'lib'))[0]
165 | library_dirs += [tensorrt_lib_path]
166 | libraries += ['nvinfer', 'nvparsers', 'nvinfer_plugin']
167 | libraries += ['cudart']
168 | define_macros = []
169 | extra_compile_args = {'cxx': []}
170 |
171 | include_path = os.path.abspath('./mmcv/ops/csrc/common/cuda')
172 | include_trt_path = os.path.abspath('./mmcv/ops/csrc/tensorrt')
173 | include_dirs.append(include_path)
174 | include_dirs.append(include_trt_path)
175 | include_dirs.append(os.path.join(tensorrt_path, 'include'))
176 | include_dirs += include_paths(cuda=True)
177 |
178 | op_files = glob.glob('./mmcv/ops/csrc/tensorrt/plugins/*')
179 | define_macros += [('MMCV_WITH_CUDA', None)]
180 | define_macros += [('MMCV_WITH_TRT', None)]
181 | cuda_args = os.getenv('MMCV_CUDA_ARGS')
182 | extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
183 | # prevent cub/thrust conflict with other python library
184 | # More context See issues #1454
185 | extra_compile_args['nvcc'] += ['-Xcompiler=-fno-gnu-unique']
186 | library_dirs += library_paths(cuda=True)
187 |
188 | from setuptools import Extension
189 | ext_ops = Extension(
190 | name=ext_name,
191 | sources=op_files,
192 | include_dirs=include_dirs,
193 | define_macros=define_macros,
194 | extra_compile_args=extra_compile_args,
195 | language='c++',
196 | library_dirs=library_dirs,
197 | libraries=libraries)
198 | extensions.append(ext_ops)
199 |
200 | if os.getenv('MMCV_WITH_OPS', '0') == '0':
201 | return extensions
202 |
203 | if EXT_TYPE == 'parrots':
204 | ext_name = 'mmcv._ext'
205 | from parrots.utils.build_extension import Extension
206 |
207 | # new parrots op impl do not use MMCV_USE_PARROTS
208 | # define_macros = [('MMCV_USE_PARROTS', None)]
209 | define_macros = []
210 | include_dirs = []
211 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') +\
212 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') +\
213 | glob.glob('./mmcv/ops/csrc/parrots/*.cpp')
214 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
215 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
216 | cuda_args = os.getenv('MMCV_CUDA_ARGS')
217 | extra_compile_args = {
218 | 'nvcc': [cuda_args, '-std=c++14'] if cuda_args else ['-std=c++14'],
219 | 'cxx': ['-std=c++14'],
220 | }
221 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
222 | define_macros += [('MMCV_WITH_CUDA', None)]
223 | extra_compile_args['nvcc'] += [
224 | '-D__CUDA_NO_HALF_OPERATORS__',
225 | '-D__CUDA_NO_HALF_CONVERSIONS__',
226 | '-D__CUDA_NO_HALF2_OPERATORS__',
227 | ]
228 | ext_ops = Extension(
229 | name=ext_name,
230 | sources=op_files,
231 | include_dirs=include_dirs,
232 | define_macros=define_macros,
233 | extra_compile_args=extra_compile_args,
234 | cuda=True,
235 | pytorch=True)
236 | extensions.append(ext_ops)
237 | elif EXT_TYPE == 'pytorch':
238 | ext_name = 'mmcv._ext'
239 | from torch.utils.cpp_extension import CppExtension, CUDAExtension
240 |
241 | # prevent ninja from using too many resources
242 | try:
243 | import psutil
244 | num_cpu = len(psutil.Process().cpu_affinity())
245 | cpu_use = max(4, num_cpu - 1)
246 | except (ModuleNotFoundError, AttributeError):
247 | cpu_use = 4
248 |
249 | os.environ.setdefault('MAX_JOBS', str(cpu_use))
250 | define_macros = []
251 |
252 | # Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
253 | # required key passed to PyTorch. Even if there is no flag passed
254 | # to cxx, users also need to pass an empty list to PyTorch.
255 | # Since PyTorch1.8.0, it has a default value so users do not need
256 | # to pass an empty list anymore.
257 | # More details at https://github.com/pytorch/pytorch/pull/45956
258 | extra_compile_args = {'cxx': []}
259 |
260 | # Since the PR (https://github.com/open-mmlab/mmcv/pull/1463) uses
261 | # c++14 features, the argument ['std=c++14'] must be added here.
262 | # However, in the windows environment, some standard libraries
263 | # will depend on c++17 or higher. In fact, for the windows
264 | # environment, the compiler will choose the appropriate compiler
265 | # to compile those cpp files, so there is no need to add the
266 | # argument
267 | if platform.system() != 'Windows':
268 | extra_compile_args['cxx'] = ['-std=c++14']
269 |
270 | include_dirs = []
271 |
272 | is_rocm_pytorch = False
273 | try:
274 | from torch.utils.cpp_extension import ROCM_HOME
275 | is_rocm_pytorch = True if ((torch.version.hip is not None) and
276 | (ROCM_HOME is not None)) else False
277 | except ImportError:
278 | pass
279 |
280 | if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
281 | 'FORCE_CUDA', '0') == '1':
282 | if is_rocm_pytorch:
283 | define_macros += [('MMCV_WITH_HIP', None)]
284 | define_macros += [('MMCV_WITH_CUDA', None)]
285 | cuda_args = os.getenv('MMCV_CUDA_ARGS')
286 | extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
287 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
288 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
289 | glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
290 | glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp')
291 | extension = CUDAExtension
292 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch'))
293 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
294 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
295 | elif (hasattr(torch, 'is_mlu_available') and
296 | torch.is_mlu_available()) or \
297 | os.getenv('FORCE_MLU', '0') == '1':
298 | from torch_mlu.utils.cpp_extension import MLUExtension
299 | define_macros += [('MMCV_WITH_MLU', None)]
300 | mlu_args = os.getenv('MMCV_MLU_ARGS')
301 | extra_compile_args['cncc'] = [mlu_args] if mlu_args else []
302 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
303 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
304 | glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
305 | glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu')
306 | extension = MLUExtension
307 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
308 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
309 | elif (hasattr(torch.backends, 'mps')
310 | and torch.backends.mps.is_available()) or os.getenv(
311 | 'FORCE_MPS', '0') == '1':
312 | # objc compiler support
313 | from distutils.unixccompiler import UnixCCompiler
314 | if '.mm' not in UnixCCompiler.src_extensions:
315 | UnixCCompiler.src_extensions.append('.mm')
316 | UnixCCompiler.language_map['.mm'] = 'objc'
317 |
318 | define_macros += [('MMCV_WITH_MPS', None)]
319 | extra_compile_args = {}
320 | extra_compile_args['cxx'] = ['-Wall', '-std=c++17']
321 | extra_compile_args['cxx'] += [
322 | '-framework', 'Metal', '-framework', 'Foundation'
323 | ]
324 | extra_compile_args['cxx'] += ['-ObjC++']
325 | # src
326 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
327 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
328 | glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \
329 | glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm')
330 | extension = CppExtension
331 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
332 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
333 | elif (os.getenv('FORCE_NPU', '0') == '1'):
334 | print(f'Compiling {ext_name} only with CPU and NPU')
335 | try:
336 | from torch_npu.utils.cpp_extension import NpuExtension
337 | define_macros += [('MMCV_WITH_NPU', None)]
338 | extension = NpuExtension
339 | except Exception:
340 | raise ImportError('can not find any torch_npu')
341 | # src
342 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
343 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
344 | glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \
345 | glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp')
346 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
347 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu'))
348 | else:
349 | print(f'Compiling {ext_name} only with CPU')
350 | op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
351 | glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp')
352 | extension = CppExtension
353 | include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
354 |
355 | # Since the PR (https://github.com/open-mmlab/mmcv/pull/1463) uses
356 | # c++14 features, the argument ['std=c++14'] must be added here.
357 | # However, in the windows environment, some standard libraries
358 | # will depend on c++17 or higher. In fact, for the windows
359 | # environment, the compiler will choose the appropriate compiler
360 | # to compile those cpp files, so there is no need to add the
361 | # argument
362 | if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
363 | extra_compile_args['nvcc'] += ['-std=c++14']
364 |
365 | ext_ops = extension(
366 | name=ext_name,
367 | sources=op_files,
368 | include_dirs=include_dirs,
369 | define_macros=define_macros,
370 | extra_compile_args=extra_compile_args)
371 | extensions.append(ext_ops)
372 |
373 | if EXT_TYPE == 'pytorch' and os.getenv('MMCV_WITH_ORT', '0') != '0':
374 |
375 | # Following strings of text style are from colorama package
376 | bright_style, reset_style = '\x1b[1m', '\x1b[0m'
377 | red_text, blue_text = '\x1b[31m', '\x1b[34m'
378 | white_background = '\x1b[107m'
379 |
380 | msg = white_background + bright_style + red_text
381 | msg += 'DeprecationWarning: ' + \
382 | 'Custom ONNXRuntime Ops will be deprecated in future. '
383 | msg += blue_text + \
384 | 'Welcome to use the unified model deployment toolbox '
385 | msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
386 | msg += reset_style
387 | warnings.warn(msg)
388 | ext_name = 'mmcv._ext_ort'
389 | import onnxruntime
390 | from torch.utils.cpp_extension import include_paths, library_paths
391 | library_dirs = []
392 | libraries = []
393 | include_dirs = []
394 | ort_path = os.getenv('ONNXRUNTIME_DIR', '0')
395 | library_dirs += [os.path.join(ort_path, 'lib')]
396 | libraries.append('onnxruntime')
397 | define_macros = []
398 | extra_compile_args = {'cxx': []}
399 |
400 | include_path = os.path.abspath('./mmcv/ops/csrc/onnxruntime')
401 | include_dirs.append(include_path)
402 | include_dirs.append(os.path.join(ort_path, 'include'))
403 |
404 | op_files = glob.glob('./mmcv/ops/csrc/onnxruntime/cpu/*')
405 | if onnxruntime.get_device() == 'GPU' or os.getenv('FORCE_CUDA',
406 | '0') == '1':
407 | define_macros += [('MMCV_WITH_CUDA', None)]
408 | cuda_args = os.getenv('MMCV_CUDA_ARGS')
409 | extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
410 | op_files += glob.glob('./mmcv/ops/csrc/onnxruntime/gpu/*')
411 | include_dirs += include_paths(cuda=True)
412 | library_dirs += library_paths(cuda=True)
413 | else:
414 | include_dirs += include_paths(cuda=False)
415 | library_dirs += library_paths(cuda=False)
416 |
417 | from setuptools import Extension
418 | ext_ops = Extension(
419 | name=ext_name,
420 | sources=op_files,
421 | include_dirs=include_dirs,
422 | define_macros=define_macros,
423 | extra_compile_args=extra_compile_args,
424 | language='c++',
425 | library_dirs=library_dirs,
426 | libraries=libraries)
427 | extensions.append(ext_ops)
428 |
429 | return extensions
430 |
431 |
432 | setup(
433 | name='mmcv' if os.getenv('MMCV_WITH_OPS', '0') == '0' else 'mmcv-full',
434 | version=get_version(),
435 | description='OpenMMLab Computer Vision Foundation',
436 | keywords='computer vision',
437 | packages=find_packages(),
438 | include_package_data=True,
439 | classifiers=[
440 | 'Development Status :: 4 - Beta',
441 | 'License :: OSI Approved :: Apache Software License',
442 | 'Operating System :: OS Independent',
443 | 'Programming Language :: Python :: 3',
444 | 'Programming Language :: Python :: 3.6',
445 | 'Programming Language :: Python :: 3.7',
446 | 'Programming Language :: Python :: 3.8',
447 | 'Programming Language :: Python :: 3.9',
448 | 'Programming Language :: Python :: 3.10',
449 | 'Topic :: Utilities',
450 | ],
451 | url='https://github.com/open-mmlab/mmcv',
452 | author='MMCV Contributors',
453 | author_email='openmmlab@gmail.com',
454 | install_requires=install_requires,
455 | extras_require={
456 | 'all': parse_requirements('requirements.txt'),
457 | 'tests': parse_requirements('requirements/test.txt'),
458 | 'build': parse_requirements('requirements/build.txt'),
459 | 'optional': parse_requirements('requirements/optional.txt'),
460 | },
461 | ext_modules=get_extensions(),
462 | cmdclass=cmd_class,
463 | zip_safe=False)
464 |
--------------------------------------------------------------------------------
/orbbec/__init__.py:
--------------------------------------------------------------------------------
1 | from .warpaffine import affine_opencv, affine_torch
2 | from .nms import batched_nms, nms, soft_nms
3 | from .roi_align import RoIAlign, roi_align
4 | from .utils import get_compiler_version, get_compiling_cuda_version
5 |
6 | __all__ = [
7 | 'nms', 'soft_nms', 'batched_nms',
8 | 'RoIAlign', 'roi_align',
9 | 'get_compiler_version', 'get_compiling_cuda_version',
10 | 'affine_opencv','affine_torch'
11 | ]
12 |
--------------------------------------------------------------------------------
/orbbec/nms/__init__.py:
--------------------------------------------------------------------------------
1 | from .nms_wrapper import batched_nms, nms, soft_nms
2 |
3 | __all__ = ['nms', 'soft_nms', 'batched_nms']
4 |
--------------------------------------------------------------------------------
/orbbec/nms/nms_wrapper.py:
--------------------------------------------------------------------------------
1 | from re import I
2 | import numpy as np
3 | import torch
4 |
5 | from . import nms_ext
6 |
7 |
8 | def nms(dets, iou_thr, device_id=None):
9 | """Dispatch to either CPU or GPU NMS implementations.
10 |
11 | The input can be either a torch tensor or numpy array. GPU NMS will be used
12 | if the input is a gpu tensor or device_id is specified, otherwise CPU NMS
13 | will be used. The returned type will always be the same as inputs.
14 |
15 | Arguments:
16 | dets (torch.Tensor or np.ndarray): bboxes with scores.
17 | iou_thr (float): IoU threshold for NMS.
18 | device_id (int, optional): when `dets` is a numpy array, if `device_id`
19 | is None, then cpu nms is used, otherwise gpu_nms will be used.
20 |
21 | Returns:
22 | tuple: kept bboxes and indice, which is always the same data type as
23 | the input.
24 |
25 | Example:
26 | >>> dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
27 | >>> [49.3, 32.9, 51.0, 35.3, 0.9],
28 | >>> [49.2, 31.8, 51.0, 35.4, 0.5],
29 | >>> [35.1, 11.5, 39.1, 15.7, 0.5],
30 | >>> [35.6, 11.8, 39.3, 14.2, 0.5],
31 | >>> [35.3, 11.5, 39.9, 14.5, 0.4],
32 | >>> [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32)
33 | >>> iou_thr = 0.6
34 | >>> suppressed, inds = nms(dets, iou_thr)
35 | >>> assert len(inds) == len(suppressed) == 3
36 | """
37 | # convert dets (tensor or numpy array) to tensor
38 | if isinstance(dets, torch.Tensor):
39 | is_numpy = False
40 | dets_th = dets
41 | elif isinstance(dets, np.ndarray):
42 | is_numpy = True
43 | device = 'cpu' if device_id is None else f'cuda:{device_id}'
44 | dets_th = torch.from_numpy(dets).to(device)
45 | else:
46 | raise TypeError('dets must be either a Tensor or numpy array, '
47 | f'but got {type(dets)}')
48 |
49 | # execute cpu or cuda nms
50 | if dets_th.shape[0] == 0:
51 | inds = dets_th.new_zeros(0, dtype=torch.long)
52 | else:
53 | if dets_th.is_cuda:
54 | inds = nms_ext.nms(dets_th, iou_thr)
55 | else:
56 | inds = nms_ext.nms(dets_th, iou_thr)
57 |
58 | if is_numpy:
59 | inds = inds.cpu().numpy()
60 | return dets[inds, :], inds
61 |
62 |
63 | def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
64 | """Dispatch to only CPU Soft NMS implementations.
65 |
66 | The input can be either a torch tensor or numpy array.
67 | The returned type will always be the same as inputs.
68 |
69 | Arguments:
70 | dets (torch.Tensor or np.ndarray): bboxes with scores.
71 | iou_thr (float): IoU threshold for Soft NMS.
72 | method (str): either 'linear' or 'gaussian'
73 | sigma (float): hyperparameter for gaussian method
74 | min_score (float): score filter threshold
75 |
76 | Returns:
77 | tuple: new det bboxes and indice, which is always the same
78 | data type as the input.
79 |
80 | Example:
81 | >>> dets = np.array([[4., 3., 5., 3., 0.9],
82 | >>> [4., 3., 5., 4., 0.9],
83 | >>> [3., 1., 3., 1., 0.5],
84 | >>> [3., 1., 3., 1., 0.5],
85 | >>> [3., 1., 3., 1., 0.4],
86 | >>> [3., 1., 3., 1., 0.0]], dtype=np.float32)
87 | >>> iou_thr = 0.6
88 | >>> new_dets, inds = soft_nms(dets, iou_thr, sigma=0.5)
89 | >>> assert len(inds) == len(new_dets) == 5
90 | """
91 | # convert dets (tensor or numpy array) to tensor
92 | if isinstance(dets, torch.Tensor):
93 | is_tensor = True
94 | dets_t = dets.detach().cpu()
95 | elif isinstance(dets, np.ndarray):
96 | is_tensor = False
97 | dets_t = torch.from_numpy(dets)
98 | else:
99 | raise TypeError('dets must be either a Tensor or numpy array, '
100 | f'but got {type(dets)}')
101 |
102 | method_codes = {'linear': 1, 'gaussian': 2}
103 | if method not in method_codes:
104 | raise ValueError(f'Invalid method for SoftNMS: {method}')
105 | results = nms_ext.soft_nms(dets_t, iou_thr, method_codes[method], sigma,
106 | min_score)
107 |
108 | new_dets = results[:, :5]
109 | inds = results[:, 5]
110 |
111 | if is_tensor:
112 | return new_dets.to(
113 | device=dets.device, dtype=dets.dtype), inds.to(
114 | device=dets.device, dtype=torch.long)
115 | else:
116 | return new_dets.numpy().astype(dets.dtype), inds.numpy().astype(
117 | np.int64)
118 |
119 |
120 | def batched_nms(bboxes, scores, inds, nms_cfg):
121 | """Performs non-maximum suppression in a batched fashion.
122 |
123 | Modified from https://github.com/pytorch/vision/blob
124 | /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
125 | In order to perform NMS independently per class, we add an offset to all
126 | the boxes. The offset is dependent only on the class idx, and is large
127 | enough so that boxes from different classes do not overlap.
128 |
129 | Arguments:
130 | bboxes (torch.Tensor): bboxes in shape (N, 4).
131 | scores (torch.Tensor): scores in shape (N, ).
132 | inds (torch.Tensor): each index value correspond to a bbox cluster,
133 | and NMS will not be applied between elements of different inds,
134 | shape (N, ).
135 | nms_cfg (dict): specify nms type and other parameters like iou_thr.
136 |
137 | Returns:
138 | tuple: kept bboxes and indice.
139 | """
140 | max_coordinate = bboxes.max()
141 | offsets = inds.to(bboxes) * (max_coordinate + 1)
142 | bboxes_for_nms = bboxes + offsets[:, None]
143 | nms_cfg_ = nms_cfg.copy()
144 | nms_type = nms_cfg_.pop('type', 'nms')
145 | nms_op = eval(nms_type)
146 | dets, keep = nms_op(
147 | torch.cat([bboxes_for_nms, scores[:, None]], -1), **nms_cfg_)
148 | bboxes = bboxes[keep]
149 | scores = dets[:, -1]
150 | return torch.cat([bboxes, scores[:, None]], -1), keep
151 |
--------------------------------------------------------------------------------
/orbbec/nms/src/cpu/nms_cpu.cpp:
--------------------------------------------------------------------------------
1 | // Soft-NMS is added by MMDetection.
2 | // Modified from
3 | // https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx.
4 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
5 | #include
6 |
7 | template
8 | at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
9 | AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor");
10 |
11 | if (dets.numel() == 0) {
12 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
13 | }
14 |
15 | auto x1_t = dets.select(1, 0).contiguous();
16 | auto y1_t = dets.select(1, 1).contiguous();
17 | auto x2_t = dets.select(1, 2).contiguous();
18 | auto y2_t = dets.select(1, 3).contiguous();
19 | auto scores = dets.select(1, 4).contiguous();
20 |
21 | at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
22 |
23 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
24 |
25 | auto ndets = dets.size(0);
26 | at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
27 |
28 | auto suppressed = suppressed_t.data_ptr();
29 | auto order = order_t.data_ptr();
30 | auto x1 = x1_t.data_ptr();
31 | auto y1 = y1_t.data_ptr();
32 | auto x2 = x2_t.data_ptr();
33 | auto y2 = y2_t.data_ptr();
34 | auto areas = areas_t.data_ptr();
35 |
36 | for (int64_t _i = 0; _i < ndets; _i++) {
37 | auto i = order[_i];
38 | if (suppressed[i] == 1) continue;
39 | auto ix1 = x1[i];
40 | auto iy1 = y1[i];
41 | auto ix2 = x2[i];
42 | auto iy2 = y2[i];
43 | auto iarea = areas[i];
44 |
45 | for (int64_t _j = _i + 1; _j < ndets; _j++) {
46 | auto j = order[_j];
47 | if (suppressed[j] == 1) continue;
48 | auto xx1 = std::max(ix1, x1[j]);
49 | auto yy1 = std::max(iy1, y1[j]);
50 | auto xx2 = std::min(ix2, x2[j]);
51 | auto yy2 = std::min(iy2, y2[j]);
52 |
53 | auto w = std::max(static_cast(0), xx2 - xx1);
54 | auto h = std::max(static_cast(0), yy2 - yy1);
55 | auto inter = w * h;
56 | auto ovr = inter / (iarea + areas[j] - inter);
57 | if (ovr >= threshold) suppressed[j] = 1;
58 | }
59 | }
60 | return at::nonzero(suppressed_t == 0).squeeze(1);
61 | }
62 |
63 | at::Tensor nms_cpu(const at::Tensor& dets, const float threshold) {
64 | at::Tensor result;
65 | AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
66 | result = nms_cpu_kernel(dets, threshold);
67 | });
68 | return result;
69 | }
70 |
71 | template
72 | at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
73 | const unsigned char method, const float sigma,
74 | const float min_score) {
75 | AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor");
76 |
77 | if (dets.numel() == 0) {
78 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
79 | }
80 |
81 | auto x1_t = dets.select(1, 0).contiguous();
82 | auto y1_t = dets.select(1, 1).contiguous();
83 | auto x2_t = dets.select(1, 2).contiguous();
84 | auto y2_t = dets.select(1, 3).contiguous();
85 | auto scores_t = dets.select(1, 4).contiguous();
86 |
87 | at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
88 |
89 | auto ndets = dets.size(0);
90 | auto x1 = x1_t.data_ptr();
91 | auto y1 = y1_t.data_ptr();
92 | auto x2 = x2_t.data_ptr();
93 | auto y2 = y2_t.data_ptr();
94 | auto scores = scores_t.data_ptr();
95 | auto areas = areas_t.data_ptr();
96 |
97 | int64_t pos = 0;
98 | at::Tensor inds_t = at::arange(ndets, dets.options());
99 | auto inds = inds_t.data_ptr();
100 |
101 | for (int64_t i = 0; i < ndets; i++) {
102 | auto max_score = scores[i];
103 | auto max_pos = i;
104 |
105 | auto ix1 = x1[i];
106 | auto iy1 = y1[i];
107 | auto ix2 = x2[i];
108 | auto iy2 = y2[i];
109 | auto iscore = scores[i];
110 | auto iarea = areas[i];
111 | auto iind = inds[i];
112 |
113 | pos = i + 1;
114 | // get max box
115 | while (pos < ndets) {
116 | if (max_score < scores[pos]) {
117 | max_score = scores[pos];
118 | max_pos = pos;
119 | }
120 | pos = pos + 1;
121 | }
122 | // add max box as a detection
123 | x1[i] = x1[max_pos];
124 | y1[i] = y1[max_pos];
125 | x2[i] = x2[max_pos];
126 | y2[i] = y2[max_pos];
127 | scores[i] = scores[max_pos];
128 | areas[i] = areas[max_pos];
129 | inds[i] = inds[max_pos];
130 |
131 | // swap ith box with position of max box
132 | x1[max_pos] = ix1;
133 | y1[max_pos] = iy1;
134 | x2[max_pos] = ix2;
135 | y2[max_pos] = iy2;
136 | scores[max_pos] = iscore;
137 | areas[max_pos] = iarea;
138 | inds[max_pos] = iind;
139 |
140 | ix1 = x1[i];
141 | iy1 = y1[i];
142 | ix2 = x2[i];
143 | iy2 = y2[i];
144 | iscore = scores[i];
145 | iarea = areas[i];
146 |
147 | pos = i + 1;
148 | // NMS iterations, note that N changes if detection boxes fall below
149 | // threshold
150 | while (pos < ndets) {
151 | auto xx1 = std::max(ix1, x1[pos]);
152 | auto yy1 = std::max(iy1, y1[pos]);
153 | auto xx2 = std::min(ix2, x2[pos]);
154 | auto yy2 = std::min(iy2, y2[pos]);
155 |
156 | auto w = std::max(static_cast(0), xx2 - xx1);
157 | auto h = std::max(static_cast(0), yy2 - yy1);
158 | auto inter = w * h;
159 | auto ovr = inter / (iarea + areas[pos] - inter);
160 |
161 | scalar_t weight = 1.;
162 | if (method == 1) {
163 | if (ovr > threshold) weight = 1 - ovr;
164 | } else if (method == 2) {
165 | weight = std::exp(-(ovr * ovr) / sigma);
166 | } else {
167 | // original NMS
168 | if (ovr > threshold) {
169 | weight = 0;
170 | } else {
171 | weight = 1;
172 | }
173 | }
174 | scores[pos] = weight * scores[pos];
175 | // if box score falls below threshold, discard the box by
176 | // swapping with last box update N
177 | if (scores[pos] < min_score) {
178 | x1[pos] = x1[ndets - 1];
179 | y1[pos] = y1[ndets - 1];
180 | x2[pos] = x2[ndets - 1];
181 | y2[pos] = y2[ndets - 1];
182 | scores[pos] = scores[ndets - 1];
183 | areas[pos] = areas[ndets - 1];
184 | inds[pos] = inds[ndets - 1];
185 | ndets = ndets - 1;
186 | pos = pos - 1;
187 | }
188 | pos = pos + 1;
189 | }
190 | }
191 | at::Tensor result = at::zeros({6, ndets}, dets.options());
192 | result[0] = x1_t.slice(0, 0, ndets);
193 | result[1] = y1_t.slice(0, 0, ndets);
194 | result[2] = x2_t.slice(0, 0, ndets);
195 | result[3] = y2_t.slice(0, 0, ndets);
196 | result[4] = scores_t.slice(0, 0, ndets);
197 | result[5] = inds_t.slice(0, 0, ndets);
198 |
199 | result = result.t().contiguous();
200 | return result;
201 | }
202 |
203 | at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold,
204 | const unsigned char method, const float sigma,
205 | const float min_score) {
206 | at::Tensor result;
207 | AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] {
208 | result = soft_nms_cpu_kernel(dets, threshold, method, sigma,
209 | min_score);
210 | });
211 | return result;
212 | }
213 |
--------------------------------------------------------------------------------
/orbbec/nms/src/cuda/nms_cuda.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #include
3 |
4 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
5 |
6 | at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh);
7 |
8 | at::Tensor nms_cuda(const at::Tensor& dets, const float threshold) {
9 | CHECK_CUDA(dets);
10 | if (dets.numel() == 0)
11 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
12 | return nms_cuda_forward(dets, threshold);
13 | }
14 |
--------------------------------------------------------------------------------
/orbbec/nms/src/cuda/nms_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | // /* pytorch: 1.5.0 ~ 1.10.x */
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | // /* ---------------------- */
10 |
11 | // /* pytorch: 1.11.0 ~ latest */
12 | // #include
13 | // #include
14 | // #include
15 | // #include
16 |
17 | // #include
18 | // #include
19 |
20 | #include
21 | #include
22 |
23 | int const threadsPerBlock = sizeof(unsigned long long) * 8;
24 |
25 | __device__ inline float devIoU(float const * const a, float const * const b) {
26 | float left = max(a[0], b[0]), right = min(a[2], b[2]);
27 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
28 | float width = max(right - left, 0.f), height = max(bottom - top, 0.f);
29 | float interS = width * height;
30 | float Sa = (a[2] - a[0]) * (a[3] - a[1]);
31 | float Sb = (b[2] - b[0]) * (b[3] - b[1]);
32 | return interS / (Sa + Sb - interS);
33 | }
34 |
35 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
36 | const float *dev_boxes, unsigned long long *dev_mask) {
37 | const int row_start = blockIdx.y;
38 | const int col_start = blockIdx.x;
39 |
40 | // if (row_start > col_start) return;
41 |
42 | const int row_size =
43 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
44 | const int col_size =
45 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
46 |
47 | __shared__ float block_boxes[threadsPerBlock * 5];
48 | if (threadIdx.x < col_size) {
49 | block_boxes[threadIdx.x * 5 + 0] =
50 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
51 | block_boxes[threadIdx.x * 5 + 1] =
52 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
53 | block_boxes[threadIdx.x * 5 + 2] =
54 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
55 | block_boxes[threadIdx.x * 5 + 3] =
56 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
57 | block_boxes[threadIdx.x * 5 + 4] =
58 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
59 | }
60 | __syncthreads();
61 |
62 | if (threadIdx.x < row_size) {
63 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
64 | const float *cur_box = dev_boxes + cur_box_idx * 5;
65 | int i = 0;
66 | unsigned long long t = 0;
67 | int start = 0;
68 | if (row_start == col_start) {
69 | start = threadIdx.x + 1;
70 | }
71 | for (i = start; i < col_size; i++) {
72 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
73 | t |= 1ULL << i;
74 | }
75 | }
76 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
77 | dev_mask[cur_box_idx * col_blocks + col_start] = t;
78 | }
79 | }
80 |
81 | // boxes is a N x 5 tensor
82 | at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) {
83 |
84 | // Ensure CUDA uses the input tensor device.
85 | at::DeviceGuard guard(boxes.device());
86 |
87 | using scalar_t = float;
88 | AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
89 | auto scores = boxes.select(1, 4);
90 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
91 | auto boxes_sorted = boxes.index_select(0, order_t);
92 |
93 | int boxes_num = boxes.size(0);
94 |
95 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
96 |
97 | scalar_t* boxes_dev = boxes_sorted.data_ptr();
98 |
99 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
100 |
101 | unsigned long long* mask_dev = NULL;
102 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
103 | // boxes_num * col_blocks * sizeof(unsigned long long)));
104 |
105 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
106 |
107 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
108 | THCCeilDiv(boxes_num, threadsPerBlock));
109 | dim3 threads(threadsPerBlock);
110 | nms_kernel<<>>(boxes_num,
111 | nms_overlap_thresh,
112 | boxes_dev,
113 | mask_dev);
114 |
115 | std::vector mask_host(boxes_num * col_blocks);
116 | THCudaCheck(cudaMemcpyAsync(
117 | &mask_host[0],
118 | mask_dev,
119 | sizeof(unsigned long long) * boxes_num * col_blocks,
120 | cudaMemcpyDeviceToHost,
121 | at::cuda::getCurrentCUDAStream()
122 | ));
123 |
124 | std::vector remv(col_blocks);
125 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
126 |
127 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
128 | int64_t* keep_out = keep.data_ptr();
129 |
130 | int num_to_keep = 0;
131 | for (int i = 0; i < boxes_num; i++) {
132 | int nblock = i / threadsPerBlock;
133 | int inblock = i % threadsPerBlock;
134 |
135 | if (!(remv[nblock] & (1ULL << inblock))) {
136 | keep_out[num_to_keep++] = i;
137 | unsigned long long *p = &mask_host[0] + i * col_blocks;
138 | for (int j = nblock; j < col_blocks; j++) {
139 | remv[j] |= p[j];
140 | }
141 | }
142 | }
143 |
144 | THCudaFree(state, mask_dev);
145 | // TODO improve this part
146 | return order_t.index({
147 | keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
148 | order_t.device(), keep.scalar_type())});
149 | }
150 |
--------------------------------------------------------------------------------
/orbbec/nms/src/nms_ext.cpp:
--------------------------------------------------------------------------------
1 | // Modified from https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx, Soft-NMS is added
2 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 | #include
4 |
5 | at::Tensor nms_cpu(const at::Tensor& dets, const float threshold);
6 |
7 | at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold,
8 | const unsigned char method, const float sigma, const
9 | float min_score);
10 |
11 | #ifdef WITH_CUDA
12 | at::Tensor nms_cuda(const at::Tensor& dets, const float threshold);
13 | #endif
14 |
15 | at::Tensor nms(const at::Tensor& dets, const float threshold){
16 | if (dets.device().is_cuda()) {
17 | #ifdef WITH_CUDA
18 | return nms_cuda(dets, threshold);
19 | #else
20 | AT_ERROR("nms is not compiled with GPU support");
21 | #endif
22 | }
23 | return nms_cpu(dets, threshold);
24 | }
25 |
26 | at::Tensor soft_nms(const at::Tensor& dets, const float threshold,
27 | const unsigned char method, const float sigma, const
28 | float min_score) {
29 | if (dets.device().is_cuda()) {
30 | AT_ERROR("soft_nms is not implemented on GPU");
31 | }
32 | return soft_nms_cpu(dets, threshold, method, sigma, min_score);
33 | }
34 |
35 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
36 | m.def("nms", &nms, "non-maximum suppression");
37 | m.def("soft_nms", &soft_nms, "soft non-maximum suppression");
38 | }
39 |
--------------------------------------------------------------------------------
/orbbec/roi_align/__init__.py:
--------------------------------------------------------------------------------
1 | from .roi_align import RoIAlign, roi_align
2 |
3 | __all__ = ['roi_align', 'RoIAlign']
4 |
--------------------------------------------------------------------------------
/orbbec/roi_align/gradcheck.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | import numpy as np
5 | import torch
6 | from torch.autograd import gradcheck
7 |
8 | sys.path.append(osp.abspath(osp.join(__file__, '../../')))
9 | from roi_align import RoIAlign # noqa: E402, isort:skip
10 |
11 | feat_size = 15
12 | spatial_scale = 1.0 / 8
13 | img_size = feat_size / spatial_scale
14 | num_imgs = 2
15 | num_rois = 20
16 |
17 | batch_ind = np.random.randint(num_imgs, size=(num_rois, 1))
18 | rois = np.random.rand(num_rois, 4) * img_size * 0.5
19 | rois[:, 2:] += img_size * 0.5
20 | rois = np.hstack((batch_ind, rois))
21 |
22 | feat = torch.randn(
23 | num_imgs, 16, feat_size, feat_size, requires_grad=True, device='cuda:0')
24 | rois = torch.from_numpy(rois).float().cuda()
25 | inputs = (feat, rois)
26 | print('Gradcheck for roi align...')
27 | test = gradcheck(RoIAlign(3, spatial_scale), inputs, atol=1e-3, eps=1e-3)
28 | print(test)
29 | test = gradcheck(RoIAlign(3, spatial_scale, 2), inputs, atol=1e-3, eps=1e-3)
30 | print(test)
31 |
--------------------------------------------------------------------------------
/orbbec/roi_align/roi_align.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.autograd import Function
3 | from torch.autograd.function import once_differentiable
4 | from torch.nn.modules.utils import _pair
5 |
6 | from . import roi_align_ext
7 |
8 |
9 | class RoIAlignFunction(Function):
10 |
11 | @staticmethod
12 | def forward(ctx,
13 | features,
14 | rois,
15 | out_size,
16 | spatial_scale,
17 | sample_num=0,
18 | aligned=True):
19 | out_h, out_w = _pair(out_size)
20 | assert isinstance(out_h, int) and isinstance(out_w, int)
21 | ctx.spatial_scale = spatial_scale
22 | ctx.sample_num = sample_num
23 | ctx.save_for_backward(rois)
24 | ctx.feature_size = features.size()
25 | ctx.aligned = aligned
26 |
27 | if aligned:
28 | output = roi_align_ext.forward_v2(features, rois, spatial_scale,
29 | out_h, out_w, sample_num,
30 | aligned)
31 | elif features.is_cuda:
32 | (batch_size, num_channels, data_height,
33 | data_width) = features.size()
34 | num_rois = rois.size(0)
35 |
36 | output = features.new_zeros(num_rois, num_channels, out_h, out_w)
37 | roi_align_ext.forward_v1(features, rois, out_h, out_w,
38 | spatial_scale, sample_num, output)
39 | else:
40 | raise NotImplementedError
41 |
42 | return output
43 |
44 | @staticmethod
45 | @once_differentiable
46 | def backward(ctx, grad_output):
47 | feature_size = ctx.feature_size
48 | spatial_scale = ctx.spatial_scale
49 | sample_num = ctx.sample_num
50 | rois = ctx.saved_tensors[0]
51 | aligned = ctx.aligned
52 | assert feature_size is not None
53 |
54 | batch_size, num_channels, data_height, data_width = feature_size
55 | out_w = grad_output.size(3)
56 | out_h = grad_output.size(2)
57 |
58 | grad_input = grad_rois = None
59 | if not aligned:
60 | if ctx.needs_input_grad[0]:
61 | grad_input = rois.new_zeros(batch_size, num_channels,
62 | data_height, data_width)
63 | roi_align_ext.backward_v1(grad_output.contiguous(), rois,
64 | out_h, out_w, spatial_scale,
65 | sample_num, grad_input)
66 | else:
67 | grad_input = roi_align_ext.backward_v2(grad_output, rois,
68 | spatial_scale, out_h, out_w,
69 | batch_size, num_channels,
70 | data_height, data_width,
71 | sample_num, aligned)
72 |
73 | return grad_input, grad_rois, None, None, None, None
74 |
75 |
76 | roi_align = RoIAlignFunction.apply
77 |
78 |
79 | class RoIAlign(nn.Module):
80 |
81 | def __init__(self,
82 | out_size,
83 | spatial_scale,
84 | sample_num=0,
85 | use_torchvision=False,
86 | aligned=True):
87 | """
88 | Args:
89 | out_size (tuple): h, w
90 | spatial_scale (float): scale the input boxes by this number
91 | sample_num (int): number of inputs samples to take for each
92 | output sample. 2 to take samples densely for current models.
93 | use_torchvision (bool): whether to use roi_align from torchvision
94 | aligned (bool): if False, use the legacy implementation in
95 | MMDetection. If True, align the results more perfectly.
96 |
97 | Note:
98 | The implementation of RoIAlign when aligned=True is modified from
99 | https://github.com/facebookresearch/detectron2/
100 |
101 | The meaning of aligned=True:
102 |
103 | Given a continuous coordinate c, its two neighboring pixel
104 | indices (in our pixel model) are computed by floor(c - 0.5) and
105 | ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
106 | indices [0] and [1] (which are sampled from the underlying signal
107 | at continuous coordinates 0.5 and 1.5). But the original roi_align
108 | (aligned=False) does not subtract the 0.5 when computing
109 | neighboring pixel indices and therefore it uses pixels with a
110 | slightly incorrect alignment (relative to our pixel model) when
111 | performing bilinear interpolation.
112 |
113 | With `aligned=True`,
114 | we first appropriately scale the ROI and then shift it by -0.5
115 | prior to calling roi_align. This produces the correct neighbors;
116 |
117 | The difference does not make a difference to the model's
118 | performance if ROIAlign is used together with conv layers.
119 | """
120 | super(RoIAlign, self).__init__()
121 | self.out_size = _pair(out_size)
122 | self.spatial_scale = float(spatial_scale)
123 | self.aligned = aligned
124 | self.sample_num = int(sample_num)
125 | self.use_torchvision = use_torchvision
126 | assert not (use_torchvision and
127 | aligned), 'Torchvision does not support aligned RoIAlgin'
128 |
129 | def forward(self, features, rois):
130 | """
131 | Args:
132 | features: NCHW images
133 | rois: Bx5 boxes. First column is the index into N. The other 4
134 | columns are xyxy.
135 | """
136 | assert rois.dim() == 2 and rois.size(1) == 5
137 |
138 | if self.use_torchvision:
139 | from torchvision.ops import roi_align as tv_roi_align
140 | return tv_roi_align(features, rois, self.out_size,
141 | self.spatial_scale, self.sample_num)
142 | else:
143 | return roi_align(features, rois, self.out_size, self.spatial_scale,
144 | self.sample_num, self.aligned)
145 |
146 | def __repr__(self):
147 | indent_str = '\n '
148 | format_str = self.__class__.__name__
149 | format_str += f'({indent_str}out_size={self.out_size},'
150 | format_str += f'{indent_str}spatial_scale={self.spatial_scale},'
151 | format_str += f'{indent_str}sample_num={self.sample_num},'
152 | format_str += f'{indent_str}use_torchvision={self.use_torchvision},'
153 | format_str += f'{indent_str}aligned={self.aligned})'
154 | return format_str
155 |
--------------------------------------------------------------------------------
/orbbec/roi_align/src/cpu/roi_align_v2.cpp:
--------------------------------------------------------------------------------
1 | // Modified from
2 | // https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign
3 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4 | #include
5 | #include
6 |
7 | // implementation taken from Caffe2
8 | template
9 | struct PreCalc {
10 | int pos1;
11 | int pos2;
12 | int pos3;
13 | int pos4;
14 | T w1;
15 | T w2;
16 | T w3;
17 | T w4;
18 | };
19 |
20 | template
21 | void pre_calc_for_bilinear_interpolate(
22 | const int height, const int width, const int pooled_height,
23 | const int pooled_width, const int iy_upper, const int ix_upper,
24 | T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
25 | int roi_bin_grid_h, int roi_bin_grid_w, std::vector>& pre_calc) {
26 | int pre_calc_index = 0;
27 | for (int ph = 0; ph < pooled_height; ph++) {
28 | for (int pw = 0; pw < pooled_width; pw++) {
29 | for (int iy = 0; iy < iy_upper; iy++) {
30 | const T yy = roi_start_h + ph * bin_size_h +
31 | static_cast(iy + .5f) * bin_size_h /
32 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
33 | for (int ix = 0; ix < ix_upper; ix++) {
34 | const T xx = roi_start_w + pw * bin_size_w +
35 | static_cast(ix + .5f) * bin_size_w /
36 | static_cast(roi_bin_grid_w);
37 |
38 | T x = xx;
39 | T y = yy;
40 | // deal with: inverse elements are out of feature map boundary
41 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
42 | // empty
43 | PreCalc pc;
44 | pc.pos1 = 0;
45 | pc.pos2 = 0;
46 | pc.pos3 = 0;
47 | pc.pos4 = 0;
48 | pc.w1 = 0;
49 | pc.w2 = 0;
50 | pc.w3 = 0;
51 | pc.w4 = 0;
52 | pre_calc[pre_calc_index] = pc;
53 | pre_calc_index += 1;
54 | continue;
55 | }
56 |
57 | if (y <= 0) {
58 | y = 0;
59 | }
60 | if (x <= 0) {
61 | x = 0;
62 | }
63 |
64 | int y_low = (int)y;
65 | int x_low = (int)x;
66 | int y_high;
67 | int x_high;
68 |
69 | if (y_low >= height - 1) {
70 | y_high = y_low = height - 1;
71 | y = (T)y_low;
72 | } else {
73 | y_high = y_low + 1;
74 | }
75 |
76 | if (x_low >= width - 1) {
77 | x_high = x_low = width - 1;
78 | x = (T)x_low;
79 | } else {
80 | x_high = x_low + 1;
81 | }
82 |
83 | T ly = y - y_low;
84 | T lx = x - x_low;
85 | T hy = 1. - ly, hx = 1. - lx;
86 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
87 |
88 | // save weights and indices
89 | PreCalc pc;
90 | pc.pos1 = y_low * width + x_low;
91 | pc.pos2 = y_low * width + x_high;
92 | pc.pos3 = y_high * width + x_low;
93 | pc.pos4 = y_high * width + x_high;
94 | pc.w1 = w1;
95 | pc.w2 = w2;
96 | pc.w3 = w3;
97 | pc.w4 = w4;
98 | pre_calc[pre_calc_index] = pc;
99 |
100 | pre_calc_index += 1;
101 | }
102 | }
103 | }
104 | }
105 | }
106 |
107 | template
108 | void ROIAlignForward(const int nthreads, const T* input, const T& spatial_scale,
109 | const int channels, const int height, const int width,
110 | const int pooled_height, const int pooled_width,
111 | const int sampling_ratio, const T* rois, T* output,
112 | bool aligned) {
113 | int n_rois = nthreads / channels / pooled_width / pooled_height;
114 | // (n, c, ph, pw) is an element in the pooled output
115 | // can be parallelized using omp
116 | // #pragma omp parallel for num_threads(32)
117 | for (int n = 0; n < n_rois; n++) {
118 | int index_n = n * channels * pooled_width * pooled_height;
119 |
120 | const T* offset_rois = rois + n * 5;
121 | int roi_batch_ind = offset_rois[0];
122 |
123 | // Do not use rounding; this implementation detail is critical
124 | T offset = aligned ? (T)0.5 : (T)0.0;
125 | T roi_start_w = offset_rois[1] * spatial_scale - offset;
126 | T roi_start_h = offset_rois[2] * spatial_scale - offset;
127 | T roi_end_w = offset_rois[3] * spatial_scale - offset;
128 | T roi_end_h = offset_rois[4] * spatial_scale - offset;
129 |
130 | T roi_width = roi_end_w - roi_start_w;
131 | T roi_height = roi_end_h - roi_start_h;
132 | if (aligned) {
133 | AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
134 | "ROIs in ROIAlign cannot have non-negative size!");
135 | } else { // for backward-compatibility only
136 | roi_width = std::max(roi_width, (T)1.);
137 | roi_height = std::max(roi_height, (T)1.);
138 | }
139 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
140 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
141 |
142 | // We use roi_bin_grid to sample the grid and mimic integral
143 | int roi_bin_grid_h = (sampling_ratio > 0)
144 | ? sampling_ratio
145 | : ceil(roi_height / pooled_height); // e.g., = 2
146 | int roi_bin_grid_w =
147 | (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
148 |
149 | // We do average (integral) pooling inside a bin
150 | // When the grid is empty, output zeros == 0/1, instead of NaN.
151 | const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
152 |
153 | // we want to precalculate indices and weights shared by all channels,
154 | // this is the key point of optimization
155 | std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
156 | pooled_width * pooled_height);
157 | pre_calc_for_bilinear_interpolate(
158 | height, width, pooled_height, pooled_width, roi_bin_grid_h,
159 | roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
160 | roi_bin_grid_h, roi_bin_grid_w, pre_calc);
161 |
162 | for (int c = 0; c < channels; c++) {
163 | int index_n_c = index_n + c * pooled_width * pooled_height;
164 | const T* offset_input =
165 | input + (roi_batch_ind * channels + c) * height * width;
166 | int pre_calc_index = 0;
167 |
168 | for (int ph = 0; ph < pooled_height; ph++) {
169 | for (int pw = 0; pw < pooled_width; pw++) {
170 | int index = index_n_c + ph * pooled_width + pw;
171 |
172 | T output_val = 0.;
173 | for (int iy = 0; iy < roi_bin_grid_h; iy++) {
174 | for (int ix = 0; ix < roi_bin_grid_w; ix++) {
175 | PreCalc pc = pre_calc[pre_calc_index];
176 | output_val += pc.w1 * offset_input[pc.pos1] +
177 | pc.w2 * offset_input[pc.pos2] +
178 | pc.w3 * offset_input[pc.pos3] +
179 | pc.w4 * offset_input[pc.pos4];
180 |
181 | pre_calc_index += 1;
182 | }
183 | }
184 | output_val /= count;
185 |
186 | output[index] = output_val;
187 | } // for pw
188 | } // for ph
189 | } // for c
190 | } // for n
191 | }
192 |
193 | template
194 | void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
195 | T& w1, T& w2, T& w3, T& w4, int& x_low,
196 | int& x_high, int& y_low, int& y_high,
197 | const int index /* index for debug only*/) {
198 | // deal with cases that inverse elements are out of feature map boundary
199 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
200 | // empty
201 | w1 = w2 = w3 = w4 = 0.;
202 | x_low = x_high = y_low = y_high = -1;
203 | return;
204 | }
205 |
206 | if (y <= 0) y = 0;
207 | if (x <= 0) x = 0;
208 |
209 | y_low = (int)y;
210 | x_low = (int)x;
211 |
212 | if (y_low >= height - 1) {
213 | y_high = y_low = height - 1;
214 | y = (T)y_low;
215 | } else {
216 | y_high = y_low + 1;
217 | }
218 |
219 | if (x_low >= width - 1) {
220 | x_high = x_low = width - 1;
221 | x = (T)x_low;
222 | } else {
223 | x_high = x_low + 1;
224 | }
225 |
226 | T ly = y - y_low;
227 | T lx = x - x_low;
228 | T hy = 1. - ly, hx = 1. - lx;
229 |
230 | // reference in forward
231 | // T v1 = input[y_low * width + x_low];
232 | // T v2 = input[y_low * width + x_high];
233 | // T v3 = input[y_high * width + x_low];
234 | // T v4 = input[y_high * width + x_high];
235 | // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
236 |
237 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
238 |
239 | return;
240 | }
241 |
242 | template
243 | inline void add(T* address, const T& val) {
244 | *address += val;
245 | }
246 |
247 | template
248 | void ROIAlignBackward(const int nthreads, const T* grad_output,
249 | const T& spatial_scale, const int channels,
250 | const int height, const int width,
251 | const int pooled_height, const int pooled_width,
252 | const int sampling_ratio, T* grad_input, const T* rois,
253 | const int n_stride, const int c_stride,
254 | const int h_stride, const int w_stride, bool aligned) {
255 | for (int index = 0; index < nthreads; index++) {
256 | // (n, c, ph, pw) is an element in the pooled output
257 | int pw = index % pooled_width;
258 | int ph = (index / pooled_width) % pooled_height;
259 | int c = (index / pooled_width / pooled_height) % channels;
260 | int n = index / pooled_width / pooled_height / channels;
261 |
262 | const T* offset_rois = rois + n * 5;
263 | int roi_batch_ind = offset_rois[0];
264 |
265 | // Do not use rounding; this implementation detail is critical
266 | T offset = aligned ? (T)0.5 : (T)0.0;
267 | T roi_start_w = offset_rois[1] * spatial_scale - offset;
268 | T roi_start_h = offset_rois[2] * spatial_scale - offset;
269 | T roi_end_w = offset_rois[3] * spatial_scale - offset;
270 | T roi_end_h = offset_rois[4] * spatial_scale - offset;
271 |
272 | T roi_width = roi_end_w - roi_start_w;
273 | T roi_height = roi_end_h - roi_start_h;
274 | if (aligned) {
275 | AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
276 | "ROIs in ROIAlign do not have non-negative size!");
277 | } else { // for backward-compatibility only
278 | roi_width = std::max(roi_width, (T)1.);
279 | roi_height = std::max(roi_height, (T)1.);
280 | }
281 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
282 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
283 |
284 | T* offset_grad_input =
285 | grad_input + ((roi_batch_ind * channels + c) * height * width);
286 |
287 | int output_offset = n * n_stride + c * c_stride;
288 | const T* offset_grad_output = grad_output + output_offset;
289 | const T grad_output_this_bin =
290 | offset_grad_output[ph * h_stride + pw * w_stride];
291 |
292 | // We use roi_bin_grid to sample the grid and mimic integral
293 | int roi_bin_grid_h = (sampling_ratio > 0)
294 | ? sampling_ratio
295 | : ceil(roi_height / pooled_height); // e.g., = 2
296 | int roi_bin_grid_w =
297 | (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
298 |
299 | // We do average (integral) pooling inside a bin
300 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
301 |
302 | for (int iy = 0; iy < roi_bin_grid_h; iy++) {
303 | const T y = roi_start_h + ph * bin_size_h +
304 | static_cast(iy + .5f) * bin_size_h /
305 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
306 | for (int ix = 0; ix < roi_bin_grid_w; ix++) {
307 | const T x = roi_start_w + pw * bin_size_w +
308 | static_cast(ix + .5f) * bin_size_w /
309 | static_cast(roi_bin_grid_w);
310 |
311 | T w1, w2, w3, w4;
312 | int x_low, x_high, y_low, y_high;
313 |
314 | bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
315 | x_low, x_high, y_low, y_high, index);
316 |
317 | T g1 = grad_output_this_bin * w1 / count;
318 | T g2 = grad_output_this_bin * w2 / count;
319 | T g3 = grad_output_this_bin * w3 / count;
320 | T g4 = grad_output_this_bin * w4 / count;
321 |
322 | if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
323 | // atomic add is not needed for now since it is single threaded
324 | add(offset_grad_input + y_low * width + x_low, static_cast(g1));
325 | add(offset_grad_input + y_low * width + x_high, static_cast(g2));
326 | add(offset_grad_input + y_high * width + x_low, static_cast(g3));
327 | add(offset_grad_input + y_high * width + x_high, static_cast(g4));
328 | } // if
329 | } // ix
330 | } // iy
331 | } // for
332 | } // ROIAlignBackward
333 |
334 | at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input,
335 | const at::Tensor& rois,
336 | const float spatial_scale,
337 | const int pooled_height,
338 | const int pooled_width,
339 | const int sampling_ratio, bool aligned) {
340 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
341 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
342 |
343 | at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
344 |
345 | at::CheckedFrom c = "ROIAlignForwardV2CPULaucher";
346 | at::checkAllSameType(c, {input_t, rois_t});
347 |
348 | auto num_rois = rois.size(0);
349 | auto channels = input.size(1);
350 | auto height = input.size(2);
351 | auto width = input.size(3);
352 |
353 | at::Tensor output = at::zeros(
354 | {num_rois, channels, pooled_height, pooled_width}, input.options());
355 |
356 | auto output_size = num_rois * pooled_height * pooled_width * channels;
357 |
358 | if (output.numel() == 0) return output;
359 |
360 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIAlign_forward", [&] {
361 | ROIAlignForward(
362 | output_size, input.contiguous().data_ptr(), spatial_scale,
363 | channels, height, width, pooled_height, pooled_width, sampling_ratio,
364 | rois.contiguous().data_ptr(), output.data_ptr(), aligned);
365 | });
366 | return output;
367 | }
368 |
369 | at::Tensor ROIAlignBackwardV2CPULaucher(
370 | const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
371 | const int pooled_height, const int pooled_width, const int batch_size,
372 | const int channels, const int height, const int width,
373 | const int sampling_ratio, bool aligned) {
374 | AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
375 | AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
376 |
377 | at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
378 |
379 | at::CheckedFrom c = "ROIAlignBackwardV2CPULaucher";
380 | at::checkAllSameType(c, {grad_t, rois_t});
381 |
382 | at::Tensor grad_input =
383 | at::zeros({batch_size, channels, height, width}, grad.options());
384 |
385 | // handle possibly empty gradients
386 | if (grad.numel() == 0) {
387 | return grad_input;
388 | }
389 |
390 | // get stride values to ensure indexing into gradients is correct.
391 | int n_stride = grad.stride(0);
392 | int c_stride = grad.stride(1);
393 | int h_stride = grad.stride(2);
394 | int w_stride = grad.stride(3);
395 |
396 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIAlign_backward", [&] {
397 | ROIAlignBackward(
398 | grad.numel(), grad.contiguous().data_ptr(), spatial_scale,
399 | channels, height, width, pooled_height, pooled_width, sampling_ratio,
400 | grad_input.data_ptr(), rois.contiguous().data_ptr(),
401 | n_stride, c_stride, h_stride, w_stride, aligned);
402 | });
403 | return grad_input;
404 | }
405 |
--------------------------------------------------------------------------------
/orbbec/roi_align/src/cuda/roi_align_kernel.cu:
--------------------------------------------------------------------------------
1 | // /* pytorch: 1.5.0 ~ 1.10.x */
2 | #include
3 | #include
4 | #include
5 |
6 | // /* pytorch: 1.11.0 ~ latest */
7 | // #include
8 | // #include
9 | // #include
10 | // #include
11 |
12 | // #include
13 | // #include
14 |
15 | #define CUDA_1D_KERNEL_LOOP(i, n) \
16 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
17 | i += blockDim.x * gridDim.x)
18 |
19 | #define THREADS_PER_BLOCK 1024
20 |
21 | inline int GET_BLOCKS(const int N) {
22 | int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
23 | int max_block_num = 65000;
24 | return min(optimal_block_num, max_block_num);
25 | }
26 |
27 | template
28 | __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data,
29 | const int height, const int width,
30 | scalar_t y, scalar_t x) {
31 | // deal with cases that inverse elements are out of feature map boundary
32 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
33 | return 0;
34 | }
35 |
36 | if (y <= 0) y = 0;
37 | if (x <= 0) x = 0;
38 |
39 | int y_low = (int)y;
40 | int x_low = (int)x;
41 | int y_high;
42 | int x_high;
43 |
44 | if (y_low >= height - 1) {
45 | y_high = y_low = height - 1;
46 | y = (scalar_t)y_low;
47 | } else {
48 | y_high = y_low + 1;
49 | }
50 |
51 | if (x_low >= width - 1) {
52 | x_high = x_low = width - 1;
53 | x = (scalar_t)x_low;
54 | } else {
55 | x_high = x_low + 1;
56 | }
57 |
58 | scalar_t ly = y - y_low;
59 | scalar_t lx = x - x_low;
60 | scalar_t hy = 1. - ly;
61 | scalar_t hx = 1. - lx;
62 | // do bilinear interpolation
63 | scalar_t lt = bottom_data[y_low * width + x_low];
64 | scalar_t rt = bottom_data[y_low * width + x_high];
65 | scalar_t lb = bottom_data[y_high * width + x_low];
66 | scalar_t rb = bottom_data[y_high * width + x_high];
67 | scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
68 |
69 | scalar_t val = (w1 * lt + w2 * rt + w3 * lb + w4 * rb);
70 |
71 | return val;
72 | }
73 |
74 | template
75 | __global__ void ROIAlignForwardV1(
76 | const int nthreads, const scalar_t *bottom_data,
77 | const scalar_t *bottom_rois, const scalar_t spatial_scale,
78 | const int sample_num, const int channels, const int height, const int width,
79 | const int pooled_height, const int pooled_width, scalar_t *top_data) {
80 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
81 | // (n, c, ph, pw) is an element in the aligned output
82 | int pw = index % pooled_width;
83 | int ph = (index / pooled_width) % pooled_height;
84 | int c = (index / pooled_width / pooled_height) % channels;
85 | int n = index / pooled_width / pooled_height / channels;
86 |
87 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
88 | int roi_batch_ind = offset_bottom_rois[0];
89 | scalar_t roi_start_w = offset_bottom_rois[1] * spatial_scale;
90 | scalar_t roi_start_h = offset_bottom_rois[2] * spatial_scale;
91 | scalar_t roi_end_w = (offset_bottom_rois[3] + 1) * spatial_scale;
92 | scalar_t roi_end_h = (offset_bottom_rois[4] + 1) * spatial_scale;
93 |
94 | // Force malformed ROIs to be 1x1
95 | scalar_t roi_width = fmaxf((scalar_t)roi_end_w - roi_start_w, 0.);
96 | scalar_t roi_height = fmaxf((scalar_t)roi_end_h - roi_start_h, 0.);
97 |
98 | scalar_t bin_size_h = roi_height / pooled_height;
99 | scalar_t bin_size_w = roi_width / pooled_width;
100 |
101 | const scalar_t *offset_bottom_data =
102 | bottom_data + (roi_batch_ind * channels + c) * height * width;
103 |
104 | int sample_num_h = (sample_num > 0)
105 | ? sample_num
106 | : ceil(roi_height / pooled_height); // e.g., = 2
107 | int sample_num_w =
108 | (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
109 |
110 | scalar_t output_val = 0;
111 | for (int iy = 0; iy < sample_num_h; iy++) {
112 | const scalar_t y = roi_start_h + ph * bin_size_h +
113 | (scalar_t)(iy + scalar_t(.5f)) * bin_size_h /
114 | (scalar_t)(sample_num_h);
115 | for (int ix = 0; ix < sample_num_w; ix++) {
116 | const scalar_t x = roi_start_w + pw * bin_size_w +
117 | (scalar_t)(ix + scalar_t(.5f)) * bin_size_w /
118 | (scalar_t)(sample_num_w);
119 | scalar_t val = bilinear_interpolate(offset_bottom_data,
120 | height, width, y, x);
121 | output_val += val;
122 | }
123 | }
124 | output_val /= (sample_num_h * sample_num_w);
125 | top_data[index] = output_val;
126 | }
127 | }
128 |
129 | int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
130 | const float spatial_scale, const int sample_num,
131 | const int channels, const int height,
132 | const int width, const int num_rois,
133 | const int pooled_height, const int pooled_width,
134 | at::Tensor output) {
135 | const int output_size = num_rois * pooled_height * pooled_width * channels;
136 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
137 | features.scalar_type(), "ROIAlignLaucherForward", ([&] {
138 | const scalar_t *bottom_data = features.data_ptr();
139 | const scalar_t *rois_data = rois.data_ptr();
140 | scalar_t *top_data = output.data_ptr();
141 |
142 | ROIAlignForwardV1
143 | <<>>(
145 | output_size, bottom_data, rois_data, scalar_t(spatial_scale),
146 | sample_num, channels, height, width, pooled_height,
147 | pooled_width, top_data);
148 | }));
149 | THCudaCheck(cudaGetLastError());
150 | return 1;
151 | }
152 |
153 | template
154 | __device__ void bilinear_interpolate_gradient(const int height, const int width,
155 | scalar_t y, scalar_t x,
156 | scalar_t &w1, scalar_t &w2,
157 | scalar_t &w3, scalar_t &w4,
158 | int &x_low, int &x_high,
159 | int &y_low, int &y_high) {
160 | // deal with cases that inverse elements are out of feature map boundary
161 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
162 | w1 = w2 = w3 = w4 = 0.;
163 | x_low = x_high = y_low = y_high = -1;
164 | return;
165 | }
166 |
167 | if (y <= 0) y = 0;
168 | if (x <= 0) x = 0;
169 |
170 | y_low = (int)y;
171 | x_low = (int)x;
172 |
173 | if (y_low >= height - 1) {
174 | y_high = y_low = height - 1;
175 | y = (scalar_t)y_low;
176 | } else {
177 | y_high = y_low + 1;
178 | }
179 |
180 | if (x_low >= width - 1) {
181 | x_high = x_low = width - 1;
182 | x = (scalar_t)x_low;
183 | } else {
184 | x_high = x_low + 1;
185 | }
186 |
187 | scalar_t ly = y - y_low;
188 | scalar_t lx = x - x_low;
189 | scalar_t hy = 1. - ly;
190 | scalar_t hx = 1. - lx;
191 |
192 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
193 |
194 | return;
195 | }
196 |
197 | template
198 | __global__ void ROIAlignBackwardV1(
199 | const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
200 | const scalar_t spatial_scale, const int sample_num, const int channels,
201 | const int height, const int width, const int pooled_height,
202 | const int pooled_width, scalar_t *bottom_diff) {
203 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
204 | // (n, c, ph, pw) is an element in the aligned output
205 | int pw = index % pooled_width;
206 | int ph = (index / pooled_width) % pooled_height;
207 | int c = (index / pooled_width / pooled_height) % channels;
208 | int n = index / pooled_width / pooled_height / channels;
209 |
210 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
211 | int roi_batch_ind = offset_bottom_rois[0];
212 | scalar_t roi_start_w = offset_bottom_rois[1] * spatial_scale;
213 | scalar_t roi_start_h = offset_bottom_rois[2] * spatial_scale;
214 | scalar_t roi_end_w = (offset_bottom_rois[3] + 1) * spatial_scale;
215 | scalar_t roi_end_h = (offset_bottom_rois[4] + 1) * spatial_scale;
216 |
217 | // Force malformed ROIs to be 1x1
218 | scalar_t roi_width = fmaxf((scalar_t)roi_end_w - roi_start_w, 0.);
219 | scalar_t roi_height = fmaxf((scalar_t)roi_end_h - roi_start_h, 0.);
220 |
221 | scalar_t bin_size_h = roi_height / pooled_height;
222 | scalar_t bin_size_w = roi_width / pooled_width;
223 |
224 | scalar_t *offset_bottom_diff =
225 | bottom_diff + (roi_batch_ind * channels + c) * height * width;
226 | int offset_top = (n * channels + c) * pooled_height * pooled_width +
227 | ph * pooled_width + pw;
228 | scalar_t offset_top_diff = top_diff[offset_top];
229 |
230 | int sample_num_h = (sample_num > 0)
231 | ? sample_num
232 | : ceil(roi_height / pooled_height); // e.g., = 2
233 | int sample_num_w =
234 | (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
235 |
236 | const scalar_t count = (scalar_t)(sample_num_h * sample_num_w);
237 |
238 | for (int iy = 0; iy < sample_num_h; iy++) {
239 | const scalar_t y =
240 | roi_start_h + ph * bin_size_h +
241 | (scalar_t)(iy + .5f) * bin_size_h / (scalar_t)(sample_num_h);
242 | for (int ix = 0; ix < sample_num_w; ix++) {
243 | const scalar_t x =
244 | roi_start_w + pw * bin_size_w +
245 | (scalar_t)(ix + .5f) * bin_size_w / (scalar_t)(sample_num_w);
246 | scalar_t w1, w2, w3, w4;
247 | int x_low, x_high, y_low, y_high;
248 |
249 | bilinear_interpolate_gradient(
250 | height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
251 | scalar_t g1 = offset_top_diff * w1 / count;
252 | scalar_t g2 = offset_top_diff * w2 / count;
253 | scalar_t g3 = offset_top_diff * w3 / count;
254 | scalar_t g4 = offset_top_diff * w4 / count;
255 | if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
256 | atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
257 | atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
258 | atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
259 | atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
260 | }
261 | }
262 | }
263 | }
264 | }
265 |
266 | int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
267 | const float spatial_scale, const int sample_num,
268 | const int channels, const int height,
269 | const int width, const int num_rois,
270 | const int pooled_height, const int pooled_width,
271 | at::Tensor bottom_grad) {
272 | const int output_size = num_rois * pooled_height * pooled_width * channels;
273 |
274 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
275 | top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] {
276 | const scalar_t *top_diff = top_grad.data_ptr();
277 | const scalar_t *rois_data = rois.data_ptr();
278 | scalar_t *bottom_diff = bottom_grad.data_ptr();
279 | if (sizeof(scalar_t) == sizeof(double)) {
280 | fprintf(stderr, "double is not supported\n");
281 | exit(-1);
282 | }
283 |
284 | ROIAlignBackwardV1
285 | <<>>(
287 | output_size, top_diff, rois_data, spatial_scale, sample_num,
288 | channels, height, width, pooled_height, pooled_width,
289 | bottom_diff);
290 | }));
291 | THCudaCheck(cudaGetLastError());
292 | return 1;
293 | }
294 |
--------------------------------------------------------------------------------
/orbbec/roi_align/src/cuda/roi_align_kernel_v2.cu:
--------------------------------------------------------------------------------
1 | // Modified from
2 | // https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign
3 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4 | // /* pytorch: 1.5.0 ~ 1.10.x */
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 |
11 | // /* pytorch: 1.11.0 ~ latest */
12 | // #include
13 | // #include
14 | // #include
15 | // #include
16 |
17 | // #include
18 | // #include
19 |
20 |
21 | // TODO make it in a common file
22 | #define CUDA_1D_KERNEL_LOOP(i, n) \
23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
24 | i += blockDim.x * gridDim.x)
25 |
26 | template
27 | __device__ T bilinear_interpolate(const T* bottom_data, const int height,
28 | const int width, T y, T x,
29 | const int index /* index for debug only*/) {
30 | // deal with cases that inverse elements are out of feature map boundary
31 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
32 | // empty
33 | return 0;
34 | }
35 |
36 | if (y <= 0) y = 0;
37 | if (x <= 0) x = 0;
38 |
39 | int y_low = (int)y;
40 | int x_low = (int)x;
41 | int y_high;
42 | int x_high;
43 |
44 | if (y_low >= height - 1) {
45 | y_high = y_low = height - 1;
46 | y = (T)y_low;
47 | } else {
48 | y_high = y_low + 1;
49 | }
50 |
51 | if (x_low >= width - 1) {
52 | x_high = x_low = width - 1;
53 | x = (T)x_low;
54 | } else {
55 | x_high = x_low + 1;
56 | }
57 |
58 | T ly = y - y_low;
59 | T lx = x - x_low;
60 | T hy = 1. - ly, hx = 1. - lx;
61 | // do bilinear interpolation
62 | T v1 = bottom_data[y_low * width + x_low];
63 | T v2 = bottom_data[y_low * width + x_high];
64 | T v3 = bottom_data[y_high * width + x_low];
65 | T v4 = bottom_data[y_high * width + x_high];
66 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
67 |
68 | T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
69 |
70 | return val;
71 | }
72 |
73 | template
74 | __global__ void RoIAlignForwardV2(
75 | const int nthreads, const T* bottom_data, const T spatial_scale,
76 | const int channels, const int height, const int width,
77 | const int pooled_height, const int pooled_width, const int sampling_ratio,
78 | const T* bottom_rois, T* top_data, bool aligned) {
79 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
80 | // (n, c, ph, pw) is an element in the pooled output
81 | int pw = index % pooled_width;
82 | int ph = (index / pooled_width) % pooled_height;
83 | int c = (index / pooled_width / pooled_height) % channels;
84 | int n = index / pooled_width / pooled_height / channels;
85 |
86 | const T* offset_bottom_rois = bottom_rois + n * 5;
87 | int roi_batch_ind = offset_bottom_rois[0];
88 |
89 | // Do not use rounding; this implementation detail is critical
90 | T offset = aligned ? (T)0.5 : (T)0.0;
91 | T roi_start_w = offset_bottom_rois[1] * spatial_scale - offset;
92 | T roi_start_h = offset_bottom_rois[2] * spatial_scale - offset;
93 | T roi_end_w = offset_bottom_rois[3] * spatial_scale - offset;
94 | T roi_end_h = offset_bottom_rois[4] * spatial_scale - offset;
95 |
96 | T roi_width = roi_end_w - roi_start_w;
97 | T roi_height = roi_end_h - roi_start_h;
98 | if (!aligned) { // for backward-compatibility only
99 | roi_width = max(roi_width, (T)1.);
100 | roi_height = max(roi_height, (T)1.);
101 | }
102 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
103 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
104 |
105 | const T* offset_bottom_data =
106 | bottom_data + (roi_batch_ind * channels + c) * height * width;
107 |
108 | // We use roi_bin_grid to sample the grid and mimic integral
109 | int roi_bin_grid_h = (sampling_ratio > 0)
110 | ? sampling_ratio
111 | : ceil(roi_height / pooled_height); // e.g., = 2
112 | int roi_bin_grid_w =
113 | (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
114 |
115 | // We do average (integral) pooling inside a bin
116 | // When the grid is empty, output zeros.
117 | const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
118 |
119 | T output_val = 0.;
120 | for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
121 | {
122 | const T y = roi_start_h + ph * bin_size_h +
123 | static_cast(iy + .5f) * bin_size_h /
124 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
125 | for (int ix = 0; ix < roi_bin_grid_w; ix++) {
126 | const T x = roi_start_w + pw * bin_size_w +
127 | static_cast(ix + .5f) * bin_size_w /
128 | static_cast(roi_bin_grid_w);
129 |
130 | T val = bilinear_interpolate(offset_bottom_data, height, width, y, x,
131 | index);
132 | output_val += val;
133 | }
134 | }
135 | output_val /= count;
136 |
137 | top_data[index] = output_val;
138 | }
139 | }
140 |
141 | template
142 | __device__ void bilinear_interpolate_gradient(
143 | const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
144 | int& x_low, int& x_high, int& y_low, int& y_high,
145 | const int index /* index for debug only*/) {
146 | // deal with cases that inverse elements are out of feature map boundary
147 | if (y < -1.0 || y > height || x < -1.0 || x > width) {
148 | // empty
149 | w1 = w2 = w3 = w4 = 0.;
150 | x_low = x_high = y_low = y_high = -1;
151 | return;
152 | }
153 |
154 | if (y <= 0) y = 0;
155 | if (x <= 0) x = 0;
156 |
157 | y_low = (int)y;
158 | x_low = (int)x;
159 |
160 | if (y_low >= height - 1) {
161 | y_high = y_low = height - 1;
162 | y = (T)y_low;
163 | } else {
164 | y_high = y_low + 1;
165 | }
166 |
167 | if (x_low >= width - 1) {
168 | x_high = x_low = width - 1;
169 | x = (T)x_low;
170 | } else {
171 | x_high = x_low + 1;
172 | }
173 |
174 | T ly = y - y_low;
175 | T lx = x - x_low;
176 | T hy = 1. - ly, hx = 1. - lx;
177 |
178 | // reference in forward
179 | // T v1 = bottom_data[y_low * width + x_low];
180 | // T v2 = bottom_data[y_low * width + x_high];
181 | // T v3 = bottom_data[y_high * width + x_low];
182 | // T v4 = bottom_data[y_high * width + x_high];
183 | // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
184 |
185 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
186 |
187 | return;
188 | }
189 |
190 | template
191 | __global__ void RoIAlignBackwardFeatureV2(
192 | const int nthreads, const T* top_diff, const int num_rois,
193 | const T spatial_scale, const int channels, const int height,
194 | const int width, const int pooled_height, const int pooled_width,
195 | const int sampling_ratio, T* bottom_diff, const T* bottom_rois,
196 | bool aligned) {
197 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
198 | // (n, c, ph, pw) is an element in the pooled output
199 | int pw = index % pooled_width;
200 | int ph = (index / pooled_width) % pooled_height;
201 | int c = (index / pooled_width / pooled_height) % channels;
202 | int n = index / pooled_width / pooled_height / channels;
203 |
204 | const T* offset_bottom_rois = bottom_rois + n * 5;
205 | int roi_batch_ind = offset_bottom_rois[0];
206 |
207 | // Do not use rounding; this implementation detail is critical
208 | T offset = aligned ? (T)0.5 : (T)0.0;
209 | T roi_start_w = offset_bottom_rois[1] * spatial_scale - offset;
210 | T roi_start_h = offset_bottom_rois[2] * spatial_scale - offset;
211 | T roi_end_w = offset_bottom_rois[3] * spatial_scale - offset;
212 | T roi_end_h = offset_bottom_rois[4] * spatial_scale - offset;
213 |
214 | T roi_width = roi_end_w - roi_start_w;
215 | T roi_height = roi_end_h - roi_start_h;
216 | if (!aligned) { // for backward-compatibility only
217 | roi_width = max(roi_width, (T)1.);
218 | roi_height = max(roi_height, (T)1.);
219 | }
220 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
221 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
222 |
223 | T* offset_bottom_diff =
224 | bottom_diff + (roi_batch_ind * channels + c) * height * width;
225 |
226 | int top_offset = (n * channels + c) * pooled_height * pooled_width;
227 | const T* offset_top_diff = top_diff + top_offset;
228 | const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
229 |
230 | // We use roi_bin_grid to sample the grid and mimic integral
231 | int roi_bin_grid_h = (sampling_ratio > 0)
232 | ? sampling_ratio
233 | : ceil(roi_height / pooled_height); // e.g., = 2
234 | int roi_bin_grid_w =
235 | (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
236 |
237 | // We do average (integral) pooling inside a bin
238 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
239 |
240 | for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
241 | {
242 | const T y = roi_start_h + ph * bin_size_h +
243 | static_cast(iy + .5f) * bin_size_h /
244 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
245 | for (int ix = 0; ix < roi_bin_grid_w; ix++) {
246 | const T x = roi_start_w + pw * bin_size_w +
247 | static_cast(ix + .5f) * bin_size_w /
248 | static_cast(roi_bin_grid_w);
249 |
250 | T w1, w2, w3, w4;
251 | int x_low, x_high, y_low, y_high;
252 |
253 | bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
254 | x_low, x_high, y_low, y_high, index);
255 |
256 | T g1 = top_diff_this_bin * w1 / count;
257 | T g2 = top_diff_this_bin * w2 / count;
258 | T g3 = top_diff_this_bin * w3 / count;
259 | T g4 = top_diff_this_bin * w4 / count;
260 |
261 | if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
262 | atomicAdd(offset_bottom_diff + y_low * width + x_low,
263 | static_cast(g1));
264 | atomicAdd(offset_bottom_diff + y_low * width + x_high,
265 | static_cast(g2));
266 | atomicAdd(offset_bottom_diff + y_high * width + x_low,
267 | static_cast(g3));
268 | atomicAdd(offset_bottom_diff + y_high * width + x_high,
269 | static_cast(g4));
270 | } // if
271 | } // ix
272 | } // iy
273 | } // CUDA_1D_KERNEL_LOOP
274 | } // RoIAlignBackward
275 |
276 | at::Tensor ROIAlignForwardV2Laucher(const at::Tensor& input,
277 | const at::Tensor& rois,
278 | const float spatial_scale,
279 | const int pooled_height,
280 | const int pooled_width,
281 | const int sampling_ratio, bool aligned) {
282 | AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
283 | AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
284 | at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
285 |
286 | at::CheckedFrom c = "ROIAlign_forward_cuda";
287 | at::checkAllSameGPU(c, {input_t, rois_t});
288 | at::checkAllSameType(c, {input_t, rois_t});
289 | at::cuda::CUDAGuard device_guard(input.device());
290 |
291 | auto num_rois = rois.size(0);
292 | auto channels = input.size(1);
293 | auto height = input.size(2);
294 | auto width = input.size(3);
295 |
296 | auto output = at::empty({num_rois, channels, pooled_height, pooled_width},
297 | input.options());
298 | auto output_size = num_rois * pooled_height * pooled_width * channels;
299 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
300 |
301 | dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast(output_size), static_cast(512)), static_cast(4096)));
302 | dim3 block(512);
303 |
304 | if (output.numel() == 0) {
305 | AT_CUDA_CHECK(cudaGetLastError());
306 | return output;
307 | }
308 |
309 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
310 | RoIAlignForwardV2<<>>(
311 | output_size, input.contiguous().data_ptr(), spatial_scale,
312 | channels, height, width, pooled_height, pooled_width, sampling_ratio,
313 | rois.contiguous().data_ptr(), output.data_ptr(), aligned);
314 | });
315 | cudaDeviceSynchronize();
316 | AT_CUDA_CHECK(cudaGetLastError());
317 | return output;
318 | }
319 |
320 | // TODO remove the dependency on input and use instead its sizes -> save memory
321 | at::Tensor ROIAlignBackwardV2Laucher(
322 | const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
323 | const int pooled_height, const int pooled_width, const int batch_size,
324 | const int channels, const int height, const int width,
325 | const int sampling_ratio, bool aligned) {
326 | AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
327 | AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
328 |
329 | at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
330 | at::CheckedFrom c = "ROIAlign_backward_cuda";
331 | at::checkAllSameGPU(c, {grad_t, rois_t});
332 | at::checkAllSameType(c, {grad_t, rois_t});
333 | at::cuda::CUDAGuard device_guard(grad.device());
334 |
335 | auto num_rois = rois.size(0);
336 | auto grad_input =
337 | at::zeros({batch_size, channels, height, width}, grad.options());
338 |
339 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
340 |
341 | dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast(grad.numel()), static_cast(512)), static_cast(4096)));
342 | dim3 block(512);
343 |
344 | // handle possibly empty gradients
345 | if (grad.numel() == 0) {
346 | AT_CUDA_CHECK(cudaGetLastError());
347 | return grad_input;
348 | }
349 |
350 | AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] {
351 | RoIAlignBackwardFeatureV2<<>>(
352 | grad.numel(), grad.contiguous().data_ptr(), num_rois,
353 | spatial_scale, channels, height, width, pooled_height, pooled_width,
354 | sampling_ratio, grad_input.data_ptr(),
355 | rois.contiguous().data_ptr(), aligned);
356 | });
357 | AT_CUDA_CHECK(cudaGetLastError());
358 | return grad_input;
359 | }
360 |
--------------------------------------------------------------------------------
/orbbec/roi_align/src/roi_align_ext.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 |
7 | #ifdef WITH_CUDA
8 | int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
9 | const float spatial_scale, const int sample_num,
10 | const int channels, const int height,
11 | const int width, const int num_rois,
12 | const int pooled_height, const int pooled_width,
13 | at::Tensor output);
14 |
15 | int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
16 | const float spatial_scale, const int sample_num,
17 | const int channels, const int height,
18 | const int width, const int num_rois,
19 | const int pooled_height, const int pooled_width,
20 | at::Tensor bottom_grad);
21 |
22 | at::Tensor ROIAlignForwardV2Laucher(const at::Tensor& input,
23 | const at::Tensor& rois,
24 | const float spatial_scale,
25 | const int pooled_height,
26 | const int pooled_width,
27 | const int sampling_ratio, bool aligned);
28 |
29 | at::Tensor ROIAlignBackwardV2Laucher(
30 | const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
31 | const int pooled_height, const int pooled_width, const int batch_size,
32 | const int channels, const int height, const int width,
33 | const int sampling_ratio, bool aligned);
34 | #endif
35 |
36 | at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input,
37 | const at::Tensor& rois,
38 | const float spatial_scale,
39 | const int pooled_height,
40 | const int pooled_width,
41 | const int sampling_ratio, bool aligned);
42 |
43 | at::Tensor ROIAlignBackwardV2CPULaucher(
44 | const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
45 | const int pooled_height, const int pooled_width, const int batch_size,
46 | const int channels, const int height, const int width,
47 | const int sampling_ratio, bool aligned);
48 |
49 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
50 | #define CHECK_CONTIGUOUS(x) \
51 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
52 | #define CHECK_INPUT(x) \
53 | CHECK_CUDA(x); \
54 | CHECK_CONTIGUOUS(x)
55 |
56 | int ROIAlign_forwardV1(at::Tensor features, at::Tensor rois, int pooled_height,
57 | int pooled_width, float spatial_scale, int sample_num,
58 | at::Tensor output) {
59 | if (features.device().is_cuda()) {
60 | #ifdef WITH_CUDA
61 | CHECK_INPUT(features);
62 | CHECK_INPUT(rois);
63 | CHECK_INPUT(output);
64 | at::DeviceGuard guard(features.device());
65 |
66 | // Number of ROIs
67 | int num_rois = rois.size(0);
68 | int size_rois = rois.size(1);
69 |
70 | if (size_rois != 5) {
71 | printf("wrong roi size\n");
72 | return 0;
73 | }
74 |
75 | int num_channels = features.size(1);
76 | int data_height = features.size(2);
77 | int data_width = features.size(3);
78 |
79 | ROIAlignForwardLaucher(features, rois, spatial_scale, sample_num,
80 | num_channels, data_height, data_width, num_rois,
81 | pooled_height, pooled_width, output);
82 |
83 | return 1;
84 | #else
85 | AT_ERROR("ROIAlign is not compiled with GPU support");
86 | #endif
87 | }
88 | AT_ERROR("ROIAlign is not implemented on CPU");
89 | }
90 |
91 | int ROIAlign_backwardV1(at::Tensor top_grad, at::Tensor rois, int pooled_height,
92 | int pooled_width, float spatial_scale, int sample_num,
93 | at::Tensor bottom_grad) {
94 | if (top_grad.device().is_cuda()) {
95 | #ifdef WITH_CUDA
96 | CHECK_INPUT(top_grad);
97 | CHECK_INPUT(rois);
98 | CHECK_INPUT(bottom_grad);
99 | at::DeviceGuard guard(top_grad.device());
100 |
101 | // Number of ROIs
102 | int num_rois = rois.size(0);
103 | int size_rois = rois.size(1);
104 | if (size_rois != 5) {
105 | printf("wrong roi size\n");
106 | return 0;
107 | }
108 |
109 | int num_channels = bottom_grad.size(1);
110 | int data_height = bottom_grad.size(2);
111 | int data_width = bottom_grad.size(3);
112 |
113 | ROIAlignBackwardLaucher(top_grad, rois, spatial_scale, sample_num,
114 | num_channels, data_height, data_width, num_rois,
115 | pooled_height, pooled_width, bottom_grad);
116 |
117 | return 1;
118 | #else
119 | AT_ERROR("ROIAlign is not compiled with GPU support");
120 | #endif
121 | }
122 | AT_ERROR("ROIAlign is not implemented on CPU");
123 | }
124 |
125 | // Interface for Python
126 | inline at::Tensor ROIAlign_forwardV2(const at::Tensor& input,
127 | const at::Tensor& rois,
128 | const float spatial_scale,
129 | const int pooled_height,
130 | const int pooled_width,
131 | const int sampling_ratio, bool aligned) {
132 | if (input.device().is_cuda()) {
133 | #ifdef WITH_CUDA
134 | return ROIAlignForwardV2Laucher(input, rois, spatial_scale, pooled_height,
135 | pooled_width, sampling_ratio, aligned);
136 | #else
137 | AT_ERROR("ROIAlignV2 is not compiled with GPU support");
138 | #endif
139 | }
140 | return ROIAlignForwardV2CPULaucher(input, rois, spatial_scale, pooled_height,
141 | pooled_width, sampling_ratio, aligned);
142 | }
143 |
144 | inline at::Tensor ROIAlign_backwardV2(
145 | const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
146 | const int pooled_height, const int pooled_width, const int batch_size,
147 | const int channels, const int height, const int width,
148 | const int sampling_ratio, bool aligned) {
149 | if (grad.device().is_cuda()) {
150 | #ifdef WITH_CUDA
151 | return ROIAlignBackwardV2Laucher(grad, rois, spatial_scale, pooled_height,
152 | pooled_width, batch_size, channels, height,
153 | width, sampling_ratio, aligned);
154 | #else
155 | AT_ERROR("ROIAlignV2 is not compiled with GPU support");
156 | #endif
157 | }
158 | return ROIAlignBackwardV2CPULaucher(grad, rois, spatial_scale, pooled_height,
159 | pooled_width, batch_size, channels,
160 | height, width, sampling_ratio, aligned);
161 | }
162 |
163 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
164 | m.def("forward_v1", &ROIAlign_forwardV1, "Roi_Align V1 forward");
165 | m.def("backward_v1", &ROIAlign_backwardV1, "Roi_Align V1 backward");
166 | m.def("forward_v2", &ROIAlign_forwardV2, "Roi_Align V2 forward");
167 | m.def("backward_v2", &ROIAlign_backwardV2, "Roi_Align V2 backward");
168 | }
169 |
--------------------------------------------------------------------------------
/orbbec/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # from . import compiling_info
2 | from .compiling_info import get_compiler_version, get_compiling_cuda_version
3 |
4 | # get_compiler_version = compiling_info.get_compiler_version
5 | # get_compiling_cuda_version = compiling_info.get_compiling_cuda_version
6 |
7 | __all__ = ['get_compiler_version', 'get_compiling_cuda_version']
8 |
--------------------------------------------------------------------------------
/orbbec/utils/src/compiling_info.cpp:
--------------------------------------------------------------------------------
1 | // modified from
2 | // https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/vision.cpp
3 | #include
4 |
5 | #ifdef WITH_CUDA
6 | #include
7 | int get_cudart_version() { return CUDART_VERSION; }
8 | #endif
9 |
10 | std::string get_compiling_cuda_version() {
11 | #ifdef WITH_CUDA
12 | std::ostringstream oss;
13 |
14 | // copied from
15 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
16 | auto printCudaStyleVersion = [&](int v) {
17 | oss << (v / 1000) << "." << (v / 10 % 100);
18 | if (v % 10 != 0) {
19 | oss << "." << (v % 10);
20 | }
21 | };
22 | printCudaStyleVersion(get_cudart_version());
23 | return oss.str();
24 | #else
25 | return std::string("not available");
26 | #endif
27 | }
28 |
29 | // similar to
30 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
31 | std::string get_compiler_version() {
32 | std::ostringstream ss;
33 | #if defined(__GNUC__)
34 | #ifndef __clang__
35 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
36 | #endif
37 | #endif
38 |
39 | #if defined(__clang_major__)
40 | {
41 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
42 | << __clang_patchlevel__;
43 | }
44 | #endif
45 |
46 | #if defined(_MSC_VER)
47 | { ss << "MSVC " << _MSC_FULL_VER; }
48 | #endif
49 | return ss.str();
50 | }
51 |
52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
53 | m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
54 | m.def("get_compiling_cuda_version", &get_compiling_cuda_version, "get_compiling_cuda_version");
55 | }
56 |
--------------------------------------------------------------------------------
/orbbec/warpaffine/__init__.py:
--------------------------------------------------------------------------------
1 | from .warpaffine_ext import affine_opencv, affine_torch
2 |
3 | __all__ = ['affine_opencv', 'affine_torch']
--------------------------------------------------------------------------------
/orbbec/warpaffine/src/cpu/warpaffine_opencv.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | namespace py = pybind11;
10 |
11 | /* Python->C++ Mat */
12 | cv::Mat numpy_uint8_1c_to_cv_mat(py::array_t& input)
13 | {
14 | if (input.ndim() != 2)
15 | throw std::runtime_error("1-channel image must be 2 dims ");
16 |
17 | py::buffer_info buf = input.request();
18 | cv::Mat mat(buf.shape[0], buf.shape[1], CV_8UC1, (unsigned char*)buf.ptr);
19 |
20 | return mat;
21 | }
22 |
23 | cv::Mat numpy_uint8_3c_to_cv_mat(py::array_t& input)
24 | {
25 | if (input.ndim() != 3)
26 | throw std::runtime_error("3-channel image must be 3 dims ");
27 |
28 | py::buffer_info buf = input.request();
29 | cv::Mat mat(buf.shape[0], buf.shape[1], CV_8UC3, (unsigned char*)buf.ptr);
30 |
31 | return mat;
32 | }
33 |
34 | /* C++ Mat ->numpy */
35 | py::array_t cv_mat_uint8_1c_to_numpy(cv::Mat& input)
36 | {
37 | py::array_t dst;
38 | dst = py::array_t({ input.rows,input.cols }, input.data);
39 |
40 | return dst;
41 | }
42 |
43 | py::array_t cv_mat_uint8_3c_to_numpy(cv::Mat& input)
44 | {
45 | py::array_t dst;
46 | dst = py::array_t({ input.rows,input.cols,3}, input.data);
47 |
48 | return dst;
49 | }
50 |
51 | py::array_t affine_opencv(py::array_t& input,
52 | py::array_t& from_point,
53 | py::array_t& to_point)
54 | {
55 | // step1: get affine transform matrix
56 | py::buffer_info from_p_buf = from_point.request();
57 | py::buffer_info from_t_buf = to_point.request();
58 | float* fp = (float*)from_p_buf.ptr;
59 | float* tp = (float*)from_t_buf.ptr;
60 | int fp_stride = from_p_buf.shape[1];
61 | int tp_stride = from_t_buf.shape[1];
62 |
63 | cv::Point2f src[3] = {};
64 | cv::Point2f dst[3] = {};
65 |
66 | for(int i = 0; i < from_p_buf.shape[0]; i++)
67 | {
68 | src[i] = cv::Point2f(fp[fp_stride * i + 0], fp[fp_stride * i + 1]);
69 | dst[i] = cv::Point2f(tp[tp_stride * i + 0], tp[tp_stride * i + 1]);
70 | }
71 |
72 | cv::Mat H = cv::getAffineTransform(src, dst);
73 |
74 | // step2: run affine transform
75 | cv::Mat input_mat = numpy_uint8_1c_to_cv_mat(input);
76 | cv::Mat output;
77 | cv::warpAffine(input_mat, output, H, cv::Size(600, 800), cv::INTER_LINEAR);
78 |
79 | return cv_mat_uint8_1c_to_numpy(output);
80 | }
--------------------------------------------------------------------------------
/orbbec/warpaffine/src/cpu/warpaffine_torch_v1.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | template
4 | void affine_cpu_kernel(
5 | int32_t inHeight,
6 | int32_t inWidth,
7 | int32_t inChannel,
8 | int32_t inPlanarSize,
9 | int32_t outHeight,
10 | int32_t outWidth,
11 | int32_t outPlanarSize,
12 | scalar_t* dst,
13 | const scalar_t* src,
14 | const scalar_t* M,
15 | float delta);
16 |
17 | template
18 | at::Tensor affine_torch_cpu(const at::Tensor& input, /*[B, C, H, W]*/
19 | const at::Tensor& from, /*[B, 3, 3]*/
20 | const at::Tensor& to, /*[B, 3, 2]*/
21 | const int out_h,
22 | const int out_w)
23 | {
24 | // step1. get affine transform matrix
25 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
26 | AT_ASSERTM(from.device().is_cpu(), "from point must be a CPU tensor");
27 | AT_ASSERTM(to.device().is_cpu(), "to point must be a CPU tensor");
28 |
29 | // F = ((X^T*X)^-1)*(X^T*Y)
30 | auto matrix_l = (torch::transpose(from,1, 2).bmm(from)).inverse(); //(X^T*X)^-1)
31 | auto matrix_l_ptr = matrix_l.contiguous().data_ptr();
32 | auto matrix_r = (torch::transpose(from, 1, 2)).bmm(to); //(X^T*Y)
33 | auto matrix_r_ptr = matrix_r.contiguous().data_ptr();
34 | auto affine = matrix_l.bmm(matrix_r);
35 | auto affine_ptr = affine.data_ptr();
36 |
37 | auto affine_matrix = torch::transpose(affine, 1, 2); // ((X^T*X)^-1) * (X^T*Y))^T --> [B, 2, 3].
38 |
39 | // step2. affine per imgs
40 | // get data pointer
41 | auto matrix_ptr = affine_matrix.contiguous().data_ptr();
42 | auto input_ptr = input.contiguous().data_ptr();
43 |
44 | auto nimgs = input.size(0);
45 | auto img_c = input.size(1);
46 | auto img_h = input.size(2);
47 | auto img_w = input.size(3);
48 | auto input_size = img_c * img_h * img_w;
49 | auto output_size = img_c * out_h * out_w;
50 |
51 | // build dst tensor
52 | auto output_tensor = at::zeros({nimgs, img_c, out_h, out_w}, input.options());
53 | auto output_ptr = output_tensor.contiguous().data_ptr();
54 |
55 | for(int i = 0; i < nimgs; i++)
56 | {
57 | scalar_t* matrix = matrix_ptr + i * 6;
58 | scalar_t* in = input_ptr + i * input_size;
59 | scalar_t* out = output_ptr + i * output_size;
60 | affine_cpu_kernel(img_h, img_w, img_c, img_w*img_h,
61 | out_h, out_w, out_h*out_w, out, in, matrix, 0.0f);
62 | }
63 |
64 | return output_tensor;
65 | }
66 |
67 |
68 | at::Tensor affine_cpu(const at::Tensor& input, /*[B, C, H, W]*/
69 | const at::Tensor& from, /*[B, 3, 3]*/
70 | const at::Tensor& to, /*[B, 3, 2]*/
71 | const int out_h,
72 | const int out_w)
73 | {
74 | at::Tensor result;
75 | // AT_DISPATCH_FLOATING_TYPES: input.scalar_type() => scalar_t
76 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "affine_cpu", [&] {
77 | result = affine_torch_cpu(input, from, to, out_h, out_w);
78 | });
79 |
80 | return result;
81 | }
--------------------------------------------------------------------------------
/orbbec/warpaffine/src/cpu/warpaffine_torch_v2.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | // #include
3 |
4 | /*
5 | https://www.cnblogs.com/shine-lee/p/10950963.html
6 | */
7 | template
8 | void affine_cpu_kernel(
9 | int32_t inHeight,
10 | int32_t inWidth,
11 | int32_t inChannel,
12 | int32_t inPlanarSize,
13 | int32_t outHeight,
14 | int32_t outWidth,
15 | int32_t outPlanarSize,
16 | scalar_t* dst,
17 | const scalar_t* src,
18 | const scalar_t* M,
19 | float delta)
20 | {
21 | for (int32_t i = 0; i < outHeight; i++)
22 | {
23 | float base_x = M[1] * i + M[2];
24 | float base_y = M[4] * i + M[5];
25 |
26 | for (int32_t j = 0; j < outWidth; j++)
27 | {
28 | float x = base_x + M[0] * j;
29 | float y = base_y + M[3] * j;
30 | int32_t sx0 = (int32_t)x;
31 | int32_t sy0 = (int32_t)y;
32 |
33 | float u = x - sx0;
34 | float v = y - sy0;
35 |
36 | float tab[4];
37 | float taby[2], tabx[2];
38 | float v0, v1, v2, v3;
39 | taby[0] = 1.0f - v;
40 | taby[1] = v;
41 | tabx[0] = 1.0f - u;
42 | tabx[1] = u;
43 |
44 | tab[0] = taby[0] * tabx[0];
45 | tab[1] = taby[0] * tabx[1];
46 | tab[2] = taby[1] * tabx[0];
47 | tab[3] = taby[1] * tabx[1];
48 |
49 | int32_t idxDst = (i * outWidth + j);
50 |
51 | bool flag0 = (sx0 >= 0 && sx0 < inWidth && sy0 >= 0 && sy0 < inHeight);
52 | bool flag1 = (sx0 + 1 >= 0 && sx0 + 1 < inWidth && sy0 >= 0 && sy0 < inHeight);
53 | bool flag2 = (sx0 >= 0 && sx0 < inWidth && sy0 + 1 >= 0 && sy0 + 1 < inHeight);
54 | bool flag3 = (sx0 + 1 >= 0 && sx0 + 1 < inWidth && sy0 + 1 >= 0 && sy0 + 1 < inHeight);
55 |
56 | for(int32_t c = 0; c < inChannel; c++)
57 | {
58 | int32_t position1 = ((sy0 + 0) * inWidth + sx0);
59 | int32_t position2 = ((sy0 + 1) * inWidth + sx0);
60 | v0 = flag0 ? src[position1 + c * inPlanarSize + 0] : delta;
61 | v1 = flag1 ? src[position1 + c * inPlanarSize + 1] : delta;
62 | v2 = flag2 ? src[position2 + c * inPlanarSize + 0] : delta;
63 | v3 = flag3 ? src[position2 + c * inPlanarSize + 1] : delta;
64 | scalar_t sum = 0.0f;
65 | sum += v0 * tab[0] + v1 * tab[1] + v2 * tab[2] + v3 * tab[3];
66 | dst[idxDst + c * outPlanarSize] = static_cast(sum);
67 | }
68 | }
69 | }
70 | }
71 |
72 | template
73 | at::Tensor affine_torch_cpu(const at::Tensor& input, /*[B, C, H, W]*/
74 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
75 | const int out_h,
76 | const int out_w)
77 | {
78 | AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
79 | AT_ASSERTM(affine_matrix.device().is_cpu(), "affine_matrix must be a CPU tensor");
80 |
81 | // auto input_cpy = input.contiguous();
82 | // auto in_tensor = input_cpy.squeeze().permute({1, 2, 0}).contiguous();
83 | // in_tensor = in_tensor.mul(255).clamp(0, 255).to(torch::kU8);
84 | // in_tensor = in_tensor.to(torch::kCPU);
85 | // cv::Mat resultImg(800, 600, CV_8UC3);
86 | // std::memcpy((void *) resultImg.data, in_tensor.data_ptr(), sizeof(torch::kU8) * in_tensor.numel());
87 | // cv::imwrite("input.png", resultImg);
88 |
89 | auto matrix_ptr = affine_matrix.contiguous().data_ptr();
90 | auto input_ptr = input.contiguous().data_ptr();
91 | auto nimgs = input.size(0);
92 | auto img_c = input.size(1);
93 | auto img_h = input.size(2);
94 | auto img_w = input.size(3);
95 | auto in_img_size = img_c * img_h * img_w;
96 | auto out_img_size = img_c * out_h * out_w;
97 |
98 | // build dst tensor
99 | auto output_tensor = at::zeros({nimgs, img_c, out_h, out_w}, input.options());
100 | auto output_ptr = output_tensor.contiguous().data_ptr();
101 |
102 | for(int i = 0; i < nimgs; i++)
103 | {
104 | scalar_t* matrix = matrix_ptr + i * 6;
105 | scalar_t* in = input_ptr + i * in_img_size;
106 | scalar_t* out = output_ptr + i * out_img_size;
107 | affine_cpu_kernel(img_h, img_w, img_c, img_w*img_h,
108 | out_h, out_w, out_h*out_w, out, in, matrix, 0.0f);
109 | }
110 |
111 | return output_tensor;
112 | }
113 |
114 |
115 | at::Tensor affine_cpu(const at::Tensor& input, /*[B, C, H, W]*/
116 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
117 | const int out_h,
118 | const int out_w)
119 | {
120 | at::Tensor result;
121 | // AT_DISPATCH_FLOATING_TYPES: input.scalar_type() => scalar_t
122 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "affine_cpu", [&] {
123 | result = affine_torch_cpu(input, affine_matrix, out_h, out_w);
124 | });
125 | return result;
126 | }
--------------------------------------------------------------------------------
/orbbec/warpaffine/src/cuda/warpaffine_cuda.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #include
3 |
4 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
5 | #define CHECK_CONTIGUOUS(x) \
6 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
7 | #define CHECK_INPUT(x) \
8 | CHECK_CUDA(x); \
9 | CHECK_CONTIGUOUS(x)
10 |
11 | at::Tensor affine_cuda_forward(const at::Tensor& input, /*[B, C, H, W]*/
12 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
13 | const int out_h,
14 | const int out_w);
15 |
16 | at::Tensor affine_gpu(const at::Tensor& input, /*[B, C, H, W]*/
17 | const at::Tensor& affine_matrix, /*[B, 2, 3]*/
18 | const int out_h,
19 | const int out_w)
20 | {
21 | CHECK_INPUT(input);
22 | CHECK_INPUT(affine_matrix);
23 |
24 | // Ensure CUDA uses the input tensor device.
25 | at::DeviceGuard guard(input.device());
26 |
27 | return affine_cuda_forward(input, affine_matrix, out_h, out_w);
28 | }
--------------------------------------------------------------------------------
/orbbec/warpaffine/src/cuda/warpaffine_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | // /* pytorch: 1.5.0 ~ 1.10.x */
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | // /* ---------------------- */
10 |
11 | // /* pytorch: 1.11.0 ~ latest */
12 | // #include
13 | // #include
14 | // #include