├── .gitignore ├── LICENSE ├── README.md ├── exp └── train_val_step_se_resnet50.sh ├── fig ├── adaptive_conv.png ├── octave_conv.png ├── res2net.png ├── sablock.png └── srm.png ├── libs ├── __init__.py ├── flops_counter.py ├── logger.py ├── lr_scheduler.py ├── nn │ ├── OCtaveResnet.py │ ├── OctaveConv1.py │ ├── OctaveConv2.py │ ├── __init__.py │ ├── res2net.py │ ├── resnet.py │ ├── resnet_adaptiveconv.py │ ├── resnet_eca.py │ ├── resnet_ge.py │ ├── resnet_se.py │ ├── resnet_sge.py │ ├── resnet_sk.py │ └── resnet_srm.py ├── progress │ ├── __init__.py │ ├── bar.py │ ├── counter.py │ ├── helpers.py │ └── spinner.py └── utils.py ├── main_imagenet.py ├── requirement.txt └── test_speed.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.swp 3 | *.pyc 4 | 5 | 6 | data/ 7 | build/ 8 | encoding/_ext/ 9 | encoding.egg-info/ 10 | 11 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 XingtaiLi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Beyond Convolution 2 | ## ~~OctaveConv_pytorch~~ 3 | ## Pytorch implementation of recent operators 4 | This is **third parity** implementation(un-official) of Following Paper. 5 | 1. Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution(ICCV 2019). 6 | [paper](https://arxiv.org/pdf/1904.05049) 7 | ![](fig/octave_conv.png) 8 | 2. Adaptively Connected Neural Networks.(CVPR 2019) 9 | [paper](https://arxiv.org/abs/1904.03579) 10 | ![](fig/adaptive_conv.png) 11 | 3. Res2net:A New Multi-scale Backbone Architecture(PAMI 2019) 12 | [paper](https://arxiv.org/abs/1904.01169) 13 | ![](fig/res2net.png) 14 | 4. ScaleNet:Data-Driven Neuron Allocation for Scale Aggregation Networks (CVPR2019) 15 | [paper](https://arxiv.org/pdf/1904.09460) 16 | ![](fig/sablock.png) 17 | 5. SRM : A Style-based Recalibration Module for Convolutional Neural Networks 18 | [paper](https://arxiv.org/abs/1903.10829) 19 | ![](fig/srm.png) 20 | 6. SEnet: Squeeze-and-Excitation Networks(CVPR 2018) [paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf) 21 | 7. GEnet: Exploiting Feature Context in Convolutional Neural Networks(NIPS 2018) [paper](https://papers.nips.cc/paper/8151-gather-excite-exploiting-feature-context-in-convolutional-neural-networks.pdf) 22 | 8. ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks [paper](https://arxiv.org/abs/1910.03151) 23 | 9. SK-Net: Selective Kernel Networks(CVPR 2019) [paper](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Selective_Kernel_Networks_CVPR_2019_paper.pdf) 24 | 10. More Net will be added. 25 | 26 | ### Plan 27 | 1. add Res2Net bolock with SE-layer (done) 28 | 2. add Adaptive-Convolution: both pixel-aware and dataset-aware (done) 29 | 3. Train code on Imagenet. (done) 30 | 4. Add SE-like models. (done) 31 | 5. Keep tracking with new proposed operators. (-) 32 | 33 | ### Usage 34 | check model files under the fig/nn floder. 35 | 36 | ```python 37 | from lib.nn.OCtaveResnet import resnet50 38 | from lib.nn.res2net import se_resnet50 39 | from lib.nn.AdaptiveConvResnet import PixelAwareResnet50, DataSetAwareResnet50 40 | 41 | model = resnet50().cuda() 42 | model = se_resnet50().cuda() 43 | model = PixelAwareResnet50().cuda() 44 | model = DataSetAwareResnet50().cuda() 45 | 46 | ``` 47 | ### Training 48 | 49 | see exp floder for the detailed information 50 | 51 | ### CheckPoint 52 | 53 | 54 | ## Reference and Citation: 55 | 56 | 1. OctaveConv: MXNet implementation [here](https://github.com/terrychenism/OctaveConv) 57 | 2. AdaptiveCov: Offical tensorflow implementation [here](https://github.com/wanggrun/Adaptively-Connected-Neural-Networks) 58 | 3. ScaleNet: [here](https://github.com/Eli-YiLi/ScaleNet) 59 | 4. SGENet:[here](https://github.com/implus/PytorchInsight) 60 | 61 | Please consider cite the author's paper when using the code for your research. 62 | ## License 63 | MIT License 64 | -------------------------------------------------------------------------------- /exp/train_val_step_se_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=8 main_imagenet.py \ 4 | -a se_resnet50 --data /data/lxt/ImageNet \ 5 | --epochs 120 \ 6 | --schedule 30 60 90 \ 7 | --wd 1e-4 --gamma 0.1 \ 8 | --train-batch 64 \ 9 | --pretrained False \ 10 | --pretrained_dir /home/user/pretrained \ 11 | -c checkpoints/imagenet/se_res50_bs_512 \ 12 | --opt-level O0 \ 13 | --wd-all \ 14 | --label-smoothing 0. \ 15 | --warmup_epochs 5 -------------------------------------------------------------------------------- /fig/adaptive_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/fig/adaptive_conv.png -------------------------------------------------------------------------------- /fig/octave_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/fig/octave_conv.png -------------------------------------------------------------------------------- /fig/res2net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/fig/res2net.png -------------------------------------------------------------------------------- /fig/sablock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/fig/sablock.png -------------------------------------------------------------------------------- /fig/srm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/fig/srm.png -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxtGH/OctaveConv_pytorch/079f7da29d55c2eeed8985d33f0b2f765d7a469e/libs/__init__.py -------------------------------------------------------------------------------- /libs/flops_counter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True): 6 | assert type(input_res) is tuple 7 | assert len(input_res) == 2 8 | batch = torch.FloatTensor(1, 3, *input_res) 9 | flops_model = add_flops_counting_methods(model) 10 | flops_model.eval().start_flops_count() 11 | out = flops_model(batch) 12 | 13 | if print_per_layer_stat: 14 | print_model_with_flops(flops_model) 15 | flops_count = flops_model.compute_average_flops_cost() 16 | params_count = get_model_parameters_number(flops_model) 17 | flops_model.stop_flops_count() 18 | 19 | if as_strings: 20 | return flops_to_string(flops_count), params_to_string(params_count) 21 | 22 | return flops_count, params_count 23 | 24 | def flops_to_string(flops, units='GMac', precision=2): 25 | if units is None: 26 | if flops // 10**9 > 0: 27 | return str(round(flops / 10.**9, precision)) + ' GMac' 28 | elif flops // 10**6 > 0: 29 | return str(round(flops / 10.**6, precision)) + ' MMac' 30 | elif flops // 10**3 > 0: 31 | return str(round(flops / 10.**3, precision)) + ' KMac' 32 | else: 33 | return str(flops) + ' Mac' 34 | else: 35 | if units == 'GMac': 36 | return str(round(flops / 10.**9, precision)) + ' ' + units 37 | elif units == 'MMac': 38 | return str(round(flops / 10.**6, precision)) + ' ' + units 39 | elif units == 'KMac': 40 | return str(round(flops / 10.**3, precision)) + ' ' + units 41 | else: 42 | return str(flops) + ' Mac' 43 | 44 | def params_to_string(params_num): 45 | if params_num // 10 ** 6 > 0: 46 | return str(round(params_num / 10 ** 6, 2)) + ' M' 47 | elif params_num // 10 ** 3: 48 | return str(round(params_num / 10 ** 3, 2)) + ' k' 49 | 50 | def print_model_with_flops(model, units='GMac', precision=3): 51 | total_flops = model.compute_average_flops_cost() 52 | 53 | def accumulate_flops(self): 54 | if is_supported_instance(self): 55 | return self.__flops__ / model.__batch_counter__ 56 | else: 57 | sum = 0 58 | for m in self.children(): 59 | sum += m.accumulate_flops() 60 | return sum 61 | 62 | def flops_repr(self): 63 | accumulated_flops_cost = self.accumulate_flops() 64 | return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), 65 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 66 | self.original_extra_repr()]) 67 | 68 | def add_extra_repr(m): 69 | m.accumulate_flops = accumulate_flops.__get__(m) 70 | flops_extra_repr = flops_repr.__get__(m) 71 | if m.extra_repr != flops_extra_repr: 72 | m.original_extra_repr = m.extra_repr 73 | m.extra_repr = flops_extra_repr 74 | assert m.extra_repr != m.original_extra_repr 75 | 76 | def del_extra_repr(m): 77 | if hasattr(m, 'original_extra_repr'): 78 | m.extra_repr = m.original_extra_repr 79 | del m.original_extra_repr 80 | if hasattr(m, 'accumulate_flops'): 81 | del m.accumulate_flops 82 | 83 | model.apply(add_extra_repr) 84 | print(model) 85 | model.apply(del_extra_repr) 86 | 87 | def get_model_parameters_number(model): 88 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 89 | return params_num 90 | 91 | def add_flops_counting_methods(net_main_module): 92 | # adding additional methods to the existing module object, 93 | # this is done this way so that each function has access to self object 94 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 95 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 96 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 97 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 98 | 99 | net_main_module.reset_flops_count() 100 | 101 | # Adding variables necessary for masked flops computation 102 | net_main_module.apply(add_flops_mask_variable_or_reset) 103 | 104 | return net_main_module 105 | 106 | 107 | def compute_average_flops_cost(self): 108 | """ 109 | A method that will be available after add_flops_counting_methods() is called 110 | on a desired net object. 111 | 112 | Returns current mean flops consumption per image. 113 | 114 | """ 115 | 116 | batches_count = self.__batch_counter__ 117 | flops_sum = 0 118 | for module in self.modules(): 119 | if is_supported_instance(module): 120 | flops_sum += module.__flops__ 121 | 122 | return flops_sum / batches_count 123 | 124 | 125 | def start_flops_count(self): 126 | """ 127 | A method that will be available after add_flops_counting_methods() is called 128 | on a desired net object. 129 | 130 | Activates the computation of mean flops consumption per image. 131 | Call it before you run the network. 132 | 133 | """ 134 | add_batch_counter_hook_function(self) 135 | self.apply(add_flops_counter_hook_function) 136 | 137 | 138 | def stop_flops_count(self): 139 | """ 140 | A method that will be available after add_flops_counting_methods() is called 141 | on a desired net object. 142 | 143 | Stops computing the mean flops consumption per image. 144 | Call whenever you want to pause the computation. 145 | 146 | """ 147 | remove_batch_counter_hook_function(self) 148 | self.apply(remove_flops_counter_hook_function) 149 | 150 | 151 | def reset_flops_count(self): 152 | """ 153 | A method that will be available after add_flops_counting_methods() is called 154 | on a desired net object. 155 | 156 | Resets statistics computed so far. 157 | 158 | """ 159 | add_batch_counter_variables_or_reset(self) 160 | self.apply(add_flops_counter_variable_or_reset) 161 | 162 | 163 | def add_flops_mask(module, mask): 164 | def add_flops_mask_func(module): 165 | if isinstance(module, torch.nn.Conv2d): 166 | module.__mask__ = mask 167 | module.apply(add_flops_mask_func) 168 | 169 | 170 | def remove_flops_mask(module): 171 | module.apply(add_flops_mask_variable_or_reset) 172 | 173 | 174 | # ---- Internal functions 175 | def is_supported_instance(module): 176 | if isinstance(module, (torch.nn.Conv2d, torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 177 | torch.nn.LeakyReLU, torch.nn.ReLU6, torch.nn.Linear, \ 178 | torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.BatchNorm2d, \ 179 | torch.nn.Upsample, nn.AdaptiveMaxPool2d, nn.AdaptiveAvgPool2d)): 180 | return True 181 | 182 | return False 183 | 184 | 185 | def empty_flops_counter_hook(module, input, output): 186 | module.__flops__ += 0 187 | 188 | 189 | def upsample_flops_counter_hook(module, input, output): 190 | output_size = output[0] 191 | batch_size = output_size.shape[0] 192 | output_elements_count = batch_size 193 | for val in output_size.shape[1:]: 194 | output_elements_count *= val 195 | module.__flops__ += output_elements_count 196 | 197 | 198 | def relu_flops_counter_hook(module, input, output): 199 | active_elements_count = output.numel() 200 | module.__flops__ += active_elements_count 201 | 202 | 203 | def linear_flops_counter_hook(module, input, output): 204 | input = input[0] 205 | batch_size = input.shape[0] 206 | module.__flops__ += batch_size * input.shape[1] * output.shape[1] 207 | 208 | 209 | def pool_flops_counter_hook(module, input, output): 210 | input = input[0] 211 | module.__flops__ += np.prod(input.shape) 212 | 213 | def bn_flops_counter_hook(module, input, output): 214 | module.affine 215 | input = input[0] 216 | 217 | batch_flops = np.prod(input.shape) 218 | if module.affine: 219 | batch_flops *= 2 220 | module.__flops__ += batch_flops 221 | 222 | def conv_flops_counter_hook(conv_module, input, output): 223 | # Can have multiple inputs, getting the first one 224 | input = input[0] 225 | 226 | batch_size = input.shape[0] 227 | output_height, output_width = output.shape[2:] 228 | 229 | kernel_height, kernel_width = conv_module.kernel_size 230 | in_channels = conv_module.in_channels 231 | out_channels = conv_module.out_channels 232 | groups = conv_module.groups 233 | 234 | filters_per_channel = out_channels // groups 235 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 236 | 237 | active_elements_count = batch_size * output_height * output_width 238 | 239 | if conv_module.__mask__ is not None: 240 | # (b, 1, h, w) 241 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 242 | active_elements_count = flops_mask.sum() 243 | 244 | overall_conv_flops = conv_per_position_flops * active_elements_count 245 | 246 | bias_flops = 0 247 | 248 | if conv_module.bias is not None: 249 | 250 | bias_flops = out_channels * active_elements_count 251 | 252 | overall_flops = overall_conv_flops + bias_flops 253 | 254 | conv_module.__flops__ += overall_flops 255 | 256 | 257 | def batch_counter_hook(module, input, output): 258 | # Can have multiple inputs, getting the first one 259 | input = input[0] 260 | batch_size = input.shape[0] 261 | module.__batch_counter__ += batch_size 262 | 263 | 264 | def add_batch_counter_variables_or_reset(module): 265 | 266 | module.__batch_counter__ = 0 267 | 268 | 269 | def add_batch_counter_hook_function(module): 270 | if hasattr(module, '__batch_counter_handle__'): 271 | return 272 | 273 | handle = module.register_forward_hook(batch_counter_hook) 274 | module.__batch_counter_handle__ = handle 275 | 276 | 277 | def remove_batch_counter_hook_function(module): 278 | if hasattr(module, '__batch_counter_handle__'): 279 | module.__batch_counter_handle__.remove() 280 | del module.__batch_counter_handle__ 281 | 282 | 283 | def add_flops_counter_variable_or_reset(module): 284 | if is_supported_instance(module): 285 | module.__flops__ = 0 286 | 287 | 288 | def add_flops_counter_hook_function(module): 289 | if is_supported_instance(module): 290 | if hasattr(module, '__flops_handle__'): 291 | return 292 | 293 | if isinstance(module, torch.nn.Conv2d): 294 | handle = module.register_forward_hook(conv_flops_counter_hook) 295 | elif isinstance(module, (torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 296 | torch.nn.LeakyReLU, torch.nn.ReLU6)): 297 | handle = module.register_forward_hook(relu_flops_counter_hook) 298 | elif isinstance(module, torch.nn.Linear): 299 | handle = module.register_forward_hook(linear_flops_counter_hook) 300 | elif isinstance(module, (torch.nn.AvgPool2d, torch.nn.MaxPool2d, nn.AdaptiveMaxPool2d, \ 301 | nn.AdaptiveAvgPool2d)): 302 | handle = module.register_forward_hook(pool_flops_counter_hook) 303 | elif isinstance(module, torch.nn.BatchNorm2d): 304 | handle = module.register_forward_hook(bn_flops_counter_hook) 305 | elif isinstance(module, torch.nn.Upsample): 306 | handle = module.register_forward_hook(upsample_flops_counter_hook) 307 | else: 308 | handle = module.register_forward_hook(empty_flops_counter_hook) 309 | module.__flops_handle__ = handle 310 | 311 | 312 | def remove_flops_counter_hook_function(module): 313 | if is_supported_instance(module): 314 | if hasattr(module, '__flops_handle__'): 315 | module.__flops_handle__.remove() 316 | del module.__flops_handle__ 317 | # --- Masked flops counting 318 | 319 | 320 | # Also being run in the initialization 321 | def add_flops_mask_variable_or_reset(module): 322 | if is_supported_instance(module): 323 | module.__mask__ = None 324 | -------------------------------------------------------------------------------- /libs/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /libs/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """Learning Rate Schedulers""" 2 | from __future__ import division 3 | import math 4 | from math import pi, cos 5 | import torch 6 | 7 | class LRScheduler(object): 8 | r"""Learning Rate Scheduler 9 | For mode='step', we multiply lr with `decay_factor` at each epoch in `step`. 10 | For mode='poly':: 11 | lr = targetlr + (baselr - targetlr) * (1 - iter / maxiter) ^ power 12 | For mode='cosine':: 13 | lr = targetlr + (baselr - targetlr) * (1 + cos(pi * iter / maxiter)) / 2 14 | If warmup_epochs > 0, a warmup stage will be inserted before the main lr scheduler. 15 | For warmup_mode='linear':: 16 | lr = warmup_lr + (baselr - warmup_lr) * iter / max_warmup_iter 17 | For warmup_mode='constant':: 18 | lr = warmup_lr 19 | Parameters 20 | ---------- 21 | mode : str 22 | Modes for learning rate scheduler. 23 | Currently it supports 'step', 'poly' and 'cosine'. 24 | niters : int 25 | Number of iterations in each epoch. 26 | base_lr : float 27 | Base learning rate, i.e. the starting learning rate. 28 | epochs : int 29 | Number of training epochs. 30 | step : list 31 | A list of epochs to decay the learning rate. 32 | decay_factor : float 33 | Learning rate decay factor. 34 | targetlr : float 35 | Target learning rate for poly and cosine, as the ending learning rate. 36 | power : float 37 | Power of poly function. 38 | warmup_epochs : int 39 | Number of epochs for the warmup stage. 40 | warmup_lr : float 41 | The base learning rate for the warmup stage. 42 | warmup_mode : str 43 | Modes for the warmup stage. 44 | Currently it supports 'linear' and 'constant'. 45 | """ 46 | def __init__(self, optimizer, niters, args): 47 | super(LRScheduler, self).__init__() 48 | 49 | self.mode = args.lr_mode 50 | self.warmup_mode = args.warmup_mode if hasattr(args,'warmup_mode') else 'linear' 51 | assert(self.mode in ['step', 'poly', 'cosine']) 52 | assert(self.warmup_mode in ['linear', 'constant']) 53 | 54 | self.optimizer = optimizer 55 | 56 | self.base_lr = args.base_lr if hasattr(args,'base_lr') else 0.1 57 | self.learning_rate = self.base_lr 58 | self.niters = niters 59 | 60 | self.step = [int(i) for i in args.step.split(',')] if hasattr(args,'step') else [30, 60, 90] 61 | self.decay_factor = args.decay_factor if hasattr(args,'decay_factor') else 0.1 62 | self.targetlr = args.targetlr if hasattr(args,'targetlr') else 0.0 63 | self.power = args.power if hasattr(args,'power') else 2.0 64 | self.warmup_lr = args.warmup_lr if hasattr(args,'warmup_lr') else 0.0 65 | self.max_iter = args.epochs * niters 66 | self.warmup_iters = (args.warmup_epochs if hasattr(args,'warmup_epochs') else 0) * niters 67 | 68 | def update(self, i, epoch): 69 | T = epoch * self.niters + i 70 | assert (T >= 0 and T <= self.max_iter) 71 | 72 | if self.warmup_iters > T: 73 | # Warm-up Stage 74 | if self.warmup_mode == 'linear': 75 | self.learning_rate = self.warmup_lr + (self.base_lr - self.warmup_lr) * \ 76 | T / self.warmup_iters 77 | elif self.warmup_mode == 'constant': 78 | self.learning_rate = self.warmup_lr 79 | else: 80 | raise NotImplementedError 81 | else: 82 | if self.mode == 'step': 83 | count = sum([1 for s in self.step if s <= epoch]) 84 | self.learning_rate = self.base_lr * pow(self.decay_factor, count) 85 | elif self.mode == 'poly': 86 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \ 87 | pow(1 - (T - self.warmup_iters) / (self.max_iter - self.warmup_iters), self.power) 88 | elif self.mode == 'cosine': 89 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \ 90 | (1 + cos(pi * (T - self.warmup_iters) / (self.max_iter - self.warmup_iters))) / 2 91 | else: 92 | raise NotImplementedError 93 | 94 | for i, param_group in enumerate(self.optimizer.param_groups): 95 | param_group['lr'] = self.learning_rate 96 | 97 | 98 | class CosineAnnealingLR(object): 99 | def __init__(self, optimizer, T_max, N_batch, eta_min=0, last_epoch=-1, warmup=0): 100 | if not isinstance(optimizer, torch.optim.Optimizer): 101 | raise TypeError('{} is not an Optimizer'.format( 102 | type(optimizer).__name__)) 103 | self.optimizer = optimizer 104 | self.T_max = T_max 105 | self.N_batch = N_batch 106 | self.eta_min = eta_min 107 | self.warmup = warmup 108 | 109 | if last_epoch == -1: 110 | for group in optimizer.param_groups: 111 | group.setdefault('initial_lr', group['lr']) 112 | else: 113 | for i, group in enumerate(optimizer.param_groups): 114 | if 'initial_lr' not in group: 115 | raise KeyError("param 'initial_lr' is not specified " 116 | "in param_groups[{}] when resuming an optimizer".format(i)) 117 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 118 | self.update(last_epoch+1) 119 | self.last_epoch = last_epoch 120 | self.iter = 0 121 | 122 | def state_dict(self): 123 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 124 | 125 | def load_state_dict(self, state_dict): 126 | self.__dict__.update(state_dict) 127 | 128 | def get_lr(self): 129 | if self.last_epoch < self.warmup: 130 | lrs = [base_lr * (self.last_epoch + self.iter / self.N_batch) / self.warmup for base_lr in self.base_lrs] 131 | else: 132 | lrs = [self.eta_min + (base_lr - self.eta_min) * 133 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup + self.iter / self.N_batch) / (self.T_max - self.warmup))) / 2 134 | for base_lr in self.base_lrs] 135 | return lrs 136 | 137 | def update(self, epoch, batch=0): 138 | self.last_epoch = epoch 139 | self.iter = batch + 1 140 | lrs = self.get_lr() 141 | for param_group, lr in zip(self.optimizer.param_groups, lrs): 142 | param_group['lr'] = lr 143 | 144 | return lrs -------------------------------------------------------------------------------- /libs/nn/OCtaveResnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 4 | # Pytorch Implementation of Octave Resnet 5 | # original code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | import torch.nn as nn 7 | 8 | __all__ = ['Octresnet50','Octresnet101'] 9 | 10 | from libs.nn.OctaveConv2 import * 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 14 | """3x3 conv with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=(3,3), stride=stride, 16 | padding=1, groups=groups, bias=False) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 conv""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=stride, bias=False, padding=0) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 28 | base_width=64, norm_layer=None): 29 | super(BasicBlock, self).__init__() 30 | if norm_layer is None: 31 | norm_layer = nn.BatchNorm2d 32 | if groups != 1 or base_width != 64: 33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 34 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = norm_layer(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = norm_layer(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 66 | base_width=64, norm_layer=None,First=False): 67 | super(Bottleneck, self).__init__() 68 | if norm_layer is None: 69 | norm_layer = nn.BatchNorm2d 70 | width = int(planes * (base_width / 64.)) * groups 71 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 72 | self.first = First 73 | if self.first: 74 | self.ocb1 = FirstOctaveCBR(inplanes, width, kernel_size=(1, 1),norm_layer=norm_layer,padding=0) 75 | else: 76 | self.ocb1 = OctaveCBR(inplanes, width, kernel_size=(1,1),norm_layer=norm_layer,padding=0) 77 | 78 | self.ocb2 = OctaveCBR(width, width, kernel_size=(3,3), stride=stride, groups=groups, norm_layer=norm_layer) 79 | 80 | self.ocb3 = OctaveCB(width, planes * self.expansion, kernel_size=(1,1), norm_layer=norm_layer,padding=0) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | 87 | if self.first: 88 | x_h_res, x_l_res = self.ocb1(x) 89 | x_h, x_l = self.ocb2((x_h_res, x_l_res)) 90 | else: 91 | x_h_res, x_l_res = x 92 | x_h, x_l = self.ocb1((x_h_res,x_l_res)) 93 | x_h, x_l = self.ocb2((x_h, x_l)) 94 | 95 | x_h, x_l = self.ocb3((x_h, x_l)) 96 | 97 | if self.downsample is not None: 98 | x_h_res, x_l_res = self.downsample((x_h_res,x_l_res)) 99 | 100 | x_h += x_h_res 101 | x_l += x_l_res 102 | 103 | x_h = self.relu(x_h) 104 | x_l = self.relu(x_l) 105 | 106 | return x_h, x_l 107 | 108 | 109 | class BottleneckLast(nn.Module): 110 | expansion = 4 111 | 112 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 113 | base_width=64, norm_layer=None): 114 | super(BottleneckLast, self).__init__() 115 | if norm_layer is None: 116 | norm_layer = nn.BatchNorm2d 117 | width = int(planes * (base_width / 64.)) * groups 118 | # Last means the end of two branch 119 | self.ocb1 = OctaveCBR(inplanes, width,kernel_size=(1,1),padding=0) 120 | self.ocb2 = OctaveCBR(width, width, kernel_size=(3, 3), stride=stride, groups=groups, norm_layer=norm_layer) 121 | self.ocb3 = LastOCtaveCB(width, planes * self.expansion, kernel_size=(1, 1), norm_layer=norm_layer, padding=0) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.downsample = downsample 124 | self.stride = stride 125 | 126 | def forward(self,x): 127 | 128 | x_h_res, x_l_res = x 129 | x_h, x_l = self.ocb1((x_h_res, x_l_res)) 130 | 131 | x_h, x_l = self.ocb2((x_h, x_l)) 132 | x_h = self.ocb3((x_h, x_l)) 133 | 134 | if self.downsample is not None: 135 | x_h_res = self.downsample((x_h_res, x_l_res)) 136 | 137 | x_h += x_h_res 138 | x_h = self.relu(x_h) 139 | 140 | return x_h 141 | 142 | 143 | class BottleneckOrigin(nn.Module): 144 | expansion = 4 145 | 146 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 147 | base_width=64, norm_layer=None): 148 | super(BottleneckOrigin, self).__init__() 149 | if norm_layer is None: 150 | norm_layer = nn.BatchNorm2d 151 | width = int(planes * (base_width / 64.)) * groups 152 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 153 | self.conv1 = conv1x1(inplanes, width) 154 | self.bn1 = norm_layer(width) 155 | self.conv2 = conv3x3(width, width, stride, groups) 156 | self.bn2 = norm_layer(width) 157 | self.conv3 = conv1x1(width, planes * self.expansion) 158 | self.bn3 = norm_layer(planes * self.expansion) 159 | self.relu = nn.ReLU(inplace=True) 160 | self.downsample = downsample 161 | self.stride = stride 162 | 163 | def forward(self, x): 164 | identity = x 165 | 166 | out = self.conv1(x) 167 | out = self.bn1(out) 168 | out = self.relu(out) 169 | 170 | out = self.conv2(out) 171 | out = self.bn2(out) 172 | out = self.relu(out) 173 | 174 | out = self.conv3(out) 175 | out = self.bn3(out) 176 | 177 | if self.downsample is not None: 178 | identity = self.downsample(x) 179 | 180 | out += identity 181 | out = self.relu(out) 182 | 183 | return out 184 | 185 | 186 | class OCtaveResNet(nn.Module): 187 | 188 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 189 | groups=1, width_per_group=64, norm_layer=None): 190 | super(OCtaveResNet, self).__init__() 191 | if norm_layer is None: 192 | norm_layer = nn.BatchNorm2d 193 | 194 | self.inplanes = 64 195 | self.groups = groups 196 | self.base_width = width_per_group 197 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 198 | bias=False) 199 | self.bn1 = norm_layer(self.inplanes) 200 | self.relu = nn.ReLU(inplace=True) 201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 202 | 203 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, First=True) 204 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 205 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 206 | self.layer4 = self._make_last_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 207 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 208 | self.fc = nn.Linear(512 * block.expansion, num_classes) 209 | 210 | for m in self.modules(): 211 | if isinstance(m, nn.Conv2d): 212 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 213 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 214 | nn.init.constant_(m.weight, 1) 215 | nn.init.constant_(m.bias, 0) 216 | 217 | # Zero-initialize the last BN in each residual branch, 218 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 219 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 220 | if zero_init_residual: 221 | for m in self.modules(): 222 | if isinstance(m, Bottleneck): 223 | nn.init.constant_(m.bn3.weight, 0) 224 | elif isinstance(m, BasicBlock): 225 | nn.init.constant_(m.bn2.weight, 0) 226 | 227 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, First=False): 228 | if norm_layer is None: 229 | norm_layer = nn.BatchNorm2d 230 | downsample = None 231 | if stride != 1 or self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | OctaveCB(in_channels=self.inplanes,out_channels=planes * block.expansion, kernel_size=(1,1), stride=stride, padding=0) 234 | ) 235 | 236 | layers = [] 237 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 238 | self.base_width, norm_layer, First)) 239 | self.inplanes = planes * block.expansion 240 | for _ in range(1, blocks): 241 | layers.append(block(self.inplanes, planes, groups=self.groups, 242 | base_width=self.base_width, norm_layer=norm_layer)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def _make_last_layer(self, block, planes, blocks, stride=1, norm_layer=None): 247 | 248 | if norm_layer is None: 249 | norm_layer = nn.BatchNorm2d 250 | downsample = None 251 | if stride != 1 or self.inplanes != planes * block.expansion: 252 | downsample = nn.Sequential( 253 | LastOCtaveCB(in_channels=self.inplanes,out_channels=planes * block.expansion, kernel_size=(1,1), stride=stride, padding=0) 254 | ) 255 | 256 | layers = [] 257 | layers.append(BottleneckLast(self.inplanes, planes, stride, downsample, self.groups, 258 | self.base_width, norm_layer)) 259 | self.inplanes = planes * block.expansion 260 | 261 | for _ in range(1, blocks): 262 | layers.append(BottleneckOrigin(self.inplanes, planes, groups=self.groups, 263 | base_width=self.base_width, norm_layer=norm_layer)) 264 | 265 | return nn.Sequential(*layers) 266 | 267 | def forward(self, x): 268 | x = self.conv1(x) 269 | x = self.bn1(x) 270 | x = self.relu(x) 271 | x = self.maxpool(x) 272 | 273 | x_h, x_l = self.layer1(x) 274 | x_h, x_l = self.layer2((x_h,x_l)) 275 | x_h, x_l = self.layer3((x_h,x_l)) 276 | # print(x_h.size(), x_l.size()) 277 | x_h = self.layer4((x_h,x_l)) 278 | x = self.avgpool(x_h) 279 | x = x.view(x.size(0), -1) 280 | x = self.fc(x) 281 | 282 | return x 283 | 284 | 285 | def Octresnet50(pretrained=False, **kwargs): 286 | """Constructs a ResNet-50 model. 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | """ 291 | model = OCtaveResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 292 | return model 293 | 294 | 295 | def Octresnet101(pretrained=False, **kwargs): 296 | """Constructs a ResNet-101 model. 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | """ 301 | model = OCtaveResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 302 | return model 303 | 304 | 305 | def Octresnet152(pretrained=False, **kwargs): 306 | """Constructs a ResNet-152 model. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = OCtaveResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 312 | return model 313 | 314 | 315 | if __name__ == '__main__': 316 | model = Octresnet50(num_classes=10).cuda() 317 | print(model) 318 | i = torch.Tensor(1,3,256,256).cuda() 319 | y= model(i) 320 | print(y.size()) 321 | """ 322 | layer output size: 323 | torch.Size([1, 128, 64, 64]) 324 | torch.Size([1, 256, 32, 32]) 325 | torch.Size([1, 1024, 16, 16]) 326 | torch.Size([1, 2048, 8, 8]) 327 | torch.Size([1, 1000]) 328 | """ 329 | -------------------------------------------------------------------------------- /libs/nn/OctaveConv1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 4 | # Pytorch Implementation of Octave Conv Operation 5 | # This version uses F.conv2d with learnable sampled weights 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | up_kwargs = {'mode': 'nearest'} 12 | 13 | 14 | class OctaveConv(nn.Module): 15 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=1, dilation=1, 16 | groups=1, bias=False, up_kwargs = up_kwargs): 17 | super(OctaveConv, self).__init__() 18 | self.weights = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1])) 19 | self.stride = stride 20 | self.padding = padding 21 | self.dilation = dilation 22 | self.groups = groups 23 | if bias: 24 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 25 | else: 26 | self.bias = torch.zeros(out_channels).cuda() 27 | self.up_kwargs = up_kwargs 28 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2) 29 | 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.alpha_in = alpha_in 33 | self.alpha_out = alpha_out 34 | 35 | def forward(self, x): 36 | X_h, X_l = x 37 | 38 | if self.stride ==2: 39 | X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l) 40 | 41 | X_h2l = self.h2g_pool(X_h) 42 | 43 | 44 | end_h_x = int(self.in_channels*(1- self.alpha_in)) 45 | end_h_y = int(self.out_channels*(1- self.alpha_out)) 46 | 47 | X_h2h = F.conv2d(X_h, self.weights[0:end_h_y, 0:end_h_x, :,:], self.bias[0:end_h_y], 1, 48 | self.padding, self.dilation, self.groups) 49 | 50 | X_l2l = F.conv2d(X_l, self.weights[end_h_y:, end_h_x:, :,:], self.bias[end_h_y:], 1, 51 | self.padding, self.dilation, self.groups) 52 | 53 | X_h2l = F.conv2d(X_h2l, self.weights[end_h_y:, 0: end_h_x, :,:], self.bias[end_h_y:], 1, 54 | self.padding, self.dilation, self.groups) 55 | 56 | X_l2h = F.conv2d(X_l, self.weights[0:end_h_y, end_h_x:, :,:], self.bias[0:end_h_y], 1, 57 | self.padding, self.dilation, self.groups) 58 | 59 | X_l2h = F.upsample(X_l2h, scale_factor=2, **self.up_kwargs) 60 | 61 | X_h = X_h2h + X_l2h 62 | X_l = X_l2l + X_h2l 63 | 64 | return X_h, X_l 65 | 66 | 67 | class FirstOctaveConv(nn.Module): 68 | def __init__(self, in_channels, out_channels,kernel_size, alpha_in=0.0, alpha_out=0.5, stride=1, padding=1, dilation=1, 69 | groups=1, bias=False, up_kwargs = up_kwargs): 70 | super(FirstOctaveConv, self).__init__() 71 | self.weights = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1])) 72 | self.stride = stride 73 | self.padding = padding 74 | self.dilation = dilation 75 | self.groups = groups 76 | if bias: 77 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 78 | else: 79 | self.bias = torch.zeros(out_channels).cuda() 80 | self.up_kwargs = up_kwargs 81 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2) 82 | 83 | self.in_channels = in_channels 84 | self.out_channels = out_channels 85 | self.alpha_in = alpha_in 86 | self.alpha_out = alpha_out 87 | 88 | def forward(self, x): 89 | 90 | if self.stride ==2: 91 | x = self.h2g_pool(x) 92 | 93 | X_h2l = self.h2g_pool(x) 94 | X_h = x 95 | 96 | end_h_x = int(self.in_channels*(1- self.alpha_in)) 97 | end_h_y = int(self.out_channels*(1- self.alpha_out)) 98 | 99 | X_h2h = F.conv2d(X_h, self.weights[0:end_h_y, 0: end_h_x, :,:], self.bias[0:end_h_y], 1, 100 | self.padding, self.dilation, self.groups) 101 | 102 | X_h2l = F.conv2d(X_h2l, self.weights[end_h_y:, 0: end_h_x, :,:], self.bias[end_h_y:], 1, 103 | self.padding, self.dilation, self.groups) 104 | 105 | X_h = X_h2h 106 | X_l = X_h2l 107 | 108 | return X_h, X_l 109 | 110 | 111 | class LastOctaveConv(nn.Module): 112 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.0, stride=1, padding=1, dilation=1, 113 | groups=1, bias=False, up_kwargs = up_kwargs): 114 | super(LastOctaveConv, self).__init__() 115 | self.weights = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1])) 116 | self.stride = stride 117 | self.padding = padding 118 | self.dilation = dilation 119 | self.groups = groups 120 | if bias: 121 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 122 | else: 123 | self.bias = torch.zeros(out_channels).cuda() 124 | self.up_kwargs = up_kwargs 125 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2) 126 | 127 | self.in_channels = in_channels 128 | self.out_channels = out_channels 129 | self.alpha_in = alpha_in 130 | self.alpha_out = alpha_out 131 | 132 | def forward(self, x): 133 | X_h, X_l = x 134 | 135 | if self.stride ==2: 136 | X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l) 137 | 138 | end_h_x = int(self.in_channels*(1- self.alpha_in)) 139 | end_h_y = int(self.out_channels*(1- self.alpha_out)) 140 | 141 | X_h2h = F.conv2d(X_h, self.weights[0:end_h_y, 0:end_h_x, :,:], self.bias[:end_h_y], 1, 142 | self.padding, self.dilation, self.groups) 143 | 144 | X_l2h = F.conv2d(X_l, self.weights[0:end_h_y, end_h_x:, :,:], self.bias[:end_h_y], 1, 145 | self.padding, self.dilation, self.groups) 146 | X_l2h = F.upsample(X_l2h, scale_factor=2, **self.up_kwargs) 147 | 148 | X_h = X_h2h + X_l2h 149 | 150 | return X_h 151 | 152 | 153 | class OctaveCBR(nn.Module): 154 | def __init__(self,in_channels, out_channels, kernel_size=(3,3),alpha_in=0.5, alpha_out=0.5, stride=1, padding=1, dilation=1, 155 | groups=1, bias=False, up_kwargs = up_kwargs, norm_layer=nn.BatchNorm2d): 156 | super(OctaveCBR, self).__init__() 157 | self.conv = OctaveConv(in_channels,out_channels,kernel_size, alpha_in,alpha_out, stride, padding, dilation, groups, bias, up_kwargs) 158 | self.bn_h = norm_layer(int(out_channels*(1-alpha_out))) 159 | self.bn_l = norm_layer(int(out_channels*alpha_out)) 160 | self.relu = nn.ReLU(inplace=True) 161 | 162 | def forward(self, x): 163 | x_h, x_l = self.conv(x) 164 | x_h = self.relu(self.bn_h(x_h)) 165 | x_l = self.relu(self.bn_l(x_l)) 166 | return x_h, x_l 167 | 168 | 169 | class OctaveCB(nn.Module): 170 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha_in=0.5, alpha_out=0.5, stride=1, padding=1, dilation=1, 171 | groups=1, bias=False, up_kwargs=up_kwargs, norm_layer=nn.BatchNorm2d): 172 | super(OctaveCB, self).__init__() 173 | self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation, 174 | groups, bias, up_kwargs) 175 | self.bn_h = norm_layer(int(out_channels * (1 - alpha_out))) 176 | self.bn_l = norm_layer(int(out_channels * alpha_out)) 177 | 178 | def forward(self, x): 179 | x_h, x_l = self.conv(x) 180 | x_h = self.bn_h(x_h) 181 | x_l = self.bn_l(x_l) 182 | return x_h, x_l 183 | 184 | 185 | class FirstOctaveCBR(nn.Module): 186 | def __init__(self, in_channels, out_channels, kernel_size=(3,3),alpha_in=0.0, alpha_out=0.5, stride=1, padding=1, dilation=1, 187 | groups=1, bias=False, up_kwargs = up_kwargs, norm_layer=nn.BatchNorm2d): 188 | super(FirstOctaveCBR, self).__init__() 189 | self.conv = FirstOctaveConv(in_channels,out_channels,kernel_size, alpha_in,alpha_out,stride,padding,dilation,groups,bias,up_kwargs) 190 | self.bn_h = norm_layer(int(out_channels * (1 - alpha_out))) 191 | self.bn_l = norm_layer(int(out_channels * alpha_out)) 192 | self.relu = nn.ReLU(inplace=True) 193 | 194 | def forward(self, x): 195 | x_h, x_l = self.conv(x) 196 | x_h = self.relu(self.bn_h(x_h)) 197 | x_l = self.relu(self.bn_l(x_l)) 198 | return x_h, x_l 199 | 200 | 201 | class LastOCtaveCBR(nn.Module): 202 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha_in=0.5, alpha_out=0.0, stride=1, padding=1, dilation=1, 203 | groups=1, bias=False, up_kwargs = up_kwargs, norm_layer=nn.BatchNorm2d): 204 | super(LastOCtaveCBR, self).__init__() 205 | self.conv = LastOctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation, groups, bias, up_kwargs) 206 | self.bn_h = norm_layer(int(out_channels * (1 - alpha_out))) 207 | self.relu = nn.ReLU(inplace=True) 208 | 209 | def forward(self, x): 210 | x_h = self.conv(x) 211 | x_h = self.relu(self.bn_h(x_h)) 212 | return x_h 213 | 214 | 215 | class FirstOctaveCB(nn.Module): 216 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha_in=0.0, alpha_out=0.5, stride=1, padding=1, dilation=1, 217 | groups=1, bias=False, up_kwargs = up_kwargs, norm_layer=nn.BatchNorm2d): 218 | super(FirstOctaveCB, self).__init__() 219 | self.conv = FirstOctaveConv(in_channels,out_channels,kernel_size, alpha_in,alpha_out,stride,padding,dilation,groups,bias,up_kwargs) 220 | self.bn_h = norm_layer(int(out_channels * (1 - alpha_out))) 221 | self.bn_l = norm_layer(int(out_channels * alpha_out)) 222 | self.relu = nn.ReLU(inplace=True) 223 | 224 | def forward(self, x): 225 | x_h, x_l = self.conv(x) 226 | x_h = self.bn_h(x_h) 227 | x_l = self.bn_l(x_l) 228 | return x_h, x_l 229 | 230 | 231 | class LastOCtaveCB(nn.Module): 232 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.0, stride=1, padding=1, dilation=1, 233 | groups=1, bias=False, up_kwargs = up_kwargs, norm_layer=nn.BatchNorm2d): 234 | super(LastOCtaveCB, self).__init__() 235 | self.conv = LastOctaveConv( in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation, groups, bias, up_kwargs) 236 | self.bn_h = norm_layer(int(out_channels * (1 - alpha_out))) 237 | self.relu = nn.ReLU(inplace=True) 238 | 239 | def forward(self, x): 240 | x_h = self.conv(x) 241 | x_h = self.bn_h(x_h) 242 | return x_h 243 | 244 | 245 | if __name__ == '__main__': 246 | # nn.Conv2d 247 | high = torch.Tensor(1, 64, 32, 32).cuda() 248 | low = torch.Tensor(1, 192, 16, 16).cuda() 249 | # test Oc conv 250 | OCconv = OctaveConv(kernel_size=(3,3),in_channels=256,out_channels=512,bias=False,stride=2,alpha_in=0.75,alpha_out=0.75).cuda() 251 | i = high,low 252 | x_out,y_out = OCconv(i) 253 | print(x_out.size()) 254 | print(y_out.size()) 255 | # test First Octave Cov 256 | i = torch.Tensor(1, 3, 512, 512).cuda() 257 | FOCconv = FirstOctaveConv(kernel_size=(3,3), in_channels=3, out_channels=128).cuda() 258 | x_out, y_out = FOCconv(i) 259 | # test last Octave Cov 260 | LOCconv = LastOctaveConv(kernel_size=(3,3), in_channels=256, out_channels=128, alpha_out=0.75, alpha_in=0.75).cuda() 261 | i = high, low 262 | out = LOCconv(i) 263 | print(out.size()) 264 | # test OCB 265 | ocb = OctaveCB(in_channels=256, out_channels=128, alpha_out=0.75, alpha_in=0.75).cuda() 266 | i = high, low 267 | x_out_h, y_out_l = ocb(i) 268 | print(x_out_h.size()) 269 | print(y_out_l.size()) 270 | 271 | ocb_last = LastOCtaveCBR(256,128, alpha_out=0.0, alpha_in=0.75).cuda() 272 | i = high, low 273 | x_out_h = ocb_last(i) 274 | print(x_out_h.size()) 275 | -------------------------------------------------------------------------------- /libs/nn/OctaveConv2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 4 | # Pytorch Implementation of Octave Conv Operation 5 | # This version use nn.Conv2d because alpha_in always equals alpha_out 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class OctaveConv(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1, 13 | groups=1, bias=False): 14 | super(OctaveConv, self).__init__() 15 | kernel_size = kernel_size[0] 16 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2) 17 | self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') 18 | self.stride = stride 19 | self.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels), 20 | kernel_size, 1, padding, dilation, groups, bias) 21 | self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels), 22 | kernel_size, 1, padding, dilation, groups, bias) 23 | self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels), 24 | kernel_size, 1, padding, dilation, groups, bias) 25 | self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels), 26 | out_channels - int(alpha * out_channels), 27 | kernel_size, 1, padding, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | X_h, X_l = x 31 | 32 | if self.stride ==2: 33 | X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l) 34 | 35 | X_h2l = self.h2g_pool(X_h) 36 | 37 | X_h2h = self.h2h(X_h) 38 | X_l2h = self.l2h(X_l) 39 | 40 | X_l2l = self.l2l(X_l) 41 | X_h2l = self.h2l(X_h2l) 42 | 43 | X_l2h = self.upsample(X_l2h) 44 | X_h = X_l2h + X_h2h 45 | X_l = X_h2l + X_l2l 46 | 47 | return X_h, X_l 48 | 49 | 50 | class FirstOctaveConv(nn.Module): 51 | def __init__(self, in_channels, out_channels,kernel_size, alpha=0.5, stride=1, padding=1, dilation=1, 52 | groups=1, bias=False): 53 | super(FirstOctaveConv, self).__init__() 54 | self.stride = stride 55 | kernel_size = kernel_size[0] 56 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2) 57 | self.h2l = torch.nn.Conv2d(in_channels, int(alpha * out_channels), 58 | kernel_size, 1, padding, dilation, groups, bias) 59 | self.h2h = torch.nn.Conv2d(in_channels, out_channels - int(alpha * out_channels), 60 | kernel_size, 1, padding, dilation, groups, bias) 61 | 62 | def forward(self, x): 63 | if self.stride ==2: 64 | x = self.h2g_pool(x) 65 | 66 | X_h2l = self.h2g_pool(x) 67 | X_h = x 68 | X_h = self.h2h(X_h) 69 | X_l = self.h2l(X_h2l) 70 | 71 | return X_h, X_l 72 | 73 | 74 | class LastOctaveConv(nn.Module): 75 | def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1, 76 | groups=1, bias=False): 77 | super(LastOctaveConv, self).__init__() 78 | self.stride = stride 79 | kernel_size = kernel_size[0] 80 | self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2) 81 | 82 | self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels, 83 | kernel_size, 1, padding, dilation, groups, bias) 84 | self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels), 85 | out_channels, 86 | kernel_size, 1, padding, dilation, groups, bias) 87 | self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') 88 | 89 | def forward(self, x): 90 | X_h, X_l = x 91 | 92 | if self.stride ==2: 93 | X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l) 94 | 95 | X_l2h = self.l2h(X_l) 96 | X_h2h = self.h2h(X_h) 97 | X_l2h = self.upsample(X_l2h) 98 | 99 | X_h = X_h2h + X_l2h 100 | 101 | return X_h 102 | 103 | 104 | class OctaveCBR(nn.Module): 105 | def __init__(self,in_channels, out_channels, kernel_size=(3,3),alpha=0.5, stride=1, padding=1, dilation=1, 106 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 107 | super(OctaveCBR, self).__init__() 108 | self.conv = OctaveConv(in_channels,out_channels,kernel_size, alpha, stride, padding, dilation, groups, bias) 109 | self.bn_h = norm_layer(int(out_channels*(1-alpha))) 110 | self.bn_l = norm_layer(int(out_channels*alpha)) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | def forward(self, x): 114 | x_h, x_l = self.conv(x) 115 | x_h = self.relu(self.bn_h(x_h)) 116 | x_l = self.relu(self.bn_l(x_l)) 117 | return x_h, x_l 118 | 119 | 120 | class OctaveCB(nn.Module): 121 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha=0.5, stride=1, padding=1, dilation=1, 122 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 123 | super(OctaveCB, self).__init__() 124 | self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha, stride, padding, dilation, 125 | groups, bias) 126 | self.bn_h = norm_layer(int(out_channels * (1 - alpha))) 127 | self.bn_l = norm_layer(int(out_channels * alpha)) 128 | 129 | def forward(self, x): 130 | x_h, x_l = self.conv(x) 131 | x_h = self.bn_h(x_h) 132 | x_l = self.bn_l(x_l) 133 | return x_h, x_l 134 | 135 | 136 | class FirstOctaveCBR(nn.Module): 137 | def __init__(self, in_channels, out_channels, kernel_size=(3,3),alpha=0.5, stride=1, padding=1, dilation=1, 138 | groups=1, bias=False,norm_layer=nn.BatchNorm2d): 139 | super(FirstOctaveCBR, self).__init__() 140 | self.conv = FirstOctaveConv(in_channels,out_channels,kernel_size, alpha,stride,padding,dilation,groups,bias) 141 | self.bn_h = norm_layer(int(out_channels * (1 - alpha))) 142 | self.bn_l = norm_layer(int(out_channels * alpha)) 143 | self.relu = nn.ReLU(inplace=True) 144 | 145 | def forward(self, x): 146 | x_h, x_l = self.conv(x) 147 | x_h = self.relu(self.bn_h(x_h)) 148 | x_l = self.relu(self.bn_l(x_l)) 149 | return x_h, x_l 150 | 151 | 152 | class LastOCtaveCBR(nn.Module): 153 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha=0.5, stride=1, padding=1, dilation=1, 154 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 155 | super(LastOCtaveCBR, self).__init__() 156 | self.conv = LastOctaveConv(in_channels, out_channels, kernel_size, alpha, stride, padding, dilation, groups, bias) 157 | self.bn_h = norm_layer(out_channels) 158 | self.relu = nn.ReLU(inplace=True) 159 | 160 | def forward(self, x): 161 | x_h = self.conv(x) 162 | x_h = self.relu(self.bn_h(x_h)) 163 | return x_h 164 | 165 | 166 | class FirstOctaveCB(nn.Module): 167 | def __init__(self, in_channels, out_channels, kernel_size=(3,3), alpha=0.5,stride=1, padding=1, dilation=1, 168 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 169 | super(FirstOctaveCB, self).__init__() 170 | self.conv = FirstOctaveConv(in_channels,out_channels,kernel_size, alpha,stride,padding,dilation,groups,bias) 171 | self.bn_h = norm_layer(int(out_channels * (1 - alpha))) 172 | self.bn_l = norm_layer(int(out_channels * alpha)) 173 | self.relu = nn.ReLU(inplace=True) 174 | 175 | def forward(self, x): 176 | x_h, x_l = self.conv(x) 177 | x_h = self.bn_h(x_h) 178 | x_l = self.bn_l(x_l) 179 | return x_h, x_l 180 | 181 | 182 | class LastOCtaveCB(nn.Module): 183 | def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1, 184 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 185 | super(LastOCtaveCB, self).__init__() 186 | self.conv = LastOctaveConv( in_channels, out_channels, kernel_size, alpha, stride, padding, dilation, groups, bias) 187 | self.bn_h = norm_layer(out_channels) 188 | self.relu = nn.ReLU(inplace=True) 189 | 190 | def forward(self, x): 191 | x_h = self.conv(x) 192 | x_h = self.bn_h(x_h) 193 | return x_h 194 | 195 | 196 | if __name__ == '__main__': 197 | # nn.Conv2d 198 | high = torch.Tensor(1, 64, 32, 32).cuda() 199 | low = torch.Tensor(1, 192, 16, 16).cuda() 200 | # test Oc conv 201 | OCconv = OctaveConv(kernel_size=(3,3),in_channels=256,out_channels=512,bias=False,stride=2,alpha=0.75).cuda() 202 | i = high,low 203 | x_out,y_out = OCconv(i) 204 | print(x_out.size()) 205 | print(y_out.size()) 206 | 207 | i = torch.Tensor(1, 3, 512, 512).cuda() 208 | FOCconv = FirstOctaveConv(kernel_size=(3, 3), in_channels=3, out_channels=128).cuda() 209 | x_out, y_out = FOCconv(i) 210 | print("First: ", x_out.size(), y_out.size()) 211 | # test last Octave Cov 212 | LOCconv = LastOctaveConv(kernel_size=(3, 3), in_channels=256, out_channels=128, alpha=0.75).cuda() 213 | i = high, low 214 | out = LOCconv(i) 215 | print("Last: ", out.size()) 216 | 217 | # test OCB 218 | ocb = OctaveCB(in_channels=256, out_channels=128, alpha=0.75).cuda() 219 | i = high, low 220 | x_out_h, y_out_l = ocb(i) 221 | print("OCB:",x_out_h.size(),y_out_l.size()) 222 | 223 | # test last OCB 224 | ocb_last = LastOCtaveCBR(256, 128, alpha=0.75).cuda() 225 | i = high, low 226 | x_out_h = ocb_last(i) 227 | print("Last OCB", x_out_h.size()) 228 | -------------------------------------------------------------------------------- /libs/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .OCtaveResnet import * 2 | from .resnet_sge import * 3 | from .resnet_sk import * 4 | from .resnet_se import * 5 | from .res2net import * 6 | from .resnet_eca import * 7 | from .resnet_srm import * 8 | from .resnet_ge import * -------------------------------------------------------------------------------- /libs/nn/res2net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 4 | # Pytorch Implementation of Res2Net 5 | # original code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = ['res2net50', 'res2net101', 'res2net152', 'res2next50_32x4d', 12 | 'res2next101_32x8d','se_res2net50','se_res2net101','se_res2net152', 13 | 'se_res2next50_32x4d','se_res2next101_32x8d'] 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, groups=groups, bias=False) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class Res2NetBlock(nn.Module): 28 | def __init__(self, planes, scale=1, stride=1, groups=1, norm_layer=None): 29 | super(Res2NetBlock, self).__init__() 30 | 31 | self.relu = nn.ReLU(inplace=True) 32 | if norm_layer is None: 33 | norm_layer = nn.BatchNorm2d 34 | 35 | self.scale = scale 36 | ch_per_sub = planes // self.scale 37 | ch_res = planes % self.scale 38 | self.chunks = [ch_per_sub * i + ch_res for i in range(1, scale + 1)] 39 | self.conv_blocks = self._make_sub_convs(ch_per_sub, norm_layer, stride, groups) 40 | 41 | def forward(self, x): 42 | sub_convs = [] 43 | sub_convs.append(x[:, :self.chunks[0]]) 44 | sub_convs.append(self.conv_blocks[0](x[:, self.chunks[0]: self.chunks[1]])) 45 | for s in range(2, self.scale): 46 | sub_x = x[:, self.chunks[s - 1]: self.chunks[s]] 47 | sub_x += sub_convs[-1] 48 | sub_convs.append(self.conv_blocks[s - 1](sub_x)) 49 | 50 | return torch.cat(sub_convs, dim=1) 51 | 52 | def _make_sub_convs(self, ch_per_sub, norm_layer, stride, groups): 53 | layers = [] 54 | for s in range(1, self.scale): 55 | layers.append(nn.Sequential( 56 | conv3x3(ch_per_sub, ch_per_sub, stride, groups), 57 | norm_layer(ch_per_sub)) 58 | # norm_layer(ch_per_sub), self.relu)) 59 | 60 | return nn.Sequential(*layers) 61 | 62 | 63 | class SELayer(nn.Module): 64 | def __init__(self, channel, reduction=16): 65 | super(SELayer, self).__init__() 66 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 67 | self.fc = nn.Sequential( 68 | nn.Linear(channel, channel // reduction, bias=False), 69 | nn.ReLU(inplace=True), 70 | nn.Linear(channel // reduction, channel, bias=False), 71 | nn.Sigmoid() 72 | ) 73 | 74 | def forward(self, x): 75 | b, c, _, _ = x.size() 76 | y = self.avg_pool(x).view(b, c) 77 | y = self.fc(y).view(b, c, 1, 1) 78 | return x * y.expand_as(x) 79 | 80 | 81 | class Res2NetBlockSE(nn.Module): 82 | def __init__(self, planes, scale=1, stride=1, groups=1, norm_layer=None): 83 | super(Res2NetBlockSE, self).__init__() 84 | 85 | self.relu = nn.ReLU(inplace=True) 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | 89 | self.scale = scale 90 | ch_per_sub = planes // self.scale 91 | ch_res = planes % self.scale 92 | self.chunks = [ch_per_sub * i + ch_res for i in range(1, scale + 1)] 93 | self.conv_blocks = self._make_sub_convs(ch_per_sub, norm_layer, stride, groups) 94 | self.se = SELayer(planes) 95 | 96 | def forward(self, x): 97 | sub_convs = [] 98 | sub_convs.append(x[:, :self.chunks[0]]) 99 | sub_convs.append(self.conv_blocks[0](x[:, self.chunks[0]: self.chunks[1]])) 100 | for s in range(2, self.scale): 101 | sub_x = x[:, self.chunks[s - 1]: self.chunks[s]] 102 | sub_x += sub_convs[-1] 103 | sub_convs.append(self.conv_blocks[s - 1](sub_x)) 104 | out = torch.cat(sub_convs, dim=1) 105 | out = self.se(out) 106 | return out 107 | 108 | def _make_sub_convs(self, ch_per_sub, norm_layer, stride, groups): 109 | layers = [] 110 | for s in range(1, self.scale): 111 | layers.append(nn.Sequential( 112 | conv3x3(ch_per_sub, ch_per_sub, stride, groups), 113 | norm_layer(ch_per_sub))) 114 | #norm_layer(ch_per_sub), self.relu)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | 119 | class Res2NetBottleneck(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, scale=1, stride=1, downsample=None, groups=1, norm_layer=None, se=False,reduction=16): 123 | super(Res2NetBottleneck, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 127 | self.conv1 = conv1x1(inplanes, planes) 128 | self.bn1 = norm_layer(planes) 129 | if downsample is None and scale > 1: 130 | self.conv2 = Res2NetBlock(planes, scale, stride, groups) 131 | else: 132 | self.conv2 = conv3x3(planes, planes, stride, groups) 133 | self.bn2 = norm_layer(planes) 134 | self.conv3 = conv1x1(planes, planes * self.expansion) 135 | self.bn3 = norm_layer(planes * self.expansion) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.downsample = downsample 138 | self.stride = stride 139 | self.se = se 140 | if se: 141 | self.se_layer = SELayer(planes*self.expansion, reduction=reduction) 142 | 143 | def forward(self, x): 144 | identity = x 145 | 146 | out = self.conv1(x) 147 | out = self.bn1(out) 148 | out = self.relu(out) 149 | 150 | out = self.conv2(out) 151 | out = self.bn2(out) 152 | out = self.relu(out) 153 | 154 | out = self.conv3(out) 155 | out = self.bn3(out) 156 | if self.se: 157 | out = self.se_layer(out) 158 | 159 | if self.downsample is not None: 160 | identity = self.downsample(x) 161 | 162 | out += identity 163 | out = self.relu(out) 164 | 165 | return out 166 | 167 | 168 | class ResNet(nn.Module): 169 | 170 | def __init__(self, block, layers, scale=1, se=False, num_classes=1000, zero_init_residual=False, 171 | groups=1, width_per_group=64, norm_layer=None): 172 | super(ResNet, self).__init__() 173 | self.scale = scale 174 | if norm_layer is None: 175 | norm_layer = nn.BatchNorm2d 176 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] 177 | self.inplanes = planes[0] 178 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, 179 | bias=False) 180 | self.bn1 = norm_layer(planes[0]) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 183 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer,se=se) 184 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer,se=se) 185 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer,se=se) 186 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer,se=se) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | # Zero-initialize the last BN in each residual branch, 198 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 199 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 200 | if zero_init_residual: 201 | for m in self.modules(): 202 | if isinstance(m, Res2NetBottleneck): 203 | nn.init.constant_(m.bn3.weight, 0) 204 | elif isinstance(m, Res2NetBlock): 205 | nn.init.constant_(m.bn2.weight, 0) 206 | 207 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, se=False): 208 | if norm_layer is None: 209 | norm_layer = nn.BatchNorm2d 210 | downsample = None 211 | if stride != 1 or self.inplanes != planes * block.expansion: 212 | downsample = nn.Sequential( 213 | conv1x1(self.inplanes, planes * block.expansion, stride), 214 | norm_layer(planes * block.expansion), 215 | ) 216 | 217 | layers = [] 218 | layers.append(block(self.inplanes, planes, self.scale, stride, downsample, groups, norm_layer, se)) 219 | self.inplanes = planes * block.expansion 220 | for _ in range(1, blocks): 221 | layers.append(block(self.inplanes, planes, self.scale, groups=groups, norm_layer=norm_layer,se=se)) 222 | 223 | return nn.Sequential(*layers) 224 | 225 | def forward(self, x): 226 | x = self.conv1(x) 227 | x = self.bn1(x) 228 | x = self.relu(x) 229 | x = self.maxpool(x) 230 | 231 | x = self.layer1(x) 232 | x = self.layer2(x) 233 | x = self.layer3(x) 234 | x = self.layer4(x) 235 | 236 | x = self.avgpool(x) 237 | x = x.view(x.size(0), -1) 238 | x = self.fc(x) 239 | 240 | return x 241 | 242 | 243 | def res2net50(scale=4, **kwargs): 244 | """Constructs a Res2Net-50 model. 245 | 246 | Args: 247 | scale (int): Number of feature groups in the Res2Net block 248 | """ 249 | model = ResNet(Res2NetBottleneck, [3, 4, 6, 3], scale=scale, **kwargs) 250 | return model 251 | 252 | 253 | def res2net101(scale=4, **kwargs): 254 | """Constructs a Res2Net-101 model. 255 | 256 | Args: 257 | scale (int): Number of feature groups in the Res2Net block 258 | """ 259 | model = ResNet(Res2NetBottleneck, [3, 4, 23, 3], scale=scale, **kwargs) 260 | return model 261 | 262 | 263 | def res2net152(scale=4, **kwargs): 264 | """Constructs a Res2Net-152 model. 265 | 266 | Args: 267 | scale (int): Number of feature groups in the Res2Net block 268 | """ 269 | model = ResNet(Res2NetBottleneck, [3, 8, 36, 3], scale=scale, **kwargs) 270 | return model 271 | 272 | 273 | def res2next50_32x4d(scale=4, **kwargs): 274 | """Constructs a Res2NeXt50_32x4d model. 275 | 276 | Args: 277 | scale (int): Number of feature groups in the Res2Net block 278 | """ 279 | model = ResNet(Res2NetBottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, scale=scale, **kwargs) 280 | return model 281 | 282 | 283 | def res2next101_32x8d(scale=4, **kwargs): 284 | """Constructs a Res2NeXt101_32x8d model. 285 | 286 | Args: 287 | scale (int): Number of feature groups in the Res2Net block 288 | If scale=1 then it will create the standard conv3x3 block 289 | """ 290 | model = ResNet(Res2NetBottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, scale=scale, **kwargs) 291 | return model 292 | 293 | 294 | def se_res2net50(scale=4, **kwargs): 295 | """Constructs a Res2Net-152 model. 296 | 297 | Args: 298 | scale (int): Number of feature groups in the Res2Net block 299 | """ 300 | model = ResNet(Res2NetBottleneck, [3, 4, 6, 3], scale=scale, se=True, **kwargs) 301 | return model 302 | 303 | 304 | def se_res2net101(scale=4, **kwargs): 305 | """Constructs a Res2Net-152 model. 306 | 307 | Args: 308 | scale (int): Number of feature groups in the Res2Net block 309 | """ 310 | model = ResNet(Res2NetBottleneck, [3, 4, 23, 3], scale=scale, se=True, **kwargs) 311 | return model 312 | 313 | 314 | def se_res2net152(scale=4, **kwargs): 315 | """Constructs a Res2Net-152 model. 316 | 317 | Args: 318 | scale (int): Number of feature groups in the Res2Net block 319 | """ 320 | model = ResNet(Res2NetBottleneck, [3, 8, 36, 3], scale=scale, se=True, **kwargs) 321 | return model 322 | 323 | def se_res2next50_32x4d(scale=4, **kwargs): 324 | """Constructs a Res2NeXt50_32x4d model. 325 | 326 | Args: 327 | scale (int): Number of feature groups in the Res2Net block 328 | """ 329 | model = ResNet(Res2NetBottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, scale=scale,se=True, **kwargs) 330 | return model 331 | 332 | 333 | def se_res2next101_32x8d(scale=4, **kwargs): 334 | """Constructs a Res2NeXt101_32x8d model. 335 | 336 | Args: 337 | scale (int): Number of feature groups in the Res2Net block 338 | If scale=1 then it will create the standard conv3x3 block 339 | """ 340 | model = ResNet(Res2NetBottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, scale=scale,se=True, **kwargs) 341 | return model 342 | 343 | 344 | if __name__ == '__main__': 345 | model = res2next101_32x8d().cuda() 346 | # model = se_resnet50().cuda() 347 | print(model) 348 | i = torch.Tensor(1,3,256,256).cuda() 349 | y= model(i) 350 | print(y.size()) 351 | -------------------------------------------------------------------------------- /libs/nn/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 4 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, groups=groups, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 22 | base_width=64, norm_layer=None): 23 | super(BasicBlock, self).__init__() 24 | if norm_layer is None: 25 | norm_layer = nn.BatchNorm2d 26 | if groups != 1 or base_width != 64: 27 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 60 | base_width=64, norm_layer=None): 61 | super(Bottleneck, self).__init__() 62 | if norm_layer is None: 63 | norm_layer = nn.BatchNorm2d 64 | width = int(planes * (base_width / 64.)) * groups 65 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 66 | self.conv1 = conv1x1(inplanes, width) 67 | self.bn1 = norm_layer(width) 68 | self.conv2 = conv3x3(width, width, stride, groups) 69 | self.bn2 = norm_layer(width) 70 | self.conv3 = conv1x1(width, planes * self.expansion) 71 | self.bn3 = norm_layer(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 102 | groups=1, width_per_group=64, norm_layer=None): 103 | super(ResNet, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm2d 106 | 107 | self.inplanes = 64 108 | self.groups = groups 109 | self.base_width = width_per_group 110 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 111 | bias=False) 112 | self.bn1 = norm_layer(self.inplanes) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 119 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 125 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | 129 | # Zero-initialize the last BN in each residual branch, 130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, Bottleneck): 135 | nn.init.constant_(m.bn3.weight, 0) 136 | elif isinstance(m, BasicBlock): 137 | nn.init.constant_(m.bn2.weight, 0) 138 | 139 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 140 | if norm_layer is None: 141 | norm_layer = nn.BatchNorm2d 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | conv1x1(self.inplanes, planes * block.expansion, stride,), 146 | norm_layer(planes * block.expansion), 147 | ) 148 | 149 | layers = [] 150 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 151 | self.base_width, norm_layer)) 152 | self.inplanes = planes * block.expansion 153 | for _ in range(1, blocks): 154 | layers.append(block(self.inplanes, planes, groups=self.groups, 155 | base_width=self.base_width, norm_layer=norm_layer)) 156 | 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | x = self.conv1(x) 161 | x = self.bn1(x) 162 | x = self.relu(x) 163 | x = self.maxpool(x) 164 | 165 | x = self.layer1(x) 166 | x = self.layer2(x) 167 | x = self.layer3(x) 168 | x = self.layer4(x) 169 | x = self.avgpool(x) 170 | x = x.view(x.size(0), -1) 171 | x = self.fc(x) 172 | 173 | return x 174 | 175 | 176 | def resnet18(pretrained=False, **kwargs): 177 | """Constructs a ResNet-18 model. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 183 | # if pretrained: 184 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 185 | return model 186 | 187 | 188 | def resnet34(pretrained=False, **kwargs): 189 | """Constructs a ResNet-34 model. 190 | 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 195 | # if pretrained: 196 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 197 | return model 198 | 199 | 200 | def resnet50(pretrained=False, **kwargs): 201 | """Constructs a ResNet-50 model. 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 207 | # if pretrained: 208 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 209 | return model 210 | 211 | 212 | def resnet101(pretrained=False, **kwargs): 213 | """Constructs a ResNet-101 model. 214 | 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | """ 218 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 219 | # if pretrained: 220 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 221 | return model 222 | 223 | 224 | def resnet152(pretrained=False, **kwargs): 225 | """Constructs a ResNet-152 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 231 | # if pretrained: 232 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 233 | return model 234 | 235 | 236 | def resnext50_32x4d(pretrained=False, **kwargs): 237 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 238 | # if pretrained: 239 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 240 | return model 241 | 242 | 243 | def resnext101_32x8d(pretrained=False, **kwargs): 244 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 245 | # if pretrained: 246 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 247 | return model 248 | 249 | if __name__ == '__main__': 250 | import torch 251 | model = resnet50().cuda() 252 | i = torch.Tensor(1, 3, 256, 256).cuda() 253 | y = model(i) 254 | print(y.size()) 255 | 256 | """ 257 | layer output size: 258 | torch.Size([1, 256, 64, 64]) 259 | torch.Size([1, 512, 32, 32]) 260 | torch.Size([1, 1024, 16, 16]) 261 | torch.Size([1, 2048, 8, 8]) 262 | torch.Size([1, 1000]) 263 | """ -------------------------------------------------------------------------------- /libs/nn/resnet_adaptiveconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 4 | # Pytorch Implementation of Adaptive Conv 5 | # This is un-offical implementation of Adaptive Conv 6 | # PixelAwareAdaptiveBottleneck: (finished) 7 | # DataSetAwareAdaptiveBottleneck: depends on input size (finished) which is position sensitive 8 | # (data sensitive with learnable weights) 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 16 | """3x3 conv with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=(3,3), stride=stride, 18 | padding=1, groups=groups, bias=False) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 conv""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=stride, bias=False, padding=0) 24 | 25 | 26 | class PixelAwareAdaptiveBottleneck(nn.Module): 27 | expansion = 4 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 30 | base_width=64, norm_layer=None, input_size=None): 31 | super(PixelAwareAdaptiveBottleneck, self).__init__() 32 | if norm_layer is None: 33 | norm_layer = nn.BatchNorm2d 34 | width = int(planes * (base_width / 64.)) * groups 35 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 36 | self.conv1 = conv1x1(inplanes, width) 37 | self.bn1 = norm_layer(width) 38 | self.conv2_3x3 = conv3x3(width, width, stride, groups) 39 | self.bn2 = norm_layer(width) 40 | self.gap = nn.AdaptiveAvgPool2d(1) 41 | self.fc1 = nn.Conv2d(inplanes,width,kernel_size=1) 42 | self.fc2 = nn.Conv2d(width,width,kernel_size=1) 43 | 44 | self.fusion_conv1 = nn.Conv2d(width*2,width,1) 45 | self.fusion_conv2 = nn.Conv2d(width,width,1) 46 | 47 | 48 | self.conv3 = conv1x1(width, planes * self.expansion) 49 | self.bn3 = norm_layer(planes * self.expansion) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.downsample = downsample 52 | self.stride = stride 53 | self.sigmod = nn.Sigmoid() 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | # conv 63 | out_conv3x3 = self.conv2_3x3(out) 64 | 65 | # gap 66 | size = out_conv3x3.size()[2:] 67 | gap = self.gap(x) 68 | gap = self.relu(self.fc1(gap)) 69 | gap = self.fc2(gap) 70 | gap = F.upsample(gap, size=size,mode="bilinear", align_corners=True) 71 | 72 | # concat 73 | out_concat = torch.cat([gap,out_conv3x3],dim=1) 74 | out_concat = self.fusion_conv1(out_concat) 75 | # out_concat = self.bn_fusion1(out_concat) 76 | out_concat = self.relu(out_concat) 77 | out_concat = self.fusion_conv2(out_concat) 78 | # out_concat = self.bn_fusion2(out_concat) 79 | out_concat = self.sigmod(out_concat) 80 | 81 | out = out_conv3x3 + gap * out_concat 82 | 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | identity = self.downsample(x) 91 | 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class DataSetAwareAdaptiveBottleneck(nn.Module): 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 102 | base_width=64, norm_layer=None, input_size=(224, 224)): 103 | super(DataSetAwareAdaptiveBottleneck, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm2d 106 | width = int(planes * (base_width / 64.)) * groups 107 | H, W = input_size 108 | self.H, self.W = H, W 109 | if stride == 1: 110 | self.H, self.W = H, W 111 | else: 112 | self.H, self.W = H//2, W//2 113 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 114 | # which is a little different 115 | self.conv1 = conv1x1(inplanes, width, stride=stride) 116 | self.bn1 = norm_layer(width) 117 | self.conv2 = AdaptiveConv(width, width, stride=1, groups=groups,size=(self.H, self.W)) 118 | self.bn2 = norm_layer(width) 119 | self.conv3 = conv1x1(width, planes * self.expansion) 120 | self.bn3 = norm_layer(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x): 126 | identity = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | _,_,h,w = out.size() 132 | assert self.H == h and self.W == w 133 | 134 | out = self.conv2(out) 135 | out = self.bn2(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | if self.downsample is not None: 142 | identity = self.downsample(x) 143 | 144 | out += identity 145 | out = self.relu(out) 146 | 147 | return out 148 | 149 | 150 | class AdaptiveConv(nn.Module): 151 | def __init__(self, in_channels, out_channels, stride=1, padding=1, dilation=1, 152 | groups=1, bias=False, size=(256, 256)): 153 | super(AdaptiveConv, self).__init__() 154 | 155 | self.conv3x3 = nn.Conv2d(in_channels, out_channels,3, stride, padding=1, dilation=dilation,groups=groups, bias=bias) 156 | self.conv1x1 = nn.Conv2d(in_channels, out_channels,1, stride, padding=0, dilation=dilation,groups=groups, bias=bias) 157 | self.gap = nn.AdaptiveAvgPool2d(1) 158 | self.fc1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 159 | self.fc2 = nn.Conv2d(out_channels, out_channels, kernel_size=1) 160 | self.size = size 161 | self.w = nn.Parameter(torch.ones(3, 1, self.size[0], self.size[1])) 162 | self.softmax = nn.Softmax() 163 | self.relu = nn.ReLU(inplace=True) 164 | 165 | def forward(self, x): 166 | _, _, h, w = x.size() 167 | 168 | weight = self.softmax(self.w) 169 | w1 = weight[0, :, :, :] 170 | w2 = weight[1, :, :, :] 171 | w3 = weight[2, :, :, :] 172 | 173 | x1 = self.conv3x3(x) # con3x3 174 | x2 = self.conv1x1(x) # self 175 | 176 | size = x1.size()[2:] # gap 177 | gap = self.gap(x) 178 | gap = self.relu(self.fc1(gap)) 179 | gap = self.fc2(gap) 180 | # gap = F.upsample(gap, size=size, mode="bilinear", align_corners=True) 181 | gap = F.upsample(gap,size=size,mode="nearest") 182 | 183 | x = w1 * x1 + w2 * x2 + w3 * gap 184 | 185 | return x 186 | 187 | class ResNet(nn.Module): 188 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 189 | groups=1, width_per_group=64, norm_layer=None, input_size=(256,256)): 190 | super(ResNet, self).__init__() 191 | if norm_layer is None: 192 | norm_layer = nn.BatchNorm2d 193 | 194 | self.inplanes = 64 195 | self.groups = groups 196 | self.base_width = width_per_group 197 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 198 | bias=False) 199 | self.bn1 = norm_layer(self.inplanes) 200 | self.relu = nn.ReLU(inplace=True) 201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 202 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, input_size=(input_size[0]//4, 203 | input_size[1]//4, 204 | )) 205 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer, input_size=(input_size[0]//4, 206 | input_size[1]//4 207 | )) 208 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer, input_size=(input_size[0]//8, 209 | input_size[1]//8 210 | )) 211 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer, input_size=(input_size[0]//16, 212 | input_size[1]//16 213 | )) 214 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 215 | self.fc = nn.Linear(512 * block.expansion, num_classes) 216 | 217 | for m in self.modules(): 218 | if isinstance(m, nn.Conv2d): 219 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 220 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 221 | nn.init.constant_(m.weight, 1) 222 | nn.init.constant_(m.bias, 0) 223 | 224 | # Zero-initialize the last BN in each residual branch, 225 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 226 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 227 | if zero_init_residual: 228 | for m in self.modules(): 229 | if isinstance(m, PixelAwareAdaptiveBottleneck): 230 | nn.init.constant_(m.bn3.weight, 0) 231 | 232 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, input_size=(224, 224)): 233 | if norm_layer is None: 234 | norm_layer = nn.BatchNorm2d 235 | downsample = None 236 | if stride != 1 or self.inplanes != planes * block.expansion: 237 | downsample = nn.Sequential( 238 | conv1x1(self.inplanes, planes * block.expansion, stride,), 239 | norm_layer(planes * block.expansion), 240 | ) 241 | 242 | layers = [] 243 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 244 | self.base_width, norm_layer, input_size)) 245 | self.inplanes = planes * block.expansion 246 | for _ in range(1, blocks): 247 | if stride !=1: 248 | layers.append(block(self.inplanes, planes, groups=self.groups, 249 | base_width=self.base_width, norm_layer=norm_layer, 250 | input_size=(input_size[0]//2, input_size[1]//2))) 251 | else: 252 | layers.append(block(self.inplanes, planes, groups=self.groups, 253 | base_width=self.base_width, norm_layer=norm_layer, input_size=input_size)) 254 | 255 | return nn.Sequential(*layers) 256 | 257 | def forward(self, x): 258 | x = self.conv1(x) 259 | x = self.bn1(x) 260 | x = self.relu(x) 261 | x = self.maxpool(x) 262 | # print(x.size()) 263 | x = self.layer1(x) 264 | # print(x.size()) 265 | x = self.layer2(x) 266 | # print(x.size()) 267 | x = self.layer3(x) 268 | # print(x.size()) 269 | x = self.layer4(x) 270 | x = self.avgpool(x) 271 | x = x.view(x.size(0), -1) 272 | x = self.fc(x) 273 | 274 | return x 275 | 276 | 277 | 278 | def PixelAwareResnet50(pretrained=False, **kwargs): 279 | """Constructs a ResNet-50 model. 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | """ 284 | model = ResNet(PixelAwareAdaptiveBottleneck, [3, 4, 6, 3], **kwargs) 285 | return model 286 | 287 | 288 | def PixelAwareResnet101(pretrained=False, **kwargs): 289 | """Constructs a ResNet-101 model. 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | """ 294 | model = ResNet(PixelAwareAdaptiveBottleneck, [3, 4, 23, 3], **kwargs) 295 | return model 296 | 297 | 298 | def PixelAwareResnet152(pretrained=False, **kwargs): 299 | """Constructs a ResNet-152 model. 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | """ 304 | model = ResNet(PixelAwareAdaptiveBottleneck, [3, 8, 36, 3], **kwargs) 305 | return model 306 | 307 | 308 | def DataSetAwareResnet50(pretrained=False, **kwargs): 309 | """Constructs a ResNet-50 model. 310 | 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | """ 314 | model = ResNet(DataSetAwareAdaptiveBottleneck, [3, 4, 6, 3], **kwargs) 315 | return model 316 | 317 | 318 | def DataSetAwareResnet101(pretrained=False, **kwargs): 319 | """Constructs a ResNet-50 model. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | """ 324 | model = ResNet(DataSetAwareAdaptiveBottleneck, [3, 4, 23, 3], **kwargs) 325 | return model 326 | 327 | 328 | def DataSetAwareResnet152(pretrained=False, **kwargs): 329 | """Constructs a ResNet-50 model. 330 | 331 | Args: 332 | pretrained (bool): If True, returns a model pre-trained on ImageNet 333 | """ 334 | model = ResNet(DataSetAwareAdaptiveBottleneck, [3, 4, 23, 3], **kwargs) 335 | return model 336 | 337 | 338 | if __name__ == '__main__': 339 | model = DataSetAwareResnet50().cuda() 340 | i = torch.Tensor(1, 3, 256, 256).cuda() 341 | y = model(i) 342 | print(y.size()) -------------------------------------------------------------------------------- /libs/nn/resnet_eca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = [ 'eca_resnet50', 'eca_resnet101'] 5 | 6 | 7 | class eca_layer(nn.Module): 8 | """Constructs a ECA module. 9 | Args: 10 | channel: Number of channels of the input feature map 11 | k_size: Adaptive selection of kernel size 12 | """ 13 | 14 | def __init__(self, channel, k_size=3): 15 | super(eca_layer, self).__init__() 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | def forward(self, x): 21 | # x: input features with shape [b, c, h, w] 22 | b, c, h, w = x.size() 23 | 24 | # feature descriptor on the global spatial information 25 | y = self.avg_pool(x) 26 | 27 | # Two different branches of ECA module 28 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 29 | 30 | # Multi-scale information fusion 31 | y = self.sigmoid(y) 32 | 33 | return x * y.expand_as(x) 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | 42 | def conv1x1(in_planes, out_planes, stride=1): 43 | """1x1 convolution""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | self.se = eca_layer(planes) 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None): 85 | super(Bottleneck, self).__init__() 86 | self.conv1 = conv1x1(inplanes, planes) 87 | self.bn1 = nn.BatchNorm2d(planes) 88 | self.conv2 = conv3x3(planes, planes, stride) 89 | self.bn2 = nn.BatchNorm2d(planes) 90 | self.conv3 = conv1x1(planes, planes * self.expansion) 91 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 92 | self.se = eca_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | out = self.se(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class ResNet(nn.Module): 122 | 123 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 124 | super(ResNet, self).__init__() 125 | self.inplanes = 64 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | self.fc = nn.Linear(512 * block.expansion, num_classes) 137 | 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 141 | elif isinstance(m, nn.BatchNorm2d): 142 | nn.init.constant_(m.weight, 1) 143 | nn.init.constant_(m.bias, 0) 144 | 145 | # Zero-initialize the last BN in each residual branch, 146 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 147 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 148 | if zero_init_residual: 149 | for m in self.modules(): 150 | if isinstance(m, Bottleneck): 151 | nn.init.constant_(m.bn3.weight, 0) 152 | elif isinstance(m, BasicBlock): 153 | nn.init.constant_(m.bn2.weight, 0) 154 | 155 | def _make_layer(self, block, planes, blocks, stride=1): 156 | downsample = None 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | downsample = nn.Sequential( 159 | conv1x1(self.inplanes, planes * block.expansion, stride), 160 | nn.BatchNorm2d(planes * block.expansion), 161 | ) 162 | 163 | layers = [] 164 | layers.append(block(self.inplanes, planes, stride, downsample)) 165 | self.inplanes = planes * block.expansion 166 | for _ in range(1, blocks): 167 | layers.append(block(self.inplanes, planes)) 168 | 169 | return nn.Sequential(*layers) 170 | 171 | def forward(self, x): 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | x = self.layer4(x) 181 | 182 | x = self.avgpool(x) 183 | x = x.view(x.size(0), -1) 184 | x = self.fc(x) 185 | 186 | return x 187 | 188 | 189 | 190 | def eca_resnet50(pretrained=False, **kwargs): 191 | """Constructs a ResNet-50 model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 196 | return model 197 | 198 | 199 | def eca_resnet101(pretrained=False, **kwargs): 200 | """Constructs a ResNet-101 model. 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 205 | return model 206 | 207 | 208 | def eca_resnet152(pretrained=False, **kwargs): 209 | """Constructs a ResNet-152 model. 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | """ 213 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 214 | return model 215 | 216 | -------------------------------------------------------------------------------- /libs/nn/resnet_ge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # Author: Xiangtai Li(lxtpku@pku.edu.cn) 5 | # Pytorch Implementation of GE-net: 6 | 7 | __all__ = [ 'ge_resnet50', 'ge_resnet101', 'ge_resnet152'] 8 | 9 | 10 | 11 | class GELayerv1(nn.Module): 12 | def __init__(self): 13 | super(GELayerv1, self).__init__() 14 | self.avg_pool = nn.AvgPool2d(kernel_size=(15, 15), stride=8) 15 | self.sigmod = nn.Sigmoid() 16 | 17 | 18 | def forward(self, x): 19 | b, c, h, w = x.size() 20 | res = x 21 | y = self.avg_pool(x) 22 | y = F.upsample(y,size=(h, w), mode="bilinear", align_corners=True) 23 | y = y * x 24 | return res + y 25 | 26 | 27 | class GELayerv2(nn.Module): 28 | def __init__(self,): 29 | super(GELayerv2, self).__init__() 30 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 31 | self.sigmod = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | b, c, _, _ = x.size() 35 | res = x 36 | y = self.avg_pool(x) 37 | y = self.sigmod(y) 38 | z = x * y 39 | return res + z 40 | 41 | 42 | class GELayerv3(nn.Module): 43 | def __init__(self, inplane): 44 | super(GELayerv3, self).__init__() 45 | self.dconv1 = nn.Sequential( 46 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 47 | nn.BatchNorm2d(inplane), 48 | nn.ReLU(inplace=False) 49 | ) 50 | self.dconv2 = nn.Sequential( 51 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 52 | nn.BatchNorm2d(inplane), 53 | nn.ReLU(inplace=False) 54 | ) 55 | self.dconv3 = nn.Sequential( 56 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 57 | nn.BatchNorm2d(inplane), 58 | nn.ReLU(inplace=False) 59 | ) 60 | self.sigmoid_spatial = nn.Sigmoid() 61 | 62 | def forward(self, x): 63 | b, c, h, w = x.size() 64 | res1 = x 65 | res2 = x 66 | x = self.dconv1(x) 67 | x = self.dconv2(x) 68 | x = self.dconv3(x) 69 | x = F.upsample(x, size=(h, w), mode="bilinear", align_corners=True) 70 | x = self.sigmoid_spatial(x) 71 | res1 = res1 * x 72 | 73 | return res2 + res1 74 | 75 | 76 | def conv3x3(in_planes, out_planes, stride=1): 77 | """3x3 convolution with padding""" 78 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 79 | padding=1, bias=False) 80 | 81 | 82 | def conv1x1(in_planes, out_planes, stride=1): 83 | """1x1 convolution""" 84 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 85 | 86 | 87 | class BasicBlock(nn.Module): 88 | expansion = 1 89 | 90 | def __init__(self, inplanes, planes, stride=1, downsample=None): 91 | super(BasicBlock, self).__init__() 92 | self.conv1 = conv3x3(inplanes, planes, stride) 93 | self.bn1 = nn.BatchNorm2d(planes) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.conv2 = conv3x3(planes, planes) 96 | self.bn2 = nn.BatchNorm2d(planes) 97 | self.downsample = downsample 98 | self.stride = stride 99 | self.ge = GELayerv2() 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.ge(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class Bottleneck(nn.Module): 122 | expansion = 4 123 | 124 | def __init__(self, inplanes, planes, stride=1, downsample=None): 125 | super(Bottleneck, self).__init__() 126 | self.conv1 = conv1x1(inplanes, planes) 127 | self.bn1 = nn.BatchNorm2d(planes) 128 | self.conv2 = conv3x3(planes, planes, stride) 129 | self.bn2 = nn.BatchNorm2d(planes) 130 | self.conv3 = conv1x1(planes, planes * self.expansion) 131 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 132 | self.ge = GELayerv2() 133 | self.relu = nn.ReLU(inplace=True) 134 | self.downsample = downsample 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | identity = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | out = self.relu(out) 147 | 148 | out = self.conv3(out) 149 | out = self.bn3(out) 150 | out = self.ge(out) 151 | 152 | if self.downsample is not None: 153 | identity = self.downsample(x) 154 | 155 | out += identity 156 | out = self.relu(out) 157 | 158 | return out 159 | 160 | 161 | class ResNet(nn.Module): 162 | 163 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 164 | super(ResNet, self).__init__() 165 | self.inplanes = 64 166 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 167 | bias=False) 168 | self.bn1 = nn.BatchNorm2d(64) 169 | self.relu = nn.ReLU(inplace=True) 170 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 171 | self.layer1 = self._make_layer(block, 64, layers[0]) 172 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 173 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 175 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 176 | self.fc = nn.Linear(512 * block.expansion, num_classes) 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | elif isinstance(m, nn.BatchNorm2d): 182 | nn.init.constant_(m.weight, 1) 183 | nn.init.constant_(m.bias, 0) 184 | 185 | # Zero-initialize the last BN in each residual branch, 186 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 187 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 188 | if zero_init_residual: 189 | for m in self.modules(): 190 | if isinstance(m, Bottleneck): 191 | nn.init.constant_(m.bn3.weight, 0) 192 | elif isinstance(m, BasicBlock): 193 | nn.init.constant_(m.bn2.weight, 0) 194 | 195 | def _make_layer(self, block, planes, blocks, stride=1): 196 | downsample = None 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | conv1x1(self.inplanes, planes * block.expansion, stride), 200 | nn.BatchNorm2d(planes * block.expansion), 201 | ) 202 | 203 | layers = [] 204 | layers.append(block(self.inplanes, planes, stride, downsample)) 205 | self.inplanes = planes * block.expansion 206 | for _ in range(1, blocks): 207 | layers.append(block(self.inplanes, planes)) 208 | 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x): 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | x = self.relu(x) 215 | x = self.maxpool(x) 216 | 217 | x = self.layer1(x) 218 | x = self.layer2(x) 219 | x = self.layer3(x) 220 | x = self.layer4(x) 221 | 222 | x = self.avgpool(x) 223 | x = x.view(x.size(0), -1) 224 | x = self.fc(x) 225 | 226 | return x 227 | 228 | 229 | 230 | 231 | def ge_resnet50(pretrained=False, **kwargs): 232 | """Constructs a ResNet-50 model. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | """ 236 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 237 | return model 238 | 239 | 240 | def ge_resnet101(pretrained=False, **kwargs): 241 | """Constructs a ResNet-101 model. 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | """ 245 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 246 | return model 247 | 248 | 249 | def ge_resnet152(pretrained=False, **kwargs): 250 | """Constructs a ResNet-152 model. 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | """ 254 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 255 | return model 256 | 257 | 258 | if __name__ == '__main__': 259 | model = ge_resnet50() 260 | i = torch.Tensor(1, 3, 224, 224) 261 | y = model(i) 262 | print(y.size()) -------------------------------------------------------------------------------- /libs/nn/resnet_se.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['se_resnet18', 'se_resnet34', 'se_resnet50', 'se_resnet101', 'se_resnet152'] 6 | 7 | 8 | class SELayer(nn.Module): 9 | def __init__(self, channel, reduction = 16): 10 | super(SELayer, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.fc = nn.Sequential( 13 | nn.Linear(channel, 10), 14 | nn.ReLU(inplace = True), 15 | nn.Linear(10, channel), 16 | nn.Sigmoid() 17 | ) 18 | 19 | def forward(self, x): 20 | b, c, _, _ = x.size() 21 | y = self.avg_pool(x).view(b, c) 22 | y = self.fc(y).view(b, c, 1, 1) 23 | return x * y 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None): 41 | super(BasicBlock, self).__init__() 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | self.se = SELayer(planes) 50 | 51 | def forward(self, x): 52 | identity = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | out = self.se(out) 61 | 62 | if self.downsample is not None: 63 | identity = self.downsample(x) 64 | 65 | out += identity 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = conv1x1(inplanes, planes) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.conv2 = conv3x3(planes, planes, stride) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.conv3 = conv1x1(planes, planes * self.expansion) 81 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 82 | self.se = SELayer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | out = self.se(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 114 | super(ResNet, self).__init__() 115 | self.inplanes = 64 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 117 | bias=False) 118 | self.bn1 = nn.BatchNorm2d(64) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.layer1 = self._make_layer(block, 64, layers[0]) 122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 124 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 125 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 126 | self.fc = nn.Linear(512 * block.expansion, num_classes) 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | elif isinstance(m, nn.BatchNorm2d): 132 | nn.init.constant_(m.weight, 1) 133 | nn.init.constant_(m.bias, 0) 134 | 135 | # Zero-initialize the last BN in each residual branch, 136 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 137 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 138 | if zero_init_residual: 139 | for m in self.modules(): 140 | if isinstance(m, Bottleneck): 141 | nn.init.constant_(m.bn3.weight, 0) 142 | elif isinstance(m, BasicBlock): 143 | nn.init.constant_(m.bn2.weight, 0) 144 | 145 | def _make_layer(self, block, planes, blocks, stride=1): 146 | downsample = None 147 | if stride != 1 or self.inplanes != planes * block.expansion: 148 | downsample = nn.Sequential( 149 | conv1x1(self.inplanes, planes * block.expansion, stride), 150 | nn.BatchNorm2d(planes * block.expansion), 151 | ) 152 | 153 | layers = [] 154 | layers.append(block(self.inplanes, planes, stride, downsample)) 155 | self.inplanes = planes * block.expansion 156 | for _ in range(1, blocks): 157 | layers.append(block(self.inplanes, planes)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | x = self.maxpool(x) 166 | 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | 172 | x = self.avgpool(x) 173 | x = x.view(x.size(0), -1) 174 | x = self.fc(x) 175 | 176 | return x 177 | 178 | 179 | def se_resnet18(pretrained=False, **kwargs): 180 | """Constructs a ResNet-18 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 185 | return model 186 | 187 | 188 | def se_resnet34(pretrained=False, **kwargs): 189 | """Constructs a ResNet-34 model. 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 194 | return model 195 | 196 | 197 | def se_resnet50(pretrained=False, **kwargs): 198 | """Constructs a ResNet-50 model. 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 203 | return model 204 | 205 | 206 | def se_resnet101(pretrained=False, **kwargs): 207 | """Constructs a ResNet-101 model. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 212 | return model 213 | 214 | 215 | def se_resnet152(pretrained=False, **kwargs): 216 | """Constructs a ResNet-152 model. 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 221 | return model 222 | 223 | if __name__ == '__main__': 224 | model = se_resnet50().cuda() 225 | i = torch.Tensor(1, 3, 256, 256).cuda() 226 | y = model(i) 227 | print(y.size()) -------------------------------------------------------------------------------- /libs/nn/resnet_sge.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import Parameter 4 | # code from origin repo: https://github.com/implus/PytorchInsight 5 | 6 | __all__ = ['sge_resnet18', 'sge_resnet34', 'sge_resnet50', 'sge_resnet101', 7 | 'sge_resnet152'] 8 | 9 | 10 | class SpatialGroupEnhance(nn.Module): 11 | def __init__(self, groups = 64): 12 | super(SpatialGroupEnhance, self).__init__() 13 | self.groups = groups 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.weight = Parameter(torch.zeros(1, groups, 1, 1)) 16 | self.bias = Parameter(torch.ones(1, groups, 1, 1)) 17 | self.sig = nn.Sigmoid() 18 | 19 | def forward(self, x): # (b, c, h, w) 20 | b, c, h, w = x.size() 21 | x = x.view(b * self.groups, -1, h, w) 22 | xn = x * self.avg_pool(x) 23 | xn = xn.sum(dim=1, keepdim=True) 24 | t = xn.view(b * self.groups, -1) 25 | t = t - t.mean(dim=1, keepdim=True) 26 | std = t.std(dim=1, keepdim=True) + 1e-5 27 | t = t / std 28 | t = t.view(b, self.groups, h, w) 29 | t = t * self.weight + self.bias 30 | t = t.view(b * self.groups, 1, h, w) 31 | x = x * self.sig(t) 32 | x = x.view(b, c, h, w) 33 | return x 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | 42 | def conv1x1(in_planes, out_planes, stride=1): 43 | """1x1 convolution""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | self.sge = SpatialGroupEnhance(64) 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.sge(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None): 85 | super(Bottleneck, self).__init__() 86 | self.conv1 = conv1x1(inplanes, planes) 87 | self.bn1 = nn.BatchNorm2d(planes) 88 | self.conv2 = conv3x3(planes, planes, stride) 89 | self.bn2 = nn.BatchNorm2d(planes) 90 | self.conv3 = conv1x1(planes, planes * self.expansion) 91 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | self.sge = SpatialGroupEnhance(64) 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | out = self.sge(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class ResNet(nn.Module): 122 | 123 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 124 | super(ResNet, self).__init__() 125 | self.inplanes = 64 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | self.fc = nn.Linear(512 * block.expansion, num_classes) 137 | 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 141 | elif isinstance(m, nn.BatchNorm2d): 142 | nn.init.constant_(m.weight, 1) 143 | nn.init.constant_(m.bias, 0) 144 | 145 | # Zero-initialize the last BN in each residual branch, 146 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 147 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 148 | if zero_init_residual: 149 | for m in self.modules(): 150 | if isinstance(m, Bottleneck): 151 | nn.init.constant_(m.bn3.weight, 0) 152 | elif isinstance(m, BasicBlock): 153 | nn.init.constant_(m.bn2.weight, 0) 154 | 155 | def _make_layer(self, block, planes, blocks, stride=1): 156 | downsample = None 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | downsample = nn.Sequential( 159 | conv1x1(self.inplanes, planes * block.expansion, stride), 160 | nn.BatchNorm2d(planes * block.expansion), 161 | ) 162 | 163 | layers = [] 164 | layers.append(block(self.inplanes, planes, stride, downsample)) 165 | self.inplanes = planes * block.expansion 166 | for _ in range(1, blocks): 167 | layers.append(block(self.inplanes, planes)) 168 | 169 | return nn.Sequential(*layers) 170 | 171 | def forward(self, x): 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | x = self.layer4(x) 181 | 182 | x = self.avgpool(x) 183 | x = x.view(x.size(0), -1) 184 | x = self.fc(x) 185 | 186 | return x 187 | 188 | 189 | 190 | 191 | def sge_resnet18(pretrained=False, **kwargs): 192 | """Constructs a ResNet-18 model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 197 | return model 198 | 199 | 200 | def sge_resnet34(pretrained=False, **kwargs): 201 | """Constructs a ResNet-34 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 206 | return model 207 | 208 | 209 | def sge_resnet50(pretrained=False, **kwargs): 210 | """Constructs a ResNet-50 model. 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 215 | return model 216 | 217 | 218 | def sge_resnet101(pretrained=False, **kwargs): 219 | """Constructs a ResNet-101 model. 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 224 | return model 225 | 226 | 227 | def sge_resnet152(pretrained=False, **kwargs): 228 | """Constructs a ResNet-152 model. 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 233 | return model 234 | 235 | -------------------------------------------------------------------------------- /libs/nn/resnet_sk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | # code from origin repo: https://github.com/implus/PytorchInsight 5 | # TODO sknet for basicblocks 6 | 7 | __all__ = ['sk_resnet18', 'sk_resnet34', 'sk_resnet50', 'sk_resnet101', 8 | 'sk_resnet152'] 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False, groups=groups) 15 | 16 | 17 | def conv1x1(in_planes, out_planes, stride=1): 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | identity = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = conv1x1(inplanes, planes) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | 62 | self.conv2 = conv3x3(planes, planes, stride) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv2g = conv3x3(planes, planes, stride, groups = 32) 65 | self.bn2g = nn.BatchNorm2d(planes) 66 | 67 | self.conv3 = conv1x1(planes, planes * self.expansion) 68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 74 | self.conv_fc1 = nn.Conv2d(planes, planes//16, 1, bias=False) 75 | self.bn_fc1 = nn.BatchNorm2d(planes//16) 76 | self.conv_fc2 = nn.Conv2d(planes//16, 2 * planes, 1, bias=False) 77 | 78 | self.D = planes 79 | 80 | def forward(self, x): 81 | identity = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | d1 = self.conv2(out) 88 | d1 = self.bn2(d1) 89 | d1 = self.relu(d1) 90 | 91 | d2 = self.conv2g(out) 92 | d2 = self.bn2g(d2) 93 | d2 = self.relu(d2) 94 | 95 | d = self.avg_pool(d1) + self.avg_pool(d2) 96 | d = F.relu(self.bn_fc1(self.conv_fc1(d))) 97 | d = self.conv_fc2(d) 98 | d = torch.unsqueeze(d, 1).view(-1, 2, self.D, 1, 1) 99 | d = F.softmax(d, 1) 100 | d1 = d1 * d[:, 0, :, :, :].squeeze(1) 101 | d2 = d2 * d[:, 1, :, :, :].squeeze(1) 102 | d = d1 + d2 103 | 104 | out = self.conv3(d) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | identity = self.downsample(x) 109 | 110 | out += identity 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 119 | super(ResNet, self).__init__() 120 | self.inplanes = 64 121 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 128 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 129 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 130 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 131 | self.fc = nn.Linear(512 * block.expansion, num_classes) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 136 | elif isinstance(m, nn.BatchNorm2d): 137 | nn.init.constant_(m.weight, 1) 138 | nn.init.constant_(m.bias, 0) 139 | 140 | # Zero-initialize the last BN in each residual branch, 141 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 142 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 143 | if zero_init_residual: 144 | for m in self.modules(): 145 | if isinstance(m, Bottleneck): 146 | nn.init.constant_(m.bn3.weight, 0) 147 | elif isinstance(m, BasicBlock): 148 | nn.init.constant_(m.bn2.weight, 0) 149 | 150 | def _make_layer(self, block, planes, blocks, stride=1): 151 | downsample = None 152 | if stride != 1 or self.inplanes != planes * block.expansion: 153 | downsample = nn.Sequential( 154 | conv1x1(self.inplanes, planes * block.expansion, stride), 155 | nn.BatchNorm2d(planes * block.expansion), 156 | ) 157 | 158 | layers = [] 159 | layers.append(block(self.inplanes, planes, stride, downsample)) 160 | self.inplanes = planes * block.expansion 161 | for _ in range(1, blocks): 162 | layers.append(block(self.inplanes, planes)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | x = self.conv1(x) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | x = self.layer1(x) 173 | x = self.layer2(x) 174 | x = self.layer3(x) 175 | x = self.layer4(x) 176 | 177 | x = self.avgpool(x) 178 | x = x.view(x.size(0), -1) 179 | x = self.fc(x) 180 | 181 | return x 182 | 183 | 184 | def sk_resnet18(pretrained=False, **kwargs): 185 | """Constructs a ResNet-18 model. 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | """ 189 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 190 | return model 191 | 192 | 193 | def sk_resnet34(pretrained=False, **kwargs): 194 | """Constructs a ResNet-34 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 199 | return model 200 | 201 | 202 | def sk_resnet50(pretrained=False, **kwargs): 203 | """Constructs a ResNet-50 model. 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | """ 207 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 208 | return model 209 | 210 | 211 | def sk_resnet101(pretrained=False, **kwargs): 212 | """Constructs a ResNet-101 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 217 | return model 218 | 219 | def sk_resnet152(pretrained=False, **kwargs): 220 | """Constructs a ResNet-152 model. 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 225 | return model 226 | 227 | -------------------------------------------------------------------------------- /libs/nn/resnet_srm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import Parameter 3 | 4 | __all__ = [ 'srm_resnet18', 'srm_resnet34', 'srm_resnet50', 'srm_resnet101', 5 | 'srm_resnet152', 'srm_resnext50_32x4d', 'srm_resnext101_32x8d'] 6 | 7 | 8 | class SRMLayer(nn.Module): 9 | def __init__(self, channel): 10 | super(SRMLayer, self).__init__() 11 | 12 | self.cfc = Parameter(torch.Tensor(channel, 2)) 13 | self.cfc.data.fill_(0) 14 | 15 | self.bn = nn.BatchNorm2d(channel) 16 | self.activation = nn.Sigmoid() 17 | 18 | setattr(self.cfc, 'srm_param', True) 19 | setattr(self.bn.weight, 'srm_param', True) 20 | setattr(self.bn.bias, 'srm_param', True) 21 | 22 | def _style_pooling(self, x): 23 | N, C, _, _ = x.size() 24 | 25 | channel_mean = x.view(N, C, -1).mean(dim=2).view(N, C, -1) 26 | channel_std = x.view(N, C, -1).std(dim=2).view(N, C, -1) 27 | 28 | t = torch.cat((channel_mean, channel_std), dim=2) 29 | return t 30 | 31 | def _style_integration(self, t): 32 | z = t * self.cfc[None, :, :] # B x C x 2 33 | z = torch.sum(z, dim=2)[:, :, None, None] # B x C x 1 x 1 34 | 35 | z_hat = self.bn(z) 36 | g = self.activation(z_hat) 37 | 38 | return g 39 | 40 | def forward(self, x): 41 | # B x C x 2 42 | t = self._style_pooling(x) 43 | 44 | # B x C x 1 x 1 45 | g = self._style_integration(t) 46 | 47 | return x * g 48 | 49 | 50 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 51 | """3x3 convolution with padding""" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, groups=groups, bias=False) 54 | 55 | 56 | def conv1x1(in_planes, out_planes, stride=1): 57 | """1x1 convolution""" 58 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 59 | 60 | 61 | class BasicBlock(nn.Module): 62 | expansion = 1 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 65 | base_width=64, norm_layer=None): 66 | super(BasicBlock, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = nn.BatchNorm2d 69 | if groups != 1 or base_width != 64: 70 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 71 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 72 | self.conv1 = conv3x3(inplanes, planes, stride) 73 | self.bn1 = norm_layer(planes) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.conv2 = conv3x3(planes, planes) 76 | self.bn2 = norm_layer(planes) 77 | self.downsample = downsample 78 | self.stride = stride 79 | self.srm = SRMLayer(planes) 80 | 81 | def forward(self, x): 82 | identity = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out = self.srm(out) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class Bottleneck(nn.Module): 103 | expansion = 4 104 | 105 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 106 | base_width=64, norm_layer=None): 107 | super(Bottleneck, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | width = int(planes * (base_width / 64.)) * groups 111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 | self.conv1 = conv1x1(inplanes, width) 113 | self.bn1 = norm_layer(width) 114 | self.conv2 = conv3x3(width, width, stride, groups) 115 | self.bn2 = norm_layer(width) 116 | self.conv3 = conv1x1(width, planes * self.expansion) 117 | self.bn3 = norm_layer(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | self.srm = SRMLayer(planes) 123 | 124 | def forward(self, x): 125 | identity = x 126 | 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.relu(out) 134 | 135 | out = self.conv3(out) 136 | out = self.bn3(out) 137 | 138 | out = self.srm(out) 139 | 140 | if self.downsample is not None: 141 | identity = self.downsample(x) 142 | 143 | out += identity 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | 151 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 152 | groups=1, width_per_group=64, norm_layer=None): 153 | super(ResNet, self).__init__() 154 | if norm_layer is None: 155 | norm_layer = nn.BatchNorm2d 156 | 157 | self.inplanes = 64 158 | self.groups = groups 159 | self.base_width = width_per_group 160 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 161 | bias=False) 162 | self.bn1 = norm_layer(self.inplanes) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 165 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 166 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 167 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 169 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 170 | self.fc = nn.Linear(512 * block.expansion, num_classes) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 175 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 176 | nn.init.constant_(m.weight, 1) 177 | nn.init.constant_(m.bias, 0) 178 | 179 | # Zero-initialize the last BN in each residual branch, 180 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 181 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 182 | if zero_init_residual: 183 | for m in self.modules(): 184 | if isinstance(m, Bottleneck): 185 | nn.init.constant_(m.bn3.weight, 0) 186 | elif isinstance(m, BasicBlock): 187 | nn.init.constant_(m.bn2.weight, 0) 188 | 189 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 190 | if norm_layer is None: 191 | norm_layer = nn.BatchNorm2d 192 | downsample = None 193 | if stride != 1 or self.inplanes != planes * block.expansion: 194 | downsample = nn.Sequential( 195 | conv1x1(self.inplanes, planes * block.expansion, stride,), 196 | norm_layer(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 201 | self.base_width, norm_layer)) 202 | self.inplanes = planes * block.expansion 203 | for _ in range(1, blocks): 204 | layers.append(block(self.inplanes, planes, groups=self.groups, 205 | base_width=self.base_width, norm_layer=norm_layer)) 206 | 207 | return nn.Sequential(*layers) 208 | 209 | def forward(self, x): 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | x = self.maxpool(x) 214 | 215 | x = self.layer1(x) 216 | x = self.layer2(x) 217 | x = self.layer3(x) 218 | x = self.layer4(x) 219 | x = self.avgpool(x) 220 | x = x.view(x.size(0), -1) 221 | x = self.fc(x) 222 | 223 | return x 224 | 225 | 226 | def srm_resnet18(pretrained=False, **kwargs): 227 | """Constructs a ResNet-18 model. 228 | 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 233 | # if pretrained: 234 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 235 | return model 236 | 237 | 238 | def srm_resnet34(pretrained=False, **kwargs): 239 | """Constructs a ResNet-34 model. 240 | 241 | Args: 242 | pretrained (bool): If True, returns a model pre-trained on ImageNet 243 | """ 244 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 245 | # if pretrained: 246 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 247 | return model 248 | 249 | 250 | def srm_resnet50(pretrained=False, **kwargs): 251 | """Constructs a ResNet-50 model. 252 | 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | """ 256 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 257 | # if pretrained: 258 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 259 | return model 260 | 261 | 262 | def srm_resnet101(pretrained=False, **kwargs): 263 | """Constructs a ResNet-101 model. 264 | 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 269 | # if pretrained: 270 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 271 | return model 272 | 273 | 274 | def srm_resnet152(pretrained=False, **kwargs): 275 | """Constructs a ResNet-152 model. 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | """ 280 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 281 | # if pretrained: 282 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 283 | return model 284 | 285 | 286 | def srm_resnext50_32x4d(pretrained=False, **kwargs): 287 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 288 | # if pretrained: 289 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 290 | return model 291 | 292 | 293 | def srm_resnext101_32x8d(pretrained=False, **kwargs): 294 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 295 | # if pretrained: 296 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 297 | return model 298 | 299 | if __name__ == '__main__': 300 | import torch 301 | model = srm_resnet18().cuda() 302 | i = torch.Tensor(2, 3, 256, 256).cuda() 303 | y = model(i) 304 | print(y.size()) 305 | 306 | """ 307 | layer output size: 308 | torch.Size([1, 256, 64, 64]) 309 | torch.Size([1, 512, 32, 32]) 310 | torch.Size([1, 1024, 16, 16]) 311 | torch.Size([1, 2048, 8, 8]) 312 | torch.Size([1, 1000]) 313 | """ -------------------------------------------------------------------------------- /libs/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /libs/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /libs/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /libs/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /libs/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import errno 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | 16 | def get_mean_and_std(dataset): 17 | '''Compute the mean and std value of dataset.''' 18 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 19 | mean = torch.zeros(3) 20 | std = torch.zeros(3) 21 | print('==> Computing mean and std..') 22 | for inputs, targets in dataloader: 23 | for i in range(3): 24 | mean[i] += inputs[:,i,:,:].mean() 25 | std[i] += inputs[:,i,:,:].std() 26 | mean.div_(len(dataset)) 27 | std.div_(len(dataset)) 28 | return mean, std 29 | 30 | 31 | def init_params(net): 32 | '''Init layer parameters.''' 33 | for m in net.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | init.kaiming_normal(m.weight, mode='fan_out') 36 | if m.bias: 37 | init.constant(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant(m.weight, 1) 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | init.normal(m.weight, std=1e-3) 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | 46 | 47 | _, term_width = os.popen('stty size', 'r').read().split() 48 | term_width = int(term_width) 49 | 50 | TOTAL_BAR_LENGTH = 65. 51 | last_time = time.time() 52 | begin_time = last_time 53 | 54 | 55 | def progress_bar(current, total, msg=None): 56 | global last_time, begin_time 57 | if current == 0: 58 | begin_time = time.time() # Reset for new bar. 59 | 60 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 61 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 62 | 63 | sys.stdout.write(' [') 64 | for i in range(cur_len): 65 | sys.stdout.write('=') 66 | sys.stdout.write('>') 67 | for i in range(rest_len): 68 | sys.stdout.write('.') 69 | sys.stdout.write(']') 70 | 71 | cur_time = time.time() 72 | step_time = cur_time - last_time 73 | last_time = cur_time 74 | tot_time = cur_time - begin_time 75 | 76 | L = [] 77 | L.append(' Step: %s' % format_time(step_time)) 78 | L.append(' | Tot: %s' % format_time(tot_time)) 79 | if msg: 80 | L.append(' | ' + msg) 81 | 82 | msg = ''.join(L) 83 | sys.stdout.write(msg) 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 85 | sys.stdout.write(' ') 86 | 87 | # Go back to the center of the bar. 88 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 89 | sys.stdout.write('\b') 90 | sys.stdout.write(' %d/%d ' % (current+1, total)) 91 | 92 | if current < total-1: 93 | sys.stdout.write('\r') 94 | else: 95 | sys.stdout.write('\n') 96 | sys.stdout.flush() 97 | 98 | 99 | def format_time(seconds): 100 | days = int(seconds / 3600/24) 101 | seconds = seconds - days*3600*24 102 | hours = int(seconds / 3600) 103 | seconds = seconds - hours*3600 104 | minutes = int(seconds / 60) 105 | seconds = seconds - minutes*60 106 | secondsf = int(seconds) 107 | seconds = seconds - secondsf 108 | millis = int(seconds*1000) 109 | 110 | f = '' 111 | i = 1 112 | if days > 0: 113 | f += str(days) + 'D' 114 | i += 1 115 | if hours > 0 and i <= 2: 116 | f += str(hours) + 'h' 117 | i += 1 118 | if minutes > 0 and i <= 2: 119 | f += str(minutes) + 'm' 120 | i += 1 121 | if secondsf > 0 and i <= 2: 122 | f += str(secondsf) + 's' 123 | i += 1 124 | if millis > 0 and i <= 2: 125 | f += str(millis) + 'ms' 126 | i += 1 127 | if f == '': 128 | f = '0ms' 129 | return f 130 | 131 | 132 | def mkdir_p(path): 133 | '''make dir if not exist''' 134 | try: 135 | os.makedirs(path) 136 | except OSError as exc: # Python >2.5 137 | if exc.errno == errno.EEXIST and os.path.isdir(path): 138 | pass 139 | else: 140 | raise 141 | 142 | 143 | def accuracy(output, target, topk=(1,)): 144 | """Computes the precision@k for the specified values of k""" 145 | maxk = max(topk) 146 | batch_size = target.size(0) 147 | 148 | _, pred = output.topk(maxk, 1, True, True) 149 | pred = pred.t() 150 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 151 | 152 | res = [] 153 | for k in topk: 154 | correct_k = correct[:k].view(-1).float().sum(0) 155 | res.append(correct_k.mul_(100.0 / batch_size)) 156 | return res 157 | 158 | 159 | class AverageMeter(object): 160 | """Computes and stores the average and current value 161 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 162 | """ 163 | def __init__(self): 164 | self.reset() 165 | 166 | def reset(self): 167 | self.val = 0 168 | self.avg = 0 169 | self.sum = 0 170 | self.count = 0 171 | 172 | def update(self, val, n=1): 173 | self.val = val 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | import argparse 5 | import os 6 | import shutil 7 | import time 8 | import random 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.nn.functional as F 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim as optim 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import libs.nn as customized_models 23 | 24 | from PIL import ImageFile 25 | 26 | ImageFile.LOAD_TRUNCATED_IMAGES = True 27 | 28 | # import logger and printer 29 | from libs.progress.bar import Bar 30 | from libs.utils import accuracy, mkdir_p, AverageMeter 31 | from libs.logger import Logger 32 | 33 | from libs.flops_counter import get_model_complexity_info 34 | 35 | import warnings 36 | 37 | warnings.filterwarnings('ignore') 38 | 39 | try: 40 | from apex.parallel import DistributedDataParallel as DDP 41 | from apex.fp16_utils import * 42 | from apex import amp, optimizers 43 | except ImportError: 44 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 45 | 46 | 47 | # for servers to immediately record the logs 48 | def flush_print(func): 49 | def new_print(*args, **kwargs): 50 | func(*args, **kwargs) 51 | sys.stdout.flush() 52 | 53 | return new_print 54 | 55 | 56 | print = flush_print(print) 57 | 58 | # Models 59 | default_model_names = sorted(name for name in models.__dict__ 60 | if name.islower() and not name.startswith("__") 61 | and callable(models.__dict__[name])) 62 | 63 | customized_models_names = sorted(name for name in customized_models.__dict__ 64 | if name.islower() and not name.startswith("__") 65 | and callable(customized_models.__dict__[name])) 66 | 67 | for name in customized_models.__dict__: 68 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 69 | models.__dict__[name] = customized_models.__dict__[name] 70 | 71 | model_names = default_model_names + customized_models_names 72 | 73 | # Parse arguments 74 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 75 | 76 | # Datasets 77 | parser.add_argument('-d', '--data', default='path to dataset', type=str) 78 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 79 | help='number of data loading workers (default: 4)') 80 | # Optimization options 81 | parser.add_argument('--opt-level', default='O2', type=str, 82 | help='O2 is mixed FP16/32 training, see more in https://github.com/NVIDIA/apex/tree/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet') 83 | parser.add_argument('--keep-batchnorm-fp32', default=True, action='store_true', 84 | help='keeping cudnn bn leads to fast training') 85 | parser.add_argument('--loss-scale', type=float, default=None) 86 | 87 | parser.add_argument('--label-smoothing', '--ls', default=0.1, type=float) 88 | 89 | parser.add_argument('--mixup', dest='mixup', action='store_true', 90 | help='whether to use mixup') 91 | parser.add_argument('--alpha', default=0.2, type=float, 92 | metavar='mixup alpha', help='alpha value for mixup B(alpha, alpha) distribution') 93 | parser.add_argument('--cos', dest='cos', action='store_true', 94 | help='using cosine decay lr schedule') 95 | parser.add_argument('--warmup', '--wp', default=5, type=int, 96 | help='number of epochs to warmup') 97 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 98 | help='number of total epochs to run') 99 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 100 | help='manual epoch number (useful on restarts)') 101 | parser.add_argument('--train-batch', default=32, type=int, metavar='N', 102 | help='train batchsize (default: 256)') 103 | parser.add_argument('--test-batch', default=125, type=int, metavar='N', 104 | help='test batchsize (default: 200)') 105 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 106 | metavar='LR', help='initial learning rate') 107 | parser.add_argument('--drop', '--dropout', default=0, type=float, 108 | metavar='Dropout', help='Dropout ratio') 109 | parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90], 110 | help='Decrease learning rate at these epochs.') 111 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 112 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 113 | help='momentum') 114 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 115 | metavar='W', help='weight decay (default: 1e-4)') 116 | parser.add_argument('--wd-all', dest='wdall', action='store_true', 117 | help='weight decay on all parameters') 118 | 119 | # Checkpoints 120 | parser.add_argument('--print-freq', '-p', default=10, type=int, 121 | metavar='N', help='print frequency (default: 10)') 122 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 123 | help='path to save checkpoint (default: checkpoint)') 124 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 125 | help='path to latest checkpoint (default: none)') 126 | 127 | # Architecture 128 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 129 | choices=model_names, 130 | help='model architecture: ' + 131 | ' | '.join(model_names) + 132 | ' (default: resnet18)') 133 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 134 | parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') 135 | parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') 136 | parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') 137 | # Miscs 138 | parser.add_argument('--manualSeed', type=int, help='manual seed') 139 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 140 | help='evaluate model on validation set') 141 | parser.add_argument('--pretrained', dest='pretrained', type=bool, default=False, 142 | help='use pre-trained model') 143 | parser.add_argument("--pretrained_dir", default=None, type=str, help="pretrined model directory") 144 | # Device options 145 | parser.add_argument('--local_rank', default=0, type=int) 146 | 147 | args = parser.parse_args() 148 | state = {k: v for k, v in args._get_kwargs()} 149 | 150 | print("opt_level = {}".format(args.opt_level)) 151 | print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) 152 | print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) 153 | 154 | # Use CUDA 155 | # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 156 | use_cuda = torch.cuda.is_available() 157 | 158 | # Random seed 159 | if args.manualSeed is None: 160 | args.manualSeed = random.randint(1, 10000) 161 | random.seed(args.manualSeed) 162 | torch.manual_seed(args.manualSeed) 163 | if use_cuda: 164 | torch.cuda.manual_seed_all(args.manualSeed) 165 | 166 | best_acc = 0 # best test accuracy 167 | 168 | 169 | def fast_collate(batch): 170 | imgs = [img[0] for img in batch] 171 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 172 | w = imgs[0].size[0] 173 | h = imgs[0].size[1] 174 | tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8) 175 | for i, img in enumerate(imgs): 176 | nump_array = np.asarray(img, dtype=np.uint8) 177 | # tens = torch.from_numpy(nump_array) 178 | if nump_array.ndim < 3: 179 | nump_array = np.expand_dims(nump_array, axis=-1) 180 | nump_array = np.rollaxis(nump_array, 2) 181 | 182 | tensor[i] += torch.from_numpy(nump_array) 183 | 184 | return tensor, targets 185 | 186 | 187 | class data_prefetcher(): 188 | def __init__(self, loader): 189 | self.loader = iter(loader) 190 | self.stream = torch.cuda.Stream() 191 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) \ 192 | .cuda().view(1, 3, 1, 1) 193 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) \ 194 | .cuda().view(1, 3, 1, 1) 195 | self.preload() 196 | 197 | def preload(self): 198 | try: 199 | self.next_input, self.next_target = next(self.loader) 200 | except StopIteration: 201 | self.next_input = None 202 | self.next_target = None 203 | return 204 | 205 | with torch.cuda.stream(self.stream): 206 | self.next_input = self.next_input.cuda(non_blocking=True) 207 | self.next_target = self.next_target.cuda(non_blocking=True) 208 | self.next_input = self.next_input.float() 209 | self.next_input = self.next_input.sub_(self.mean).div_(self.std) 210 | 211 | def next(self): 212 | torch.cuda.current_stream().wait_stream(self.stream) 213 | input = self.next_input 214 | target = self.next_target 215 | if input is not None: 216 | self.preload() 217 | return input, target 218 | 219 | 220 | def main(): 221 | global best_acc 222 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 223 | 224 | if not os.path.isdir(args.checkpoint) and args.local_rank == 0: 225 | mkdir_p(args.checkpoint) 226 | 227 | args.distributed = True 228 | args.gpu = args.local_rank 229 | torch.cuda.set_device(args.gpu) 230 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 231 | args.world_size = torch.distributed.get_world_size() 232 | print('world_size = ', args.world_size) 233 | 234 | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." 235 | 236 | # create model 237 | if args.pretrained: 238 | print("=> using pre-trained model '{}'".format(args.arch)) 239 | model = models.__dict__[args.arch]() 240 | elif 'resnext' in args.arch: 241 | model = models.__dict__[args.arch]( 242 | baseWidth=args.base_width, 243 | cardinality=args.cardinality, 244 | ) 245 | else: 246 | print("=> creating model '{}'".format(args.arch)) 247 | model = models.__dict__[args.arch]() 248 | 249 | flops, params = get_model_complexity_info(model, (224, 224), as_strings=False, print_per_layer_stat=False) 250 | print('Flops: %.3f' % (flops / 1e9)) 251 | print('Params: %.2fM' % (params / 1e6)) 252 | 253 | cudnn.benchmark = True 254 | # define loss function (criterion) and optimizer 255 | # criterion = nn.CrossEntropyLoss().cuda() 256 | criterion = SoftCrossEntropyLoss(label_smoothing=args.label_smoothing).cuda() 257 | model = model.cuda() 258 | 259 | args.lr = float(0.1 * float(args.train_batch * args.world_size) / 256.) 260 | state['lr'] = args.lr 261 | optimizer = set_optimizer(model) 262 | # optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 263 | 264 | model, optimizer = amp.initialize(model, optimizer, 265 | opt_level=args.opt_level, 266 | keep_batchnorm_fp32=args.keep_batchnorm_fp32, 267 | loss_scale=args.loss_scale) 268 | 269 | # model = torch.nn.DataParallel(model).cuda() 270 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) 271 | model = DDP(model, delay_allreduce=True) 272 | 273 | # Data loading code 274 | traindir = os.path.join(args.data, 'img_train') 275 | valdir = os.path.join(args.data, 'img_val') 276 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 277 | 278 | data_aug_scale = (0.08, 1.0) 279 | 280 | train_dataset = datasets.ImageFolder(traindir, transforms.Compose([ 281 | transforms.RandomResizedCrop(224, scale=data_aug_scale), 282 | transforms.RandomHorizontalFlip(), 283 | # transforms.ToTensor(), 284 | # normalize, 285 | ])) 286 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 287 | transforms.Resize(256), 288 | transforms.CenterCrop(224), 289 | # transforms.ToTensor(), 290 | # normalize, 291 | ])) 292 | 293 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 294 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 295 | 296 | train_loader = torch.utils.data.DataLoader( 297 | train_dataset, batch_size=args.train_batch, shuffle=False, 298 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) 299 | val_loader = torch.utils.data.DataLoader( 300 | val_dataset, batch_size=args.test_batch, shuffle=False, 301 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=fast_collate) 302 | 303 | # Resume 304 | title = 'ImageNet-' + args.arch 305 | if args.resume: 306 | # Load checkpoint. 307 | print('==> Resuming from checkpoint..', args.resume) 308 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 309 | args.checkpoint = os.path.dirname(args.resume) 310 | checkpoint = torch.load(args.resume, map_location='cpu') 311 | best_acc = checkpoint['best_acc'] 312 | start_epoch = checkpoint['epoch'] 313 | # model may have more keys 314 | t = model.state_dict() 315 | c = checkpoint['state_dict'] 316 | flag = True 317 | for k in t: 318 | if k not in c: 319 | print('not in loading dict! fill it', k, t[k]) 320 | c[k] = t[k] 321 | flag = False 322 | model.load_state_dict(c) 323 | 324 | print('new optimizer !') 325 | if args.local_rank == 0: 326 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 327 | else: 328 | if args.local_rank == 0: 329 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 330 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 331 | 332 | if args.evaluate: 333 | print('\nEvaluation only') 334 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) 335 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 336 | return 337 | 338 | # Train and val 339 | for epoch in range(start_epoch, args.epochs): 340 | train_sampler.set_epoch(epoch) 341 | 342 | adjust_learning_rate(optimizer, epoch) 343 | 344 | if args.local_rank == 0: 345 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 346 | 347 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 348 | test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) 349 | 350 | # save model 351 | if args.local_rank == 0: 352 | # append logger file 353 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 354 | 355 | is_best = test_acc > best_acc 356 | best_acc = max(test_acc, best_acc) 357 | save_checkpoint({ 358 | 'epoch': epoch + 1, 359 | 'state_dict': model.state_dict(), 360 | 'acc': test_acc, 361 | 'best_acc': best_acc, 362 | 'optimizer': optimizer.state_dict(), 363 | }, is_best, checkpoint=args.checkpoint) 364 | 365 | if args.local_rank == 0: 366 | logger.close() 367 | 368 | print('Best acc:') 369 | print(best_acc) 370 | 371 | 372 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 373 | # switch to train mode 374 | model.train() 375 | torch.set_grad_enabled(True) 376 | 377 | batch_time = AverageMeter() 378 | losses = AverageMeter() 379 | top1 = AverageMeter() 380 | top5 = AverageMeter() 381 | end = time.time() 382 | 383 | if args.local_rank == 0: 384 | bar = Bar('Processing', max=len(train_loader)) 385 | show_step = len(train_loader) // 10 386 | 387 | prefetcher = data_prefetcher(train_loader) 388 | inputs, targets = prefetcher.next() 389 | 390 | batch_idx = -1 391 | while inputs is not None: 392 | # for batch_idx, (inputs, targets) in enumerate(train_loader): 393 | batch_idx += 1 394 | batch_size = inputs.size(0) 395 | if batch_size < args.train_batch: 396 | break 397 | # measure data loading time 398 | 399 | # if use_cuda: 400 | # inputs, targets = inputs.cuda(), targets.cuda(async=True) 401 | # inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 402 | 403 | if args.mixup: 404 | inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda) 405 | outputs = model(inputs) 406 | loss_func = mixup_criterion(targets_a, targets_b, lam) 407 | old_loss = loss_func(criterion, outputs) 408 | else: 409 | outputs = model(inputs) 410 | old_loss = criterion(outputs, targets) 411 | 412 | # compute gradient and do SGD step 413 | optimizer.zero_grad() 414 | # loss.backward() 415 | with amp.scale_loss(old_loss, optimizer) as loss: 416 | loss.backward() 417 | optimizer.step() 418 | 419 | if batch_idx % args.print_freq == 0: 420 | # measure accuracy and record loss 421 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 422 | reduced_loss = reduce_tensor(loss.data) 423 | prec1 = reduce_tensor(prec1) 424 | prec5 = reduce_tensor(prec5) 425 | 426 | # to_python_float incurs a host<->device sync 427 | losses.update(to_python_float(reduced_loss), inputs.size(0)) 428 | top1.update(to_python_float(prec1), inputs.size(0)) 429 | top5.update(to_python_float(prec5), inputs.size(0)) 430 | 431 | torch.cuda.synchronize() 432 | # measure elapsed time 433 | batch_time.update((time.time() - end) / args.print_freq) 434 | end = time.time() 435 | 436 | if args.local_rank == 0: # plot progress 437 | bar.suffix = '({batch}/{size}) | Batch: {bt:.3f}s | Total: {total:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 438 | batch=batch_idx + 1, 439 | size=len(train_loader), 440 | bt=batch_time.val, 441 | total=bar.elapsed_td, 442 | loss=losses.avg, 443 | top1=top1.avg, 444 | top5=top5.avg, 445 | ) 446 | bar.next() 447 | if (batch_idx) % show_step == 0 and args.local_rank == 0: 448 | print('E%d' % (epoch) + bar.suffix) 449 | 450 | inputs, targets = prefetcher.next() 451 | 452 | if args.local_rank == 0: 453 | bar.finish() 454 | return (losses.avg, top1.avg) 455 | 456 | 457 | def test(val_loader, model, criterion, epoch, use_cuda): 458 | global best_acc 459 | 460 | batch_time = AverageMeter() 461 | losses = AverageMeter() 462 | top1 = AverageMeter() 463 | top5 = AverageMeter() 464 | 465 | # switch to evaluate mode 466 | model.eval() 467 | # torch.set_grad_enabled(False) 468 | 469 | end = time.time() 470 | if args.local_rank == 0: 471 | bar = Bar('Processing', max=len(val_loader)) 472 | 473 | prefetcher = data_prefetcher(val_loader) 474 | inputs, targets = prefetcher.next() 475 | 476 | batch_idx = -1 477 | while inputs is not None: 478 | # for batch_idx, (inputs, targets) in enumerate(val_loader): 479 | batch_idx += 1 480 | 481 | # if use_cuda: 482 | # inputs, targets = inputs.cuda(), targets.cuda() 483 | # inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) 484 | 485 | # compute output 486 | with torch.no_grad(): 487 | outputs = model(inputs) 488 | loss = criterion(outputs, targets) 489 | 490 | # measure accuracy and record loss 491 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 492 | 493 | reduced_loss = reduce_tensor(loss.data) 494 | prec1 = reduce_tensor(prec1) 495 | prec5 = reduce_tensor(prec5) 496 | 497 | # to_python_float incurs a host<->device sync 498 | losses.update(to_python_float(reduced_loss), inputs.size(0)) 499 | top1.update(to_python_float(prec1), inputs.size(0)) 500 | top5.update(to_python_float(prec5), inputs.size(0)) 501 | 502 | # measure elapsed time 503 | batch_time.update(time.time() - end) 504 | end = time.time() 505 | 506 | # plot progress 507 | if args.local_rank == 0: 508 | bar.suffix = 'Valid({batch}/{size}) | Batch: {bt:.3f}s | Total: {total:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 509 | batch=batch_idx + 1, 510 | size=len(val_loader), 511 | bt=batch_time.avg, 512 | total=bar.elapsed_td, 513 | loss=losses.avg, 514 | top1=top1.avg, 515 | top5=top5.avg, 516 | ) 517 | bar.next() 518 | 519 | inputs, targets = prefetcher.next() 520 | 521 | if args.local_rank == 0: 522 | print(bar.suffix) 523 | bar.finish() 524 | return (losses.avg, top1.avg) 525 | 526 | 527 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 528 | filepath = os.path.join(checkpoint, filename) 529 | torch.save(state, filepath) 530 | if is_best: 531 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 532 | 533 | 534 | def set_optimizer(model): 535 | if args.wdall: 536 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 537 | print('weight decay on all parameters') 538 | else: 539 | # TODO pay attention this param is defined by name, it does not work sometimes 540 | params = [{'params': [p for name, p in model.named_parameters() if \ 541 | ('bias' in name or 'bn' in name)], 'weight_decay': 0} \ 542 | , {'params': [p for name, p in model.named_parameters() if \ 543 | ('bias' not in name and 'bn' not in name)]}] 544 | names = [{'params': [name for name, p in model.named_parameters() if \ 545 | ('bias' in name or 'bn' in name)], 'weight_decay': 0} \ 546 | , {'params': [name for name, p in model.named_parameters() if \ 547 | ('bias' not in name and 'bn' not in name)]}] 548 | print('optimizer group names:', names) 549 | optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 550 | print('optimizer = ', optimizer) 551 | return optimizer 552 | 553 | 554 | def adjust_learning_rate(optimizer, epoch): 555 | global state 556 | 557 | def adjust_optimizer(): 558 | for param_group in optimizer.param_groups: 559 | param_group['lr'] = state['lr'] 560 | 561 | if epoch < args.warmup: 562 | state['lr'] = args.lr * (epoch + 1) / args.warmup 563 | adjust_optimizer() 564 | 565 | elif args.cos: # cosine decay lr schedule (Note: epoch-wise, not batch-wise) 566 | state['lr'] = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / args.epochs)) 567 | adjust_optimizer() 568 | 569 | elif epoch in args.schedule: # step lr schedule 570 | state['lr'] *= args.gamma 571 | adjust_optimizer() 572 | 573 | 574 | class SoftCrossEntropyLoss(nn.NLLLoss): 575 | def __init__(self, label_smoothing=0, num_classes=1000, **kwargs): 576 | assert label_smoothing >= 0 and label_smoothing <= 1 577 | super(SoftCrossEntropyLoss, self).__init__(**kwargs) 578 | self.confidence = 1 - label_smoothing 579 | self.other = label_smoothing * 1.0 / (num_classes - 1) 580 | self.criterion = nn.KLDivLoss(reduction='batchmean') 581 | print('using soft celoss!!!, label_smoothing = ', label_smoothing) 582 | 583 | def forward(self, input, target): 584 | one_hot = torch.zeros_like(input) 585 | one_hot.fill_(self.other) 586 | one_hot.scatter_(1, target.unsqueeze(1).long(), self.confidence) 587 | input = F.log_softmax(input, 1) 588 | return self.criterion(input, one_hot) 589 | 590 | 591 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 592 | if alpha > 0.: 593 | lam = np.random.beta(alpha, alpha) 594 | else: 595 | lam = 1. 596 | 597 | batch_size = x.size(0) 598 | if use_cuda: 599 | index = torch.randperm(batch_size).cuda() 600 | else: 601 | index = torch.randperm(batch_size) 602 | 603 | mixed_x = lam * x + (1 - lam) * x[index, ...] 604 | y_a, y_b = y, y[index] 605 | return mixed_x, y_a, y_b, lam 606 | 607 | 608 | def mixup_criterion(y_a, y_b, lam): 609 | return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 610 | 611 | 612 | def reduce_tensor(tensor): 613 | rt = tensor.clone() 614 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 615 | rt /= args.world_size 616 | return rt 617 | 618 | 619 | if __name__ == '__main__': 620 | main() 621 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | apex 2 | torch==1.1.0 3 | torchvision 4 | opencv-python 5 | -------------------------------------------------------------------------------- /test_speed.py: -------------------------------------------------------------------------------- 1 | # simple scripts to test convolution module 2 | import torch 3 | 4 | 5 | def predictImage(model): 6 | for i in range(100): 7 | img = torch.Tensor(4, 3, 256, 256).cuda() 8 | with torch.no_grad(): 9 | import time 10 | torch.cuda.synchronize() 11 | start = time.time() 12 | out = model(img) 13 | torch.cuda.synchronize() 14 | print('Totoal Speed: {} fps.'.format(1.0 / (time.time() - start))) 15 | 16 | 17 | if __name__ == '__main__': 18 | from libs.nn import PixelAwareResnet50 19 | model = PixelAwareResnet50().cuda() 20 | model.eval() 21 | predictImage(model) 22 | """ 23 | Octave Conv is half speed than Original Resnet. I guess the current implementation using nn.Conv2d. 24 | Using F.conv2d is a little faster than nn.Conv2d. 25 | single 1080-ti: 26 | F.conv2d: 46 fps 27 | nn.conv2d: 42 fps 28 | """ --------------------------------------------------------------------------------