├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── contextlab ├── __init__.py ├── layers │ ├── __init__.py │ ├── cc_attention │ │ ├── __init__.py │ │ ├── functions.py │ │ ├── setup.py │ │ └── src │ │ │ ├── ca.cu │ │ │ ├── ca.h │ │ │ ├── common.h │ │ │ ├── lib_cffi.cpp │ │ │ └── lib_cffi.h │ ├── dual_attention │ │ ├── __init__.py │ │ └── dual_attention.py │ ├── em_attention │ │ ├── __init__.py │ │ └── emu.py │ ├── gcnet │ │ ├── __init__.py │ │ └── gcnet.py │ ├── latentgnn │ │ ├── __init__.py │ │ └── latentgnn.py │ ├── non_local │ │ ├── __init__.py │ │ └── non_local.py │ └── tree_filter │ │ ├── __init__.py │ │ ├── functions │ │ ├── bfs.py │ │ ├── mst.py │ │ └── refine.py │ │ ├── modules │ │ └── tree_filter.py │ │ ├── setup.py │ │ └── src │ │ ├── bfs │ │ ├── bfs.cu │ │ └── bfs.hpp │ │ ├── mst │ │ ├── boruvka.cpp │ │ ├── boruvka.hpp │ │ ├── mst.cu │ │ └── mst.hpp │ │ ├── refine │ │ ├── refine.cu │ │ └── refine.hpp │ │ └── tree_filter.cpp └── utils │ ├── __init__.py │ ├── layer_misc.py │ └── weight_init.py ├── setup.py ├── src └── images │ └── contextlab_logo.png └── test └── utils └── layer_misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | version.py 7 | # C extensions 8 | *.so 9 | # User added 10 | *.DS_Store 11 | .vscode/ 12 | github/ 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 PLUS Lab, ShanghaiTech University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ContextLab 2 | ContextLab: A Toolbox for Context Feature Augmentation developed with PyTorch 3 | 4 | 5 | 6 | 7 | ## Introduction 8 | 9 | 10 | The master branch works with **PyTorch 1.1** or higher 11 | 12 | ContextLab is an open source context feature augmentation toolbox based on PyTorch. It is a part of the Open-PLUS project developed by [ShanghaiTech PLUS Lab](http://plus.sist.shanghaitech.edu.cn) 13 | 14 | 15 | 16 | ## Major Features 17 | - **Modular Design** 18 | 19 | - **High Efficiency** 20 | 21 | - **State-of-the-art Performance** 22 | 23 | We have implemented several context augmentation algorithms in PyTorch with comparable performance. 24 | 25 | ## License 26 | This project is released under the [MIT License](LICENSE) 27 | 28 | ## Updates 29 | 30 | V0.2.0 (27/09/2019) 31 | - Support for **CCNet**, **TreeFilter** and **EMANet** 32 | 33 | v0.1.0 (26/07/2019) 34 | - Start the project 35 | 36 | 37 | ## Benchmark and Model Zoo 38 | 39 | 40 | 41 | | Method | Block-wise | Stage-wise | Paper | 42 | |--------------------|:------------:|:-----------:|:--------:| 43 | | Non-local Network | ✗ | ✓ | [CVPR 18](https://arxiv.org/abs/1711.07971) 44 | | Dual-attention | ✗ | ✓ | [CVPR 19](https://arxiv.org/abs/1809.02983) 45 | | GCNet | ✗ | ✓ | [Arxiv ](https://arxiv.org/abs/1904.11492) 46 | | CCNet | ✓ | ✓ | [ICCV 19](https://arxiv.org/abs/1811.11721) 47 | | LatentGNN | ✗ | ✓ | [ICML 19](https://arxiv.org/abs/1905.11634) 48 | | TreeFilter | ✗ | ✓ | [NIPS 19]() 49 | | EMANet | ✗ | ✓ | [ICCV 19](https://arxiv.org/abs/1907.13426) 50 | 51 | ## Installation 52 | 53 | ``` 54 | git clone https://github.com/SHTUPLUS/contextlab.git 55 | cd contextlab/ 56 | python setup.py build develop 57 | ``` 58 | 59 | ## Exapmles 60 | ```python 61 | # GCNet 62 | from contextlab.layers import GlobalContextBlock2d 63 | # Dual-Attention 64 | from contextlab.layers import SelfAttention 65 | # LatentGNN 66 | from contextlab.layers import LatentGNN 67 | # TreeFilter 68 | from contextlab.layers import MinimumSpanningTree, TreeFilter2D 69 | # CCNet 70 | from contextlab.layers import CrissCrossAttention 71 | # EMAttetnion 72 | from contextlab.layers import EMAttentionUnit 73 | ``` 74 | 75 | ## To do 76 | - [ ] Experiments on Segmentation and Detection 77 | - [ ] Performance Comparison 78 | 79 | ## Contributing 80 | 81 | We appreciate all contributions to improve ContextLab. Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for the contributing guideline. 82 | 83 | ## Acknowledgement 84 | ContextLab is an open source project that is contributed by researchers and engineers from various colledges and companies. We appreciate all the contributors who implement their methods or add new features. 85 | 86 | We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new segmentation methods. 87 | 88 | ## Citation 89 | ``` 90 | @misc{contextlab, 91 | title = {{ContextLab}: A Toolbox for Context Feature Augmentation}, 92 | author = {Songyang Zhang}, 93 | year={2019} 94 | } 95 | ``` 96 | ## Contact 97 | ``` 98 | email: sy.zhangbuaa@gmail.com 99 | ``` 100 | 101 | ## Related Projects 102 | - [Non-local Network(CVPR 18)](https://arxiv.org/abs/1711.07971) 103 | - [Video Classification(Caffe)](https://github.com/facebookresearch/video-nonlocal-net) 104 | - [Dual-attentio(CVPR 19)](https://arxiv.org/abs/1809.02983) 105 | - [Semantic Segmentation(PyTorch)](https://github.com/junfu1115/DANet) 106 | - [GCNet (Arxiv)](https://arxiv.org/abs/1904.11492) 107 | - [Object Detection(PyTorch)](https://github.com/xvjiarui/GCNet) 108 | - [CCNet (ICCV 19)](https://arxiv.org/abs/1811.11721) 109 | - [Semantic Segmentation(PyTorch)](https://github.com/speedinghzl/CCNet) 110 | - [LatentGNN (ICML 19)](https://arxiv.org/abs/1905.11634) 111 | - [Object Detection(PyTorch)](https://github.com/latentgnn/LatentGNN-V1-PyTorch) 112 | - [TreeFilter (NIPS 19)]() 113 | - [Semantic Segmentation(PyTorch)](https://github.com/StevenGrove/TreeFilter-Torch) 114 | - [EMANet (ICCV 19)](https://arxiv.org/abs/1907.13426) 115 | - [Semantic Segmentation(PyTorch)](https://github.com/XiaLiPKU/EMANet) 116 | -------------------------------------------------------------------------------- /contextlab/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__, short_version 2 | 3 | 4 | # from .layers.gcnet import GlobalContextBlock2d 5 | # from .layers.dual_attention import SelfAttention 6 | # from .layers.latentgnn import LatentGNN 7 | 8 | # from .layers.tree_filter import 9 | -------------------------------------------------------------------------------- /contextlab/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcnet import GlobalContextBlock2d 2 | from .dual_attention import SelfAttention 3 | from .latentgnn import LatentGNN 4 | from .tree_filter import MinimumSpanningTree, TreeFilter2D 5 | from .cc_attention import CrissCrossAttention 6 | from .em_attention import EMAttentionUnit -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import CrissCrossAttention, ca_weight, ca_map -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.autograd as autograd 5 | import torch.cuda.comm as comm 6 | import torch.nn.functional as F 7 | from torch.autograd.function import once_differentiable 8 | from torch.utils.cpp_extension import load 9 | import os, time 10 | import functools 11 | 12 | # curr_dir = os.path.dirname(os.path.abspath(__file__)) 13 | # _src_path = os.path.join(curr_dir, "src") 14 | # _build_path = os.path.join(curr_dir, "build") 15 | # os.makedirs(_build_path, exist_ok=True) 16 | # rcca = load(name="rcca", 17 | # extra_cflags=["-O3"], 18 | # build_directory=_build_path, 19 | # verbose=True, 20 | # sources = [os.path.join(_src_path, f) for f in [ 21 | # "lib_cffi.cpp", "ca.cu" 22 | # ]], 23 | # extra_cuda_cflags=["--expt-extended-lambda"]) 24 | # import rcca 25 | from . import rcca 26 | 27 | def _check_contiguous(*args): 28 | if not all([mod is None or mod.is_contiguous() for mod in args]): 29 | raise ValueError("Non-contiguous input") 30 | 31 | 32 | class CA_Weight(autograd.Function): 33 | @staticmethod 34 | def forward(ctx, t, f): 35 | # Save context 36 | n, c, h, w = t.size() 37 | size = (n, h+w-1, h, w) 38 | weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device) 39 | 40 | rcca.ca_forward_cuda(t, f, weight) 41 | 42 | # Output 43 | ctx.save_for_backward(t, f) 44 | 45 | return weight 46 | 47 | @staticmethod 48 | @once_differentiable 49 | def backward(ctx, dw): 50 | t, f = ctx.saved_tensors 51 | 52 | dt = torch.zeros_like(t) 53 | df = torch.zeros_like(f) 54 | 55 | rcca.ca_backward_cuda(dw.contiguous(), t, f, dt, df) 56 | 57 | _check_contiguous(dt, df) 58 | 59 | return dt, df 60 | 61 | class CA_Map(autograd.Function): 62 | @staticmethod 63 | def forward(ctx, weight, g): 64 | # Save context 65 | out = torch.zeros_like(g) 66 | rcca.ca_map_forward_cuda(weight, g, out) 67 | 68 | # Output 69 | ctx.save_for_backward(weight, g) 70 | 71 | return out 72 | 73 | @staticmethod 74 | @once_differentiable 75 | def backward(ctx, dout): 76 | weight, g = ctx.saved_tensors 77 | 78 | dw = torch.zeros_like(weight) 79 | dg = torch.zeros_like(g) 80 | 81 | rcca.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg) 82 | 83 | _check_contiguous(dw, dg) 84 | 85 | return dw, dg 86 | 87 | ca_weight = CA_Weight.apply 88 | ca_map = CA_Map.apply 89 | 90 | 91 | class CrissCrossAttention(nn.Module): 92 | """ Criss-Cross Attention Module""" 93 | def __init__(self,in_dim): 94 | super(CrissCrossAttention,self).__init__() 95 | self.chanel_in = in_dim 96 | 97 | self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) 98 | self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) 99 | self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) 100 | self.gamma = nn.Parameter(torch.zeros(1)) 101 | 102 | def forward(self, x): 103 | proj_query = self.query_conv(x) 104 | proj_key = self.key_conv(x) 105 | proj_value = self.value_conv(x) 106 | 107 | energy = ca_weight(proj_query, proj_key) 108 | attention = F.softmax(energy, 1) 109 | out = ca_map(attention, proj_value) 110 | out = self.gamma*out + x 111 | return out 112 | 113 | 114 | 115 | __all__ = ["CrissCrossAttention", "ca_weight", "ca_map"] 116 | 117 | 118 | if __name__ == "__main__": 119 | ca = CrissCrossAttention(256).cuda() 120 | x = torch.zeros(1, 8, 10, 10).cuda() + 1 121 | y = torch.zeros(1, 8, 10, 10).cuda() + 2 122 | z = torch.zeros(1, 64, 10, 10).cuda() + 3 123 | out = ca(x, y, z) 124 | print (out) 125 | -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/cc_attention/setup.py -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/src/ca.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "common.h" 6 | #include "ca.h" 7 | 8 | 9 | __global__ void ca_forward_kernel(const float *t, const float *f, float *weight, int num, int chn, int height, int width) { 10 | int x = blockIdx.x * blockDim.x + threadIdx.x; 11 | int y = blockIdx.y * blockDim.y + threadIdx.y; 12 | int sp = height * width; 13 | int len = height + width - 1; 14 | int z = blockIdx.z; 15 | 16 | if (x < width && y < height && z < height+width-1) { 17 | for (int batch = 0; batch < num; ++batch) { 18 | for (int plane = 0; plane < chn; ++plane) { 19 | float _t = t[(batch * chn + plane) * sp + y*width + x]; 20 | 21 | if (z < width) { 22 | int i = z; 23 | float _f = f[(batch * chn + plane) * sp + y*width + i]; 24 | weight[(batch * len + i) * sp + y*width + x] += _t*_f; 25 | } else { 26 | int i = z - width; 27 | int j = iy ? y : y-1; 86 | 87 | float _dw = dw[(batch * len + width + j) * sp + i*width + x]; 88 | float _t = t[(batch * chn + plane) * sp + i*width + x]; 89 | df[(batch * chn + plane) * sp + y*width + x] += _dw * _t; 90 | } 91 | } 92 | 93 | } 94 | } 95 | 96 | 97 | __global__ void ca_map_forward_kernel(const float *weight, const float *g, float *out, int num, int chn, int height, int width) { 98 | int x = blockIdx.x * blockDim.x + threadIdx.x; 99 | int y = blockIdx.y * blockDim.y + threadIdx.y; 100 | int sp = height * width; 101 | int len = height + width - 1; 102 | int plane = blockIdx.z; 103 | 104 | if (x < width && y < height && plane < chn) { 105 | for (int batch = 0; batch < num; ++batch) { 106 | 107 | for (int i = 0; i < width; ++i) { 108 | float _g = g[(batch * chn + plane) * sp + y*width + i]; 109 | float _w = weight[(batch * len + i) * sp + y*width + x]; 110 | out[(batch * chn + plane) * sp + y*width + x] += _g * _w; 111 | } 112 | for (int i = 0; i < height; ++i) { 113 | if (i == y) continue; 114 | 115 | int j = iy ? y : y-1; 176 | 177 | float _dout = dout[(batch * chn + plane) * sp + i*width + x]; 178 | float _w = weight[(batch * len + width + j) * sp + i*width + x]; 179 | dg[(batch * chn + plane) * sp + y*width + x] += _dout * _w; 180 | } 181 | } 182 | } 183 | } 184 | 185 | /* 186 | * Implementations 187 | */ 188 | extern "C" int _ca_forward_cuda(int N, int C, int H, int W, const float *t, 189 | const float *f, float *weight, cudaStream_t stream) { 190 | // Run kernel 191 | dim3 threads(32, 32); 192 | int d1 = (W+threads.x-1)/threads.x; 193 | int d2 = (H+threads.y-1)/threads.y; 194 | int d3 = H+W; 195 | dim3 blocks(d1, d2, d3); 196 | ca_forward_kernel<<>>(t, f, weight, N, C, H, W); 197 | 198 | // Check for errors 199 | cudaError_t err = cudaGetLastError(); 200 | if (err != cudaSuccess) 201 | return 0; 202 | else 203 | return 1; 204 | } 205 | 206 | 207 | extern "C" int _ca_backward_cuda(int N, int C, int H, int W, const float *dw, const float *t, const float *f, float *dt, float *df, cudaStream_t stream) { 208 | // Run kernel 209 | dim3 threads(32, 32); 210 | int d1 = (W+threads.x-1)/threads.x; 211 | int d2 = (H+threads.y-1)/threads.y; 212 | int d3 = C; 213 | dim3 blocks(d1, d2, d3); 214 | // printf("%f\n", dw[0]); 215 | ca_backward_kernel_t<<>>(dw, t, f, dt, N, C, H, W); 216 | ca_backward_kernel_f<<>>(dw, t, f, df, N, C, H, W); 217 | 218 | // Check for errors 219 | cudaError_t err = cudaGetLastError(); 220 | if (err != cudaSuccess) 221 | return 0; 222 | else 223 | return 1; 224 | } 225 | 226 | 227 | extern "C" int _ca_map_forward_cuda(int N, int C, int H, int W, const float *weight, const float *g, float *out, cudaStream_t stream) { 228 | // Run kernel 229 | dim3 threads(32, 32); 230 | dim3 blocks((W+threads.x-1)/threads.x, (H+threads.y-1)/threads.y, C); 231 | ca_map_forward_kernel<<>>(weight, g, out, N, C, H, W); 232 | 233 | // Check for errors 234 | cudaError_t err = cudaGetLastError(); 235 | if (err != cudaSuccess) 236 | return 0; 237 | else 238 | return 1; 239 | } 240 | 241 | extern "C" int _ca_map_backward_cuda(int N, int C, int H, int W, const float *dout, const float *weight, const float *g, float *dw, float *dg, cudaStream_t stream) { 242 | // Run kernel 243 | dim3 threads(32, 32); 244 | int d1 = (W+threads.x-1)/threads.x; 245 | int d2 = (H+threads.y-1)/threads.y; 246 | int d3 = H+W; 247 | dim3 blocks(d1, d2, d3); 248 | ca_map_backward_kernel_w<<>>(dout, weight, g, dw, N, C, H, W); 249 | 250 | d3 = C; 251 | blocks = dim3(d1, d2, d3); 252 | ca_map_backward_kernel_g<<>>(dout, weight, g, dg, N, C, H, W); 253 | 254 | // Check for errors 255 | cudaError_t err = cudaGetLastError(); 256 | if (err != cudaSuccess) 257 | return 0; 258 | else 259 | return 1; 260 | } 261 | -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/src/ca.h: -------------------------------------------------------------------------------- 1 | #ifndef __CA__ 2 | #define __CA__ 3 | 4 | /* 5 | * Exported functions 6 | */ 7 | extern "C" int _ca_forward_cuda(int N, int C, int H, int W, const float *t, const float *f, float *weight, cudaStream_t stream); 8 | extern "C" int _ca_backward_cuda(int N, int C, int H, int W, const float *dw, const float *t, const float *f, float *dt, float *df, cudaStream_t stream); 9 | extern "C" int _ca_map_forward_cuda(int N, int C, int H, int W, const float *weight, const float *g, float *out, cudaStream_t stream); 10 | extern "C" int _ca_map_backward_cuda(int N, int C, int H, int W, const float *dout, const float *weight, const float *g, float *dw, float *dg, cudaStream_t stream); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMON__ 2 | #define __COMMON__ 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | /* 12 | * Utility functions 13 | */ 14 | template 15 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 16 | unsigned int mask = 0xffffffff) { 17 | #if CUDART_VERSION >= 9000 18 | return __shfl_xor_sync(mask, value, laneMask, width); 19 | #else 20 | return __shfl_xor(value, laneMask, width); 21 | #endif 22 | } 23 | 24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 25 | 26 | static int getNumThreads(int nElem) { 27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 28 | for (int i = 0; i != 5; ++i) { 29 | if (nElem <= threadSizes[i]) { 30 | return threadSizes[i]; 31 | } 32 | } 33 | return MAX_BLOCK_SIZE; 34 | } 35 | 36 | 37 | #endif -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/src/lib_cffi.cpp: -------------------------------------------------------------------------------- 1 | // All functions assume that input and output tensors are already initialized 2 | // and have the correct dimensions 3 | #include 4 | #include 5 | #include "ca.h" 6 | 7 | extern THCState *state; 8 | 9 | int ca_forward_cuda(const at::Tensor& t, const at::Tensor& f, at::Tensor& weight) { 10 | cudaStream_t stream = THCState_getCurrentStream(state); 11 | int N, C, H, W; 12 | N = t.size(0); C = t.size(1); H = t.size(2); W = t.size(3); 13 | float * t_data = t.data(); 14 | float * f_data = f.data(); 15 | float * weight_data = weight.data(); 16 | return _ca_forward_cuda(N, C, H, W, t_data, f_data, weight_data, stream); 17 | } 18 | 19 | int ca_backward_cuda(const at::Tensor& dw, const at::Tensor& t, const at::Tensor& f, at::Tensor& dt, at::Tensor& df) { 20 | 21 | cudaStream_t stream = THCState_getCurrentStream(state); 22 | int N, C, H, W; 23 | N = t.size(0); C = t.size(1); H = t.size(2); W = t.size(3); 24 | float * t_data = t.data(); 25 | float * f_data = f.data(); 26 | float * dt_data = dt.data(); 27 | float * df_data = df.data(); 28 | float * dw_data = dw.data(); 29 | return _ca_backward_cuda(N, C, H, W, dw_data, t_data, f_data, dt_data, df_data, stream); 30 | } 31 | 32 | int ca_map_forward_cuda(const at::Tensor& weight, const at::Tensor& g, at::Tensor& out) { 33 | cudaStream_t stream = THCState_getCurrentStream(state); 34 | 35 | int N, C, H, W; 36 | N = g.size(0); C = g.size(1); H = g.size(2); W = g.size(3); 37 | 38 | const float *weight_data = weight.data(); 39 | const float *g_data = g.data(); 40 | float *out_data = out.data(); 41 | 42 | return _ca_map_forward_cuda(N, C, H, W, weight_data, g_data, out_data, stream); 43 | } 44 | 45 | int ca_map_backward_cuda(const at::Tensor& dout, const at::Tensor& weight, const at::Tensor& g, 46 | at::Tensor& dw, at::Tensor& dg) { 47 | cudaStream_t stream = THCState_getCurrentStream(state); 48 | 49 | int N, C, H, W; 50 | N = dout.size(0); C = dout.size(1); H = dout.size(2); W = dout.size(3); 51 | 52 | const float *dout_data = dout.data(); 53 | const float *weight_data = weight.data(); 54 | const float *g_data = g.data(); 55 | float *dw_data = dw.data(); 56 | float *dg_data = dg.data(); 57 | 58 | return _ca_map_backward_cuda(N, C, H, W, dout_data, weight_data, g_data, dw_data, dg_data, stream); 59 | } 60 | 61 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 62 | m.def("ca_forward_cuda", &ca_forward_cuda, "CA forward CUDA"); 63 | m.def("ca_backward_cuda", &ca_backward_cuda, "CA backward CUDA"); 64 | m.def("ca_map_forward_cuda", &ca_map_forward_cuda, "CA map forward CUDA"); 65 | m.def("ca_map_backward_cuda", &ca_map_backward_cuda, "CA map backward CUDA"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /contextlab/layers/cc_attention/src/lib_cffi.h: -------------------------------------------------------------------------------- 1 | int ca_forward_cuda(const at::Tensor& t, const at::Tensor& *f, at::Tensor& weight); 2 | int ca_backward_cuda(const at::Tensor& dw, const at::Tensor& t, const at::Tensor& f, at::Tensor& dt, at::Tensor& df); 3 | 4 | int ca_map_forward_cuda(const at::Tensor& weight, const at::Tensor& g, at::Tensor& out); 5 | int ca_map_backward_cuda(const at::Tensor& dout, const at::Tensor& weight, const at::Tensor& g, at::Tensor& dw, at::Tensor& dg); 6 | -------------------------------------------------------------------------------- /contextlab/layers/dual_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .dual_attention import SelfAttention -------------------------------------------------------------------------------- /contextlab/layers/dual_attention/dual_attention.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Description: Dual Attention Network 3 | @Author: Songyang Zhang 4 | @Email: sy.zhangbuaa@gmail.com 5 | @Date: 2019-07-13 17:10:33 6 | @LastEditTime: 2019-08-15 10:22:03 7 | @LastEditors: Songyang Zhang 8 | ''' 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import math 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | class SelfAttention(nn.Module): 19 | """ 20 | Self-attention Network/Non-local Network 21 | 22 | Args: 23 | 24 | Return: 25 | 26 | """ 27 | def __init__(self, inplane, outplane, channel_stride=8): 28 | super(SelfAttention, self).__init__() 29 | 30 | self.inplane = inplane 31 | self.outplane = outplane 32 | 33 | self.inter_channel = inplane // channel_stride 34 | 35 | self.query_conv = nn.Conv2d(in_channels=inplane, out_channels=self.inter_channel, kernel_size=1) 36 | self.key_conv = nn.Conv2d(in_channels=inplane, out_channels=self.inter_channel, kernel_size=1) 37 | 38 | self.value_conv = nn.Conv2d(in_channels=inplane, out_channels=outplane, kernel_size=1) 39 | if outplane != inplane: 40 | self.input_conv = nn.Conv2d(in_channels=inplane, out_channels=outplane, kernel_size=1) 41 | 42 | self.gamma = nn.Parameter(torch.zeros(1)) 43 | 44 | self.softmax = nn.Softmax(dim=-1) 45 | 46 | def forward(self, inputs): 47 | """ 48 | Args: 49 | inputs: (B, C, H, W) 50 | 51 | Return: 52 | augmented_feature: (B, C, H, W) 53 | """ 54 | 55 | B, C, H, W = inputs.size() 56 | query = self.query_conv(inputs).view(B, -1, H*W).permute(0, 2, 1) # B,N,C 57 | key = self.key_conv(inputs).view(B, -1, H*W) # B,C,N 58 | 59 | affinity_matrix = torch.bmm(query, key) 60 | affinity_matrix = self.softmax(affinity_matrix) # B, N, N 61 | 62 | value = self.value_conv(inputs).view(B, -1, H*W) 63 | 64 | out = torch.bmm(value, affinity_matrix) # B,C',N * B,N,N = B,C',N 65 | if self.inplane != self.outplane: 66 | inputs = self.input_conv(inputs) 67 | augmented_feature = self.gamma * out.view(B,-1, H, W) + inputs 68 | 69 | return augmented_feature 70 | 71 | # if __name__ == "__main__": 72 | # inputs = torch.randn(1,1024, 20,20) 73 | # model = SelfAttention(1024, 512,channel_stride=8) 74 | 75 | # out = model(inputs) 76 | # import pdb; pdb.set_trace() 77 | 78 | 79 | -------------------------------------------------------------------------------- /contextlab/layers/em_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .emu import EMAttentionUnit -------------------------------------------------------------------------------- /contextlab/layers/em_attention/emu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Description: Implement the Expectation-Maximization Attention Unit(EMAU) 3 | @Author: Songyang Zhang 4 | @Email: sy.zhangbuaa@gmail.com 5 | @License: (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University: 6 | @Date: 2019-09-28 13:54:35 7 | @LastEditors: Songyang Zhang 8 | @LastEditTime: 2019-09-28 14:30:26 9 | ''' 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import functional as F 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | 17 | class EMAttentionUnit(nn.Module): 18 | '''The Expectation-Maximization Attention Unit (EMAU). 19 | Arguments: 20 | c (int): The input and output channel number. 21 | k (int): The number of the bases. 22 | stage_num (int): The iteration number for EM. 23 | ''' 24 | def __init__(self, c, k, stage_num=3, norm_layer=nn.BatchNorm2d): 25 | super(EMAttentionUnit, self).__init__() 26 | self.stage_num = stage_num 27 | 28 | mu = torch.Tensor(1, c, k) 29 | mu.normal_(0, math.sqrt(2. / k)) # Init with Kaiming Norm. 30 | mu = self._l2norm(mu, dim=1) 31 | self.register_buffer('mu', mu) 32 | 33 | self.conv1 = nn.Conv2d(c, c, 1) 34 | self.conv2 = nn.Sequential( 35 | nn.Conv2d(c, c, 1, bias=False), 36 | norm_layer(c)) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | elif isinstance(m, _BatchNorm): 43 | m.weight.data.fill_(1) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | 47 | 48 | def forward(self, x): 49 | idn = x 50 | # The first 1x1 conv 51 | x = self.conv1(x) 52 | 53 | # The EM Attention 54 | b, c, h, w = x.size() 55 | x = x.view(b, c, h*w) # b * c * n 56 | mu = self.mu.repeat(b, 1, 1) # b * c * k 57 | with torch.no_grad(): 58 | for i in range(self.stage_num): 59 | x_t = x.permute(0, 2, 1) # b * n * c 60 | z = torch.bmm(x_t, mu) # b * n * k 61 | z = F.softmax(z, dim=2) # b * n * k 62 | z_ = z / (1e-6 + z.sum(dim=1, keepdim=True)) 63 | mu = torch.bmm(x, z_) # b * c * k 64 | mu = self._l2norm(mu, dim=1) 65 | 66 | z_t = z.permute(0, 2, 1) # b * k * n 67 | x = mu.matmul(z_t) # b * c * n 68 | x = x.view(b, c, h, w) # b * c * h * w 69 | x = F.relu(x, inplace=True) 70 | 71 | # The second 1x1 conv 72 | x = self.conv2(x) 73 | x = x + idn 74 | x = F.relu(x, inplace=True) 75 | 76 | return x, mu 77 | 78 | def _l2norm(self, inp, dim): 79 | '''Normlize the inp tensor with l2-norm. 80 | Returns a tensor where each sub-tensor of input along the given dim is 81 | normalized such that the 2-norm of the sub-tensor is equal to 1. 82 | Arguments: 83 | inp (tensor): The input tensor. 84 | dim (int): The dimension to slice over to get the ssub-tensors. 85 | Returns: 86 | (tensor) The normalized tensor. 87 | ''' 88 | return inp / (1e-6 + inp.norm(dim=dim, keepdim=True)) -------------------------------------------------------------------------------- /contextlab/layers/gcnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcnet import GlobalContextBlock2d -------------------------------------------------------------------------------- /contextlab/layers/gcnet/gcnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Description: GCNet file from the MSRA 3 | @Author: Songyang Zhang 4 | @Email: sy.zhangbuaa@gmail.com 5 | @Date: 2019-08-15 10:20:38 6 | @LastEditors: Songyang Zhang 7 | @LastEditTime: 2019-09-27 18:40:35 8 | ''' 9 | 10 | import torch 11 | from torch import nn 12 | # from mmcv.cnn import constant_init, kaiming_init 13 | from contextlab.utils.weight_init import constant_init, kaiming_init 14 | 15 | def last_zero_init(m): 16 | if isinstance(m, nn.Sequential): 17 | constant_init(m[-1], val=0) 18 | m[-1].inited = True 19 | else: 20 | constant_init(m, val=0) 21 | m.inited = True 22 | 23 | 24 | class GlobalContextBlock2d(nn.Module): 25 | 26 | def __init__(self, inplanes, stride, pool, fusions): 27 | super(GlobalContextBlock2d, self).__init__() 28 | assert pool in ['avg', 'att'] 29 | assert all([f in ['channel_add', 'channel_mul'] for f in fusions]) 30 | assert len(fusions) > 0, 'at least one fusion should be used' 31 | self.inplanes = inplanes 32 | self.planes = inplanes // stride 33 | self.pool = pool 34 | self.fusions = fusions 35 | if 'att' in pool: 36 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 37 | self.softmax = nn.Softmax(dim=2) 38 | else: 39 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 40 | if 'channel_add' in fusions: 41 | self.channel_add_conv = nn.Sequential( 42 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 43 | nn.LayerNorm([self.planes, 1, 1]), 44 | nn.ReLU(inplace=True), 45 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1) 46 | ) 47 | else: 48 | self.channel_add_conv = None 49 | if 'channel_mul' in fusions: 50 | self.channel_mul_conv = nn.Sequential( 51 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 52 | nn.LayerNorm([self.planes, 1, 1]), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1) 55 | ) 56 | else: 57 | self.channel_mul_conv = None 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | if self.pool == 'att': 62 | kaiming_init(self.conv_mask, mode='fan_in') 63 | self.conv_mask.inited = True 64 | 65 | if self.channel_add_conv is not None: 66 | last_zero_init(self.channel_add_conv) 67 | if self.channel_mul_conv is not None: 68 | last_zero_init(self.channel_mul_conv) 69 | 70 | def spatial_pool(self, x): 71 | batch, channel, height, width = x.size() 72 | if self.pool == 'att': 73 | input_x = x 74 | # [N, C, H * W] 75 | input_x = input_x.view(batch, channel, height * width) 76 | # [N, 1, C, H * W] 77 | input_x = input_x.unsqueeze(1) 78 | # [N, 1, H, W] 79 | context_mask = self.conv_mask(x) 80 | # [N, 1, H * W] 81 | context_mask = context_mask.view(batch, 1, height * width) 82 | # [N, 1, H * W] 83 | context_mask = self.softmax(context_mask) 84 | # [N, 1, H * W, 1] 85 | context_mask = context_mask.unsqueeze(3) 86 | # [N, 1, C, 1] 87 | context = torch.matmul(input_x, context_mask) 88 | # [N, C, 1, 1] 89 | context = context.view(batch, channel, 1, 1) 90 | else: 91 | # [N, C, 1, 1] 92 | context = self.avg_pool(x) 93 | 94 | return context 95 | 96 | def forward(self, x): 97 | # [N, C, 1, 1] 98 | context = self.spatial_pool(x) 99 | 100 | if self.channel_mul_conv is not None: 101 | # [N, C, 1, 1] 102 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 103 | out = x * channel_mul_term 104 | else: 105 | out = x 106 | if self.channel_add_conv is not None: 107 | # [N, C, 1, 1] 108 | channel_add_term = self.channel_add_conv(context) 109 | out = out + channel_add_term 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /contextlab/layers/latentgnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .latentgnn import LatentGNN -------------------------------------------------------------------------------- /contextlab/layers/latentgnn/latentgnn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Description: LatentGNN V1 Version 3 | @Author: Songyang Zhang 4 | @Email: sy.zhangbuaa@gmail.com 5 | @License: (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University: 6 | @Date: 2019-08-15 10:23:24 7 | @LastEditors: Songyang Zhang 8 | @LastEditTime: 2019-08-15 10:25:01 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import numpy as np 16 | 17 | class LatentGNN(nn.Module): 18 | """ 19 | Latent Graph Neural Network for Non-local Relations Learning 20 | Args: 21 | in_channels (int): Number of channels in the input feature 22 | latent_dims (list): List of latent dimensions 23 | channel_stride (int): Channel reduction factor. Default: 4 24 | num_kernels (int): Number of latent kernels used. Default: 1 25 | mode (str): Mode of bipartite graph message propagation. Default: 'asymmetric'. 26 | without_residual (bool): Flag of use residual connetion. Default: False 27 | norm_layer (nn.Module): Module used for batch normalization. Default: nn.BatchNorm2d. 28 | norm_func (function): Function used for normalization. Default: F.normalize 29 | graph_conv_flag (bool): Flag of use graph convolution layer. Default: False 30 | """ 31 | def __init__(self, in_channels, latent_dims, 32 | channel_stride=4, num_kernels=1, 33 | mode='asymmetric', without_residual=False, 34 | norm_layer=nn.BatchNorm2d, norm_func=F.normalize, 35 | graph_conv_flag=False): 36 | super(LatentGNN, self).__init__() 37 | self.without_resisual = without_residual 38 | self.num_kernels = num_kernels 39 | self.mode = mode 40 | self.norm_func = norm_func 41 | 42 | inter_channel = in_channels // channel_stride 43 | 44 | # Reduce the channel dimension for efficiency 45 | if mode == 'asymmetric': 46 | self.down_channel_v2l = nn.Sequential( 47 | nn.Conv2d(in_channels=in_channels, 48 | out_channels=inter_channel, 49 | kernel_size=1, padding=0, bias=False), 50 | norm_layer(inter_channel), 51 | ) 52 | # nn.init.kaiming_uniform_(self.down_channel_v2l[0].weight, a=1) 53 | # nn.init.kaiming_uniform_(self.down_channel_v2l[0].weight, mode='fan_in') 54 | self.down_channel_l2v = nn.Sequential( 55 | nn.Conv2d(in_channels=in_channels, 56 | out_channels=inter_channel, 57 | kernel_size=1, padding=0, bias=False), 58 | norm_layer(inter_channel), 59 | ) 60 | # nn.init.kaiming_uniform_(self.down_channel_l2v[0].weight, a=1) 61 | # nn.init.kaiming_uniform_(self.down_channel_l2v[0].weight, mode='fan_in') 62 | 63 | elif mode == 'symmetric': 64 | self.down_channel = nn.Sequential( 65 | nn.Conv2d(in_channels=in_channels, 66 | out_channels=inter_channel, 67 | kernel_size=1, padding=0, bias=False), 68 | norm_layer(inter_channel), 69 | ) 70 | # nn.init.kaiming_uniform_(self.down_channel[0].weight, a=1) 71 | # nn.init.kaiming_uniform_(self.down_channel[0].weight, mode='fan_in') 72 | else: 73 | raise NotImplementedError 74 | 75 | # Define the latentgnn kernel 76 | assert len(latent_dims) == num_kernels, 'Latent dimensions mismatch with number of kernels' 77 | 78 | for i in range(num_kernels): 79 | self.add_module('LatentGNN_Kernel_{}'.format(i), 80 | LatentGNN_Kernel(in_channels=inter_channel, 81 | num_kernels=num_kernels, 82 | latent_dim=latent_dims[i], 83 | norm_layer=norm_layer, 84 | norm_func=norm_func, 85 | mode=mode, 86 | graph_conv_flag=graph_conv_flag)) 87 | # Increase the channel for the output 88 | self.up_channel = nn.Sequential( 89 | nn.Conv2d(in_channels=inter_channel*num_kernels, 90 | out_channels=in_channels, 91 | kernel_size=1, padding=0,bias=False), 92 | norm_layer(in_channels), 93 | ) 94 | # nn.init.kaiming_uniform_(self.up_channel[0].weight, a=1) 95 | # nn.init.kaiming_uniform_(self.up_channel[0].weight, mode='fan_in') 96 | 97 | # Residual Connection 98 | self.gamma = nn.Parameter(torch.zeros(1)) 99 | 100 | def forward(self, conv_feature): 101 | # Generate visible space feature 102 | if self.mode == 'asymmetric': 103 | v2l_conv_feature = self.down_channel_v2l(conv_feature) 104 | l2v_conv_feature = self.down_channel_l2v(conv_feature) 105 | v2l_conv_feature = self.norm_func(v2l_conv_feature, dim=1) 106 | l2v_conv_feature = self.norm_func(l2v_conv_feature, dim=1) 107 | elif self.mode == 'symmetric': 108 | v2l_conv_feature = self.norm_func(self.down_channel(conv_feature), dim=1) 109 | l2v_conv_feature = None 110 | out_features = [] 111 | for i in range(self.num_kernels): 112 | out_features.append(eval('self.LatentGNN_Kernel_{}'.format(i))(v2l_conv_feature, l2v_conv_feature)) 113 | 114 | out_features = torch.cat(out_features, dim=1) if self.num_kernels > 1 else out_features[0] 115 | 116 | out_features = self.up_channel(out_features) 117 | 118 | if self.without_resisual: 119 | return out_features 120 | else: 121 | return conv_feature + out_features*self.gamma 122 | 123 | class LatentGNN_Kernel(nn.Module): 124 | """ 125 | A LatentGNN Kernel Implementation 126 | Args: 127 | """ 128 | def __init__(self, in_channels, num_kernels, 129 | latent_dim, norm_layer, 130 | norm_func, mode, graph_conv_flag): 131 | super(LatentGNN_Kernel, self).__init__() 132 | self.mode = mode 133 | self.norm_func = norm_func 134 | #---------------------------------------------- 135 | # Step1 & 3: Visible-to-Latent & Latent-to-Visible 136 | #---------------------------------------------- 137 | 138 | if mode == 'asymmetric': 139 | self.psi_v2l = nn.Sequential( 140 | nn.Conv2d(in_channels=in_channels, 141 | out_channels=latent_dim, 142 | kernel_size=1, padding=0, 143 | bias=False), 144 | norm_layer(latent_dim), 145 | nn.ReLU(inplace=True), 146 | ) 147 | # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, a=1) 148 | # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, mode='fan_in') 149 | self.psi_l2v = nn.Sequential( 150 | nn.Conv2d(in_channels=in_channels, 151 | out_channels=latent_dim, 152 | kernel_size=1, padding=0, 153 | bias=False), 154 | norm_layer(latent_dim), 155 | nn.ReLU(inplace=True), 156 | ) 157 | # nn.init.kaiming_uniform_(self.psi_l2v[0].weight, a=1) 158 | # nn.init.kaiming_uniform_(self.psi_l2v[0].weight, mode='fan_in') 159 | elif mode == 'symmetric': 160 | self.psi = nn.Sequential( 161 | nn.Conv2d(in_channels=in_channels, 162 | out_channels=latent_dim, 163 | kernel_size=1, padding=0, 164 | bias=False), 165 | norm_layer(latent_dim), 166 | nn.ReLU(inplace=True), 167 | ) 168 | # nn.init.kaiming_uniform_(self.psi[0].weight, a=1) 169 | # nn.init.kaiming_uniform_(self.psi[0].weight, mode='fan_in') 170 | 171 | #---------------------------------------------- 172 | # Step2: Latent Messge Passing 173 | #---------------------------------------------- 174 | self.graph_conv_flag = graph_conv_flag 175 | if graph_conv_flag: 176 | self.GraphConvWeight = nn.Sequential( 177 | # nn.Linear(in_channels, in_channels,bias=False), 178 | nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False), 179 | norm_layer(in_channels), 180 | nn.ReLU(inplace=True), 181 | ) 182 | nn.init.normal_(self.GraphConvWeight[0].weight, std=0.01) 183 | 184 | def forward(self, v2l_conv_feature, l2v_conv_feature): 185 | B, C, H, W = v2l_conv_feature.shape 186 | 187 | # Generate Bipartite Graph Adjacency Matrix 188 | if self.mode == 'asymmetric': 189 | v2l_graph_adj = self.psi_v2l(v2l_conv_feature) 190 | l2v_graph_adj = self.psi_l2v(l2v_conv_feature) 191 | v2l_graph_adj = self.norm_func(v2l_graph_adj.view(B,-1, H*W), dim=2) 192 | l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=1) 193 | # l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=2) 194 | elif self.mode == 'symmetric': 195 | assert l2v_conv_feature is None 196 | l2v_graph_adj = v2l_graph_adj = self.norm_func(self.psi(v2l_conv_feature).view(B,-1, H*W), dim=1) 197 | 198 | #---------------------------------------------- 199 | # Step1 : Visible-to-Latent 200 | #---------------------------------------------- 201 | latent_node_feature = torch.bmm(v2l_graph_adj, v2l_conv_feature.view(B, -1, H*W).permute(0,2,1)) 202 | # if self.graph_conv_flag: 203 | # latent_node_feature = F.relu(latent_node_feature, inplace=True) 204 | #---------------------------------------------- 205 | # Step2 : Latent-to-Latent 206 | #---------------------------------------------- 207 | # Generate Dense-connected Graph Adjacency Matrix 208 | latent_node_feature_n = self.norm_func(latent_node_feature, dim=-1) 209 | affinity_matrix = torch.bmm(latent_node_feature_n, latent_node_feature_n.permute(0,2,1)) 210 | affinity_matrix = F.softmax(affinity_matrix, dim=-1) 211 | 212 | latent_node_feature = torch.bmm(affinity_matrix, latent_node_feature) 213 | 214 | #---------------------------------------------- 215 | # Step3: Latent-to-Visible 216 | #---------------------------------------------- 217 | visible_feature = torch.bmm(latent_node_feature.permute(0,2,1), l2v_graph_adj).view(B, -1, H, W) 218 | # if self.graph_conv_flag: 219 | # visible_feature = F.relu(visible_feature, inplace=True) 220 | 221 | if self.graph_conv_flag: 222 | visible_feature = self.GraphConvWeight(visible_feature) 223 | 224 | return visible_feature -------------------------------------------------------------------------------- /contextlab/layers/non_local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/non_local/__init__.py -------------------------------------------------------------------------------- /contextlab/layers/non_local/non_local.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/non_local/non_local.py -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules.tree_filter import MinimumSpanningTree 2 | from .modules.tree_filter import TreeFilter2D 3 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/functions/bfs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from torch.nn.modules.utils import _pair 6 | 7 | # import ..tree_filter_cuda as _C 8 | # import tree_filter_cuda as _C 9 | from . import tree_filter_cuda as _C 10 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C 11 | 12 | class _BFS(Function): 13 | @staticmethod 14 | def forward(ctx, edge_index, max_adj_per_vertex): 15 | sorted_index, sorted_parent, sorted_child =\ 16 | _C.bfs_forward(edge_index, max_adj_per_vertex) 17 | return sorted_index, sorted_parent, sorted_child 18 | 19 | bfs = _BFS.apply 20 | 21 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/functions/mst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from torch.nn.modules.utils import _pair 6 | from . import tree_filter_cuda as _C 7 | 8 | # import tree_filter_cuda as _C 9 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C 10 | 11 | class _MST(Function): 12 | @staticmethod 13 | def forward(ctx, edge_index, edge_weight, vertex_index): 14 | edge_out = _C.mst_forward(edge_index, edge_weight, vertex_index) 15 | return edge_out 16 | 17 | @staticmethod 18 | @once_differentiable 19 | def backward(ctx, grad_output): 20 | return None, None, None 21 | 22 | mst = _MST.apply 23 | 24 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/functions/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from torch.nn.modules.utils import _pair 6 | from . import tree_filter_cuda as _C 7 | 8 | # import tree_filter_cuda as _C 9 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C 10 | 11 | class _Refine(Function): 12 | @staticmethod 13 | def forward(ctx, feature_in, edge_weight, sorted_index, sorted_parent, sorted_child): 14 | feature_out, feature_aggr, feature_aggr_up, weight_sum, weight_sum_up, =\ 15 | _C.refine_forward(feature_in, edge_weight, sorted_index, sorted_parent, sorted_child) 16 | 17 | ctx.save_for_backward(feature_in, edge_weight, sorted_index, sorted_parent, 18 | sorted_child, feature_out, feature_aggr, feature_aggr_up, weight_sum, 19 | weight_sum_up) 20 | return feature_out 21 | 22 | @staticmethod 23 | @once_differentiable 24 | def backward(ctx, grad_output): 25 | feature_in, edge_weight, sorted_index, sorted_parent,\ 26 | sorted_child, feature_out, feature_aggr, feature_aggr_up, weight_sum,\ 27 | weight_sum_up, = ctx.saved_tensors; 28 | 29 | grad_feature = _C.refine_backward_feature(feature_in, edge_weight, sorted_index, 30 | sorted_parent, sorted_child, feature_out, feature_aggr, feature_aggr_up, 31 | weight_sum, weight_sum_up, grad_output) 32 | grad_weight = _C.refine_backward_weight(feature_in, edge_weight, sorted_index, 33 | sorted_parent, sorted_child, feature_out, feature_aggr, feature_aggr_up, 34 | weight_sum, weight_sum_up, grad_output) 35 | 36 | return grad_feature, grad_weight, None, None, None 37 | 38 | refine = _Refine.apply 39 | 40 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/modules/tree_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.distributed as dist 4 | 5 | from ..functions.mst import mst 6 | from ..functions.bfs import bfs 7 | from ..functions.refine import refine 8 | 9 | class MinimumSpanningTree(nn.Module): 10 | def __init__(self, distance_func): 11 | super(MinimumSpanningTree, self).__init__() 12 | self.distance_func = distance_func 13 | 14 | @staticmethod 15 | def _build_matrix_index(fm): 16 | batch, height, width = (fm.shape[0], *fm.shape[2:]) 17 | row = torch.arange(width, dtype=torch.int32, device=fm.device).unsqueeze(0) 18 | col = torch.arange(height, dtype=torch.int32, device=fm.device).unsqueeze(1) 19 | raw_index = row + col * width 20 | row_index = torch.stack([raw_index[:-1, :], raw_index[1:, :]], 2) 21 | col_index = torch.stack([raw_index[:, :-1], raw_index[:, 1:]], 2) 22 | index = torch.cat([row_index.reshape(1, -1, 2), 23 | col_index.reshape(1, -1, 2)], 1) 24 | index = index.expand(batch, -1, -1) 25 | return index 26 | 27 | def _build_feature_weight(self, fm): 28 | batch = fm.shape[0] 29 | weight_row = self.distance_func(fm[:, :, :-1, :], fm[:, :, 1:, :]) 30 | weight_col = self.distance_func(fm[:, :, :, :-1], fm[:, :, :, 1:]) 31 | weight_row = weight_row.reshape([batch, -1]) 32 | weight_col = weight_col.reshape([batch, -1]) 33 | weight = torch.cat([weight_row, weight_col], dim=1) + 1 34 | return weight 35 | 36 | def forward(self, guide_in): 37 | with torch.no_grad(): 38 | index = self._build_matrix_index(guide_in) 39 | weight = self._build_feature_weight(guide_in) 40 | tree = mst(index, weight, guide_in.shape[2] * guide_in.shape[3]) 41 | return tree 42 | 43 | 44 | class TreeFilter2D(nn.Module): 45 | def __init__(self, groups=1, distance_func=None, enable_log=False): 46 | super(TreeFilter2D, self).__init__() 47 | self.groups = groups 48 | self.enable_log = enable_log 49 | if distance_func is None: 50 | self.distance_func = self.norm2_distance 51 | else: 52 | self.distance_func = distance_func 53 | 54 | @staticmethod 55 | def norm2_distance(fm_ref, fm_tar): 56 | diff = fm_ref - fm_tar 57 | weight = (diff * diff).sum(dim=1) 58 | return weight 59 | 60 | @staticmethod 61 | def batch_index_opr(data, index): 62 | with torch.no_grad(): 63 | channel = data.shape[1] 64 | index = index.unsqueeze(1).expand(-1, channel, -1).long() 65 | data = torch.gather(data, 2, index) 66 | return data 67 | 68 | def build_edge_weight(self, fm, sorted_index, sorted_parent): 69 | batch = fm.shape[0] 70 | channel = fm.shape[1] 71 | vertex = fm.shape[2] * fm.shape[3] 72 | 73 | fm = fm.reshape([batch, channel, -1]) 74 | fm_source = self.batch_index_opr(fm, sorted_index) 75 | fm_target = self.batch_index_opr(fm_source, sorted_parent) 76 | fm_source = fm_source.reshape([-1, channel // self.groups, vertex]) 77 | fm_target = fm_target.reshape([-1, channel // self.groups, vertex]) 78 | 79 | edge_weight = self.distance_func(fm_source, fm_target) 80 | edge_weight = torch.exp(-edge_weight) 81 | return edge_weight 82 | 83 | def split_group(self, feature_in, *tree_orders): 84 | feature_in = feature_in.reshape(feature_in.shape[0] * self.groups, 85 | feature_in.shape[1] // self.groups, 86 | -1) 87 | returns = [feature_in.contiguous()] 88 | for order in tree_orders: 89 | order = order.unsqueeze(1).expand(order.shape[0], self.groups, *order.shape[1:]) 90 | order = order.reshape(-1, *order.shape[2:]) 91 | returns.append(order.contiguous()) 92 | return tuple(returns) 93 | 94 | def print_info(self, edge_weight): 95 | edge_weight = edge_weight.clone() 96 | info = torch.stack([edge_weight.mean(), edge_weight.std(), edge_weight.max(), edge_weight.min()]) 97 | if self.training and dist.is_initialized(): 98 | dist.all_reduce(info / dist.get_world_size()) 99 | info_str = (float(x) for x in info) 100 | if dist.get_rank() == 0: 101 | print('Mean:{0:.4f}, Std:{1:.4f}, Max:{2:.4f}, Min:{3:.4f}'.format(*info_str)) 102 | else: 103 | info_str = [float(x) for x in info] 104 | print('Mean:{0:.4f}, Std:{1:.4f}, Max:{2:.4f}, Min:{3:.4f}'.format(*info_str)) 105 | 106 | def forward(self, feature_in, embed_in, tree): 107 | ori_shape = feature_in.shape 108 | sorted_index, sorted_parent, sorted_child = bfs(tree, 4) 109 | edge_weight = self.build_edge_weight(embed_in, sorted_index, sorted_parent) 110 | with torch.no_grad(): 111 | if self.enable_log: self.print_info(edge_weight) 112 | feature_in, sorted_index, sorted_parent, sorted_child = \ 113 | self.split_group(feature_in, sorted_index, sorted_parent, sorted_child) 114 | feature_out = refine(feature_in, edge_weight, sorted_index, 115 | sorted_parent, sorted_child) 116 | feature_out = feature_out.reshape(ori_shape) 117 | return feature_out 118 | 119 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import shutil 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = os.path.dirname(os.path.abspath(__file__)) 9 | extensions_dir = os.path.join(this_dir, "src") 10 | 11 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 12 | source_cpu = glob.glob(os.path.join(extensions_dir, "*", "*.cpp")) 13 | source_cuda = glob.glob(os.path.join(extensions_dir, "*", "*.cu")) 14 | 15 | if torch.cuda.is_available(): 16 | if 'LD_LIBRARY_PATH' not in os.environ: 17 | raise Exception('LD_LIBRARY_PATH is not set.') 18 | cuda_lib_path = os.environ['LD_LIBRARY_PATH'].split(':') 19 | sources = source_cpu + source_cuda + main_file 20 | else: 21 | raise Exception('This implementation is only avaliable for CUDA devices.') 22 | 23 | print(sources) 24 | 25 | setup( 26 | name='tree_filter', 27 | version="0.1", 28 | description="learnable tree filter for pytorch", 29 | ext_modules=[ 30 | CUDAExtension( 31 | name='tree_filter_cuda', 32 | include_dirs=[extensions_dir], 33 | sources=sources, 34 | library_dirs=cuda_lib_path, 35 | extra_compile_args={'cxx':['-O3'], 36 | 'nvcc':['-O3']}) 37 | ], 38 | cmdclass={ 39 | 'build_ext': BuildExtension 40 | }) 41 | 42 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/bfs/bfs.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #define CUDA_CHECK(call) if((call) != cudaSuccess) {cudaError_t err = cudaGetLastError(); std::cout << "CUDA error calling ""#call"", code is " << err << std::endl;} 17 | 18 | #define CUDA_NUM_THREADS 64 19 | #define GET_CUDA_BLOCKS(N) ceil((float)N / CUDA_NUM_THREADS) 20 | 21 | __global__ void adj_vec_kernel( 22 | int batch_size, 23 | int * edge_index, 24 | int vertex_count, 25 | int * adj_vec, 26 | int * adj_vec_len, 27 | int max_adj_per_node){ 28 | 29 | const int edge_count = vertex_count - 1; 30 | const int batch_idx = blockIdx.x; 31 | const int thread_idx = threadIdx.x; 32 | const int thread_count = blockDim.x; 33 | 34 | edge_index += batch_idx * edge_count * 2; 35 | adj_vec += batch_idx * vertex_count * max_adj_per_node; 36 | adj_vec_len += batch_idx * vertex_count; 37 | 38 | for (int i = thread_idx; i < edge_count; i += thread_count){ 39 | int source = edge_index[2 * i]; 40 | int target = edge_index[2 * i + 1]; 41 | int source_len = atomicAdd(&(adj_vec_len[source]), 1); 42 | adj_vec[source * max_adj_per_node + source_len] = target; 43 | int target_len = atomicAdd(&(adj_vec_len[target]), 1); 44 | adj_vec[target * max_adj_per_node + target_len] = source; 45 | } 46 | } 47 | 48 | __global__ void breadth_first_sort_kernel( 49 | int * sorted_index, 50 | int * sorted_parent_index, 51 | int * sorted_child_index, 52 | int * adj_vec, 53 | int * adj_vec_len, 54 | int * parent_index, 55 | int batch_size, 56 | int vertex_count, 57 | int max_adj_per_node){ 58 | 59 | const int batch_idx = blockIdx.x; 60 | const int thread_idx = threadIdx.x; 61 | const int thread_count = blockDim.x; 62 | 63 | adj_vec += batch_idx * vertex_count * max_adj_per_node; 64 | adj_vec_len += batch_idx * vertex_count; 65 | parent_index += batch_idx * vertex_count; 66 | sorted_index += batch_idx * vertex_count; 67 | sorted_parent_index += batch_idx * vertex_count; 68 | sorted_child_index += batch_idx * vertex_count * max_adj_per_node; 69 | 70 | __shared__ int sorted_len; 71 | if (thread_idx == 0) { 72 | sorted_len = 1; 73 | parent_index[0] = 0; 74 | sorted_index[0] = 0; 75 | sorted_parent_index[0] = 0; 76 | } 77 | __syncthreads(); 78 | 79 | int i = thread_idx; 80 | while (i < vertex_count){ 81 | if ((sorted_index[i] > 0) || (i == 0)){ 82 | int child_index = 0; 83 | int par = parent_index[i]; 84 | int cur = sorted_index[i]; 85 | for (int j = 0; j < adj_vec_len[cur]; j++){ 86 | int child = adj_vec[cur * max_adj_per_node + j]; 87 | if (child != par){ 88 | int pos = atomicAdd(&(sorted_len), 1); 89 | sorted_index[pos] = child; 90 | parent_index[pos] = cur; 91 | sorted_parent_index[pos] = i; 92 | sorted_child_index[i * max_adj_per_node + child_index] = pos; 93 | child_index++; 94 | } 95 | } 96 | i += thread_count; 97 | } 98 | __syncthreads(); 99 | } 100 | } 101 | 102 | std::tuple 103 | bfs_forward( 104 | const at::Tensor & edge_index_tensor, 105 | int max_adj_per_node){ 106 | 107 | int batch_size = edge_index_tensor.size(0); 108 | int vertex_count = edge_index_tensor.size(1) + 1; 109 | 110 | auto options = edge_index_tensor.options(); 111 | auto sorted_index_tensor = at::zeros({batch_size, vertex_count}, options); 112 | auto sorted_parent_tensor = at::zeros({batch_size, vertex_count}, options); 113 | auto sorted_child_tensor = at::zeros({batch_size, vertex_count, max_adj_per_node}, options); 114 | auto adj_vec_tensor = at::zeros({batch_size, vertex_count, max_adj_per_node}, options); 115 | auto adj_vec_len_tensor = at::zeros({batch_size, vertex_count}, options); 116 | auto parent_index_tensor = at::zeros({batch_size, vertex_count}, options); 117 | 118 | int * edge_index = edge_index_tensor.contiguous().data(); 119 | int * sorted_index = sorted_index_tensor.contiguous().data(); 120 | int * sorted_parent = sorted_parent_tensor.contiguous().data(); 121 | int * sorted_child = sorted_child_tensor.contiguous().data(); 122 | int * adj_vec = adj_vec_tensor.contiguous().data(); 123 | int * adj_vec_len = adj_vec_len_tensor.contiguous().data(); 124 | int * parent_index = parent_index_tensor.contiguous().data(); 125 | 126 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 127 | 128 | dim3 block_dims(CUDA_NUM_THREADS, 1, 1), grid_dims(batch_size, 1, 1); 129 | adj_vec_kernel <<< grid_dims, block_dims, 0, stream >>>( 130 | batch_size, edge_index, vertex_count, adj_vec, adj_vec_len, max_adj_per_node); 131 | 132 | breadth_first_sort_kernel <<< grid_dims, block_dims, 1, stream >>>( 133 | sorted_index, sorted_parent, sorted_child, adj_vec, adj_vec_len, parent_index, 134 | batch_size, vertex_count, max_adj_per_node); 135 | 136 | return std::make_tuple(sorted_index_tensor, sorted_parent_tensor, sorted_child_tensor); 137 | } 138 | 139 | 140 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/bfs/bfs.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | extern std::tuple 5 | bfs_forward( 6 | const at::Tensor & edge_index_tensor, 7 | int max_adj_per_node 8 | ); 9 | 10 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/mst/boruvka.cpp: -------------------------------------------------------------------------------- 1 | // Boruvka's algorithm to find Minimum Spanning 2 | // Tree of a given connected, undirected and 3 | // weighted graph 4 | #include 5 | #include "boruvka.hpp" 6 | 7 | // A structure to represent a subset for union-find 8 | struct subset 9 | { 10 | int parent; 11 | int rank; 12 | }; 13 | 14 | // Function prototypes for union-find (These functions are defined 15 | // after boruvkaMST() ) 16 | int find(struct subset subsets[], int i); 17 | void Union(struct subset subsets[], int x, int y); 18 | 19 | // The main function for MST using Boruvka's algorithm 20 | void boruvkaMST(struct Graph* graph, int * edge_out) 21 | { 22 | // Get data of given graph 23 | int V = graph->V, E = graph->E; 24 | Edge *edge = graph->edge; 25 | 26 | // Allocate memory for creating V subsets. 27 | struct subset *subsets = new subset[V]; 28 | 29 | // An array to store index of the cheapest edge of 30 | // subset. The stored index for indexing array 'edge[]' 31 | int *cheapest = new int[V]; 32 | 33 | // Create V subsets with single elements 34 | for (int v = 0; v < V; ++v) 35 | { 36 | subsets[v].parent = v; 37 | subsets[v].rank = 0; 38 | cheapest[v] = -1; 39 | } 40 | 41 | // Initially there are V different trees. 42 | // Finally there will be one tree that will be MST 43 | int numTrees = V; 44 | int MSTweight = 0; 45 | 46 | // Keep combining components (or sets) until all 47 | // compnentes are not combined into single MST. 48 | while (numTrees > 1) 49 | { 50 | // Everytime initialize cheapest array 51 | for (int v = 0; v < V; ++v) 52 | { 53 | cheapest[v] = -1; 54 | } 55 | 56 | // Traverse through all edges and update 57 | // cheapest of every component 58 | for (int i=0; i edge[i].weight) 76 | cheapest[set1] = i; 77 | 78 | if (cheapest[set2] == -1 || 79 | edge[cheapest[set2]].weight > edge[i].weight) 80 | cheapest[set2] = i; 81 | } 82 | } 83 | 84 | // Consider the above picked cheapest edges and add them 85 | // to MST 86 | for (int i=0; iV = V; 119 | graph->E = E; 120 | graph->edge = new Edge[E]; 121 | return graph; 122 | } 123 | 124 | // A utility function to find set of an element i 125 | // (uses path compression technique) 126 | int find(struct subset subsets[], int i) 127 | { 128 | // find root and make root as parent of i 129 | // (path compression) 130 | if (subsets[i].parent != i) 131 | subsets[i].parent = 132 | find(subsets, subsets[i].parent); 133 | 134 | return subsets[i].parent; 135 | } 136 | 137 | // A function that does union of two sets of x and y 138 | // (uses union by rank) 139 | void Union(struct subset subsets[], int x, int y) 140 | { 141 | int xroot = find(subsets, x); 142 | int yroot = find(subsets, y); 143 | 144 | // Attach smaller rank tree under root of high 145 | // rank tree (Union by Rank) 146 | if (subsets[xroot].rank < subsets[yroot].rank) 147 | subsets[xroot].parent = yroot; 148 | else if (subsets[xroot].rank > subsets[yroot].rank) 149 | subsets[yroot].parent = xroot; 150 | 151 | // If ranks are same, then make one as root and 152 | // increment its rank by one 153 | else 154 | { 155 | subsets[yroot].parent = xroot; 156 | subsets[xroot].rank++; 157 | } 158 | } 159 | 160 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/mst/boruvka.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // a structure to represent a weighted edge in graph 4 | struct Edge 5 | { 6 | int src, dest; 7 | float weight; 8 | }; 9 | 10 | // a structure to represent a connected, undirected 11 | // and weighted graph as a collection of edges. 12 | struct Graph 13 | { 14 | // V-> Number of vertices, E-> Number of edges 15 | int V, E; 16 | 17 | // graph is represented as an array of edges. 18 | // Since the graph is undirected, the edge 19 | // from src to dest is also edge from dest 20 | // to src. Both are counted as 1 edge here. 21 | Edge* edge; 22 | }; 23 | 24 | extern struct Graph* createGraph(int V, int E); 25 | extern void boruvkaMST(struct Graph* graph, int * edge_out); 26 | 27 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/mst/mst.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | /* Switch of minimal spanning tree algorithms */ 14 | /* Note: we will migrate the cuda implementaion to PyTorch in the next version */ 15 | //#define MST_PRIM 16 | //#define MST_KRUSKAL 17 | #define MST_BORUVKA 18 | 19 | #ifdef MST_PRIM 20 | #include 21 | #include 22 | #endif 23 | #ifdef MST_KRUSKAL 24 | #include 25 | #include 26 | #endif 27 | #ifdef MST_BORUVKA 28 | #include "boruvka.hpp" 29 | #endif 30 | 31 | 32 | #ifndef MST_BORUVKA 33 | using namespace boost; 34 | typedef adjacency_list > Graph; 36 | typedef graph_traits < Graph >::edge_descriptor Edge; 37 | typedef graph_traits < Graph >::vertex_descriptor Vertex; 38 | typedef std::pair E; 39 | #endif 40 | 41 | static void forward_kernel(int * edge_index, float * edge_weight, int * edge_out, int vertex_count, int edge_count){ 42 | #ifdef MST_BORUVKA 43 | struct Graph * g = createGraph(vertex_count, edge_count); 44 | for (int i = 0; i < edge_count; ++i){ 45 | g->edge[i].src = edge_index[i * 2]; 46 | g->edge[i].dest = edge_index[i * 2 + 1]; 47 | g->edge[i].weight = edge_weight[i]; 48 | } 49 | #else 50 | Graph g(vertex_count); 51 | for (int i = 0; i < edge_count; ++i) 52 | boost::add_edge((int)edge_index[i * 2], (int)edge_index[i * 2 + 1], 53 | edge_weight[i], g); 54 | #endif 55 | 56 | #ifdef MST_PRIM 57 | std::vector < graph_traits < Graph >::vertex_descriptor > p(num_vertices(g)); 58 | prim_minimum_spanning_tree(g, &(p[0])); 59 | int * edge_out_ptr = edge_out; 60 | for (std::size_t i = 0; i != p.size(); ++i) 61 | if (p[i] != i) { 62 | *(edge_out_ptr++) = i; 63 | *(edge_out_ptr++) = p[i]; 64 | } 65 | #endif 66 | 67 | #ifdef MST_KRUSKAL 68 | std::vector < Edge > spanning_tree; 69 | kruskal_minimum_spanning_tree(g, std::back_inserter(spanning_tree)); 70 | float * edge_out_ptr = edge_out; 71 | for (std::vector < Edge >::iterator ei = spanning_tree.begin(); 72 | ei != spanning_tree.end(); ++ei){ 73 | *(edge_out_ptr++) = source(*ei, g); 74 | *(edge_out_ptr++) = target(*ei, g); 75 | } 76 | #endif 77 | 78 | #ifdef MST_BORUVKA 79 | boruvkaMST(g, edge_out); 80 | delete[] g->edge; 81 | delete[] g; 82 | #endif 83 | 84 | } 85 | 86 | at::Tensor mst_forward( 87 | const at::Tensor & edge_index_tensor, 88 | const at::Tensor & edge_weight_tensor, 89 | int vertex_count){ 90 | unsigned batch_size = edge_index_tensor.size(0); 91 | unsigned edge_count = edge_index_tensor.size(1); 92 | 93 | auto edge_index_cpu = edge_index_tensor.cpu(); 94 | auto edge_weight_cpu = edge_weight_tensor.cpu(); 95 | auto edge_out_cpu = at::empty({batch_size, vertex_count - 1, 2}, edge_index_cpu.options()); 96 | 97 | int * edge_out = edge_out_cpu.contiguous().data(); 98 | int * edge_index = edge_index_cpu.contiguous().data(); 99 | float * edge_weight = edge_weight_cpu.contiguous().data(); 100 | 101 | // Loop for batch 102 | std::thread pids[batch_size]; 103 | for (unsigned i = 0; i < batch_size; i++){ 104 | auto edge_index_iter = edge_index + i * edge_count * 2; 105 | auto edge_weight_iter = edge_weight + i * edge_count; 106 | auto edge_out_iter = edge_out + i * (vertex_count - 1) * 2; 107 | pids[i] = std::thread(forward_kernel, edge_index_iter, edge_weight_iter, edge_out_iter, vertex_count, edge_count); 108 | } 109 | 110 | for (unsigned i = 0; i < batch_size; i++){ 111 | pids[i].join(); 112 | } 113 | 114 | auto edge_out_tensor = edge_out_cpu.to(edge_index_tensor.device()); 115 | 116 | return edge_out_tensor; 117 | } 118 | 119 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/mst/mst.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | extern at::Tensor mst_forward( 5 | const at::Tensor & edge_index_tensor, 6 | const at::Tensor & edge_weight_tensor, 7 | int vertex_count); 8 | 9 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/refine/refine.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #define CUDA_CHECK(call) if((call) != cudaSuccess) {cudaError_t err = cudaGetLastError(); std::cout << "CUDA error calling ""#call"", code is " << err << std::endl;} 17 | 18 | #define CUDA_NUM_THREADS 64 19 | #define GET_CUDA_CHANNEL(N) ceil(512.0f / N) 20 | 21 | __global__ void root_leaf_prop_kernel( 22 | float * in_data, 23 | float * out_data, 24 | float * weight, 25 | int * sorted_index, 26 | int * sorted_parent_index, 27 | int batch_size, 28 | int channel_size, 29 | int vertex_count){ 30 | 31 | const int thread_idx = threadIdx.x; 32 | const int batch_idx = blockIdx.x; 33 | const int channel_idx = blockIdx.y; 34 | const int thread_count = blockDim.x; 35 | const int channel_step = gridDim.y; 36 | 37 | in_data += batch_idx * vertex_count * channel_size; 38 | out_data += batch_idx * vertex_count * channel_size; 39 | weight += batch_idx * vertex_count; 40 | sorted_index += batch_idx * vertex_count; 41 | sorted_parent_index += batch_idx * vertex_count; 42 | 43 | __shared__ int node_per_thread[CUDA_NUM_THREADS]; 44 | node_per_thread[thread_idx] = -1; 45 | if (thread_idx == 0){ 46 | weight[0] = 0; 47 | sorted_parent_index[0] = 0; 48 | } 49 | __syncthreads(); 50 | 51 | int i = thread_idx; 52 | while (i < vertex_count){ 53 | int par = sorted_parent_index[i]; 54 | int par_thread = par % thread_count; 55 | if ((node_per_thread[par_thread] >= par) || (i == 0)){ 56 | int cur_pos = sorted_index[i]; 57 | int par_pos = sorted_index[par]; 58 | for (int k = channel_idx * vertex_count; k < channel_size * vertex_count; 59 | k += channel_step * vertex_count){ 60 | float edge_weight = weight[i]; 61 | out_data[cur_pos + k] = in_data[i + k] * (1 - edge_weight * edge_weight) + 62 | out_data[par_pos + k] * edge_weight; 63 | __threadfence_block(); 64 | } 65 | node_per_thread[thread_idx] = i; 66 | i += thread_count; 67 | } 68 | __syncthreads(); 69 | } 70 | } 71 | 72 | __global__ void leaf_root_aggr_kernel( 73 | float * in_data, 74 | float * out_data, 75 | float * weight, 76 | int * sorted_index, 77 | int * sorted_child_index, 78 | int batch_size, 79 | int channel_size, 80 | int vertex_count, 81 | int max_adj_per_node){ 82 | 83 | const int thread_idx = threadIdx.x; 84 | const int batch_idx = blockIdx.x; 85 | const int channel_idx = blockIdx.y; 86 | const int thread_count = blockDim.x; 87 | const int channel_step = gridDim.y; 88 | 89 | if (in_data != NULL){ 90 | in_data += batch_idx * vertex_count * channel_size; 91 | } 92 | out_data += batch_idx * vertex_count * channel_size; 93 | weight += batch_idx * vertex_count; 94 | sorted_index += batch_idx * vertex_count; 95 | sorted_child_index += batch_idx * vertex_count * max_adj_per_node; 96 | 97 | __shared__ int node_per_thread[CUDA_NUM_THREADS]; 98 | node_per_thread[thread_idx] = vertex_count; 99 | __syncthreads(); 100 | 101 | int i = vertex_count - thread_idx - 1; 102 | while (i >= 0){ 103 | int child_len = 0; 104 | bool valid = true; 105 | for (int j = 0; j < max_adj_per_node; j++){ 106 | int child = sorted_child_index[i * max_adj_per_node + j]; 107 | int child_thread = (vertex_count - child - 1) % thread_count; 108 | 109 | if (child <= 0) break; 110 | if (node_per_thread[child_thread] > child){ 111 | valid = false; 112 | break; 113 | } 114 | child_len++; 115 | } 116 | if (valid){ 117 | int cur_pos = sorted_index[i]; 118 | for (int k = channel_idx * vertex_count; k < channel_size * vertex_count; 119 | k += channel_step * vertex_count){ 120 | float aggr_sum; 121 | if (in_data != NULL) 122 | aggr_sum = in_data[cur_pos + k]; 123 | else 124 | aggr_sum = 1; 125 | for (int j = 0; j < child_len; j++){ 126 | int child = sorted_child_index[i * max_adj_per_node + j]; 127 | aggr_sum += out_data[child + k] * weight[child]; 128 | } 129 | out_data[i + k] = aggr_sum; 130 | } 131 | node_per_thread[thread_idx] = i; 132 | i -= thread_count; 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | 138 | __global__ void root_leaf_grad_kernel( 139 | float * in_data, 140 | float * in_grad, 141 | float * out_data, 142 | float * out_grad, 143 | float * weight, 144 | float * grad, 145 | int * sorted_index, 146 | int * sorted_parent_index, 147 | int batch_size, 148 | int data_channel_size, 149 | int grad_channel_size, 150 | int vertex_count){ 151 | 152 | const int thread_idx = threadIdx.x; 153 | const int batch_idx = blockIdx.x; 154 | const int channel_idx = blockIdx.y; 155 | const int thread_count = blockDim.x; 156 | const int channel_step = gridDim.y; 157 | const int channel_size = data_channel_size > grad_channel_size ? data_channel_size : grad_channel_size; 158 | 159 | in_data += batch_idx * vertex_count * data_channel_size; 160 | in_grad += batch_idx * vertex_count * grad_channel_size; 161 | out_data += batch_idx * vertex_count * data_channel_size; 162 | out_grad += batch_idx * vertex_count * grad_channel_size; 163 | weight += batch_idx * vertex_count; 164 | grad += batch_idx * vertex_count * channel_size; 165 | sorted_index += batch_idx * vertex_count; 166 | sorted_parent_index += batch_idx * vertex_count; 167 | 168 | __shared__ int node_per_thread[CUDA_NUM_THREADS]; 169 | node_per_thread[thread_idx] = -1; 170 | 171 | int i = thread_idx; 172 | while (i < vertex_count){ 173 | int cur = i; 174 | int par = sorted_parent_index[i]; 175 | int par_pos = sorted_index[par]; 176 | int par_thread = par % thread_count; 177 | if ((cur == 0) || (node_per_thread[par_thread] >= par)){ 178 | for (int k = channel_idx; k < channel_size; k += channel_step){ 179 | float edge_weight = weight[i]; 180 | int data_offset = (k % data_channel_size) * vertex_count; 181 | int grad_offset = (k % grad_channel_size) * vertex_count; 182 | int out_offset = k * vertex_count; 183 | 184 | if (cur > 0){ 185 | float left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]); 186 | float right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]); 187 | 188 | grad[cur + out_offset] = left + right; 189 | out_grad[cur + grad_offset] = in_grad[cur + grad_offset] * (1 - edge_weight * edge_weight) + 190 | out_grad[par + grad_offset] * edge_weight; 191 | __threadfence_block(); 192 | } 193 | else 194 | grad[cur + out_offset] = 0; 195 | } 196 | node_per_thread[thread_idx] = i; 197 | i += thread_count; 198 | } 199 | __syncthreads(); 200 | } 201 | } 202 | 203 | std::tuple 204 | refine_forward( 205 | const at::Tensor & feature_in_tensor, 206 | const at::Tensor & edge_weight_tensor, 207 | const at::Tensor & sorted_index_tensor, 208 | const at::Tensor & sorted_parent_tensor, 209 | const at::Tensor & sorted_child_tensor 210 | ){ 211 | 212 | const int batch_size = feature_in_tensor.size(0); 213 | const int channel_size = feature_in_tensor.size(1); 214 | const int vertex_size = feature_in_tensor.size(2); 215 | const int max_adj_per_node = sorted_child_tensor.size(2); 216 | 217 | auto options = feature_in_tensor.options(); 218 | auto feature_aggr_tensor = at::zeros_like(feature_in_tensor, options); 219 | auto feature_aggr_up_tensor = at::zeros_like(feature_in_tensor, options); 220 | auto weight_sum_tensor = at::zeros({batch_size, vertex_size}, options); 221 | auto weight_sum_up_tensor = at::zeros({batch_size, vertex_size}, options); 222 | 223 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 224 | 225 | float * feature_in = feature_in_tensor.contiguous().data(); 226 | float * edge_weight = edge_weight_tensor.contiguous().data(); 227 | int * sorted_index = sorted_index_tensor.contiguous().data(); 228 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data(); 229 | int * sorted_child_index = sorted_child_tensor.contiguous().data(); 230 | float * feature_aggr = feature_aggr_tensor.contiguous().data(); 231 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data(); 232 | float * weight_sum = weight_sum_tensor.contiguous().data(); 233 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data(); 234 | 235 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1); 236 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 237 | feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node); 238 | root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 239 | feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size); 240 | 241 | dim3 weight_block_dims(CUDA_NUM_THREADS, 1, 1), weight_grid_dims(batch_size, 1, 1); 242 | leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 243 | NULL, weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1, vertex_size, max_adj_per_node); 244 | root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 245 | weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1, vertex_size); 246 | 247 | auto feature_out_tensor = feature_aggr_tensor / weight_sum_tensor.unsqueeze(1); 248 | auto result = std::make_tuple(feature_out_tensor, feature_aggr_tensor, feature_aggr_up_tensor, 249 | weight_sum_tensor, weight_sum_up_tensor); 250 | return result; 251 | } 252 | 253 | at::Tensor refine_backward_feature( 254 | const at::Tensor & feature_in_tensor, 255 | const at::Tensor & edge_weight_tensor, 256 | const at::Tensor & sorted_index_tensor, 257 | const at::Tensor & sorted_parent_tensor, 258 | const at::Tensor & sorted_child_tensor, 259 | const at::Tensor & feature_out_tensor, 260 | const at::Tensor & feature_aggr_tensor, 261 | const at::Tensor & feature_aggr_up_tensor, 262 | const at::Tensor & weight_sum_tensor, 263 | const at::Tensor & weight_sum_up_tensor, 264 | const at::Tensor & grad_out_tensor 265 | ){ 266 | 267 | auto options = feature_in_tensor.options(); 268 | auto grad_feature_tensor = at::zeros_like(feature_in_tensor, options); 269 | auto grad_feature_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options); 270 | 271 | auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1); 272 | 273 | const int batch_size = feature_in_tensor.size(0); 274 | const int channel_size = feature_in_tensor.size(1); 275 | const int vertex_size = feature_in_tensor.size(2); 276 | const int max_adj_per_node = sorted_child_tensor.size(2); 277 | 278 | float * feature_in = feature_in_tensor.contiguous().data(); 279 | float * edge_weight = edge_weight_tensor.contiguous().data(); 280 | int * sorted_index = sorted_index_tensor.contiguous().data(); 281 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data(); 282 | int * sorted_child_index = sorted_child_tensor.contiguous().data(); 283 | float * feature_aggr = feature_aggr_tensor.contiguous().data(); 284 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data(); 285 | float * weight_sum = weight_sum_tensor.contiguous().data(); 286 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data(); 287 | float * grad_out = grad_out_tensor.contiguous().data(); 288 | float * grad_feature = grad_feature_tensor.contiguous().data(); 289 | 290 | float * grad_out_norm = grad_out_norm_tensor.contiguous().data(); 291 | float * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous().data(); 292 | 293 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 294 | 295 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1); 296 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 297 | grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node); 298 | root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 299 | grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size); 300 | 301 | return grad_feature_tensor; 302 | } 303 | 304 | at::Tensor refine_backward_weight( 305 | const at::Tensor & feature_in_tensor, 306 | const at::Tensor & edge_weight_tensor, 307 | const at::Tensor & sorted_index_tensor, 308 | const at::Tensor & sorted_parent_tensor, 309 | const at::Tensor & sorted_child_tensor, 310 | const at::Tensor & feature_out_tensor, 311 | const at::Tensor & feature_aggr_tensor, 312 | const at::Tensor & feature_aggr_up_tensor, 313 | const at::Tensor & weight_sum_tensor, 314 | const at::Tensor & weight_sum_up_tensor, 315 | const at::Tensor & grad_out_tensor 316 | ){ 317 | 318 | auto options = feature_in_tensor.options(); 319 | auto grad_weight_tensor = at::zeros_like(edge_weight_tensor, options); 320 | 321 | const int batch_size = feature_in_tensor.size(0); 322 | const int channel_size = feature_in_tensor.size(1); 323 | const int vertex_size = feature_in_tensor.size(2); 324 | const int max_adj_per_node = sorted_child_tensor.size(2); 325 | 326 | float * feature_in = feature_in_tensor.contiguous().data(); 327 | float * edge_weight = edge_weight_tensor.contiguous().data(); 328 | int * sorted_index = sorted_index_tensor.contiguous().data(); 329 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data(); 330 | int * sorted_child_index = sorted_child_tensor.contiguous().data(); 331 | float * feature_out = feature_out_tensor.contiguous().data(); 332 | float * feature_aggr = feature_aggr_tensor.contiguous().data(); 333 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data(); 334 | float * weight_sum = weight_sum_tensor.contiguous().data(); 335 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data(); 336 | float * grad_out = grad_out_tensor.contiguous().data(); 337 | float * grad_weight = grad_weight_tensor.contiguous().data(); 338 | 339 | auto grad_all_channel_tensor = at::zeros_like(feature_in_tensor, options); 340 | auto grad_norm_all_channel_tensor = at::zeros_like(feature_in_tensor, options); 341 | auto grad_out_norm_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options); 342 | auto feature_grad_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options); 343 | 344 | float * grad_all_channel = grad_all_channel_tensor.contiguous().data(); 345 | float * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous().data(); 346 | float * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous().data(); 347 | float * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous().data(); 348 | 349 | auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1); 350 | auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor; 351 | float * grad_out_norm = grad_out_norm_tensor.contiguous().data(); 352 | float * feature_grad = feature_grad_tensor.contiguous().data(); 353 | 354 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 355 | 356 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1); 357 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 358 | grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node); 359 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 360 | feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node); 361 | 362 | root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 363 | feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel, 364 | sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size); 365 | root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>( 366 | weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel, 367 | sorted_index, sorted_parent_index, batch_size, 1, channel_size, vertex_size); 368 | 369 | grad_weight_tensor = (grad_all_channel_tensor - grad_norm_all_channel_tensor).sum(1); 370 | 371 | return grad_weight_tensor; 372 | } 373 | 374 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/refine/refine.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | extern std::tuple 5 | refine_forward( 6 | const at::Tensor & feature_in_tensor, 7 | const at::Tensor & edge_weight_tensor, 8 | const at::Tensor & sorted_index_tensor, 9 | const at::Tensor & sorted_parent_index_tensor, 10 | const at::Tensor & sorted_child_index_tensor 11 | ); 12 | 13 | extern at::Tensor refine_backward_feature( 14 | const at::Tensor & feature_in_tensor, 15 | const at::Tensor & edge_weight_tensor, 16 | const at::Tensor & sorted_index_tensor, 17 | const at::Tensor & sorted_parent_tensor, 18 | const at::Tensor & sorted_child_tensor, 19 | const at::Tensor & feature_out_tensor, 20 | const at::Tensor & feature_aggr_tensor, 21 | const at::Tensor & feature_aggr_up_tensor, 22 | const at::Tensor & weight_sum_tensor, 23 | const at::Tensor & weight_sum_up_tensor, 24 | const at::Tensor & grad_out_tensor 25 | ); 26 | 27 | extern at::Tensor refine_backward_weight( 28 | const at::Tensor & feature_in_tensor, 29 | const at::Tensor & edge_weight_tensor, 30 | const at::Tensor & sorted_index_tensor, 31 | const at::Tensor & sorted_parent_tensor, 32 | const at::Tensor & sorted_child_tensor, 33 | const at::Tensor & feature_out_tensor, 34 | const at::Tensor & feature_aggr_tensor, 35 | const at::Tensor & feature_aggr_up_tensor, 36 | const at::Tensor & weight_sum_tensor, 37 | const at::Tensor & weight_sum_up_tensor, 38 | const at::Tensor & grad_out_tensor 39 | ); 40 | 41 | -------------------------------------------------------------------------------- /contextlab/layers/tree_filter/src/tree_filter.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "refine/refine.hpp" 4 | #include "mst/mst.hpp" 5 | #include "bfs/bfs.hpp" 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | m.def("mst_forward", &mst_forward, "mst forward"); 9 | m.def("bfs_forward", &bfs_forward, "bfs forward"); 10 | m.def("refine_forward", &refine_forward, "refine forward"); 11 | m.def("refine_backward_feature", &refine_backward_feature, "refine backward wrt feature"); 12 | m.def("refine_backward_weight", &refine_backward_weight, "refine backward wrt weight"); 13 | } 14 | 15 | -------------------------------------------------------------------------------- /contextlab/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer_misc import ConvBNReLU, GraphAdjNetwork 2 | from .weight_init import constant_init, xavier_init, kaiming_init -------------------------------------------------------------------------------- /contextlab/utils/layer_misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | __all__ = ['ConvBNReLU'] 5 | 6 | 7 | class ConvBNReLU(nn.Module): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | norm_layer=nn.BatchNorm2d, 13 | with_relu=True, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | padding_mode='zeros'): 20 | super(ConvBNReLU, self).__init__() 21 | self.conv = nn.Conv2d( 22 | in_channels=in_channels, 23 | out_channels=out_channels, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | groups=groups, 29 | bias=bias, 30 | padding_mode=padding_mode) 31 | self.bn = norm_layer(out_channels) 32 | self.with_relu = with_relu 33 | if with_relu: 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | x = self.bn(x) 39 | if self.with_relu: 40 | x = self.relu(x) 41 | return x 42 | 43 | 44 | class GraphAdjNetwork(nn.Module): 45 | def __init__(self, 46 | pair_function, 47 | in_channels, 48 | channel_stride): 49 | super(GraphAdjNetwork, self).__init__() 50 | self.pair_function = pair_function 51 | 52 | if pair_function == 'embedded_gaussian': 53 | inter_channel = in_channels // channel_stride 54 | self.phi = ConvBNReLU( 55 | in_channels=in_channels, 56 | out_channels=inter_channel, 57 | kernel_size=1, 58 | bias=False, 59 | norm_layer=nn.BatchNorm2d 60 | ) 61 | self.theta = ConvBNReLU( 62 | in_channels=in_channels, 63 | out_channels=inter_channel, 64 | kernel_size=1, 65 | bias=False, 66 | norm_layer=nn.BatchNorm2d 67 | ) 68 | elif pair_function == 'gaussian': 69 | pass 70 | elif pair_function == 'diff_learnable': 71 | self.learnable_adj_conv = ConvBNReLU( 72 | in_channels=in_channels, 73 | out_channels=1, 74 | kernel_size=1, 75 | bias=False, 76 | norm_layer=nn.BatchNorm2d 77 | ) 78 | elif pair_function == 'sum_learnable': 79 | self.learnable_adj_conv = ConvBNReLU( 80 | in_channels=in_channels, 81 | out_channels=1, 82 | kernel_size=1, 83 | bias=False, 84 | norm_layer=nn.BatchNorm2d 85 | ) 86 | elif pair_function == 'cat_learnable': 87 | self.learnable_adj_conv = ConvBNReLU( 88 | in_channels=in_channels*2, 89 | out_channels=1, 90 | kernel_size=1, 91 | bias=False, 92 | norm_layer=nn.BatchNorm2d 93 | ) 94 | else: 95 | raise NotImplementedError 96 | 97 | def forward(self, x): 98 | """ 99 | Args: 100 | x (Tensor): 101 | (B, N, C) 102 | """ 103 | if self.pair_function == 'gaussian': 104 | adj = self.gaussian(x, x.permute(0, 2, 1)) 105 | elif self.pair_function == 'embedded_gaussian': 106 | x = x.permute(0, 2, 1).unsqueeze(-1) 107 | x_1 = self.phi(x) # B, C, N, 1 108 | x_2 = self.theta(x) # B, C, N, 1 109 | adj = self.gaussian( 110 | x_1.squeeze(-1).permute(0, 2, 1), x_2.squeeze(-1)) 111 | elif self.pair_function == 'diff_learnable': 112 | adj = self.diff_learnable_adj(x.unsqueeze(2), x.unsqueeze(1)) 113 | elif self.pair_function == 'sum_learnable': 114 | adj = self.sum_learnable_adj(x.unsqueeze(2), x.unsqueeze(1)) 115 | elif self.pair_function == 'cat_learnable': 116 | adj = self.cat_learnable_adj(x.unsqueeze(2), x.unsqueeze(1)) 117 | else: 118 | raise NotImplementedError(self.pair_function) 119 | 120 | return adj 121 | 122 | def gaussian(self, x_1, x_2): 123 | """ 124 | Args: 125 | x_1: 126 | x_2: 127 | Return: 128 | adj: normalized in the last dimenstion 129 | """ 130 | # (B, N, C) X (B, C, N) --> (B, N, N) 131 | adj = torch.bmm(x_1, x_2) # B, N, N 132 | adj = F.softmax(adj, dim=-1) # B, N, N 133 | return adj 134 | 135 | def diff_learnable_adj(self, x_1, x_2): 136 | """ 137 | Learnable attention from the difference of the feature 138 | 139 | Return: 140 | adj: normalzied at the last dimension 141 | """ 142 | # x1:(B,N,1,C) 143 | # x2:(B,1,N,C) 144 | feature_diff = x_1 - x_2 # (B, N, N, C) 145 | feature_diff = feature_diff.permute(0, 3, 1, 2) # (B, C, N, N) 146 | adj = self.learnable_adj_conv(feature_diff) # (B, 1, N, N) 147 | adj = adj.squeeze(1) # (B, N, N) 148 | # Use the number of nodes as the normalization factor 149 | adj = adj / adj.size(-1) # (B, N, N) 150 | 151 | return adj 152 | 153 | def sum_learnable_adj(self, x_1, x_2): 154 | """ 155 | Learnable attention from the difference of the feature 156 | 157 | Return: 158 | adj: normalzied at the last dimension 159 | """ 160 | # x1:(B,N,1,C) 161 | # x2:(B,1,N,C) 162 | feature_diff = x_1 + x_2 # (B, N, N, C) 163 | feature_diff = feature_diff.permute(0, 3, 1, 2) # (B, C, N, N) 164 | adj = self.learnable_adj_conv(feature_diff) # (B, 1, N, N) 165 | adj = adj.squeeze(1) # (B, N, N) 166 | # Use the number of nodes as the normalization factor 167 | adj = adj / adj.size(-1) # (B, N, N) 168 | 169 | return adj 170 | 171 | def cat_learnable_adj(self, x_1, x_2): 172 | """ 173 | Learable attention from the concatnation of the features 174 | """ 175 | x_1 = x_1.repeat(1, 1, x_1.size(1), 1) # B, N, N, C 176 | x_2 = x_2.repeat(1, x_2.size(2), 1, 1) # B, N, N, C 177 | feature_cat = torch.cat([x_1, x_2], dim=-1) # B, N, N, 2C 178 | # import pdb; pdb.set_trace() 179 | feature_cat = feature_cat.permute(0, 3, 1, 2) # B, 2C, N, N 180 | adj = self.learnable_adj_conv(feature_cat) # B, 1, N, N 181 | adj = adj.squeeze(1) # (B, N, N) 182 | # Use the number of nodes as the normalization factor 183 | adj = adj / adj.size(-1) # (B, N, N) 184 | 185 | return adj 186 | -------------------------------------------------------------------------------- /contextlab/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def constant_init(module, val, bias=0): 5 | nn.init.constant_(module.weight, val) 6 | if hasattr(module, 'bias') and module.bias is not None: 7 | nn.init.constant_(module.bias, bias) 8 | 9 | 10 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 11 | assert distribution in ['uniform', 'normal'] 12 | if distribution == 'uniform': 13 | nn.init.xavier_uniform_(module.weight, gain=gain) 14 | else: 15 | nn.init.xavier_normal_(module.weight, gain=gain) 16 | if hasattr(module, 'bias') and module.bias is not None: 17 | nn.init.constant_(module.bias, bias) 18 | 19 | 20 | def normal_init(module, mean=0, std=1, bias=0): 21 | nn.init.normal_(module.weight, mean, std) 22 | if hasattr(module, 'bias') and module.bias is not None: 23 | nn.init.constant_(module.bias, bias) 24 | 25 | 26 | def uniform_init(module, a=0, b=1, bias=0): 27 | nn.init.uniform_(module.weight, a, b) 28 | if hasattr(module, 'bias') and module.bias is not None: 29 | nn.init.constant_(module.bias, bias) 30 | 31 | 32 | def kaiming_init(module, 33 | a=0, 34 | mode='fan_out', 35 | nonlinearity='relu', 36 | bias=0, 37 | distribution='normal'): 38 | assert distribution in ['uniform', 'normal'] 39 | if distribution == 'uniform': 40 | nn.init.kaiming_uniform_( 41 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 42 | else: 43 | nn.init.kaiming_normal_( 44 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 45 | if hasattr(module, 'bias') and module.bias is not None: 46 | nn.init.constant_(module.bias, bias) 47 | 48 | 49 | def caffe2_xavier_init(module, bias=0): 50 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 51 | # Acknowledgment to FAIR's internal code 52 | kaiming_init( 53 | module, 54 | a=1, 55 | mode='fan_in', 56 | nonlinearity='leaky_relu', 57 | distribution='uniform') -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Description: Setup for contextlab project 3 | @Author: Songyang Zhang 4 | @Email: sy.zhangbuaa@gmail.com 5 | @Date: 2019-08-11 12:30:28 6 | @LastEditors: Songyang Zhang 7 | @LastEditTime: 2019-09-27 19:58:41 8 | ''' 9 | 10 | import glob 11 | import os 12 | 13 | import subprocess 14 | import time 15 | import platform 16 | 17 | import torch 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | from torch.utils.cpp_extension import CUDA_HOME 23 | from torch.utils.cpp_extension import CppExtension 24 | from torch.utils.cpp_extension import CUDAExtension 25 | from setuptools import Extension, find_packages, setup 26 | 27 | from Cython.Build import cythonize 28 | 29 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 30 | 31 | requirements = ['torch'] 32 | 33 | def readme(): 34 | with open('README.md', encoding='utf-8') as f: 35 | content = f.read() 36 | return content 37 | 38 | version_file = 'contextlab/version.py' 39 | 40 | if torch.cuda.is_available(): 41 | if 'LD_LIBRARY_PATH' not in os.environ: 42 | raise Exception('LD_LIBRARY_PATH is not set.') 43 | cuda_lib_path = os.environ['LD_LIBRARY_PATH'].split(':') 44 | else: 45 | raise Exception('This implementation is only avaliable for CUDA devices.') 46 | 47 | 48 | MAJOR = 0 49 | MINOR = 2 50 | PATCH = 0 51 | SUFFIX = '' 52 | SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) 53 | 54 | 55 | def get_git_hash(): 56 | def _minimal_ext_cmd(cmd): 57 | # construct minimal environment 58 | env = {} 59 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 60 | v = os.environ.get(k) 61 | if v is not None: 62 | env[k] = v 63 | 64 | # LANGUAGE is used on win 32 65 | env['LANGUAGE'] = 'C' 66 | env['LANG'] = 'C' 67 | env['LC_ALL'] = 'C' 68 | out = subprocess.Popen( 69 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 70 | 71 | return out 72 | try: 73 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 74 | sha = out.strip().decode('ascii') 75 | except OSError: 76 | sha = 'unknown' 77 | 78 | return sha 79 | 80 | def get_hash(): 81 | if os.path.exists('.git'): 82 | sha = get_git_hash()[:7] 83 | elif os.path.exists(version_file): 84 | try: 85 | from pluscv.version import __version__ 86 | sha = __version__.split('+')[-1] 87 | except ImportError: 88 | raise ImportError('Unable to get git version') 89 | else: 90 | sha = 'unknown' 91 | 92 | return sha 93 | 94 | def write_version_py(): 95 | content = """# Generated Version File 96 | # Time: {} 97 | 98 | __version__ = '{}' 99 | short_version = '{}' 100 | """ 101 | sha = get_hash() 102 | VERSION = SHORT_VERSION + '+' + sha 103 | 104 | with open(version_file, 'w') as f: 105 | f.write(content.format(time.asctime(), VERSION, SHORT_VERSION)) 106 | 107 | def get_version(): 108 | with open(version_file, 'r') as f: 109 | exec(compile(f.read(), version_file, 'exec')) 110 | 111 | return locals()['__version__'] 112 | 113 | 114 | def make_cuda_ext(name, module, sources, include_dirs=[]): 115 | 116 | 117 | return CUDAExtension( 118 | name='{}.{}'.format(module, name), 119 | sources=[os.path.join(*module.split('.'), p) for p in sources], 120 | include_dirs=include_dirs, 121 | library_dirs=cuda_lib_path, 122 | extra_compile_args={ 123 | 'cxx': ['-O3'], 124 | 'nvcc': [ 125 | '-O3', 126 | # '-D__CUDA_NO_HALF_OPERATORS__', 127 | # '-D__CUDA_NO_HALF_CONVERSIONS__', 128 | # '-D__CUDA_NO_HALF2_OPERATORS__', 129 | ] 130 | }) 131 | 132 | 133 | def tree_filter_files(): 134 | 135 | extensions_dir = 'contextlab/layers/tree_filter/src' 136 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 137 | source_cpu = glob.glob(os.path.join(extensions_dir, "*", "*.cpp")) 138 | source_cuda = glob.glob(os.path.join(extensions_dir, "*", "*.cu")) 139 | 140 | sources = source_cpu + source_cuda + main_file 141 | 142 | return extensions_dir, sources 143 | 144 | if __name__ == "__main__": 145 | 146 | write_version_py() 147 | 148 | tree_extensions_dir, tree_sources = tree_filter_files() 149 | 150 | setup( 151 | name='contextlab', 152 | version=get_version(), 153 | author="Songyang Zhang", 154 | url="https://github.com/SHTUPLUS/contextlab", 155 | long_description=readme(), 156 | description="Context Feature Augmentation Lab developed with PyTorch from ShanghaiTech PLUS Lab", 157 | packages=find_packages(exclude=("src",)), 158 | license='Apache License 2.0', 159 | install_requires=requirements, 160 | ext_modules=[ 161 | # make_cuda_ext( 162 | # name='tree_filter_cuda', 163 | # module='contextlab.layers.tree_filter', 164 | # include_dirs=[tree_extensions_dir], 165 | # sources=tree_sources), 166 | CUDAExtension( 167 | name='contextlab.layers.tree_filter.functions.tree_filter_cuda', 168 | # module='contextlab.layers.tree_filter', 169 | include_dirs=[tree_extensions_dir], 170 | sources=tree_sources, 171 | library_dirs=cuda_lib_path, 172 | extra_compile_args={'cxx':['-O3'], 173 | 'nvcc':['-O3']}), 174 | CUDAExtension( 175 | name='contextlab.layers.cc_attention.rcca', 176 | sources=['contextlab/layers/cc_attention/src/lib_cffi.cpp', 177 | 'contextlab/layers/cc_attention/src/ca.cu'], 178 | extra_compile_args= ['-std=c++11'], 179 | extra_cflags=["-O3"], 180 | extra_cuda_cflags=["--expt-extended-lambda"], 181 | ) 182 | ], 183 | cmdclass={ 184 | 'build_ext': BuildExtension 185 | }, 186 | zip_safe=False 187 | ) -------------------------------------------------------------------------------- /src/images/contextlab_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/src/images/contextlab_logo.png -------------------------------------------------------------------------------- /test/utils/layer_misc.py: -------------------------------------------------------------------------------- 1 | # from contextlab.layers.long_tail import CategoryAttentionNetwork 2 | from contextlab.utils import GraphAdjNetwork 3 | 4 | import torch 5 | 6 | if __name__ == "__main__": 7 | inputs = torch.randn(2, 1024, 8).permute(0, 2, 1) 8 | network = GraphAdjNetwork( 9 | pair_function='embedded_gaussian', 10 | in_channels=1024, 11 | channel_stride=8, 12 | ) 13 | output = network(inputs) 14 | print('Embedding Gaussian run pass') 15 | 16 | network = GraphAdjNetwork( 17 | pair_function='gaussian', 18 | in_channels=1024, 19 | channel_stride=8, 20 | ) 21 | output = network(inputs) 22 | 23 | print('Gaussian run pass') 24 | 25 | network = GraphAdjNetwork( 26 | pair_function='diff_learnable', 27 | in_channels=1024, 28 | channel_stride=8, 29 | ) 30 | output = network(inputs) 31 | 32 | print('Diff learnable run pass') 33 | 34 | network = GraphAdjNetwork( 35 | pair_function='sum_learnable', 36 | in_channels=1024, 37 | channel_stride=8, 38 | ) 39 | output = network(inputs) 40 | print('Sum learnable run pass') 41 | 42 | network = GraphAdjNetwork( 43 | pair_function='cat_learnable', 44 | in_channels=1024, 45 | channel_stride=8, 46 | ) 47 | output = network(inputs) 48 | print('Cat learnable run pass') 49 | --------------------------------------------------------------------------------