├── .gitignore ├── 1.4 └── cuda_config.h ├── LICENSE ├── common.cc ├── work_sharder.h ├── common.h ├── README.md ├── ps_roi_align_op.h ├── rotated_ps_roi_align_op.h ├── CMakeLists.txt ├── ps_roi_align_op.cu ├── ps_roi_align_grad_op.cu ├── rotated_ps_roi_align_op.cu ├── test_op.py ├── rotated_ps_roi_align_grad_op.cu ├── ps_roi_align_op.cc ├── rotated_ps_roi_align_op.cc ├── ps_roi_align_grad_op.cc └── rotated_ps_roi_align_grad_op.cc /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /1.4/cuda_config.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // DO NOT EDIT: automatically generated file 17 | #ifndef CUDA_CUDA_CONFIG_H_ 18 | #define CUDA_CUDA_CONFIG_H_ 19 | 20 | #define TF_CUDA_CAPABILITIES CudaVersion("3.0") 21 | 22 | #define TF_CUDA_VERSION "8.0" 23 | #define TF_CUDNN_VERSION "5" 24 | 25 | #define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-8.0" 26 | 27 | #endif // CUDA_CUDA_CONFIG_H_ 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Changan Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /common.cc: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #include "common.h" 23 | 24 | //__attribute__((always_inline)) 25 | // template::type>::value> 26 | void atomic_float_add(volatile float* ptr, const float operand) 27 | { 28 | assert(is_aligned(ptr, 4)); 29 | 30 | volatile int32_t* iptr = reinterpret_cast(ptr); 31 | int32_t expected = *iptr; 32 | 33 | while (true) 34 | { 35 | const float value = binary_cast(expected); 36 | const int32_t new_value = binary_cast(value + operand); 37 | const int32_t actual = __sync_val_compare_and_swap(iptr, expected, new_value); 38 | if (actual == expected) 39 | return; 40 | expected = actual; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /work_sharder.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_ 17 | #define TENSORFLOW_UTIL_WORK_SHARDER_H_ 18 | 19 | #include 20 | 21 | #include "tensorflow/core/lib/core/threadpool.h" 22 | #include "tensorflow/core/platform/types.h" 23 | 24 | namespace tensorflow { 25 | 26 | // Shards the "total" unit of work assuming each unit of work having 27 | // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., 28 | // total - 1. Each shard contains 1 or more units of work and the 29 | // total cost of each shard is roughly the same. The calling thread and the 30 | // "workers" are used to compute each shard (calling work(start, 31 | // limit). A common configuration is that "workers" is a thread pool 32 | // with at least "max_parallelism" threads. 33 | // 34 | // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds 35 | // if not CPU-bound) to complete a unit of work. Overestimating creates too 36 | // many shards and CPU time will be dominated by per-shard overhead, such as 37 | // Context creation. Underestimating may not fully make use of the specified 38 | // parallelism. 39 | // 40 | // "work" should be a callable taking (int64, int64) arguments. 41 | // work(start, limit) computes the work units from [start, 42 | // limit), i.e., [start, limit) is a shard. 43 | // 44 | // REQUIRES: max_parallelism >= 0 45 | // REQUIRES: workers != nullptr 46 | // REQUIRES: total >= 0 47 | // REQUIRES: cost_per_unit >= 0 48 | void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, 49 | int64 cost_per_unit, std::function work); 50 | 51 | } // end namespace tensorflow 52 | 53 | #endif // TENSORFLOW_UTIL_WORK_SHARDER_H_ 54 | -------------------------------------------------------------------------------- /common.h: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #ifndef COMMON_H_ 23 | #define COMMON_H_ 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | // atomic addition for float using gcc built-in functions for atomic memory access 30 | // this code snippet borrowed from https://codereview.stackexchange.com/questions/135852/atomic-floating-point-addition 31 | template 32 | __attribute__((always_inline)) Target binary_cast(Source s) 33 | { 34 | static_assert(sizeof(Target) == sizeof(Source), "binary_cast: 'Target' must has the same size as 'Source'"); 35 | union 36 | { 37 | Source m_source; 38 | Target m_target; 39 | } u; 40 | 41 | u.m_source = s; 42 | return u.m_target; 43 | } 44 | 45 | template 46 | __attribute__((always_inline)) bool is_pow2(const T x) 47 | { 48 | return (x & (x - 1)) == 0; 49 | } 50 | 51 | template 52 | __attribute__((always_inline)) bool is_aligned(const T ptr, const size_t alignment) 53 | { 54 | assert(alignment > 0); 55 | assert(is_pow2(alignment)); 56 | 57 | const uintptr_t p = (uintptr_t)ptr; 58 | return (p & (alignment - 1)) == 0; 59 | } 60 | 61 | extern void atomic_float_add(volatile float* ptr, const float operand); 62 | 63 | #endif // COMMON_H_ 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PsRoIAlign Operation In Tensorflow C++ API 2 | PsRoIAlign involves interpolation techniques for [PsRoiPooling](https://arxiv.org/abs/1605.06409) (position-sensitive RoI pooling operation), the interpolation idea is proposed in [RoIAlign](https://arxiv.org/abs/1703.06870) to avoid any quantization of the RoI boundaries. The first adoption of PsRoIAlign might be in this paper [Light-Head R-CNN: In Defense of Two-Stage Object Detector](https://arxiv.org/abs/1711.07264). 3 | 4 | You can find more details about the RoiPooling technique in [Fast R-CNN](https://arxiv.org/abs/1504.08083) and [Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition](https://arxiv.org/abs/1406.4729). 5 | 6 | ## ## 7 | This repository contains code of the implement of PsRoIAlign operation in Tensorflow C++ API. You can use this operation in many popular two-stage object detector. Both research work using PsRoIAlign and contribution to this repository are welcomed. 8 | 9 | For using this op in your own machine, just following these steps: 10 | 11 | - copy the header file "cuda\_config.h" from "your\_python\_path/site-packages/external/local\_config\_cuda/cuda/cuda/cuda\_config.h" to "your\_python\_path/site-packages/tensorflow/include/tensorflow/stream\_executor/cuda/cuda\_config.h". 12 | 13 | - run the following script: 14 | 15 | 16 | ```sh 17 | mkdir build 18 | cd build && cmake .. 19 | make 20 | ``` 21 | 22 | - run "test\_op.py" and check the numeric errors to test your install 23 | 24 | - follow the below codes snippet to integrate this Op into your own code: 25 | 26 | ```python 27 | op_module = tf.load_op_library(so_lib_path) 28 | ps_roi_align = op_module.ps_roi_align 29 | 30 | @ops.RegisterGradient("PsRoiAlign") 31 | def _ps_roi_align_grad(op, grad, _): 32 | '''The gradients for `PsRoiAlign`. 33 | ''' 34 | inputs_features = op.inputs[0] 35 | rois = op.inputs[1] 36 | pooled_features_grad = op.outputs[0] 37 | pooled_index = op.outputs[1] 38 | grid_dim_width = op.get_attr('grid_dim_width') 39 | grid_dim_height = op.get_attr('grid_dim_height') 40 | 41 | return [op_module.ps_roi_align_grad(inputs_features, rois, grad, pooled_index, grid_dim_width, grid_dim_height, pool_method), None] 42 | 43 | pool_method = 'max' # or 'mean' 44 | pool_result = ps_roi_align(features, rois, 2, 2, pool_method) 45 | ``` 46 | 47 | The code is tested under TensorFlow 1.6 with CUDA 8.0 using Ubuntu 16.04. This PsRoIAlign Op had been used to train Xception based [Light-Head RCNN](https://arxiv.org/abs/1711.07264) successfully with performance at ~75%mAP on PASCAL VOC 2007 Test dataset, you can see codes [here](https://github.com/HiKapok/X-Detector). 48 | 49 | Update: 50 | 51 | - Added support for mean pooling (default is max pooling) 52 | - PsRoIAlign now support oriented RoI inputs (for both max and mean pooling). 53 | 54 | Future Work: 55 | 56 | - Check if there is need to ensure the convex of polygon 57 | - Improve performance 58 | 59 | If you encountered some linkage problem when generating or loading *.so, you are highly recommended to read this section in the [official tourial](https://www.tensorflow.org/extend/adding_an_op#compile_the_op_using_your_system_compiler_tensorflow_binary_installation) to make sure you were using the same C++ ABI version. 60 | 61 | ## ## 62 | The MIT License -------------------------------------------------------------------------------- /ps_roi_align_op.h: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #ifndef KERNEL_PSROI_POOLING_H_ 23 | #define KERNEL_PSROI_POOLING_H_ 24 | 25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 | #include "tensorflow/core/framework/tensor_types.h" 27 | #include "tensorflow/core/framework/op_kernel.h" 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | using tensorflow::TTypes; 35 | using tensorflow::OpKernelContext; 36 | 37 | using CPUDevice = Eigen::ThreadPoolDevice; 38 | using GPUDevice = Eigen::GpuDevice; 39 | 40 | using KDimSize = std::tuple; 41 | 42 | template 43 | struct PSROIAlignFunctor { 44 | void operator()(OpKernelContext* context, const Device& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info); 45 | }; 46 | 47 | template 48 | struct PSROIAlignGradFunctor { 49 | void operator()(OpKernelContext* context, const Device& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info); 50 | }; 51 | 52 | #if GOOGLE_CUDA == 1 53 | template 54 | struct PSROIAlignFunctor { 55 | void operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info); 56 | }; 57 | #endif 58 | 59 | #if GOOGLE_CUDA == 1 60 | template 61 | struct PSROIAlignGradFunctor { 62 | void operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info); 63 | }; 64 | #endif 65 | 66 | #endif // KERNEL_PSROI_POOLING_H_ 67 | 68 | -------------------------------------------------------------------------------- /rotated_ps_roi_align_op.h: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #ifndef KERNEL_ROTATED_PSROI_POOLING_H_ 23 | #define KERNEL_ROTATED_PSROI_POOLING_H_ 24 | 25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 | #include "tensorflow/core/framework/tensor_types.h" 27 | #include "tensorflow/core/framework/op_kernel.h" 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | using tensorflow::TTypes; 35 | using tensorflow::OpKernelContext; 36 | 37 | using CPUDevice = Eigen::ThreadPoolDevice; 38 | using GPUDevice = Eigen::GpuDevice; 39 | 40 | using KDimSize = std::tuple; 41 | 42 | #define PI 3.14159265359 43 | 44 | template 45 | struct RotatedPSROIAlignFunctor { 46 | void operator()(OpKernelContext* context, const Device& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info); 47 | }; 48 | 49 | template 50 | struct RotatedPSROIAlignGradFunctor { 51 | void operator()(OpKernelContext* context, const Device& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info); 52 | }; 53 | 54 | #if GOOGLE_CUDA == 1 55 | template 56 | struct RotatedPSROIAlignFunctor { 57 | void operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info); 58 | }; 59 | #endif 60 | 61 | #if GOOGLE_CUDA == 1 62 | template 63 | struct RotatedPSROIAlignGradFunctor { 64 | void operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info); 65 | }; 66 | #endif 67 | 68 | #endif // KERNEL_ROTATED_PSROI_POOLING_H_ 69 | 70 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8) 2 | PROJECT(ps_roi_align) 3 | 4 | find_package(CUDA REQUIRED) 5 | 6 | # Pass options to NVCC 7 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61 -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr") 8 | 9 | # compiler flags 10 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2 ${OpenMP_CXX_FLAGS} -Wall -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -DGOOGLE_CUDA=1") 11 | 12 | # TensorFlow dependencies 13 | EXECUTE_PROCESS(COMMAND python3.5 -c "import os; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; print(tf.sysconfig.get_include(), end='', flush=True)" OUTPUT_VARIABLE TF_INC) 14 | 15 | EXECUTE_PROCESS(COMMAND python3.5 -c "import os; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; print(tf.sysconfig.get_lib(), end='', flush=True)" OUTPUT_VARIABLE TF_LIB) 16 | 17 | 18 | MESSAGE(STATUS "Found TF_INC: " ${TF_INC}) 19 | MESSAGE(STATUS "Found TF_INC_EXTERNAL: " ${TF_INC}/external/nsync/public) 20 | MESSAGE(STATUS "Found TF_LIB: " ${TF_LIB}) 21 | 22 | 23 | INCLUDE_DIRECTORIES(${TF_INC}) 24 | INCLUDE_DIRECTORIES(${TF_INC}/external/nsync/public) 25 | LINK_DIRECTORIES(${TF_LIB}) 26 | 27 | # approach 1 28 | # CUDA_ADD_LIBRARY(ps_roi_align_gpu SHARED ps_roi_align_op.cu OPTIONS -I$TF_INC/tensorflow/stream_executor/cuda -I/usr/local) 29 | 30 | # ADD_LIBRARY(ps_roi_align SHARED 31 | # ps_roi_align_op.h 32 | # ps_roi_align_op.cc 33 | # ) 34 | 35 | # TARGET_LINK_LIBRARIES(ps_roi_align tensorflow_framework ${CUDA_LIBRARIES} ps_roi_align_gpu) 36 | 37 | 38 | # approach 2 39 | CUDA_COMPILE(PSROI_ALIGN_CU_O ps_roi_align_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) 40 | CUDA_COMPILE(PSROI_ALIGN_GRAD_CU_O ps_roi_align_grad_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) 41 | CUDA_COMPILE(ROTATED_PSROI_ALIGN_CU_O rotated_ps_roi_align_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) 42 | CUDA_COMPILE(ROTATED_PSROI_ALIGN_GRAD_CU_O rotated_ps_roi_align_grad_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) 43 | 44 | ADD_LIBRARY(ps_roi_align SHARED 45 | ${PSROI_ALIGN_CU_O} 46 | ${PSROI_ALIGN_GRAD_CU_O} 47 | ${ROTATED_PSROI_ALIGN_CU_O} 48 | ${ROTATED_PSROI_ALIGN_GRAD_CU_O} 49 | rotated_ps_roi_align_op.h 50 | rotated_ps_roi_align_op.cc 51 | rotated_ps_roi_align_grad_op.cc 52 | ps_roi_align_op.h 53 | ps_roi_align_op.cc 54 | ps_roi_align_grad_op.cc 55 | common.h 56 | common.cc 57 | ) 58 | 59 | TARGET_LINK_LIBRARIES(ps_roi_align tensorflow_framework ${CUDA_LIBRARIES}) 60 | 61 | 62 | #nvcc -std=c++11 -c -o cuda_op_kernel.cu.o ../focal_loss_op.cu -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -I/usr/local/ -I$TF_INC/tensorflow/stream_executor/cuda --expt-relaxed-constexpr -gencode arch=compute_61,code=sm_61 63 | 64 | 65 | #TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 66 | #TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 67 | #g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -I$TF_INC -I$TF_INC/external/nsync/public -L$TF_LIB -ltensorflow_framework -O2 68 | #nvcc -std=c++11 -c -o cuda_op_kernel.cu.o ../focal_loss_op.cu -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -I/usr/local/ -I$TF_INC/tensorflow/stream_executor/cuda 69 | #cd `python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())'` 70 | #cd tensorflow/stream_executor/cuda 71 | #curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/master/third_party/toolchains/gpus/cuda/cuda/cuda_config.h 72 | 73 | 74 | # g++ -std=c++11 -shared -o libfocal_loss.so focal_loss_op.cc focal_loss_grad_op.cc cuda_compile_generated_focal_loss_op.cu.o -I $TF_INC -I$TF_INC/external/nsync/public -fPIC -lcudart -L$TF_LIB -ltensorflow_framework 75 | 76 | # nvcc -std=c++11 -c -o focal_loss_op.cu.o focal_loss_op.cu -I$TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -I/usr/local -O2 -gencode arch=compute_61,code=sm_61 --expt-relaxed-constexpr 77 | 78 | # g++ -std=c++11 -shared -o libfocal_loss.so focal_loss_op.cc focal_loss_grad_op.cc focal_loss_op.cu.o -I$TF_INC -I$TF_INC/external/nsync/public -fPIC -L/usr/local/cuda/lib64 -lcudart -L$TF_LIB -ltensorflow_framework 79 | 80 | 81 | -------------------------------------------------------------------------------- /ps_roi_align_op.cu: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #if GOOGLE_CUDA == 1 23 | #define EIGEN_USE_GPU 24 | #include "ps_roi_align_op.h" 25 | #include "tensorflow/core/util/cuda_kernel_helper.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor_shape.h" 28 | 29 | using namespace tensorflow; 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | // Define the CUDA kernel. 36 | template 37 | __global__ void PSROIAlignCudaKernel(CudaLaunchConfig config, const T * inputs, const T * rois, T * pooled_features, int32_t * pooled_index, const int32_t grid_dim_width, const int32_t grid_dim_height, const int batch_size, const int num_channals, const int map_height, const int map_width, const int num_rois, const bool using_max_pool) { 38 | 39 | const int32_t grid_size = grid_dim_width * grid_dim_height; 40 | const int32_t bank_size = num_channals / grid_size; 41 | 42 | CUDA_1D_KERNEL_LOOP(worker_index, config.virtual_thread_count) { 43 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 44 | const int32_t position_index = (worker_index % num_channals) / bank_size; 45 | const int32_t row_index = position_index / grid_dim_width; 46 | const int32_t col_index = position_index % grid_dim_width; 47 | // position of the channal of pooled feature 48 | // position of the channal in the bank of feature map 49 | const int32_t channal_pos_remainder = worker_index % bank_size; 50 | const int32_t pool_index = worker_index / num_channals; 51 | const int32_t image_index = pool_index / num_rois; 52 | const int32_t roi_index = pool_index % num_rois; 53 | 54 | const T * roi_to_pool = rois + (image_index * num_rois + roi_index) * 4; 55 | 56 | const T * feature_map_to_pool = inputs + (image_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder) * map_height * map_width; 57 | T * pooled_features_start = pooled_features + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 58 | int32_t * pooled_index_start = pooled_index + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 59 | 60 | if(roi_to_pool[2] < std::numeric_limits::min() || roi_to_pool[3] < std::numeric_limits::min()){ 61 | *pooled_features_start = static_cast(0); 62 | continue; 63 | } 64 | // T roi_ymin = static_cast(0); 65 | // T roi_xmin = static_cast(0); 66 | // T roi_ymax = static_cast(0); 67 | // T roi_xmax = static_cast(0); 68 | // fix ROI 69 | // std::tie(roi_ymin, roi_xmin, roi_ymax, roi_xmax) = [roi_to_pool, map_height, map_width](){ 70 | T _roi_y_center = static_cast(ldg(roi_to_pool) * map_height); 71 | T _roi_x_center = static_cast(ldg(roi_to_pool + 1) * map_width); 72 | T _roi_h = tf_max(ldg(roi_to_pool + 2) * map_height, static_cast(1)); 73 | T _roi_w = tf_max(ldg(roi_to_pool + 3) * map_width, static_cast(1)); 74 | 75 | T roi_ymin = tf_max(_roi_y_center - static_cast(_roi_h / 2.), static_cast(0)); 76 | T roi_xmin = tf_max(_roi_x_center - static_cast(_roi_w / 2.), static_cast(0)); 77 | T roi_ymax = tf_min(_roi_y_center + static_cast(_roi_h / 2.), static_cast(map_height) - std::numeric_limits::min()); 78 | T roi_xmax = tf_min(_roi_x_center + static_cast(_roi_w / 2.), static_cast(map_width) - std::numeric_limits::min()); 79 | // return std::make_tuple(roi_ymin, roi_xmin, roi_ymax, roi_xmax); 80 | // }(); 81 | 82 | T roi_h = roi_ymax - roi_ymin; 83 | T roi_w = roi_xmax - roi_xmin; 84 | float pool_bin_width = static_cast(roi_w) / grid_dim_width; 85 | float pool_bin_height = static_cast(roi_h) / grid_dim_height; 86 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 87 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 88 | 89 | // std::cout << "pool_bin_width: " << pool_bin_width << " pool_bin_height: " << pool_bin_height << " num_elem_width: " << num_elem_width << " num_elem_height: " << num_elem_height << std::endl; 90 | 91 | // std::cout << "worker_index: " << worker_index << " roi_index: " << roi_index 92 | // << " roi_ymin: " << roi_ymin << " roi_xmin: " << roi_xmin << " roi_ymax: " << roi_ymax << " roi_xmax: " << roi_xmax << " image_index: " << image_index << " position_index: " << (position_index % grid_size) << " channal_pos_remainder: " << channal_pos_remainder << std::endl; 93 | 94 | float step_widht_each_bin = pool_bin_width / num_elem_width; 95 | float step_height_each_bin = pool_bin_height / num_elem_height; 96 | 97 | float pool_width_start = roi_xmin + pool_bin_width * col_index; 98 | float pool_height_start = roi_ymin + pool_bin_height * row_index; 99 | int32_t max_pool_ind = 0; 100 | //T max_elem = std::numeric_limits::lowest(); 101 | T max_or_acc_elem = using_max_pool ? std::numeric_limits::lowest() : static_cast(0); 102 | for (int32_t h_ind = 0; h_ind < num_elem_height; ++h_ind) { 103 | for (int32_t w_ind = 0; w_ind < num_elem_width; ++w_ind) { 104 | float col_to_pool = pool_width_start + step_widht_each_bin * w_ind + step_widht_each_bin / 2.; 105 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 106 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 107 | int32_t int_col_to_pool = static_cast(col_to_pool); 108 | int32_t int_row_to_pool = static_cast(row_to_pool); 109 | float float_col_to_pool = col_to_pool - int_col_to_pool; 110 | float float_row_to_pool = row_to_pool - int_row_to_pool; 111 | 112 | int32_t current_switch_ind = num_elem_width * h_ind + w_ind; 113 | //std::cout << "current_switch_ind: " << current_switch_ind << std::endl; 114 | T temp_value = static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * ldg(feature_map_to_pool + int_row_to_pool * map_width + int_col_to_pool) + 115 | (1. - float_col_to_pool) * float_row_to_pool * ldg(feature_map_to_pool + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool) + 116 | float_col_to_pool * (1. - float_row_to_pool) * ldg(feature_map_to_pool + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1)) + 117 | float_col_to_pool * float_row_to_pool * ldg(feature_map_to_pool + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1))); 118 | if(using_max_pool){ 119 | if(max_or_acc_elem < temp_value){ 120 | max_or_acc_elem = temp_value; 121 | max_pool_ind = current_switch_ind; 122 | } 123 | }else{ 124 | max_or_acc_elem += temp_value; 125 | } 126 | } 127 | } 128 | if(!using_max_pool) max_or_acc_elem /= static_cast(num_elem_height * num_elem_width); 129 | *pooled_features_start = max_or_acc_elem; 130 | *pooled_index_start = using_max_pool ? max_pool_ind : static_cast(0); 131 | } 132 | } 133 | 134 | template 135 | void PSROIAlignFunctor::operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info) { 136 | 137 | int batch_size = 0; 138 | int num_channals = 0; 139 | int map_height = 0; 140 | int map_width = 0; 141 | int num_rois = 0; 142 | bool using_max_pool = false; 143 | 144 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 145 | 146 | CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * num_rois * num_channals, d); 147 | PSROIAlignCudaKernel <<>> (config, inputs.data(), rois.data(), pooled_features.data(), pooled_index.data(), grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool); 149 | 150 | cudaError_t err = cudaGetLastError(); 151 | if(cudaSuccess != err) 152 | { 153 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 154 | exit( -1 ); 155 | } 156 | } 157 | 158 | template struct PSROIAlignFunctor; 159 | // #define DEFINE_GPU_SPECS(T) \ 160 | // template struct PSROIAlignFunctorGPU; 161 | 162 | // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); 163 | 164 | #endif // GOOGLE_CUDA 165 | -------------------------------------------------------------------------------- /ps_roi_align_grad_op.cu: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #if GOOGLE_CUDA == 1 23 | #define EIGEN_USE_GPU 24 | #include "ps_roi_align_op.h" 25 | #include "tensorflow/core/util/cuda_kernel_helper.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor_shape.h" 28 | 29 | using namespace tensorflow; 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | // Define the CUDA kernel. 36 | template 37 | __global__ void PSROIAlignGradCudaKernel(CudaLaunchConfig config, const T * inputs, const T * rois, const T * pooled_features_grad, const int32_t * pooled_index, T * grad_output, const int32_t grid_dim_width, const int32_t grid_dim_height, const int batch_size, const int num_channals, const int map_height, const int map_width, const int num_rois, const bool using_max_pool) { 38 | 39 | const int32_t grid_size = grid_dim_width * grid_dim_height; 40 | const int32_t bank_size = num_channals / grid_size; 41 | 42 | CUDA_1D_KERNEL_LOOP(worker_index, config.virtual_thread_count) { 43 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 44 | const int32_t position_index = (worker_index % num_channals) / bank_size; 45 | const int32_t row_index = position_index / grid_dim_width; 46 | const int32_t col_index = position_index % grid_dim_width; 47 | // position of the channal of pooled feature 48 | // position of the channal in the bank of feature map 49 | const int32_t channal_pos_remainder = worker_index % bank_size; 50 | const int32_t pool_index = worker_index / num_channals; 51 | const int32_t image_index = pool_index / num_rois; 52 | const int32_t roi_index = pool_index % num_rois; 53 | 54 | const T * roi_to_pool = rois + (image_index * num_rois + roi_index) * 4; 55 | 56 | if(ldg(roi_to_pool + 2) < std::numeric_limits::min() || ldg(roi_to_pool + 3) < std::numeric_limits::min()) continue; 57 | // T roi_ymin = static_cast(0); 58 | // T roi_xmin = static_cast(0); 59 | // T roi_ymax = static_cast(0); 60 | // T roi_xmax = static_cast(0); 61 | // fix ROI 62 | // std::tie(roi_ymin, roi_xmin, roi_ymax, roi_xmax) = [roi_to_pool, map_height, map_width](){ 63 | T _roi_y_center = static_cast(ldg(roi_to_pool) * map_height); 64 | T _roi_x_center = static_cast(ldg(roi_to_pool + 1) * map_width); 65 | T _roi_h = tf_max(ldg(roi_to_pool + 2) * map_height, static_cast(1)); 66 | T _roi_w = tf_max(ldg(roi_to_pool + 3) * map_width, static_cast(1)); 67 | 68 | T roi_ymin = tf_max(_roi_y_center - static_cast(_roi_h / 2.), static_cast(0)); 69 | T roi_xmin = tf_max(_roi_x_center - static_cast(_roi_w / 2.), static_cast(0)); 70 | T roi_ymax = tf_min(_roi_y_center + static_cast(_roi_h / 2.), static_cast(map_height) - std::numeric_limits::min()); 71 | T roi_xmax = tf_min(_roi_x_center + static_cast(_roi_w / 2.), static_cast(map_width) - std::numeric_limits::min()); 72 | // return std::make_tuple(roi_ymin, roi_xmin, roi_ymax, roi_xmax); 73 | // }(); 74 | 75 | T roi_h = roi_ymax - roi_ymin; 76 | T roi_w = roi_xmax - roi_xmin; 77 | float pool_bin_width = static_cast(roi_w) / grid_dim_width; 78 | float pool_bin_height = static_cast(roi_h) / grid_dim_height; 79 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 80 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 81 | 82 | // std::cout << "pool_bin_width: " << pool_bin_width << " pool_bin_height: " << pool_bin_height << " num_elem_width: " << num_elem_width << " num_elem_height: " << num_elem_height << std::endl; 83 | 84 | // std::cout << "worker_index: " << worker_index << " roi_index: " << roi_index 85 | // << " roi_ymin: " << roi_ymin << " roi_xmin: " << roi_xmin << " roi_ymax: " << roi_ymax << " roi_xmax: " << roi_xmax << " image_index: " << image_index << " position_index: " << (position_index % grid_size) << " channal_pos_remainder: " << channal_pos_remainder << std::endl; 86 | 87 | float step_width_each_bin = pool_bin_width / num_elem_width; 88 | float step_height_each_bin = pool_bin_height / num_elem_height; 89 | 90 | T * grad_output_start = reinterpret_cast(grad_output + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width); 91 | 92 | const T * pooled_features_start = pooled_features_grad + worker_index; 93 | const int32_t * pooled_index_start = pooled_index + worker_index; 94 | // T * pooled_features_start = pooled_features_grad + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 95 | // int32_t * pooled_index_start = pooled_index + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 96 | 97 | float pool_width_start = roi_xmin + pool_bin_width * col_index; 98 | float pool_height_start = roi_ymin + pool_bin_height * row_index; 99 | 100 | if(using_max_pool){ 101 | const int32_t h_ind = ldg(pooled_index_start) / num_elem_width; 102 | const int32_t w_ind = ldg(pooled_index_start) % num_elem_width; 103 | 104 | float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 105 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 106 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 107 | int32_t int_col_to_pool = static_cast(col_to_pool); 108 | int32_t int_row_to_pool = static_cast(row_to_pool); 109 | float float_col_to_pool = col_to_pool - int_col_to_pool; 110 | float float_row_to_pool = row_to_pool - int_row_to_pool; 111 | 112 | const T grad_in = ldg(pooled_features_start); 113 | 114 | atomicAdd(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 115 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 116 | atomicAdd(grad_output_start + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 117 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 118 | }else{ 119 | const T grad_in = ldg(pooled_features_start) / static_cast(num_elem_width * num_elem_height); 120 | for (int32_t h_ind = 0; h_ind < num_elem_height; ++h_ind) { 121 | for (int32_t w_ind = 0; w_ind < num_elem_width; ++w_ind) { 122 | float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 123 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 124 | 125 | int32_t int_col_to_pool = static_cast(col_to_pool); 126 | int32_t int_row_to_pool = static_cast(row_to_pool); 127 | float float_col_to_pool = col_to_pool - int_col_to_pool; 128 | float float_row_to_pool = row_to_pool - int_row_to_pool; 129 | 130 | atomicAdd(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 131 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 132 | atomicAdd(grad_output_start + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 133 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 134 | } 135 | } 136 | } 137 | } 138 | } 139 | 140 | template 141 | void PSROIAlignGradFunctor::operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 142 | 143 | int batch_size = 0; 144 | int num_channals = 0; 145 | int map_height = 0; 146 | int map_width = 0; 147 | int num_rois = 0; 148 | bool using_max_pool = false; 149 | 150 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 151 | 152 | CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * num_rois * num_channals, d); 153 | //grad_output = grad_output.setZero(); 154 | SetZero <<>> (batch_size * map_height * map_width * num_channals, grad_output.data()); 155 | 156 | PSROIAlignGradCudaKernel <<>> (config, inputs.data(), rois.data(), pooled_features_grad.data(), pooled_index.data(), grad_output.data(), grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool); 158 | 159 | cudaError_t err = cudaGetLastError(); 160 | if(cudaSuccess != err) 161 | { 162 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 163 | exit( -1 ); 164 | } 165 | } 166 | 167 | template struct PSROIAlignGradFunctor; 168 | // #define DEFINE_GPU_SPECS(T) \ 169 | // template struct PSROIAlignFunctorGPU; 170 | 171 | // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); 172 | 173 | #endif // GOOGLE_CUDA 174 | -------------------------------------------------------------------------------- /rotated_ps_roi_align_op.cu: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #if GOOGLE_CUDA == 1 23 | #define EIGEN_USE_GPU 24 | #include "rotated_ps_roi_align_op.h" 25 | #include "tensorflow/core/util/cuda_kernel_helper.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor_shape.h" 28 | 29 | using namespace tensorflow; 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | // Define the CUDA kernel. 36 | template 37 | __global__ void RotatedPSROIAlignCudaKernel(CudaLaunchConfig config, const T * inputs, const T * rois, const int32_t * orders, T * pooled_features, int32_t * pooled_index, const int32_t grid_dim_width, const int32_t grid_dim_height, const int batch_size, const int num_channals, const int map_height, const int map_width, const int num_rois, const bool using_max_pool) { 38 | 39 | const int32_t grid_size = grid_dim_width * grid_dim_height; 40 | const int32_t bank_size = num_channals / grid_size; 41 | 42 | CUDA_1D_KERNEL_LOOP(worker_index, config.virtual_thread_count) { 43 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 44 | const int32_t position_index = (worker_index % num_channals) / bank_size; 45 | const int32_t row_index = position_index / grid_dim_width; 46 | const int32_t col_index = position_index % grid_dim_width; 47 | // position of the channal of pooled feature 48 | // position of the channal in the bank of feature map 49 | const int32_t channal_pos_remainder = worker_index % bank_size; 50 | const int32_t pool_index = worker_index / num_channals; 51 | const int32_t image_index = pool_index / num_rois; 52 | const int32_t roi_index = pool_index % num_rois; 53 | 54 | const T * roi_to_pool = rois + (image_index * num_rois + roi_index) * 8; 55 | const int32_t * roi_order = orders + image_index * num_rois + roi_index; 56 | 57 | const T * feature_map_to_pool = inputs + (image_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder) * map_height * map_width; 58 | T * pooled_features_start = pooled_features + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 59 | int32_t * pooled_index_start = pooled_index + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 60 | 61 | int32_t order = ldg(roi_order) < 0 ? 0 : ldg(roi_order) * 2; 62 | 63 | T roi_y0 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 64 | T roi_x0 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 65 | T roi_y1 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 66 | T roi_x1 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 67 | T roi_y2 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 68 | T roi_x2 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 69 | T roi_y3 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 70 | T roi_x3 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 71 | 72 | 73 | double len0 = static_cast((roi_y1 - roi_y0) * (roi_y1 - roi_y0) + (roi_x1 - roi_x0) * (roi_x1 - roi_x0)); 74 | double len1 = static_cast((roi_y2 - roi_y1) * (roi_y2 - roi_y1) + (roi_x2 - roi_x1) * (roi_x2 - roi_x1)); 75 | double len2 = static_cast((roi_y3 - roi_y2) * (roi_y3 - roi_y2) + (roi_x3 - roi_x2) * (roi_x3 - roi_x2)); 76 | double len3 = static_cast((roi_y0 - roi_y3) * (roi_y0 - roi_y3) + (roi_x0 - roi_x3) * (roi_x0 - roi_x3)); 77 | double cross_len0 = static_cast((roi_y0 - roi_y2) * (roi_y0 - roi_y2) + (roi_x0 - roi_x2) * (roi_x0 - roi_x2)); 78 | double cross_len1 = static_cast((roi_y3 - roi_y1) * (roi_y3 - roi_y1) + (roi_x3 - roi_x1) * (roi_x3 - roi_x1)); 79 | 80 | order = ldg(roi_order) < 0 ? (len0 + len2 > len1 + len3 ? 1 : 0) : 0; 81 | // fix ROI 82 | if(len0 < std::numeric_limits::min() || len1 < std::numeric_limits::min() || len2 < std::numeric_limits::min() || len3 < std::numeric_limits::min()){ 83 | // not check convex for faster speed 84 | //if(is_convex(roi_to_pool)){ 85 | *pooled_features_start = static_cast(0); 86 | *pooled_index_start = static_cast(0); 87 | continue; 88 | } 89 | 90 | T roi_y0_order = (order == 0) ? roi_y0 : roi_y1; 91 | T roi_x0_order = (order == 0) ? roi_x0 : roi_x1; 92 | T roi_y1_order = (order == 0) ? roi_y1 : roi_y2; 93 | T roi_x1_order = (order == 0) ? roi_x1 : roi_x2; 94 | T roi_y2_order = (order == 0) ? roi_y2 : roi_y3; 95 | T roi_x2_order = (order == 0) ? roi_x2 : roi_x3; 96 | T roi_y3_order = (order == 0) ? roi_y3 : roi_y0; 97 | T roi_x3_order = (order == 0) ? roi_x3 : roi_x0; 98 | 99 | T y_step_left = (roi_y3_order - roi_y0_order)/(grid_dim_height * 1.); 100 | T y_step_right = (roi_y2_order - roi_y1_order)/(grid_dim_height * 1.); 101 | T x_step_top = (roi_x1_order - roi_x0_order)/(grid_dim_width * 1.); 102 | T x_step_bottom = (roi_x2_order - roi_x3_order)/(grid_dim_width * 1.); 103 | 104 | T left_y1 = (roi_y0_order + row_index * y_step_left); 105 | T right_y1 = (roi_y1_order + row_index * y_step_right); 106 | T left_y2 = (roi_y0_order + (row_index + 1.) * y_step_left); 107 | T right_y2 = (roi_y1_order + (row_index + 1.) * y_step_right); 108 | 109 | T left_top_y = left_y1 + col_index * (right_y1 - left_y1)/(grid_dim_width); 110 | T right_top_y = left_y1 + (col_index + 1.) * (right_y1 - left_y1)/(grid_dim_width); 111 | T left_bottom_y = left_y2 + col_index * (right_y2 - left_y2)/(grid_dim_width); 112 | T right_bottom_y = left_y2 + (col_index + 1.) * (right_y2 - left_y2)/(grid_dim_width); 113 | 114 | T top_x1 = (roi_x0_order + col_index * x_step_top); 115 | T bottom_x1 = (roi_x3_order + col_index * x_step_bottom); 116 | T top_x2 = (roi_x0_order + (col_index + 1.) * x_step_top); 117 | T bottom_x2 = (roi_x3_order + (col_index + 1.) * x_step_bottom); 118 | 119 | T left_top_x = top_x1 + row_index * (bottom_x1 - top_x1)/(grid_dim_height); 120 | T left_bottom_x = top_x1 + (row_index + 1.) * (bottom_x1 - top_x1)/(grid_dim_height); 121 | T right_top_x = top_x2 + row_index * (bottom_x2 - top_x2)/(grid_dim_height); 122 | T right_bottom_x = top_x2 + (row_index + 1.) * (bottom_x2 - top_x2)/(grid_dim_height); 123 | 124 | float pool_bin_width = static_cast(tf_max(tf_min(fabsf(right_top_x - left_top_x), fabsf(right_top_y - left_top_y)), tf_min(fabsf(right_bottom_x - left_bottom_x), fabsf(right_bottom_y - left_bottom_y)))); 125 | float pool_bin_height = static_cast(tf_max(tf_min(fabsf(left_bottom_x - left_top_x), fabsf(left_bottom_y - left_top_y)), tf_min(fabsf(right_bottom_x - right_top_x), fabsf(right_bottom_y - right_top_y)))); 126 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 127 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 128 | 129 | T grid_y_step_left = (left_bottom_y - left_top_y)/(num_elem_height + 1.); 130 | T grid_y_step_right = (right_bottom_y - right_top_y)/(num_elem_height + 1.); 131 | T grid_x_step_top = (right_top_x - left_top_x)/(num_elem_width + 1.); 132 | T grid_x_step_bottom = (right_bottom_x - left_bottom_x)/(num_elem_width + 1.); 133 | 134 | int32_t max_pool_ind = 0; 135 | //T max_elem = std::numeric_limits::lowest(); 136 | T max_or_acc_elem = using_max_pool ? std::numeric_limits::lowest() : static_cast(0); 137 | 138 | for(int32_t pool_h = 0; pool_h < num_elem_height; ++pool_h){ 139 | for(int32_t pool_w = 0; pool_w < num_elem_width; ++pool_w){ 140 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 141 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 142 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 143 | 144 | int32_t int_col_to_pool = static_cast(col_to_pool); 145 | int32_t int_row_to_pool = static_cast(row_to_pool); 146 | float float_col_to_pool = col_to_pool - int_col_to_pool; 147 | float float_row_to_pool = row_to_pool - int_row_to_pool; 148 | 149 | int32_t current_switch_ind = num_elem_width * pool_h + pool_w; 150 | //std::cout << "current_switch_ind: " << current_switch_ind << std::endl; 151 | T temp_value = static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * ldg(feature_map_to_pool + int_row_to_pool * map_width + int_col_to_pool) + 152 | (1. - float_col_to_pool) * float_row_to_pool * ldg(feature_map_to_pool + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool) + 153 | float_col_to_pool * (1. - float_row_to_pool) * ldg(feature_map_to_pool + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1)) + 154 | float_col_to_pool * float_row_to_pool * ldg(feature_map_to_pool + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1))); 155 | if(using_max_pool){ 156 | if(max_or_acc_elem < temp_value){ 157 | max_or_acc_elem = temp_value; 158 | max_pool_ind = current_switch_ind; 159 | } 160 | }else{ 161 | max_or_acc_elem += temp_value; 162 | } 163 | } 164 | } 165 | 166 | if(!using_max_pool) max_or_acc_elem /= static_cast(num_elem_height * num_elem_width); 167 | *pooled_features_start = max_or_acc_elem; 168 | *pooled_index_start = using_max_pool ? max_pool_ind : static_cast(0); 169 | } 170 | } 171 | 172 | template 173 | void RotatedPSROIAlignFunctor::operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info) { 174 | 175 | int batch_size = 0; 176 | int num_channals = 0; 177 | int map_height = 0; 178 | int map_width = 0; 179 | int num_rois = 0; 180 | bool using_max_pool = false; 181 | 182 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 183 | 184 | CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * num_rois * num_channals, d); 185 | RotatedPSROIAlignCudaKernel <<>> (config, inputs.data(), rois.data(), orders.data(), pooled_features.data(), pooled_index.data(), grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool); 187 | 188 | cudaError_t err = cudaGetLastError(); 189 | if(cudaSuccess != err) 190 | { 191 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 192 | exit( -1 ); 193 | } 194 | } 195 | 196 | template struct RotatedPSROIAlignFunctor; 197 | // #define DEFINE_GPU_SPECS(T) \ 198 | // template struct RotatedPSROIAlignFunctorGPU; 199 | 200 | // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); 201 | 202 | #endif // GOOGLE_CUDA 203 | -------------------------------------------------------------------------------- /test_op.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2018 Changan Wang 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | import os 23 | import shutil 24 | import uuid 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorflow.python.framework import ops 28 | from tensorflow.python.ops import array_ops 29 | import math 30 | 31 | LIB_NAME = 'ps_roi_align' 32 | 33 | def load_op_module(lib_name): 34 | """ 35 | Load TensorFlow operator library. 36 | """ 37 | # use absolute path so that ops.py can be called from other directory 38 | lib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'build/lib{0}.so'.format(lib_name)) 39 | # duplicate library with a random new name so that 40 | # a running program will not be interrupted when the original library is updated 41 | lib_copy_path = '/tmp/lib{0}_{1}.so'.format(str(uuid.uuid4())[:8], LIB_NAME) 42 | shutil.copyfile(lib_path, lib_copy_path) 43 | oplib = tf.load_op_library(lib_copy_path) 44 | #print(_) 45 | return oplib 46 | 47 | op_module = load_op_module(LIB_NAME) 48 | #print("----",op_module.OP_LIST) 49 | 50 | # map_to_pool = [[[[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]]]] 51 | 52 | map_to_pool = [[ 53 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 54 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 55 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 56 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 57 | 58 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 59 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 60 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 61 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 62 | 63 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 64 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 65 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 66 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 67 | 68 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 69 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 70 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]], 71 | [[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.], [16., 17., 18., 19., 20.], [21., 22., 23., 24., 25.]] 72 | ]] 73 | 74 | pool_method = 'max' 75 | 76 | class PSROIAlignTest(tf.test.TestCase): 77 | def testPSROIAlign(self): 78 | with tf.device('/gpu:1'): 79 | # map C++ operators to python objects 80 | ps_roi_align = op_module.ps_roi_align 81 | result = ps_roi_align(map_to_pool, [[[0.2, 0.2, 0.7, 0.7], [0.5, 0.5, 0.9, 0.9], [0.9, 0.9, 1., 1.]]], 2, 2, pool_method) 82 | with self.test_session() as sess: 83 | print('ps_roi_align in gpu:', sess.run(result)) 84 | with tf.device('/cpu:0'): 85 | # map C++ operators to python objects 86 | ps_roi_align = op_module.ps_roi_align 87 | result = ps_roi_align(map_to_pool, [[[0.2, 0.2, 0.7, 0.7], [0.5, 0.5, 0.9, 0.9], [0.9, 0.9, 1., 1.]]], 2, 2, pool_method) 88 | with self.test_session() as sess: 89 | print('ps_roi_align in cpu:', sess.run(result)) 90 | # expect [3.18034267 0.39960092 0.00709875 2.96500921] 91 | #self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) 92 | 93 | @ops.RegisterGradient("PsRoiAlign") 94 | def _ps_roi_align_grad(op, grad, _): 95 | '''The gradients for `PsRoiAlign`. 96 | ''' 97 | inputs_features = op.inputs[0] 98 | rois = op.inputs[1] 99 | pooled_features_grad = op.outputs[0] 100 | pooled_index = op.outputs[1] 101 | grid_dim_width = op.get_attr('grid_dim_width') 102 | grid_dim_height = op.get_attr('grid_dim_height') 103 | 104 | return [op_module.ps_roi_align_grad(inputs_features, rois, grad, pooled_index, grid_dim_width, grid_dim_height, pool_method), None] 105 | 106 | class PSROIAlignGradTest(tf.test.TestCase): 107 | def testPSROIAlignGrad(self): 108 | with tf.device('/cpu:0'): 109 | ps_roi_align = op_module.ps_roi_align 110 | inputs_features = tf.constant(map_to_pool, dtype=tf.float32) 111 | pool_result = ps_roi_align(inputs_features, [[[0.2, 0.2, 0.7, 0.7], [0.5, 0.5, 0.9, 0.9], [0.9, 0.9, 1., 1.]]], 2, 2, pool_method) 112 | with tf.Session() as sess: 113 | #print(sess.run(tf.gradients(pool_result[0], [inputs_features]))) 114 | print(tf.test.compute_gradient_error(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 115 | # _, jaccobian = tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool)) 116 | # y = sess.run(pool_result[0]) 117 | # print(jaccobian.shape) 118 | # print(np.reshape(np.matmul(jaccobian, np.ones_like(y.flatten())), np.array(map_to_pool).shape)) 119 | print(tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 120 | with tf.device('/gpu:0'): 121 | ps_roi_align = op_module.ps_roi_align 122 | inputs_features = tf.constant(map_to_pool, dtype=tf.float32) 123 | pool_result = ps_roi_align(inputs_features, [[[0.2, 0.2, 0.7, 0.7], [0.5, 0.5, 0.9, 0.9], [0.9, 0.9, 1., 1.]]], 2, 2, pool_method) 124 | with tf.Session() as sess: 125 | #print(sess.run(tf.gradients(pool_result[0], [inputs_features]))) 126 | print(tf.test.compute_gradient_error(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 127 | # _, jaccobian = tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool)) 128 | # y = sess.run(pool_result[0]) 129 | # print(jaccobian.shape) 130 | # print(np.reshape(np.matmul(jaccobian, np.ones_like(y.flatten())), np.array(map_to_pool).shape)) 131 | print(tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 132 | 133 | class RotatedPSROIAlignTest(tf.test.TestCase): 134 | def testRotatedPSROIAlign(self): 135 | with tf.device('/gpu:1'): 136 | # map C++ operators to python objects 137 | rotated_ps_roi_align = op_module.rotated_ps_roi_align 138 | result = rotated_ps_roi_align(map_to_pool, [[[0.1, 0.1, 0.2, 0.3, 0.5, 0.5, 0.3, 0.2], [0.5, 0.5, 0.6, 0.7, 0.9, 0.9, 0.7, 0.6], [0.6, 0.7, 0.9, 0.9, 0.7, 0.6, 0.2, 0.2]]], [[1, -1, 3]], 2, 2, pool_method) 139 | with self.test_session() as sess: 140 | print('rotated_ps_roi_align in gpu:', sess.run(result)) 141 | with tf.device('/cpu:0'): 142 | # map C++ operators to python objects 143 | rotated_ps_roi_align = op_module.rotated_ps_roi_align 144 | result = rotated_ps_roi_align(map_to_pool, [[[0.1, 0.1, 0.2, 0.3, 0.5, 0.5, 0.3, 0.2], [0.5, 0.5, 0.6, 0.7, 0.9, 0.9, 0.7, 0.6], [0.6, 0.7, 0.9, 0.9, 0.7, 0.6, 0.2, 0.2]]], [[1, -1, 3]], 2, 2, pool_method) 145 | with self.test_session() as sess: 146 | print('rotated_ps_roi_align in cpu:', sess.run(result)) 147 | # expect [3.18034267 0.39960092 0.00709875 2.96500921] 148 | #self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) 149 | 150 | @ops.RegisterGradient("RotatedPsRoiAlign") 151 | def _rotated_ps_roi_align_grad(op, grad, _): 152 | '''The gradients for `RotatedPsRoiAlign`. 153 | ''' 154 | inputs_features = op.inputs[0] 155 | rois = op.inputs[1] 156 | orders = op.inputs[2] 157 | pooled_features_grad = op.outputs[0] 158 | pooled_index = op.outputs[1] 159 | grid_dim_width = op.get_attr('grid_dim_width') 160 | grid_dim_height = op.get_attr('grid_dim_height') 161 | 162 | return [op_module.rotated_ps_roi_align_grad(inputs_features, rois, orders, grad, pooled_index, grid_dim_width, grid_dim_height, pool_method), None, None] 163 | 164 | class RotatedPSROIAlignGradTest(tf.test.TestCase): 165 | def testRotatedPSROIAlignGrad(self): 166 | with tf.device('/cpu:0'): 167 | rotated_ps_roi_align = op_module.rotated_ps_roi_align 168 | inputs_features = tf.constant(map_to_pool, dtype=tf.float32) 169 | pool_result = rotated_ps_roi_align(inputs_features, [[[0.1, 0.1, 0.2, 0.3, 0.5, 0.5, 0.3, 0.2], [0.5, 0.5, 0.6, 0.7, 0.9, 0.9, 0.7, 0.6], [0.6, 0.7, 0.9, 0.9, 0.7, 0.6, 0.2, 0.2]]], [[1, -1, 3]], 2, 2, pool_method) 170 | with tf.Session() as sess: 171 | #print(sess.run(tf.gradients(pool_result[0], [inputs_features]))) 172 | print(tf.test.compute_gradient_error(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 173 | # _, jaccobian = tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool)) 174 | # y = sess.run(pool_result[0]) 175 | # print(jaccobian.shape) 176 | # print(np.reshape(np.matmul(jaccobian, np.ones_like(y.flatten())), np.array(map_to_pool).shape)) 177 | print(tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 178 | with tf.device('/gpu:0'): 179 | rotated_ps_roi_align = op_module.rotated_ps_roi_align 180 | inputs_features = tf.constant(map_to_pool, dtype=tf.float32) 181 | pool_result = rotated_ps_roi_align(inputs_features, [[[0.1, 0.1, 0.2, 0.3, 0.5, 0.5, 0.3, 0.2], [0.5, 0.5, 0.6, 0.7, 0.9, 0.9, 0.7, 0.6], [0.6, 0.7, 0.9, 0.9, 0.7, 0.6, 0.2, 0.2]]], [[1, -1, 3]], 2, 2, pool_method) 182 | with tf.Session() as sess: 183 | #print(sess.run(tf.gradients(pool_result[0], [inputs_features]))) 184 | print(tf.test.compute_gradient_error(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 185 | # _, jaccobian = tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool)) 186 | # y = sess.run(pool_result[0]) 187 | # print(jaccobian.shape) 188 | # print(np.reshape(np.matmul(jaccobian, np.ones_like(y.flatten())), np.array(map_to_pool).shape)) 189 | print(tf.test.compute_gradient(inputs_features, [1, 16, 5, 5], pool_result[0], [1, 3, 4, 4], delta=0.0001, x_init_value=np.array(map_to_pool))) 190 | 191 | if __name__ == "__main__": 192 | tf.test.main() 193 | -------------------------------------------------------------------------------- /rotated_ps_roi_align_grad_op.cu: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #if GOOGLE_CUDA == 1 23 | #define EIGEN_USE_GPU 24 | #include "rotated_ps_roi_align_op.h" 25 | #include "tensorflow/core/util/cuda_kernel_helper.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor_shape.h" 28 | 29 | using namespace tensorflow; 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | // Define the CUDA kernel. 36 | template 37 | __global__ void RotatedPSROIAlignGradCudaKernel(CudaLaunchConfig config, const T * inputs, const T * rois, const int32_t * orders, const T * pooled_features_grad, const int32_t * pooled_index, T * grad_output, const int32_t grid_dim_width, const int32_t grid_dim_height, const int batch_size, const int num_channals, const int map_height, const int map_width, const int num_rois, const bool using_max_pool) { 38 | 39 | const int32_t grid_size = grid_dim_width * grid_dim_height; 40 | const int32_t bank_size = num_channals / grid_size; 41 | 42 | CUDA_1D_KERNEL_LOOP(worker_index, config.virtual_thread_count) { 43 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 44 | const int32_t position_index = (worker_index % num_channals) / bank_size; 45 | const int32_t row_index = position_index / grid_dim_width; 46 | const int32_t col_index = position_index % grid_dim_width; 47 | // position of the channal of pooled feature 48 | // position of the channal in the bank of feature map 49 | const int32_t channal_pos_remainder = worker_index % bank_size; 50 | const int32_t pool_index = worker_index / num_channals; 51 | const int32_t image_index = pool_index / num_rois; 52 | const int32_t roi_index = pool_index % num_rois; 53 | 54 | const T * roi_to_pool = rois + (image_index * num_rois + roi_index) * 8; 55 | const int32_t * roi_order = orders + image_index * num_rois + roi_index; 56 | 57 | T * grad_output_start = reinterpret_cast(grad_output + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width); 58 | 59 | const T * pooled_features_start = pooled_features_grad + worker_index; 60 | const int32_t * pooled_index_start = pooled_index + worker_index; 61 | 62 | int32_t order = ldg(roi_order) < 0 ? 0 : ldg(roi_order) * 2; 63 | 64 | T roi_y0 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 65 | T roi_x0 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 66 | T roi_y1 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 67 | T roi_x1 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 68 | T roi_y2 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 69 | T roi_x2 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 70 | T roi_y3 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_height); 71 | T roi_x3 = static_cast(ldg(roi_to_pool + (order++) % 8) * map_width); 72 | 73 | double len0 = static_cast((roi_y1 - roi_y0) * (roi_y1 - roi_y0) + (roi_x1 - roi_x0) * (roi_x1 - roi_x0)); 74 | double len1 = static_cast((roi_y2 - roi_y1) * (roi_y2 - roi_y1) + (roi_x2 - roi_x1) * (roi_x2 - roi_x1)); 75 | double len2 = static_cast((roi_y3 - roi_y2) * (roi_y3 - roi_y2) + (roi_x3 - roi_x2) * (roi_x3 - roi_x2)); 76 | double len3 = static_cast((roi_y0 - roi_y3) * (roi_y0 - roi_y3) + (roi_x0 - roi_x3) * (roi_x0 - roi_x3)); 77 | double cross_len0 = static_cast((roi_y0 - roi_y2) * (roi_y0 - roi_y2) + (roi_x0 - roi_x2) * (roi_x0 - roi_x2)); 78 | double cross_len1 = static_cast((roi_y3 - roi_y1) * (roi_y3 - roi_y1) + (roi_x3 - roi_x1) * (roi_x3 - roi_x1)); 79 | 80 | order = ldg(roi_order) < 0 ? (len0 + len2 > len1 + len3 ? 1 : 0) : 0; 81 | // fix ROI 82 | if(len0 < std::numeric_limits::min() || len1 < std::numeric_limits::min() || len2 < std::numeric_limits::min() || len3 < std::numeric_limits::min()){ 83 | // not check convex for faster speed 84 | //if(is_convex(roi_to_pool)){ 85 | continue; 86 | } 87 | 88 | T roi_y0_order = (order == 0) ? roi_y0 : roi_y1; 89 | T roi_x0_order = (order == 0) ? roi_x0 : roi_x1; 90 | T roi_y1_order = (order == 0) ? roi_y1 : roi_y2; 91 | T roi_x1_order = (order == 0) ? roi_x1 : roi_x2; 92 | T roi_y2_order = (order == 0) ? roi_y2 : roi_y3; 93 | T roi_x2_order = (order == 0) ? roi_x2 : roi_x3; 94 | T roi_y3_order = (order == 0) ? roi_y3 : roi_y0; 95 | T roi_x3_order = (order == 0) ? roi_x3 : roi_x0; 96 | 97 | T y_step_left = (roi_y3_order - roi_y0_order)/(grid_dim_height * 1.); 98 | T y_step_right = (roi_y2_order - roi_y1_order)/(grid_dim_height * 1.); 99 | T x_step_top = (roi_x1_order - roi_x0_order)/(grid_dim_width * 1.); 100 | T x_step_bottom = (roi_x2_order - roi_x3_order)/(grid_dim_width * 1.); 101 | 102 | T left_y1 = (roi_y0_order + row_index * y_step_left); 103 | T right_y1 = (roi_y1_order + row_index * y_step_right); 104 | T left_y2 = (roi_y0_order + (row_index + 1.) * y_step_left); 105 | T right_y2 = (roi_y1_order + (row_index + 1.) * y_step_right); 106 | 107 | T left_top_y = left_y1 + col_index * (right_y1 - left_y1)/(grid_dim_width); 108 | T right_top_y = left_y1 + (col_index + 1.) * (right_y1 - left_y1)/(grid_dim_width); 109 | T left_bottom_y = left_y2 + col_index * (right_y2 - left_y2)/(grid_dim_width); 110 | T right_bottom_y = left_y2 + (col_index + 1.) * (right_y2 - left_y2)/(grid_dim_width); 111 | 112 | T top_x1 = (roi_x0_order + col_index * x_step_top); 113 | T bottom_x1 = (roi_x3_order + col_index * x_step_bottom); 114 | T top_x2 = (roi_x0_order + (col_index + 1.) * x_step_top); 115 | T bottom_x2 = (roi_x3_order + (col_index + 1.) * x_step_bottom); 116 | 117 | T left_top_x = top_x1 + row_index * (bottom_x1 - top_x1)/(grid_dim_height); 118 | T left_bottom_x = top_x1 + (row_index + 1.) * (bottom_x1 - top_x1)/(grid_dim_height); 119 | T right_top_x = top_x2 + row_index * (bottom_x2 - top_x2)/(grid_dim_height); 120 | T right_bottom_x = top_x2 + (row_index + 1.) * (bottom_x2 - top_x2)/(grid_dim_height); 121 | 122 | float pool_bin_width = static_cast(tf_max(tf_min(fabsf(right_top_x - left_top_x), fabsf(right_top_y - left_top_y)), tf_min(fabsf(right_bottom_x - left_bottom_x), fabsf(right_bottom_y - left_bottom_y)))); 123 | float pool_bin_height = static_cast(tf_max(tf_min(fabsf(left_bottom_x - left_top_x), fabsf(left_bottom_y - left_top_y)), tf_min(fabsf(right_bottom_x - right_top_x), fabsf(right_bottom_y - right_top_y)))); 124 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 125 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 126 | 127 | T grid_y_step_left = (left_bottom_y - left_top_y)/(num_elem_height + 1.); 128 | T grid_y_step_right = (right_bottom_y - right_top_y)/(num_elem_height + 1.); 129 | T grid_x_step_top = (right_top_x - left_top_x)/(num_elem_width + 1.); 130 | T grid_x_step_bottom = (right_bottom_x - left_bottom_x)/(num_elem_width + 1.); 131 | 132 | if(using_max_pool){ 133 | const int32_t pool_h = ldg(pooled_index_start) / num_elem_width; 134 | const int32_t pool_w = ldg(pooled_index_start) % num_elem_width; 135 | 136 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 137 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 138 | 139 | int32_t int_col_to_pool = static_cast(col_to_pool); 140 | int32_t int_row_to_pool = static_cast(row_to_pool); 141 | float float_col_to_pool = col_to_pool - int_col_to_pool; 142 | float float_row_to_pool = row_to_pool - int_row_to_pool; 143 | 144 | const T grad_in = ldg(pooled_features_start); 145 | 146 | atomicAdd(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 147 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 148 | atomicAdd(grad_output_start + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 149 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 150 | }else{ 151 | const T grad_in = ldg(pooled_features_start) / static_cast(num_elem_width * num_elem_height); 152 | for(int32_t pool_h = 0; pool_h < num_elem_height; ++pool_h){ 153 | for(int32_t pool_w = 0; pool_w < num_elem_width; ++pool_w){ 154 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 155 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 156 | 157 | int32_t int_col_to_pool = static_cast(col_to_pool); 158 | int32_t int_row_to_pool = static_cast(row_to_pool); 159 | float float_col_to_pool = col_to_pool - int_col_to_pool; 160 | float float_row_to_pool = row_to_pool - int_row_to_pool; 161 | 162 | atomicAdd(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 163 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 164 | atomicAdd(grad_output_start + int_row_to_pool * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 165 | atomicAdd(grad_output_start + tf_min(int_row_to_pool + 1, map_height - 1) * map_width + tf_min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 166 | } 167 | } 168 | } 169 | } 170 | } 171 | 172 | template 173 | void RotatedPSROIAlignGradFunctor::operator()(OpKernelContext* context, const GPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 174 | 175 | int batch_size = 0; 176 | int num_channals = 0; 177 | int map_height = 0; 178 | int map_width = 0; 179 | int num_rois = 0; 180 | bool using_max_pool = false; 181 | 182 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 183 | 184 | CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * num_rois * num_channals, d); 185 | //grad_output = grad_output.setZero(); 186 | SetZero <<>> (batch_size * map_height * map_width * num_channals, grad_output.data()); 187 | 188 | RotatedPSROIAlignGradCudaKernel <<>> (config, inputs.data(), rois.data(), orders.data(), pooled_features_grad.data(), pooled_index.data(), grad_output.data(), grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool); 190 | 191 | cudaError_t err = cudaGetLastError(); 192 | if(cudaSuccess != err) 193 | { 194 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 195 | exit( -1 ); 196 | } 197 | } 198 | 199 | template struct RotatedPSROIAlignGradFunctor; 200 | // #define DEFINE_GPU_SPECS(T) \ 201 | // template struct RotatedPSROIAlignFunctorGPU; 202 | 203 | // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); 204 | 205 | #endif // GOOGLE_CUDA 206 | -------------------------------------------------------------------------------- /ps_roi_align_op.cc: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #include "ps_roi_align_op.h" 23 | #include "work_sharder.h" 24 | 25 | #include "tensorflow/core/framework/op_kernel.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor.h" 28 | #include "tensorflow/core/framework/tensor_shape.h" 29 | #include "tensorflow/core/framework/register_types.h" 30 | #include "tensorflow/core/framework/op.h" 31 | #include "tensorflow/core/framework/shape_inference.h" 32 | 33 | #include 34 | 35 | using namespace tensorflow; 36 | 37 | // the inputs should have format NCHW, which is faster on GPUs 38 | REGISTER_OP("PsRoiAlign") 39 | .Attr("T: {float}") 40 | .Attr("grid_dim_width: int") 41 | .Attr("grid_dim_height: int") 42 | .Attr("pool_method: string") 43 | .Input("inputs: T") 44 | .Input("rois: T") 45 | // .Input("grid_dim_width: int32") 46 | // .Input("grid_dim_height: int32") 47 | .Output("pooled_features: T") 48 | .Output("pooled_index: int32") 49 | .Doc(R"doc( 50 | PsRoiAlign is a new PsRoiPooling method without align problems. 51 | The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. 52 | The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). 53 | )doc") 54 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 55 | shape_inference::ShapeHandle inputs_shape = c->input(0); 56 | shape_inference::DimensionHandle num_per_batch = c->Dim(inputs_shape, 0); 57 | shape_inference::DimensionHandle num_channals = c->Dim(inputs_shape, 1); 58 | shape_inference::DimensionHandle num_rois = c->Dim(c->input(1), 1); 59 | //TF_RETURN_IF_ERROR(c->MakeDimGetAttrForScalarInput(3, &grid_dim_height)); 60 | int32_t grid_dim_width(0); 61 | TF_RETURN_IF_ERROR(c->GetAttr("grid_dim_width", &grid_dim_width)); 62 | int32_t grid_dim_height(0); 63 | TF_RETURN_IF_ERROR(c->GetAttr("grid_dim_height", &grid_dim_height)); 64 | // one can use following function to make more check on input shape 65 | // use WithValue check DimensionHandle, and use WithRank check ShapeHandle 66 | // TF_RETURN_IF_ERROR(c->WithRank(logits_shape, 2, &logits_shape)); 67 | // TF_RETURN_IF_ERROR(c->WithValue(num_per_batch, 128, &num_per_batch)); 68 | const int32_t grid_size(grid_dim_width * grid_dim_height); 69 | shape_inference::DimensionHandle bank_size; 70 | TF_RETURN_IF_ERROR(c->Divide(num_channals, grid_size, true, &bank_size)); 71 | // use MakeShape to create a ShapeHandle from one DimensionHandle 72 | c->set_output(0, c->MakeShape({num_per_batch, num_rois, grid_size, bank_size})); 73 | c->set_output(1, c->MakeShape({num_per_batch, num_rois, grid_size, bank_size})); 74 | //c->set_output(1, c->MakeShape({num_per_batch, num_classes})); 75 | return Status::OK(); 76 | }); 77 | 78 | 79 | // CPU specialization of actual computation. 80 | //template 81 | template 82 | struct PSROIAlignFunctor { 83 | void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info) { 84 | 85 | int batch_size = 0; 86 | int num_channals = 0; 87 | int map_height = 0; 88 | int map_width = 0; 89 | int num_rois = 0; 90 | bool using_max_pool = false; 91 | 92 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 93 | 94 | auto pooling_routine = [&inputs, &rois, &pooled_features, &pooled_index, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 95 | const int32_t grid_size = grid_dim_width * grid_dim_height; 96 | const int32_t bank_size = num_channals/grid_size; 97 | for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 98 | // worker_index / bank_size / grid_size * num_channals + worker_index / bank_size / grid_size % num_rois * num_channals + (worker_index % grid_size) + worker_index % bank_size; 99 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 100 | const int32_t position_index = (worker_index % num_channals) / bank_size; 101 | const int32_t row_index = position_index / grid_dim_width; 102 | const int32_t col_index = position_index % grid_dim_width; 103 | // position of the channal of pooled feature 104 | // position of the channal in the bank of feature map 105 | const int32_t channal_pos_remainder = worker_index % bank_size; 106 | const int32_t pool_index = worker_index / num_channals; 107 | const int32_t image_index = pool_index / num_rois; 108 | const int32_t roi_index = pool_index % num_rois; 109 | 110 | const T * roi_to_pool = rois.data() + (image_index * num_rois + roi_index) * 4; 111 | 112 | const T * feature_map_to_pool = inputs.data() + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width; 113 | // T * pooled_features_start = pooled_features.data() + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 114 | // int32_t * pooled_index_start = pooled_index.data() + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 115 | T * pooled_features_start = pooled_features.data() + worker_index; 116 | int32_t * pooled_index_start = pooled_index.data() + worker_index; 117 | 118 | T roi_ymin = static_cast(0); 119 | T roi_xmin = static_cast(0); 120 | T roi_ymax = static_cast(0); 121 | T roi_xmax = static_cast(0); 122 | // fix ROI 123 | if(roi_to_pool[2] < std::numeric_limits::min() || roi_to_pool[3] < std::numeric_limits::min()){ 124 | *pooled_features_start = static_cast(0); 125 | continue; 126 | } 127 | 128 | std::tie(roi_ymin, roi_xmin, roi_ymax, roi_xmax) = [roi_to_pool, map_height, map_width](){ 129 | T roi_y_center = static_cast(roi_to_pool[0] * map_height); 130 | T roi_x_center = static_cast(roi_to_pool[1] * map_width); 131 | T roi_h = std::max(roi_to_pool[2] * map_height, static_cast(1)); 132 | T roi_w = std::max(roi_to_pool[3] * map_width, static_cast(1)); 133 | 134 | T roi_ymin = std::max(roi_y_center - static_cast(roi_h / 2.), static_cast(0)); 135 | T roi_xmin = std::max(roi_x_center - static_cast(roi_w / 2.), static_cast(0)); 136 | T roi_ymax = std::min(roi_y_center + static_cast(roi_h / 2.), static_cast(map_height) - std::numeric_limits::min()); 137 | T roi_xmax = std::min(roi_x_center + static_cast(roi_w / 2.), static_cast(map_width) - std::numeric_limits::min()); 138 | return std::make_tuple(roi_ymin, roi_xmin, roi_ymax, roi_xmax); 139 | }(); 140 | // T roi_center_y = roi_to_pool[0]; 141 | // T roi_center_x = roi_to_pool[1]; 142 | T roi_h = roi_ymax - roi_ymin; 143 | T roi_w = roi_xmax - roi_xmin; 144 | float pool_bin_width = static_cast(roi_w) / grid_dim_width; 145 | float pool_bin_height = static_cast(roi_h) / grid_dim_height; 146 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 147 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 148 | 149 | // std::cout << "pool_bin_width: " << pool_bin_width << " pool_bin_height: " << pool_bin_height << " num_elem_width: " << num_elem_width << " num_elem_height: " << num_elem_height << std::endl; 150 | 151 | // std::cout << "worker_index: " << worker_index << " roi_index: " << roi_index 152 | // << " roi_ymin: " << roi_ymin << " roi_xmin: " << roi_xmin << " roi_ymax: " << roi_ymax << " roi_xmax: " << roi_xmax << " image_index: " << image_index << " position_index: " << (position_index % grid_size) << " channal_pos_remainder: " << channal_pos_remainder << std::endl; 153 | 154 | float step_width_each_bin = pool_bin_width / num_elem_width; 155 | float step_height_each_bin = pool_bin_height / num_elem_height; 156 | 157 | float pool_width_start = roi_xmin + pool_bin_width * col_index; 158 | float pool_height_start = roi_ymin + pool_bin_height * row_index; 159 | int32_t max_pool_ind = 0; 160 | T max_or_acc_elem = using_max_pool ? std::numeric_limits::lowest() : static_cast(0); 161 | for (int32_t h_ind = 0; h_ind < num_elem_height; ++h_ind) { 162 | for (int32_t w_ind = 0; w_ind < num_elem_width; ++w_ind) { 163 | float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 164 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 165 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 166 | int32_t int_col_to_pool = static_cast(col_to_pool); 167 | int32_t int_row_to_pool = static_cast(row_to_pool); 168 | float float_col_to_pool = col_to_pool - int_col_to_pool; 169 | float float_row_to_pool = row_to_pool - int_row_to_pool; 170 | 171 | int32_t current_switch_ind = num_elem_width * h_ind + w_ind; 172 | //std::cout << "current_switch_ind: " << current_switch_ind << std::endl; 173 | T temp_value = static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * feature_map_to_pool[int_row_to_pool * map_width + int_col_to_pool] + 174 | (1. - float_col_to_pool) * float_row_to_pool * feature_map_to_pool[std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool] + 175 | float_col_to_pool * (1. - float_row_to_pool) * feature_map_to_pool[int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1)] + 176 | float_col_to_pool * float_row_to_pool * feature_map_to_pool[std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1)]); 177 | if(using_max_pool){ 178 | if(max_or_acc_elem < temp_value){ 179 | max_or_acc_elem = temp_value; 180 | max_pool_ind = current_switch_ind; 181 | } 182 | }else{ 183 | max_or_acc_elem += temp_value; 184 | } 185 | } 186 | } 187 | 188 | if(!using_max_pool) max_or_acc_elem /= static_cast(num_elem_height * num_elem_width); 189 | 190 | *pooled_features_start = max_or_acc_elem; 191 | *pooled_index_start = using_max_pool ? max_pool_ind : static_cast(0); 192 | } 193 | }; 194 | 195 | const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 196 | // one worker for one position in each ROI 197 | const int64_t shard_cost = 4 * map_height * map_width / grid_dim_width / grid_dim_height / 4; 198 | Shard(worker_threads.num_threads, worker_threads.workers, 199 | pooled_features.size(), shard_cost, pooling_routine); 200 | } 201 | }; 202 | 203 | // OpKernel definition. 204 | // template parameter is the datatype of the tensors. 205 | template 206 | class PSROIAlignOp : public OpKernel { 207 | public: 208 | explicit PSROIAlignOp(OpKernelConstruction* context) : OpKernel(context) { 209 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_width", &grid_dim_width_in)); 210 | OP_REQUIRES(context, grid_dim_width_in >= 0, errors::InvalidArgument("Need Attr grid_dim_width >= 0, got ", grid_dim_width_in)); 211 | 212 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_height", &grid_dim_height_in)); 213 | OP_REQUIRES(context, grid_dim_height_in >= 0, errors::InvalidArgument("Need Attr grid_dim_height >= 0, got ", grid_dim_height_in)); 214 | 215 | OP_REQUIRES_OK(context, context->GetAttr("pool_method", &pool_method)); 216 | OP_REQUIRES(context, StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max")), errors::InvalidArgument("Need Attr pool_method to be either 'mean' or 'max', got ", pool_method)); 217 | // std::cout << (StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max"))) << std::endl; 218 | } 219 | 220 | void Compute(OpKernelContext* context) override { 221 | const Tensor& inputs_in = context->input(0); 222 | const Tensor& rois_in = context->input(1); 223 | 224 | OP_REQUIRES(context, inputs_in.shape().dims() == 4, errors::InvalidArgument("inputs must be in 'NCHW' format.")); 225 | OP_REQUIRES(context, rois_in.shape().dims() == 3 && rois_in.shape().dim_size(2) == 4, errors::InvalidArgument("rois must be in 'batch_size x num_rois x 4' format.")); 226 | OP_REQUIRES(context, inputs_in.dim_size(0) == rois_in.dim_size(0), errors::InvalidArgument("'batch_size' in inputs and rois don't match.")); 227 | 228 | const int batch_size = inputs_in.dim_size(0); 229 | const int num_channals = inputs_in.dim_size(1); 230 | const int map_height = inputs_in.dim_size(2); 231 | const int map_width = inputs_in.dim_size(3); 232 | const int num_rois = rois_in.dim_size(1); 233 | 234 | const int32_t grid_size = grid_dim_width_in * grid_dim_height_in; 235 | 236 | auto bank_size = static_cast(num_channals / grid_size); 237 | Tensor* pooled_features = nullptr; 238 | OP_REQUIRES_OK(context, context->allocate_output(0, {batch_size, num_rois, grid_size, bank_size}, &pooled_features)); 239 | Tensor* pooled_index = nullptr; 240 | OP_REQUIRES_OK(context, context->allocate_output(1, {batch_size, num_rois, grid_size, bank_size}, &pooled_index)); 241 | 242 | PSROIAlignFunctor()(context, context->eigen_device(), inputs_in.template flat(), rois_in.template flat(), grid_dim_width_in, grid_dim_height_in, pooled_features->template flat(), pooled_index->template flat(), std::make_tuple(batch_size, num_channals, map_height, map_width, num_rois, StringPiece(pool_method).contains(StringPiece("max")))); 243 | // PSROIPoolingFunctor()(context, context->eigen_device(), inputs_in.tensor(), rois_in.tensor(), grid_dim_buffer[0], pooled_features->tensor()); 244 | } 245 | 246 | private: 247 | int32_t grid_dim_width_in{-1}; 248 | int32_t grid_dim_height_in{-1}; 249 | std::string pool_method{"max"}; 250 | }; 251 | 252 | // Register the CPU kernels. 253 | #define REGISTER_CPU(T) \ 254 | REGISTER_KERNEL_BUILDER( \ 255 | Name("PsRoiAlign").Device(DEVICE_CPU).TypeConstraint("T"), \ 256 | PSROIAlignOp); 257 | REGISTER_CPU(float); 258 | 259 | // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 260 | // #undef REGISTER_CPU 261 | 262 | // Register the GPU kernels. 263 | #if GOOGLE_CUDA == 1 264 | #define REGISTER_GPU(T) \ 265 | REGISTER_KERNEL_BUILDER( \ 266 | Name("PsRoiAlign").Device(DEVICE_GPU).TypeConstraint("T"), \ 267 | PSROIAlignOp); 268 | REGISTER_GPU(float); 269 | #endif // GOOGLE_CUDA 270 | -------------------------------------------------------------------------------- /rotated_ps_roi_align_op.cc: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #include "rotated_ps_roi_align_op.h" 23 | #include "work_sharder.h" 24 | 25 | #include "tensorflow/core/framework/op_kernel.h" 26 | #include "tensorflow/core/framework/register_types.h" 27 | #include "tensorflow/core/framework/tensor.h" 28 | #include "tensorflow/core/framework/tensor_shape.h" 29 | #include "tensorflow/core/framework/register_types.h" 30 | #include "tensorflow/core/framework/op.h" 31 | #include "tensorflow/core/framework/shape_inference.h" 32 | 33 | #include 34 | 35 | using namespace tensorflow; 36 | 37 | // the inputs should have format NCHW, which is faster on GPUs 38 | REGISTER_OP("RotatedPsRoiAlign") 39 | .Attr("T: {float}") 40 | .Attr("grid_dim_width: int") 41 | .Attr("grid_dim_height: int") 42 | .Attr("pool_method: string") 43 | .Input("inputs: T") 44 | .Input("rois: T") 45 | .Input("orders: int32") 46 | .Output("pooled_features: T") 47 | .Output("pooled_index: int32") 48 | .Doc(R"doc( 49 | RotatedPsRoiAlign is a new PsRoiPooling method without align problems. 50 | The input rois to be pooled must in format [y0, x0, y1, x1, y2, x2, y3, x3] which is four vertexes defining quadrilateral in clockwise order and each element must be in range [0, 1.]. 51 | The input orders define which point is the first one, each element must be in range [-1, 4). The order will be determined to be the first vertex of the shorter side if given -1. 52 | The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0., 0., 0., 1., 1., 1., 1., 0.]). 53 | )doc") 54 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 55 | shape_inference::ShapeHandle inputs_shape = c->input(0); 56 | shape_inference::DimensionHandle num_per_batch = c->Dim(inputs_shape, 0); 57 | shape_inference::DimensionHandle num_channals = c->Dim(inputs_shape, 1); 58 | shape_inference::DimensionHandle num_rois = c->Dim(c->input(1), 1); 59 | //TF_RETURN_IF_ERROR(c->MakeDimGetAttrForScalarInput(3, &grid_dim_height)); 60 | int32_t grid_dim_width(0); 61 | TF_RETURN_IF_ERROR(c->GetAttr("grid_dim_width", &grid_dim_width)); 62 | int32_t grid_dim_height(0); 63 | TF_RETURN_IF_ERROR(c->GetAttr("grid_dim_height", &grid_dim_height)); 64 | // one can use following function to make more check on input shape 65 | // use WithValue check DimensionHandle, and use WithRank check ShapeHandle 66 | // TF_RETURN_IF_ERROR(c->WithRank(logits_shape, 2, &logits_shape)); 67 | // TF_RETURN_IF_ERROR(c->WithValue(num_per_batch, 128, &num_per_batch)); 68 | const int32_t grid_size(grid_dim_width * grid_dim_height); 69 | shape_inference::DimensionHandle bank_size; 70 | TF_RETURN_IF_ERROR(c->Divide(num_channals, grid_size, true, &bank_size)); 71 | // use MakeShape to create a ShapeHandle from one DimensionHandle 72 | c->set_output(0, c->MakeShape({num_per_batch, num_rois, grid_size, bank_size})); 73 | c->set_output(1, c->MakeShape({num_per_batch, num_rois, grid_size, bank_size})); 74 | //c->set_output(1, c->MakeShape({num_per_batch, num_classes})); 75 | return Status::OK(); 76 | }); 77 | 78 | 79 | // CPU specialization of actual computation. 80 | //template 81 | template 82 | struct RotatedPSROIAlignFunctor { 83 | void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::Flat pooled_features, typename TTypes::Flat pooled_index, KDimSize dim_info) { 84 | 85 | int batch_size = 0; 86 | int num_channals = 0; 87 | int map_height = 0; 88 | int map_width = 0; 89 | int num_rois = 0; 90 | bool using_max_pool = false; 91 | 92 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 93 | 94 | auto pooling_routine = [&inputs, &rois, &orders, &pooled_features, &pooled_index, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 95 | const int32_t grid_size = grid_dim_width * grid_dim_height; 96 | const int32_t bank_size = num_channals/grid_size; 97 | 98 | // auto fn_get_nth_dividing_points = [](T y1, T x1, T y2, T x2, int32_t nth /* from 0 to total - 1 */, int32_t total) -> std::make_tuple { 99 | // T nth_x = (x1 + x2) / 2.; 100 | // double x_step = (x1 - x2) / (total + 1.); 101 | // double y_step = (y1 - y2) / (total + 1.); 102 | // if(std::abs(x1 - x2) > std::numeric_limits::min()){ 103 | // nth_x = x2 + nth * x_step; 104 | // } 105 | // T nth_y = y2 + nth * y_step; 106 | // return std::make_tuple(nth_y, nth_x); 107 | // }; 108 | 109 | // https://stackoverflow.com/questions/471962/how-do-determine-if-a-polygon-is-complex-convex-nonconvex 110 | // Return True if the polynomial defined by the sequence of 2D 111 | // points is 'strictly convex': points are valid, side lengths non- 112 | // zero, interior angles are strictly between zero and a straight 113 | // angle, and the polygon does not intersect itself. 114 | 115 | // NOTES: 1. Algorithm: the signed changes of the direction angles 116 | // from one side to the next side must be all positive or 117 | // all negative, and their sum must equal plus-or-minus 118 | // one full turn (2 pi radians). Also check for too few, 119 | // invalid, or repeated points. 120 | // 2. No check is explicitly done for zero internal angles 121 | // (180 degree direction-change angle) as this is covered 122 | // in other ways, including the `n < 3` check. 123 | auto is_convex = [](const T * points){ 124 | double TWO_PI = 2 * PI; 125 | // Get starting information 126 | T old_x = points[5]; 127 | T old_y = points[4]; 128 | T new_x = points[7]; 129 | T new_y = points[6]; 130 | if(std::abs(new_y - old_y) < std::numeric_limits::min() && std::abs(new_x - old_x) < std::numeric_limits::min()) return false; 131 | T new_direction = std::atan2(new_y - old_y, new_x - old_x); 132 | T old_direction = 0.; 133 | double angle_sum = 0.; 134 | double orientation = 1.; 135 | // Check each point (the side ending there, its angle) and accum. angles 136 | for(uint16_t index = 0; index < 4; index++){ 137 | // Update point coordinates and side directions, check side length 138 | old_x = new_x; 139 | old_y = new_y; 140 | old_direction = new_direction; 141 | new_y = points[2 * index]; 142 | new_x = points[2 * index + 1]; 143 | if(std::abs(old_x - new_x) < std::numeric_limits::min() && std::abs(old_y - new_y) < std::numeric_limits::min()) return false; // repeated consecutive points 144 | new_direction = std::atan2(new_y - old_y, new_x - old_x); 145 | // Calculate & check the normalized direction-change angle 146 | double angle = new_direction - old_direction; 147 | if(angle <= -PI) angle += TWO_PI; // make it in half-open interval (-Pi, Pi] 148 | else if(angle > PI) angle -= TWO_PI; 149 | 150 | if(index == 0){ // if first time through loop, initialize orientation 151 | if(angle == 0.0) return false; 152 | orientation = angle > 0.0 ? 1.0 : -1.0; 153 | }else{ // if other time through loop, check orientation is stable 154 | if(orientation * angle <= 0.0) return false;// not both pos. or both neg. 155 | } 156 | // Accumulate the direction-change angle 157 | angle_sum += angle; 158 | } 159 | // Check that the total number of full turns is plus-or-minus 1 160 | return std::abs(std::round(angle_sum / TWO_PI)) == 1; 161 | }; 162 | 163 | for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 164 | // worker_index / bank_size / grid_size * num_channals + worker_index / bank_size / grid_size % num_rois * num_channals + (worker_index % grid_size) + worker_index % bank_size; 165 | // image_index * roi_index * channal_pos_remainder * row_index * col_index 166 | const int32_t position_index = (worker_index % num_channals) / bank_size; 167 | const int32_t row_index = position_index / grid_dim_width; 168 | const int32_t col_index = position_index % grid_dim_width; 169 | // position of the channal of pooled feature 170 | // position of the channal in the bank of feature map 171 | const int32_t channal_pos_remainder = worker_index % bank_size; 172 | const int32_t pool_index = worker_index / num_channals; 173 | const int32_t image_index = pool_index / num_rois; 174 | const int32_t roi_index = pool_index % num_rois; 175 | 176 | const T * roi_to_pool = rois.data() + (image_index * num_rois + roi_index) * 8; 177 | const int32_t * roi_order = orders.data() + image_index * num_rois + roi_index; 178 | 179 | const T * feature_map_to_pool = inputs.data() + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width; 180 | // T * pooled_features_start = pooled_features.data() + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 181 | // int32_t * pooled_index_start = pooled_index.data() + image_index * (num_rois * num_channals) + roi_index * num_channals + (position_index % grid_size) * bank_size + channal_pos_remainder; 182 | T * pooled_features_start = pooled_features.data() + worker_index; 183 | int32_t * pooled_index_start = pooled_index.data() + worker_index; 184 | 185 | int32_t order = *roi_order < 0 ? 0 : *roi_order * 2; 186 | 187 | T roi_y0 = static_cast(roi_to_pool[(order++) % 8] * map_height); 188 | T roi_x0 = static_cast(roi_to_pool[(order++) % 8] * map_width); 189 | T roi_y1 = static_cast(roi_to_pool[(order++) % 8] * map_height); 190 | T roi_x1 = static_cast(roi_to_pool[(order++) % 8] * map_width); 191 | T roi_y2 = static_cast(roi_to_pool[(order++) % 8] * map_height); 192 | T roi_x2 = static_cast(roi_to_pool[(order++) % 8] * map_width); 193 | T roi_y3 = static_cast(roi_to_pool[(order++) % 8] * map_height); 194 | T roi_x3 = static_cast(roi_to_pool[(order++) % 8] * map_width); 195 | 196 | double len0 = static_cast((roi_y1 - roi_y0) * (roi_y1 - roi_y0) + (roi_x1 - roi_x0) * (roi_x1 - roi_x0)); 197 | double len1 = static_cast((roi_y2 - roi_y1) * (roi_y2 - roi_y1) + (roi_x2 - roi_x1) * (roi_x2 - roi_x1)); 198 | double len2 = static_cast((roi_y3 - roi_y2) * (roi_y3 - roi_y2) + (roi_x3 - roi_x2) * (roi_x3 - roi_x2)); 199 | double len3 = static_cast((roi_y0 - roi_y3) * (roi_y0 - roi_y3) + (roi_x0 - roi_x3) * (roi_x0 - roi_x3)); 200 | double cross_len0 = static_cast((roi_y0 - roi_y2) * (roi_y0 - roi_y2) + (roi_x0 - roi_x2) * (roi_x0 - roi_x2)); 201 | double cross_len1 = static_cast((roi_y3 - roi_y1) * (roi_y3 - roi_y1) + (roi_x3 - roi_x1) * (roi_x3 - roi_x1)); 202 | 203 | order = *roi_order < 0 ? (len0 + len2 > len1 + len3 ? 1 : 0) : 0; 204 | // fix ROI 205 | if(len0 < std::numeric_limits::min() || len1 < std::numeric_limits::min() || len2 < std::numeric_limits::min() || len3 < std::numeric_limits::min()){ 206 | // not check convex for faster speed 207 | //if(is_convex(roi_to_pool)){ 208 | *pooled_features_start = static_cast(0); 209 | *pooled_index_start = static_cast(0); 210 | continue; 211 | } 212 | 213 | T roi_y0_order = (order == 0) ? roi_y0 : roi_y1; 214 | T roi_x0_order = (order == 0) ? roi_x0 : roi_x1; 215 | T roi_y1_order = (order == 0) ? roi_y1 : roi_y2; 216 | T roi_x1_order = (order == 0) ? roi_x1 : roi_x2; 217 | T roi_y2_order = (order == 0) ? roi_y2 : roi_y3; 218 | T roi_x2_order = (order == 0) ? roi_x2 : roi_x3; 219 | T roi_y3_order = (order == 0) ? roi_y3 : roi_y0; 220 | T roi_x3_order = (order == 0) ? roi_x3 : roi_x0; 221 | 222 | T y_step_left = (roi_y3_order - roi_y0_order)/(grid_dim_height * 1.); 223 | T y_step_right = (roi_y2_order - roi_y1_order)/(grid_dim_height * 1.); 224 | T x_step_top = (roi_x1_order - roi_x0_order)/(grid_dim_width * 1.); 225 | T x_step_bottom = (roi_x2_order - roi_x3_order)/(grid_dim_width * 1.); 226 | 227 | T left_y1 = (roi_y0_order + row_index * y_step_left); 228 | T right_y1 = (roi_y1_order + row_index * y_step_right); 229 | T left_y2 = (roi_y0_order + (row_index + 1.) * y_step_left); 230 | T right_y2 = (roi_y1_order + (row_index + 1.) * y_step_right); 231 | 232 | T left_top_y = left_y1 + col_index * (right_y1 - left_y1)/(grid_dim_width); 233 | T right_top_y = left_y1 + (col_index + 1.) * (right_y1 - left_y1)/(grid_dim_width); 234 | T left_bottom_y = left_y2 + col_index * (right_y2 - left_y2)/(grid_dim_width); 235 | T right_bottom_y = left_y2 + (col_index + 1.) * (right_y2 - left_y2)/(grid_dim_width); 236 | 237 | T top_x1 = (roi_x0_order + col_index * x_step_top); 238 | T bottom_x1 = (roi_x3_order + col_index * x_step_bottom); 239 | T top_x2 = (roi_x0_order + (col_index + 1.) * x_step_top); 240 | T bottom_x2 = (roi_x3_order + (col_index + 1.) * x_step_bottom); 241 | 242 | T left_top_x = top_x1 + row_index * (bottom_x1 - top_x1)/(grid_dim_height); 243 | T left_bottom_x = top_x1 + (row_index + 1.) * (bottom_x1 - top_x1)/(grid_dim_height); 244 | T right_top_x = top_x2 + row_index * (bottom_x2 - top_x2)/(grid_dim_height); 245 | T right_bottom_x = top_x2 + (row_index + 1.) * (bottom_x2 - top_x2)/(grid_dim_height); 246 | 247 | float pool_bin_width = static_cast(std::max(std::min(std::abs(right_top_x - left_top_x), std::abs(right_top_y - left_top_y)), std::min(std::abs(right_bottom_x - left_bottom_x), std::abs(right_bottom_y - left_bottom_y)))); 248 | float pool_bin_height = static_cast(std::max(std::min(std::abs(left_bottom_x - left_top_x), std::abs(left_bottom_y - left_top_y)), std::min(std::abs(right_bottom_x - right_top_x), std::abs(right_bottom_y - right_top_y)))); 249 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 250 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 251 | 252 | T grid_y_step_left = (left_bottom_y - left_top_y)/(num_elem_height + 1.); 253 | T grid_y_step_right = (right_bottom_y - right_top_y)/(num_elem_height + 1.); 254 | T grid_x_step_top = (right_top_x - left_top_x)/(num_elem_width + 1.); 255 | T grid_x_step_bottom = (right_bottom_x - left_bottom_x)/(num_elem_width + 1.); 256 | 257 | int32_t max_pool_ind = 0; 258 | T max_or_acc_elem = using_max_pool ? std::numeric_limits::lowest() : static_cast(0); 259 | //std::cout << "num_elem_height: " << num_elem_height << " num_elem_width:" << num_elem_width << std::endl; 260 | for(int32_t pool_h = 0; pool_h < num_elem_height; ++pool_h){ 261 | for(int32_t pool_w = 0; pool_w < num_elem_width; ++pool_w){ 262 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 263 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 264 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 265 | 266 | int32_t int_col_to_pool = static_cast(col_to_pool); 267 | int32_t int_row_to_pool = static_cast(row_to_pool); 268 | float float_col_to_pool = col_to_pool - int_col_to_pool; 269 | float float_row_to_pool = row_to_pool - int_row_to_pool; 270 | 271 | int32_t current_switch_ind = num_elem_width * pool_h + pool_w; 272 | //std::cout << "current_switch_ind: " << current_switch_ind << std::endl; 273 | T temp_value = static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * feature_map_to_pool[int_row_to_pool * map_width + int_col_to_pool] + 274 | (1. - float_col_to_pool) * float_row_to_pool * feature_map_to_pool[std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool] + 275 | float_col_to_pool * (1. - float_row_to_pool) * feature_map_to_pool[int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1)] + 276 | float_col_to_pool * float_row_to_pool * feature_map_to_pool[std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1)]); 277 | if(using_max_pool){ 278 | if(max_or_acc_elem < temp_value){ 279 | max_or_acc_elem = temp_value; 280 | max_pool_ind = current_switch_ind; 281 | } 282 | }else{ 283 | max_or_acc_elem += temp_value; 284 | } 285 | } 286 | } 287 | 288 | if(!using_max_pool) max_or_acc_elem /= static_cast(num_elem_height * num_elem_width); 289 | 290 | *pooled_features_start = max_or_acc_elem; 291 | *pooled_index_start = using_max_pool ? max_pool_ind : static_cast(0); 292 | } 293 | }; 294 | 295 | const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 296 | // one worker for one position in each ROI 297 | const int64_t shard_cost = 4 * map_height * map_width / grid_dim_width / grid_dim_height / 4; 298 | Shard(worker_threads.num_threads, worker_threads.workers, 299 | pooled_features.size(), shard_cost, pooling_routine); 300 | } 301 | }; 302 | 303 | // OpKernel definition. 304 | // template parameter is the datatype of the tensors. 305 | template 306 | class RotatedPSROIAlignOp : public OpKernel { 307 | public: 308 | explicit RotatedPSROIAlignOp(OpKernelConstruction* context) : OpKernel(context) { 309 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_width", &grid_dim_width_in)); 310 | OP_REQUIRES(context, grid_dim_width_in >= 0, errors::InvalidArgument("Need Attr grid_dim_width >= 0, got ", grid_dim_width_in)); 311 | 312 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_height", &grid_dim_height_in)); 313 | OP_REQUIRES(context, grid_dim_height_in >= 0, errors::InvalidArgument("Need Attr grid_dim_height >= 0, got ", grid_dim_height_in)); 314 | 315 | OP_REQUIRES_OK(context, context->GetAttr("pool_method", &pool_method)); 316 | OP_REQUIRES(context, StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max")), errors::InvalidArgument("Need Attr pool_method to be either 'mean' or 'max', got ", pool_method)); 317 | // std::cout << (StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max"))) << std::endl; 318 | } 319 | 320 | void Compute(OpKernelContext* context) override { 321 | const Tensor& inputs_in = context->input(0); 322 | const Tensor& rois_in = context->input(1); 323 | const Tensor& orders_in = context->input(2); 324 | 325 | OP_REQUIRES(context, inputs_in.shape().dims() == 4, errors::InvalidArgument("inputs must be in 'NCHW' format.")); 326 | OP_REQUIRES(context, rois_in.shape().dims() == 3 && rois_in.shape().dim_size(2) == 8, errors::InvalidArgument("rois must be in 'batch_size x num_rois x 8' format.")); 327 | OP_REQUIRES(context, orders_in.shape().dims() == 2, errors::InvalidArgument("orders must be in 'batch_size x num_rois' format.")); 328 | OP_REQUIRES(context, inputs_in.dim_size(0) == rois_in.dim_size(0), errors::InvalidArgument("'batch_size' in inputs and rois don't match.")); 329 | OP_REQUIRES(context, (orders_in.dim_size(0) == rois_in.dim_size(0)) && (orders_in.dim_size(1) == rois_in.dim_size(1)), errors::InvalidArgument("'batch_size' or 'num_rois' in orders and rois don't match.")); 330 | 331 | const int batch_size = inputs_in.dim_size(0); 332 | const int num_channals = inputs_in.dim_size(1); 333 | const int map_height = inputs_in.dim_size(2); 334 | const int map_width = inputs_in.dim_size(3); 335 | const int num_rois = rois_in.dim_size(1); 336 | 337 | const int32_t grid_size = grid_dim_width_in * grid_dim_height_in; 338 | 339 | auto bank_size = static_cast(num_channals / grid_size); 340 | Tensor* pooled_features = nullptr; 341 | OP_REQUIRES_OK(context, context->allocate_output(0, {batch_size, num_rois, grid_size, bank_size}, &pooled_features)); 342 | Tensor* pooled_index = nullptr; 343 | OP_REQUIRES_OK(context, context->allocate_output(1, {batch_size, num_rois, grid_size, bank_size}, &pooled_index)); 344 | 345 | RotatedPSROIAlignFunctor()(context, context->eigen_device(), inputs_in.template flat(), rois_in.template flat(), orders_in.template flat(), grid_dim_width_in, grid_dim_height_in, pooled_features->template flat(), pooled_index->template flat(), std::make_tuple(batch_size, num_channals, map_height, map_width, num_rois, StringPiece(pool_method).contains(StringPiece("max")))); 346 | // RotatedPSROIPoolingFunctor()(context, context->eigen_device(), inputs_in.tensor(), rois_in.tensor(), grid_dim_buffer[0], pooled_features->tensor()); 347 | } 348 | 349 | private: 350 | int32_t grid_dim_width_in{-1}; 351 | int32_t grid_dim_height_in{-1}; 352 | std::string pool_method{"max"}; 353 | }; 354 | 355 | // Register the CPU kernels. 356 | #define REGISTER_CPU(T) \ 357 | REGISTER_KERNEL_BUILDER( \ 358 | Name("RotatedPsRoiAlign").Device(DEVICE_CPU).TypeConstraint("T"), \ 359 | RotatedPSROIAlignOp); 360 | REGISTER_CPU(float); 361 | 362 | // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 363 | // #undef REGISTER_CPU 364 | 365 | // Register the GPU kernels. 366 | #if GOOGLE_CUDA == 1 367 | #define REGISTER_GPU(T) \ 368 | REGISTER_KERNEL_BUILDER( \ 369 | Name("RotatedPsRoiAlign").Device(DEVICE_GPU).TypeConstraint("T"), \ 370 | RotatedPSROIAlignOp); 371 | REGISTER_GPU(float); 372 | #endif // GOOGLE_CUDA 373 | -------------------------------------------------------------------------------- /ps_roi_align_grad_op.cc: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #include "ps_roi_align_op.h" 23 | #include "common.h" 24 | #include "work_sharder.h" 25 | 26 | #include "tensorflow/core/framework/op_kernel.h" 27 | #include "tensorflow/core/framework/register_types.h" 28 | #include "tensorflow/core/framework/tensor.h" 29 | #include "tensorflow/core/framework/tensor_shape.h" 30 | #include "tensorflow/core/framework/register_types.h" 31 | #include "tensorflow/core/framework/op.h" 32 | #include "tensorflow/core/framework/shape_inference.h" 33 | 34 | #include 35 | 36 | using namespace tensorflow; 37 | 38 | // the inputs should have format NCHW, which is faster on GPUs 39 | REGISTER_OP("PsRoiAlignGrad") 40 | .Attr("T: {float}") 41 | .Attr("grid_dim_width: int") 42 | .Attr("grid_dim_height: int") 43 | .Attr("pool_method: string") 44 | .Input("inputs: T") 45 | .Input("rois: T") 46 | .Input("pooled_features_grad: T") 47 | .Input("pooled_index: int32") 48 | .Output("grad_output: T") 49 | .Doc(R"doc( 50 | PsRoiAlignGrad is the Gradient op of PsRoiAlign. 51 | The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. 52 | The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). 53 | )doc") 54 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 55 | c->set_output(0, c->input(0)); 56 | return Status::OK(); 57 | }); 58 | 59 | // CPU specialization of actual computation. 60 | // template 61 | // struct PSROIAlignGradFunctor { 62 | // void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 63 | 64 | // int batch_size = 0; 65 | // int num_channals = 0; 66 | // int map_height = 0; 67 | // int map_width = 0; 68 | // int num_rois = 0; 69 | // bool using_max_pool = false; 70 | 71 | // std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 72 | // grad_output = grad_output.setZero(); 73 | 74 | // auto pooling_grad_routine = [&rois, &pooled_features_grad, &pooled_index, &grad_output, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 75 | // const int32_t grid_size = grid_dim_width * grid_dim_height; 76 | // const int32_t bank_size = num_channals / grid_size; 77 | // for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 78 | // // image_index * roi_index * channal_pos_remainder * row_index * col_index 79 | // const int32_t position_index = (worker_index % num_channals) / bank_size; 80 | // const int32_t row_index = position_index / grid_dim_width; 81 | // const int32_t col_index = position_index % grid_dim_width; 82 | // // position of the channal of pooled feature 83 | // // position of the channal in the bank of feature map 84 | // const int32_t channal_pos_remainder = worker_index % bank_size; 85 | // const int32_t pool_index = worker_index / num_channals; 86 | // const int32_t image_index = pool_index / num_rois; 87 | // const int32_t roi_index = pool_index % num_rois; 88 | 89 | // const T * roi_to_pool = rois.data() + (image_index * num_rois + roi_index) * 4; 90 | 91 | // volatile T * grad_output_start = reinterpret_cast(grad_output.data() + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width); 92 | // const T * pooled_features_start = pooled_features_grad.data() + worker_index; 93 | // const int32_t * pooled_index_start = pooled_index.data() + worker_index; 94 | 95 | // T roi_ymin = static_cast(0); 96 | // T roi_xmin = static_cast(0); 97 | // T roi_ymax = static_cast(0); 98 | // T roi_xmax = static_cast(0); 99 | // if(roi_to_pool[2] < std::numeric_limits::min() || roi_to_pool[3] < std::numeric_limits::min()) continue; 100 | 101 | // // fix ROI 102 | // std::tie(roi_ymin, roi_xmin, roi_ymax, roi_xmax) = [roi_to_pool, map_height, map_width](){ 103 | // T roi_y_center = static_cast(roi_to_pool[0] * map_height); 104 | // T roi_x_center = static_cast(roi_to_pool[1] * map_width); 105 | // T roi_h = std::max(roi_to_pool[2] * map_height, static_cast(1)); 106 | // T roi_w = std::max(roi_to_pool[3] * map_width, static_cast(1)); 107 | 108 | // T roi_ymin = std::max(roi_y_center - static_cast(roi_h / 2.), static_cast(0)); 109 | // T roi_xmin = std::max(roi_x_center - static_cast(roi_w / 2.), static_cast(0)); 110 | // T roi_ymax = std::min(roi_y_center + static_cast(roi_h / 2.), static_cast(map_height) - std::numeric_limits::min()); 111 | // T roi_xmax = std::min(roi_x_center + static_cast(roi_w / 2.), static_cast(map_width) - std::numeric_limits::min()); 112 | // return std::make_tuple(roi_ymin, roi_xmin, roi_ymax, roi_xmax); 113 | // }(); 114 | // // T roi_center_y = roi_to_pool[0]; 115 | // // T roi_center_x = roi_to_pool[1]; 116 | // T roi_h = roi_ymax - roi_ymin; 117 | // T roi_w = roi_xmax - roi_xmin; 118 | // float pool_bin_width = static_cast(roi_w) / grid_dim_width; 119 | // float pool_bin_height = static_cast(roi_h) / grid_dim_height; 120 | // int32_t num_elem_width = static_cast(pool_bin_width) + 1; 121 | // int32_t num_elem_height = static_cast(pool_bin_height) + 1; 122 | 123 | // // std::cout << "pool_bin_width: " << pool_bin_width << " pool_bin_height: " << pool_bin_height << " num_elem_width: " << num_elem_width << " num_elem_height: " << num_elem_height << std::endl; 124 | 125 | // // std::cout << "worker_index: " << worker_index << " roi_index: " << roi_index 126 | // // << " roi_ymin: " << roi_ymin << " roi_xmin: " << roi_xmin << " roi_ymax: " << roi_ymax << " roi_xmax: " << roi_xmax << " image_index: " << image_index << " position_index: " << (position_index % grid_size) << " channal_pos_remainder: " << channal_pos_remainder << std::endl; 127 | 128 | // float step_width_each_bin = pool_bin_width / num_elem_width; 129 | // float step_height_each_bin = pool_bin_height / num_elem_height; 130 | 131 | // float pool_width_start = roi_xmin + pool_bin_width * col_index; 132 | // float pool_height_start = roi_ymin + pool_bin_height * row_index; 133 | 134 | // if(using_max_pool){ 135 | // const int32_t h_ind = *pooled_index_start / num_elem_width; 136 | // const int32_t w_ind = *pooled_index_start % num_elem_width; 137 | 138 | // float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 139 | // float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 140 | // //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 141 | // int32_t int_col_to_pool = static_cast(col_to_pool); 142 | // int32_t int_row_to_pool = static_cast(row_to_pool); 143 | // float float_col_to_pool = col_to_pool - int_col_to_pool; 144 | // float float_row_to_pool = row_to_pool - int_row_to_pool; 145 | 146 | // const T grad_in = *pooled_features_start; 147 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 148 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 149 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 150 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 151 | // }else{ 152 | // const T grad_in = *pooled_features_start / static_cast(num_elem_width * num_elem_height); 153 | // for (int32_t h_ind = 0; h_ind < num_elem_height; ++h_ind) { 154 | // for (int32_t w_ind = 0; w_ind < num_elem_width; ++w_ind) { 155 | // float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 156 | // float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 157 | 158 | // int32_t int_col_to_pool = static_cast(col_to_pool); 159 | // int32_t int_row_to_pool = static_cast(row_to_pool); 160 | // float float_col_to_pool = col_to_pool - int_col_to_pool; 161 | // float float_row_to_pool = row_to_pool - int_row_to_pool; 162 | 163 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 164 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 165 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 166 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 167 | // } 168 | // } 169 | // } 170 | // } 171 | // }; 172 | 173 | // const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 174 | // // one worker for one position in each ROI 175 | // const int64_t shard_cost = 4 * map_height * map_width / grid_dim_width / grid_dim_height / 4; 176 | // Shard(worker_threads.num_threads, worker_threads.workers, 177 | // pooled_features_grad.size(), shard_cost, pooling_grad_routine); 178 | // } 179 | // }; 180 | 181 | // // calculate gradients from input side 182 | // // the result of this kernel is same as the above kernel which is calculate gradients from the output side 183 | // // the different is that this kernel don't need synchronous gradients of the same input cell 184 | // // but the drawback of this kernel is that more threads scheduling may be occurred due to the larger input feature map size compared with output feature map 185 | // // you can choose any one to use depends on the relative overhead between the scheduling and atomic sync operation 186 | template 187 | struct PSROIAlignGradFunctor { 188 | void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 189 | int batch_size = 0; 190 | int num_channals = 0; 191 | int map_height = 0; 192 | int map_width = 0; 193 | int num_rois = 0; 194 | bool using_max_pool = false; 195 | 196 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 197 | 198 | grad_output = grad_output.setZero(); 199 | 200 | auto pooling_grad_routine = [&rois, &pooled_features_grad, &pooled_index, &grad_output, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 201 | const int32_t grid_size = grid_dim_width * grid_dim_height; 202 | const int32_t bank_size = num_channals/grid_size; 203 | for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 204 | const int32_t cur_image_index = worker_index / (num_channals * map_height * map_width); 205 | const int32_t cur_channal_index = (worker_index % (num_channals * map_height * map_width)) / (map_height * map_width); 206 | const int32_t offset_on_map = worker_index % (map_height * map_width); 207 | const int32_t col_on_map = offset_on_map % map_width; 208 | const int32_t row_on_map = offset_on_map / map_width; 209 | 210 | T * grad_to_fill = reinterpret_cast(grad_output.data() + worker_index); 211 | 212 | for(int roi_index = 0;roi_index < num_rois;++roi_index){ 213 | const T * roi_to_pool = rois.data() + (cur_image_index * num_rois + roi_index) * 4; 214 | 215 | T roi_ymin = static_cast(0); 216 | T roi_xmin = static_cast(0); 217 | T roi_ymax = static_cast(0); 218 | T roi_xmax = static_cast(0); 219 | // fix ROI 220 | if(roi_to_pool[2] < std::numeric_limits::min() || roi_to_pool[3] < std::numeric_limits::min()) continue; 221 | std::tie(roi_ymin, roi_xmin, roi_ymax, roi_xmax) = [roi_to_pool, map_height, map_width](){ 222 | T roi_y_center = static_cast(roi_to_pool[0] * map_height); 223 | T roi_x_center = static_cast(roi_to_pool[1] * map_width); 224 | T roi_h = std::max(roi_to_pool[2] * map_height, static_cast(1)); 225 | T roi_w = std::max(roi_to_pool[3] * map_width, static_cast(1)); 226 | 227 | T roi_ymin = std::max(roi_y_center - static_cast(roi_h / 2.), static_cast(0)); 228 | T roi_xmin = std::max(roi_x_center - static_cast(roi_w / 2.), static_cast(0)); 229 | T roi_ymax = std::min(roi_y_center + static_cast(roi_h / 2.), static_cast(map_height) - std::numeric_limits::min()); 230 | T roi_xmax = std::min(roi_x_center + static_cast(roi_w / 2.), static_cast(map_width) - std::numeric_limits::min()); 231 | return std::make_tuple(roi_ymin, roi_xmin, roi_ymax, roi_xmax); 232 | }(); 233 | // T roi_center_y = roi_to_pool[0]; 234 | // T roi_center_x = roi_to_pool[1]; 235 | T roi_h = roi_ymax - roi_ymin; 236 | T roi_w = roi_xmax - roi_xmin; 237 | float pool_bin_width = static_cast(roi_w) / grid_dim_width; 238 | float pool_bin_height = static_cast(roi_h) / grid_dim_height; 239 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 240 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 241 | 242 | // std::cout << "pool_bin_width: " << pool_bin_width << " pool_bin_height: " << pool_bin_height << " num_elem_width: " << num_elem_width << " num_elem_height: " << num_elem_height << std::endl; 243 | 244 | // std::cout << "worker_index: " << worker_index << " roi_index: " << roi_index 245 | // << " roi_ymin: " << roi_ymin << " roi_xmin: " << roi_xmin << " roi_ymax: " << roi_ymax << " roi_xmax: " << roi_xmax << " cur_image_index: " << cur_image_index << " position_index: " << (position_index % grid_size) << " channal_pos_remainder: " << channal_pos_remainder << std::endl; 246 | float step_width_each_bin = pool_bin_width / num_elem_width; 247 | float step_height_each_bin = pool_bin_height / num_elem_height; 248 | 249 | const T pooled_features_grad_in = *(pooled_features_grad.data() + cur_image_index * (num_rois * num_channals) + roi_index * num_channals + cur_channal_index); 250 | const int32_t pooled_max_index = *(pooled_index.data() + cur_image_index * (num_rois * num_channals) + roi_index * num_channals + cur_channal_index); 251 | 252 | const int32_t row_index = (cur_channal_index / bank_size) / grid_dim_width; 253 | const int32_t col_index = (cur_channal_index / bank_size) % grid_dim_width; 254 | 255 | float pool_width_start = roi_xmin + pool_bin_width * col_index; 256 | float pool_height_start = roi_ymin + pool_bin_height * row_index; 257 | 258 | if(using_max_pool){ 259 | const int32_t h_ind = pooled_max_index / num_elem_width; 260 | const int32_t w_ind = pooled_max_index % num_elem_width; 261 | 262 | float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 263 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 264 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 265 | int32_t int_col_to_pool = static_cast(col_to_pool); 266 | int32_t int_row_to_pool = static_cast(row_to_pool); 267 | float float_col_to_pool = col_to_pool - int_col_to_pool; 268 | float float_row_to_pool = row_to_pool - int_row_to_pool; 269 | 270 | // not 'if else' here for there may be collapsing in pooling operation when the ROI is small enough 271 | if(col_on_map == int_col_to_pool && row_on_map == int_row_to_pool){ 272 | *grad_to_fill += static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * pooled_features_grad_in); 273 | } 274 | if(col_on_map == int_col_to_pool && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 275 | *grad_to_fill += static_cast((1. - float_col_to_pool) * float_row_to_pool * pooled_features_grad_in); 276 | } 277 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == int_row_to_pool){ 278 | *grad_to_fill += static_cast(float_col_to_pool * (1. - float_row_to_pool) * pooled_features_grad_in); 279 | } 280 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 281 | *grad_to_fill += static_cast(float_col_to_pool * float_row_to_pool * pooled_features_grad_in); 282 | } 283 | }else{ 284 | T acc_back_grad = static_cast(0); 285 | for (int32_t h_ind = 0; h_ind < num_elem_height; ++h_ind) { 286 | for (int32_t w_ind = 0; w_ind < num_elem_width; ++w_ind) { 287 | float col_to_pool = pool_width_start + step_width_each_bin * w_ind + step_width_each_bin / 2.; 288 | float row_to_pool = pool_height_start + step_height_each_bin * h_ind + step_height_each_bin / 2.; 289 | //std::cout << "col_to_pool: " << col_to_pool << " row_to_pool: " << row_to_pool << std::endl; 290 | int32_t int_col_to_pool = static_cast(col_to_pool); 291 | int32_t int_row_to_pool = static_cast(row_to_pool); 292 | float float_col_to_pool = col_to_pool - int_col_to_pool; 293 | float float_row_to_pool = row_to_pool - int_row_to_pool; 294 | 295 | if(col_on_map == int_col_to_pool && row_on_map == int_row_to_pool){ 296 | acc_back_grad += static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * pooled_features_grad_in); 297 | } 298 | if(col_on_map == int_col_to_pool && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 299 | acc_back_grad += static_cast((1. - float_col_to_pool) * float_row_to_pool * pooled_features_grad_in); 300 | } 301 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == int_row_to_pool){ 302 | acc_back_grad += static_cast(float_col_to_pool * (1. - float_row_to_pool) * pooled_features_grad_in); 303 | } 304 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 305 | acc_back_grad += static_cast(float_col_to_pool * float_row_to_pool * pooled_features_grad_in); 306 | } 307 | } 308 | } 309 | *grad_to_fill += acc_back_grad / static_cast(num_elem_width * num_elem_height); 310 | } 311 | } 312 | 313 | } 314 | }; 315 | 316 | const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 317 | // one worker for one position in each ROI 318 | const int64_t shard_cost = num_rois * 4; 319 | Shard(worker_threads.num_threads, worker_threads.workers, 320 | grad_output.size(), shard_cost, pooling_grad_routine); 321 | } 322 | }; 323 | 324 | // OpKernel definition. 325 | // template parameter is the datatype of the tensors. 326 | template 327 | class PSROIAlignGradOp : public OpKernel { 328 | public: 329 | explicit PSROIAlignGradOp(OpKernelConstruction* context) : OpKernel(context) { 330 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_width", &grid_dim_width_in)); 331 | OP_REQUIRES(context, grid_dim_width_in >= 0, errors::InvalidArgument("Need Attr grid_dim_width >= 0, got ", grid_dim_width_in)); 332 | 333 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_height", &grid_dim_height_in)); 334 | OP_REQUIRES(context, grid_dim_height_in >= 0, errors::InvalidArgument("Need Attr grid_dim_height >= 0, got ", grid_dim_height_in)); 335 | 336 | OP_REQUIRES_OK(context, context->GetAttr("pool_method", &pool_method)); 337 | OP_REQUIRES(context, StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max")), errors::InvalidArgument("Need Attr pool_method to be either 'mean' or 'max', got ", pool_method)); 338 | } 339 | 340 | void Compute(OpKernelContext* context) override { 341 | const Tensor& inputs_in = context->input(0); 342 | const Tensor& rois_in = context->input(1); 343 | const Tensor& pooled_features_grad = context->input(2); 344 | const Tensor& pooled_index = context->input(3); 345 | 346 | OP_REQUIRES(context, inputs_in.shape().dims() == 4, errors::InvalidArgument("inputs must be in 'NCHW' format.")); 347 | OP_REQUIRES(context, pooled_features_grad.shape() == pooled_index.shape(), errors::InvalidArgument("pooled_index and pooled_features_grad must have the same shape")); 348 | OP_REQUIRES(context, rois_in.shape().dims() == 3 && rois_in.shape().dim_size(2) == 4, errors::InvalidArgument("rois must be in 'batch_size x num_rois x 4' format.")); 349 | OP_REQUIRES(context, inputs_in.dim_size(0) == rois_in.dim_size(0), errors::InvalidArgument("'batch_size' in inputs and rois don't match.")); 350 | 351 | const int batch_size = inputs_in.dim_size(0); 352 | const int num_channals = inputs_in.dim_size(1); 353 | const int map_height = inputs_in.dim_size(2); 354 | const int map_width = inputs_in.dim_size(3); 355 | const int num_rois = rois_in.dim_size(1); 356 | 357 | const int32_t grid_size = grid_dim_width_in * grid_dim_height_in; 358 | auto bank_size = static_cast(num_channals / grid_size); 359 | 360 | OP_REQUIRES(context, pooled_features_grad.shape() == TensorShape({batch_size, num_rois, grid_size, bank_size}), errors::InvalidArgument("both pooled_index and pooled_features_grad must have the shape 'batch_size x num_rois x grid_size x bank_size'")); 361 | 362 | Tensor* grad_output = nullptr; 363 | OP_REQUIRES_OK(context, context->allocate_output(0, inputs_in.shape(), &grad_output)); 364 | 365 | PSROIAlignGradFunctor()(context, context->eigen_device(), inputs_in.template flat(), rois_in.template flat(), grid_dim_width_in, grid_dim_height_in, pooled_features_grad.template flat(), pooled_index.template flat(), grad_output->template flat(), std::make_tuple(batch_size, num_channals, map_height, map_width, num_rois, StringPiece(pool_method).contains(StringPiece("max")))); 366 | // PSROIPoolingFunctor()(context, context->eigen_device(), inputs_in.tensor(), rois_in.tensor(), grid_dim_buffer[0], pooled_features->tensor()); 367 | } 368 | 369 | private: 370 | int32_t grid_dim_width_in{-1}; 371 | int32_t grid_dim_height_in{-1}; 372 | std::string pool_method{"max"}; 373 | }; 374 | 375 | // Register the CPU kernels. 376 | #define REGISTER_CPU(T) \ 377 | REGISTER_KERNEL_BUILDER( \ 378 | Name("PsRoiAlignGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ 379 | PSROIAlignGradOp); 380 | REGISTER_CPU(float); 381 | 382 | // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 383 | // #undef REGISTER_CPU 384 | 385 | // Register the GPU kernels. 386 | #if GOOGLE_CUDA == 1 387 | #define REGISTER_GPU(T) \ 388 | REGISTER_KERNEL_BUILDER( \ 389 | Name("PsRoiAlignGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ 390 | PSROIAlignGradOp); 391 | REGISTER_GPU(float); 392 | #endif // GOOGLE_CUDA 393 | -------------------------------------------------------------------------------- /rotated_ps_roi_align_grad_op.cc: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2018 Changan Wang 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | #include "rotated_ps_roi_align_op.h" 23 | #include "common.h" 24 | #include "work_sharder.h" 25 | 26 | #include "tensorflow/core/framework/op_kernel.h" 27 | #include "tensorflow/core/framework/register_types.h" 28 | #include "tensorflow/core/framework/tensor.h" 29 | #include "tensorflow/core/framework/tensor_shape.h" 30 | #include "tensorflow/core/framework/register_types.h" 31 | #include "tensorflow/core/framework/op.h" 32 | #include "tensorflow/core/framework/shape_inference.h" 33 | 34 | #include 35 | 36 | using namespace tensorflow; 37 | 38 | // the inputs should have format NCHW, which is faster on GPUs 39 | REGISTER_OP("RotatedPsRoiAlignGrad") 40 | .Attr("T: {float}") 41 | .Attr("grid_dim_width: int") 42 | .Attr("grid_dim_height: int") 43 | .Attr("pool_method: string") 44 | .Input("inputs: T") 45 | .Input("rois: T") 46 | .Input("orders: int32") 47 | .Input("pooled_features_grad: T") 48 | .Input("pooled_index: int32") 49 | .Output("grad_output: T") 50 | .Doc(R"doc( 51 | RotatedPsRoiAlignGrad is the Gradient op of RotatedPsRoiAlign. 52 | The input rois to be pooled must in format [y0, x0, y1, x1, y2, x2, y3, x3] which is four vertexes defining quadrilateral in clockwise order and each element must be in range [0, 1.]. 53 | The input orders define which point is the first one, each element must be in range [-1, 4). The order will be determined to be the first vertex of the shorter side if given -1. 54 | The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0., 0., 0., 1., 1., 1., 1., 0.]). 55 | )doc") 56 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 57 | c->set_output(0, c->input(0)); 58 | return Status::OK(); 59 | }); 60 | 61 | // CPU specialization of actual computation. 62 | // template 63 | // struct RotatedPSROIAlignGradFunctor { 64 | // void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 65 | 66 | // int batch_size = 0; 67 | // int num_channals = 0; 68 | // int map_height = 0; 69 | // int map_width = 0; 70 | // int num_rois = 0; 71 | // bool using_max_pool = false; 72 | 73 | // std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 74 | // grad_output = grad_output.setZero(); 75 | 76 | // auto pooling_grad_routine = [&rois, &orders, &pooled_features_grad, &pooled_index, &grad_output, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 77 | // const int32_t grid_size = grid_dim_width * grid_dim_height; 78 | // const int32_t bank_size = num_channals / grid_size; 79 | // for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 80 | // // image_index * roi_index * channal_pos_remainder * row_index * col_index 81 | // const int32_t position_index = (worker_index % num_channals) / bank_size; 82 | // const int32_t row_index = position_index / grid_dim_width; 83 | // const int32_t col_index = position_index % grid_dim_width; 84 | // // position of the channal of pooled feature 85 | // // position of the channal in the bank of feature map 86 | // const int32_t channal_pos_remainder = worker_index % bank_size; 87 | // const int32_t pool_index = worker_index / num_channals; 88 | // const int32_t image_index = pool_index / num_rois; 89 | // const int32_t roi_index = pool_index % num_rois; 90 | 91 | // const T * roi_to_pool = rois.data() + (image_index * num_rois + roi_index) * 8; 92 | // const int32_t * roi_order = orders.data() + image_index * num_rois + roi_index; 93 | 94 | // volatile T * grad_output_start = reinterpret_cast(grad_output.data() + (image_index * num_channals + position_index * bank_size + channal_pos_remainder) * map_height * map_width); 95 | // const T * pooled_features_start = pooled_features_grad.data() + worker_index; 96 | // const int32_t * pooled_index_start = pooled_index.data() + worker_index; 97 | 98 | // int32_t order = *roi_order < 0 ? 0 : *roi_order * 2; 99 | 100 | // T roi_y0 = static_cast(roi_to_pool[(order++) % 8] * map_height); 101 | // T roi_x0 = static_cast(roi_to_pool[(order++) % 8] * map_width); 102 | // T roi_y1 = static_cast(roi_to_pool[(order++) % 8] * map_height); 103 | // T roi_x1 = static_cast(roi_to_pool[(order++) % 8] * map_width); 104 | // T roi_y2 = static_cast(roi_to_pool[(order++) % 8] * map_height); 105 | // T roi_x2 = static_cast(roi_to_pool[(order++) % 8] * map_width); 106 | // T roi_y3 = static_cast(roi_to_pool[(order++) % 8] * map_height); 107 | // T roi_x3 = static_cast(roi_to_pool[(order++) % 8] * map_width); 108 | 109 | // double len0 = static_cast((roi_y1 - roi_y0) * (roi_y1 - roi_y0) + (roi_x1 - roi_x0) * (roi_x1 - roi_x0)); 110 | // double len1 = static_cast((roi_y2 - roi_y1) * (roi_y2 - roi_y1) + (roi_x2 - roi_x1) * (roi_x2 - roi_x1)); 111 | // double len2 = static_cast((roi_y3 - roi_y2) * (roi_y3 - roi_y2) + (roi_x3 - roi_x2) * (roi_x3 - roi_x2)); 112 | // double len3 = static_cast((roi_y0 - roi_y3) * (roi_y0 - roi_y3) + (roi_x0 - roi_x3) * (roi_x0 - roi_x3)); 113 | // double cross_len0 = static_cast((roi_y0 - roi_y2) * (roi_y0 - roi_y2) + (roi_x0 - roi_x2) * (roi_x0 - roi_x2)); 114 | // double cross_len1 = static_cast((roi_y3 - roi_y1) * (roi_y3 - roi_y1) + (roi_x3 - roi_x1) * (roi_x3 - roi_x1)); 115 | 116 | // order = *roi_order < 0 ? (len0 + len2 > len1 + len3 ? 1 : 0) : 0; 117 | // // fix ROI 118 | // if(len0 < std::numeric_limits::min() || len1 < std::numeric_limits::min() || len2 < std::numeric_limits::min() || len3 < std::numeric_limits::min()){ 119 | // // not check convex for faster speed 120 | // //if(is_convex(roi_to_pool)){ 121 | // continue; 122 | // } 123 | 124 | // T roi_y0_order = (order == 0) ? roi_y0 : roi_y1; 125 | // T roi_x0_order = (order == 0) ? roi_x0 : roi_x1; 126 | // T roi_y1_order = (order == 0) ? roi_y1 : roi_y2; 127 | // T roi_x1_order = (order == 0) ? roi_x1 : roi_x2; 128 | // T roi_y2_order = (order == 0) ? roi_y2 : roi_y3; 129 | // T roi_x2_order = (order == 0) ? roi_x2 : roi_x3; 130 | // T roi_y3_order = (order == 0) ? roi_y3 : roi_y0; 131 | // T roi_x3_order = (order == 0) ? roi_x3 : roi_x0; 132 | 133 | // T y_step_left = (roi_y3_order - roi_y0_order)/(grid_dim_height * 1.); 134 | // T y_step_right = (roi_y2_order - roi_y1_order)/(grid_dim_height * 1.); 135 | // T x_step_top = (roi_x1_order - roi_x0_order)/(grid_dim_width * 1.); 136 | // T x_step_bottom = (roi_x2_order - roi_x3_order)/(grid_dim_width * 1.); 137 | 138 | // T left_y1 = (roi_y0_order + row_index * y_step_left); 139 | // T right_y1 = (roi_y1_order + row_index * y_step_right); 140 | // T left_y2 = (roi_y0_order + (row_index + 1.) * y_step_left); 141 | // T right_y2 = (roi_y1_order + (row_index + 1.) * y_step_right); 142 | 143 | // T left_top_y = left_y1 + col_index * (right_y1 - left_y1)/(grid_dim_width); 144 | // T right_top_y = left_y1 + (col_index + 1.) * (right_y1 - left_y1)/(grid_dim_width); 145 | // T left_bottom_y = left_y2 + col_index * (right_y2 - left_y2)/(grid_dim_width); 146 | // T right_bottom_y = left_y2 + (col_index + 1.) * (right_y2 - left_y2)/(grid_dim_width); 147 | 148 | // T top_x1 = (roi_x0_order + col_index * x_step_top); 149 | // T bottom_x1 = (roi_x3_order + col_index * x_step_bottom); 150 | // T top_x2 = (roi_x0_order + (col_index + 1.) * x_step_top); 151 | // T bottom_x2 = (roi_x3_order + (col_index + 1.) * x_step_bottom); 152 | 153 | // T left_top_x = top_x1 + row_index * (bottom_x1 - top_x1)/(grid_dim_height); 154 | // T left_bottom_x = top_x1 + (row_index + 1.) * (bottom_x1 - top_x1)/(grid_dim_height); 155 | // T right_top_x = top_x2 + row_index * (bottom_x2 - top_x2)/(grid_dim_height); 156 | // T right_bottom_x = top_x2 + (row_index + 1.) * (bottom_x2 - top_x2)/(grid_dim_height); 157 | 158 | // float pool_bin_width = static_cast(std::max(std::min(std::abs(right_top_x - left_top_x), std::abs(right_top_y - left_top_y)), std::min(std::abs(right_bottom_x - left_bottom_x), std::abs(right_bottom_y - left_bottom_y)))); 159 | // float pool_bin_height = static_cast(std::max(std::min(std::abs(left_bottom_x - left_top_x), std::abs(left_bottom_y - left_top_y)), std::min(std::abs(right_bottom_x - right_top_x), std::abs(right_bottom_y - right_top_y)))); 160 | // int32_t num_elem_width = static_cast(pool_bin_width) + 1; 161 | // int32_t num_elem_height = static_cast(pool_bin_height) + 1; 162 | 163 | // T grid_y_step_left = (left_bottom_y - left_top_y)/(num_elem_height + 1.); 164 | // T grid_y_step_right = (right_bottom_y - right_top_y)/(num_elem_height + 1.); 165 | // T grid_x_step_top = (right_top_x - left_top_x)/(num_elem_width + 1.); 166 | // T grid_x_step_bottom = (right_bottom_x - left_bottom_x)/(num_elem_width + 1.); 167 | 168 | // if(using_max_pool){ 169 | // const int32_t pool_h = *pooled_index_start / num_elem_width; 170 | // const int32_t pool_w = *pooled_index_start % num_elem_width; 171 | 172 | // T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 173 | // T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 174 | 175 | // int32_t int_col_to_pool = static_cast(col_to_pool); 176 | // int32_t int_row_to_pool = static_cast(row_to_pool); 177 | // float float_col_to_pool = col_to_pool - int_col_to_pool; 178 | // float float_row_to_pool = row_to_pool - int_row_to_pool; 179 | 180 | // const T grad_in = *pooled_features_start; 181 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 182 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 183 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 184 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 185 | // }else{ 186 | // const T grad_in = *pooled_features_start / static_cast(num_elem_width * num_elem_height); 187 | // for(int32_t pool_h = 0; pool_h < num_elem_height; ++pool_h){ 188 | // for(int32_t pool_w = 0; pool_w < num_elem_width; ++pool_w){ 189 | // T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 190 | // T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 191 | 192 | // int32_t int_col_to_pool = static_cast(col_to_pool); 193 | // int32_t int_row_to_pool = static_cast(row_to_pool); 194 | // float float_col_to_pool = col_to_pool - int_col_to_pool; 195 | // float float_row_to_pool = row_to_pool - int_row_to_pool; 196 | 197 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * grad_in)); 198 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + int_col_to_pool, static_cast((1. - float_col_to_pool) * float_row_to_pool * grad_in)); 199 | // atomic_float_add(grad_output_start + int_row_to_pool * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * (1. - float_row_to_pool) * grad_in)); 200 | // atomic_float_add(grad_output_start + std::min(int_row_to_pool + 1, map_height - 1) * map_width + std::min(int_col_to_pool + 1, map_width - 1), static_cast(float_col_to_pool * float_row_to_pool * grad_in)); 201 | // } 202 | // } 203 | // } 204 | // } 205 | // }; 206 | 207 | // const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 208 | // // one worker for one position in each ROI 209 | // const int64_t shard_cost = 4 * map_height * map_width / grid_dim_width / grid_dim_height / 4; 210 | // Shard(worker_threads.num_threads, worker_threads.workers, 211 | // pooled_features_grad.size(), shard_cost, pooling_grad_routine); 212 | // } 213 | // }; 214 | 215 | // // calculate gradients from input side 216 | // // the result of this kernel is same as the above kernel which is calculate gradients from the output side 217 | // // the different is that this kernel don't need synchronous gradients of the same input cell 218 | // // but the drawback of this kernel is that more threads scheduling may be occurred due to the larger input feature map size compared with output feature map 219 | // // you can choose any one to use depends on the relative overhead between the scheduling and atomic sync operation 220 | template 221 | struct RotatedPSROIAlignGradFunctor { 222 | void operator()(OpKernelContext* context, const CPUDevice& d, typename TTypes::ConstFlat inputs, typename TTypes::ConstFlat rois, typename TTypes::ConstFlat orders, const int32_t grid_dim_width, const int32_t grid_dim_height, typename TTypes::ConstFlat pooled_features_grad, typename TTypes::ConstFlat pooled_index, typename TTypes::Flat grad_output, KDimSize dim_info) { 223 | int batch_size = 0; 224 | int num_channals = 0; 225 | int map_height = 0; 226 | int map_width = 0; 227 | int num_rois = 0; 228 | bool using_max_pool = false; 229 | 230 | std::tie(batch_size, num_channals, map_height, map_width, num_rois, using_max_pool) = dim_info; 231 | 232 | grad_output = grad_output.setZero(); 233 | 234 | auto pooling_grad_routine = [&rois, &orders, &pooled_features_grad, &pooled_index, &grad_output, grid_dim_width, grid_dim_height, batch_size, num_channals, map_height, map_width, num_rois, using_max_pool](int64_t start, int64_t limit){ 235 | const int32_t grid_size = grid_dim_width * grid_dim_height; 236 | const int32_t bank_size = num_channals/grid_size; 237 | for (int64_t worker_index = start; worker_index < limit; ++worker_index){ 238 | const int32_t cur_image_index = worker_index / (num_channals * map_height * map_width); 239 | const int32_t cur_channal_index = (worker_index % (num_channals * map_height * map_width)) / (map_height * map_width); 240 | const int32_t offset_on_map = worker_index % (map_height * map_width); 241 | const int32_t col_on_map = offset_on_map % map_width; 242 | const int32_t row_on_map = offset_on_map / map_width; 243 | 244 | T * grad_to_fill = reinterpret_cast(grad_output.data() + worker_index); 245 | 246 | for(int roi_index = 0;roi_index < num_rois;++roi_index){ 247 | const T * roi_to_pool = rois.data() + (cur_image_index * num_rois + roi_index) * 8; 248 | const int32_t * roi_order = orders.data() + cur_image_index * num_rois + roi_index; 249 | 250 | const T pooled_features_grad_in = *(pooled_features_grad.data() + cur_image_index * (num_rois * num_channals) + roi_index * num_channals + cur_channal_index); 251 | const int32_t pooled_max_index = *(pooled_index.data() + cur_image_index * (num_rois * num_channals) + roi_index * num_channals + cur_channal_index); 252 | 253 | const int32_t row_index = (cur_channal_index / bank_size) / grid_dim_width; 254 | const int32_t col_index = (cur_channal_index / bank_size) % grid_dim_width; 255 | 256 | int32_t order = *roi_order < 0 ? 0 : *roi_order * 2; 257 | 258 | T roi_y0 = static_cast(roi_to_pool[(order++) % 8] * map_height); 259 | T roi_x0 = static_cast(roi_to_pool[(order++) % 8] * map_width); 260 | T roi_y1 = static_cast(roi_to_pool[(order++) % 8] * map_height); 261 | T roi_x1 = static_cast(roi_to_pool[(order++) % 8] * map_width); 262 | T roi_y2 = static_cast(roi_to_pool[(order++) % 8] * map_height); 263 | T roi_x2 = static_cast(roi_to_pool[(order++) % 8] * map_width); 264 | T roi_y3 = static_cast(roi_to_pool[(order++) % 8] * map_height); 265 | T roi_x3 = static_cast(roi_to_pool[(order++) % 8] * map_width); 266 | 267 | double len0 = static_cast((roi_y1 - roi_y0) * (roi_y1 - roi_y0) + (roi_x1 - roi_x0) * (roi_x1 - roi_x0)); 268 | double len1 = static_cast((roi_y2 - roi_y1) * (roi_y2 - roi_y1) + (roi_x2 - roi_x1) * (roi_x2 - roi_x1)); 269 | double len2 = static_cast((roi_y3 - roi_y2) * (roi_y3 - roi_y2) + (roi_x3 - roi_x2) * (roi_x3 - roi_x2)); 270 | double len3 = static_cast((roi_y0 - roi_y3) * (roi_y0 - roi_y3) + (roi_x0 - roi_x3) * (roi_x0 - roi_x3)); 271 | double cross_len0 = static_cast((roi_y0 - roi_y2) * (roi_y0 - roi_y2) + (roi_x0 - roi_x2) * (roi_x0 - roi_x2)); 272 | double cross_len1 = static_cast((roi_y3 - roi_y1) * (roi_y3 - roi_y1) + (roi_x3 - roi_x1) * (roi_x3 - roi_x1)); 273 | 274 | order = *roi_order < 0 ? (len0 + len2 > len1 + len3 ? 1 : 0) : 0; 275 | // fix ROI 276 | if(len0 < std::numeric_limits::min() || len1 < std::numeric_limits::min() || len2 < std::numeric_limits::min() || len3 < std::numeric_limits::min()){ 277 | // not check convex for faster speed 278 | //if(is_convex(roi_to_pool)){ 279 | continue; 280 | } 281 | 282 | T roi_y0_order = (order == 0) ? roi_y0 : roi_y1; 283 | T roi_x0_order = (order == 0) ? roi_x0 : roi_x1; 284 | T roi_y1_order = (order == 0) ? roi_y1 : roi_y2; 285 | T roi_x1_order = (order == 0) ? roi_x1 : roi_x2; 286 | T roi_y2_order = (order == 0) ? roi_y2 : roi_y3; 287 | T roi_x2_order = (order == 0) ? roi_x2 : roi_x3; 288 | T roi_y3_order = (order == 0) ? roi_y3 : roi_y0; 289 | T roi_x3_order = (order == 0) ? roi_x3 : roi_x0; 290 | 291 | T y_step_left = (roi_y3_order - roi_y0_order)/(grid_dim_height * 1.); 292 | T y_step_right = (roi_y2_order - roi_y1_order)/(grid_dim_height * 1.); 293 | T x_step_top = (roi_x1_order - roi_x0_order)/(grid_dim_width * 1.); 294 | T x_step_bottom = (roi_x2_order - roi_x3_order)/(grid_dim_width * 1.); 295 | 296 | T left_y1 = (roi_y0_order + row_index * y_step_left); 297 | T right_y1 = (roi_y1_order + row_index * y_step_right); 298 | T left_y2 = (roi_y0_order + (row_index + 1.) * y_step_left); 299 | T right_y2 = (roi_y1_order + (row_index + 1.) * y_step_right); 300 | 301 | T left_top_y = left_y1 + col_index * (right_y1 - left_y1)/(grid_dim_width); 302 | T right_top_y = left_y1 + (col_index + 1.) * (right_y1 - left_y1)/(grid_dim_width); 303 | T left_bottom_y = left_y2 + col_index * (right_y2 - left_y2)/(grid_dim_width); 304 | T right_bottom_y = left_y2 + (col_index + 1.) * (right_y2 - left_y2)/(grid_dim_width); 305 | 306 | T top_x1 = (roi_x0_order + col_index * x_step_top); 307 | T bottom_x1 = (roi_x3_order + col_index * x_step_bottom); 308 | T top_x2 = (roi_x0_order + (col_index + 1.) * x_step_top); 309 | T bottom_x2 = (roi_x3_order + (col_index + 1.) * x_step_bottom); 310 | 311 | T left_top_x = top_x1 + row_index * (bottom_x1 - top_x1)/(grid_dim_height); 312 | T left_bottom_x = top_x1 + (row_index + 1.) * (bottom_x1 - top_x1)/(grid_dim_height); 313 | T right_top_x = top_x2 + row_index * (bottom_x2 - top_x2)/(grid_dim_height); 314 | T right_bottom_x = top_x2 + (row_index + 1.) * (bottom_x2 - top_x2)/(grid_dim_height); 315 | 316 | float pool_bin_width = static_cast(std::max(std::min(std::abs(right_top_x - left_top_x), std::abs(right_top_y - left_top_y)), std::min(std::abs(right_bottom_x - left_bottom_x), std::abs(right_bottom_y - left_bottom_y)))); 317 | float pool_bin_height = static_cast(std::max(std::min(std::abs(left_bottom_x - left_top_x), std::abs(left_bottom_y - left_top_y)), std::min(std::abs(right_bottom_x - right_top_x), std::abs(right_bottom_y - right_top_y)))); 318 | int32_t num_elem_width = static_cast(pool_bin_width) + 1; 319 | int32_t num_elem_height = static_cast(pool_bin_height) + 1; 320 | 321 | T grid_y_step_left = (left_bottom_y - left_top_y)/(num_elem_height + 1.); 322 | T grid_y_step_right = (right_bottom_y - right_top_y)/(num_elem_height + 1.); 323 | T grid_x_step_top = (right_top_x - left_top_x)/(num_elem_width + 1.); 324 | T grid_x_step_bottom = (right_bottom_x - left_bottom_x)/(num_elem_width + 1.); 325 | 326 | if(using_max_pool){ 327 | const int32_t pool_h = pooled_max_index / num_elem_width; 328 | const int32_t pool_w = pooled_max_index % num_elem_width; 329 | 330 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 331 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 332 | 333 | int32_t int_col_to_pool = static_cast(col_to_pool); 334 | int32_t int_row_to_pool = static_cast(row_to_pool); 335 | float float_col_to_pool = col_to_pool - int_col_to_pool; 336 | float float_row_to_pool = row_to_pool - int_row_to_pool; 337 | 338 | // not 'if else' here for there may be collapsing in pooling operation when the ROI is small enough 339 | if(col_on_map == int_col_to_pool && row_on_map == int_row_to_pool){ 340 | *grad_to_fill += static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * pooled_features_grad_in); 341 | } 342 | if(col_on_map == int_col_to_pool && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 343 | *grad_to_fill += static_cast((1. - float_col_to_pool) * float_row_to_pool * pooled_features_grad_in); 344 | } 345 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == int_row_to_pool){ 346 | *grad_to_fill += static_cast(float_col_to_pool * (1. - float_row_to_pool) * pooled_features_grad_in); 347 | } 348 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 349 | *grad_to_fill += static_cast(float_col_to_pool * float_row_to_pool * pooled_features_grad_in); 350 | } 351 | }else{ 352 | T acc_back_grad = static_cast(0); 353 | for(int32_t pool_h = 0; pool_h < num_elem_height; ++pool_h){ 354 | for(int32_t pool_w = 0; pool_w < num_elem_width; ++pool_w){ 355 | T col_to_pool = (left_top_x + (pool_w + 1.) * grid_x_step_top + left_bottom_x + (pool_w + 1.) * grid_x_step_bottom) / 2.; 356 | T row_to_pool = (left_top_y + (pool_h + 1.) * grid_y_step_left + right_top_y + (pool_h + 1.) * grid_y_step_right) / 2.; 357 | 358 | int32_t int_col_to_pool = static_cast(col_to_pool); 359 | int32_t int_row_to_pool = static_cast(row_to_pool); 360 | float float_col_to_pool = col_to_pool - int_col_to_pool; 361 | float float_row_to_pool = row_to_pool - int_row_to_pool; 362 | 363 | if(col_on_map == int_col_to_pool && row_on_map == int_row_to_pool){ 364 | acc_back_grad += static_cast((1. - float_col_to_pool) * (1. - float_row_to_pool) * pooled_features_grad_in); 365 | } 366 | if(col_on_map == int_col_to_pool && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 367 | acc_back_grad += static_cast((1. - float_col_to_pool) * float_row_to_pool * pooled_features_grad_in); 368 | } 369 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == int_row_to_pool){ 370 | acc_back_grad += static_cast(float_col_to_pool * (1. - float_row_to_pool) * pooled_features_grad_in); 371 | } 372 | if(col_on_map == std::min(int_col_to_pool + 1, map_width - 1) && row_on_map == std::min(int_row_to_pool + 1, map_height - 1)){ 373 | acc_back_grad += static_cast(float_col_to_pool * float_row_to_pool * pooled_features_grad_in); 374 | } 375 | } 376 | } 377 | *grad_to_fill += acc_back_grad / static_cast(num_elem_width * num_elem_height); 378 | } 379 | } 380 | 381 | } 382 | }; 383 | 384 | const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 385 | // one worker for one position in each ROI 386 | const int64_t shard_cost = num_rois * 4; 387 | Shard(worker_threads.num_threads, worker_threads.workers, 388 | grad_output.size(), shard_cost, pooling_grad_routine); 389 | } 390 | }; 391 | 392 | // OpKernel definition. 393 | // template parameter is the datatype of the tensors. 394 | template 395 | class RotatedPSROIAlignGradOp : public OpKernel { 396 | public: 397 | explicit RotatedPSROIAlignGradOp(OpKernelConstruction* context) : OpKernel(context) { 398 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_width", &grid_dim_width_in)); 399 | OP_REQUIRES(context, grid_dim_width_in >= 0, errors::InvalidArgument("Need Attr grid_dim_width >= 0, got ", grid_dim_width_in)); 400 | 401 | OP_REQUIRES_OK(context, context->GetAttr("grid_dim_height", &grid_dim_height_in)); 402 | OP_REQUIRES(context, grid_dim_height_in >= 0, errors::InvalidArgument("Need Attr grid_dim_height >= 0, got ", grid_dim_height_in)); 403 | 404 | OP_REQUIRES_OK(context, context->GetAttr("pool_method", &pool_method)); 405 | OP_REQUIRES(context, StringPiece(pool_method).contains(StringPiece("mean")) || StringPiece(pool_method).contains(StringPiece("max")), errors::InvalidArgument("Need Attr pool_method to be either 'mean' or 'max', got ", pool_method)); 406 | } 407 | 408 | void Compute(OpKernelContext* context) override { 409 | const Tensor& inputs_in = context->input(0); 410 | const Tensor& rois_in = context->input(1); 411 | const Tensor& orders_in = context->input(2); 412 | const Tensor& pooled_features_grad = context->input(3); 413 | const Tensor& pooled_index = context->input(4); 414 | 415 | OP_REQUIRES(context, inputs_in.shape().dims() == 4, errors::InvalidArgument("inputs must be in 'NCHW' format.")); 416 | OP_REQUIRES(context, pooled_features_grad.shape() == pooled_index.shape(), errors::InvalidArgument("pooled_index and pooled_features_grad must have the same shape")); 417 | OP_REQUIRES(context, rois_in.shape().dims() == 3 && rois_in.shape().dim_size(2) == 8, errors::InvalidArgument("rois must be in 'batch_size x num_rois x 8' format.")); 418 | OP_REQUIRES(context, inputs_in.dim_size(0) == rois_in.dim_size(0), errors::InvalidArgument("'batch_size' in inputs and rois don't match.")); 419 | OP_REQUIRES(context, orders_in.shape().dims() == 2, errors::InvalidArgument("orders must be in 'batch_size x num_rois' format.")); 420 | OP_REQUIRES(context, (orders_in.dim_size(0) == rois_in.dim_size(0)) && (orders_in.dim_size(1) == rois_in.dim_size(1)), errors::InvalidArgument("'batch_size' or 'num_rois' in orders and rois don't match.")); 421 | 422 | const int batch_size = inputs_in.dim_size(0); 423 | const int num_channals = inputs_in.dim_size(1); 424 | const int map_height = inputs_in.dim_size(2); 425 | const int map_width = inputs_in.dim_size(3); 426 | const int num_rois = rois_in.dim_size(1); 427 | 428 | const int32_t grid_size = grid_dim_width_in * grid_dim_height_in; 429 | auto bank_size = static_cast(num_channals / grid_size); 430 | 431 | OP_REQUIRES(context, pooled_features_grad.shape() == TensorShape({batch_size, num_rois, grid_size, bank_size}), errors::InvalidArgument("both pooled_index and pooled_features_grad must have the shape 'batch_size x num_rois x grid_size x bank_size'")); 432 | 433 | Tensor* grad_output = nullptr; 434 | OP_REQUIRES_OK(context, context->allocate_output(0, inputs_in.shape(), &grad_output)); 435 | 436 | RotatedPSROIAlignGradFunctor()(context, context->eigen_device(), inputs_in.template flat(), rois_in.template flat(), orders_in.template flat(), grid_dim_width_in, grid_dim_height_in, pooled_features_grad.template flat(), pooled_index.template flat(), grad_output->template flat(), std::make_tuple(batch_size, num_channals, map_height, map_width, num_rois, StringPiece(pool_method).contains(StringPiece("max")))); 437 | // PSROIPoolingFunctor()(context, context->eigen_device(), inputs_in.tensor(), rois_in.tensor(), grid_dim_buffer[0], pooled_features->tensor()); 438 | } 439 | 440 | private: 441 | int32_t grid_dim_width_in{-1}; 442 | int32_t grid_dim_height_in{-1}; 443 | std::string pool_method{"max"}; 444 | }; 445 | 446 | // Register the CPU kernels. 447 | #define REGISTER_CPU(T) \ 448 | REGISTER_KERNEL_BUILDER( \ 449 | Name("RotatedPsRoiAlignGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ 450 | RotatedPSROIAlignGradOp); 451 | REGISTER_CPU(float); 452 | 453 | // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 454 | // #undef REGISTER_CPU 455 | 456 | // Register the GPU kernels. 457 | #if GOOGLE_CUDA == 1 458 | #define REGISTER_GPU(T) \ 459 | REGISTER_KERNEL_BUILDER( \ 460 | Name("RotatedPsRoiAlignGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ 461 | RotatedPSROIAlignGradOp); 462 | REGISTER_GPU(float); 463 | #endif // GOOGLE_CUDA 464 | --------------------------------------------------------------------------------