├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── contextlab
├── __init__.py
├── layers
│ ├── __init__.py
│ ├── cc_attention
│ │ ├── __init__.py
│ │ ├── functions.py
│ │ ├── setup.py
│ │ └── src
│ │ │ ├── ca.cu
│ │ │ ├── ca.h
│ │ │ ├── common.h
│ │ │ ├── lib_cffi.cpp
│ │ │ └── lib_cffi.h
│ ├── dual_attention
│ │ ├── __init__.py
│ │ └── dual_attention.py
│ ├── em_attention
│ │ ├── __init__.py
│ │ └── emu.py
│ ├── gcnet
│ │ ├── __init__.py
│ │ └── gcnet.py
│ ├── latentgnn
│ │ ├── __init__.py
│ │ └── latentgnn.py
│ ├── non_local
│ │ ├── __init__.py
│ │ └── non_local.py
│ └── tree_filter
│ │ ├── __init__.py
│ │ ├── functions
│ │ ├── bfs.py
│ │ ├── mst.py
│ │ └── refine.py
│ │ ├── modules
│ │ └── tree_filter.py
│ │ ├── setup.py
│ │ └── src
│ │ ├── bfs
│ │ ├── bfs.cu
│ │ └── bfs.hpp
│ │ ├── mst
│ │ ├── boruvka.cpp
│ │ ├── boruvka.hpp
│ │ ├── mst.cu
│ │ └── mst.hpp
│ │ ├── refine
│ │ ├── refine.cu
│ │ └── refine.hpp
│ │ └── tree_filter.cpp
└── utils
│ ├── __init__.py
│ ├── layer_misc.py
│ └── weight_init.py
├── setup.py
├── src
└── images
│ └── contextlab_logo.png
└── test
└── utils
└── layer_misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | version.py
7 | # C extensions
8 | *.so
9 | # User added
10 | *.DS_Store
11 | .vscode/
12 | github/
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/CONTRIBUTING.md
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 PLUS Lab, ShanghaiTech University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ContextLab
2 | ContextLab: A Toolbox for Context Feature Augmentation developed with PyTorch
3 |
4 |
5 |
6 |
7 | ## Introduction
8 |
9 |
10 | The master branch works with **PyTorch 1.1** or higher
11 |
12 | ContextLab is an open source context feature augmentation toolbox based on PyTorch. It is a part of the Open-PLUS project developed by [ShanghaiTech PLUS Lab](http://plus.sist.shanghaitech.edu.cn)
13 |
14 |
15 |
16 | ## Major Features
17 | - **Modular Design**
18 |
19 | - **High Efficiency**
20 |
21 | - **State-of-the-art Performance**
22 |
23 | We have implemented several context augmentation algorithms in PyTorch with comparable performance.
24 |
25 | ## License
26 | This project is released under the [MIT License](LICENSE)
27 |
28 | ## Updates
29 |
30 | V0.2.0 (27/09/2019)
31 | - Support for **CCNet**, **TreeFilter** and **EMANet**
32 |
33 | v0.1.0 (26/07/2019)
34 | - Start the project
35 |
36 |
37 | ## Benchmark and Model Zoo
38 |
39 |
40 |
41 | | Method | Block-wise | Stage-wise | Paper |
42 | |--------------------|:------------:|:-----------:|:--------:|
43 | | Non-local Network | ✗ | ✓ | [CVPR 18](https://arxiv.org/abs/1711.07971)
44 | | Dual-attention | ✗ | ✓ | [CVPR 19](https://arxiv.org/abs/1809.02983)
45 | | GCNet | ✗ | ✓ | [Arxiv ](https://arxiv.org/abs/1904.11492)
46 | | CCNet | ✓ | ✓ | [ICCV 19](https://arxiv.org/abs/1811.11721)
47 | | LatentGNN | ✗ | ✓ | [ICML 19](https://arxiv.org/abs/1905.11634)
48 | | TreeFilter | ✗ | ✓ | [NIPS 19]()
49 | | EMANet | ✗ | ✓ | [ICCV 19](https://arxiv.org/abs/1907.13426)
50 |
51 | ## Installation
52 |
53 | ```
54 | git clone https://github.com/SHTUPLUS/contextlab.git
55 | cd contextlab/
56 | python setup.py build develop
57 | ```
58 |
59 | ## Exapmles
60 | ```python
61 | # GCNet
62 | from contextlab.layers import GlobalContextBlock2d
63 | # Dual-Attention
64 | from contextlab.layers import SelfAttention
65 | # LatentGNN
66 | from contextlab.layers import LatentGNN
67 | # TreeFilter
68 | from contextlab.layers import MinimumSpanningTree, TreeFilter2D
69 | # CCNet
70 | from contextlab.layers import CrissCrossAttention
71 | # EMAttetnion
72 | from contextlab.layers import EMAttentionUnit
73 | ```
74 |
75 | ## To do
76 | - [ ] Experiments on Segmentation and Detection
77 | - [ ] Performance Comparison
78 |
79 | ## Contributing
80 |
81 | We appreciate all contributions to improve ContextLab. Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for the contributing guideline.
82 |
83 | ## Acknowledgement
84 | ContextLab is an open source project that is contributed by researchers and engineers from various colledges and companies. We appreciate all the contributors who implement their methods or add new features.
85 |
86 | We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new segmentation methods.
87 |
88 | ## Citation
89 | ```
90 | @misc{contextlab,
91 | title = {{ContextLab}: A Toolbox for Context Feature Augmentation},
92 | author = {Songyang Zhang},
93 | year={2019}
94 | }
95 | ```
96 | ## Contact
97 | ```
98 | email: sy.zhangbuaa@gmail.com
99 | ```
100 |
101 | ## Related Projects
102 | - [Non-local Network(CVPR 18)](https://arxiv.org/abs/1711.07971)
103 | - [Video Classification(Caffe)](https://github.com/facebookresearch/video-nonlocal-net)
104 | - [Dual-attentio(CVPR 19)](https://arxiv.org/abs/1809.02983)
105 | - [Semantic Segmentation(PyTorch)](https://github.com/junfu1115/DANet)
106 | - [GCNet (Arxiv)](https://arxiv.org/abs/1904.11492)
107 | - [Object Detection(PyTorch)](https://github.com/xvjiarui/GCNet)
108 | - [CCNet (ICCV 19)](https://arxiv.org/abs/1811.11721)
109 | - [Semantic Segmentation(PyTorch)](https://github.com/speedinghzl/CCNet)
110 | - [LatentGNN (ICML 19)](https://arxiv.org/abs/1905.11634)
111 | - [Object Detection(PyTorch)](https://github.com/latentgnn/LatentGNN-V1-PyTorch)
112 | - [TreeFilter (NIPS 19)]()
113 | - [Semantic Segmentation(PyTorch)](https://github.com/StevenGrove/TreeFilter-Torch)
114 | - [EMANet (ICCV 19)](https://arxiv.org/abs/1907.13426)
115 | - [Semantic Segmentation(PyTorch)](https://github.com/XiaLiPKU/EMANet)
116 |
--------------------------------------------------------------------------------
/contextlab/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__, short_version
2 |
3 |
4 | # from .layers.gcnet import GlobalContextBlock2d
5 | # from .layers.dual_attention import SelfAttention
6 | # from .layers.latentgnn import LatentGNN
7 |
8 | # from .layers.tree_filter import
9 |
--------------------------------------------------------------------------------
/contextlab/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcnet import GlobalContextBlock2d
2 | from .dual_attention import SelfAttention
3 | from .latentgnn import LatentGNN
4 | from .tree_filter import MinimumSpanningTree, TreeFilter2D
5 | from .cc_attention import CrissCrossAttention
6 | from .em_attention import EMAttentionUnit
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import CrissCrossAttention, ca_weight, ca_map
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import torch.autograd as autograd
5 | import torch.cuda.comm as comm
6 | import torch.nn.functional as F
7 | from torch.autograd.function import once_differentiable
8 | from torch.utils.cpp_extension import load
9 | import os, time
10 | import functools
11 |
12 | # curr_dir = os.path.dirname(os.path.abspath(__file__))
13 | # _src_path = os.path.join(curr_dir, "src")
14 | # _build_path = os.path.join(curr_dir, "build")
15 | # os.makedirs(_build_path, exist_ok=True)
16 | # rcca = load(name="rcca",
17 | # extra_cflags=["-O3"],
18 | # build_directory=_build_path,
19 | # verbose=True,
20 | # sources = [os.path.join(_src_path, f) for f in [
21 | # "lib_cffi.cpp", "ca.cu"
22 | # ]],
23 | # extra_cuda_cflags=["--expt-extended-lambda"])
24 | # import rcca
25 | from . import rcca
26 |
27 | def _check_contiguous(*args):
28 | if not all([mod is None or mod.is_contiguous() for mod in args]):
29 | raise ValueError("Non-contiguous input")
30 |
31 |
32 | class CA_Weight(autograd.Function):
33 | @staticmethod
34 | def forward(ctx, t, f):
35 | # Save context
36 | n, c, h, w = t.size()
37 | size = (n, h+w-1, h, w)
38 | weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device)
39 |
40 | rcca.ca_forward_cuda(t, f, weight)
41 |
42 | # Output
43 | ctx.save_for_backward(t, f)
44 |
45 | return weight
46 |
47 | @staticmethod
48 | @once_differentiable
49 | def backward(ctx, dw):
50 | t, f = ctx.saved_tensors
51 |
52 | dt = torch.zeros_like(t)
53 | df = torch.zeros_like(f)
54 |
55 | rcca.ca_backward_cuda(dw.contiguous(), t, f, dt, df)
56 |
57 | _check_contiguous(dt, df)
58 |
59 | return dt, df
60 |
61 | class CA_Map(autograd.Function):
62 | @staticmethod
63 | def forward(ctx, weight, g):
64 | # Save context
65 | out = torch.zeros_like(g)
66 | rcca.ca_map_forward_cuda(weight, g, out)
67 |
68 | # Output
69 | ctx.save_for_backward(weight, g)
70 |
71 | return out
72 |
73 | @staticmethod
74 | @once_differentiable
75 | def backward(ctx, dout):
76 | weight, g = ctx.saved_tensors
77 |
78 | dw = torch.zeros_like(weight)
79 | dg = torch.zeros_like(g)
80 |
81 | rcca.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg)
82 |
83 | _check_contiguous(dw, dg)
84 |
85 | return dw, dg
86 |
87 | ca_weight = CA_Weight.apply
88 | ca_map = CA_Map.apply
89 |
90 |
91 | class CrissCrossAttention(nn.Module):
92 | """ Criss-Cross Attention Module"""
93 | def __init__(self,in_dim):
94 | super(CrissCrossAttention,self).__init__()
95 | self.chanel_in = in_dim
96 |
97 | self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
98 | self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
99 | self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
100 | self.gamma = nn.Parameter(torch.zeros(1))
101 |
102 | def forward(self, x):
103 | proj_query = self.query_conv(x)
104 | proj_key = self.key_conv(x)
105 | proj_value = self.value_conv(x)
106 |
107 | energy = ca_weight(proj_query, proj_key)
108 | attention = F.softmax(energy, 1)
109 | out = ca_map(attention, proj_value)
110 | out = self.gamma*out + x
111 | return out
112 |
113 |
114 |
115 | __all__ = ["CrissCrossAttention", "ca_weight", "ca_map"]
116 |
117 |
118 | if __name__ == "__main__":
119 | ca = CrissCrossAttention(256).cuda()
120 | x = torch.zeros(1, 8, 10, 10).cuda() + 1
121 | y = torch.zeros(1, 8, 10, 10).cuda() + 2
122 | z = torch.zeros(1, 64, 10, 10).cuda() + 3
123 | out = ca(x, y, z)
124 | print (out)
125 |
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/setup.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/cc_attention/setup.py
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/src/ca.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "common.h"
6 | #include "ca.h"
7 |
8 |
9 | __global__ void ca_forward_kernel(const float *t, const float *f, float *weight, int num, int chn, int height, int width) {
10 | int x = blockIdx.x * blockDim.x + threadIdx.x;
11 | int y = blockIdx.y * blockDim.y + threadIdx.y;
12 | int sp = height * width;
13 | int len = height + width - 1;
14 | int z = blockIdx.z;
15 |
16 | if (x < width && y < height && z < height+width-1) {
17 | for (int batch = 0; batch < num; ++batch) {
18 | for (int plane = 0; plane < chn; ++plane) {
19 | float _t = t[(batch * chn + plane) * sp + y*width + x];
20 |
21 | if (z < width) {
22 | int i = z;
23 | float _f = f[(batch * chn + plane) * sp + y*width + i];
24 | weight[(batch * len + i) * sp + y*width + x] += _t*_f;
25 | } else {
26 | int i = z - width;
27 | int j = iy ? y : y-1;
86 |
87 | float _dw = dw[(batch * len + width + j) * sp + i*width + x];
88 | float _t = t[(batch * chn + plane) * sp + i*width + x];
89 | df[(batch * chn + plane) * sp + y*width + x] += _dw * _t;
90 | }
91 | }
92 |
93 | }
94 | }
95 |
96 |
97 | __global__ void ca_map_forward_kernel(const float *weight, const float *g, float *out, int num, int chn, int height, int width) {
98 | int x = blockIdx.x * blockDim.x + threadIdx.x;
99 | int y = blockIdx.y * blockDim.y + threadIdx.y;
100 | int sp = height * width;
101 | int len = height + width - 1;
102 | int plane = blockIdx.z;
103 |
104 | if (x < width && y < height && plane < chn) {
105 | for (int batch = 0; batch < num; ++batch) {
106 |
107 | for (int i = 0; i < width; ++i) {
108 | float _g = g[(batch * chn + plane) * sp + y*width + i];
109 | float _w = weight[(batch * len + i) * sp + y*width + x];
110 | out[(batch * chn + plane) * sp + y*width + x] += _g * _w;
111 | }
112 | for (int i = 0; i < height; ++i) {
113 | if (i == y) continue;
114 |
115 | int j = iy ? y : y-1;
176 |
177 | float _dout = dout[(batch * chn + plane) * sp + i*width + x];
178 | float _w = weight[(batch * len + width + j) * sp + i*width + x];
179 | dg[(batch * chn + plane) * sp + y*width + x] += _dout * _w;
180 | }
181 | }
182 | }
183 | }
184 |
185 | /*
186 | * Implementations
187 | */
188 | extern "C" int _ca_forward_cuda(int N, int C, int H, int W, const float *t,
189 | const float *f, float *weight, cudaStream_t stream) {
190 | // Run kernel
191 | dim3 threads(32, 32);
192 | int d1 = (W+threads.x-1)/threads.x;
193 | int d2 = (H+threads.y-1)/threads.y;
194 | int d3 = H+W;
195 | dim3 blocks(d1, d2, d3);
196 | ca_forward_kernel<<>>(t, f, weight, N, C, H, W);
197 |
198 | // Check for errors
199 | cudaError_t err = cudaGetLastError();
200 | if (err != cudaSuccess)
201 | return 0;
202 | else
203 | return 1;
204 | }
205 |
206 |
207 | extern "C" int _ca_backward_cuda(int N, int C, int H, int W, const float *dw, const float *t, const float *f, float *dt, float *df, cudaStream_t stream) {
208 | // Run kernel
209 | dim3 threads(32, 32);
210 | int d1 = (W+threads.x-1)/threads.x;
211 | int d2 = (H+threads.y-1)/threads.y;
212 | int d3 = C;
213 | dim3 blocks(d1, d2, d3);
214 | // printf("%f\n", dw[0]);
215 | ca_backward_kernel_t<<>>(dw, t, f, dt, N, C, H, W);
216 | ca_backward_kernel_f<<>>(dw, t, f, df, N, C, H, W);
217 |
218 | // Check for errors
219 | cudaError_t err = cudaGetLastError();
220 | if (err != cudaSuccess)
221 | return 0;
222 | else
223 | return 1;
224 | }
225 |
226 |
227 | extern "C" int _ca_map_forward_cuda(int N, int C, int H, int W, const float *weight, const float *g, float *out, cudaStream_t stream) {
228 | // Run kernel
229 | dim3 threads(32, 32);
230 | dim3 blocks((W+threads.x-1)/threads.x, (H+threads.y-1)/threads.y, C);
231 | ca_map_forward_kernel<<>>(weight, g, out, N, C, H, W);
232 |
233 | // Check for errors
234 | cudaError_t err = cudaGetLastError();
235 | if (err != cudaSuccess)
236 | return 0;
237 | else
238 | return 1;
239 | }
240 |
241 | extern "C" int _ca_map_backward_cuda(int N, int C, int H, int W, const float *dout, const float *weight, const float *g, float *dw, float *dg, cudaStream_t stream) {
242 | // Run kernel
243 | dim3 threads(32, 32);
244 | int d1 = (W+threads.x-1)/threads.x;
245 | int d2 = (H+threads.y-1)/threads.y;
246 | int d3 = H+W;
247 | dim3 blocks(d1, d2, d3);
248 | ca_map_backward_kernel_w<<>>(dout, weight, g, dw, N, C, H, W);
249 |
250 | d3 = C;
251 | blocks = dim3(d1, d2, d3);
252 | ca_map_backward_kernel_g<<>>(dout, weight, g, dg, N, C, H, W);
253 |
254 | // Check for errors
255 | cudaError_t err = cudaGetLastError();
256 | if (err != cudaSuccess)
257 | return 0;
258 | else
259 | return 1;
260 | }
261 |
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/src/ca.h:
--------------------------------------------------------------------------------
1 | #ifndef __CA__
2 | #define __CA__
3 |
4 | /*
5 | * Exported functions
6 | */
7 | extern "C" int _ca_forward_cuda(int N, int C, int H, int W, const float *t, const float *f, float *weight, cudaStream_t stream);
8 | extern "C" int _ca_backward_cuda(int N, int C, int H, int W, const float *dw, const float *t, const float *f, float *dt, float *df, cudaStream_t stream);
9 | extern "C" int _ca_map_forward_cuda(int N, int C, int H, int W, const float *weight, const float *g, float *out, cudaStream_t stream);
10 | extern "C" int _ca_map_backward_cuda(int N, int C, int H, int W, const float *dout, const float *weight, const float *g, float *dw, float *dg, cudaStream_t stream);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/src/common.h:
--------------------------------------------------------------------------------
1 | #ifndef __COMMON__
2 | #define __COMMON__
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | /*
12 | * Utility functions
13 | */
14 | template
15 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
16 | unsigned int mask = 0xffffffff) {
17 | #if CUDART_VERSION >= 9000
18 | return __shfl_xor_sync(mask, value, laneMask, width);
19 | #else
20 | return __shfl_xor(value, laneMask, width);
21 | #endif
22 | }
23 |
24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
25 |
26 | static int getNumThreads(int nElem) {
27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
28 | for (int i = 0; i != 5; ++i) {
29 | if (nElem <= threadSizes[i]) {
30 | return threadSizes[i];
31 | }
32 | }
33 | return MAX_BLOCK_SIZE;
34 | }
35 |
36 |
37 | #endif
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/src/lib_cffi.cpp:
--------------------------------------------------------------------------------
1 | // All functions assume that input and output tensors are already initialized
2 | // and have the correct dimensions
3 | #include
4 | #include
5 | #include "ca.h"
6 |
7 | extern THCState *state;
8 |
9 | int ca_forward_cuda(const at::Tensor& t, const at::Tensor& f, at::Tensor& weight) {
10 | cudaStream_t stream = THCState_getCurrentStream(state);
11 | int N, C, H, W;
12 | N = t.size(0); C = t.size(1); H = t.size(2); W = t.size(3);
13 | float * t_data = t.data();
14 | float * f_data = f.data();
15 | float * weight_data = weight.data();
16 | return _ca_forward_cuda(N, C, H, W, t_data, f_data, weight_data, stream);
17 | }
18 |
19 | int ca_backward_cuda(const at::Tensor& dw, const at::Tensor& t, const at::Tensor& f, at::Tensor& dt, at::Tensor& df) {
20 |
21 | cudaStream_t stream = THCState_getCurrentStream(state);
22 | int N, C, H, W;
23 | N = t.size(0); C = t.size(1); H = t.size(2); W = t.size(3);
24 | float * t_data = t.data();
25 | float * f_data = f.data();
26 | float * dt_data = dt.data();
27 | float * df_data = df.data();
28 | float * dw_data = dw.data();
29 | return _ca_backward_cuda(N, C, H, W, dw_data, t_data, f_data, dt_data, df_data, stream);
30 | }
31 |
32 | int ca_map_forward_cuda(const at::Tensor& weight, const at::Tensor& g, at::Tensor& out) {
33 | cudaStream_t stream = THCState_getCurrentStream(state);
34 |
35 | int N, C, H, W;
36 | N = g.size(0); C = g.size(1); H = g.size(2); W = g.size(3);
37 |
38 | const float *weight_data = weight.data();
39 | const float *g_data = g.data();
40 | float *out_data = out.data();
41 |
42 | return _ca_map_forward_cuda(N, C, H, W, weight_data, g_data, out_data, stream);
43 | }
44 |
45 | int ca_map_backward_cuda(const at::Tensor& dout, const at::Tensor& weight, const at::Tensor& g,
46 | at::Tensor& dw, at::Tensor& dg) {
47 | cudaStream_t stream = THCState_getCurrentStream(state);
48 |
49 | int N, C, H, W;
50 | N = dout.size(0); C = dout.size(1); H = dout.size(2); W = dout.size(3);
51 |
52 | const float *dout_data = dout.data();
53 | const float *weight_data = weight.data();
54 | const float *g_data = g.data();
55 | float *dw_data = dw.data();
56 | float *dg_data = dg.data();
57 |
58 | return _ca_map_backward_cuda(N, C, H, W, dout_data, weight_data, g_data, dw_data, dg_data, stream);
59 | }
60 |
61 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
62 | m.def("ca_forward_cuda", &ca_forward_cuda, "CA forward CUDA");
63 | m.def("ca_backward_cuda", &ca_backward_cuda, "CA backward CUDA");
64 | m.def("ca_map_forward_cuda", &ca_map_forward_cuda, "CA map forward CUDA");
65 | m.def("ca_map_backward_cuda", &ca_map_backward_cuda, "CA map backward CUDA");
66 | }
67 |
68 |
--------------------------------------------------------------------------------
/contextlab/layers/cc_attention/src/lib_cffi.h:
--------------------------------------------------------------------------------
1 | int ca_forward_cuda(const at::Tensor& t, const at::Tensor& *f, at::Tensor& weight);
2 | int ca_backward_cuda(const at::Tensor& dw, const at::Tensor& t, const at::Tensor& f, at::Tensor& dt, at::Tensor& df);
3 |
4 | int ca_map_forward_cuda(const at::Tensor& weight, const at::Tensor& g, at::Tensor& out);
5 | int ca_map_backward_cuda(const at::Tensor& dout, const at::Tensor& weight, const at::Tensor& g, at::Tensor& dw, at::Tensor& dg);
6 |
--------------------------------------------------------------------------------
/contextlab/layers/dual_attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .dual_attention import SelfAttention
--------------------------------------------------------------------------------
/contextlab/layers/dual_attention/dual_attention.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Description: Dual Attention Network
3 | @Author: Songyang Zhang
4 | @Email: sy.zhangbuaa@gmail.com
5 | @Date: 2019-07-13 17:10:33
6 | @LastEditTime: 2019-08-15 10:22:03
7 | @LastEditors: Songyang Zhang
8 | '''
9 |
10 | import numpy as np
11 |
12 | import torch
13 | import math
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 |
18 | class SelfAttention(nn.Module):
19 | """
20 | Self-attention Network/Non-local Network
21 |
22 | Args:
23 |
24 | Return:
25 |
26 | """
27 | def __init__(self, inplane, outplane, channel_stride=8):
28 | super(SelfAttention, self).__init__()
29 |
30 | self.inplane = inplane
31 | self.outplane = outplane
32 |
33 | self.inter_channel = inplane // channel_stride
34 |
35 | self.query_conv = nn.Conv2d(in_channels=inplane, out_channels=self.inter_channel, kernel_size=1)
36 | self.key_conv = nn.Conv2d(in_channels=inplane, out_channels=self.inter_channel, kernel_size=1)
37 |
38 | self.value_conv = nn.Conv2d(in_channels=inplane, out_channels=outplane, kernel_size=1)
39 | if outplane != inplane:
40 | self.input_conv = nn.Conv2d(in_channels=inplane, out_channels=outplane, kernel_size=1)
41 |
42 | self.gamma = nn.Parameter(torch.zeros(1))
43 |
44 | self.softmax = nn.Softmax(dim=-1)
45 |
46 | def forward(self, inputs):
47 | """
48 | Args:
49 | inputs: (B, C, H, W)
50 |
51 | Return:
52 | augmented_feature: (B, C, H, W)
53 | """
54 |
55 | B, C, H, W = inputs.size()
56 | query = self.query_conv(inputs).view(B, -1, H*W).permute(0, 2, 1) # B,N,C
57 | key = self.key_conv(inputs).view(B, -1, H*W) # B,C,N
58 |
59 | affinity_matrix = torch.bmm(query, key)
60 | affinity_matrix = self.softmax(affinity_matrix) # B, N, N
61 |
62 | value = self.value_conv(inputs).view(B, -1, H*W)
63 |
64 | out = torch.bmm(value, affinity_matrix) # B,C',N * B,N,N = B,C',N
65 | if self.inplane != self.outplane:
66 | inputs = self.input_conv(inputs)
67 | augmented_feature = self.gamma * out.view(B,-1, H, W) + inputs
68 |
69 | return augmented_feature
70 |
71 | # if __name__ == "__main__":
72 | # inputs = torch.randn(1,1024, 20,20)
73 | # model = SelfAttention(1024, 512,channel_stride=8)
74 |
75 | # out = model(inputs)
76 | # import pdb; pdb.set_trace()
77 |
78 |
79 |
--------------------------------------------------------------------------------
/contextlab/layers/em_attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .emu import EMAttentionUnit
--------------------------------------------------------------------------------
/contextlab/layers/em_attention/emu.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Description: Implement the Expectation-Maximization Attention Unit(EMAU)
3 | @Author: Songyang Zhang
4 | @Email: sy.zhangbuaa@gmail.com
5 | @License: (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University:
6 | @Date: 2019-09-28 13:54:35
7 | @LastEditors: Songyang Zhang
8 | @LastEditTime: 2019-09-28 14:30:26
9 | '''
10 | import math
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.nn import functional as F
15 | from torch.nn.modules.batchnorm import _BatchNorm
16 |
17 | class EMAttentionUnit(nn.Module):
18 | '''The Expectation-Maximization Attention Unit (EMAU).
19 | Arguments:
20 | c (int): The input and output channel number.
21 | k (int): The number of the bases.
22 | stage_num (int): The iteration number for EM.
23 | '''
24 | def __init__(self, c, k, stage_num=3, norm_layer=nn.BatchNorm2d):
25 | super(EMAttentionUnit, self).__init__()
26 | self.stage_num = stage_num
27 |
28 | mu = torch.Tensor(1, c, k)
29 | mu.normal_(0, math.sqrt(2. / k)) # Init with Kaiming Norm.
30 | mu = self._l2norm(mu, dim=1)
31 | self.register_buffer('mu', mu)
32 |
33 | self.conv1 = nn.Conv2d(c, c, 1)
34 | self.conv2 = nn.Sequential(
35 | nn.Conv2d(c, c, 1, bias=False),
36 | norm_layer(c))
37 |
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, math.sqrt(2. / n))
42 | elif isinstance(m, _BatchNorm):
43 | m.weight.data.fill_(1)
44 | if m.bias is not None:
45 | m.bias.data.zero_()
46 |
47 |
48 | def forward(self, x):
49 | idn = x
50 | # The first 1x1 conv
51 | x = self.conv1(x)
52 |
53 | # The EM Attention
54 | b, c, h, w = x.size()
55 | x = x.view(b, c, h*w) # b * c * n
56 | mu = self.mu.repeat(b, 1, 1) # b * c * k
57 | with torch.no_grad():
58 | for i in range(self.stage_num):
59 | x_t = x.permute(0, 2, 1) # b * n * c
60 | z = torch.bmm(x_t, mu) # b * n * k
61 | z = F.softmax(z, dim=2) # b * n * k
62 | z_ = z / (1e-6 + z.sum(dim=1, keepdim=True))
63 | mu = torch.bmm(x, z_) # b * c * k
64 | mu = self._l2norm(mu, dim=1)
65 |
66 | z_t = z.permute(0, 2, 1) # b * k * n
67 | x = mu.matmul(z_t) # b * c * n
68 | x = x.view(b, c, h, w) # b * c * h * w
69 | x = F.relu(x, inplace=True)
70 |
71 | # The second 1x1 conv
72 | x = self.conv2(x)
73 | x = x + idn
74 | x = F.relu(x, inplace=True)
75 |
76 | return x, mu
77 |
78 | def _l2norm(self, inp, dim):
79 | '''Normlize the inp tensor with l2-norm.
80 | Returns a tensor where each sub-tensor of input along the given dim is
81 | normalized such that the 2-norm of the sub-tensor is equal to 1.
82 | Arguments:
83 | inp (tensor): The input tensor.
84 | dim (int): The dimension to slice over to get the ssub-tensors.
85 | Returns:
86 | (tensor) The normalized tensor.
87 | '''
88 | return inp / (1e-6 + inp.norm(dim=dim, keepdim=True))
--------------------------------------------------------------------------------
/contextlab/layers/gcnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcnet import GlobalContextBlock2d
--------------------------------------------------------------------------------
/contextlab/layers/gcnet/gcnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Description: GCNet file from the MSRA
3 | @Author: Songyang Zhang
4 | @Email: sy.zhangbuaa@gmail.com
5 | @Date: 2019-08-15 10:20:38
6 | @LastEditors: Songyang Zhang
7 | @LastEditTime: 2019-09-27 18:40:35
8 | '''
9 |
10 | import torch
11 | from torch import nn
12 | # from mmcv.cnn import constant_init, kaiming_init
13 | from contextlab.utils.weight_init import constant_init, kaiming_init
14 |
15 | def last_zero_init(m):
16 | if isinstance(m, nn.Sequential):
17 | constant_init(m[-1], val=0)
18 | m[-1].inited = True
19 | else:
20 | constant_init(m, val=0)
21 | m.inited = True
22 |
23 |
24 | class GlobalContextBlock2d(nn.Module):
25 |
26 | def __init__(self, inplanes, stride, pool, fusions):
27 | super(GlobalContextBlock2d, self).__init__()
28 | assert pool in ['avg', 'att']
29 | assert all([f in ['channel_add', 'channel_mul'] for f in fusions])
30 | assert len(fusions) > 0, 'at least one fusion should be used'
31 | self.inplanes = inplanes
32 | self.planes = inplanes // stride
33 | self.pool = pool
34 | self.fusions = fusions
35 | if 'att' in pool:
36 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
37 | self.softmax = nn.Softmax(dim=2)
38 | else:
39 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
40 | if 'channel_add' in fusions:
41 | self.channel_add_conv = nn.Sequential(
42 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
43 | nn.LayerNorm([self.planes, 1, 1]),
44 | nn.ReLU(inplace=True),
45 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
46 | )
47 | else:
48 | self.channel_add_conv = None
49 | if 'channel_mul' in fusions:
50 | self.channel_mul_conv = nn.Sequential(
51 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
52 | nn.LayerNorm([self.planes, 1, 1]),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
55 | )
56 | else:
57 | self.channel_mul_conv = None
58 | self.reset_parameters()
59 |
60 | def reset_parameters(self):
61 | if self.pool == 'att':
62 | kaiming_init(self.conv_mask, mode='fan_in')
63 | self.conv_mask.inited = True
64 |
65 | if self.channel_add_conv is not None:
66 | last_zero_init(self.channel_add_conv)
67 | if self.channel_mul_conv is not None:
68 | last_zero_init(self.channel_mul_conv)
69 |
70 | def spatial_pool(self, x):
71 | batch, channel, height, width = x.size()
72 | if self.pool == 'att':
73 | input_x = x
74 | # [N, C, H * W]
75 | input_x = input_x.view(batch, channel, height * width)
76 | # [N, 1, C, H * W]
77 | input_x = input_x.unsqueeze(1)
78 | # [N, 1, H, W]
79 | context_mask = self.conv_mask(x)
80 | # [N, 1, H * W]
81 | context_mask = context_mask.view(batch, 1, height * width)
82 | # [N, 1, H * W]
83 | context_mask = self.softmax(context_mask)
84 | # [N, 1, H * W, 1]
85 | context_mask = context_mask.unsqueeze(3)
86 | # [N, 1, C, 1]
87 | context = torch.matmul(input_x, context_mask)
88 | # [N, C, 1, 1]
89 | context = context.view(batch, channel, 1, 1)
90 | else:
91 | # [N, C, 1, 1]
92 | context = self.avg_pool(x)
93 |
94 | return context
95 |
96 | def forward(self, x):
97 | # [N, C, 1, 1]
98 | context = self.spatial_pool(x)
99 |
100 | if self.channel_mul_conv is not None:
101 | # [N, C, 1, 1]
102 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
103 | out = x * channel_mul_term
104 | else:
105 | out = x
106 | if self.channel_add_conv is not None:
107 | # [N, C, 1, 1]
108 | channel_add_term = self.channel_add_conv(context)
109 | out = out + channel_add_term
110 |
111 | return out
112 |
--------------------------------------------------------------------------------
/contextlab/layers/latentgnn/__init__.py:
--------------------------------------------------------------------------------
1 | from .latentgnn import LatentGNN
--------------------------------------------------------------------------------
/contextlab/layers/latentgnn/latentgnn.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Description: LatentGNN V1 Version
3 | @Author: Songyang Zhang
4 | @Email: sy.zhangbuaa@gmail.com
5 | @License: (C)Copyright 2019-2020, PLUS Group@ShanhaiTech University:
6 | @Date: 2019-08-15 10:23:24
7 | @LastEditors: Songyang Zhang
8 | @LastEditTime: 2019-08-15 10:25:01
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | import numpy as np
16 |
17 | class LatentGNN(nn.Module):
18 | """
19 | Latent Graph Neural Network for Non-local Relations Learning
20 | Args:
21 | in_channels (int): Number of channels in the input feature
22 | latent_dims (list): List of latent dimensions
23 | channel_stride (int): Channel reduction factor. Default: 4
24 | num_kernels (int): Number of latent kernels used. Default: 1
25 | mode (str): Mode of bipartite graph message propagation. Default: 'asymmetric'.
26 | without_residual (bool): Flag of use residual connetion. Default: False
27 | norm_layer (nn.Module): Module used for batch normalization. Default: nn.BatchNorm2d.
28 | norm_func (function): Function used for normalization. Default: F.normalize
29 | graph_conv_flag (bool): Flag of use graph convolution layer. Default: False
30 | """
31 | def __init__(self, in_channels, latent_dims,
32 | channel_stride=4, num_kernels=1,
33 | mode='asymmetric', without_residual=False,
34 | norm_layer=nn.BatchNorm2d, norm_func=F.normalize,
35 | graph_conv_flag=False):
36 | super(LatentGNN, self).__init__()
37 | self.without_resisual = without_residual
38 | self.num_kernels = num_kernels
39 | self.mode = mode
40 | self.norm_func = norm_func
41 |
42 | inter_channel = in_channels // channel_stride
43 |
44 | # Reduce the channel dimension for efficiency
45 | if mode == 'asymmetric':
46 | self.down_channel_v2l = nn.Sequential(
47 | nn.Conv2d(in_channels=in_channels,
48 | out_channels=inter_channel,
49 | kernel_size=1, padding=0, bias=False),
50 | norm_layer(inter_channel),
51 | )
52 | # nn.init.kaiming_uniform_(self.down_channel_v2l[0].weight, a=1)
53 | # nn.init.kaiming_uniform_(self.down_channel_v2l[0].weight, mode='fan_in')
54 | self.down_channel_l2v = nn.Sequential(
55 | nn.Conv2d(in_channels=in_channels,
56 | out_channels=inter_channel,
57 | kernel_size=1, padding=0, bias=False),
58 | norm_layer(inter_channel),
59 | )
60 | # nn.init.kaiming_uniform_(self.down_channel_l2v[0].weight, a=1)
61 | # nn.init.kaiming_uniform_(self.down_channel_l2v[0].weight, mode='fan_in')
62 |
63 | elif mode == 'symmetric':
64 | self.down_channel = nn.Sequential(
65 | nn.Conv2d(in_channels=in_channels,
66 | out_channels=inter_channel,
67 | kernel_size=1, padding=0, bias=False),
68 | norm_layer(inter_channel),
69 | )
70 | # nn.init.kaiming_uniform_(self.down_channel[0].weight, a=1)
71 | # nn.init.kaiming_uniform_(self.down_channel[0].weight, mode='fan_in')
72 | else:
73 | raise NotImplementedError
74 |
75 | # Define the latentgnn kernel
76 | assert len(latent_dims) == num_kernels, 'Latent dimensions mismatch with number of kernels'
77 |
78 | for i in range(num_kernels):
79 | self.add_module('LatentGNN_Kernel_{}'.format(i),
80 | LatentGNN_Kernel(in_channels=inter_channel,
81 | num_kernels=num_kernels,
82 | latent_dim=latent_dims[i],
83 | norm_layer=norm_layer,
84 | norm_func=norm_func,
85 | mode=mode,
86 | graph_conv_flag=graph_conv_flag))
87 | # Increase the channel for the output
88 | self.up_channel = nn.Sequential(
89 | nn.Conv2d(in_channels=inter_channel*num_kernels,
90 | out_channels=in_channels,
91 | kernel_size=1, padding=0,bias=False),
92 | norm_layer(in_channels),
93 | )
94 | # nn.init.kaiming_uniform_(self.up_channel[0].weight, a=1)
95 | # nn.init.kaiming_uniform_(self.up_channel[0].weight, mode='fan_in')
96 |
97 | # Residual Connection
98 | self.gamma = nn.Parameter(torch.zeros(1))
99 |
100 | def forward(self, conv_feature):
101 | # Generate visible space feature
102 | if self.mode == 'asymmetric':
103 | v2l_conv_feature = self.down_channel_v2l(conv_feature)
104 | l2v_conv_feature = self.down_channel_l2v(conv_feature)
105 | v2l_conv_feature = self.norm_func(v2l_conv_feature, dim=1)
106 | l2v_conv_feature = self.norm_func(l2v_conv_feature, dim=1)
107 | elif self.mode == 'symmetric':
108 | v2l_conv_feature = self.norm_func(self.down_channel(conv_feature), dim=1)
109 | l2v_conv_feature = None
110 | out_features = []
111 | for i in range(self.num_kernels):
112 | out_features.append(eval('self.LatentGNN_Kernel_{}'.format(i))(v2l_conv_feature, l2v_conv_feature))
113 |
114 | out_features = torch.cat(out_features, dim=1) if self.num_kernels > 1 else out_features[0]
115 |
116 | out_features = self.up_channel(out_features)
117 |
118 | if self.without_resisual:
119 | return out_features
120 | else:
121 | return conv_feature + out_features*self.gamma
122 |
123 | class LatentGNN_Kernel(nn.Module):
124 | """
125 | A LatentGNN Kernel Implementation
126 | Args:
127 | """
128 | def __init__(self, in_channels, num_kernels,
129 | latent_dim, norm_layer,
130 | norm_func, mode, graph_conv_flag):
131 | super(LatentGNN_Kernel, self).__init__()
132 | self.mode = mode
133 | self.norm_func = norm_func
134 | #----------------------------------------------
135 | # Step1 & 3: Visible-to-Latent & Latent-to-Visible
136 | #----------------------------------------------
137 |
138 | if mode == 'asymmetric':
139 | self.psi_v2l = nn.Sequential(
140 | nn.Conv2d(in_channels=in_channels,
141 | out_channels=latent_dim,
142 | kernel_size=1, padding=0,
143 | bias=False),
144 | norm_layer(latent_dim),
145 | nn.ReLU(inplace=True),
146 | )
147 | # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, a=1)
148 | # nn.init.kaiming_uniform_(self.psi_v2l[0].weight, mode='fan_in')
149 | self.psi_l2v = nn.Sequential(
150 | nn.Conv2d(in_channels=in_channels,
151 | out_channels=latent_dim,
152 | kernel_size=1, padding=0,
153 | bias=False),
154 | norm_layer(latent_dim),
155 | nn.ReLU(inplace=True),
156 | )
157 | # nn.init.kaiming_uniform_(self.psi_l2v[0].weight, a=1)
158 | # nn.init.kaiming_uniform_(self.psi_l2v[0].weight, mode='fan_in')
159 | elif mode == 'symmetric':
160 | self.psi = nn.Sequential(
161 | nn.Conv2d(in_channels=in_channels,
162 | out_channels=latent_dim,
163 | kernel_size=1, padding=0,
164 | bias=False),
165 | norm_layer(latent_dim),
166 | nn.ReLU(inplace=True),
167 | )
168 | # nn.init.kaiming_uniform_(self.psi[0].weight, a=1)
169 | # nn.init.kaiming_uniform_(self.psi[0].weight, mode='fan_in')
170 |
171 | #----------------------------------------------
172 | # Step2: Latent Messge Passing
173 | #----------------------------------------------
174 | self.graph_conv_flag = graph_conv_flag
175 | if graph_conv_flag:
176 | self.GraphConvWeight = nn.Sequential(
177 | # nn.Linear(in_channels, in_channels,bias=False),
178 | nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False),
179 | norm_layer(in_channels),
180 | nn.ReLU(inplace=True),
181 | )
182 | nn.init.normal_(self.GraphConvWeight[0].weight, std=0.01)
183 |
184 | def forward(self, v2l_conv_feature, l2v_conv_feature):
185 | B, C, H, W = v2l_conv_feature.shape
186 |
187 | # Generate Bipartite Graph Adjacency Matrix
188 | if self.mode == 'asymmetric':
189 | v2l_graph_adj = self.psi_v2l(v2l_conv_feature)
190 | l2v_graph_adj = self.psi_l2v(l2v_conv_feature)
191 | v2l_graph_adj = self.norm_func(v2l_graph_adj.view(B,-1, H*W), dim=2)
192 | l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=1)
193 | # l2v_graph_adj = self.norm_func(l2v_graph_adj.view(B,-1, H*W), dim=2)
194 | elif self.mode == 'symmetric':
195 | assert l2v_conv_feature is None
196 | l2v_graph_adj = v2l_graph_adj = self.norm_func(self.psi(v2l_conv_feature).view(B,-1, H*W), dim=1)
197 |
198 | #----------------------------------------------
199 | # Step1 : Visible-to-Latent
200 | #----------------------------------------------
201 | latent_node_feature = torch.bmm(v2l_graph_adj, v2l_conv_feature.view(B, -1, H*W).permute(0,2,1))
202 | # if self.graph_conv_flag:
203 | # latent_node_feature = F.relu(latent_node_feature, inplace=True)
204 | #----------------------------------------------
205 | # Step2 : Latent-to-Latent
206 | #----------------------------------------------
207 | # Generate Dense-connected Graph Adjacency Matrix
208 | latent_node_feature_n = self.norm_func(latent_node_feature, dim=-1)
209 | affinity_matrix = torch.bmm(latent_node_feature_n, latent_node_feature_n.permute(0,2,1))
210 | affinity_matrix = F.softmax(affinity_matrix, dim=-1)
211 |
212 | latent_node_feature = torch.bmm(affinity_matrix, latent_node_feature)
213 |
214 | #----------------------------------------------
215 | # Step3: Latent-to-Visible
216 | #----------------------------------------------
217 | visible_feature = torch.bmm(latent_node_feature.permute(0,2,1), l2v_graph_adj).view(B, -1, H, W)
218 | # if self.graph_conv_flag:
219 | # visible_feature = F.relu(visible_feature, inplace=True)
220 |
221 | if self.graph_conv_flag:
222 | visible_feature = self.GraphConvWeight(visible_feature)
223 |
224 | return visible_feature
--------------------------------------------------------------------------------
/contextlab/layers/non_local/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/non_local/__init__.py
--------------------------------------------------------------------------------
/contextlab/layers/non_local/non_local.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/contextlab/layers/non_local/non_local.py
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules.tree_filter import MinimumSpanningTree
2 | from .modules.tree_filter import TreeFilter2D
3 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/functions/bfs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from torch.nn.modules.utils import _pair
6 |
7 | # import ..tree_filter_cuda as _C
8 | # import tree_filter_cuda as _C
9 | from . import tree_filter_cuda as _C
10 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C
11 |
12 | class _BFS(Function):
13 | @staticmethod
14 | def forward(ctx, edge_index, max_adj_per_vertex):
15 | sorted_index, sorted_parent, sorted_child =\
16 | _C.bfs_forward(edge_index, max_adj_per_vertex)
17 | return sorted_index, sorted_parent, sorted_child
18 |
19 | bfs = _BFS.apply
20 |
21 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/functions/mst.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from torch.nn.modules.utils import _pair
6 | from . import tree_filter_cuda as _C
7 |
8 | # import tree_filter_cuda as _C
9 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C
10 |
11 | class _MST(Function):
12 | @staticmethod
13 | def forward(ctx, edge_index, edge_weight, vertex_index):
14 | edge_out = _C.mst_forward(edge_index, edge_weight, vertex_index)
15 | return edge_out
16 |
17 | @staticmethod
18 | @once_differentiable
19 | def backward(ctx, grad_output):
20 | return None, None, None
21 |
22 | mst = _MST.apply
23 |
24 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/functions/refine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from torch.nn.modules.utils import _pair
6 | from . import tree_filter_cuda as _C
7 |
8 | # import tree_filter_cuda as _C
9 | # from contextlab.layers.tree_filter import tree_filter_cuda as _C
10 |
11 | class _Refine(Function):
12 | @staticmethod
13 | def forward(ctx, feature_in, edge_weight, sorted_index, sorted_parent, sorted_child):
14 | feature_out, feature_aggr, feature_aggr_up, weight_sum, weight_sum_up, =\
15 | _C.refine_forward(feature_in, edge_weight, sorted_index, sorted_parent, sorted_child)
16 |
17 | ctx.save_for_backward(feature_in, edge_weight, sorted_index, sorted_parent,
18 | sorted_child, feature_out, feature_aggr, feature_aggr_up, weight_sum,
19 | weight_sum_up)
20 | return feature_out
21 |
22 | @staticmethod
23 | @once_differentiable
24 | def backward(ctx, grad_output):
25 | feature_in, edge_weight, sorted_index, sorted_parent,\
26 | sorted_child, feature_out, feature_aggr, feature_aggr_up, weight_sum,\
27 | weight_sum_up, = ctx.saved_tensors;
28 |
29 | grad_feature = _C.refine_backward_feature(feature_in, edge_weight, sorted_index,
30 | sorted_parent, sorted_child, feature_out, feature_aggr, feature_aggr_up,
31 | weight_sum, weight_sum_up, grad_output)
32 | grad_weight = _C.refine_backward_weight(feature_in, edge_weight, sorted_index,
33 | sorted_parent, sorted_child, feature_out, feature_aggr, feature_aggr_up,
34 | weight_sum, weight_sum_up, grad_output)
35 |
36 | return grad_feature, grad_weight, None, None, None
37 |
38 | refine = _Refine.apply
39 |
40 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/modules/tree_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.distributed as dist
4 |
5 | from ..functions.mst import mst
6 | from ..functions.bfs import bfs
7 | from ..functions.refine import refine
8 |
9 | class MinimumSpanningTree(nn.Module):
10 | def __init__(self, distance_func):
11 | super(MinimumSpanningTree, self).__init__()
12 | self.distance_func = distance_func
13 |
14 | @staticmethod
15 | def _build_matrix_index(fm):
16 | batch, height, width = (fm.shape[0], *fm.shape[2:])
17 | row = torch.arange(width, dtype=torch.int32, device=fm.device).unsqueeze(0)
18 | col = torch.arange(height, dtype=torch.int32, device=fm.device).unsqueeze(1)
19 | raw_index = row + col * width
20 | row_index = torch.stack([raw_index[:-1, :], raw_index[1:, :]], 2)
21 | col_index = torch.stack([raw_index[:, :-1], raw_index[:, 1:]], 2)
22 | index = torch.cat([row_index.reshape(1, -1, 2),
23 | col_index.reshape(1, -1, 2)], 1)
24 | index = index.expand(batch, -1, -1)
25 | return index
26 |
27 | def _build_feature_weight(self, fm):
28 | batch = fm.shape[0]
29 | weight_row = self.distance_func(fm[:, :, :-1, :], fm[:, :, 1:, :])
30 | weight_col = self.distance_func(fm[:, :, :, :-1], fm[:, :, :, 1:])
31 | weight_row = weight_row.reshape([batch, -1])
32 | weight_col = weight_col.reshape([batch, -1])
33 | weight = torch.cat([weight_row, weight_col], dim=1) + 1
34 | return weight
35 |
36 | def forward(self, guide_in):
37 | with torch.no_grad():
38 | index = self._build_matrix_index(guide_in)
39 | weight = self._build_feature_weight(guide_in)
40 | tree = mst(index, weight, guide_in.shape[2] * guide_in.shape[3])
41 | return tree
42 |
43 |
44 | class TreeFilter2D(nn.Module):
45 | def __init__(self, groups=1, distance_func=None, enable_log=False):
46 | super(TreeFilter2D, self).__init__()
47 | self.groups = groups
48 | self.enable_log = enable_log
49 | if distance_func is None:
50 | self.distance_func = self.norm2_distance
51 | else:
52 | self.distance_func = distance_func
53 |
54 | @staticmethod
55 | def norm2_distance(fm_ref, fm_tar):
56 | diff = fm_ref - fm_tar
57 | weight = (diff * diff).sum(dim=1)
58 | return weight
59 |
60 | @staticmethod
61 | def batch_index_opr(data, index):
62 | with torch.no_grad():
63 | channel = data.shape[1]
64 | index = index.unsqueeze(1).expand(-1, channel, -1).long()
65 | data = torch.gather(data, 2, index)
66 | return data
67 |
68 | def build_edge_weight(self, fm, sorted_index, sorted_parent):
69 | batch = fm.shape[0]
70 | channel = fm.shape[1]
71 | vertex = fm.shape[2] * fm.shape[3]
72 |
73 | fm = fm.reshape([batch, channel, -1])
74 | fm_source = self.batch_index_opr(fm, sorted_index)
75 | fm_target = self.batch_index_opr(fm_source, sorted_parent)
76 | fm_source = fm_source.reshape([-1, channel // self.groups, vertex])
77 | fm_target = fm_target.reshape([-1, channel // self.groups, vertex])
78 |
79 | edge_weight = self.distance_func(fm_source, fm_target)
80 | edge_weight = torch.exp(-edge_weight)
81 | return edge_weight
82 |
83 | def split_group(self, feature_in, *tree_orders):
84 | feature_in = feature_in.reshape(feature_in.shape[0] * self.groups,
85 | feature_in.shape[1] // self.groups,
86 | -1)
87 | returns = [feature_in.contiguous()]
88 | for order in tree_orders:
89 | order = order.unsqueeze(1).expand(order.shape[0], self.groups, *order.shape[1:])
90 | order = order.reshape(-1, *order.shape[2:])
91 | returns.append(order.contiguous())
92 | return tuple(returns)
93 |
94 | def print_info(self, edge_weight):
95 | edge_weight = edge_weight.clone()
96 | info = torch.stack([edge_weight.mean(), edge_weight.std(), edge_weight.max(), edge_weight.min()])
97 | if self.training and dist.is_initialized():
98 | dist.all_reduce(info / dist.get_world_size())
99 | info_str = (float(x) for x in info)
100 | if dist.get_rank() == 0:
101 | print('Mean:{0:.4f}, Std:{1:.4f}, Max:{2:.4f}, Min:{3:.4f}'.format(*info_str))
102 | else:
103 | info_str = [float(x) for x in info]
104 | print('Mean:{0:.4f}, Std:{1:.4f}, Max:{2:.4f}, Min:{3:.4f}'.format(*info_str))
105 |
106 | def forward(self, feature_in, embed_in, tree):
107 | ori_shape = feature_in.shape
108 | sorted_index, sorted_parent, sorted_child = bfs(tree, 4)
109 | edge_weight = self.build_edge_weight(embed_in, sorted_index, sorted_parent)
110 | with torch.no_grad():
111 | if self.enable_log: self.print_info(edge_weight)
112 | feature_in, sorted_index, sorted_parent, sorted_child = \
113 | self.split_group(feature_in, sorted_index, sorted_parent, sorted_child)
114 | feature_out = refine(feature_in, edge_weight, sorted_index,
115 | sorted_parent, sorted_child)
116 | feature_out = feature_out.reshape(ori_shape)
117 | return feature_out
118 |
119 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import shutil
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | this_dir = os.path.dirname(os.path.abspath(__file__))
9 | extensions_dir = os.path.join(this_dir, "src")
10 |
11 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
12 | source_cpu = glob.glob(os.path.join(extensions_dir, "*", "*.cpp"))
13 | source_cuda = glob.glob(os.path.join(extensions_dir, "*", "*.cu"))
14 |
15 | if torch.cuda.is_available():
16 | if 'LD_LIBRARY_PATH' not in os.environ:
17 | raise Exception('LD_LIBRARY_PATH is not set.')
18 | cuda_lib_path = os.environ['LD_LIBRARY_PATH'].split(':')
19 | sources = source_cpu + source_cuda + main_file
20 | else:
21 | raise Exception('This implementation is only avaliable for CUDA devices.')
22 |
23 | print(sources)
24 |
25 | setup(
26 | name='tree_filter',
27 | version="0.1",
28 | description="learnable tree filter for pytorch",
29 | ext_modules=[
30 | CUDAExtension(
31 | name='tree_filter_cuda',
32 | include_dirs=[extensions_dir],
33 | sources=sources,
34 | library_dirs=cuda_lib_path,
35 | extra_compile_args={'cxx':['-O3'],
36 | 'nvcc':['-O3']})
37 | ],
38 | cmdclass={
39 | 'build_ext': BuildExtension
40 | })
41 |
42 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/bfs/bfs.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 | #include
15 |
16 | #define CUDA_CHECK(call) if((call) != cudaSuccess) {cudaError_t err = cudaGetLastError(); std::cout << "CUDA error calling ""#call"", code is " << err << std::endl;}
17 |
18 | #define CUDA_NUM_THREADS 64
19 | #define GET_CUDA_BLOCKS(N) ceil((float)N / CUDA_NUM_THREADS)
20 |
21 | __global__ void adj_vec_kernel(
22 | int batch_size,
23 | int * edge_index,
24 | int vertex_count,
25 | int * adj_vec,
26 | int * adj_vec_len,
27 | int max_adj_per_node){
28 |
29 | const int edge_count = vertex_count - 1;
30 | const int batch_idx = blockIdx.x;
31 | const int thread_idx = threadIdx.x;
32 | const int thread_count = blockDim.x;
33 |
34 | edge_index += batch_idx * edge_count * 2;
35 | adj_vec += batch_idx * vertex_count * max_adj_per_node;
36 | adj_vec_len += batch_idx * vertex_count;
37 |
38 | for (int i = thread_idx; i < edge_count; i += thread_count){
39 | int source = edge_index[2 * i];
40 | int target = edge_index[2 * i + 1];
41 | int source_len = atomicAdd(&(adj_vec_len[source]), 1);
42 | adj_vec[source * max_adj_per_node + source_len] = target;
43 | int target_len = atomicAdd(&(adj_vec_len[target]), 1);
44 | adj_vec[target * max_adj_per_node + target_len] = source;
45 | }
46 | }
47 |
48 | __global__ void breadth_first_sort_kernel(
49 | int * sorted_index,
50 | int * sorted_parent_index,
51 | int * sorted_child_index,
52 | int * adj_vec,
53 | int * adj_vec_len,
54 | int * parent_index,
55 | int batch_size,
56 | int vertex_count,
57 | int max_adj_per_node){
58 |
59 | const int batch_idx = blockIdx.x;
60 | const int thread_idx = threadIdx.x;
61 | const int thread_count = blockDim.x;
62 |
63 | adj_vec += batch_idx * vertex_count * max_adj_per_node;
64 | adj_vec_len += batch_idx * vertex_count;
65 | parent_index += batch_idx * vertex_count;
66 | sorted_index += batch_idx * vertex_count;
67 | sorted_parent_index += batch_idx * vertex_count;
68 | sorted_child_index += batch_idx * vertex_count * max_adj_per_node;
69 |
70 | __shared__ int sorted_len;
71 | if (thread_idx == 0) {
72 | sorted_len = 1;
73 | parent_index[0] = 0;
74 | sorted_index[0] = 0;
75 | sorted_parent_index[0] = 0;
76 | }
77 | __syncthreads();
78 |
79 | int i = thread_idx;
80 | while (i < vertex_count){
81 | if ((sorted_index[i] > 0) || (i == 0)){
82 | int child_index = 0;
83 | int par = parent_index[i];
84 | int cur = sorted_index[i];
85 | for (int j = 0; j < adj_vec_len[cur]; j++){
86 | int child = adj_vec[cur * max_adj_per_node + j];
87 | if (child != par){
88 | int pos = atomicAdd(&(sorted_len), 1);
89 | sorted_index[pos] = child;
90 | parent_index[pos] = cur;
91 | sorted_parent_index[pos] = i;
92 | sorted_child_index[i * max_adj_per_node + child_index] = pos;
93 | child_index++;
94 | }
95 | }
96 | i += thread_count;
97 | }
98 | __syncthreads();
99 | }
100 | }
101 |
102 | std::tuple
103 | bfs_forward(
104 | const at::Tensor & edge_index_tensor,
105 | int max_adj_per_node){
106 |
107 | int batch_size = edge_index_tensor.size(0);
108 | int vertex_count = edge_index_tensor.size(1) + 1;
109 |
110 | auto options = edge_index_tensor.options();
111 | auto sorted_index_tensor = at::zeros({batch_size, vertex_count}, options);
112 | auto sorted_parent_tensor = at::zeros({batch_size, vertex_count}, options);
113 | auto sorted_child_tensor = at::zeros({batch_size, vertex_count, max_adj_per_node}, options);
114 | auto adj_vec_tensor = at::zeros({batch_size, vertex_count, max_adj_per_node}, options);
115 | auto adj_vec_len_tensor = at::zeros({batch_size, vertex_count}, options);
116 | auto parent_index_tensor = at::zeros({batch_size, vertex_count}, options);
117 |
118 | int * edge_index = edge_index_tensor.contiguous().data();
119 | int * sorted_index = sorted_index_tensor.contiguous().data();
120 | int * sorted_parent = sorted_parent_tensor.contiguous().data();
121 | int * sorted_child = sorted_child_tensor.contiguous().data();
122 | int * adj_vec = adj_vec_tensor.contiguous().data();
123 | int * adj_vec_len = adj_vec_len_tensor.contiguous().data();
124 | int * parent_index = parent_index_tensor.contiguous().data();
125 |
126 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
127 |
128 | dim3 block_dims(CUDA_NUM_THREADS, 1, 1), grid_dims(batch_size, 1, 1);
129 | adj_vec_kernel <<< grid_dims, block_dims, 0, stream >>>(
130 | batch_size, edge_index, vertex_count, adj_vec, adj_vec_len, max_adj_per_node);
131 |
132 | breadth_first_sort_kernel <<< grid_dims, block_dims, 1, stream >>>(
133 | sorted_index, sorted_parent, sorted_child, adj_vec, adj_vec_len, parent_index,
134 | batch_size, vertex_count, max_adj_per_node);
135 |
136 | return std::make_tuple(sorted_index_tensor, sorted_parent_tensor, sorted_child_tensor);
137 | }
138 |
139 |
140 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/bfs/bfs.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | extern std::tuple
5 | bfs_forward(
6 | const at::Tensor & edge_index_tensor,
7 | int max_adj_per_node
8 | );
9 |
10 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/mst/boruvka.cpp:
--------------------------------------------------------------------------------
1 | // Boruvka's algorithm to find Minimum Spanning
2 | // Tree of a given connected, undirected and
3 | // weighted graph
4 | #include
5 | #include "boruvka.hpp"
6 |
7 | // A structure to represent a subset for union-find
8 | struct subset
9 | {
10 | int parent;
11 | int rank;
12 | };
13 |
14 | // Function prototypes for union-find (These functions are defined
15 | // after boruvkaMST() )
16 | int find(struct subset subsets[], int i);
17 | void Union(struct subset subsets[], int x, int y);
18 |
19 | // The main function for MST using Boruvka's algorithm
20 | void boruvkaMST(struct Graph* graph, int * edge_out)
21 | {
22 | // Get data of given graph
23 | int V = graph->V, E = graph->E;
24 | Edge *edge = graph->edge;
25 |
26 | // Allocate memory for creating V subsets.
27 | struct subset *subsets = new subset[V];
28 |
29 | // An array to store index of the cheapest edge of
30 | // subset. The stored index for indexing array 'edge[]'
31 | int *cheapest = new int[V];
32 |
33 | // Create V subsets with single elements
34 | for (int v = 0; v < V; ++v)
35 | {
36 | subsets[v].parent = v;
37 | subsets[v].rank = 0;
38 | cheapest[v] = -1;
39 | }
40 |
41 | // Initially there are V different trees.
42 | // Finally there will be one tree that will be MST
43 | int numTrees = V;
44 | int MSTweight = 0;
45 |
46 | // Keep combining components (or sets) until all
47 | // compnentes are not combined into single MST.
48 | while (numTrees > 1)
49 | {
50 | // Everytime initialize cheapest array
51 | for (int v = 0; v < V; ++v)
52 | {
53 | cheapest[v] = -1;
54 | }
55 |
56 | // Traverse through all edges and update
57 | // cheapest of every component
58 | for (int i=0; i edge[i].weight)
76 | cheapest[set1] = i;
77 |
78 | if (cheapest[set2] == -1 ||
79 | edge[cheapest[set2]].weight > edge[i].weight)
80 | cheapest[set2] = i;
81 | }
82 | }
83 |
84 | // Consider the above picked cheapest edges and add them
85 | // to MST
86 | for (int i=0; iV = V;
119 | graph->E = E;
120 | graph->edge = new Edge[E];
121 | return graph;
122 | }
123 |
124 | // A utility function to find set of an element i
125 | // (uses path compression technique)
126 | int find(struct subset subsets[], int i)
127 | {
128 | // find root and make root as parent of i
129 | // (path compression)
130 | if (subsets[i].parent != i)
131 | subsets[i].parent =
132 | find(subsets, subsets[i].parent);
133 |
134 | return subsets[i].parent;
135 | }
136 |
137 | // A function that does union of two sets of x and y
138 | // (uses union by rank)
139 | void Union(struct subset subsets[], int x, int y)
140 | {
141 | int xroot = find(subsets, x);
142 | int yroot = find(subsets, y);
143 |
144 | // Attach smaller rank tree under root of high
145 | // rank tree (Union by Rank)
146 | if (subsets[xroot].rank < subsets[yroot].rank)
147 | subsets[xroot].parent = yroot;
148 | else if (subsets[xroot].rank > subsets[yroot].rank)
149 | subsets[yroot].parent = xroot;
150 |
151 | // If ranks are same, then make one as root and
152 | // increment its rank by one
153 | else
154 | {
155 | subsets[yroot].parent = xroot;
156 | subsets[xroot].rank++;
157 | }
158 | }
159 |
160 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/mst/boruvka.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | // a structure to represent a weighted edge in graph
4 | struct Edge
5 | {
6 | int src, dest;
7 | float weight;
8 | };
9 |
10 | // a structure to represent a connected, undirected
11 | // and weighted graph as a collection of edges.
12 | struct Graph
13 | {
14 | // V-> Number of vertices, E-> Number of edges
15 | int V, E;
16 |
17 | // graph is represented as an array of edges.
18 | // Since the graph is undirected, the edge
19 | // from src to dest is also edge from dest
20 | // to src. Both are counted as 1 edge here.
21 | Edge* edge;
22 | };
23 |
24 | extern struct Graph* createGraph(int V, int E);
25 | extern void boruvkaMST(struct Graph* graph, int * edge_out);
26 |
27 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/mst/mst.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | /* Switch of minimal spanning tree algorithms */
14 | /* Note: we will migrate the cuda implementaion to PyTorch in the next version */
15 | //#define MST_PRIM
16 | //#define MST_KRUSKAL
17 | #define MST_BORUVKA
18 |
19 | #ifdef MST_PRIM
20 | #include
21 | #include
22 | #endif
23 | #ifdef MST_KRUSKAL
24 | #include
25 | #include
26 | #endif
27 | #ifdef MST_BORUVKA
28 | #include "boruvka.hpp"
29 | #endif
30 |
31 |
32 | #ifndef MST_BORUVKA
33 | using namespace boost;
34 | typedef adjacency_list > Graph;
36 | typedef graph_traits < Graph >::edge_descriptor Edge;
37 | typedef graph_traits < Graph >::vertex_descriptor Vertex;
38 | typedef std::pair E;
39 | #endif
40 |
41 | static void forward_kernel(int * edge_index, float * edge_weight, int * edge_out, int vertex_count, int edge_count){
42 | #ifdef MST_BORUVKA
43 | struct Graph * g = createGraph(vertex_count, edge_count);
44 | for (int i = 0; i < edge_count; ++i){
45 | g->edge[i].src = edge_index[i * 2];
46 | g->edge[i].dest = edge_index[i * 2 + 1];
47 | g->edge[i].weight = edge_weight[i];
48 | }
49 | #else
50 | Graph g(vertex_count);
51 | for (int i = 0; i < edge_count; ++i)
52 | boost::add_edge((int)edge_index[i * 2], (int)edge_index[i * 2 + 1],
53 | edge_weight[i], g);
54 | #endif
55 |
56 | #ifdef MST_PRIM
57 | std::vector < graph_traits < Graph >::vertex_descriptor > p(num_vertices(g));
58 | prim_minimum_spanning_tree(g, &(p[0]));
59 | int * edge_out_ptr = edge_out;
60 | for (std::size_t i = 0; i != p.size(); ++i)
61 | if (p[i] != i) {
62 | *(edge_out_ptr++) = i;
63 | *(edge_out_ptr++) = p[i];
64 | }
65 | #endif
66 |
67 | #ifdef MST_KRUSKAL
68 | std::vector < Edge > spanning_tree;
69 | kruskal_minimum_spanning_tree(g, std::back_inserter(spanning_tree));
70 | float * edge_out_ptr = edge_out;
71 | for (std::vector < Edge >::iterator ei = spanning_tree.begin();
72 | ei != spanning_tree.end(); ++ei){
73 | *(edge_out_ptr++) = source(*ei, g);
74 | *(edge_out_ptr++) = target(*ei, g);
75 | }
76 | #endif
77 |
78 | #ifdef MST_BORUVKA
79 | boruvkaMST(g, edge_out);
80 | delete[] g->edge;
81 | delete[] g;
82 | #endif
83 |
84 | }
85 |
86 | at::Tensor mst_forward(
87 | const at::Tensor & edge_index_tensor,
88 | const at::Tensor & edge_weight_tensor,
89 | int vertex_count){
90 | unsigned batch_size = edge_index_tensor.size(0);
91 | unsigned edge_count = edge_index_tensor.size(1);
92 |
93 | auto edge_index_cpu = edge_index_tensor.cpu();
94 | auto edge_weight_cpu = edge_weight_tensor.cpu();
95 | auto edge_out_cpu = at::empty({batch_size, vertex_count - 1, 2}, edge_index_cpu.options());
96 |
97 | int * edge_out = edge_out_cpu.contiguous().data();
98 | int * edge_index = edge_index_cpu.contiguous().data();
99 | float * edge_weight = edge_weight_cpu.contiguous().data();
100 |
101 | // Loop for batch
102 | std::thread pids[batch_size];
103 | for (unsigned i = 0; i < batch_size; i++){
104 | auto edge_index_iter = edge_index + i * edge_count * 2;
105 | auto edge_weight_iter = edge_weight + i * edge_count;
106 | auto edge_out_iter = edge_out + i * (vertex_count - 1) * 2;
107 | pids[i] = std::thread(forward_kernel, edge_index_iter, edge_weight_iter, edge_out_iter, vertex_count, edge_count);
108 | }
109 |
110 | for (unsigned i = 0; i < batch_size; i++){
111 | pids[i].join();
112 | }
113 |
114 | auto edge_out_tensor = edge_out_cpu.to(edge_index_tensor.device());
115 |
116 | return edge_out_tensor;
117 | }
118 |
119 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/mst/mst.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | extern at::Tensor mst_forward(
5 | const at::Tensor & edge_index_tensor,
6 | const at::Tensor & edge_weight_tensor,
7 | int vertex_count);
8 |
9 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/refine/refine.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 | #include
15 |
16 | #define CUDA_CHECK(call) if((call) != cudaSuccess) {cudaError_t err = cudaGetLastError(); std::cout << "CUDA error calling ""#call"", code is " << err << std::endl;}
17 |
18 | #define CUDA_NUM_THREADS 64
19 | #define GET_CUDA_CHANNEL(N) ceil(512.0f / N)
20 |
21 | __global__ void root_leaf_prop_kernel(
22 | float * in_data,
23 | float * out_data,
24 | float * weight,
25 | int * sorted_index,
26 | int * sorted_parent_index,
27 | int batch_size,
28 | int channel_size,
29 | int vertex_count){
30 |
31 | const int thread_idx = threadIdx.x;
32 | const int batch_idx = blockIdx.x;
33 | const int channel_idx = blockIdx.y;
34 | const int thread_count = blockDim.x;
35 | const int channel_step = gridDim.y;
36 |
37 | in_data += batch_idx * vertex_count * channel_size;
38 | out_data += batch_idx * vertex_count * channel_size;
39 | weight += batch_idx * vertex_count;
40 | sorted_index += batch_idx * vertex_count;
41 | sorted_parent_index += batch_idx * vertex_count;
42 |
43 | __shared__ int node_per_thread[CUDA_NUM_THREADS];
44 | node_per_thread[thread_idx] = -1;
45 | if (thread_idx == 0){
46 | weight[0] = 0;
47 | sorted_parent_index[0] = 0;
48 | }
49 | __syncthreads();
50 |
51 | int i = thread_idx;
52 | while (i < vertex_count){
53 | int par = sorted_parent_index[i];
54 | int par_thread = par % thread_count;
55 | if ((node_per_thread[par_thread] >= par) || (i == 0)){
56 | int cur_pos = sorted_index[i];
57 | int par_pos = sorted_index[par];
58 | for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
59 | k += channel_step * vertex_count){
60 | float edge_weight = weight[i];
61 | out_data[cur_pos + k] = in_data[i + k] * (1 - edge_weight * edge_weight) +
62 | out_data[par_pos + k] * edge_weight;
63 | __threadfence_block();
64 | }
65 | node_per_thread[thread_idx] = i;
66 | i += thread_count;
67 | }
68 | __syncthreads();
69 | }
70 | }
71 |
72 | __global__ void leaf_root_aggr_kernel(
73 | float * in_data,
74 | float * out_data,
75 | float * weight,
76 | int * sorted_index,
77 | int * sorted_child_index,
78 | int batch_size,
79 | int channel_size,
80 | int vertex_count,
81 | int max_adj_per_node){
82 |
83 | const int thread_idx = threadIdx.x;
84 | const int batch_idx = blockIdx.x;
85 | const int channel_idx = blockIdx.y;
86 | const int thread_count = blockDim.x;
87 | const int channel_step = gridDim.y;
88 |
89 | if (in_data != NULL){
90 | in_data += batch_idx * vertex_count * channel_size;
91 | }
92 | out_data += batch_idx * vertex_count * channel_size;
93 | weight += batch_idx * vertex_count;
94 | sorted_index += batch_idx * vertex_count;
95 | sorted_child_index += batch_idx * vertex_count * max_adj_per_node;
96 |
97 | __shared__ int node_per_thread[CUDA_NUM_THREADS];
98 | node_per_thread[thread_idx] = vertex_count;
99 | __syncthreads();
100 |
101 | int i = vertex_count - thread_idx - 1;
102 | while (i >= 0){
103 | int child_len = 0;
104 | bool valid = true;
105 | for (int j = 0; j < max_adj_per_node; j++){
106 | int child = sorted_child_index[i * max_adj_per_node + j];
107 | int child_thread = (vertex_count - child - 1) % thread_count;
108 |
109 | if (child <= 0) break;
110 | if (node_per_thread[child_thread] > child){
111 | valid = false;
112 | break;
113 | }
114 | child_len++;
115 | }
116 | if (valid){
117 | int cur_pos = sorted_index[i];
118 | for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
119 | k += channel_step * vertex_count){
120 | float aggr_sum;
121 | if (in_data != NULL)
122 | aggr_sum = in_data[cur_pos + k];
123 | else
124 | aggr_sum = 1;
125 | for (int j = 0; j < child_len; j++){
126 | int child = sorted_child_index[i * max_adj_per_node + j];
127 | aggr_sum += out_data[child + k] * weight[child];
128 | }
129 | out_data[i + k] = aggr_sum;
130 | }
131 | node_per_thread[thread_idx] = i;
132 | i -= thread_count;
133 | }
134 | __syncthreads();
135 | }
136 | }
137 |
138 | __global__ void root_leaf_grad_kernel(
139 | float * in_data,
140 | float * in_grad,
141 | float * out_data,
142 | float * out_grad,
143 | float * weight,
144 | float * grad,
145 | int * sorted_index,
146 | int * sorted_parent_index,
147 | int batch_size,
148 | int data_channel_size,
149 | int grad_channel_size,
150 | int vertex_count){
151 |
152 | const int thread_idx = threadIdx.x;
153 | const int batch_idx = blockIdx.x;
154 | const int channel_idx = blockIdx.y;
155 | const int thread_count = blockDim.x;
156 | const int channel_step = gridDim.y;
157 | const int channel_size = data_channel_size > grad_channel_size ? data_channel_size : grad_channel_size;
158 |
159 | in_data += batch_idx * vertex_count * data_channel_size;
160 | in_grad += batch_idx * vertex_count * grad_channel_size;
161 | out_data += batch_idx * vertex_count * data_channel_size;
162 | out_grad += batch_idx * vertex_count * grad_channel_size;
163 | weight += batch_idx * vertex_count;
164 | grad += batch_idx * vertex_count * channel_size;
165 | sorted_index += batch_idx * vertex_count;
166 | sorted_parent_index += batch_idx * vertex_count;
167 |
168 | __shared__ int node_per_thread[CUDA_NUM_THREADS];
169 | node_per_thread[thread_idx] = -1;
170 |
171 | int i = thread_idx;
172 | while (i < vertex_count){
173 | int cur = i;
174 | int par = sorted_parent_index[i];
175 | int par_pos = sorted_index[par];
176 | int par_thread = par % thread_count;
177 | if ((cur == 0) || (node_per_thread[par_thread] >= par)){
178 | for (int k = channel_idx; k < channel_size; k += channel_step){
179 | float edge_weight = weight[i];
180 | int data_offset = (k % data_channel_size) * vertex_count;
181 | int grad_offset = (k % grad_channel_size) * vertex_count;
182 | int out_offset = k * vertex_count;
183 |
184 | if (cur > 0){
185 | float left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]);
186 | float right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]);
187 |
188 | grad[cur + out_offset] = left + right;
189 | out_grad[cur + grad_offset] = in_grad[cur + grad_offset] * (1 - edge_weight * edge_weight) +
190 | out_grad[par + grad_offset] * edge_weight;
191 | __threadfence_block();
192 | }
193 | else
194 | grad[cur + out_offset] = 0;
195 | }
196 | node_per_thread[thread_idx] = i;
197 | i += thread_count;
198 | }
199 | __syncthreads();
200 | }
201 | }
202 |
203 | std::tuple
204 | refine_forward(
205 | const at::Tensor & feature_in_tensor,
206 | const at::Tensor & edge_weight_tensor,
207 | const at::Tensor & sorted_index_tensor,
208 | const at::Tensor & sorted_parent_tensor,
209 | const at::Tensor & sorted_child_tensor
210 | ){
211 |
212 | const int batch_size = feature_in_tensor.size(0);
213 | const int channel_size = feature_in_tensor.size(1);
214 | const int vertex_size = feature_in_tensor.size(2);
215 | const int max_adj_per_node = sorted_child_tensor.size(2);
216 |
217 | auto options = feature_in_tensor.options();
218 | auto feature_aggr_tensor = at::zeros_like(feature_in_tensor, options);
219 | auto feature_aggr_up_tensor = at::zeros_like(feature_in_tensor, options);
220 | auto weight_sum_tensor = at::zeros({batch_size, vertex_size}, options);
221 | auto weight_sum_up_tensor = at::zeros({batch_size, vertex_size}, options);
222 |
223 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
224 |
225 | float * feature_in = feature_in_tensor.contiguous().data();
226 | float * edge_weight = edge_weight_tensor.contiguous().data();
227 | int * sorted_index = sorted_index_tensor.contiguous().data();
228 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data();
229 | int * sorted_child_index = sorted_child_tensor.contiguous().data();
230 | float * feature_aggr = feature_aggr_tensor.contiguous().data();
231 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data();
232 | float * weight_sum = weight_sum_tensor.contiguous().data();
233 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data();
234 |
235 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
236 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
237 | feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
238 | root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
239 | feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
240 |
241 | dim3 weight_block_dims(CUDA_NUM_THREADS, 1, 1), weight_grid_dims(batch_size, 1, 1);
242 | leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
243 | NULL, weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1, vertex_size, max_adj_per_node);
244 | root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
245 | weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1, vertex_size);
246 |
247 | auto feature_out_tensor = feature_aggr_tensor / weight_sum_tensor.unsqueeze(1);
248 | auto result = std::make_tuple(feature_out_tensor, feature_aggr_tensor, feature_aggr_up_tensor,
249 | weight_sum_tensor, weight_sum_up_tensor);
250 | return result;
251 | }
252 |
253 | at::Tensor refine_backward_feature(
254 | const at::Tensor & feature_in_tensor,
255 | const at::Tensor & edge_weight_tensor,
256 | const at::Tensor & sorted_index_tensor,
257 | const at::Tensor & sorted_parent_tensor,
258 | const at::Tensor & sorted_child_tensor,
259 | const at::Tensor & feature_out_tensor,
260 | const at::Tensor & feature_aggr_tensor,
261 | const at::Tensor & feature_aggr_up_tensor,
262 | const at::Tensor & weight_sum_tensor,
263 | const at::Tensor & weight_sum_up_tensor,
264 | const at::Tensor & grad_out_tensor
265 | ){
266 |
267 | auto options = feature_in_tensor.options();
268 | auto grad_feature_tensor = at::zeros_like(feature_in_tensor, options);
269 | auto grad_feature_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options);
270 |
271 | auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1);
272 |
273 | const int batch_size = feature_in_tensor.size(0);
274 | const int channel_size = feature_in_tensor.size(1);
275 | const int vertex_size = feature_in_tensor.size(2);
276 | const int max_adj_per_node = sorted_child_tensor.size(2);
277 |
278 | float * feature_in = feature_in_tensor.contiguous().data();
279 | float * edge_weight = edge_weight_tensor.contiguous().data();
280 | int * sorted_index = sorted_index_tensor.contiguous().data();
281 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data();
282 | int * sorted_child_index = sorted_child_tensor.contiguous().data();
283 | float * feature_aggr = feature_aggr_tensor.contiguous().data();
284 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data();
285 | float * weight_sum = weight_sum_tensor.contiguous().data();
286 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data();
287 | float * grad_out = grad_out_tensor.contiguous().data();
288 | float * grad_feature = grad_feature_tensor.contiguous().data();
289 |
290 | float * grad_out_norm = grad_out_norm_tensor.contiguous().data();
291 | float * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous().data();
292 |
293 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
294 |
295 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
296 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
297 | grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
298 | root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
299 | grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
300 |
301 | return grad_feature_tensor;
302 | }
303 |
304 | at::Tensor refine_backward_weight(
305 | const at::Tensor & feature_in_tensor,
306 | const at::Tensor & edge_weight_tensor,
307 | const at::Tensor & sorted_index_tensor,
308 | const at::Tensor & sorted_parent_tensor,
309 | const at::Tensor & sorted_child_tensor,
310 | const at::Tensor & feature_out_tensor,
311 | const at::Tensor & feature_aggr_tensor,
312 | const at::Tensor & feature_aggr_up_tensor,
313 | const at::Tensor & weight_sum_tensor,
314 | const at::Tensor & weight_sum_up_tensor,
315 | const at::Tensor & grad_out_tensor
316 | ){
317 |
318 | auto options = feature_in_tensor.options();
319 | auto grad_weight_tensor = at::zeros_like(edge_weight_tensor, options);
320 |
321 | const int batch_size = feature_in_tensor.size(0);
322 | const int channel_size = feature_in_tensor.size(1);
323 | const int vertex_size = feature_in_tensor.size(2);
324 | const int max_adj_per_node = sorted_child_tensor.size(2);
325 |
326 | float * feature_in = feature_in_tensor.contiguous().data();
327 | float * edge_weight = edge_weight_tensor.contiguous().data();
328 | int * sorted_index = sorted_index_tensor.contiguous().data();
329 | int * sorted_parent_index = sorted_parent_tensor.contiguous().data();
330 | int * sorted_child_index = sorted_child_tensor.contiguous().data();
331 | float * feature_out = feature_out_tensor.contiguous().data();
332 | float * feature_aggr = feature_aggr_tensor.contiguous().data();
333 | float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data();
334 | float * weight_sum = weight_sum_tensor.contiguous().data();
335 | float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data();
336 | float * grad_out = grad_out_tensor.contiguous().data();
337 | float * grad_weight = grad_weight_tensor.contiguous().data();
338 |
339 | auto grad_all_channel_tensor = at::zeros_like(feature_in_tensor, options);
340 | auto grad_norm_all_channel_tensor = at::zeros_like(feature_in_tensor, options);
341 | auto grad_out_norm_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options);
342 | auto feature_grad_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options);
343 |
344 | float * grad_all_channel = grad_all_channel_tensor.contiguous().data();
345 | float * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous().data();
346 | float * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous().data();
347 | float * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous().data();
348 |
349 | auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1);
350 | auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor;
351 | float * grad_out_norm = grad_out_norm_tensor.contiguous().data();
352 | float * feature_grad = feature_grad_tensor.contiguous().data();
353 |
354 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
355 |
356 | dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
357 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
358 | grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
359 | leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
360 | feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
361 |
362 | root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
363 | feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel,
364 | sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size);
365 | root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
366 | weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel,
367 | sorted_index, sorted_parent_index, batch_size, 1, channel_size, vertex_size);
368 |
369 | grad_weight_tensor = (grad_all_channel_tensor - grad_norm_all_channel_tensor).sum(1);
370 |
371 | return grad_weight_tensor;
372 | }
373 |
374 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/refine/refine.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | extern std::tuple
5 | refine_forward(
6 | const at::Tensor & feature_in_tensor,
7 | const at::Tensor & edge_weight_tensor,
8 | const at::Tensor & sorted_index_tensor,
9 | const at::Tensor & sorted_parent_index_tensor,
10 | const at::Tensor & sorted_child_index_tensor
11 | );
12 |
13 | extern at::Tensor refine_backward_feature(
14 | const at::Tensor & feature_in_tensor,
15 | const at::Tensor & edge_weight_tensor,
16 | const at::Tensor & sorted_index_tensor,
17 | const at::Tensor & sorted_parent_tensor,
18 | const at::Tensor & sorted_child_tensor,
19 | const at::Tensor & feature_out_tensor,
20 | const at::Tensor & feature_aggr_tensor,
21 | const at::Tensor & feature_aggr_up_tensor,
22 | const at::Tensor & weight_sum_tensor,
23 | const at::Tensor & weight_sum_up_tensor,
24 | const at::Tensor & grad_out_tensor
25 | );
26 |
27 | extern at::Tensor refine_backward_weight(
28 | const at::Tensor & feature_in_tensor,
29 | const at::Tensor & edge_weight_tensor,
30 | const at::Tensor & sorted_index_tensor,
31 | const at::Tensor & sorted_parent_tensor,
32 | const at::Tensor & sorted_child_tensor,
33 | const at::Tensor & feature_out_tensor,
34 | const at::Tensor & feature_aggr_tensor,
35 | const at::Tensor & feature_aggr_up_tensor,
36 | const at::Tensor & weight_sum_tensor,
37 | const at::Tensor & weight_sum_up_tensor,
38 | const at::Tensor & grad_out_tensor
39 | );
40 |
41 |
--------------------------------------------------------------------------------
/contextlab/layers/tree_filter/src/tree_filter.cpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include "refine/refine.hpp"
4 | #include "mst/mst.hpp"
5 | #include "bfs/bfs.hpp"
6 |
7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8 | m.def("mst_forward", &mst_forward, "mst forward");
9 | m.def("bfs_forward", &bfs_forward, "bfs forward");
10 | m.def("refine_forward", &refine_forward, "refine forward");
11 | m.def("refine_backward_feature", &refine_backward_feature, "refine backward wrt feature");
12 | m.def("refine_backward_weight", &refine_backward_weight, "refine backward wrt weight");
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/contextlab/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer_misc import ConvBNReLU, GraphAdjNetwork
2 | from .weight_init import constant_init, xavier_init, kaiming_init
--------------------------------------------------------------------------------
/contextlab/utils/layer_misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | __all__ = ['ConvBNReLU']
5 |
6 |
7 | class ConvBNReLU(nn.Module):
8 | def __init__(self,
9 | in_channels,
10 | out_channels,
11 | kernel_size,
12 | norm_layer=nn.BatchNorm2d,
13 | with_relu=True,
14 | stride=1,
15 | padding=0,
16 | dilation=1,
17 | groups=1,
18 | bias=True,
19 | padding_mode='zeros'):
20 | super(ConvBNReLU, self).__init__()
21 | self.conv = nn.Conv2d(
22 | in_channels=in_channels,
23 | out_channels=out_channels,
24 | kernel_size=kernel_size,
25 | stride=stride,
26 | padding=padding,
27 | dilation=dilation,
28 | groups=groups,
29 | bias=bias,
30 | padding_mode=padding_mode)
31 | self.bn = norm_layer(out_channels)
32 | self.with_relu = with_relu
33 | if with_relu:
34 | self.relu = nn.ReLU(inplace=True)
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | x = self.bn(x)
39 | if self.with_relu:
40 | x = self.relu(x)
41 | return x
42 |
43 |
44 | class GraphAdjNetwork(nn.Module):
45 | def __init__(self,
46 | pair_function,
47 | in_channels,
48 | channel_stride):
49 | super(GraphAdjNetwork, self).__init__()
50 | self.pair_function = pair_function
51 |
52 | if pair_function == 'embedded_gaussian':
53 | inter_channel = in_channels // channel_stride
54 | self.phi = ConvBNReLU(
55 | in_channels=in_channels,
56 | out_channels=inter_channel,
57 | kernel_size=1,
58 | bias=False,
59 | norm_layer=nn.BatchNorm2d
60 | )
61 | self.theta = ConvBNReLU(
62 | in_channels=in_channels,
63 | out_channels=inter_channel,
64 | kernel_size=1,
65 | bias=False,
66 | norm_layer=nn.BatchNorm2d
67 | )
68 | elif pair_function == 'gaussian':
69 | pass
70 | elif pair_function == 'diff_learnable':
71 | self.learnable_adj_conv = ConvBNReLU(
72 | in_channels=in_channels,
73 | out_channels=1,
74 | kernel_size=1,
75 | bias=False,
76 | norm_layer=nn.BatchNorm2d
77 | )
78 | elif pair_function == 'sum_learnable':
79 | self.learnable_adj_conv = ConvBNReLU(
80 | in_channels=in_channels,
81 | out_channels=1,
82 | kernel_size=1,
83 | bias=False,
84 | norm_layer=nn.BatchNorm2d
85 | )
86 | elif pair_function == 'cat_learnable':
87 | self.learnable_adj_conv = ConvBNReLU(
88 | in_channels=in_channels*2,
89 | out_channels=1,
90 | kernel_size=1,
91 | bias=False,
92 | norm_layer=nn.BatchNorm2d
93 | )
94 | else:
95 | raise NotImplementedError
96 |
97 | def forward(self, x):
98 | """
99 | Args:
100 | x (Tensor):
101 | (B, N, C)
102 | """
103 | if self.pair_function == 'gaussian':
104 | adj = self.gaussian(x, x.permute(0, 2, 1))
105 | elif self.pair_function == 'embedded_gaussian':
106 | x = x.permute(0, 2, 1).unsqueeze(-1)
107 | x_1 = self.phi(x) # B, C, N, 1
108 | x_2 = self.theta(x) # B, C, N, 1
109 | adj = self.gaussian(
110 | x_1.squeeze(-1).permute(0, 2, 1), x_2.squeeze(-1))
111 | elif self.pair_function == 'diff_learnable':
112 | adj = self.diff_learnable_adj(x.unsqueeze(2), x.unsqueeze(1))
113 | elif self.pair_function == 'sum_learnable':
114 | adj = self.sum_learnable_adj(x.unsqueeze(2), x.unsqueeze(1))
115 | elif self.pair_function == 'cat_learnable':
116 | adj = self.cat_learnable_adj(x.unsqueeze(2), x.unsqueeze(1))
117 | else:
118 | raise NotImplementedError(self.pair_function)
119 |
120 | return adj
121 |
122 | def gaussian(self, x_1, x_2):
123 | """
124 | Args:
125 | x_1:
126 | x_2:
127 | Return:
128 | adj: normalized in the last dimenstion
129 | """
130 | # (B, N, C) X (B, C, N) --> (B, N, N)
131 | adj = torch.bmm(x_1, x_2) # B, N, N
132 | adj = F.softmax(adj, dim=-1) # B, N, N
133 | return adj
134 |
135 | def diff_learnable_adj(self, x_1, x_2):
136 | """
137 | Learnable attention from the difference of the feature
138 |
139 | Return:
140 | adj: normalzied at the last dimension
141 | """
142 | # x1:(B,N,1,C)
143 | # x2:(B,1,N,C)
144 | feature_diff = x_1 - x_2 # (B, N, N, C)
145 | feature_diff = feature_diff.permute(0, 3, 1, 2) # (B, C, N, N)
146 | adj = self.learnable_adj_conv(feature_diff) # (B, 1, N, N)
147 | adj = adj.squeeze(1) # (B, N, N)
148 | # Use the number of nodes as the normalization factor
149 | adj = adj / adj.size(-1) # (B, N, N)
150 |
151 | return adj
152 |
153 | def sum_learnable_adj(self, x_1, x_2):
154 | """
155 | Learnable attention from the difference of the feature
156 |
157 | Return:
158 | adj: normalzied at the last dimension
159 | """
160 | # x1:(B,N,1,C)
161 | # x2:(B,1,N,C)
162 | feature_diff = x_1 + x_2 # (B, N, N, C)
163 | feature_diff = feature_diff.permute(0, 3, 1, 2) # (B, C, N, N)
164 | adj = self.learnable_adj_conv(feature_diff) # (B, 1, N, N)
165 | adj = adj.squeeze(1) # (B, N, N)
166 | # Use the number of nodes as the normalization factor
167 | adj = adj / adj.size(-1) # (B, N, N)
168 |
169 | return adj
170 |
171 | def cat_learnable_adj(self, x_1, x_2):
172 | """
173 | Learable attention from the concatnation of the features
174 | """
175 | x_1 = x_1.repeat(1, 1, x_1.size(1), 1) # B, N, N, C
176 | x_2 = x_2.repeat(1, x_2.size(2), 1, 1) # B, N, N, C
177 | feature_cat = torch.cat([x_1, x_2], dim=-1) # B, N, N, 2C
178 | # import pdb; pdb.set_trace()
179 | feature_cat = feature_cat.permute(0, 3, 1, 2) # B, 2C, N, N
180 | adj = self.learnable_adj_conv(feature_cat) # B, 1, N, N
181 | adj = adj.squeeze(1) # (B, N, N)
182 | # Use the number of nodes as the normalization factor
183 | adj = adj / adj.size(-1) # (B, N, N)
184 |
185 | return adj
186 |
--------------------------------------------------------------------------------
/contextlab/utils/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def constant_init(module, val, bias=0):
5 | nn.init.constant_(module.weight, val)
6 | if hasattr(module, 'bias') and module.bias is not None:
7 | nn.init.constant_(module.bias, bias)
8 |
9 |
10 | def xavier_init(module, gain=1, bias=0, distribution='normal'):
11 | assert distribution in ['uniform', 'normal']
12 | if distribution == 'uniform':
13 | nn.init.xavier_uniform_(module.weight, gain=gain)
14 | else:
15 | nn.init.xavier_normal_(module.weight, gain=gain)
16 | if hasattr(module, 'bias') and module.bias is not None:
17 | nn.init.constant_(module.bias, bias)
18 |
19 |
20 | def normal_init(module, mean=0, std=1, bias=0):
21 | nn.init.normal_(module.weight, mean, std)
22 | if hasattr(module, 'bias') and module.bias is not None:
23 | nn.init.constant_(module.bias, bias)
24 |
25 |
26 | def uniform_init(module, a=0, b=1, bias=0):
27 | nn.init.uniform_(module.weight, a, b)
28 | if hasattr(module, 'bias') and module.bias is not None:
29 | nn.init.constant_(module.bias, bias)
30 |
31 |
32 | def kaiming_init(module,
33 | a=0,
34 | mode='fan_out',
35 | nonlinearity='relu',
36 | bias=0,
37 | distribution='normal'):
38 | assert distribution in ['uniform', 'normal']
39 | if distribution == 'uniform':
40 | nn.init.kaiming_uniform_(
41 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
42 | else:
43 | nn.init.kaiming_normal_(
44 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
45 | if hasattr(module, 'bias') and module.bias is not None:
46 | nn.init.constant_(module.bias, bias)
47 |
48 |
49 | def caffe2_xavier_init(module, bias=0):
50 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
51 | # Acknowledgment to FAIR's internal code
52 | kaiming_init(
53 | module,
54 | a=1,
55 | mode='fan_in',
56 | nonlinearity='leaky_relu',
57 | distribution='uniform')
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Description: Setup for contextlab project
3 | @Author: Songyang Zhang
4 | @Email: sy.zhangbuaa@gmail.com
5 | @Date: 2019-08-11 12:30:28
6 | @LastEditors: Songyang Zhang
7 | @LastEditTime: 2019-09-27 19:58:41
8 | '''
9 |
10 | import glob
11 | import os
12 |
13 | import subprocess
14 | import time
15 | import platform
16 |
17 | import torch
18 |
19 | from setuptools import find_packages
20 | from setuptools import setup
21 |
22 | from torch.utils.cpp_extension import CUDA_HOME
23 | from torch.utils.cpp_extension import CppExtension
24 | from torch.utils.cpp_extension import CUDAExtension
25 | from setuptools import Extension, find_packages, setup
26 |
27 | from Cython.Build import cythonize
28 |
29 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
30 |
31 | requirements = ['torch']
32 |
33 | def readme():
34 | with open('README.md', encoding='utf-8') as f:
35 | content = f.read()
36 | return content
37 |
38 | version_file = 'contextlab/version.py'
39 |
40 | if torch.cuda.is_available():
41 | if 'LD_LIBRARY_PATH' not in os.environ:
42 | raise Exception('LD_LIBRARY_PATH is not set.')
43 | cuda_lib_path = os.environ['LD_LIBRARY_PATH'].split(':')
44 | else:
45 | raise Exception('This implementation is only avaliable for CUDA devices.')
46 |
47 |
48 | MAJOR = 0
49 | MINOR = 2
50 | PATCH = 0
51 | SUFFIX = ''
52 | SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
53 |
54 |
55 | def get_git_hash():
56 | def _minimal_ext_cmd(cmd):
57 | # construct minimal environment
58 | env = {}
59 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
60 | v = os.environ.get(k)
61 | if v is not None:
62 | env[k] = v
63 |
64 | # LANGUAGE is used on win 32
65 | env['LANGUAGE'] = 'C'
66 | env['LANG'] = 'C'
67 | env['LC_ALL'] = 'C'
68 | out = subprocess.Popen(
69 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
70 |
71 | return out
72 | try:
73 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
74 | sha = out.strip().decode('ascii')
75 | except OSError:
76 | sha = 'unknown'
77 |
78 | return sha
79 |
80 | def get_hash():
81 | if os.path.exists('.git'):
82 | sha = get_git_hash()[:7]
83 | elif os.path.exists(version_file):
84 | try:
85 | from pluscv.version import __version__
86 | sha = __version__.split('+')[-1]
87 | except ImportError:
88 | raise ImportError('Unable to get git version')
89 | else:
90 | sha = 'unknown'
91 |
92 | return sha
93 |
94 | def write_version_py():
95 | content = """# Generated Version File
96 | # Time: {}
97 |
98 | __version__ = '{}'
99 | short_version = '{}'
100 | """
101 | sha = get_hash()
102 | VERSION = SHORT_VERSION + '+' + sha
103 |
104 | with open(version_file, 'w') as f:
105 | f.write(content.format(time.asctime(), VERSION, SHORT_VERSION))
106 |
107 | def get_version():
108 | with open(version_file, 'r') as f:
109 | exec(compile(f.read(), version_file, 'exec'))
110 |
111 | return locals()['__version__']
112 |
113 |
114 | def make_cuda_ext(name, module, sources, include_dirs=[]):
115 |
116 |
117 | return CUDAExtension(
118 | name='{}.{}'.format(module, name),
119 | sources=[os.path.join(*module.split('.'), p) for p in sources],
120 | include_dirs=include_dirs,
121 | library_dirs=cuda_lib_path,
122 | extra_compile_args={
123 | 'cxx': ['-O3'],
124 | 'nvcc': [
125 | '-O3',
126 | # '-D__CUDA_NO_HALF_OPERATORS__',
127 | # '-D__CUDA_NO_HALF_CONVERSIONS__',
128 | # '-D__CUDA_NO_HALF2_OPERATORS__',
129 | ]
130 | })
131 |
132 |
133 | def tree_filter_files():
134 |
135 | extensions_dir = 'contextlab/layers/tree_filter/src'
136 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
137 | source_cpu = glob.glob(os.path.join(extensions_dir, "*", "*.cpp"))
138 | source_cuda = glob.glob(os.path.join(extensions_dir, "*", "*.cu"))
139 |
140 | sources = source_cpu + source_cuda + main_file
141 |
142 | return extensions_dir, sources
143 |
144 | if __name__ == "__main__":
145 |
146 | write_version_py()
147 |
148 | tree_extensions_dir, tree_sources = tree_filter_files()
149 |
150 | setup(
151 | name='contextlab',
152 | version=get_version(),
153 | author="Songyang Zhang",
154 | url="https://github.com/SHTUPLUS/contextlab",
155 | long_description=readme(),
156 | description="Context Feature Augmentation Lab developed with PyTorch from ShanghaiTech PLUS Lab",
157 | packages=find_packages(exclude=("src",)),
158 | license='Apache License 2.0',
159 | install_requires=requirements,
160 | ext_modules=[
161 | # make_cuda_ext(
162 | # name='tree_filter_cuda',
163 | # module='contextlab.layers.tree_filter',
164 | # include_dirs=[tree_extensions_dir],
165 | # sources=tree_sources),
166 | CUDAExtension(
167 | name='contextlab.layers.tree_filter.functions.tree_filter_cuda',
168 | # module='contextlab.layers.tree_filter',
169 | include_dirs=[tree_extensions_dir],
170 | sources=tree_sources,
171 | library_dirs=cuda_lib_path,
172 | extra_compile_args={'cxx':['-O3'],
173 | 'nvcc':['-O3']}),
174 | CUDAExtension(
175 | name='contextlab.layers.cc_attention.rcca',
176 | sources=['contextlab/layers/cc_attention/src/lib_cffi.cpp',
177 | 'contextlab/layers/cc_attention/src/ca.cu'],
178 | extra_compile_args= ['-std=c++11'],
179 | extra_cflags=["-O3"],
180 | extra_cuda_cflags=["--expt-extended-lambda"],
181 | )
182 | ],
183 | cmdclass={
184 | 'build_ext': BuildExtension
185 | },
186 | zip_safe=False
187 | )
--------------------------------------------------------------------------------
/src/images/contextlab_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHTUPLUS/ContextLab/4e12f0af9d0640f29c763b915f02de763b577200/src/images/contextlab_logo.png
--------------------------------------------------------------------------------
/test/utils/layer_misc.py:
--------------------------------------------------------------------------------
1 | # from contextlab.layers.long_tail import CategoryAttentionNetwork
2 | from contextlab.utils import GraphAdjNetwork
3 |
4 | import torch
5 |
6 | if __name__ == "__main__":
7 | inputs = torch.randn(2, 1024, 8).permute(0, 2, 1)
8 | network = GraphAdjNetwork(
9 | pair_function='embedded_gaussian',
10 | in_channels=1024,
11 | channel_stride=8,
12 | )
13 | output = network(inputs)
14 | print('Embedding Gaussian run pass')
15 |
16 | network = GraphAdjNetwork(
17 | pair_function='gaussian',
18 | in_channels=1024,
19 | channel_stride=8,
20 | )
21 | output = network(inputs)
22 |
23 | print('Gaussian run pass')
24 |
25 | network = GraphAdjNetwork(
26 | pair_function='diff_learnable',
27 | in_channels=1024,
28 | channel_stride=8,
29 | )
30 | output = network(inputs)
31 |
32 | print('Diff learnable run pass')
33 |
34 | network = GraphAdjNetwork(
35 | pair_function='sum_learnable',
36 | in_channels=1024,
37 | channel_stride=8,
38 | )
39 | output = network(inputs)
40 | print('Sum learnable run pass')
41 |
42 | network = GraphAdjNetwork(
43 | pair_function='cat_learnable',
44 | in_channels=1024,
45 | channel_stride=8,
46 | )
47 | output = network(inputs)
48 | print('Cat learnable run pass')
49 |
--------------------------------------------------------------------------------