├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── cifar10.yaml ├── cifar100.yaml ├── cub2011.yaml ├── default.py └── imagenet.yaml ├── figures ├── ._C3Net_framework.png ├── C3Net_framework.png └── C3Net_motivation.png ├── lib ├── .DS_Store ├── ._.DS_Store ├── ._model.py ├── __init__.py ├── data.py ├── layers │ ├── __init__.py │ ├── cross_neuron.py │ ├── cross_neuron_distributed.py │ ├── data_parallel.py │ └── selayer.py ├── model.py ├── model_analysis.py ├── networks │ ├── .DS_Store │ ├── ._.DS_Store │ ├── __init__.py │ ├── densenet_cifar.py │ ├── invcnn_pytorch │ │ ├── __init__.py │ │ └── invnet │ │ │ ├── __init__.py │ │ │ └── inv_cnn.py │ ├── mobilenet_v2.py │ ├── resnet_cifar.py │ ├── resnet_cifar_analysis.py │ ├── resnet_cifar_analysis1.py │ ├── resnext.py │ ├── resnext_cifar.py │ └── wide_resnet_cifar.py └── utils │ ├── __init__.py │ ├── cub2011.py │ ├── imagenet.py │ ├── net_utils.py │ ├── verbo.py │ └── visualize.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | # lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jianwei Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C3Net 2 | Pytorch implementation of our NeurIPS 2019 paper ["Cross-channel Communication Networks"](https://papers.nips.cc/paper/8411-cross-channel-communication-networks.pdf) 3 | 4 |
5 | 6 |
7 | 8 | ## Introduction 9 | 10 | As shown above, the motivation behind our proposed C3Net is that: 11 | 12 | * Neurons at the same layer do not directly interact with each other. 13 | * Different neurons might respond to the same patterns and locations. 14 | 15 | In this paper, we mainly focus on convolutional neural networks (CNN). In CNN, channel responses naturally encodes which pattern is at where. Our main idea is to enable channels at the same layer to communicate with each other and then calibrate their responses accordingly. We want different filters learn to focus on different useful patterns, so that they are complementary to each other. 16 | 17 | The main contributions are: 18 | 19 | * We proposed cross-channel communication (C3) block to enable full interactions across channels at the same layer. 20 | * It achieved better performance on image classification, object detection and semantic segmentation. 21 | * It captured more diverse representations with light-weight networks. 22 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import _C as cfg 2 | -------------------------------------------------------------------------------- /configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | input_size: 32 3 | initialize: 4 | seed: 1 5 | training: 6 | batch_size: 128 7 | workers: 8 8 | test: 9 | batch_size: 1000 10 | workers: 8 11 | optimizer: 12 | name: 'sgd' 13 | lr: 0.1 14 | lr_decay_gamma: 0.1 15 | weight_decay: 1e-4 16 | momentum: 0.9 17 | lr_decay_schedule: (100, 140) 18 | max_epoch: 160 19 | log: 20 | print_interval: 20 21 | checkpoint_interval: 2 22 | -------------------------------------------------------------------------------- /configs/cifar100.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | input_size: 32 3 | initialize: 4 | seed: 1 5 | training: 6 | batch_size: 128 7 | workers: 8 8 | test: 9 | batch_size: 1000 10 | workers: 8 11 | optimizer: 12 | name: 'sgd' 13 | lr: 0.1 14 | lr_decay_gamma: 0.1 15 | weight_decay: 1e-4 16 | momentum: 0.9 17 | lr_decay_schedule: (120, 140) 18 | max_epoch: 160 19 | log: 20 | print_interval: 20 21 | checkpoint_interval: 20 22 | -------------------------------------------------------------------------------- /configs/cub2011.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | input_size: 224 3 | initialize: 4 | seed: 1 5 | training: 6 | batch_size: 32 7 | workers: 8 8 | test: 9 | batch_size: 128 10 | workers: 8 11 | optimizer: 12 | name: 'sgd' 13 | lr: 0.001 14 | lr_decay_gamma: 0.9 15 | weight_decay: 1e-4 16 | momentum: 0.9 17 | lr_decay_schedule: (10, 20, 30, 40, 50, 60, 70, 80) 18 | max_epoch: 80 19 | log: 20 | print_interval: 20 21 | checkpoint_interval: 2 22 | -------------------------------------------------------------------------------- /configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yacs.config import CfgNode as CN 3 | 4 | _C = CN() 5 | 6 | _C.data = CN() 7 | _C.data.input_size = 32 8 | _C.data.traindir = "" 9 | _C.data.valdir = "" 10 | _C.data.testdir = "" 11 | 12 | _C.initialize = CN() 13 | _C.initialize.seed = 1 14 | 15 | _C.training = CN() 16 | _C.training.batch_size = 128 17 | _C.training.workers = 8 18 | _C.training.distributed = False 19 | 20 | _C.test = CN() 21 | _C.test.batch_size = 1000 22 | _C.test.workers = 8 23 | 24 | _C.optimizer = CN() 25 | _C.optimizer.name = "sgd" 26 | _C.optimizer.lr = 0.1 27 | _C.optimizer.lr_decay_gamma = 0.1 28 | _C.optimizer.weight_decay = 1e-4 29 | _C.optimizer.momentum = 0.9 30 | _C.optimizer.lr_decay_schedule = (60, 120) 31 | _C.optimizer.max_epoch = 160 32 | 33 | _C.log = CN() 34 | _C.log.print_interval = 20 35 | _C.log.checkpoint_interval = 2 36 | -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | input_size: 224 3 | traindir: "/srv/share/datasets/ImageNet/pytorch_folder/train" 4 | valdir: "/srv/share/datasets/ImageNet/pytorch_folder/val" 5 | initialize: 6 | seed: 1 7 | training: 8 | batch_size: 512 9 | workers: 8 10 | distributed: False 11 | test: 12 | batch_size: 256 13 | workers: 8 14 | optimizer: 15 | name: 'sgd' 16 | lr: 0.1 17 | lr_decay_gamma: 0.1 18 | weight_decay: 1e-4 19 | momentum: 0.9 20 | lr_decay_schedule: (30, 60) 21 | max_epoch: 90 22 | log: 23 | print_interval: 20 24 | checkpoint_interval: 2 25 | -------------------------------------------------------------------------------- /figures/._C3Net_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/._C3Net_framework.png -------------------------------------------------------------------------------- /figures/C3Net_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/C3Net_framework.png -------------------------------------------------------------------------------- /figures/C3Net_motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/C3Net_motivation.png -------------------------------------------------------------------------------- /lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/.DS_Store -------------------------------------------------------------------------------- /lib/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/._.DS_Store -------------------------------------------------------------------------------- /lib/._model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/._model.py -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .model import * 3 | from .model_analysis import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /lib/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from .utils.imagenet import Loader 4 | from .utils.cub2011 import Cub2011 5 | def get_cifar10(opts): 6 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {} 7 | train_loader = torch.utils.data.DataLoader( 8 | datasets.CIFAR10('../data', train=True, download=True, 9 | transform=transforms.Compose([ 10 | transforms.RandomCrop(32, padding=4), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 14 | ])), 15 | batch_size=opts.training.batch_size, shuffle=True, **kwargs) 16 | test_loader = torch.utils.data.DataLoader( 17 | datasets.CIFAR10('../data', train=False, transform=transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 20 | ])), 21 | batch_size=opts.test.batch_size, shuffle=True, **kwargs) 22 | return train_loader, test_loader 23 | 24 | def get_cifar100(opts): 25 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {} 26 | train_loader = torch.utils.data.DataLoader( 27 | datasets.CIFAR100('../data', train=True, download=True, 28 | transform=transforms.Compose([ 29 | transforms.RandomCrop(32, padding=4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 33 | ])), 34 | batch_size=opts.training.batch_size, shuffle=True, **kwargs) 35 | test_loader = torch.utils.data.DataLoader( 36 | datasets.CIFAR100('../data', train=False, transform=transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 39 | ])), 40 | batch_size=opts.test.batch_size, shuffle=False, **kwargs) 41 | return train_loader, test_loader 42 | 43 | def get_cub2011(opts): 44 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 45 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {} 46 | train_loader = torch.utils.data.DataLoader( 47 | Cub2011('../data', train=True, download=False, 48 | transform=transforms.Compose([ 49 | transforms.Resize(256), 50 | transforms.RandomCrop(224), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | normalize, 54 | ])), 55 | batch_size=opts.training.batch_size, shuffle=True, **kwargs) 56 | test_loader = torch.utils.data.DataLoader( 57 | Cub2011('../data', train=False, transform=transforms.Compose([ 58 | transforms.Resize(256), 59 | transforms.CenterCrop(224), 60 | transforms.ToTensor(), 61 | normalize, 62 | ])), 63 | batch_size=opts.test.batch_size, shuffle=True, **kwargs) 64 | return train_loader, test_loader 65 | 66 | def get_imagenet(opts): 67 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 68 | # train_dataset = datasets.ImageFolder( 69 | # opts.data.traindir, 70 | # transforms.Compose([ 71 | # transforms.RandomResizedCrop(224), 72 | # transforms.RandomHorizontalFlip(), 73 | # transforms.ToTensor(), 74 | # normalize, 75 | # ])) 76 | # 77 | # if opts.training.distributed: 78 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 79 | # else: 80 | # train_sampler = None 81 | # 82 | # train_loader = torch.utils.data.DataLoader( 83 | # train_dataset, batch_size=opts.training.batch_size, shuffle=(train_sampler is None), 84 | # num_workers=opts.training.workers, pin_memory=True, sampler=train_sampler) 85 | 86 | # import pdb; pdb.set_trace() 87 | 88 | train_loader = Loader('train', batch_size=opts.training.batch_size, num_workers=opts.training.workers) 89 | 90 | val_dataset = datasets.ImageFolder( 91 | opts.data.valdir, 92 | transforms.Compose([ 93 | transforms.Resize(256), 94 | transforms.CenterCrop(224), 95 | transforms.ToTensor(), 96 | normalize, 97 | ])) 98 | val_loader = torch.utils.data.DataLoader( 99 | val_dataset, batch_size=256, shuffle=False, 100 | num_workers=opts.test.workers, pin_memory=True) 101 | # val_loader = Loader('val', batch_size=opts.test.batch_size, num_workers=opts.test.workers) 102 | return train_loader, val_loader 103 | 104 | def create_data_loader(opts): 105 | if opts.dataset == "cifar10": 106 | return get_cifar10(opts) 107 | elif opts.dataset == "cifar100": 108 | return get_cifar100(opts) 109 | elif opts.dataset == "cub2011": 110 | return get_cub2011(opts) 111 | elif opts.dataset == "imagenet": 112 | return get_imagenet(opts) 113 | else: 114 | raise ValueError("Unknow dataset") 115 | -------------------------------------------------------------------------------- /lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_neuron_distributed import * 2 | from .selayer import * 3 | from .data_parallel import * 4 | -------------------------------------------------------------------------------- /lib/layers/cross_neuron.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import math 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import numpy as np 9 | from scipy.linalg import block_diag 10 | 11 | class _CrossNeuronBlock(nn.Module): 12 | def __init__(self, in_channels, in_height, in_width, 13 | nblocks_channel=4, 14 | spatial_height=24, spatial_width=24, 15 | reduction=8, size_is_consistant=True): 16 | # nblock_channel: number of block along channel axis 17 | # spatial_size: spatial_size 18 | super(_CrossNeuronBlock, self).__init__() 19 | 20 | # set channel splits 21 | if in_channels <= 512: 22 | self.nblocks_channel = 1 23 | else: 24 | self.nblocks_channel = in_channels // 512 25 | block_size = in_channels // self.nblocks_channel 26 | block = torch.Tensor(block_size, block_size).fill_(1) 27 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 28 | for i in range(self.nblocks_channel): 29 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 30 | 31 | # set spatial splits 32 | if in_height * in_width < 32 * 32 and size_is_consistant: 33 | self.spatial_area = in_height * in_width 34 | self.spatial_height = in_height 35 | self.spatial_width = in_width 36 | else: 37 | self.spatial_area = spatial_height * spatial_width 38 | self.spatial_height = spatial_height 39 | self.spatial_width = spatial_width 40 | 41 | self.fc_in = nn.Sequential( 42 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 43 | nn.ReLU(True), 44 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 45 | ) 46 | 47 | self.fc_out = nn.Sequential( 48 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 49 | nn.ReLU(True), 50 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 51 | ) 52 | 53 | self.bn = nn.BatchNorm1d(self.spatial_area) 54 | 55 | def forward(self, x): 56 | ''' 57 | :param x: (bt, c, h, w) 58 | :return: 59 | ''' 60 | bt, c, h, w = x.shape 61 | residual = x 62 | x_stretch = x.view(bt, c, h * w) 63 | spblock_h = int(np.ceil(h / self.spatial_height)) 64 | spblock_w = int(np.ceil(w / self.spatial_width)) 65 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0 66 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0 67 | 68 | import pdb; pdb.set_trace() 69 | 70 | if spblock_h == 1 and spblock_w == 1: 71 | x_stacked = x_stretch # (b) x c x (h * w) 72 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1) 73 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 74 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 75 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel).detach() # (b * h * w) x 1 x c 76 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 77 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 78 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 79 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 80 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w) 81 | return F.relu(residual + out) 82 | else: 83 | # first splt input tensor into chunks 84 | ind_chunks = [] 85 | x_chunks = [] 86 | for i in range(spblock_h): 87 | for j in range(spblock_w): 88 | tl_y, tl_x = max(0, i * stride_h), max(0, j * stride_w) 89 | br_y, br_x = min(h, tl_y + self.spatial_height), min(w, tl_x + self.spatial_width) 90 | ind_y = torch.arange(tl_y, br_y).view(-1, 1) 91 | ind_x = torch.arange(tl_x, br_x).view(1, -1) 92 | ind = (ind_y * w + ind_x).view(1, 1, -1).repeat(bt, c, 1).type_as(x_stretch).long() 93 | ind_chunks.append(ind) 94 | chunk_ij = torch.gather(x_stretch, 2, ind).contiguous() 95 | x_chunks.append(chunk_ij) 96 | 97 | x_stacked = torch.cat(x_chunks, 0) # (b * nb_h * n_w) x c x (b_h * b_w) 98 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b * nb_h * n_w) x (b_h * b_w) x c 99 | x_v = self.fc_in(x_v) # (b * nb_h * n_w) x (b_h * b_w) x c 100 | x_m = x_v.mean(1).view(-1, 1, c) # (b * nb_h * n_w) x 1 x c 101 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * nb_h * n_w) x c x c 102 | score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 103 | attn = F.softmax(score, dim=1) # (b * nb_h * n_w) x c x c 104 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b * nb_h * n_w) x (b_h * b_w) x c 105 | 106 | # put back to original shape 107 | out = out.permute(0, 2, 1).contiguous() # (b * nb_h * n_w) x c x (b_h * b_w) 108 | # x_stretch_out = x_stretch.clone().zero_() 109 | for i in range(spblock_h): 110 | for j in range(spblock_w): 111 | idx = i * spblock_w + j 112 | ind = ind_chunks[idx] 113 | chunk_ij = out[idx * bt:(idx+1) * bt] 114 | x_stretch = x_stretch.scatter_add(2, ind, chunk_ij / spblock_h / spblock_h) 115 | return F.relu(x_stretch.view(residual.shape)) 116 | 117 | class CrossNeuronlBlock2D(_CrossNeuronBlock): 118 | def __init__(self, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8, size_is_consistant=True): 119 | super(CrossNeuronlBlock2D, self).__init__(in_channels, in_height, in_width, 120 | nblocks_channel=4, 121 | spatial_height=spatial_height, 122 | spatial_width=spatial_width, 123 | reduction=reduction, 124 | size_is_consistant=size_is_consistant) 125 | 126 | 127 | class CrossNeuronWrapper(nn.Module): 128 | def __init__(self, block, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8): 129 | super(CrossNeuronWrapper, self).__init__() 130 | self.block = block 131 | self.cn = CrossNeuronlBlock2D(in_channels, in_height, in_width, spatial_height, spatial_width, reduction=reduction) 132 | 133 | def forward(self, x): 134 | x = self.cn(x) 135 | x = self.block(x) 136 | return x 137 | 138 | def add_cross_neuron(net, img_height, img_width, spatial_height, spatial_width, reduction=8): 139 | import torchvision 140 | import lib.networks as archs 141 | 142 | import pdb; pdb.set_trace() 143 | 144 | if isinstance(net, torchvision.models.ResNet): 145 | dummy_img = torch.randn(1, 3, img_height, img_width) 146 | out = net.conv1(dummy_img) 147 | out = net.relu(net.bn1(out)) 148 | out0 = net.maxpool(out) 149 | print("layer0 out shape: {}x{}x{}x{}".format(out0.shape[0], out0.shape[1], out0.shape[2], out0.shape[3])) 150 | out1 = net.layer1(out0) 151 | print("layer1 out shape: {}x{}x{}x{}".format(out1.shape[0], out1.shape[1], out1.shape[2], out1.shape[3])) 152 | out2 = net.layer2(out1) 153 | print("layer2 out shape: {}x{}x{}x{}".format(out2.shape[0], out2.shape[1], out2.shape[2], out2.shape[3])) 154 | out3 = net.layer3(out2) 155 | print("layer3 out shape: {}x{}x{}x{}".format(out3.shape[0], out3.shape[1], out3.shape[2], out3.shape[3])) 156 | out4 = net.layer4(out3) 157 | print("layer4 out shape: {}x{}x{}x{}".format(out4.shape[0], out4.shape[1], out4.shape[2], out4.shape[3])) 158 | 159 | # net.layer1 = CrossNeuronWrapper(net.layer1, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[0], spatial_width[0], reduction) 160 | net.layer2 = CrossNeuronWrapper(net.layer2, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[1], spatial_width[1], reduction) 161 | net.layer3 = CrossNeuronWrapper(net.layer3, out3.shape[1], out3.shape[2], out3.shape[3], spatial_height[2], spatial_width[2], reduction) 162 | net.layer4 = CrossNeuronWrapper(net.layer4, out4.shape[1], out4.shape[2], out4.shape[3], spatial_height[3], spatial_width[3], reduction) 163 | 164 | # layers = [] 165 | # l = len(net.layer2) 166 | # for i in range(l): 167 | # if i % 6 == 0 or i == (l - 1): 168 | # layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3], 169 | # spatial_height[1], spatial_width[1], reduction[1])) 170 | # else: 171 | # layers.append(net.layer2[i]) 172 | # net.layer2 = nn.Sequential(*layers) 173 | # 174 | # # 175 | # layers = [] 176 | # l = len(net.layer3) 177 | # for i in range(0, l): 178 | # if i % 6 == 0 or i == (l - 1): 179 | # layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2], out3.shape[3], 180 | # spatial_height[2], spatial_width[2], reduction[2])) 181 | # else: 182 | # layers.append(net.layer3[i]) 183 | # net.layer3 = nn.Sequential(*layers) 184 | # 185 | # layers = [] 186 | # l = len(net.layer4) 187 | # for i in range(0, l): 188 | # if i % 6 == 0 or i == (l - 1): 189 | # layers.append(CrossNeuronWrapper(net.layer4[i], out4.shape[1], out4.shape[2], out4.shape[3], 190 | # spatial_height[3], spatial_width[3], reduction[3])) 191 | # else: 192 | # layers.append(net.layer4[i]) 193 | # net.layer4 = nn.Sequential(*layers) 194 | 195 | 196 | elif isinstance(net, archs.resnet_cifar.ResNet_Cifar): 197 | 198 | dummy_img = torch.randn(1, 3, img_height, img_width) 199 | out = net.conv1(dummy_img) 200 | out0 = net.relu(net.bn1(out)) 201 | out1 = net.layer1(out0) 202 | out2 = net.layer2(out1) 203 | out3 = net.layer3(out2) 204 | 205 | net.layer1 = CrossNeuronWrapper(net.layer1, out0.shape[1], out0.shape[2], out0.shape[3], spatial_height[0], spatial_width[0]) 206 | net.layer2 = CrossNeuronWrapper(net.layer2, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[1], spatial_width[1]) 207 | net.layer3 = CrossNeuronWrapper(net.layer3, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[2], spatial_width[2]) 208 | 209 | else: 210 | dummy_img = torch.randn(1, 3, img_height, img_width) 211 | out = net.conv1(dummy_img) 212 | out = net.relu(net.bn1(out)) 213 | out1 = net.layer1(out) 214 | out2 = net.layer2(out1) 215 | out3 = net.layer3(out2) 216 | 217 | net.layer1 = CrossNeuronWrapper(net.layer1, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[0], spatial_width[0]) 218 | net.layer2 = CrossNeuronWrapper(net.layer2, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[1], spatial_width[1]) 219 | net.layer3 = CrossNeuronWrapper(net.layer3, out3.shape[1], out3.shape[2], out3.shape[3], spatial_height[2], spatial_width[2]) 220 | 221 | 222 | # layers = [] 223 | # l = len(net.layer2) 224 | # for i in range(l): 225 | # if i % 5 == 0 or i == (l - 1): 226 | # layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2] * out2.shape[3])) 227 | # else: 228 | # layers.append(net.layer2[i]) 229 | # net.layer2 = nn.Sequential(*layers) 230 | # # 231 | # layers = [] 232 | # l = len(net.layer3) 233 | # for i in range(0, l): 234 | # if i % 5 == 0 or i == (l - 1): 235 | # layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2] * out3.shape[3])) 236 | # else: 237 | # layers.append(net.layer3[i]) 238 | # net.layer3 = nn.Sequential(*layers) 239 | # 240 | # else: 241 | # raise NotImplementedError 242 | 243 | 244 | if __name__ == '__main__': 245 | from torch.autograd import Variable 246 | import torch 247 | 248 | sub_sample = True 249 | bn_layer = True 250 | 251 | img = torch.randn(2, 3, 10, 20, 20) 252 | net = CrossNeuronlBlock3D(3, 20 * 20) 253 | out = net(img) 254 | print(out.size()) 255 | -------------------------------------------------------------------------------- /lib/layers/cross_neuron_distributed.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import math 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import numpy as np 9 | from scipy.linalg import block_diag 10 | 11 | class _CrossNeuronBlock(nn.Module): 12 | def __init__(self, in_channels, in_height, in_width, 13 | nblocks_channel=8, 14 | spatial_height=32, spatial_width=32, 15 | reduction=4, size_is_consistant=True): 16 | # nblock_channel: number of block along channel axis 17 | # spatial_size: spatial_size 18 | super(_CrossNeuronBlock, self).__init__() 19 | 20 | self.corr_bf = 0 21 | self.corr_af = 0 22 | 23 | self.nblocks_channel = 1 if in_channels <= 512 else in_channels // 512 24 | 25 | block_size = in_channels // self.nblocks_channel 26 | block = torch.Tensor(block_size, block_size).fill_(1) 27 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 28 | for i in range(self.nblocks_channel): 29 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 30 | 31 | factor = in_height // spatial_height 32 | if factor == 0 and size_is_consistant: 33 | self.spatial_area = in_height * in_width 34 | self.spatial_height = in_height 35 | self.spatial_width = in_width 36 | else: 37 | ds_layers = [] 38 | us_layers = [] 39 | for i in range(factor - 1): 40 | ds_layer = nn.Sequential( 41 | nn.Conv2d(in_channels, in_channels, kernel_size=2, stride=2, padding=0, groups=in_channels, bias=False), 42 | nn.BatchNorm2d(in_channels), 43 | nn.ReLU(True), 44 | ) 45 | ds_layers.append(ds_layer) 46 | 47 | us_layer = nn.Sequential( 48 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2, padding=0, groups=in_channels, bias=False), 49 | nn.BatchNorm2d(in_channels), 50 | nn.ReLU(True), 51 | ) 52 | us_layers.append(us_layer) 53 | self.downsample = nn.Sequential(*ds_layers) 54 | self.upsample = nn.Sequential(*us_layers) 55 | self.spatial_height = in_height // factor 56 | self.spatial_width = in_width // factor 57 | self.spatial_area = self.spatial_height * self.spatial_width 58 | 59 | self.fc_in = nn.Sequential( 60 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 61 | # nn.BatchNorm1d(self.spatial_area // reduction), 62 | nn.ReLU(True), 63 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 64 | ) 65 | 66 | self.fc_out = nn.Sequential( 67 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 68 | # nn.BatchNorm1d(self.spatial_area // reduction), 69 | nn.ReLU(True), 70 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 71 | # nn.BatchNorm1d(self.spatial_area) 72 | ) 73 | 74 | self.ln = nn.LayerNorm(in_channels) 75 | 76 | self.initialize() 77 | 78 | def initialize(self): 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv1d): 81 | nn.init.kaiming_normal_(m.weight) 82 | elif isinstance(m, nn.BatchNorm1d): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | # def _compute_correlation(self, x): 87 | # bt, c, h, w = x.shape 88 | # x_v = x.view(bt, c, h*w).detach() 89 | # x_v_mean = x_v.mean(1).unsqueeze(1) 90 | # x_v_cent = x_v - x_v_mean # bt x c x (hw) 91 | # # x_v_cent = x_v 92 | # x_v_cent = x_v_cent / (torch.norm(x_v_cent, 2, 2).unsqueeze(2) + 1e-5) 93 | # correlations = torch.bmm(x_v_cent, x_v_cent.permute(0, 2, 1).contiguous()) # btxcxc 94 | # diags = 1 - torch.eye(c).unsqueeze(0).type_as(correlations) 95 | # correlations = correlations * diags 96 | # return torch.abs(correlations).mean(0).sum() / c / (c - 1) 97 | 98 | def _compute_correlation(self, x): 99 | b, c, h, w = x.shape 100 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw) 101 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw) 102 | x_c = x_v - x_m # b x c x (hw) 103 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c 104 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1 105 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c 106 | out = num / dec # b x c x c 107 | # diags = 1 - torch.eye(c).unsqueeze(0).type_as(out) 108 | # out = out * diags 109 | out = torch.abs(out) # .mean(0).sum() / c / (c - 1) 110 | return out.mean() 111 | 112 | def forward(self, x): 113 | ''' 114 | :param x: (bt, c, h, w) 115 | :return: 116 | ''' 117 | bt, c, h, w = x.shape 118 | residual = x 119 | x_stretch = x.view(bt, c, h * w) 120 | self.corr_bf = self._compute_correlation(x) 121 | # self.corr_af = self._compute_correlation(x) 122 | # return x 123 | 124 | if self.spatial_height == h and self.spatial_width == w: 125 | x_stacked = x_stretch # (b) x c x (h * w) 126 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1) 127 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 128 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 129 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c 130 | # x_m = x_m - x_m.mean(2).unsqueeze(2) 131 | # x_m = self.ln(x_m) 132 | # x_m = x_m.detach() 133 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 134 | # score = -score / score.sum(2).unsqueeze(2) 135 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 136 | # x_v = F.dropout(x_v, 0.1, self.training) 137 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 138 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 139 | out = self.fc_out(torch.bmm(x_v, attn)) # (b) x (h * w) x c 140 | out = F.dropout(out.permute(0, 2, 1).contiguous().view(bt, c, h, w), 0.0, self.training) 141 | out = F.relu(residual + out) 142 | self.corr_af = self._compute_correlation(out) 143 | return out 144 | else: 145 | # x = self.downsample(x) 146 | x = F.interpolate(x, (self.spatial_height, self.spatial_width)) 147 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width) 148 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width) 149 | 150 | x_stacked = x_stretch # (b) x c x (h * w) 151 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 152 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 153 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # b x 1 x c 154 | # x_m = x_m - x_m.mean(2).unsqueeze(2) 155 | # x_m = self.ln(x_m) 156 | # x_m = x_m.detach() 157 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # b x c x c 158 | # score = -score / score.sum(2).unsqueeze(2) 159 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 160 | # x_v = F.dropout(x_v, 0.1, self.training) 161 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 162 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 163 | out = self.fc_out(torch.bmm(x_v, attn)) # (b) x (h * w) x c 164 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width) 165 | # out = F.dropout(self.upsample(out), 0.0, self.training) 166 | out = F.dropout(F.interpolate(out, (h, w)), 0.0, self.training) 167 | out = F.relu(residual + out) 168 | self.corr_af = self._compute_correlation(out) 169 | return out 170 | 171 | # class _CrossNeuronBlock(nn.Module): 172 | # def __init__(self, in_channels, in_height, in_width, 173 | # nblocks_channel=8, 174 | # spatial_height=32, spatial_width=32, 175 | # reduction=4, size_is_consistant=True): 176 | # # nblock_channel: number of block along channel axis 177 | # # spatial_size: spatial_size 178 | # super(_CrossNeuronBlock, self).__init__() 179 | # 180 | # # set channel splits 181 | # if in_channels <= 512: 182 | # self.nblocks_channel = 1 183 | # else: 184 | # self.nblocks_channel = in_channels // 512 185 | # block_size = in_channels // self.nblocks_channel 186 | # block = torch.Tensor(block_size, block_size).fill_(1) 187 | # self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 188 | # for i in range(self.nblocks_channel): 189 | # self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 190 | # 191 | # # set spatial splits 192 | # if in_height * in_width < 32 * 32 and size_is_consistant: 193 | # self.spatial_area = in_height * in_width 194 | # self.spatial_height = in_height 195 | # self.spatial_width = in_width 196 | # else: 197 | # self.spatial_area = spatial_height * spatial_width 198 | # self.spatial_height = spatial_height 199 | # self.spatial_width = spatial_width 200 | # 201 | # self.conv_in = nn.Sequential( 202 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False), 203 | # nn.BatchNorm2d(in_channels), 204 | # nn.ReLU(True), 205 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False), 206 | # nn.BatchNorm2d(in_channels), 207 | # nn.ReLU(True), 208 | # ) 209 | # 210 | # self.conv_out = nn.Sequential( 211 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False), 212 | # nn.BatchNorm2d(in_channels), 213 | # nn.ReLU(True), 214 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False), 215 | # nn.BatchNorm2d(in_channels), 216 | # ) 217 | # 218 | # # self.bn = nn.BatchNorm1d(self.spatial_area) 219 | # 220 | # self.initialize() 221 | # 222 | # def initialize(self): 223 | # for m in self.modules(): 224 | # if isinstance(m, nn.Conv2d): 225 | # nn.init.kaiming_normal_(m.weight) 226 | # elif isinstance(m, nn.BatchNorm2d): 227 | # nn.init.constant_(m.weight, 1) 228 | # nn.init.constant_(m.bias, 0) 229 | # 230 | # def forward(self, x): 231 | # ''' 232 | # :param x: (bt, c, h, w) 233 | # :return: 234 | # ''' 235 | # # import pdb; pdb.set_trace() 236 | # bt, c, h, w = x.shape 237 | # residual = x 238 | # x_v = self.conv_in(x) # b x c x h x w 239 | # x_m = x_v.mean(3).mean(2).unsqueeze(2) # bt x c x 1 240 | # 241 | # score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # bt x c x c 242 | # # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 243 | # # x_v = F.dropout(x_v, 0.1, self.training) 244 | # # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 245 | # attn = F.softmax(score, dim=2) # bt x c x c 246 | # out = self.conv_out(torch.bmm(attn, x_v.view(bt, c, h * w)).view(bt, c, h, w)) 247 | # return F.relu(residual + out) 248 | 249 | class CrossNeuronlBlock2D(_CrossNeuronBlock): 250 | def __init__(self, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8, size_is_consistant=True): 251 | super(CrossNeuronlBlock2D, self).__init__(in_channels, in_height, in_width, 252 | nblocks_channel=4, 253 | spatial_height=spatial_height, 254 | spatial_width=spatial_width, 255 | reduction=reduction, 256 | size_is_consistant=size_is_consistant) 257 | 258 | 259 | class CrossNeuronWrapper(nn.Module): 260 | def __init__(self, block, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8): 261 | super(CrossNeuronWrapper, self).__init__() 262 | self.block = block 263 | self.cn = CrossNeuronlBlock2D(in_channels, in_height, in_width, spatial_height, spatial_width, reduction=reduction) 264 | 265 | # self.conv = nn.Sequential( 266 | # nn.Conv2d(in_channels, 4 * in_channels, 3, 1, 1), 267 | # nn.ReLU(True), 268 | # nn.Conv2d(4 * in_channels, in_channels, 3, 1, 1), 269 | # # nn.ReLU(True), 270 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1), 271 | # # nn.ReLU(True), 272 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1), 273 | # # nn.ReLU(True), 274 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1), 275 | # # nn.ReLU(True), 276 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1) 277 | # ) 278 | 279 | def forward(self, x): 280 | x = self.cn(x) 281 | x = self.block(x) 282 | return x 283 | 284 | def add_cross_neuron(net, img_height, img_width, spatial_height, spatial_width, reduction=[4,4,4,4]): 285 | import torchvision 286 | import lib.networks as archs 287 | 288 | import pdb; pdb.set_trace() 289 | 290 | if isinstance(net, torchvision.models.ResNet): 291 | dummy_img = torch.randn(1, 3, img_height, img_width) 292 | out = net.conv1(dummy_img) 293 | out = net.relu(net.bn1(out)) 294 | out0 = net.maxpool(out) 295 | print("layer0 out shape: {}x{}x{}x{}".format(out0.shape[0], out0.shape[1], out0.shape[2], out0.shape[3])) 296 | out1 = net.layer1(out0) 297 | print("layer1 out shape: {}x{}x{}x{}".format(out1.shape[0], out1.shape[1], out1.shape[2], out1.shape[3])) 298 | out2 = net.layer2(out1) 299 | print("layer2 out shape: {}x{}x{}x{}".format(out2.shape[0], out2.shape[1], out2.shape[2], out2.shape[3])) 300 | out3 = net.layer3(out2) 301 | print("layer3 out shape: {}x{}x{}x{}".format(out3.shape[0], out3.shape[1], out3.shape[2], out3.shape[3])) 302 | out4 = net.layer4(out3) 303 | print("layer4 out shape: {}x{}x{}x{}".format(out4.shape[0], out4.shape[1], out4.shape[2], out4.shape[3])) 304 | 305 | 306 | layers = [] 307 | l = len(net.layer1) 308 | for i in range(l): 309 | if i == 0: 310 | layers.append(CrossNeuronWrapper(net.layer1[i], out0.shape[1], out0.shape[2], out0.shape[3], 311 | out0.shape[2], out0.shape[3], 4)) 312 | elif i % 4 == 0: 313 | layers.append(CrossNeuronWrapper(net.layer1[i], out1.shape[1], out1.shape[2], out1.shape[3], 314 | out1.shape[2], out1.shape[3], 4)) 315 | else: 316 | layers.append(net.layer1[i]) 317 | net.layer1 = nn.Sequential(*layers) 318 | 319 | layers = [] 320 | l = len(net.layer2) 321 | for i in range(l): 322 | if i == 0: 323 | layers.append(CrossNeuronWrapper(net.layer2[i], out1.shape[1], out1.shape[2], out1.shape[3], 324 | out1.shape[2], out1.shape[3], 4)) 325 | elif i % 4 == 0: 326 | layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3], 327 | out2.shape[2], out2.shape[3], 4)) 328 | else: 329 | layers.append(net.layer2[i]) 330 | net.layer2 = nn.Sequential(*layers) 331 | 332 | # 333 | layers = [] 334 | l = len(net.layer3) 335 | for i in range(0, l): 336 | if i == 0: 337 | layers.append(CrossNeuronWrapper(net.layer3[i], out2.shape[1], out2.shape[2], out2.shape[3], 338 | out2.shape[2], out2.shape[3], 4)) 339 | elif i % 4 == 0: 340 | layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2], out3.shape[3], 341 | out3.shape[2], out3.shape[3], 4)) 342 | else: 343 | layers.append(net.layer3[i]) 344 | net.layer3 = nn.Sequential(*layers) 345 | 346 | layers = [] 347 | l = len(net.layer4) 348 | for i in range(0, l): 349 | if i == 0: 350 | layers.append(CrossNeuronWrapper(net.layer4[i], out3.shape[1], out3.shape[2], out3.shape[3], 351 | out3.shape[2], out3.shape[3], 4)) 352 | else: 353 | layers.append(net.layer4[i]) 354 | net.layer4 = nn.Sequential(*layers) 355 | 356 | else: 357 | dummy_img = torch.randn(1, 3, img_height, img_width) 358 | out = net.conv1(dummy_img) 359 | out0 = net.relu(net.bn1(out)) 360 | out1 = net.layer1(out0) 361 | out2 = net.layer2(out1) 362 | out3 = net.layer3(out2) 363 | # 364 | net.layer1 = CrossNeuronWrapper(net.layer1, out0.shape[1], out0.shape[2], out0.shape[3], spatial_height[0], spatial_width[0]) 365 | net.layer2 = CrossNeuronWrapper(net.layer2, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[1], spatial_width[1]) 366 | net.layer3 = CrossNeuronWrapper(net.layer3, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[2], spatial_width[2]) 367 | 368 | ''' 369 | layers = [] 370 | l = len(net.layer1) 371 | for i in range(l): 372 | if i == 0: 373 | layers.append(CrossNeuronWrapper(net.layer1[i], out0.shape[1], out0.shape[2], out0.shape[3], out0.shape[2], out0.shape[3])) 374 | elif i in [4, 7]: # resnet56: [4, 7] 375 | layers.append(CrossNeuronWrapper(net.layer1[i], out1.shape[1], out1.shape[2], out1.shape[3], out1.shape[2], out1.shape[3])) 376 | else: 377 | layers.append(net.layer1[i]) 378 | net.layer1 = nn.Sequential(*layers) 379 | # 380 | layers = [] 381 | l = len(net.layer2) 382 | for i in range(l): 383 | if i in [0]: 384 | layers.append(CrossNeuronWrapper(net.layer2[i], out1.shape[1], out1.shape[2], out1.shape[3], out1.shape[2], out1.shape[3])) 385 | elif i in [4, 7]: 386 | layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3], out2.shape[2], out2.shape[3])) 387 | else: 388 | layers.append(net.layer2[i]) 389 | net.layer2 = nn.Sequential(*layers) 390 | # 391 | layers = [] 392 | l = len(net.layer3) 393 | for i in range(0, l): 394 | if i in [0]: 395 | layers.append(CrossNeuronWrapper(net.layer3[i], out2.shape[1], out2.shape[2], out2.shape[3], out2.shape[2], out2.shape[3])) 396 | else: 397 | layers.append(net.layer3[i]) 398 | net.layer3 = nn.Sequential(*layers) 399 | # 400 | ''' 401 | # else: 402 | # raise NotImplementedError 403 | 404 | if __name__ == '__main__': 405 | from torch.autograd import Variable 406 | import torch 407 | 408 | sub_sample = True 409 | bn_layer = True 410 | 411 | img = torch.randn(2, 3, 10, 20, 20) 412 | net = CrossNeuronlBlock3D(3, 20 * 20) 413 | out = net(img) 414 | print(out.size()) 415 | -------------------------------------------------------------------------------- /lib/layers/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | -------------------------------------------------------------------------------- /lib/layers/selayer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=8): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, int(channel / reduction), bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(int(channel / reduction), channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | return x * y.expand_as(x) 20 | -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | # Code for Cross-Neuron Communication Network 2 | # Author: Jianwei Yang (jw2yang@gatech.edu) 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | from .networks import * 9 | from .layers import * 10 | from .layers.cross_neuron_distributed import _CrossNeuronBlock 11 | 12 | class AlexNet(nn.Module): 13 | def __init__(self, num_classes=100, has_gtlayer=False): 14 | super(AlexNet, self).__init__() 15 | 16 | if has_gtlayer: 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 19 | nn.ReLU(inplace=True), 20 | _CrossNeuronBlock(64, 8, 8, spatial_height=8, spatial_width=8, reduction=2), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=2, stride=2), 25 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(kernel_size=2, stride=2), 32 | ) 33 | else: 34 | self.features = nn.Sequential( 35 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 39 | nn.ReLU(inplace=True), 40 | nn.MaxPool2d(kernel_size=2, stride=2), 41 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 44 | nn.ReLU(inplace=True), 45 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 46 | nn.ReLU(inplace=True), 47 | nn.MaxPool2d(kernel_size=2, stride=2), 48 | ) 49 | self.has_gtlayer = has_gtlayer 50 | 51 | # if has_gtlayer: 52 | # self.gtlayer = _CrossNeuronBlock(256, 16, 16, 16, 16) 53 | self.classifier = nn.Linear(256, num_classes) 54 | 55 | def forward(self, x): 56 | x = self.features(x) 57 | x = x.view(x.size(0), -1) 58 | x = self.classifier(x) 59 | return x 60 | 61 | class CrossNeuronNet(nn.Module): 62 | def __init__(self, opts): 63 | super(CrossNeuronNet, self).__init__() 64 | if opts.dataset == "imagenet": # we directly use pytorch version arch 65 | if opts.arch in models.__dict__: 66 | self.net = models.__dict__[opts.arch](pretrained=False) 67 | if opts.arch == "vgg16": 68 | for i in range(len(self.net.features)): 69 | if self.net.features[i].__class__.__name__ == "Conv2d": 70 | channels = self.net.features[i].out_channels 71 | self.net.features[i] = nn.Sequential( 72 | self.net.features[i], 73 | nn.BatchNorm2d(channels) 74 | ) 75 | elif opts.arch in imagenet_models: 76 | self.net = imagenet_models[opts.arch](num_classes=1000) 77 | else: 78 | raise ValueError("Unknow network architecture for imagenet, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?") 79 | # if opts.add_cross_neuron: 80 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16], reduction = [4, 4, 4, 2]) 81 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16]) 82 | 83 | elif opts.dataset == "cub2011": 84 | if opts.arch in models.__dict__: 85 | self.net = models.__dict__[opts.arch](pretrained=True) 86 | self.net.fc = nn.Linear(self.net.fc.in_features, 200, bias=True) 87 | # NOTE: replace the last fc layer 88 | else: 89 | raise ValueError("Unknow network architecture for cub2011, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?") 90 | if opts.add_cross_neuron: 91 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 32, 32, 16], [32, 32, 32, 16], 8) 92 | elif opts.dataset == "cifar10": # we use the arch following He's paper (deep residual learning) 93 | if opts.arch in cifar_models: 94 | self.net = cifar_models[opts.arch](num_classes=10, has_selayer=("se" in opts.arch)) 95 | else: 96 | raise ValueError("Unknow network architecture for imagenet") 97 | if opts.add_cross_neuron: 98 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16]) 99 | elif opts.dataset == "cifar100": # we use the arch following He's paper (deep residual learning) 100 | if opts.arch in cifar_models: 101 | self.net = cifar_models[opts.arch](num_classes=100, has_selayer=("se" in opts.arch), has_gtlayer=False) 102 | elif opts.arch == "alexnet": 103 | self.net = AlexNet(has_gtlayer=opts.add_cross_neuron) 104 | else: 105 | raise ValueError("Unknow network architecture for cifar100") 106 | self.net.name = "cnn" 107 | if opts.add_cross_neuron and not opts.arch == "alexnet": 108 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [16, 16, 8, 8], [16, 16, 8, 8]) 109 | else: 110 | raise ValueError("Unknow dataset, we only support cifar and imagenet for now") 111 | 112 | print(self.net) 113 | 114 | def get_optim_policies(self): 115 | resnet_param = [] 116 | ccn_param = [] 117 | for m in self.modules(): 118 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.BatchNorm1d): 119 | ccn_param = ccn_param + list(m.parameters()) 120 | else: 121 | resnet_param = resnet_param + list(m.parameters()) 122 | 123 | return [{"params": resnet_param, 'lr_mult': 1}, {"params": ccn_param, 'lr_mult': 10}] 124 | 125 | def forward(self, x): 126 | # x = self.net.conv1(x) 127 | # x = self.net.bn1(x) 128 | # x = self.net.relu(x) 129 | # # x = self.net.maxpool(x) 130 | # x0 = x.clone().detach() 131 | # 132 | # x = self.net.layer1(x) 133 | # x1 = x.clone().detach() 134 | # 135 | # x = self.net.layer2(x) 136 | # x2 = x.clone().detach() 137 | # 138 | # x = self.net.layer3(x) 139 | # x3 = x.clone().detach() 140 | 141 | # x = self.net.layer4(x) 142 | # x4 = x.clone().detach() 143 | 144 | # out = self.net.avgpool(x).squeeze(3).squeeze(2) 145 | # out = self.net.fc(out) 146 | out = self.net(x) 147 | return out 148 | -------------------------------------------------------------------------------- /lib/model_analysis.py: -------------------------------------------------------------------------------- 1 | # Code for Cross-Neuron Communication Network 2 | # Author: Jianwei Yang (jw2yang@gatech.edu) 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | from .networks import * 9 | from .layers import * 10 | 11 | class CrossNeuronNetAnalysis(nn.Module): 12 | def __init__(self, opts): 13 | super(CrossNeuronNetAnalysis, self).__init__() 14 | if opts.dataset == "imagenet": # we directly use pytorch version arch 15 | import pdb; pdb.set_trace() 16 | if opts.arch in models.__dict__: 17 | self.net = models.__dict__[opts.arch](pretrained=False) 18 | elif opts.arch in imagenet_models: 19 | self.net = imagenet_models[opts.arch](num_classes=1000) 20 | else: 21 | raise ValueError("Unknow network architecture for imagenet, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?") 22 | if opts.add_cross_neuron: 23 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16], reduction = [4, 4, 4, 2]) 24 | elif opts.dataset == "cub2011": 25 | if opts.arch in models.__dict__: 26 | self.net = models.__dict__[opts.arch](pretrained=True) 27 | self.net.fc = nn.Linear(self.net.fc.in_features, 200, bias=True) 28 | # NOTE: replace the last fc layer 29 | else: 30 | raise ValueError("Unknow network architecture for cub2011, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?") 31 | if opts.add_cross_neuron: 32 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 32, 32, 16], [32, 32, 32, 16], 8) 33 | elif opts.dataset == "cifar10": # we use the arch following He's paper (deep residual learning) 34 | if opts.arch in cifar_models: 35 | self.net = cifar_models[opts.arch](num_classes=10, has_selayer=("se" in opts.arch)) 36 | else: 37 | raise ValueError("Unknow network architecture for imagenet") 38 | if opts.add_cross_neuron: 39 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16]) 40 | elif opts.dataset == "cifar100": # we use the arch following He's paper (deep residual learning) 41 | if opts.arch in cifar_models: 42 | self.net = cifar_models[opts.arch](num_classes=100, insert_layers=opts.layers, depth=opts.depth, 43 | has_selayer=("se" in opts.arch), has_gtlayer=opts.add_cross_neuron, communication=opts.communication) 44 | else: 45 | raise ValueError("Unknow network architecture for cifar100") 46 | # if opts.add_cross_neuron: 47 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16]) 48 | else: 49 | raise ValueError("Unknow dataset, we only support cifar and imagenet for now") 50 | 51 | print(self.net) 52 | 53 | def get_optim_policies(self): 54 | resnet_param = [] 55 | ccn_param = [] 56 | for m in self.modules(): 57 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.BatchNorm1d): 58 | ccn_param = ccn_param + list(m.parameters()) 59 | else: 60 | resnet_param = resnet_param + list(m.parameters()) 61 | 62 | return [{"params": resnet_param, 'lr_mult': 1}, {"params": ccn_param, 'lr_mult': 10}] 63 | 64 | def forward(self, x): 65 | out = self.net(x) 66 | return out 67 | -------------------------------------------------------------------------------- /lib/networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/networks/.DS_Store -------------------------------------------------------------------------------- /lib/networks/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/networks/._.DS_Store -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_cifar import * 2 | from .resnet_cifar_analysis import * 3 | from .resnext_cifar import * 4 | from .wide_resnet_cifar import * 5 | from .senet_pytorch import * 6 | from .ncnet_pytorch import * 7 | from .invcnn_pytorch import * 8 | from .resnext import * 9 | from .mobilenet_v2 import * 10 | 11 | cifar_models = {"resnet20": resnet20a_cifar, 12 | "resnet56": resnet56a_cifar, 13 | "resnet110": resnet110a_cifar, 14 | "resnet164": resnet164_cifar, 15 | 16 | "resnet20a": resnet20a_cifar, 17 | "resnet56a": resnet56a_cifar, 18 | "resnet62a": resnet62a_cifar, 19 | "resnet68a": resnet68a_cifar, 20 | "resnet74a": resnet74a_cifar, 21 | "resnet80a": resnet80a_cifar, 22 | "resnet86a": resnet86a_cifar, 23 | "resnet92a": resnet92a_cifar, 24 | "resnet98a": resnet98a_cifar, 25 | "resnet104a": resnet104a_cifar, 26 | "resnet110a": resnet110a_cifar, 27 | 28 | "seresnet20": resnet20_cifar, 29 | "seresnet56": resnet56_cifar, 30 | "seresnet110": resnet110_cifar, 31 | "seresnet164": resnet164_cifar, 32 | 33 | "seresnet20a": resnet20a_cifar, 34 | "seresnet56a": resnet56a_cifar, 35 | "seresnet62a": resnet62a_cifar, 36 | "seresnet68a": resnet68a_cifar, 37 | "seresnet74a": resnet74a_cifar, 38 | "seresnet80a": resnet80a_cifar, 39 | "seresnet86a": resnet86a_cifar, 40 | "seresnet92a": resnet92a_cifar, 41 | "seresnet98a": resnet98a_cifar, 42 | "seresnet104a": resnet104a_cifar, 43 | "seresnet110a": resnet110a_cifar, 44 | 45 | "plainnet20": resnet20plain_cifar, 46 | "plainnet110": resnet110plain_cifar, 47 | 48 | "seplainnet20": resnet20plain_cifar, 49 | "seplainnet110": resnet110plain_cifar, 50 | 51 | "wresnet20": wresnet20_cifar, 52 | "sewresnet20": wresnet20_cifar, 53 | 54 | "resnext110": resneXt110_cifar, 55 | "seresnext110": resneXt110_cifar, 56 | 57 | "invcnn4": inv_cnn_4, 58 | } 59 | 60 | imagenet_models = {"seresnet18": se_resnet18, 61 | "seresnet50": se_resnet50, 62 | "seresnet101": se_resnet101, 63 | "ncresnet101": nc_resnet101, 64 | "ncvgg16": nc_vgg16, 65 | "resnext50": resnext50, 66 | "seresnext50": seresnext50, 67 | "ncresnext50": ncresnext50, 68 | "mobilenetv2": MobileNetV2, 69 | "semobilenetv2": SEMobileNetV2, 70 | "ncmobilenetv2": NCMobileNetV2, 71 | "sedensenet121": se_densenet121, 72 | } 73 | -------------------------------------------------------------------------------- /lib/networks/densenet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | DenseNet for cifar with pytorch 3 | 4 | Reference: 5 | [1] H. Gao, Z. Liu, L. Maaten and K. Weinberger. Densely connected convolutional networks. In CVPR, 2017 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from collections import OrderedDict 12 | 13 | import math 14 | 15 | class _DenseLayer(nn.Sequential): 16 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 17 | super(_DenseLayer, self).__init__() 18 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 19 | self.add_module('relu1', nn.ReLU(inplace=True)), 20 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 21 | growth_rate, kernel_size=1, stride=1, bias=False)), 22 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 23 | self.add_module('relu2', nn.ReLU(inplace=True)), 24 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 25 | kernel_size=3, stride=1, padding=1, bias=False)), 26 | self.drop_rate = drop_rate 27 | 28 | def forward(self, x): 29 | new_features = super(_DenseLayer, self).forward(x) 30 | if self.drop_rate > 0: 31 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 32 | return torch.cat([x, new_features], 1) 33 | 34 | 35 | class _DenseBlock(nn.Sequential): 36 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 37 | super(_DenseBlock, self).__init__() 38 | for i in range(num_layers): 39 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 40 | self.add_module('denselayer%d' % (i + 1), layer) 41 | 42 | 43 | class _Transition(nn.Sequential): 44 | def __init__(self, num_input_features, num_output_features): 45 | super(_Transition, self).__init__() 46 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 47 | self.add_module('relu', nn.ReLU(inplace=True)) 48 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 49 | kernel_size=1, stride=1, bias=False)) 50 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 51 | 52 | 53 | class DenseNet_Cifar(nn.Module): 54 | r"""Densenet-BC model class, based on 55 | `"Densely Connected Convolutional Networks" `_ 56 | 57 | Args: 58 | growth_rate (int) - how many filters to add each layer (`k` in paper) 59 | block_config (list of 4 ints) - how many layers in each pooling block 60 | num_init_features (int) - the number of filters to learn in the first convolution layer 61 | bn_size (int) - multiplicative factor for number of bottle neck layers 62 | (i.e. bn_size * k features in the bottleneck layer) 63 | drop_rate (float) - dropout rate after each dense layer 64 | num_classes (int) - number of classification classes 65 | """ 66 | def __init__(self, growth_rate=12, block_config=(16, 16, 16), 67 | num_init_features=24, bn_size=4, drop_rate=0, num_classes=10): 68 | 69 | super(DenseNet_Cifar, self).__init__() 70 | 71 | # First convolution 72 | self.features = nn.Sequential(OrderedDict([ 73 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 74 | ])) 75 | 76 | # Each denseblock 77 | num_features = num_init_features 78 | for i, num_layers in enumerate(block_config): 79 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 80 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 81 | self.features.add_module('denseblock%d' % (i + 1), block) 82 | num_features = num_features + num_layers * growth_rate 83 | if i != len(block_config) - 1: 84 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 85 | self.features.add_module('transition%d' % (i + 1), trans) 86 | num_features = num_features // 2 87 | 88 | # Final batch norm 89 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 90 | 91 | # Linear layer 92 | self.classifier = nn.Linear(num_features, num_classes) 93 | 94 | # initialize conv and bn parameters 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | def forward(self, x): 104 | features = self.features(x) 105 | out = F.relu(features, inplace=True) 106 | out = F.avg_pool2d(out, kernel_size=8, stride=1).view(features.size(0), -1) 107 | out = self.classifier(out) 108 | return out 109 | 110 | 111 | def densenet_BC_cifar(depth, k, **kwargs): 112 | N = (depth - 4) // 6 113 | model = DenseNet_Cifar(growth_rate=k, block_config=[N, N, N], num_init_features=2*k, **kwargs) 114 | return model 115 | 116 | 117 | if __name__ == '__main__': 118 | net = densenet_BC_cifar(190, 40, num_classes=100) 119 | input = torch.randn(1, 3, 32, 32) 120 | y = net(input) 121 | print(net) 122 | print(y.size()) 123 | -------------------------------------------------------------------------------- /lib/networks/invcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .invnet import * 2 | -------------------------------------------------------------------------------- /lib/networks/invcnn_pytorch/invnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .inv_cnn import * 2 | -------------------------------------------------------------------------------- /lib/networks/invcnn_pytorch/invnet/inv_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | expansion=1 7 | 8 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.downsample = downsample 16 | self.stride = stride 17 | 18 | def forward(self, x): 19 | residual = x 20 | 21 | out = self.conv1(x) 22 | out = self.bn1(out) 23 | out = self.relu(out) 24 | 25 | out = self.conv2(out) 26 | out = self.bn2(out) 27 | 28 | if self.downsample is not None: 29 | residual = self.downsample(x) 30 | 31 | out += residual 32 | out = self.relu(out) 33 | return out 34 | 35 | class attention(nn.Module): 36 | def __init__(self, dim): 37 | super(attention, self).__init__() 38 | self.net1 = nn.Conv1d(dim, dim, 1, 1, 0) 39 | self.net2 = nn.Conv1d(dim, dim, 1, 1, 0) 40 | 41 | def forward(self, x): 42 | x1 = x[0].view(x[0].shape[0], x[0].shape[1], -1) 43 | x2 = x[1].view(x[1].shape[0], x[1].shape[1], -1) 44 | x1 = F.relu(self.net1(x1)) # B x C x M 45 | x2 = F.relu(self.net2(x2)) # B x C x 1 46 | 47 | attn = torch.softmax((x1 * x2).sum(1), dim=1) 48 | attn = attn.unsqueeze(2) 49 | 50 | out = (x2 + torch.bmm(x1, attn)).unsqueeze(3) 51 | return out 52 | 53 | class INVCNN(nn.Module): 54 | def __init__(self, nlayers, num_classes=100, has_gtlayer=False, has_selayer=False): 55 | super(INVCNN, self).__init__() 56 | self.name = "invcnn" 57 | self.planes = 3 58 | 59 | self.layer1 = nn.Sequential( 60 | nn.Conv2d(3, 64, kernel_size=4, stride=4, padding=0, bias=False), 61 | nn.BatchNorm2d(64), 62 | nn.ReLU(True), 63 | ) 64 | self.fc1 = nn.Linear(64, num_classes) 65 | 66 | self.layer2 = self._make_layer(BasicBlock, 64, 1, kernel_size=4) 67 | self.conv1x1_2 = nn.Sequential( 68 | nn.Conv2d(64, 64, 1, 1, 0), 69 | # nn.ReLU(True), 70 | ) 71 | self.conv3x3_2 = nn.Sequential( 72 | nn.Conv2d(64, 64, 3, 1, 1), 73 | nn.ReLU(True), 74 | nn.AdaptiveAvgPool2d(1) 75 | ) 76 | self.attn2 = attention(64) 77 | self.fc2 = nn.Linear(64, num_classes) 78 | 79 | self.layer3 = self._make_layer(BasicBlock, 64, 2, kernel_size=4) 80 | self.conv1x1_3 = nn.Sequential( 81 | nn.Conv2d(64, 64, 1, 1, 0), 82 | # nn.ReLU(True), 83 | ) 84 | self.conv3x3_3 = nn.Sequential( 85 | nn.Conv2d(64, 64, 3, 1, 1), 86 | nn.ReLU(True), 87 | nn.AdaptiveAvgPool2d(1) 88 | ) 89 | self.attn3 = attention(64) 90 | self.fc3 = nn.Linear(64, num_classes) 91 | 92 | self.layer4 = self._make_layer(BasicBlock, 64, 3, kernel_size=4) 93 | self.conv1x1_4 = nn.Sequential( 94 | nn.Conv2d(64, 64, 1, 1, 0), 95 | # nn.ReLU(True), 96 | ) 97 | self.conv3x3_4 = nn.Sequential( 98 | nn.Conv2d(64, 64, 3, 1, 1), 99 | nn.ReLU(True), 100 | nn.AdaptiveAvgPool2d(1) 101 | ) 102 | self.attn4 = attention(64) 103 | self.fc4 = nn.Linear(64, num_classes) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 108 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def _make_layer(self, block, planes, blocks, kernel_size=3): 113 | layers = [] 114 | inplanes = 3 115 | for i in range(0, blocks): 116 | if i == 0: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=kernel_size, padding=kernel_size // 2, bias=False), 119 | nn.BatchNorm2d(planes) 120 | ) 121 | layers.append(block(inplanes, planes, kernel_size=kernel_size, stride=kernel_size, padding=kernel_size // 2, downsample=downsample)) 122 | else: 123 | layers.append(block(inplanes, planes)) 124 | inplanes = planes 125 | # layers.append(nn.AdaptiveAvgPool2d(1)) 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | # import pdb; pdb.set_trace() 130 | # out1 = self.layer1(x) 131 | # score1 = self.fc1(out1.view(out1.size(0), -1)) 132 | # out2 = self.layer2(x) 133 | # out2 = self.conv1x1_2(out2) + F.interpolate(out1, (out2.shape[2], out2.shape[3])) 134 | # out2_fc = self.conv3x3_2(out2) 135 | # score2 = self.fc2(out2_fc.view(out2_fc.size(0), -1)) 136 | # out3 = self.layer3(x) 137 | # out3 = self.conv1x1_3(out3) + F.interpolate(out2, (out3.shape[2], out3.shape[3])) 138 | # out3_fc = self.conv3x3_3(out3) 139 | # score3 = self.fc3(out3_fc.view(out3_fc.size(0), -1)) 140 | # out4 = self.layer4(x) 141 | # out4 = self.conv1x1_4(out4) + F.interpolate(out3, (out4.shape[2], out4.shape[3])) 142 | # out4_fc = self.conv3x3_4(out4) 143 | # score4 = self.fc4(out4_fc.view(out4_fc.size(0), -1)) 144 | # # scores = (score1, score2, score3, score4) 145 | # return [score4] 146 | 147 | out1 = self.layer1(F.interpolate(x, (4, 4))) 148 | score1 = self.fc1(out1.view(out1.size(0), -1)) 149 | out2 = self.layer2(F.interpolate(x, (8, 8))) # BxCx2x2 150 | out2 = self.attn2((out2, out1)) 151 | score2 = self.fc2(out2.view(out2.size(0), -1)) 152 | out3 = self.layer3(F.interpolate(x, (16, 16))) 153 | out3 = self.attn3((out3, out2)) 154 | score3 = self.fc3(out3.view(out3.size(0), -1)) 155 | out4 = self.layer4(x) 156 | out4 = self.attn4((out4, out3)) 157 | score4 = self.fc4(out4.view(out4.size(0), -1)) 158 | # scores = (score1, score2, score3, score4) 159 | return [score4] 160 | 161 | def inv_cnn_4(**kwargs): 162 | return INVCNN(4, **kwargs) 163 | -------------------------------------------------------------------------------- /lib/networks/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | class InvertedResidual(nn.Module): 22 | def __init__(self, inp, oup, stride, expand_ratio): 23 | super(InvertedResidual, self).__init__() 24 | self.stride = stride 25 | assert stride in [1, 2] 26 | 27 | hidden_dim = round(inp * expand_ratio) 28 | self.use_res_connect = self.stride == 1 and inp == oup 29 | 30 | if expand_ratio == 1: 31 | self.conv = nn.Sequential( 32 | # dw 33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 34 | nn.BatchNorm2d(hidden_dim), 35 | nn.ReLU6(inplace=True), 36 | # pw-linear 37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 38 | nn.BatchNorm2d(oup), 39 | ) 40 | else: 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 48 | nn.BatchNorm2d(hidden_dim), 49 | nn.ReLU6(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, num_classes=1000, input_size=224, width_mult=1.): 64 | super(MobileNetV2, self).__init__() 65 | block = InvertedResidual 66 | input_channel = 32 67 | last_channel = 1280 68 | interverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1], 77 | ] 78 | 79 | # building first layer 80 | assert input_size % 32 == 0 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 83 | self.features = [conv_bn(3, input_channel, 2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in interverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if i == 0: 89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 90 | else: 91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 92 | input_channel = output_channel 93 | # building last several layers 94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 95 | # make it nn.Sequential 96 | self.features = nn.Sequential(*self.features) 97 | 98 | # building classifier 99 | self.classifier = nn.Sequential( 100 | nn.Dropout(0.2), 101 | nn.Linear(self.last_channel, num_classes), 102 | ) 103 | 104 | self._initialize_weights() 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.mean(3).mean(2) 109 | x = self.classifier(x) 110 | return x 111 | 112 | def _initialize_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | elif isinstance(m, nn.Linear): 123 | n = m.weight.size(1) 124 | m.weight.data.normal_(0, 0.01) 125 | m.bias.data.zero_() 126 | -------------------------------------------------------------------------------- /lib/networks/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | from lib.layers import * 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | " 3x3 convolution with padding " 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | class BasicBlockPlain(nn.Module): 20 | expansion=1 21 | 22 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8): 23 | super(BasicBlockPlain, self).__init__() 24 | self.use_se = use_se 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | if self.use_se: 31 | self.se = SELayer(planes, reduction=reduction) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | # out = self.gt_spatial(out) 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | if self.use_se: 46 | out = self.se(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out = residual 52 | out = self.relu(out) 53 | return out 54 | 55 | class BasicBlock(nn.Module): 56 | expansion=1 57 | 58 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 59 | super(BasicBlock, self).__init__() 60 | self.use_se = use_se 61 | self.use_gt = use_gt 62 | self.conv1 = conv3x3(inplanes, planes, stride) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.conv2 = conv3x3(planes, planes) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | if self.use_se: 68 | self.se = SELayer(planes, reduction=reduction) 69 | if self.use_gt: 70 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8) 71 | 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | # out = self.gt_spatial(out) 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | if self.use_se: 86 | out = self.se(out) 87 | if self.use_gt: 88 | out = self.gt(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | return out 96 | 97 | 98 | class Bottleneck(nn.Module): 99 | expansion=4 100 | 101 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 102 | super(Bottleneck, self).__init__() 103 | self.use_se = use_se 104 | self.use_gt = use_gt 105 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 110 | self.bn3 = nn.BatchNorm2d(planes*4) 111 | self.relu = nn.ReLU(inplace=True) 112 | if self.use_se: 113 | self.se = SELayer(planes, reduction=reduction) 114 | if self.use_gt: 115 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | residual = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv3(out) 131 | out = self.bn3(out) 132 | 133 | if self.use_se: 134 | out = self.se(out) 135 | if self.use_gt: 136 | out = self.gt(out) 137 | 138 | if self.downsample is not None: 139 | residual = self.downsample(x) 140 | 141 | out += residual 142 | out = self.relu(out) 143 | 144 | return out 145 | 146 | # class CommunicationBlock(nn.Module): 147 | 148 | 149 | class PreActBasicBlock(nn.Module): 150 | expansion = 1 151 | 152 | def __init__(self, inplanes, planes, stride=1, downsample=None): 153 | super(PreActBasicBlock, self).__init__() 154 | self.bn1 = nn.BatchNorm2d(inplanes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.conv1 = conv3x3(inplanes, planes, stride) 157 | self.bn2 = nn.BatchNorm2d(planes) 158 | self.conv2 = conv3x3(planes, planes) 159 | self.downsample = downsample 160 | self.stride = stride 161 | 162 | def forward(self, x): 163 | residual = x 164 | 165 | out = self.bn1(x) 166 | out = self.relu(out) 167 | 168 | if self.downsample is not None: 169 | residual = self.downsample(out) 170 | 171 | out = self.conv1(out) 172 | 173 | out = self.bn2(out) 174 | out = self.relu(out) 175 | out = self.conv2(out) 176 | 177 | out += residual 178 | 179 | return out 180 | 181 | 182 | class PreActBottleneck(nn.Module): 183 | expansion = 4 184 | 185 | def __init__(self, inplanes, planes, stride=1, downsample=None): 186 | super(PreActBottleneck, self).__init__() 187 | self.bn1 = nn.BatchNorm2d(inplanes) 188 | self.relu = nn.ReLU(inplace=True) 189 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 190 | self.bn2 = nn.BatchNorm2d(planes) 191 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 192 | self.bn3 = nn.BatchNorm2d(planes) 193 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 194 | self.downsample = downsample 195 | self.stride = stride 196 | 197 | def forward(self, x): 198 | residual = x 199 | 200 | out = self.bn1(x) 201 | out = self.relu(out) 202 | 203 | if self.downsample is not None: 204 | residual = self.downsample(out) 205 | 206 | out = self.conv1(out) 207 | 208 | out = self.bn2(out) 209 | out = self.relu(out) 210 | out = self.conv2(out) 211 | 212 | out = self.bn3(out) 213 | out = self.relu(out) 214 | out = self.conv3(out) 215 | 216 | out += residual 217 | 218 | return out 219 | 220 | 221 | class ResNet_Cifar(nn.Module): 222 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False): 223 | super(ResNet_Cifar, self).__init__() 224 | self.inplanes = 16 225 | self.insize = 32 226 | self.has_gtlayer = has_gtlayer 227 | self.has_selayer = has_selayer 228 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 229 | self.bn1 = nn.BatchNorm2d(16) 230 | self.relu = nn.ReLU(inplace=True) 231 | self.layer1 = self._make_layer(block, 16, layers[0]) 232 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 233 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 234 | 235 | self.avgpool = nn.AvgPool2d(8, stride=1) 236 | self.fc = nn.Linear(64 * block.expansion, num_classes) 237 | 238 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False): 239 | downsample = None 240 | if stride != 1 or self.inplanes != planes * block.expansion: 241 | downsample = nn.Sequential( 242 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 243 | nn.BatchNorm2d(planes * block.expansion) 244 | ) 245 | 246 | self.insize = int(self.insize / stride) 247 | layers = [] 248 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=self.has_selayer, use_gt=self.has_gtlayer)) 249 | self.inplanes = planes * block.expansion 250 | for _ in range(1, blocks - 1): 251 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer)) 252 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer)) 253 | return nn.Sequential(*layers) 254 | 255 | def forward(self, x): 256 | x = self.conv1(x) 257 | x = self.bn1(x) 258 | x = self.relu(x) 259 | x0 = x.clone() 260 | x = self.layer1(x) 261 | x1 = x.clone() 262 | x = self.layer2(x) 263 | x2 = x.clone() 264 | x = self.layer3(x) 265 | x3 = x.clone() 266 | 267 | x = self.avgpool(x) 268 | x = x.view(x.size(0), -1) 269 | x = self.fc(x) 270 | 271 | return x #, (x0, x1, x2, x3) 272 | 273 | 274 | class PreAct_ResNet_Cifar(nn.Module): 275 | 276 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False): 277 | super(PreAct_ResNet_Cifar, self).__init__() 278 | self.inplanes = 16 279 | self.insize = 32 280 | self.has_gtlayer = has_gtlayer 281 | self.has_selayer = has_selayer 282 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 283 | self.layer1 = self._make_layer(block, 16, layers[0]) 284 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 285 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 286 | 287 | self.bn = nn.BatchNorm2d(64*block.expansion) 288 | self.relu = nn.ReLU(inplace=True) 289 | self.avgpool = nn.AvgPool2d(8, stride=1) 290 | self.fc = nn.Linear(64*block.expansion, num_classes) 291 | 292 | for m in self.modules(): 293 | if isinstance(m, nn.Conv2d): 294 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 295 | m.weight.data.normal_(0, math.sqrt(2. / n)) 296 | elif isinstance(m, nn.BatchNorm2d): 297 | m.weight.data.fill_(1) 298 | m.bias.data.zero_() 299 | 300 | def _make_layer(self, block, planes, blocks, stride=1): 301 | downsample = None 302 | if stride != 1 or self.inplanes != planes*block.expansion: 303 | downsample = nn.Sequential( 304 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 305 | ) 306 | 307 | layers = [] 308 | layers.append(block(self.inplanes, planes, stride, downsample)) 309 | self.inplanes = planes*block.expansion 310 | for _ in range(1, blocks): 311 | layers.append(block(self.inplanes, planes)) 312 | return nn.Sequential(*layers) 313 | 314 | def forward(self, x): 315 | x = self.conv1(x) 316 | 317 | x = self.layer1(x) 318 | x = self.layer2(x) 319 | x = self.layer3(x) 320 | 321 | x = self.bn(x) 322 | x = self.relu(x) 323 | x = self.avgpool(x) 324 | x = x.view(x.size(0), -1) 325 | x = self.fc(x) 326 | 327 | return x 328 | 329 | 330 | 331 | def resnet20_cifar(**kwargs): 332 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 333 | return model 334 | 335 | def resnet20plain_cifar(**kwargs): 336 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs) 337 | return model 338 | 339 | def resnet32_cifar(**kwargs): 340 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 341 | return model 342 | 343 | 344 | def resnet44_cifar(**kwargs): 345 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 346 | return model 347 | 348 | 349 | def resnet56_cifar(**kwargs): 350 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 351 | return model 352 | 353 | 354 | def resnet110_cifar(**kwargs): 355 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 356 | return model 357 | 358 | def resnet110plain_cifar(**kwargs): 359 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs) 360 | return model 361 | 362 | def resnet1202_cifar(**kwargs): 363 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 364 | return model 365 | 366 | 367 | def resnet164_cifar(**kwargs): 368 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 369 | return model 370 | 371 | 372 | def resnet1001_cifar(**kwargs): 373 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 374 | return model 375 | 376 | 377 | def preact_resnet110_cifar(**kwargs): 378 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 379 | return model 380 | 381 | 382 | def preact_resnet164_cifar(**kwargs): 383 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 384 | return model 385 | 386 | 387 | def preact_resnet1001_cifar(**kwargs): 388 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 389 | return model 390 | 391 | 392 | if __name__ == '__main__': 393 | net = resnet20_cifar() 394 | y = net(torch.randn(1, 3, 64, 64)) 395 | print(net) 396 | print(y.size()) 397 | -------------------------------------------------------------------------------- /lib/networks/resnet_cifar_analysis.py: -------------------------------------------------------------------------------- 1 | ''' 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | from lib.layers import * 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | " 3x3 convolution with padding " 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | def weights_init(m): 20 | if isinstance(m, nn.Conv1d): 21 | xavier(m.weight.data) 22 | xavier(m.bias.data) 23 | 24 | class _CrossNeuronBlockInternal(nn.Module): 25 | def __init__(self, in_channels): 26 | # nblock_channel: number of block along channel axis 27 | # spatial_size: spatial_size 28 | super(_CrossNeuronBlockInternal, self).__init__() 29 | 30 | self.conv_in = nn.Sequential( 31 | nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True), 32 | # nn.BatchNorm2d(in_channels), 33 | nn.ReLU(True), 34 | nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True), 35 | # nn.BatchNorm2d(in_channels), 36 | # nn.ReLU(True), 37 | ) 38 | 39 | self.conv_out = nn.Sequential( 40 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True), 41 | # nn.BatchNorm2d(in_channels), 42 | nn.ReLU(True), 43 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True), 44 | # nn.BatchNorm2d(in_channels), 45 | ) 46 | 47 | # self.bn = nn.BatchNorm1d(self.spatial_area) 48 | 49 | self.initialize() 50 | 51 | def initialize(self): 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.kaiming_normal_(m.weight) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | 59 | def forward(self, x): 60 | ''' 61 | :param x: (bt, c, h, w) 62 | :return: 63 | ''' 64 | bt, c, h, w = x.shape 65 | residual = x 66 | x_v = self.conv_in(x) # b x c x h x w 67 | x_m = x_v.mean(3).mean(2).unsqueeze(2) # bt x c x 1 68 | 69 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # bt x c x c 70 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 71 | # x_v = F.dropout(x_v, 0.1, self.training) 72 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 73 | attn = F.softmax(score, dim=2) # bt x c x c 74 | out = self.conv_out(torch.bmm(attn, x_v.view(bt, c, -1)).view(bt, c, x_v.shape[2], x_v.shape[3])) 75 | return F.relu(residual + out) 76 | 77 | class _CrossNeuronBlock(nn.Module): 78 | def __init__(self, in_channels, in_height, in_width, 79 | spatial_height=32, spatial_width=32, 80 | reduction=16, 81 | size_is_consistant=True, 82 | communication=True, 83 | enc_dec=True): 84 | # nblock_channel: number of block along channel axis 85 | # spatial_size: spatial_size 86 | super(_CrossNeuronBlock, self).__init__() 87 | 88 | self.communication = communication 89 | self.enc_dec = enc_dec 90 | 91 | # set channel splits 92 | if in_channels <= 512: 93 | self.nblocks_channel = 1 94 | else: 95 | self.nblocks_channel = in_channels // 512 96 | block_size = in_channels // self.nblocks_channel 97 | block = torch.Tensor(block_size, block_size).fill_(1) 98 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 99 | for i in range(self.nblocks_channel): 100 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 101 | 102 | # set spatial splits 103 | if in_height * in_width < 16 * 16 and size_is_consistant: 104 | self.spatial_area = in_height * in_width 105 | self.spatial_height = in_height 106 | self.spatial_width = in_width 107 | else: 108 | self.spatial_area = spatial_height * spatial_width 109 | self.spatial_height = spatial_height 110 | self.spatial_width = spatial_width 111 | # 112 | 113 | 114 | self.fc_in = nn.Sequential( 115 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 116 | # nn.BatchNorm1d(self.spatial_area // reduction), 117 | nn.ReLU(True), 118 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 119 | ) 120 | 121 | self.global_context = nn.Sequential( 122 | nn.Conv1d(self.spatial_area, 1, 1, 1, 0, bias=True), 123 | nn.ReLU(True), 124 | ) 125 | 126 | self.fc_out = nn.Sequential( 127 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 128 | # nn.BatchNorm1d(self.spatial_area // reduction), 129 | nn.ReLU(True), 130 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 131 | ) 132 | 133 | self.bn = nn.BatchNorm1d(self.spatial_area) 134 | # self.ln = nn.LayerNorm(in_channels) 135 | 136 | self.initialize() 137 | 138 | def initialize(self): 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv1d): 141 | nn.init.kaiming_normal_(m.weight) 142 | if m.bias is not None: 143 | torch.nn.init.zeros_(m.bias) 144 | elif isinstance(m, nn.BatchNorm1d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | def forward(self, x): 149 | ''' 150 | :param x: (bt, c, h, w) 151 | :return: 152 | ''' 153 | bt, c, h, w = x.shape 154 | residual = x 155 | x_stretch = x.view(bt, c, h * w) 156 | 157 | if self.spatial_height == h and self.spatial_width == w: 158 | x_stacked = x_stretch # (b) x c x (h * w) 159 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1) 160 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 161 | 162 | if self.enc_dec: 163 | x_v = self.fc_in(x_v) # (b) x c x (h * w) 164 | 165 | # x_m = self.global_context(x_v) 166 | # import pdb; pdb.set_trace() 167 | # x_m = self.global_context(x_v) # b x 1 x c 168 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x c 169 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 170 | # score = -torch.abs(x_m - x_m.permute(0, 2, 1).contiguous()) # (b * h * w) x c x c 171 | 172 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 173 | # x_v = F.dropout(x_v, 0.1, self.training) 174 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 175 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 176 | # attn = F.dropout(attn, 0.2, self.training) 177 | 178 | if self.communication: 179 | if self.enc_dec: 180 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 181 | else: 182 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 183 | else: 184 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 185 | 186 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w) 187 | return F.relu(residual + out) 188 | else: 189 | x = F.interpolate(x, (self.spatial_height, self.spatial_width)) 190 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width) 191 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width) 192 | 193 | x_stacked = x_stretch # (b) x c x (h * w) 194 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 195 | 196 | if self.enc_dec: 197 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 198 | 199 | # x_m = self.global_context(x_v) 200 | # x_m = self.global_context(x_v) # b x 1 x c 201 | x_m = x_v.mean(1).unsqueeze(1) # (b) x 1 x c 202 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 203 | # score = -torch.abs(x_m - x_m.permute(0, 2, 1).contiguous()) # (b * h * w) x c x c 204 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 205 | # x_v = F.dropout(x_v, 0.1, self.training) 206 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 207 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 208 | # attn = F.dropout(attn, 0.2, self.training) 209 | if self.communication: 210 | if self.enc_dec: 211 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 212 | else: 213 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 214 | else: 215 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 216 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width) 217 | out = F.interpolate(out, (h, w)) 218 | return F.relu(residual + out) 219 | 220 | class BasicBlockPlain(nn.Module): 221 | expansion=1 222 | 223 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8): 224 | super(BasicBlockPlain, self).__init__() 225 | self.use_se = use_se 226 | self.conv1 = conv3x3(inplanes, planes, stride) 227 | self.bn1 = nn.BatchNorm2d(planes) 228 | self.relu = nn.ReLU(inplace=True) 229 | self.conv2 = conv3x3(planes, planes) 230 | self.bn2 = nn.BatchNorm2d(planes) 231 | if self.use_se: 232 | self.se = SELayer(planes, reduction=reduction) 233 | self.downsample = downsample 234 | self.stride = stride 235 | 236 | def forward(self, x): 237 | # out = self.gt_spatial(out) 238 | residual = x 239 | 240 | out = self.conv1(x) 241 | out = self.bn1(out) 242 | out = self.relu(out) 243 | 244 | out = self.conv2(out) 245 | out = self.bn2(out) 246 | if self.use_se: 247 | out = self.se(out) 248 | 249 | if self.downsample is not None: 250 | residual = self.downsample(x) 251 | 252 | out = residual 253 | out = self.relu(out) 254 | return out 255 | 256 | class BasicBlock(nn.Module): 257 | expansion=1 258 | 259 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 260 | super(BasicBlock, self).__init__() 261 | self.use_se = use_se 262 | self.use_gt = use_gt 263 | self.conv1 = conv3x3(inplanes, planes, stride) 264 | self.bn1 = nn.BatchNorm2d(planes) 265 | self.relu = nn.ReLU(inplace=True) 266 | self.conv2 = conv3x3(planes, planes) 267 | self.bn2 = nn.BatchNorm2d(planes) 268 | if self.use_se: 269 | self.se = SELayer(planes, reduction=reduction) 270 | if self.use_gt: 271 | self.gt = _CrossNeuronBlockInternal(planes) 272 | 273 | self.downsample = downsample 274 | self.stride = stride 275 | 276 | def forward(self, x): 277 | # out = self.gt_spatial(out) 278 | residual = x 279 | 280 | out = self.conv1(x) 281 | out = self.bn1(out) 282 | out = self.relu(out) 283 | 284 | if self.use_gt: 285 | out = self.gt(out) 286 | 287 | out = self.conv2(out) 288 | out = self.bn2(out) 289 | 290 | if self.use_se: 291 | out = self.se(out) 292 | 293 | if self.downsample is not None: 294 | residual = self.downsample(x) 295 | 296 | out += residual 297 | out = self.relu(out) 298 | return out 299 | 300 | 301 | class Bottleneck(nn.Module): 302 | expansion=4 303 | 304 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 305 | super(Bottleneck, self).__init__() 306 | self.use_se = use_se 307 | self.use_gt = use_gt 308 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 309 | self.bn1 = nn.BatchNorm2d(planes) 310 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 311 | self.bn2 = nn.BatchNorm2d(planes) 312 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 313 | self.bn3 = nn.BatchNorm2d(planes*4) 314 | self.relu = nn.ReLU(inplace=True) 315 | if self.use_se: 316 | self.se = SELayer(planes, reduction=reduction) 317 | if self.use_gt: 318 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8) 319 | self.downsample = downsample 320 | self.stride = stride 321 | 322 | def forward(self, x): 323 | residual = x 324 | 325 | out = self.conv1(x) 326 | out = self.bn1(out) 327 | out = self.relu(out) 328 | 329 | out = self.conv2(out) 330 | out = self.bn2(out) 331 | out = self.relu(out) 332 | 333 | out = self.conv3(out) 334 | out = self.bn3(out) 335 | 336 | if self.use_se: 337 | out = self.se(out) 338 | if self.use_gt: 339 | out = self.gt(out) 340 | 341 | if self.downsample is not None: 342 | residual = self.downsample(x) 343 | 344 | out += residual 345 | out = self.relu(out) 346 | 347 | return out 348 | 349 | # class CommunicationBlock(nn.Module): 350 | 351 | 352 | class PreActBasicBlock(nn.Module): 353 | expansion = 1 354 | 355 | def __init__(self, inplanes, planes, stride=1, downsample=None): 356 | super(PreActBasicBlock, self).__init__() 357 | self.bn1 = nn.BatchNorm2d(inplanes) 358 | self.relu = nn.ReLU(inplace=True) 359 | self.conv1 = conv3x3(inplanes, planes, stride) 360 | self.bn2 = nn.BatchNorm2d(planes) 361 | self.conv2 = conv3x3(planes, planes) 362 | self.downsample = downsample 363 | self.stride = stride 364 | 365 | def forward(self, x): 366 | residual = x 367 | 368 | out = self.bn1(x) 369 | out = self.relu(out) 370 | 371 | if self.downsample is not None: 372 | residual = self.downsample(out) 373 | 374 | out = self.conv1(out) 375 | 376 | out = self.bn2(out) 377 | out = self.relu(out) 378 | out = self.conv2(out) 379 | 380 | out += residual 381 | 382 | return out 383 | 384 | 385 | class PreActBottleneck(nn.Module): 386 | expansion = 4 387 | 388 | def __init__(self, inplanes, planes, stride=1, downsample=None): 389 | super(PreActBottleneck, self).__init__() 390 | self.bn1 = nn.BatchNorm2d(inplanes) 391 | self.relu = nn.ReLU(inplace=True) 392 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 393 | self.bn2 = nn.BatchNorm2d(planes) 394 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 395 | self.bn3 = nn.BatchNorm2d(planes) 396 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 397 | self.downsample = downsample 398 | self.stride = stride 399 | 400 | def forward(self, x): 401 | residual = x 402 | 403 | out = self.bn1(x) 404 | out = self.relu(out) 405 | 406 | if self.downsample is not None: 407 | residual = self.downsample(out) 408 | 409 | out = self.conv1(out) 410 | 411 | out = self.bn2(out) 412 | out = self.relu(out) 413 | out = self.conv2(out) 414 | 415 | out = self.bn3(out) 416 | out = self.relu(out) 417 | out = self.conv3(out) 418 | 419 | out += residual 420 | 421 | return out 422 | 423 | 424 | class ResNet_Cifar(nn.Module): 425 | def __init__(self, block, layers, num_classes=10, insert_layers=["1", "2", "3"], depth=1, 426 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True): 427 | super(ResNet_Cifar, self).__init__() 428 | self.inplanes = 16 429 | self.insize = 32 430 | self.layers = insert_layers 431 | self.depth = depth 432 | self.has_gtlayer = has_gtlayer 433 | self.has_selayer = has_selayer 434 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 435 | self.bn1 = nn.BatchNorm2d(16) 436 | self.relu = nn.ReLU(inplace=True) 437 | 438 | self.layer1 = self._make_layer(block, 16, layers[0], has_selayer=has_selayer, has_gtlayer=has_gtlayer) 439 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, has_selayer=has_selayer, has_gtlayer=has_gtlayer) 440 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, has_selayer=has_selayer, has_gtlayer=has_gtlayer) 441 | 442 | if self.has_gtlayer: 443 | if "1" in self.layers: 444 | layers = [] 445 | for i in range(self.depth): 446 | layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, reduction=16, 447 | communication=communication, enc_dec=enc_dec)) 448 | self.nclayer1 = nn.Sequential(*layers) 449 | 450 | if "2" in self.layers: 451 | layers = [] 452 | for i in range(self.depth): 453 | layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, reduction=16, 454 | communication=communication, enc_dec=enc_dec)) 455 | self.nclayer2 = nn.Sequential(*layers) 456 | 457 | if "3" in self.layers: 458 | layers = [] 459 | for i in range(self.depth): 460 | layers.append(_CrossNeuronBlock(32, 16, 16, spatial_height=16, spatial_width=16, reduction=16, 461 | communication=communication, enc_dec=enc_dec)) 462 | self.nclayer3 = nn.Sequential(*layers) 463 | 464 | if "4" in self.layers: 465 | layers = [] 466 | for i in range(self.depth): 467 | layers.append(_CrossNeuronBlock(64, 8, 8, spatial_height=8, spatial_width=8, reduction=8, 468 | communication=communication, enc_dec=enc_dec)) 469 | self.nclayer4 = nn.Sequential(*layers) 470 | 471 | self.avgpool = nn.AvgPool2d(8, stride=1) 472 | self.fc = nn.Linear(64 * block.expansion, num_classes) 473 | 474 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False, has_gtlayer=False): 475 | downsample = None 476 | if stride != 1 or self.inplanes != planes * block.expansion: 477 | downsample = nn.Sequential( 478 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 479 | nn.BatchNorm2d(planes * block.expansion) 480 | ) 481 | 482 | self.insize = int(self.insize / stride) 483 | layers = [] 484 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=has_selayer, use_gt=False)) 485 | self.inplanes = planes * block.expansion 486 | for _ in range(1, blocks - 1): 487 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False)) 488 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False)) 489 | return nn.Sequential(*layers) 490 | 491 | def _correlation(self, x): 492 | b, c, h, w = x.shape 493 | x_v = x.view(b, c, -1) # b x c x (hw) 494 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw) 495 | x_c = x_v - x_m # b x c x (hw) 496 | # x_c = x_v 497 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c 498 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1 499 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c 500 | out = num / dec 501 | out = torch.abs(out) 502 | return out.mean() 503 | 504 | def forward(self, x): 505 | x = self.conv1(x) 506 | x = self.bn1(x) 507 | x = self.relu(x) 508 | 509 | corr1 = self._correlation(x) 510 | x0 = x.clone().detach() 511 | 512 | if self.has_gtlayer and "1" in self.layers: 513 | x = self.nclayer1(x) 514 | corr1_ = self._correlation(x) 515 | else: 516 | corr1_ = corr1.clone() 517 | 518 | x = self.layer1(x) 519 | x1 = x.clone().detach() 520 | 521 | corr2 = self._correlation(x) 522 | if self.has_gtlayer and "2" in self.layers: 523 | x = self.nclayer2(x) 524 | corr2_ = self._correlation(x) 525 | else: 526 | corr2_ = corr2.clone() 527 | 528 | x = self.layer2(x) 529 | x2 = x.clone().detach() 530 | 531 | corr3 = self._correlation(x) 532 | if self.has_gtlayer and "3" in self.layers: 533 | x = self.nclayer3(x) 534 | corr3_ = self._correlation(x) 535 | else: 536 | corr3_ = corr3.clone() 537 | 538 | x = self.layer3(x) 539 | x3 = x.clone().detach() 540 | corr4 = self._correlation(x) 541 | 542 | if self.has_gtlayer and "4" in self.layers: 543 | x = self.nclayer4(x) 544 | corr4_ = self._correlation(x) 545 | else: 546 | corr4_ = corr4.clone() 547 | 548 | # print("corr0: {}, corr1: {}, corr2: {}".format(corr0, corr1, corr2)) 549 | x = self.avgpool(x) 550 | x = x.view(x.size(0), -1) 551 | x = self.fc(x) 552 | 553 | return x, (x0, x1, x2, x3), \ 554 | (corr1, corr2, corr3, corr4), \ 555 | (corr1_, corr2_, corr3_, corr4) 556 | 557 | 558 | class PreAct_ResNet_Cifar(nn.Module): 559 | 560 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False): 561 | super(PreAct_ResNet_Cifar, self).__init__() 562 | self.inplanes = 16 563 | self.insize = 32 564 | self.has_gtlayer = has_gtlayer 565 | self.has_selayer = has_selayer 566 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 567 | self.layer1 = self._make_layer(block, 16, layers[0]) 568 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 569 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 570 | 571 | self.bn = nn.BatchNorm2d(64*block.expansion) 572 | self.relu = nn.ReLU(inplace=True) 573 | self.avgpool = nn.AvgPool2d(8, stride=1) 574 | self.fc = nn.Linear(64*block.expansion, num_classes) 575 | 576 | for m in self.modules(): 577 | if isinstance(m, nn.Conv2d): 578 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 579 | m.weight.data.normal_(0, math.sqrt(2. / n)) 580 | elif isinstance(m, nn.BatchNorm2d): 581 | m.weight.data.fill_(1) 582 | m.bias.data.zero_() 583 | 584 | def _make_layer(self, block, planes, blocks, stride=1): 585 | downsample = None 586 | if stride != 1 or self.inplanes != planes*block.expansion: 587 | downsample = nn.Sequential( 588 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 589 | ) 590 | 591 | layers = [] 592 | layers.append(block(self.inplanes, planes, stride, downsample)) 593 | self.inplanes = planes*block.expansion 594 | for _ in range(1, blocks): 595 | layers.append(block(self.inplanes, planes)) 596 | return nn.Sequential(*layers) 597 | 598 | def forward(self, x): 599 | x = self.conv1(x) 600 | 601 | x = self.layer1(x) 602 | x = self.layer2(x) 603 | x = self.layer3(x) 604 | 605 | x = self.bn(x) 606 | x = self.relu(x) 607 | x = self.avgpool(x) 608 | x = x.view(x.size(0), -1) 609 | x = self.fc(x) 610 | 611 | return x 612 | 613 | 614 | 615 | def resnet20a_cifar(**kwargs): 616 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 617 | return model 618 | 619 | def resnet20plain_cifar(**kwargs): 620 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs) 621 | return model 622 | 623 | def resnet32_cifar(**kwargs): 624 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 625 | return model 626 | 627 | 628 | def resnet44_cifar(**kwargs): 629 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 630 | return model 631 | 632 | 633 | def resnet56a_cifar(**kwargs): 634 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 635 | return model 636 | 637 | def resnet62a_cifar(**kwargs): 638 | model = ResNet_Cifar(BasicBlock, [10, 10, 10], **kwargs) 639 | return model 640 | 641 | def resnet68a_cifar(**kwargs): 642 | model = ResNet_Cifar(BasicBlock, [11, 11, 11], **kwargs) 643 | return model 644 | 645 | def resnet74a_cifar(**kwargs): 646 | model = ResNet_Cifar(BasicBlock, [12, 12, 12], **kwargs) 647 | return model 648 | 649 | def resnet80a_cifar(**kwargs): 650 | model = ResNet_Cifar(BasicBlock, [13, 13, 13], **kwargs) 651 | return model 652 | 653 | def resnet86a_cifar(**kwargs): 654 | model = ResNet_Cifar(BasicBlock, [14, 14, 14], **kwargs) 655 | return model 656 | 657 | def resnet92a_cifar(**kwargs): 658 | model = ResNet_Cifar(BasicBlock, [15, 15, 15], **kwargs) 659 | return model 660 | 661 | def resnet98a_cifar(**kwargs): 662 | model = ResNet_Cifar(BasicBlock, [16, 16, 16], **kwargs) 663 | return model 664 | 665 | def resnet104a_cifar(**kwargs): 666 | model = ResNet_Cifar(BasicBlock, [17, 17, 17], **kwargs) 667 | return model 668 | 669 | def resnet110a_cifar(**kwargs): 670 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 671 | return model 672 | 673 | def resnet110plain_cifar(**kwargs): 674 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs) 675 | return model 676 | 677 | def resnet1202_cifar(**kwargs): 678 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 679 | return model 680 | 681 | 682 | def resnet164_cifar(**kwargs): 683 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 684 | return model 685 | 686 | 687 | def resnet1001_cifar(**kwargs): 688 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 689 | return model 690 | 691 | 692 | def preact_resnet110_cifar(**kwargs): 693 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 694 | return model 695 | 696 | 697 | def preact_resnet164_cifar(**kwargs): 698 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 699 | return model 700 | 701 | 702 | def preact_resnet1001_cifar(**kwargs): 703 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 704 | return model 705 | 706 | 707 | if __name__ == '__main__': 708 | net = resnet20_cifar() 709 | y = net(torch.randn(1, 3, 64, 64)) 710 | print(net) 711 | print(y.size()) 712 | -------------------------------------------------------------------------------- /lib/networks/resnet_cifar_analysis1.py: -------------------------------------------------------------------------------- 1 | ''' 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | from lib.layers import * 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | " 3x3 convolution with padding " 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | class _CrossNeuronBlock(nn.Module): 20 | def __init__(self, in_channels, in_height, in_width, 21 | nblocks_channel=8, 22 | spatial_height=32, spatial_width=32, 23 | reduction=8, size_is_consistant=True, 24 | communication=True, 25 | enc_dec=True): 26 | # nblock_channel: number of block along channel axis 27 | # spatial_size: spatial_size 28 | super(_CrossNeuronBlock, self).__init__() 29 | 30 | self.communication = communication 31 | self.enc_dec = enc_dec 32 | 33 | # set channel splits 34 | if in_channels <= 512: 35 | self.nblocks_channel = 1 36 | else: 37 | self.nblocks_channel = in_channels // 512 38 | block_size = in_channels // self.nblocks_channel 39 | block = torch.Tensor(block_size, block_size).fill_(1) 40 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 41 | for i in range(self.nblocks_channel): 42 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 43 | 44 | # set spatial splits 45 | if in_height * in_width < 16 * 16 and size_is_consistant: 46 | self.spatial_area = in_height * in_width 47 | self.spatial_height = in_height 48 | self.spatial_width = in_width 49 | else: 50 | self.spatial_area = spatial_height * spatial_width 51 | self.spatial_height = spatial_height 52 | self.spatial_width = spatial_width 53 | 54 | self.fc_in = nn.Sequential( 55 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 56 | nn.ReLU(True), 57 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 58 | ) 59 | 60 | self.fc_out = nn.Sequential( 61 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 62 | nn.ReLU(True), 63 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 64 | ) 65 | 66 | self.bn = nn.BatchNorm1d(self.spatial_area) 67 | 68 | self.initialize() 69 | 70 | def initialize(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv1d): 73 | nn.init.kaiming_normal_(m.weight) 74 | elif isinstance(m, nn.BatchNorm1d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def forward(self, x): 79 | ''' 80 | :param x: (bt, c, h, w) 81 | :return: 82 | ''' 83 | bt, c, h, w = x.shape 84 | residual = x 85 | x_stretch = x.view(bt, c, h * w) 86 | spblock_h = int(np.ceil(h / self.spatial_height)) 87 | spblock_w = int(np.ceil(w / self.spatial_width)) 88 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0 89 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0 90 | 91 | if self.spatial_height == h and self.spatial_width == w: 92 | x_stacked = x_stretch # (b) x c x (h * w) 93 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1) 94 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 95 | 96 | if self.enc_dec: 97 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 98 | 99 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c 100 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 101 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 102 | # x_v = F.dropout(x_v, 0.1, self.training) 103 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 104 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 105 | 106 | if self.communication: 107 | if self.enc_dec: 108 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 109 | else: 110 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 111 | else: 112 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 113 | 114 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w) 115 | return (residual + out) 116 | else: 117 | x = F.interpolate(x, (self.spatial_height, self.spatial_width)) 118 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width) 119 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width) 120 | 121 | x_stacked = x_stretch # (b) x c x (h * w) 122 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 123 | 124 | if self.enc_dec: 125 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 126 | 127 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c 128 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 129 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 130 | # x_v = F.dropout(x_v, 0.1, self.training) 131 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 132 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 133 | if self.communication: 134 | if self.enc_dec: 135 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 136 | else: 137 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 138 | else: 139 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 140 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width) 141 | out = F.interpolate(out, (h, w)) 142 | return (residual + out) 143 | 144 | class BasicBlockPlain(nn.Module): 145 | expansion=1 146 | 147 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8): 148 | super(BasicBlockPlain, self).__init__() 149 | self.use_se = use_se 150 | self.conv1 = conv3x3(inplanes, planes, stride) 151 | self.bn1 = nn.BatchNorm2d(planes) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = nn.BatchNorm2d(planes) 155 | if self.use_se: 156 | self.se = SELayer(planes, reduction=reduction) 157 | self.downsample = downsample 158 | self.stride = stride 159 | 160 | def forward(self, x): 161 | # out = self.gt_spatial(out) 162 | residual = x 163 | 164 | out = self.conv1(x) 165 | out = self.bn1(out) 166 | out = self.relu(out) 167 | 168 | out = self.conv2(out) 169 | out = self.bn2(out) 170 | if self.use_se: 171 | out = self.se(out) 172 | 173 | if self.downsample is not None: 174 | residual = self.downsample(x) 175 | 176 | out = residual 177 | out = self.relu(out) 178 | return out 179 | 180 | class BasicBlock(nn.Module): 181 | expansion=1 182 | 183 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 184 | super(BasicBlock, self).__init__() 185 | self.use_se = use_se 186 | self.use_gt = use_gt 187 | self.conv1 = conv3x3(inplanes, planes, stride) 188 | self.bn1 = nn.BatchNorm2d(planes) 189 | self.relu = nn.ReLU(inplace=True) 190 | 191 | # if not self.use_gt: 192 | self.conv2 = conv3x3(planes, planes) 193 | 194 | self.bn2 = nn.BatchNorm2d(planes) 195 | if self.use_se: 196 | self.se = SELayer(planes, reduction=reduction) 197 | if self.use_gt: 198 | self.gt = _CrossNeuronBlock(planes, insize, insize, spatial_height=16, spatial_width=16, reduction=8) 199 | 200 | self.downsample = downsample 201 | self.stride = stride 202 | 203 | def forward(self, x): 204 | # out = self.gt_spatial(out) 205 | residual = x 206 | 207 | out = self.conv1(x) 208 | out = self.bn1(out) 209 | out = self.relu(out) 210 | 211 | out = self.conv2(out) 212 | out = self.bn2(out) 213 | 214 | if self.use_gt: 215 | out = self.gt(out) 216 | 217 | if self.use_se: 218 | out = self.se(out) 219 | 220 | if self.downsample is not None: 221 | residual = self.downsample(x) 222 | 223 | out += residual 224 | out = self.relu(out) 225 | return out 226 | 227 | 228 | class Bottleneck(nn.Module): 229 | expansion=4 230 | 231 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8): 232 | super(Bottleneck, self).__init__() 233 | self.use_se = use_se 234 | self.use_gt = use_gt 235 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 236 | self.bn1 = nn.BatchNorm2d(planes) 237 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 238 | self.bn2 = nn.BatchNorm2d(planes) 239 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 240 | self.bn3 = nn.BatchNorm2d(planes*4) 241 | self.relu = nn.ReLU(inplace=True) 242 | if self.use_se: 243 | self.se = SELayer(planes, reduction=reduction) 244 | if self.use_gt: 245 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8) 246 | self.downsample = downsample 247 | self.stride = stride 248 | 249 | def forward(self, x): 250 | residual = x 251 | 252 | out = self.conv1(x) 253 | out = self.bn1(out) 254 | out = self.relu(out) 255 | 256 | out = self.conv2(out) 257 | out = self.bn2(out) 258 | out = self.relu(out) 259 | 260 | out = self.conv3(out) 261 | out = self.bn3(out) 262 | 263 | if self.use_se: 264 | out = self.se(out) 265 | if self.use_gt: 266 | out = self.gt(out) 267 | 268 | if self.downsample is not None: 269 | residual = self.downsample(x) 270 | 271 | out += residual 272 | out = self.relu(out) 273 | 274 | return out 275 | 276 | # class CommunicationBlock(nn.Module): 277 | 278 | 279 | class PreActBasicBlock(nn.Module): 280 | expansion = 1 281 | 282 | def __init__(self, inplanes, planes, stride=1, downsample=None): 283 | super(PreActBasicBlock, self).__init__() 284 | self.bn1 = nn.BatchNorm2d(inplanes) 285 | self.relu = nn.ReLU(inplace=True) 286 | self.conv1 = conv3x3(inplanes, planes, stride) 287 | self.bn2 = nn.BatchNorm2d(planes) 288 | self.conv2 = conv3x3(planes, planes) 289 | self.downsample = downsample 290 | self.stride = stride 291 | 292 | def forward(self, x): 293 | residual = x 294 | 295 | out = self.bn1(x) 296 | out = self.relu(out) 297 | 298 | if self.downsample is not None: 299 | residual = self.downsample(out) 300 | 301 | out = self.conv1(out) 302 | 303 | out = self.bn2(out) 304 | out = self.relu(out) 305 | out = self.conv2(out) 306 | 307 | out += residual 308 | 309 | return out 310 | 311 | 312 | class PreActBottleneck(nn.Module): 313 | expansion = 4 314 | 315 | def __init__(self, inplanes, planes, stride=1, downsample=None): 316 | super(PreActBottleneck, self).__init__() 317 | self.bn1 = nn.BatchNorm2d(inplanes) 318 | self.relu = nn.ReLU(inplace=True) 319 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 320 | self.bn2 = nn.BatchNorm2d(planes) 321 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 322 | self.bn3 = nn.BatchNorm2d(planes) 323 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 324 | self.downsample = downsample 325 | self.stride = stride 326 | 327 | def forward(self, x): 328 | residual = x 329 | 330 | out = self.bn1(x) 331 | out = self.relu(out) 332 | 333 | if self.downsample is not None: 334 | residual = self.downsample(out) 335 | 336 | out = self.conv1(out) 337 | 338 | out = self.bn2(out) 339 | out = self.relu(out) 340 | out = self.conv2(out) 341 | 342 | out = self.bn3(out) 343 | out = self.relu(out) 344 | out = self.conv3(out) 345 | 346 | out += residual 347 | 348 | return out 349 | 350 | 351 | class ResNet_Cifar(nn.Module): 352 | def __init__(self, block, layers, num_classes=10, insert_layers=["1", "2", "3"], depth=1, 353 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True): 354 | super(ResNet_Cifar, self).__init__() 355 | self.inplanes = 16 356 | self.insize = 32 357 | self.layers = insert_layers 358 | self.depth = depth 359 | self.has_gtlayer = has_gtlayer 360 | self.has_selayer = has_selayer 361 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 362 | self.bn1 = nn.BatchNorm2d(16) 363 | self.relu = nn.ReLU(inplace=True) 364 | self.layer1 = self._make_layer(block, 16, layers[0], has_selayer=has_selayer, 365 | has_gtlayer=(has_gtlayer if ("1" in self.layers) else False)) 366 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, has_selayer=has_selayer, 367 | has_gtlayer=(has_gtlayer if ("2" in self.layers) else False)) 368 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, has_selayer=has_selayer, 369 | has_gtlayer=(has_gtlayer if ("3" in self.layers) else False)) 370 | 371 | # if self.has_gtlayer: 372 | # if "1" in self.layers: 373 | # layers = [] 374 | # for i in range(self.depth): 375 | # layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, 376 | # communication=communication, enc_dec=enc_dec)) 377 | # self.nclayer1 = nn.Sequential(*layers) 378 | # 379 | # if "2" in self.layers: 380 | # layers = [] 381 | # for i in range(self.depth): 382 | # layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, 383 | # communication=communication, enc_dec=enc_dec)) 384 | # self.nclayer2 = nn.Sequential(*layers) 385 | # 386 | # if "3" in self.layers: 387 | # layers = [] 388 | # for i in range(self.depth): 389 | # layers.append(_CrossNeuronBlock(32, 16, 16, spatial_height=16, spatial_width=16, 390 | # communication=communication, enc_dec=enc_dec)) 391 | # self.nclayer3 = nn.Sequential(*layers) 392 | 393 | self.avgpool = nn.AvgPool2d(8, stride=1) 394 | self.fc = nn.Linear(64 * block.expansion, num_classes) 395 | 396 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False, has_gtlayer=False): 397 | downsample = None 398 | if stride != 1 or self.inplanes != planes * block.expansion: 399 | downsample = nn.Sequential( 400 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 401 | nn.BatchNorm2d(planes * block.expansion) 402 | ) 403 | 404 | layers = [] 405 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=has_selayer, use_gt=has_gtlayer)) 406 | self.inplanes = planes * block.expansion 407 | self.insize = int(self.insize / stride) 408 | for _ in range(1, blocks - 1): 409 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False)) 410 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False)) 411 | return nn.Sequential(*layers) 412 | 413 | def _correlation(self, x): 414 | b, c, h, w = x.shape 415 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw) 416 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw) 417 | x_c = x_v - x_m # b x c x (hw) 418 | # x_c = x_v 419 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c 420 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1 421 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c 422 | out = num / dec 423 | out = torch.abs(out) 424 | return out.mean() 425 | 426 | def forward(self, x): 427 | x = self.conv1(x) 428 | x = self.bn1(x) 429 | x = self.relu(x) 430 | 431 | x0 = x.clone().detach() 432 | corr1 = self._correlation(x) 433 | 434 | x = self.layer1(x) 435 | x1 = x.clone().detach() 436 | corr2 = self._correlation(x) 437 | 438 | x = self.layer2(x) 439 | x2 = x.clone().detach() 440 | corr3 = self._correlation(x) 441 | 442 | x = self.layer3(x) 443 | x3 = x.clone().detach() 444 | corr4 = self._correlation(x) 445 | 446 | x = self.avgpool(x) 447 | x = x.view(x.size(0), -1) 448 | x = self.fc(x) 449 | 450 | return x, (x0, x1, x2, x3), (corr1.item(), corr2.item(), corr3.item(), corr4.item()), (corr1.item(), corr2.item(), corr3.item(), corr4.item()) 451 | 452 | 453 | class PreAct_ResNet_Cifar(nn.Module): 454 | 455 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False): 456 | super(PreAct_ResNet_Cifar, self).__init__() 457 | self.inplanes = 16 458 | self.insize = 32 459 | self.has_gtlayer = has_gtlayer 460 | self.has_selayer = has_selayer 461 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 462 | self.layer1 = self._make_layer(block, 16, layers[0]) 463 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 464 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 465 | 466 | self.bn = nn.BatchNorm2d(64*block.expansion) 467 | self.relu = nn.ReLU(inplace=True) 468 | self.avgpool = nn.AvgPool2d(8, stride=1) 469 | self.fc = nn.Linear(64*block.expansion, num_classes) 470 | 471 | for m in self.modules(): 472 | if isinstance(m, nn.Conv2d): 473 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 474 | m.weight.data.normal_(0, math.sqrt(2. / n)) 475 | elif isinstance(m, nn.BatchNorm2d): 476 | m.weight.data.fill_(1) 477 | m.bias.data.zero_() 478 | 479 | def _make_layer(self, block, planes, blocks, stride=1): 480 | downsample = None 481 | if stride != 1 or self.inplanes != planes*block.expansion: 482 | downsample = nn.Sequential( 483 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 484 | ) 485 | 486 | layers = [] 487 | layers.append(block(self.inplanes, planes, stride, downsample)) 488 | self.inplanes = planes*block.expansion 489 | for _ in range(1, blocks): 490 | layers.append(block(self.inplanes, planes)) 491 | return nn.Sequential(*layers) 492 | 493 | def forward(self, x): 494 | x = self.conv1(x) 495 | 496 | x = self.layer1(x) 497 | x = self.layer2(x) 498 | x = self.layer3(x) 499 | 500 | x = self.bn(x) 501 | x = self.relu(x) 502 | x = self.avgpool(x) 503 | x = x.view(x.size(0), -1) 504 | x = self.fc(x) 505 | 506 | return x 507 | 508 | 509 | 510 | def resnet20a_cifar(**kwargs): 511 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 512 | return model 513 | 514 | def resnet20plain_cifar(**kwargs): 515 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs) 516 | return model 517 | 518 | def resnet32_cifar(**kwargs): 519 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 520 | return model 521 | 522 | 523 | def resnet44_cifar(**kwargs): 524 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 525 | return model 526 | 527 | 528 | def resnet56a_cifar(**kwargs): 529 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 530 | return model 531 | 532 | def resnet62a_cifar(**kwargs): 533 | model = ResNet_Cifar(BasicBlock, [10, 10, 10], **kwargs) 534 | return model 535 | 536 | def resnet68a_cifar(**kwargs): 537 | model = ResNet_Cifar(BasicBlock, [11, 11, 11], **kwargs) 538 | return model 539 | 540 | def resnet74a_cifar(**kwargs): 541 | model = ResNet_Cifar(BasicBlock, [12, 12, 12], **kwargs) 542 | return model 543 | 544 | def resnet80a_cifar(**kwargs): 545 | model = ResNet_Cifar(BasicBlock, [13, 13, 13], **kwargs) 546 | return model 547 | 548 | def resnet86a_cifar(**kwargs): 549 | model = ResNet_Cifar(BasicBlock, [14, 14, 14], **kwargs) 550 | return model 551 | 552 | def resnet92a_cifar(**kwargs): 553 | model = ResNet_Cifar(BasicBlock, [15, 15, 15], **kwargs) 554 | return model 555 | 556 | def resnet98a_cifar(**kwargs): 557 | model = ResNet_Cifar(BasicBlock, [16, 16, 16], **kwargs) 558 | return model 559 | 560 | def resnet104a_cifar(**kwargs): 561 | model = ResNet_Cifar(BasicBlock, [17, 17, 17], **kwargs) 562 | return model 563 | 564 | def resnet110a_cifar(**kwargs): 565 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 566 | return model 567 | 568 | def resnet110plain_cifar(**kwargs): 569 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs) 570 | return model 571 | 572 | def resnet1202_cifar(**kwargs): 573 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 574 | return model 575 | 576 | 577 | def resnet164_cifar(**kwargs): 578 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 579 | return model 580 | 581 | 582 | def resnet1001_cifar(**kwargs): 583 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 584 | return model 585 | 586 | 587 | def preact_resnet110_cifar(**kwargs): 588 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 589 | return model 590 | 591 | 592 | def preact_resnet164_cifar(**kwargs): 593 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 594 | return model 595 | 596 | 597 | def preact_resnet1001_cifar(**kwargs): 598 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 599 | return model 600 | 601 | 602 | if __name__ == '__main__': 603 | net = resnet20_cifar() 604 | y = net(torch.randn(1, 3, 64, 64)) 605 | print(net) 606 | print(y.size()) 607 | -------------------------------------------------------------------------------- /lib/networks/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua 8 | """ 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | import torch 14 | 15 | __all__ = ['resnext50', 'resnext101', 'resnext152'] 16 | 17 | class Bottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C 20 | """ 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | baseWidth: base width. 29 | cardinality: num of convolution groups. 30 | stride: conv stride. Replaces pooling layer. 31 | """ 32 | super(Bottleneck, self).__init__() 33 | 34 | D = int(math.floor(planes * (baseWidth / 64))) 35 | C = cardinality 36 | 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(D*C) 39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 40 | self.bn2 = nn.BatchNorm2d(D*C) 41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * 4) 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | self.downsample = downsample 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | """ 72 | ResNext optimized for the ImageNet dataset, as specified in 73 | https://arxiv.org/pdf/1611.05431.pdf 74 | """ 75 | def __init__(self, baseWidth, cardinality, layers, num_classes): 76 | """ Constructor 77 | Args: 78 | baseWidth: baseWidth for ResNeXt. 79 | cardinality: number of convolution groups. 80 | layers: config of layers, e.g., [3, 4, 6, 3] 81 | num_classes: number of classes 82 | """ 83 | super(ResNeXt, self).__init__() 84 | block = Bottleneck 85 | 86 | self.cardinality = cardinality 87 | self.baseWidth = baseWidth 88 | self.num_classes = num_classes 89 | self.inplanes = 64 90 | self.output_size = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 113 | Args: 114 | block: block type used to construct ResNext 115 | planes: number of output channels (need to multiply by block.expansion) 116 | blocks: number of blocks to be built 117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 118 | Returns: a Module consisting of n sequential bottlenecks. 119 | """ 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool1(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | 149 | return x 150 | 151 | 152 | def resnext50(baseWidth=32, cardinality=4, num_classes=1000): 153 | """ 154 | Construct ResNeXt-50. 155 | """ 156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], num_classes) 157 | return model 158 | 159 | 160 | def resnext101(baseWidth, cardinality): 161 | """ 162 | Construct ResNeXt-101. 163 | """ 164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) 165 | return model 166 | 167 | 168 | def resnext152(baseWidth, cardinality): 169 | """ 170 | Construct ResNeXt-152. 171 | """ 172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) 173 | return model 174 | -------------------------------------------------------------------------------- /lib/networks/resnext_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | resneXt for cifar with pytorch 3 | 4 | Reference: 5 | [1] S. Xie, G. Ross, P. Dollar, Z. Tu and K. He Aggregated residual transformations for deep neural networks. In CVPR, 2017 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | 12 | from lib.layers import * 13 | 14 | class _CrossNeuronBlock(nn.Module): 15 | def __init__(self, in_channels, in_height, in_width, 16 | nblocks_channel=8, 17 | spatial_height=32, spatial_width=32, 18 | reduction=8, size_is_consistant=True, 19 | communication=True, 20 | enc_dec=True): 21 | # nblock_channel: number of block along channel axis 22 | # spatial_size: spatial_size 23 | super(_CrossNeuronBlock, self).__init__() 24 | 25 | self.communication = communication 26 | self.enc_dec = enc_dec 27 | 28 | # set channel splits 29 | if in_channels <= 512: 30 | self.nblocks_channel = 1 31 | else: 32 | self.nblocks_channel = in_channels // 512 33 | block_size = in_channels // self.nblocks_channel 34 | block = torch.Tensor(block_size, block_size).fill_(1) 35 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0) 36 | for i in range(self.nblocks_channel): 37 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block) 38 | 39 | # set spatial splits 40 | if in_height * in_width < 16 * 16 and size_is_consistant: 41 | self.spatial_area = in_height * in_width 42 | self.spatial_height = in_height 43 | self.spatial_width = in_width 44 | else: 45 | self.spatial_area = spatial_height * spatial_width 46 | self.spatial_height = spatial_height 47 | self.spatial_width = spatial_width 48 | 49 | self.fc_in = nn.Sequential( 50 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 51 | nn.ReLU(True), 52 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 53 | ) 54 | 55 | self.fc_out = nn.Sequential( 56 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True), 57 | nn.ReLU(True), 58 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True), 59 | ) 60 | 61 | self.bn = nn.BatchNorm1d(self.spatial_area) 62 | 63 | self.initialize() 64 | 65 | def initialize(self): 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv1d): 68 | nn.init.kaiming_normal_(m.weight) 69 | elif isinstance(m, nn.BatchNorm1d): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | def forward(self, x): 74 | ''' 75 | :param x: (bt, c, h, w) 76 | :return: 77 | ''' 78 | bt, c, h, w = x.shape 79 | residual = x 80 | x_stretch = x.view(bt, c, h * w) 81 | spblock_h = int(np.ceil(h / self.spatial_height)) 82 | spblock_w = int(np.ceil(w / self.spatial_width)) 83 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0 84 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0 85 | 86 | if self.spatial_height == h and self.spatial_width == w: 87 | x_stacked = x_stretch # (b) x c x (h * w) 88 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1) 89 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 90 | 91 | if self.enc_dec: 92 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 93 | 94 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c 95 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 96 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 97 | # x_v = F.dropout(x_v, 0.1, self.training) 98 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 99 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 100 | 101 | if self.communication: 102 | if self.enc_dec: 103 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 104 | else: 105 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 106 | else: 107 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 108 | 109 | out = F.dropout(out.permute(0, 2, 1).contiguous().view(bt, c, h, w), 0.0, self.training) 110 | return F.relu(residual + out) 111 | else: 112 | x = F.interpolate(x, (self.spatial_height, self.spatial_width)) 113 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width) 114 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width) 115 | 116 | x_stacked = x_stretch # (b) x c x (h * w) 117 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c 118 | 119 | if self.enc_dec: 120 | x_v = self.fc_in(x_v) # (b) x (h * w) x c 121 | 122 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c 123 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c 124 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v) 125 | # x_v = F.dropout(x_v, 0.1, self.training) 126 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf) 127 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c 128 | if self.communication: 129 | if self.enc_dec: 130 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c 131 | else: 132 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c 133 | else: 134 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c 135 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width) 136 | out = F.dropout(F.interpolate(out, (h, w)), 0.0, self.training) 137 | return F.relu(residual + out) 138 | 139 | class Bottleneck(nn.Module): 140 | expansion = 4 141 | 142 | def __init__(self, inplanes, planes, cardinality, baseWidth, stride=1, downsample=None, use_se=False): 143 | super(Bottleneck, self).__init__() 144 | D = int(planes * (baseWidth / 64.)) 145 | C = cardinality 146 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, bias=False) 147 | self.bn1 = nn.BatchNorm2d(D*C) 148 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 149 | self.bn2 = nn.BatchNorm2d(D*C) 150 | self.conv3 = nn.Conv2d(D*C, planes*4, kernel_size=1, bias=False) 151 | self.bn3 = nn.BatchNorm2d(planes*4) 152 | self.se = nn.SELayer(planes*4, reduction=8) 153 | self.relu = nn.ReLU(inplace=True) 154 | self.downsample = downsample 155 | self.stride = stride 156 | 157 | def forward(self, x): 158 | residual = x 159 | 160 | out = self.conv1(x) 161 | out = self.bn1(out) 162 | out = self.relu(out) 163 | 164 | out = self.conv2(out) 165 | out = self.bn2(out) 166 | out = self.relu(out) 167 | 168 | out = self.conv3(out) 169 | out = self.bn3(out) 170 | out = self.se(out) 171 | 172 | if self.downsample is not None: 173 | residual = self.downsample(x) 174 | 175 | if residual.size() != out.size(): 176 | print(out.size(), residual.size()) 177 | out += residual 178 | out = self.relu(out) 179 | 180 | return out 181 | 182 | 183 | class ResNeXt_Cifar(nn.Module): 184 | 185 | def __init__(self, block, layers, cardinality, baseWidth, num_classes=10, insert_layers=["1", "2", "3"], depth=1, 186 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True): 187 | super(ResNeXt_Cifar, self).__init__() 188 | self.inplanes = 64 189 | 190 | self.insize = 32 191 | self.layers = insert_layers 192 | self.depth = depth 193 | self.has_gtlayer = has_gtlayer 194 | self.has_selayer = has_selayer 195 | 196 | self.cardinality = cardinality 197 | self.baseWidth = baseWidth 198 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 199 | self.bn1 = nn.BatchNorm2d(64) 200 | self.relu = nn.ReLU(inplace=True) 201 | self.layer1 = self._make_layer(block, 64, layers[0]) 202 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 203 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 204 | self.avgpool = nn.AvgPool2d(8, stride=1) 205 | self.fc = nn.Linear(256 * block.expansion, num_classes) 206 | 207 | if self.has_gtlayer: 208 | if "1" in self.layers: 209 | layers = [] 210 | for i in range(self.depth): 211 | layers.append(_CrossNeuronBlock(64, 32, 32, spatial_height=16, spatial_width=16, 212 | communication=communication, enc_dec=enc_dec)) 213 | self.nclayer1 = nn.Sequential(*layers) 214 | 215 | if "2" in self.layers: 216 | layers = [] 217 | for i in range(self.depth): 218 | layers.append(_CrossNeuronBlock(64, 32, 32, spatial_height=16, spatial_width=16, 219 | communication=communication, enc_dec=enc_dec)) 220 | self.nclayer2 = nn.Sequential(*layers) 221 | 222 | if "3" in self.layers: 223 | layers = [] 224 | for i in range(self.depth): 225 | layers.append(_CrossNeuronBlock(128, 16, 16, spatial_height=8, spatial_width=8, 226 | communication=communication, enc_dec=enc_dec)) 227 | self.nclayer3 = nn.Sequential(*layers) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 232 | m.weight.data.normal_(0, math.sqrt(2. / n)) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | m.weight.data.fill_(1) 235 | m.bias.data.zero_() 236 | 237 | def _make_layer(self, block, planes, blocks, stride=1): 238 | downsample = None 239 | if stride != 1 or self.inplanes != planes * block.expansion: 240 | downsample = nn.Sequential( 241 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 242 | nn.BatchNorm2d(planes * block.expansion) 243 | ) 244 | 245 | layers = [] 246 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth, stride, downsample, use_se=self.has_selayer)) 247 | self.inplanes = planes * block.expansion 248 | for _ in range(1, blocks): 249 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth, use_se=self.has_selayer)) 250 | 251 | return nn.Sequential(*layers) 252 | 253 | 254 | def _correlation(self, x): 255 | b, c, h, w = x.shape 256 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw) 257 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw) 258 | x_c = x_v - x_m # b x c x (hw) 259 | # x_c = x_v 260 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c 261 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1 262 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c 263 | out = num / dec 264 | out = torch.abs(out) 265 | return out.mean() 266 | 267 | def forward(self, x): 268 | x = self.conv1(x) 269 | x = self.bn1(x) 270 | x = self.relu(x) 271 | 272 | corr1 = self._correlation(x) 273 | x0 = x.clone().detach() 274 | 275 | if self.has_gtlayer and "1" in self.layers: 276 | x = self.nclayer1(x) 277 | corr1_ = self._correlation(x) 278 | else: 279 | corr1_ = corr1.clone() 280 | 281 | x = self.layer1(x) 282 | x1 = x.clone().detach() 283 | 284 | corr2 = self._correlation(x) 285 | if self.has_gtlayer and "2" in self.layers: 286 | x = self.nclayer2(x) 287 | corr2_ = self._correlation(x) 288 | else: 289 | corr2_ = corr2.clone() 290 | 291 | x = self.layer2(x) 292 | x2 = x.clone().detach() 293 | 294 | corr3 = self._correlation(x) 295 | if self.has_gtlayer and "3" in self.layers: 296 | x = self.nclayer3(x) 297 | corr3_ = self._correlation(x) 298 | else: 299 | corr3_ = corr3.clone() 300 | 301 | x = self.layer3(x) 302 | x3 = x.clone().detach() 303 | 304 | # print("corr0: {}, corr1: {}, corr2: {}".format(corr0, corr1, corr2)) 305 | x = self.avgpool(x) 306 | x = x.view(x.size(0), -1) 307 | x = self.fc(x) 308 | 309 | return x, (x0, x1, x2, x3), (corr1.item(), corr2.item(), corr3.item()), (corr1_.item(), corr2_.item(), corr3_.item()) 310 | 311 | 312 | # def forward(self, x): 313 | # x = self.conv1(x) 314 | # x = self.bn1(x) 315 | # x = self.relu(x) 316 | # 317 | # x = self.layer1(x) 318 | # x = self.layer2(x) 319 | # x = self.layer3(x) 320 | # 321 | # x = self.avgpool(x) 322 | # x = x.view(x.size(0), -1) 323 | # x = self.fc(x) 324 | # 325 | # return x 326 | 327 | 328 | def resneXt110_cifar(**kwargs): 329 | model = ResNeXt_Cifar(BasicBlock, [18, 18, 18], **kwargs) 330 | return model 331 | 332 | def resneXt164_cifar(**kwargs): 333 | model = ResNeXt_Cifar(Bottleneck, [18, 18, 18], **kwargs) 334 | return model 335 | 336 | def resneXt_cifar(depth, cardinality, baseWidth, **kwargs): 337 | assert (depth - 2) % 9 == 0 338 | n = (depth - 2) / 9 339 | model = ResNeXt_Cifar(Bottleneck, [n, n, n], cardinality, baseWidth, **kwargs) 340 | return model 341 | 342 | 343 | if __name__ == '__main__': 344 | net = resneXt_cifar(29, 16, 64) 345 | y = net(torch.randn(1, 3, 32, 32)) 346 | print(net) 347 | print(y.size()) 348 | -------------------------------------------------------------------------------- /lib/networks/wide_resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | wide resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016. 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | from .resnet_cifar_analysis import BasicBlock 11 | 12 | 13 | class Wide_ResNet_Cifar(nn.Module): 14 | 15 | def __init__(self, block, layers, wfactor, num_classes=100, has_selayer=False, has_gtlayer=False): 16 | super(Wide_ResNet_Cifar, self).__init__() 17 | self.inplanes = 16 18 | self.insize = 32 19 | self.has_selayer = has_selayer 20 | self.has_gtlayer = has_gtlayer 21 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(16) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.layer1 = self._make_layer(block, 16*wfactor, layers[0]) 25 | self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2) 26 | self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2) 27 | self.avgpool = nn.AvgPool2d(8, stride=1) 28 | self.fc = nn.Linear(64*block.expansion*wfactor, num_classes) 29 | 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | def _make_layer(self, block, planes, blocks, stride=1): 39 | downsample = None 40 | if stride != 1 or self.inplanes != planes * block.expansion: 41 | downsample = nn.Sequential( 42 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 43 | nn.BatchNorm2d(planes * block.expansion) 44 | ) 45 | 46 | self.insize = int(self.insize / stride) 47 | layers = [] 48 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=self.has_selayer, use_gt=self.has_gtlayer)) 49 | self.inplanes = planes * block.expansion 50 | for _ in range(1, blocks): 51 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer)) 52 | 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | x = self.conv1(x) 57 | x = self.bn1(x) 58 | x = self.relu(x) 59 | 60 | x = self.layer1(x) 61 | x = self.layer2(x) 62 | x = self.layer3(x) 63 | 64 | x = self.avgpool(x) 65 | x = x.view(x.size(0), -1) 66 | x = self.fc(x) 67 | 68 | return x 69 | 70 | 71 | def wresnet20_cifar(**kwargs): 72 | return Wide_ResNet_Cifar(BasicBlock, [3, 3, 3], 10, **kwargs) 73 | 74 | def wide_resnet_cifar(depth, width, **kwargs): 75 | assert (depth - 2) % 6 == 0 76 | n = (depth - 2) / 6 77 | return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs) 78 | 79 | if __name__=='__main__': 80 | net = wide_resnet_cifar(20, 10) 81 | y = net(torch.randn(1, 3, 32, 32)) 82 | print(isinstance(net, Wide_ResNet_Cifar)) 83 | print(y.size()) 84 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .verbo import * 2 | -------------------------------------------------------------------------------- /lib/utils/cub2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision.datasets.utils import download_url 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class Cub2011(Dataset): 9 | base_folder = 'CUB_200_2011/images' 10 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 11 | filename = 'CUB_200_2011.tgz' 12 | tgz_md5 = '97eceeb196236b17998738112f37df78' 13 | 14 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True): 15 | self.root = os.path.expanduser(root) 16 | self.transform = transform 17 | self.loader = default_loader 18 | self.train = train 19 | 20 | if download: 21 | self._download() 22 | 23 | if not self._check_integrity(): 24 | raise RuntimeError('Dataset not found or corrupted.' + 25 | ' You can use download=True to download it') 26 | 27 | def _load_metadata(self): 28 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 29 | names=['img_id', 'filepath']) 30 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 31 | sep=' ', names=['img_id', 'target']) 32 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 33 | sep=' ', names=['img_id', 'is_training_img']) 34 | 35 | data = images.merge(image_class_labels, on='img_id') 36 | self.data = data.merge(train_test_split, on='img_id') 37 | 38 | if self.train: 39 | self.data = self.data[self.data.is_training_img == 1] 40 | else: 41 | self.data = self.data[self.data.is_training_img == 0] 42 | 43 | def _check_integrity(self): 44 | try: 45 | self._load_metadata() 46 | except Exception: 47 | return False 48 | 49 | for index, row in self.data.iterrows(): 50 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 51 | if not os.path.isfile(filepath): 52 | import pdb; pdb.set_trace() 53 | print(filepath) 54 | return False 55 | return True 56 | 57 | def _download(self): 58 | import tarfile 59 | 60 | if self._check_integrity(): 61 | print('Files already downloaded and verified') 62 | return 63 | 64 | download_url(self.url, self.root, self.filename, self.tgz_md5) 65 | 66 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 67 | tar.extractall(path=self.root) 68 | 69 | def __len__(self): 70 | return len(self.data) 71 | 72 | def __getitem__(self, idx): 73 | sample = self.data.iloc[idx] 74 | path = os.path.join(self.root, self.base_folder, sample.filepath) 75 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 76 | img = self.loader(path) 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | return img, target 82 | -------------------------------------------------------------------------------- /lib/utils/imagenet.py: -------------------------------------------------------------------------------- 1 | # dataloader respecting the PyTorch conventions, but using tensorpack to load and process 2 | # includes typical augmentations for ImageNet training 3 | 4 | import os 5 | 6 | import cv2 7 | import torch 8 | 9 | import numpy as np 10 | import tensorpack.dataflow as td 11 | from tensorpack import imgaug 12 | from tensorpack.dataflow import (AugmentImageComponent, PrefetchDataZMQ, 13 | BatchData, MultiThreadMapData) 14 | 15 | ##################################################################################################### 16 | # copied from: https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/imagenet_utils.py # 17 | ##################################################################################################### 18 | class GoogleNetResize(imgaug.ImageAugmentor): 19 | """ 20 | crop 8%~100% of the original image 21 | See `Going Deeper with Convolutions` by Google. 22 | """ 23 | def __init__(self, crop_area_fraction=0.08, 24 | aspect_ratio_low=0.75, aspect_ratio_high=1.333, 25 | target_shape=224): 26 | self._init(locals()) 27 | 28 | def _augment(self, img, _): 29 | h, w = img.shape[:2] 30 | area = h * w 31 | for _ in range(10): 32 | targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area 33 | aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) 34 | ww = int(np.sqrt(targetArea * aspectR) + 0.5) 35 | hh = int(np.sqrt(targetArea / aspectR) + 0.5) 36 | if self.rng.uniform() < 0.5: 37 | ww, hh = hh, ww 38 | if hh <= h and ww <= w: 39 | x1 = 0 if w == ww else self.rng.randint(0, w - ww) 40 | y1 = 0 if h == hh else self.rng.randint(0, h - hh) 41 | out = img[y1:y1 + hh, x1:x1 + ww] 42 | out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) 43 | return out 44 | out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) 45 | out = imgaug.CenterCrop(self.target_shape).augment(out) 46 | return out 47 | 48 | 49 | def fbresnet_augmentor(isTrain): 50 | """ 51 | Augmentor used in fb.resnet.torch, for BGR images in range [0,255]. 52 | """ 53 | if isTrain: 54 | augmentors = [ 55 | GoogleNetResize(), 56 | imgaug.RandomOrderAug( 57 | [imgaug.BrightnessScale((0.6, 1.4), clip=False), 58 | imgaug.Contrast((0.6, 1.4), clip=False), 59 | imgaug.Saturation(0.4, rgb=False), 60 | # rgb-bgr conversion for the constants copied from fb.resnet.torch 61 | imgaug.Lighting(0.1, 62 | eigval=np.asarray( 63 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 64 | eigvec=np.array( 65 | [[-0.5675, 0.7192, 0.4009], 66 | [-0.5808, -0.0045, -0.8140], 67 | [-0.5836, -0.6948, 0.4203]], 68 | dtype='float32')[::-1, ::-1] 69 | )]), 70 | imgaug.Flip(horiz=True), 71 | ] 72 | else: 73 | augmentors = [ 74 | imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), 75 | imgaug.CenterCrop((224, 224)), 76 | ] 77 | return augmentors 78 | ##################################################################################################### 79 | ##################################################################################################### 80 | 81 | 82 | numpy_type_map = { 83 | 'float64': torch.DoubleTensor, 84 | 'float32': torch.FloatTensor, 85 | 'float16': torch.HalfTensor, 86 | 'int64': torch.LongTensor, 87 | 'int32': torch.IntTensor, 88 | 'int16': torch.ShortTensor, 89 | 'int8': torch.CharTensor, 90 | 'uint8': torch.ByteTensor, 91 | } 92 | 93 | 94 | def default_collate(batch): 95 | "Puts each data field into a tensor with outer dimension batch size" 96 | 97 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 98 | elem_type = type(batch[0]) 99 | if torch.is_tensor(batch[0]): 100 | out = None 101 | if _use_shared_memory: 102 | # If we're in a background process, concatenate directly into a 103 | # shared memory tensor to avoid an extra copy 104 | numel = sum([x.numel() for x in batch]) 105 | storage = batch[0].storage()._new_shared(numel) 106 | out = batch[0].new(storage) 107 | return torch.stack(batch, 0, out=out) 108 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 109 | and elem_type.__name__ != 'string_': 110 | elem = batch[0] 111 | if elem_type.__name__ == 'ndarray': 112 | # array of string classes and object 113 | if re.search('[SaUO]', elem.dtype.str) is not None: 114 | raise TypeError(error_msg.format(elem.dtype)) 115 | 116 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 117 | if elem.shape == (): # scalars 118 | py_type = float if elem.dtype.name.startswith('float') else int 119 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 120 | elif isinstance(batch[0], int): 121 | return torch.LongTensor(batch) 122 | elif isinstance(batch[0], float): 123 | return torch.DoubleTensor(batch) 124 | elif isinstance(batch[0], string_classes): 125 | return batch 126 | elif isinstance(batch[0], collections.Mapping): 127 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 128 | elif isinstance(batch[0], collections.Sequence): 129 | transposed = zip(*batch) 130 | return [default_collate(samples) for samples in transposed] 131 | 132 | raise TypeError((error_msg.format(type(batch[0])))) 133 | 134 | 135 | class Loader(object): 136 | """ 137 | Data loader. Combines a dataset and a sampler, and provides 138 | single- or multi-process iterators over the dataset. 139 | 140 | Arguments: 141 | mode (str, required): mode of dataset to operate in, one of ['train', 'val'] 142 | batch_size (int, optional): how many samples per batch to load 143 | (default: 1). 144 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 145 | at every epoch (default: False). 146 | num_workers (int, optional): how many subprocesses to use for data 147 | loading. 0 means that the data will be loaded in the main process 148 | (default: 0) 149 | cache (int, optional): cache size to use when loading data, 150 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 151 | if the dataset size is not divisible by the batch size. If ``False`` and 152 | the size of dataset is not divisible by the batch size, then the last batch 153 | will be smaller. (default: False) 154 | cuda (bool, optional): set to ``True`` and the PyTorch tensors will get preloaded 155 | to the GPU for you (necessary because this lets us to uint8 conversion on the 156 | GPU, which is faster). 157 | """ 158 | 159 | def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000, 160 | collate_fn=default_collate, drop_last=False, cuda=False): 161 | # enumerate standard imagenet augmentors 162 | imagenet_augmentors = fbresnet_augmentor(mode == 'train') 163 | 164 | # load the lmdb if we can find it 165 | lmdb_loc = os.path.join('/srv/share/datasets/ImageNet','ILSVRC-%s.lmdb'%mode) 166 | ds = td.LMDBSerializer.load(lmdb_loc, shuffle=False) 167 | ds = td.LocallyShuffleData(ds, cache) 168 | ds = td.PrefetchData(ds, 5000, 1) 169 | # ds = td.LMDBDataPoint(ds) 170 | ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) 171 | ds = td.AugmentImageComponent(ds, imagenet_augmentors) 172 | ds = td.PrefetchDataZMQ(ds, num_workers) 173 | self.ds = td.BatchData(ds, batch_size) 174 | self.ds.reset_state() 175 | 176 | self.batch_size = batch_size 177 | self.num_workers = num_workers 178 | self.cuda = cuda 179 | #self.drop_last = drop_last 180 | 181 | def __iter__(self): 182 | for x, y in self.ds.get_data(): 183 | if self.cuda: 184 | # images come out as uint8, which are faster to copy onto the gpu 185 | x = torch.ByteTensor(x).cuda() 186 | y = torch.IntTensor(y).cuda() 187 | # but once they're on the gpu, we'll need them in 188 | yield uint8_to_float(x), y.long() 189 | # yield (x), y.long() 190 | else: 191 | yield uint8_to_float(torch.ByteTensor(x)), torch.IntTensor(y).long() 192 | 193 | def __len__(self): 194 | return self.ds.size() 195 | 196 | def uint8_to_float(x): 197 | x = x.permute(0,3,1,2) # pytorch is (n,c,w,h) 198 | return x.float()/128. - 1. 199 | -------------------------------------------------------------------------------- /lib/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torchvision.models as models 7 | import cv2 8 | import pdb 9 | import random 10 | 11 | def adjust_learning_rate(optimizer, decay=0.1): 12 | """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs""" 13 | for param_group in optimizer.param_groups: 14 | param_group['lr'] = decay * param_group['lr'] 15 | 16 | def weights_normal_init(model, dev=0.01): 17 | if isinstance(model, list): 18 | for m in model: 19 | weights_normal_init(m, dev) 20 | else: 21 | for m in model.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | m.weight.data.normal_(0.0, dev) 24 | elif isinstance(m, nn.Linear): 25 | m.weight.data.normal_(0.0, dev) 26 | -------------------------------------------------------------------------------- /lib/utils/verbo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, name, fmt=':f'): 7 | self.name = name 8 | self.fmt = fmt 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | def __str__(self): 24 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 25 | return fmtstr.format(**self.__dict__) 26 | 27 | class ProgressMeter(object): 28 | def __init__(self, num_batches, *meters, prefix=""): 29 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 30 | self.meters = meters 31 | self.prefix = prefix 32 | 33 | def print(self, batch): 34 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 35 | entries += [str(meter) for meter in self.meters] 36 | print(' '.join(entries)) 37 | 38 | def _get_batch_fmtstr(self, num_batches): 39 | num_digits = len(str(num_batches // 1)) 40 | fmt = '{:' + str(num_digits) + 'd}' 41 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 42 | -------------------------------------------------------------------------------- /lib/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import cv2 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import pdb 5 | import time 6 | import yaml 7 | import os 8 | import os.path as osp 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torchvision import datasets, transforms 13 | from lib.utils.net_utils import weights_normal_init, adjust_learning_rate 14 | from configs import cfg 15 | from lib import * 16 | from ptflops import get_model_complexity_info 17 | from thop import profile 18 | 19 | def compute_accuracy(output, target, topk=(1,)): 20 | """Computes the accuracy over the k top predictions for the specified values of k""" 21 | with torch.no_grad(): 22 | maxk = max(topk) 23 | batch_size = target.size(0) 24 | 25 | _, pred = output.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | def graph_node_loss(graphs, target): 36 | loss = 0 37 | for graph in graphs: 38 | node = graph[0] 39 | B, N = node.shape 40 | mean_node = torch.mm(node, node.transpose(0, 1).contiguous()) 41 | loss += (mean_node.sum() - torch.diag(mean_node).sum()) / N / N 42 | return loss 43 | 44 | def graph_edge_loss(graphs, target): 45 | loss = 0 46 | for graph in graphs: 47 | edge = graph[1] 48 | B, N, _ = edge.shape 49 | edge = F.relu(edge) 50 | loss += (edge.mean(0).sum() - torch.diag(edge.mean(0)).sum()) / N / N 51 | return loss 52 | 53 | def train(args, model, device, train_loader, optimizer, epoch): 54 | model.train() 55 | batch_time = AverageMeter('Time', ':3.3f') 56 | data_time = AverageMeter('Data', ':3.3f') 57 | mem_cost = AverageMeter('Mem', ':3.3f') 58 | losses = AverageMeter('Loss', ':.3f') 59 | top1 = AverageMeter('Acc@1', ':3.2f') 60 | top5 = AverageMeter('Acc@5', ':3.2f') 61 | layer1_bf = AverageMeter('Corr@bf', ':3.2f') 62 | layer1_af = AverageMeter('Corr@af', ':3.2f') 63 | layer2_bf = AverageMeter('Corr@bf', ':3.2f') 64 | layer2_af = AverageMeter('Corr@af', ':3.2f') 65 | layer3_bf = AverageMeter('Corr@bf', ':3.2f') 66 | layer3_af = AverageMeter('Corr@af', ':3.2f') 67 | 68 | progress = ProgressMeter(len(train_loader), mem_cost, batch_time, data_time, losses, top1, 69 | top5, layer1_bf, layer1_af, layer2_bf, layer2_af, layer3_bf, layer3_af, 70 | prefix="Epoch: [{}]".format(epoch)) 71 | end = time.time() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data_time.update(time.time() - end) 74 | data, target = data.to(device), target.to(device) 75 | optimizer.zero_grad() 76 | output = model(data) 77 | mem = torch.cuda.max_memory_allocated() 78 | mem_cost.update(mem / 1024 / 1024 / 1024) 79 | 80 | if model.net.name == "invcnn": 81 | loss = 0 82 | for out in output: 83 | out = F.log_softmax(out, dim=1) 84 | loss += F.nll_loss(out, target) 85 | loss /= len(output) 86 | else: 87 | output = F.log_softmax(output, dim=1) 88 | loss = F.nll_loss(output, target) 89 | 90 | acc1, acc5 = compute_accuracy(output[-1] if model.net.name == "invcnn" else output, target, topk=(1, 5)) 91 | losses.update(loss.item(), data.size(0)) 92 | top1.update(acc1[0], data.size(0)) 93 | top5.update(acc5[0], data.size(0)) 94 | layer1_bf.update(model.net.layer1.cn.corr_bf.item()) 95 | layer1_af.update(model.net.layer1.cn.corr_af.item()) 96 | layer2_bf.update(model.net.layer2.cn.corr_bf.item()) 97 | layer2_af.update(model.net.layer2.cn.corr_af.item()) 98 | layer3_bf.update(model.net.layer3.cn.corr_bf.item()) 99 | layer3_af.update(model.net.layer3.cn.corr_af.item()) 100 | # import pdb; pdb.set_trace() 101 | # node_loss = graph_node_loss(graphs, target) 102 | # edge_loss = graph_edge_loss(graphs, target) 103 | loss.backward() 104 | optimizer.step() 105 | 106 | # measure elapsed time 107 | batch_time.update(time.time() - end) 108 | end = time.time() 109 | 110 | if batch_idx % args.log.print_interval == 0: 111 | progress.print(batch_idx) 112 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tmem: {:.3f}m\tLoss: {:.6f}'.format( 113 | # epoch, batch_idx * len(data), len(train_loader.dataset), mem / 1024 / 1024, 114 | # 100. * batch_idx / len(train_loader), loss.item())) 115 | 116 | def test(args, model, device, test_loader): 117 | batch_time = AverageMeter('Time', ':3.3f') 118 | losses = AverageMeter('Loss', ':.3f') 119 | mem_cost = AverageMeter('Mem', ':3.3f') 120 | top1 = AverageMeter('Acc@1', ':3.2f') 121 | top5 = AverageMeter('Acc@5', ':3.2f') 122 | progress = ProgressMeter(len(test_loader), mem_cost, batch_time, losses, top1, top5, 123 | prefix='Test: ') 124 | 125 | model.eval() 126 | test_loss = 0 127 | correct = 0 128 | with torch.no_grad(): 129 | end = time.time() 130 | for batch_idx, (data, target) in enumerate(test_loader): 131 | data, target = data.to(device), target.to(device) 132 | output = model(data) 133 | if model.net.name == "invcnn": 134 | output = torch.stack(output, 1).max(1)[0] 135 | mem = torch.cuda.max_memory_allocated() 136 | mem_cost.update(mem / 1024 / 1024 / 1024) 137 | 138 | # measure accuracy and record loss 139 | acc1, acc5 = compute_accuracy(output, target, topk=(1, 5)) 140 | top1.update(acc1[0], data.size(0)) 141 | top5.update(acc5[0], data.size(0)) 142 | 143 | # measure elapsed time 144 | batch_time.update(time.time() - end) 145 | end = time.time() 146 | 147 | # compute val loss 148 | output = F.log_softmax(output, dim=1) 149 | loss = F.nll_loss(output, target).item() # sum up batch loss 150 | test_loss += loss 151 | losses.update(loss, data.size(0)) 152 | 153 | # compute accuracy 154 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 155 | correct += pred.eq(target.view_as(pred)).sum().item() 156 | 157 | if batch_idx % args.log.print_interval == 0: 158 | progress.print(batch_idx) 159 | test_loss /= batch_idx 160 | data_size = len(test_loader.dataset) if hasattr(test_loader, 'dataset') else len(test_loader) 161 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format( 162 | test_loss, correct, data_size, 163 | 100. * float(correct) / data_size)) 164 | accuracy = 100. * float(correct) / data_size 165 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 166 | .format(top1=top1, top5=top5)) 167 | return accuracy 168 | 169 | def count_parameters(model): 170 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 171 | 172 | def main(): 173 | # Training settings 174 | parser = argparse.ArgumentParser(description='PyTorch Image Classification') 175 | parser.add_argument('--dataset', type=str, default='cifar100', 176 | help='specify training dataset') 177 | parser.add_argument('--session', type=int, default='1', 178 | help='training session to recoder multiple runs') 179 | parser.add_argument('--arch', type=str, default='resnet110', 180 | help='specify network architecture') 181 | parser.add_argument('--bs', dest="batch_size", type=int, default=128, 182 | help='training batch size') 183 | parser.add_argument('--gpu0-bs', dest="gpu0_bs", type=int, default=0, 184 | help='training batch size on gpu0') 185 | parser.add_argument('--add-ccn', type=str, default='no', 186 | help='add cross neruon communication') 187 | parser.add_argument('--mgpus', type=str, default="no", 188 | help='multi-gpu training') 189 | parser.add_argument('--resume', dest="resume", type=int, default=0, 190 | help='resume epoch') 191 | 192 | 193 | args = parser.parse_args() 194 | cfg.merge_from_file(osp.join("configs", args.dataset + ".yaml")) 195 | cfg.dataset = args.dataset 196 | cfg.arch = args.arch 197 | cfg.add_cross_neuron = True if args.add_ccn == "yes" else False 198 | use_cuda = True if torch.cuda.is_available() else False 199 | cfg.use_cuda = use_cuda 200 | cfg.training.batch_size = args.batch_size 201 | cfg.mGPUs = True if args.mgpus == "yes" else False 202 | 203 | torch.manual_seed(cfg.initialize.seed) 204 | device = torch.device("cuda" if use_cuda else "cpu") 205 | train_loader, test_loader = create_data_loader(cfg) 206 | model = CrossNeuronNet(cfg) 207 | print("parameter numer: %d" % (count_parameters(model))) 208 | with torch.cuda.device(0): 209 | if args.dataset == "cifar100": 210 | flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True) 211 | # flops, params = profile(model, input_size=(1, 3, 32, 32)) 212 | elif args.dataset == "imagenet": 213 | flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True) 214 | # flops, params = profile(model, input_size=(1, 3, 224, 224)) 215 | print('Flops: {}'.format(flops)) 216 | print('Params: {}'.format(params)) 217 | 218 | model = model.to(device) 219 | 220 | # optimizer_policy = model.get_optim_policies() 221 | optimizer = optim.SGD(model.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay) 222 | # optimizer = optim.Adam(model.parameters(), lr=1e-3) 223 | if cfg.mGPUs: 224 | if args.gpu0_bs > 0: 225 | model = BalancedDataParallel(args.gpu0_bs, model).to(device) 226 | else: 227 | model = nn.DataParallel(model).to(device) 228 | 229 | lr = cfg.optimizer.lr 230 | checkpoint_tag = osp.join("checkponts", args.dataset, args.arch) 231 | if not osp.exists(checkpoint_tag): 232 | os.makedirs(checkpoint_tag) 233 | 234 | if args.resume > 0: 235 | ckpt_path = osp.join(checkpoint_tag, 236 | ("ccn" if cfg.add_cross_neuron else "plain") + "_{}_{}.pth".format(args.session, args.resume)) 237 | print("resume model from {}".format(ckpt_path)) 238 | ckpt = torch.load(ckpt_path) 239 | model.load_state_dict(ckpt["model"]) 240 | print("resume model succesfully") 241 | acc = test(cfg, model, device, test_loader) 242 | 243 | best_acc = 0 244 | for epoch in range(args.resume + 1, cfg.optimizer.max_epoch + 1): 245 | if epoch in cfg.optimizer.lr_decay_schedule: 246 | adjust_learning_rate(optimizer, cfg.optimizer.lr_decay_gamma) 247 | lr *= cfg.optimizer.lr_decay_gamma 248 | print('Train Epoch: {} learning rate: {}'.format(epoch, lr)) 249 | tic = time.time() 250 | train(cfg, model, device, train_loader, optimizer, epoch) 251 | acc = test(cfg, model, device, test_loader) 252 | time_cost = time.time() - tic 253 | if acc > best_acc: 254 | best_acc = acc 255 | print('\nModel: {} Best Accuracy-Baseline: {}\tTime Cost per Epoch: {}\n'.format( 256 | checkpoint_tag + ("ccn" if args.add_ccn == "yes" else "plain"), 257 | best_acc, 258 | time_cost)) 259 | 260 | if epoch % cfg.log.checkpoint_interval == 0: 261 | checkpoint = {"arch": cfg.arch, 262 | "model": model.state_dict(), 263 | "epoch": epoch, 264 | "lr": lr, 265 | "test_acc": acc, 266 | "best_acc": best_acc} 267 | torch.save(checkpoint, osp.join(checkpoint_tag, 268 | ("ccn" if cfg.add_cross_neuron else "plain") + "_{}_{}.pth".format(args.session, epoch))) 269 | 270 | 271 | if __name__ == '__main__': 272 | torch.manual_seed(1) 273 | torch.cuda.manual_seed(1) 274 | torch.backends.cudnn.deterministic = True 275 | torch.backends.cudnn.benchmark = False 276 | main() 277 | --------------------------------------------------------------------------------