├── .gitignore ├── LICENSE ├── README.md ├── group_norm.py ├── package ├── README.md ├── build.py ├── my_package │ ├── __init__.py │ ├── functions │ │ ├── __init__.py │ │ └── add.py │ ├── modules │ │ ├── __init__.py │ │ └── add.py │ └── src │ │ ├── my_lib.c │ │ ├── my_lib.h │ │ ├── my_lib_cuda.c │ │ └── my_lib_cuda.h └── setup.py ├── script ├── README.md ├── _ext │ ├── __init__.py │ └── my_lib │ │ └── __init__.py ├── build.py ├── functions │ ├── __init__.py │ └── add.py ├── modules │ ├── __init__.py │ └── add.py └── src │ ├── my_lib.c │ ├── my_lib.h │ ├── my_lib_cuda.c │ ├── my_lib_cuda.h │ └── new_lib │ ├── new_lib.c │ ├── new_lib.h │ ├── new_lib_cuda.c │ └── new_lib_cuda.h └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | test.py 3 | script/_ext/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ligeng Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GroupNorm.pytorch 2 | PyTorch implementation of Group Normalization https://arxiv.org/abs/1803.08494 3 | 4 | * Python version is ready. 5 | * C++/CUDA version comes soon. -------------------------------------------------------------------------------- /group_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import Parameter 5 | 6 | 7 | class GroupNorm(nn.Module): 8 | def __init__(self, num_features, num_groups=32, eps=1e-5): 9 | super(GroupNorm, self).__init__() 10 | self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1)) 11 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 12 | self.num_groups = num_groups 13 | self.eps = eps 14 | 15 | def forward(self, x): 16 | N, C, H, W = x.size() 17 | G = self.num_groups 18 | assert C % G == 0 19 | 20 | x = x.view(N, G, -1) 21 | mean = x.mean(-1, keepdim=True) 22 | var = x.var(-1, keepdim=True) 23 | 24 | x = (x - mean) / (var + self.eps).sqrt() 25 | x = x.view(N, C, H, W) 26 | return x * self.weight + self.bias 27 | 28 | 29 | class GroupNormMoving(nn.Module): 30 | def __init__(self, num_features, num_groups=32, eps=1e-5, 31 | momentum=0.1, affine=True, 32 | track_running_stats=True 33 | ): 34 | super(GroupNormMoving, self).__init__() 35 | 36 | self.num_features = num_features 37 | self.num_groups = num_groups 38 | self.eps = eps 39 | 40 | self.momentum = momentum 41 | self.affine = affine 42 | 43 | self.track_running_stats = track_running_stats 44 | 45 | tensor_shape = (1, num_features, 1, 1) 46 | 47 | if self.affine: 48 | self.weight = Parameter(torch.Tensor(*tensor_shape)) 49 | self.bias = Parameter(torch.Tensor(*tensor_shape)) 50 | else: 51 | self.register_parameter('weight', None) 52 | self.register_parameter('bias', None) 53 | 54 | if self.track_running_stats: 55 | # self.register_buffer('running_mean', torch.zeros(*tensor_shape)) 56 | # self.register_buffer('running_var', torch.ones(*tensor_shape)) 57 | # else: 58 | self.register_parameter('running_mean', None) 59 | self.register_parameter('running_var', None) 60 | self.reset_parameters() 61 | 62 | def forward(self, x): 63 | N, C, H, W = x.size() 64 | G = self.num_groups 65 | assert C % G == 0, "Channel must be divided by groups" 66 | 67 | x = x.view(N, G, -1) 68 | mean = x.mean(-1, keepdim=True) 69 | var = x.var(-1, keepdim=True) 70 | 71 | if self.running_mean is None or self.running_mean.size() != mean.size(): 72 | # self.running_mean = Parameter(torch.Tensor(mean.data.clone())) 73 | # self.running_var = Parameter(torch.Tensor(var.data.clone())) 74 | self.running_mean = Parameter(torch.Tensor(mean.data)) 75 | self.running_var = Parameter(torch.Tensor(mean.data)) 76 | 77 | if self.training and self.track_running_stats: 78 | self.running_mean.data = mean * self.momentum + \ 79 | self.running_mean.data * (1 - self.momentum) 80 | self.running_var.data = var * self.momentum + \ 81 | self.running_var.data * (1 - self.momentum) 82 | 83 | # mean = self.running_mean 84 | # var = self.running_var 85 | 86 | x = (x - self.running_mean) / (self.running_var + self.eps).sqrt() 87 | x = x.view(N, C, H, W) 88 | return x * self.weight + self.bias 89 | 90 | def reset_parameters(self): 91 | if self.track_running_stats: 92 | if self.running_mean is not None and self.running_var is not None: 93 | self.running_mean.zero_() 94 | self.running_var.fill_(1) 95 | if self.affine: 96 | self.weight.data.uniform_() 97 | self.bias.data.zero_() 98 | 99 | def __repr__(self): 100 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 101 | ' affine={affine}, track_running_stats={track_running_stats})' 102 | .format(name=self.__class__.__name__, **self.__dict__)) -------------------------------------------------------------------------------- /package/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch FFI package 2 | 3 | This example shows how to structure the code to create an ffi package for 4 | PyTorch. It can be later distributed via pip. 5 | 6 | ### Required files: 7 | 8 | * `setup.py` - setuptools file, that defines package metadata and some extension 9 | options 10 | * `my_package/build.py` - cffi build file. Defines the extensions and builds 11 | them if executed. 12 | -------------------------------------------------------------------------------- /package/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = ['my_package/src/new_lib.c'] 8 | headers = ['my_package/src/new_lib.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['my_package/src/new_lib_cuda.c'] 15 | headers += ['my_package/src/new_lib_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | ffi = create_extension( 20 | 'my_package._ext.new_lib', 21 | package=True, 22 | headers=headers, 23 | sources=sources, 24 | define_macros=defines, 25 | relative_to=__file__, 26 | with_cuda=with_cuda 27 | ) 28 | 29 | if __name__ == '__main__': 30 | ffi.build() 31 | -------------------------------------------------------------------------------- /package/my_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/__init__.py -------------------------------------------------------------------------------- /package/my_package/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/functions/__init__.py -------------------------------------------------------------------------------- /package/my_package/functions/add.py: -------------------------------------------------------------------------------- 1 | # functions/add.py 2 | import torch 3 | from torch.autograd import Function 4 | from .._ext import my_lib 5 | 6 | 7 | class MyAddFunction(Function): 8 | def forward(self, input1, input2): 9 | output = input1.new() 10 | if not input1.is_cuda: 11 | my_lib.my_lib_add_forward(input1, input2, output) 12 | else: 13 | my_lib.my_lib_add_forward_cuda(input1, input2, output) 14 | return output 15 | 16 | def backward(self, grad_output): 17 | grad_input = grad_output.new() 18 | if not grad_output.is_cuda: 19 | my_lib.my_lib_add_backward(grad_output, grad_input) 20 | else: 21 | my_lib.my_lib_add_backward_cuda(grad_output, grad_input) 22 | return grad_input 23 | -------------------------------------------------------------------------------- /package/my_package/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/package/my_package/modules/__init__.py -------------------------------------------------------------------------------- /package/my_package/modules/add.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from ..functions.add import MyAddFunction 3 | 4 | class MyAddModule(Module): 5 | def forward(self, input1, input2): 6 | return MyAddFunction()(input1, input2) 7 | -------------------------------------------------------------------------------- /package/my_package/src/my_lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 4 | THFloatTensor *output) 5 | { 6 | if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | return 0; 8 | THFloatTensor_resizeAs(output, input1); 9 | THFloatTensor_cadd(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input) 14 | { 15 | THFloatTensor_resizeAs(grad_input, grad_output); 16 | THFloatTensor_fill(grad_input, 1); 17 | return 1; 18 | } 19 | -------------------------------------------------------------------------------- /package/my_package/src/my_lib.h: -------------------------------------------------------------------------------- 1 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 2 | THFloatTensor *output); 3 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input); 4 | -------------------------------------------------------------------------------- /package/my_package/src/my_lib_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // this symbol will be resolved automatically from PyTorch libs 4 | extern THCState *state; 5 | 6 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 7 | THCudaTensor *output) 8 | { 9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2)) 10 | return 0; 11 | THCudaTensor_resizeAs(state, output, input1); 12 | THCudaTensor_cadd(state, output, input1, 1.0, input2); 13 | return 1; 14 | } 15 | 16 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input) 17 | { 18 | THCudaTensor_resizeAs(state, grad_input, grad_output); 19 | THCudaTensor_fill(state, grad_input, 1); 20 | return 1; 21 | } 22 | -------------------------------------------------------------------------------- /package/my_package/src/my_lib_cuda.h: -------------------------------------------------------------------------------- 1 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 2 | THCudaTensor *output); 3 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input); 4 | -------------------------------------------------------------------------------- /package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | 6 | from setuptools import setup, find_packages 7 | 8 | import build 9 | 10 | this_file = os.path.dirname(__file__) 11 | 12 | setup( 13 | name="my_package", 14 | version="0.1", 15 | description="An example project using PyTorch FFI", 16 | url="https://github.com/pytorch/ffi-examples", 17 | author="XYZ", 18 | author_email="author@email.com", 19 | # Require cffi. 20 | install_requires=["cffi>=1.0.0"], 21 | setup_requires=["cffi>=1.0.0"], 22 | # Exclude the build files. 23 | packages=find_packages(exclude=["build"]), 24 | # Package where to put the extensions. Has to be a prefix of build.py. 25 | ext_package="", 26 | # Extensions to compile. 27 | cffi_modules=[ 28 | os.path.join(this_file, "build.py:ffi") 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /script/README.md: -------------------------------------------------------------------------------- 1 | # An example C extension for PyTorch 2 | 3 | This example showcases adding a neural network layer that adds two input Tensors 4 | 5 | - src: C source code 6 | - functions: the autograd functions 7 | - modules: code of the nn module 8 | - build.py: a small file that compiles your module to be ready to use 9 | - test.py: an example file that loads and uses the extension 10 | 11 | ```bash 12 | python build.py 13 | python test.py 14 | ``` 15 | -------------------------------------------------------------------------------- /script/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/_ext/__init__.py -------------------------------------------------------------------------------- /script/_ext/my_lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._my_lib import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /script/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch 5 | from torch.utils.ffi import create_extension 6 | 7 | 8 | this_file = os.path.dirname(__file__) 9 | 10 | # sources = ['src/new_lib.c'] 11 | # headers = ['src/new_lib.h'] 12 | 13 | all_sources = glob.glob("src/**/*.c", recursive=True) 14 | all_headers = glob.glob("src/**/*.h") 15 | 16 | cu_sources = glob.glob("src/**/*_cuda.c") 17 | cu_headers = glob.glob("src/**/*_cuda.h") 18 | 19 | sources = list(set(all_sources) - set(cu_sources)) 20 | headers = list(set(all_headers) - set(cu_headers)) 21 | 22 | defines = [] 23 | 24 | with_cuda = False 25 | 26 | if torch.cuda.is_available(): 27 | print('Including CUDA code.') 28 | # sources += ['src/new_lib_cuda.c'] 29 | sources += cu_sources 30 | # headers += ['src/new_lib_cuda.h'] 31 | headers += cu_headers 32 | defines += [('WITH_CUDA', None)] 33 | with_cuda = True 34 | 35 | ffi = create_extension( 36 | '_ext.new_lib', 37 | headers=headers, 38 | sources=sources, 39 | define_macros=defines, 40 | relative_to=__file__, 41 | with_cuda=with_cuda, 42 | extra_compile_args=["-std=c99"] 43 | ) 44 | 45 | if __name__ == '__main__': 46 | ffi.build() 47 | -------------------------------------------------------------------------------- /script/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/functions/__init__.py -------------------------------------------------------------------------------- /script/functions/add.py: -------------------------------------------------------------------------------- 1 | # functions/add.py 2 | import torch 3 | from torch.autograd import Function 4 | from _ext import my_lib 5 | 6 | 7 | class MyAddFunction(Function): 8 | def forward(self, input1, input2): 9 | output = input1.new() 10 | if not input1.is_cuda: 11 | my_lib.my_lib_add_forward(input1, input2, output) 12 | else: 13 | my_lib.my_lib_add_forward_cuda(input1, input2, output) 14 | return output 15 | 16 | def backward(self, grad_output): 17 | grad_input = grad_output.new() 18 | if not grad_output.is_cuda: 19 | my_lib.my_lib_add_backward(grad_output, grad_input) 20 | else: 21 | my_lib.my_lib_add_backward_cuda(grad_output, grad_input) 22 | return grad_input 23 | -------------------------------------------------------------------------------- /script/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyken17/GroupNorm.pytorch/800975c77a839d6d85daae9698085a5ef38d8d17/script/modules/__init__.py -------------------------------------------------------------------------------- /script/modules/add.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from functions.add import MyAddFunction 3 | 4 | class MyAddModule(Module): 5 | def forward(self, input1, input2): 6 | return MyAddFunction()(input1, input2) 7 | -------------------------------------------------------------------------------- /script/src/my_lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 4 | THFloatTensor *output) 5 | { 6 | if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | return 0; 8 | THFloatTensor_resizeAs(output, input1); 9 | THFloatTensor_cadd(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input) 14 | { 15 | THFloatTensor_resizeAs(grad_input, grad_output); 16 | THFloatTensor_fill(grad_input, 1); 17 | return 1; 18 | } 19 | -------------------------------------------------------------------------------- /script/src/my_lib.h: -------------------------------------------------------------------------------- 1 | int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 2 | THFloatTensor *output); 3 | int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input); 4 | -------------------------------------------------------------------------------- /script/src/my_lib_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // this symbol will be resolved automatically from PyTorch libs 4 | extern THCState *state; 5 | 6 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 7 | THCudaTensor *output) 8 | { 9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2)) 10 | return 0; 11 | THCudaTensor_resizeAs(state, output, input1); 12 | THCudaTensor_cadd(state, output, input1, 1.0, input2); 13 | return 1; 14 | } 15 | 16 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input) 17 | { 18 | THCudaTensor_resizeAs(state, grad_input, grad_output); 19 | THCudaTensor_fill(state, grad_input, 1); 20 | return 1; 21 | } 22 | -------------------------------------------------------------------------------- /script/src/my_lib_cuda.h: -------------------------------------------------------------------------------- 1 | int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 2 | THCudaTensor *output); 3 | int my_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input); 4 | -------------------------------------------------------------------------------- /script/src/new_lib/new_lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int new_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 4 | THFloatTensor *output) 5 | { 6 | if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | return 0; 8 | THFloatTensor_resizeAs(output, input1); 9 | THFloatTensor_csub(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int new_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input) 14 | { 15 | THFloatTensor_resizeAs(grad_input, grad_output); 16 | THFloatTensor_fill(grad_input, -1); 17 | return 1; 18 | } 19 | -------------------------------------------------------------------------------- /script/src/new_lib/new_lib.h: -------------------------------------------------------------------------------- 1 | int new_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, 2 | THFloatTensor *output); 3 | int new_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input); 4 | -------------------------------------------------------------------------------- /script/src/new_lib/new_lib_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // this symbol will be resolved automatically from PyTorch libs 4 | extern THCState *state; 5 | 6 | int new_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 7 | THCudaTensor *output) 8 | { 9 | if (!THCudaTensor_isSameSizeAs(state, input1, input2)) 10 | return 0; 11 | THCudaTensor_resizeAs(state, output, input1); 12 | THCudaTensor_csub(state, output, input1, 1.0, input2); 13 | return 1; 14 | } 15 | 16 | int new_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input) 17 | { 18 | THCudaTensor_resizeAs(state, grad_input, grad_output); 19 | THCudaTensor_fill(state, grad_input, -1); 20 | return 1; 21 | } 22 | -------------------------------------------------------------------------------- /script/src/new_lib/new_lib_cuda.h: -------------------------------------------------------------------------------- 1 | int new_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, 2 | THCudaTensor *output); 3 | int new_lib_add_backward_cuda(THCudaTensor *grad_output, THCudaTensor *grad_input); 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.autograd import Variable 4 | 5 | from group_norm import GroupNormMoving, GroupNorm 6 | 7 | if __name__ == "__main__": 8 | m = GroupNormMoving(64) 9 | input = Variable(torch.randn(3, 64, 32, 32)) 10 | 11 | output = m(input) 12 | print(output.size()) 13 | --------------------------------------------------------------------------------