├── .gitignore
├── LICENSE
├── README.md
├── group_norm.py
├── package
├── README.md
├── build.py
├── my_package
│ ├── __init__.py
│ ├── functions
│ │ ├── __init__.py
│ │ └── add.py
│ ├── modules
│ │ ├── __init__.py
│ │ └── add.py
│ └── src
│ │ ├── my_lib.c
│ │ ├── my_lib.h
│ │ ├── my_lib_cuda.c
│ │ └── my_lib_cuda.h
└── setup.py
├── script
├── README.md
├── _ext
│ ├── __init__.py
│ └── my_lib
│ │ └── __init__.py
├── build.py
├── functions
│ ├── __init__.py
│ └── add.py
├── modules
│ ├── __init__.py
│ └── add.py
└── src
│ ├── my_lib.c
│ ├── my_lib.h
│ ├── my_lib_cuda.c
│ ├── my_lib_cuda.h
│ └── new_lib
│ ├── new_lib.c
│ ├── new_lib.h
│ ├── new_lib_cuda.c
│ └── new_lib_cuda.h
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | test.py
3 | script/_ext/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | env/
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | .venv
91 | venv/
92 | ENV/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ligeng Zhu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GroupNorm.pytorch
2 | PyTorch implementation of Group Normalization https://arxiv.org/abs/1803.08494
3 |
4 | * Python version is ready.
5 | * C++/CUDA version comes soon.
--------------------------------------------------------------------------------
/group_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch.nn import Parameter
5 |
6 |
7 | class GroupNorm(nn.Module):
8 | def __init__(self, num_features, num_groups=32, eps=1e-5):
9 | super(GroupNorm, self).__init__()
10 | self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
11 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
12 | self.num_groups = num_groups
13 | self.eps = eps
14 |
15 | def forward(self, x):
16 | N, C, H, W = x.size()
17 | G = self.num_groups
18 | assert C % G == 0
19 |
20 | x = x.view(N, G, -1)
21 | mean = x.mean(-1, keepdim=True)
22 | var = x.var(-1, keepdim=True)
23 |
24 | x = (x - mean) / (var + self.eps).sqrt()
25 | x = x.view(N, C, H, W)
26 | return x * self.weight + self.bias
27 |
28 |
29 | class GroupNormMoving(nn.Module):
30 | def __init__(self, num_features, num_groups=32, eps=1e-5,
31 | momentum=0.1, affine=True,
32 | track_running_stats=True
33 | ):
34 | super(GroupNormMoving, self).__init__()
35 |
36 | self.num_features = num_features
37 | self.num_groups = num_groups
38 | self.eps = eps
39 |
40 | self.momentum = momentum
41 | self.affine = affine
42 |
43 | self.track_running_stats = track_running_stats
44 |
45 | tensor_shape = (1, num_features, 1, 1)
46 |
47 | if self.affine:
48 | self.weight = Parameter(torch.Tensor(*tensor_shape))
49 | self.bias = Parameter(torch.Tensor(*tensor_shape))
50 | else:
51 | self.register_parameter('weight', None)
52 | self.register_parameter('bias', None)
53 |
54 | if self.track_running_stats:
55 | # self.register_buffer('running_mean', torch.zeros(*tensor_shape))
56 | # self.register_buffer('running_var', torch.ones(*tensor_shape))
57 | # else:
58 | self.register_parameter('running_mean', None)
59 | self.register_parameter('running_var', None)
60 | self.reset_parameters()
61 |
62 | def forward(self, x):
63 | N, C, H, W = x.size()
64 | G = self.num_groups
65 | assert C % G == 0, "Channel must be divided by groups"
66 |
67 | x = x.view(N, G, -1)
68 | mean = x.mean(-1, keepdim=True)
69 | var = x.var(-1, keepdim=True)
70 |
71 | if self.running_mean is None or self.running_mean.size() != mean.size():
72 | # self.running_mean = Parameter(torch.Tensor(mean.data.clone()))
73 | # self.running_var = Parameter(torch.Tensor(var.data.clone()))
74 | self.running_mean = Parameter(torch.Tensor(mean.data))
75 | self.running_var = Parameter(torch.Tensor(mean.data))
76 |
77 | if self.training and self.track_running_stats:
78 | self.running_mean.data = mean * self.momentum + \
79 | self.running_mean.data * (1 - self.momentum)
80 | self.running_var.data = var * self.momentum + \
81 | self.running_var.data * (1 - self.momentum)
82 |
83 | # mean = self.running_mean
84 | # var = self.running_var
85 |
86 | x = (x - self.running_mean) / (self.running_var + self.eps).sqrt()
87 | x = x.view(N, C, H, W)
88 | return x * self.weight + self.bias
89 |
90 | def reset_parameters(self):
91 | if self.track_running_stats:
92 | if self.running_mean is not None and self.running_var is not None:
93 | self.running_mean.zero_()
94 | self.running_var.fill_(1)
95 | if self.affine:
96 | self.weight.data.uniform_()
97 | self.bias.data.zero_()
98 |
99 | def __repr__(self):
100 | return ('{name}({num_features}, eps={eps}, momentum={momentum},'
101 | ' affine={affine}, track_running_stats={track_running_stats})'
102 | .format(name=self.__class__.__name__, **self.__dict__))
--------------------------------------------------------------------------------
/package/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch FFI package
2 |
3 | This example shows how to structure the code to create an ffi package for
4 | PyTorch. It can be later distributed via pip.
5 |
6 | ### Required files:
7 |
8 | * `setup.py` - setuptools file, that defines package metadata and some extension
9 | options
10 | * `my_package/build.py` - cffi build file. Defines the extensions and builds
11 | them if executed.
12 |
--------------------------------------------------------------------------------
/package/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.ffi import create_extension
4 |
5 | this_file = os.path.dirname(__file__)
6 |
7 | sources = ['my_package/src/new_lib.c']
8 | headers = ['my_package/src/new_lib.h']
9 | defines = []
10 | with_cuda = False
11 |
12 | if torch.cuda.is_available():
13 | print('Including CUDA code.')
14 | sources += ['my_package/src/new_lib_cuda.c']
15 | headers += ['my_package/src/new_lib_cuda.h']
16 | defines += [('WITH_CUDA', None)]
17 | with_cuda = True
18 |
19 | ffi = create_extension(
20 | 'my_package._ext.new_lib',
21 | package=True,
22 | headers=headers,
23 | sources=sources,
24 | define_macros=defines,
25 | relative_to=__file__,
26 | with_cuda=with_cuda
27 | )
28 |
29 | if __name__ == '__main__':
30 | ffi.build()
31 |
--------------------------------------------------------------------------------
/package/my_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/__init__.py
--------------------------------------------------------------------------------
/package/my_package/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/functions/__init__.py
--------------------------------------------------------------------------------
/package/my_package/functions/add.py:
--------------------------------------------------------------------------------
1 | # functions/add.py
2 | import torch
3 | from torch.autograd import Function
4 | from .._ext import my_lib
5 |
6 |
7 | class MyAddFunction(Function):
8 | def forward(self, input1, input2):
9 | output = input1.new()
10 | if not input1.is_cuda:
11 | my_lib.my_lib_add_forward(input1, input2, output)
12 | else:
13 | my_lib.my_lib_add_forward_cuda(input1, input2, output)
14 | return output
15 |
16 | def backward(self, grad_output):
17 | grad_input = grad_output.new()
18 | if not grad_output.is_cuda:
19 | my_lib.my_lib_add_backward(grad_output, grad_input)
20 | else:
21 | my_lib.my_lib_add_backward_cuda(grad_output, grad_input)
22 | return grad_input
23 |
--------------------------------------------------------------------------------
/package/my_package/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/modules/__init__.py
--------------------------------------------------------------------------------
/package/my_package/modules/add.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from ..functions.add import MyAddFunction
3 |
4 | class MyAddModule(Module):
5 | def forward(self, input1, input2):
6 | return MyAddFunction()(input1, input2)
7 |
--------------------------------------------------------------------------------
/package/my_package/src/my_lib.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
4 | THFloatTensor *output)
5 | {
6 | if (!THFloatTensor_isSameSizeAs(input1, input2))
7 | return 0;
8 | THFloatTensor_resizeAs(output, input1);
9 | THFloatTensor_cadd(output, input1, 1.0, input2);
10 | return 1;
11 | }
12 |
13 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
14 | {
15 | THFloatTensor_resizeAs(grad_input, grad_output);
16 | THFloatTensor_fill(grad_input, 1);
17 | return 1;
18 | }
19 |
--------------------------------------------------------------------------------
/package/my_package/src/my_lib.h:
--------------------------------------------------------------------------------
1 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
2 | THFloatTensor *output);
3 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/package/my_package/src/my_lib_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // this symbol will be resolved automatically from PyTorch libs
4 | extern THCState *state;
5 |
6 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
7 | THCudaTensor *output)
8 | {
9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2))
10 | return 0;
11 | THCudaTensor_resizeAs(state, output, input1);
12 | THCudaTensor_cadd(state, output, input1, 1.0, input2);
13 | return 1;
14 | }
15 |
16 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input)
17 | {
18 | THCudaTensor_resizeAs(state, grad_input, grad_output);
19 | THCudaTensor_fill(state, grad_input, 1);
20 | return 1;
21 | }
22 |
--------------------------------------------------------------------------------
/package/my_package/src/my_lib_cuda.h:
--------------------------------------------------------------------------------
1 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
2 | THCudaTensor *output);
3 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 |
6 | from setuptools import setup, find_packages
7 |
8 | import build
9 |
10 | this_file = os.path.dirname(__file__)
11 |
12 | setup(
13 | name="my_package",
14 | version="0.1",
15 | description="An example project using PyTorch FFI",
16 | url="https://github.com/pytorch/ffi-examples",
17 | author="XYZ",
18 | author_email="author@email.com",
19 | # Require cffi.
20 | install_requires=["cffi>=1.0.0"],
21 | setup_requires=["cffi>=1.0.0"],
22 | # Exclude the build files.
23 | packages=find_packages(exclude=["build"]),
24 | # Package where to put the extensions. Has to be a prefix of build.py.
25 | ext_package="",
26 | # Extensions to compile.
27 | cffi_modules=[
28 | os.path.join(this_file, "build.py:ffi")
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/script/README.md:
--------------------------------------------------------------------------------
1 | # An example C extension for PyTorch
2 |
3 | This example showcases adding a neural network layer that adds two input Tensors
4 |
5 | - src: C source code
6 | - functions: the autograd functions
7 | - modules: code of the nn module
8 | - build.py: a small file that compiles your module to be ready to use
9 | - test.py: an example file that loads and uses the extension
10 |
11 | ```bash
12 | python build.py
13 | python test.py
14 | ```
15 |
--------------------------------------------------------------------------------
/script/_ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/_ext/__init__.py
--------------------------------------------------------------------------------
/script/_ext/my_lib/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._my_lib import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/script/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | import torch
5 | from torch.utils.ffi import create_extension
6 |
7 |
8 | this_file = os.path.dirname(__file__)
9 |
10 | # sources = ['src/new_lib.c']
11 | # headers = ['src/new_lib.h']
12 |
13 | all_sources = glob.glob("src/**/*.c", recursive=True)
14 | all_headers = glob.glob("src/**/*.h")
15 |
16 | cu_sources = glob.glob("src/**/*_cuda.c")
17 | cu_headers = glob.glob("src/**/*_cuda.h")
18 |
19 | sources = list(set(all_sources) - set(cu_sources))
20 | headers = list(set(all_headers) - set(cu_headers))
21 |
22 | defines = []
23 |
24 | with_cuda = False
25 |
26 | if torch.cuda.is_available():
27 | print('Including CUDA code.')
28 | # sources += ['src/new_lib_cuda.c']
29 | sources += cu_sources
30 | # headers += ['src/new_lib_cuda.h']
31 | headers += cu_headers
32 | defines += [('WITH_CUDA', None)]
33 | with_cuda = True
34 |
35 | ffi = create_extension(
36 | '_ext.new_lib',
37 | headers=headers,
38 | sources=sources,
39 | define_macros=defines,
40 | relative_to=__file__,
41 | with_cuda=with_cuda,
42 | extra_compile_args=["-std=c99"]
43 | )
44 |
45 | if __name__ == '__main__':
46 | ffi.build()
47 |
--------------------------------------------------------------------------------
/script/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/functions/__init__.py
--------------------------------------------------------------------------------
/script/functions/add.py:
--------------------------------------------------------------------------------
1 | # functions/add.py
2 | import torch
3 | from torch.autograd import Function
4 | from _ext import my_lib
5 |
6 |
7 | class MyAddFunction(Function):
8 | def forward(self, input1, input2):
9 | output = input1.new()
10 | if not input1.is_cuda:
11 | my_lib.my_lib_add_forward(input1, input2, output)
12 | else:
13 | my_lib.my_lib_add_forward_cuda(input1, input2, output)
14 | return output
15 |
16 | def backward(self, grad_output):
17 | grad_input = grad_output.new()
18 | if not grad_output.is_cuda:
19 | my_lib.my_lib_add_backward(grad_output, grad_input)
20 | else:
21 | my_lib.my_lib_add_backward_cuda(grad_output, grad_input)
22 | return grad_input
23 |
--------------------------------------------------------------------------------
/script/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/modules/__init__.py
--------------------------------------------------------------------------------
/script/modules/add.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from functions.add import MyAddFunction
3 |
4 | class MyAddModule(Module):
5 | def forward(self, input1, input2):
6 | return MyAddFunction()(input1, input2)
7 |
--------------------------------------------------------------------------------
/script/src/my_lib.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
4 | THFloatTensor *output)
5 | {
6 | if (!THFloatTensor_isSameSizeAs(input1, input2))
7 | return 0;
8 | THFloatTensor_resizeAs(output, input1);
9 | THFloatTensor_cadd(output, input1, 1.0, input2);
10 | return 1;
11 | }
12 |
13 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
14 | {
15 | THFloatTensor_resizeAs(grad_input, grad_output);
16 | THFloatTensor_fill(grad_input, 1);
17 | return 1;
18 | }
19 |
--------------------------------------------------------------------------------
/script/src/my_lib.h:
--------------------------------------------------------------------------------
1 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
2 | THFloatTensor *output);
3 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/script/src/my_lib_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // this symbol will be resolved automatically from PyTorch libs
4 | extern THCState *state;
5 |
6 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
7 | THCudaTensor *output)
8 | {
9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2))
10 | return 0;
11 | THCudaTensor_resizeAs(state, output, input1);
12 | THCudaTensor_cadd(state, output, input1, 1.0, input2);
13 | return 1;
14 | }
15 |
16 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input)
17 | {
18 | THCudaTensor_resizeAs(state, grad_input, grad_output);
19 | THCudaTensor_fill(state, grad_input, 1);
20 | return 1;
21 | }
22 |
--------------------------------------------------------------------------------
/script/src/my_lib_cuda.h:
--------------------------------------------------------------------------------
1 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
2 | THCudaTensor *output);
3 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/script/src/new_lib/new_lib.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int new_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
4 | THFloatTensor *output)
5 | {
6 | if (!THFloatTensor_isSameSizeAs(input1, input2))
7 | return 0;
8 | THFloatTensor_resizeAs(output, input1);
9 | THFloatTensor_csub(output, input1, 1.0, input2);
10 | return 1;
11 | }
12 |
13 | int new_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
14 | {
15 | THFloatTensor_resizeAs(grad_input, grad_output);
16 | THFloatTensor_fill(grad_input, -1);
17 | return 1;
18 | }
19 |
--------------------------------------------------------------------------------
/script/src/new_lib/new_lib.h:
--------------------------------------------------------------------------------
1 | int new_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
2 | THFloatTensor *output);
3 | int new_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/script/src/new_lib/new_lib_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // this symbol will be resolved automatically from PyTorch libs
4 | extern THCState *state;
5 |
6 | int new_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
7 | THCudaTensor *output)
8 | {
9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2))
10 | return 0;
11 | THCudaTensor_resizeAs(state, output, input1);
12 | THCudaTensor_csub(state, output, input1, 1.0, input2);
13 | return 1;
14 | }
15 |
16 | int new_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input)
17 | {
18 | THCudaTensor_resizeAs(state, grad_input, grad_output);
19 | THCudaTensor_fill(state, grad_input, -1);
20 | return 1;
21 | }
22 |
--------------------------------------------------------------------------------
/script/src/new_lib/new_lib_cuda.h:
--------------------------------------------------------------------------------
1 | int new_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
2 | THCudaTensor *output);
3 | int new_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input);
4 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.autograd import Variable
4 |
5 | from group_norm import GroupNormMoving, GroupNorm
6 |
7 | if __name__ == "__main__":
8 | m = GroupNormMoving(64)
9 | input = Variable(torch.randn(3, 64, 32, 32))
10 |
11 | output = m(input)
12 | print(output.size())
13 |
--------------------------------------------------------------------------------
| | |