├── .gitignore
├── LICENSE
├── README.md
├── configs
├── __init__.py
├── cifar10.yaml
├── cifar100.yaml
├── cub2011.yaml
├── default.py
└── imagenet.yaml
├── figures
├── ._C3Net_framework.png
├── C3Net_framework.png
└── C3Net_motivation.png
├── lib
├── .DS_Store
├── ._.DS_Store
├── ._model.py
├── __init__.py
├── data.py
├── layers
│ ├── __init__.py
│ ├── cross_neuron.py
│ ├── cross_neuron_distributed.py
│ ├── data_parallel.py
│ └── selayer.py
├── model.py
├── model_analysis.py
├── networks
│ ├── .DS_Store
│ ├── ._.DS_Store
│ ├── __init__.py
│ ├── densenet_cifar.py
│ ├── invcnn_pytorch
│ │ ├── __init__.py
│ │ └── invnet
│ │ │ ├── __init__.py
│ │ │ └── inv_cnn.py
│ ├── mobilenet_v2.py
│ ├── resnet_cifar.py
│ ├── resnet_cifar_analysis.py
│ ├── resnet_cifar_analysis1.py
│ ├── resnext.py
│ ├── resnext_cifar.py
│ └── wide_resnet_cifar.py
└── utils
│ ├── __init__.py
│ ├── cub2011.py
│ ├── imagenet.py
│ ├── net_utils.py
│ ├── verbo.py
│ └── visualize.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | # lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Jianwei Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C3Net
2 | Pytorch implementation of our NeurIPS 2019 paper ["Cross-channel Communication Networks"](https://papers.nips.cc/paper/8411-cross-channel-communication-networks.pdf)
3 |
4 |
5 |

6 |
7 |
8 | ## Introduction
9 |
10 | As shown above, the motivation behind our proposed C3Net is that:
11 |
12 | * Neurons at the same layer do not directly interact with each other.
13 | * Different neurons might respond to the same patterns and locations.
14 |
15 | In this paper, we mainly focus on convolutional neural networks (CNN). In CNN, channel responses naturally encodes which pattern is at where. Our main idea is to enable channels at the same layer to communicate with each other and then calibrate their responses accordingly. We want different filters learn to focus on different useful patterns, so that they are complementary to each other.
16 |
17 | The main contributions are:
18 |
19 | * We proposed cross-channel communication (C3) block to enable full interactions across channels at the same layer.
20 | * It achieved better performance on image classification, object detection and semantic segmentation.
21 | * It captured more diverse representations with light-weight networks.
22 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from .default import _C as cfg
2 |
--------------------------------------------------------------------------------
/configs/cifar10.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | input_size: 32
3 | initialize:
4 | seed: 1
5 | training:
6 | batch_size: 128
7 | workers: 8
8 | test:
9 | batch_size: 1000
10 | workers: 8
11 | optimizer:
12 | name: 'sgd'
13 | lr: 0.1
14 | lr_decay_gamma: 0.1
15 | weight_decay: 1e-4
16 | momentum: 0.9
17 | lr_decay_schedule: (100, 140)
18 | max_epoch: 160
19 | log:
20 | print_interval: 20
21 | checkpoint_interval: 2
22 |
--------------------------------------------------------------------------------
/configs/cifar100.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | input_size: 32
3 | initialize:
4 | seed: 1
5 | training:
6 | batch_size: 128
7 | workers: 8
8 | test:
9 | batch_size: 1000
10 | workers: 8
11 | optimizer:
12 | name: 'sgd'
13 | lr: 0.1
14 | lr_decay_gamma: 0.1
15 | weight_decay: 1e-4
16 | momentum: 0.9
17 | lr_decay_schedule: (120, 140)
18 | max_epoch: 160
19 | log:
20 | print_interval: 20
21 | checkpoint_interval: 20
22 |
--------------------------------------------------------------------------------
/configs/cub2011.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | input_size: 224
3 | initialize:
4 | seed: 1
5 | training:
6 | batch_size: 32
7 | workers: 8
8 | test:
9 | batch_size: 128
10 | workers: 8
11 | optimizer:
12 | name: 'sgd'
13 | lr: 0.001
14 | lr_decay_gamma: 0.9
15 | weight_decay: 1e-4
16 | momentum: 0.9
17 | lr_decay_schedule: (10, 20, 30, 40, 50, 60, 70, 80)
18 | max_epoch: 80
19 | log:
20 | print_interval: 20
21 | checkpoint_interval: 2
22 |
--------------------------------------------------------------------------------
/configs/default.py:
--------------------------------------------------------------------------------
1 | import os
2 | from yacs.config import CfgNode as CN
3 |
4 | _C = CN()
5 |
6 | _C.data = CN()
7 | _C.data.input_size = 32
8 | _C.data.traindir = ""
9 | _C.data.valdir = ""
10 | _C.data.testdir = ""
11 |
12 | _C.initialize = CN()
13 | _C.initialize.seed = 1
14 |
15 | _C.training = CN()
16 | _C.training.batch_size = 128
17 | _C.training.workers = 8
18 | _C.training.distributed = False
19 |
20 | _C.test = CN()
21 | _C.test.batch_size = 1000
22 | _C.test.workers = 8
23 |
24 | _C.optimizer = CN()
25 | _C.optimizer.name = "sgd"
26 | _C.optimizer.lr = 0.1
27 | _C.optimizer.lr_decay_gamma = 0.1
28 | _C.optimizer.weight_decay = 1e-4
29 | _C.optimizer.momentum = 0.9
30 | _C.optimizer.lr_decay_schedule = (60, 120)
31 | _C.optimizer.max_epoch = 160
32 |
33 | _C.log = CN()
34 | _C.log.print_interval = 20
35 | _C.log.checkpoint_interval = 2
36 |
--------------------------------------------------------------------------------
/configs/imagenet.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | input_size: 224
3 | traindir: "/srv/share/datasets/ImageNet/pytorch_folder/train"
4 | valdir: "/srv/share/datasets/ImageNet/pytorch_folder/val"
5 | initialize:
6 | seed: 1
7 | training:
8 | batch_size: 512
9 | workers: 8
10 | distributed: False
11 | test:
12 | batch_size: 256
13 | workers: 8
14 | optimizer:
15 | name: 'sgd'
16 | lr: 0.1
17 | lr_decay_gamma: 0.1
18 | weight_decay: 1e-4
19 | momentum: 0.9
20 | lr_decay_schedule: (30, 60)
21 | max_epoch: 90
22 | log:
23 | print_interval: 20
24 | checkpoint_interval: 2
25 |
--------------------------------------------------------------------------------
/figures/._C3Net_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/._C3Net_framework.png
--------------------------------------------------------------------------------
/figures/C3Net_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/C3Net_framework.png
--------------------------------------------------------------------------------
/figures/C3Net_motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/figures/C3Net_motivation.png
--------------------------------------------------------------------------------
/lib/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/.DS_Store
--------------------------------------------------------------------------------
/lib/._.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/._.DS_Store
--------------------------------------------------------------------------------
/lib/._model.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/._model.py
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import *
2 | from .model import *
3 | from .model_analysis import *
4 | from .utils import *
5 |
--------------------------------------------------------------------------------
/lib/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | from .utils.imagenet import Loader
4 | from .utils.cub2011 import Cub2011
5 | def get_cifar10(opts):
6 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {}
7 | train_loader = torch.utils.data.DataLoader(
8 | datasets.CIFAR10('../data', train=True, download=True,
9 | transform=transforms.Compose([
10 | transforms.RandomCrop(32, padding=4),
11 | transforms.RandomHorizontalFlip(),
12 | transforms.ToTensor(),
13 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
14 | ])),
15 | batch_size=opts.training.batch_size, shuffle=True, **kwargs)
16 | test_loader = torch.utils.data.DataLoader(
17 | datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
18 | transforms.ToTensor(),
19 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
20 | ])),
21 | batch_size=opts.test.batch_size, shuffle=True, **kwargs)
22 | return train_loader, test_loader
23 |
24 | def get_cifar100(opts):
25 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {}
26 | train_loader = torch.utils.data.DataLoader(
27 | datasets.CIFAR100('../data', train=True, download=True,
28 | transform=transforms.Compose([
29 | transforms.RandomCrop(32, padding=4),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.ToTensor(),
32 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
33 | ])),
34 | batch_size=opts.training.batch_size, shuffle=True, **kwargs)
35 | test_loader = torch.utils.data.DataLoader(
36 | datasets.CIFAR100('../data', train=False, transform=transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]),
39 | ])),
40 | batch_size=opts.test.batch_size, shuffle=False, **kwargs)
41 | return train_loader, test_loader
42 |
43 | def get_cub2011(opts):
44 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
45 | kwargs = {'num_workers': opts.training.workers, 'pin_memory': True} if opts.use_cuda else {}
46 | train_loader = torch.utils.data.DataLoader(
47 | Cub2011('../data', train=True, download=False,
48 | transform=transforms.Compose([
49 | transforms.Resize(256),
50 | transforms.RandomCrop(224),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | normalize,
54 | ])),
55 | batch_size=opts.training.batch_size, shuffle=True, **kwargs)
56 | test_loader = torch.utils.data.DataLoader(
57 | Cub2011('../data', train=False, transform=transforms.Compose([
58 | transforms.Resize(256),
59 | transforms.CenterCrop(224),
60 | transforms.ToTensor(),
61 | normalize,
62 | ])),
63 | batch_size=opts.test.batch_size, shuffle=True, **kwargs)
64 | return train_loader, test_loader
65 |
66 | def get_imagenet(opts):
67 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
68 | # train_dataset = datasets.ImageFolder(
69 | # opts.data.traindir,
70 | # transforms.Compose([
71 | # transforms.RandomResizedCrop(224),
72 | # transforms.RandomHorizontalFlip(),
73 | # transforms.ToTensor(),
74 | # normalize,
75 | # ]))
76 | #
77 | # if opts.training.distributed:
78 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
79 | # else:
80 | # train_sampler = None
81 | #
82 | # train_loader = torch.utils.data.DataLoader(
83 | # train_dataset, batch_size=opts.training.batch_size, shuffle=(train_sampler is None),
84 | # num_workers=opts.training.workers, pin_memory=True, sampler=train_sampler)
85 |
86 | # import pdb; pdb.set_trace()
87 |
88 | train_loader = Loader('train', batch_size=opts.training.batch_size, num_workers=opts.training.workers)
89 |
90 | val_dataset = datasets.ImageFolder(
91 | opts.data.valdir,
92 | transforms.Compose([
93 | transforms.Resize(256),
94 | transforms.CenterCrop(224),
95 | transforms.ToTensor(),
96 | normalize,
97 | ]))
98 | val_loader = torch.utils.data.DataLoader(
99 | val_dataset, batch_size=256, shuffle=False,
100 | num_workers=opts.test.workers, pin_memory=True)
101 | # val_loader = Loader('val', batch_size=opts.test.batch_size, num_workers=opts.test.workers)
102 | return train_loader, val_loader
103 |
104 | def create_data_loader(opts):
105 | if opts.dataset == "cifar10":
106 | return get_cifar10(opts)
107 | elif opts.dataset == "cifar100":
108 | return get_cifar100(opts)
109 | elif opts.dataset == "cub2011":
110 | return get_cub2011(opts)
111 | elif opts.dataset == "imagenet":
112 | return get_imagenet(opts)
113 | else:
114 | raise ValueError("Unknow dataset")
115 |
--------------------------------------------------------------------------------
/lib/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_neuron_distributed import *
2 | from .selayer import *
3 | from .data_parallel import *
4 |
--------------------------------------------------------------------------------
/lib/layers/cross_neuron.py:
--------------------------------------------------------------------------------
1 | # Non-local block using embedded gaussian
2 | # Code from
3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py
4 | import math
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | import numpy as np
9 | from scipy.linalg import block_diag
10 |
11 | class _CrossNeuronBlock(nn.Module):
12 | def __init__(self, in_channels, in_height, in_width,
13 | nblocks_channel=4,
14 | spatial_height=24, spatial_width=24,
15 | reduction=8, size_is_consistant=True):
16 | # nblock_channel: number of block along channel axis
17 | # spatial_size: spatial_size
18 | super(_CrossNeuronBlock, self).__init__()
19 |
20 | # set channel splits
21 | if in_channels <= 512:
22 | self.nblocks_channel = 1
23 | else:
24 | self.nblocks_channel = in_channels // 512
25 | block_size = in_channels // self.nblocks_channel
26 | block = torch.Tensor(block_size, block_size).fill_(1)
27 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
28 | for i in range(self.nblocks_channel):
29 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
30 |
31 | # set spatial splits
32 | if in_height * in_width < 32 * 32 and size_is_consistant:
33 | self.spatial_area = in_height * in_width
34 | self.spatial_height = in_height
35 | self.spatial_width = in_width
36 | else:
37 | self.spatial_area = spatial_height * spatial_width
38 | self.spatial_height = spatial_height
39 | self.spatial_width = spatial_width
40 |
41 | self.fc_in = nn.Sequential(
42 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
43 | nn.ReLU(True),
44 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
45 | )
46 |
47 | self.fc_out = nn.Sequential(
48 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
49 | nn.ReLU(True),
50 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
51 | )
52 |
53 | self.bn = nn.BatchNorm1d(self.spatial_area)
54 |
55 | def forward(self, x):
56 | '''
57 | :param x: (bt, c, h, w)
58 | :return:
59 | '''
60 | bt, c, h, w = x.shape
61 | residual = x
62 | x_stretch = x.view(bt, c, h * w)
63 | spblock_h = int(np.ceil(h / self.spatial_height))
64 | spblock_w = int(np.ceil(w / self.spatial_width))
65 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0
66 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0
67 |
68 | import pdb; pdb.set_trace()
69 |
70 | if spblock_h == 1 and spblock_w == 1:
71 | x_stacked = x_stretch # (b) x c x (h * w)
72 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
73 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
74 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
75 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel).detach() # (b * h * w) x 1 x c
76 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
77 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
78 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
79 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
80 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w)
81 | return F.relu(residual + out)
82 | else:
83 | # first splt input tensor into chunks
84 | ind_chunks = []
85 | x_chunks = []
86 | for i in range(spblock_h):
87 | for j in range(spblock_w):
88 | tl_y, tl_x = max(0, i * stride_h), max(0, j * stride_w)
89 | br_y, br_x = min(h, tl_y + self.spatial_height), min(w, tl_x + self.spatial_width)
90 | ind_y = torch.arange(tl_y, br_y).view(-1, 1)
91 | ind_x = torch.arange(tl_x, br_x).view(1, -1)
92 | ind = (ind_y * w + ind_x).view(1, 1, -1).repeat(bt, c, 1).type_as(x_stretch).long()
93 | ind_chunks.append(ind)
94 | chunk_ij = torch.gather(x_stretch, 2, ind).contiguous()
95 | x_chunks.append(chunk_ij)
96 |
97 | x_stacked = torch.cat(x_chunks, 0) # (b * nb_h * n_w) x c x (b_h * b_w)
98 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b * nb_h * n_w) x (b_h * b_w) x c
99 | x_v = self.fc_in(x_v) # (b * nb_h * n_w) x (b_h * b_w) x c
100 | x_m = x_v.mean(1).view(-1, 1, c) # (b * nb_h * n_w) x 1 x c
101 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * nb_h * n_w) x c x c
102 | score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
103 | attn = F.softmax(score, dim=1) # (b * nb_h * n_w) x c x c
104 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b * nb_h * n_w) x (b_h * b_w) x c
105 |
106 | # put back to original shape
107 | out = out.permute(0, 2, 1).contiguous() # (b * nb_h * n_w) x c x (b_h * b_w)
108 | # x_stretch_out = x_stretch.clone().zero_()
109 | for i in range(spblock_h):
110 | for j in range(spblock_w):
111 | idx = i * spblock_w + j
112 | ind = ind_chunks[idx]
113 | chunk_ij = out[idx * bt:(idx+1) * bt]
114 | x_stretch = x_stretch.scatter_add(2, ind, chunk_ij / spblock_h / spblock_h)
115 | return F.relu(x_stretch.view(residual.shape))
116 |
117 | class CrossNeuronlBlock2D(_CrossNeuronBlock):
118 | def __init__(self, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8, size_is_consistant=True):
119 | super(CrossNeuronlBlock2D, self).__init__(in_channels, in_height, in_width,
120 | nblocks_channel=4,
121 | spatial_height=spatial_height,
122 | spatial_width=spatial_width,
123 | reduction=reduction,
124 | size_is_consistant=size_is_consistant)
125 |
126 |
127 | class CrossNeuronWrapper(nn.Module):
128 | def __init__(self, block, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8):
129 | super(CrossNeuronWrapper, self).__init__()
130 | self.block = block
131 | self.cn = CrossNeuronlBlock2D(in_channels, in_height, in_width, spatial_height, spatial_width, reduction=reduction)
132 |
133 | def forward(self, x):
134 | x = self.cn(x)
135 | x = self.block(x)
136 | return x
137 |
138 | def add_cross_neuron(net, img_height, img_width, spatial_height, spatial_width, reduction=8):
139 | import torchvision
140 | import lib.networks as archs
141 |
142 | import pdb; pdb.set_trace()
143 |
144 | if isinstance(net, torchvision.models.ResNet):
145 | dummy_img = torch.randn(1, 3, img_height, img_width)
146 | out = net.conv1(dummy_img)
147 | out = net.relu(net.bn1(out))
148 | out0 = net.maxpool(out)
149 | print("layer0 out shape: {}x{}x{}x{}".format(out0.shape[0], out0.shape[1], out0.shape[2], out0.shape[3]))
150 | out1 = net.layer1(out0)
151 | print("layer1 out shape: {}x{}x{}x{}".format(out1.shape[0], out1.shape[1], out1.shape[2], out1.shape[3]))
152 | out2 = net.layer2(out1)
153 | print("layer2 out shape: {}x{}x{}x{}".format(out2.shape[0], out2.shape[1], out2.shape[2], out2.shape[3]))
154 | out3 = net.layer3(out2)
155 | print("layer3 out shape: {}x{}x{}x{}".format(out3.shape[0], out3.shape[1], out3.shape[2], out3.shape[3]))
156 | out4 = net.layer4(out3)
157 | print("layer4 out shape: {}x{}x{}x{}".format(out4.shape[0], out4.shape[1], out4.shape[2], out4.shape[3]))
158 |
159 | # net.layer1 = CrossNeuronWrapper(net.layer1, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[0], spatial_width[0], reduction)
160 | net.layer2 = CrossNeuronWrapper(net.layer2, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[1], spatial_width[1], reduction)
161 | net.layer3 = CrossNeuronWrapper(net.layer3, out3.shape[1], out3.shape[2], out3.shape[3], spatial_height[2], spatial_width[2], reduction)
162 | net.layer4 = CrossNeuronWrapper(net.layer4, out4.shape[1], out4.shape[2], out4.shape[3], spatial_height[3], spatial_width[3], reduction)
163 |
164 | # layers = []
165 | # l = len(net.layer2)
166 | # for i in range(l):
167 | # if i % 6 == 0 or i == (l - 1):
168 | # layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3],
169 | # spatial_height[1], spatial_width[1], reduction[1]))
170 | # else:
171 | # layers.append(net.layer2[i])
172 | # net.layer2 = nn.Sequential(*layers)
173 | #
174 | # #
175 | # layers = []
176 | # l = len(net.layer3)
177 | # for i in range(0, l):
178 | # if i % 6 == 0 or i == (l - 1):
179 | # layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2], out3.shape[3],
180 | # spatial_height[2], spatial_width[2], reduction[2]))
181 | # else:
182 | # layers.append(net.layer3[i])
183 | # net.layer3 = nn.Sequential(*layers)
184 | #
185 | # layers = []
186 | # l = len(net.layer4)
187 | # for i in range(0, l):
188 | # if i % 6 == 0 or i == (l - 1):
189 | # layers.append(CrossNeuronWrapper(net.layer4[i], out4.shape[1], out4.shape[2], out4.shape[3],
190 | # spatial_height[3], spatial_width[3], reduction[3]))
191 | # else:
192 | # layers.append(net.layer4[i])
193 | # net.layer4 = nn.Sequential(*layers)
194 |
195 |
196 | elif isinstance(net, archs.resnet_cifar.ResNet_Cifar):
197 |
198 | dummy_img = torch.randn(1, 3, img_height, img_width)
199 | out = net.conv1(dummy_img)
200 | out0 = net.relu(net.bn1(out))
201 | out1 = net.layer1(out0)
202 | out2 = net.layer2(out1)
203 | out3 = net.layer3(out2)
204 |
205 | net.layer1 = CrossNeuronWrapper(net.layer1, out0.shape[1], out0.shape[2], out0.shape[3], spatial_height[0], spatial_width[0])
206 | net.layer2 = CrossNeuronWrapper(net.layer2, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[1], spatial_width[1])
207 | net.layer3 = CrossNeuronWrapper(net.layer3, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[2], spatial_width[2])
208 |
209 | else:
210 | dummy_img = torch.randn(1, 3, img_height, img_width)
211 | out = net.conv1(dummy_img)
212 | out = net.relu(net.bn1(out))
213 | out1 = net.layer1(out)
214 | out2 = net.layer2(out1)
215 | out3 = net.layer3(out2)
216 |
217 | net.layer1 = CrossNeuronWrapper(net.layer1, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[0], spatial_width[0])
218 | net.layer2 = CrossNeuronWrapper(net.layer2, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[1], spatial_width[1])
219 | net.layer3 = CrossNeuronWrapper(net.layer3, out3.shape[1], out3.shape[2], out3.shape[3], spatial_height[2], spatial_width[2])
220 |
221 |
222 | # layers = []
223 | # l = len(net.layer2)
224 | # for i in range(l):
225 | # if i % 5 == 0 or i == (l - 1):
226 | # layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2] * out2.shape[3]))
227 | # else:
228 | # layers.append(net.layer2[i])
229 | # net.layer2 = nn.Sequential(*layers)
230 | # #
231 | # layers = []
232 | # l = len(net.layer3)
233 | # for i in range(0, l):
234 | # if i % 5 == 0 or i == (l - 1):
235 | # layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2] * out3.shape[3]))
236 | # else:
237 | # layers.append(net.layer3[i])
238 | # net.layer3 = nn.Sequential(*layers)
239 | #
240 | # else:
241 | # raise NotImplementedError
242 |
243 |
244 | if __name__ == '__main__':
245 | from torch.autograd import Variable
246 | import torch
247 |
248 | sub_sample = True
249 | bn_layer = True
250 |
251 | img = torch.randn(2, 3, 10, 20, 20)
252 | net = CrossNeuronlBlock3D(3, 20 * 20)
253 | out = net(img)
254 | print(out.size())
255 |
--------------------------------------------------------------------------------
/lib/layers/cross_neuron_distributed.py:
--------------------------------------------------------------------------------
1 | # Non-local block using embedded gaussian
2 | # Code from
3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py
4 | import math
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | import numpy as np
9 | from scipy.linalg import block_diag
10 |
11 | class _CrossNeuronBlock(nn.Module):
12 | def __init__(self, in_channels, in_height, in_width,
13 | nblocks_channel=8,
14 | spatial_height=32, spatial_width=32,
15 | reduction=4, size_is_consistant=True):
16 | # nblock_channel: number of block along channel axis
17 | # spatial_size: spatial_size
18 | super(_CrossNeuronBlock, self).__init__()
19 |
20 | self.corr_bf = 0
21 | self.corr_af = 0
22 |
23 | self.nblocks_channel = 1 if in_channels <= 512 else in_channels // 512
24 |
25 | block_size = in_channels // self.nblocks_channel
26 | block = torch.Tensor(block_size, block_size).fill_(1)
27 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
28 | for i in range(self.nblocks_channel):
29 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
30 |
31 | factor = in_height // spatial_height
32 | if factor == 0 and size_is_consistant:
33 | self.spatial_area = in_height * in_width
34 | self.spatial_height = in_height
35 | self.spatial_width = in_width
36 | else:
37 | ds_layers = []
38 | us_layers = []
39 | for i in range(factor - 1):
40 | ds_layer = nn.Sequential(
41 | nn.Conv2d(in_channels, in_channels, kernel_size=2, stride=2, padding=0, groups=in_channels, bias=False),
42 | nn.BatchNorm2d(in_channels),
43 | nn.ReLU(True),
44 | )
45 | ds_layers.append(ds_layer)
46 |
47 | us_layer = nn.Sequential(
48 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2, padding=0, groups=in_channels, bias=False),
49 | nn.BatchNorm2d(in_channels),
50 | nn.ReLU(True),
51 | )
52 | us_layers.append(us_layer)
53 | self.downsample = nn.Sequential(*ds_layers)
54 | self.upsample = nn.Sequential(*us_layers)
55 | self.spatial_height = in_height // factor
56 | self.spatial_width = in_width // factor
57 | self.spatial_area = self.spatial_height * self.spatial_width
58 |
59 | self.fc_in = nn.Sequential(
60 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
61 | # nn.BatchNorm1d(self.spatial_area // reduction),
62 | nn.ReLU(True),
63 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
64 | )
65 |
66 | self.fc_out = nn.Sequential(
67 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
68 | # nn.BatchNorm1d(self.spatial_area // reduction),
69 | nn.ReLU(True),
70 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
71 | # nn.BatchNorm1d(self.spatial_area)
72 | )
73 |
74 | self.ln = nn.LayerNorm(in_channels)
75 |
76 | self.initialize()
77 |
78 | def initialize(self):
79 | for m in self.modules():
80 | if isinstance(m, nn.Conv1d):
81 | nn.init.kaiming_normal_(m.weight)
82 | elif isinstance(m, nn.BatchNorm1d):
83 | nn.init.constant_(m.weight, 1)
84 | nn.init.constant_(m.bias, 0)
85 |
86 | # def _compute_correlation(self, x):
87 | # bt, c, h, w = x.shape
88 | # x_v = x.view(bt, c, h*w).detach()
89 | # x_v_mean = x_v.mean(1).unsqueeze(1)
90 | # x_v_cent = x_v - x_v_mean # bt x c x (hw)
91 | # # x_v_cent = x_v
92 | # x_v_cent = x_v_cent / (torch.norm(x_v_cent, 2, 2).unsqueeze(2) + 1e-5)
93 | # correlations = torch.bmm(x_v_cent, x_v_cent.permute(0, 2, 1).contiguous()) # btxcxc
94 | # diags = 1 - torch.eye(c).unsqueeze(0).type_as(correlations)
95 | # correlations = correlations * diags
96 | # return torch.abs(correlations).mean(0).sum() / c / (c - 1)
97 |
98 | def _compute_correlation(self, x):
99 | b, c, h, w = x.shape
100 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw)
101 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw)
102 | x_c = x_v - x_m # b x c x (hw)
103 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c
104 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1
105 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c
106 | out = num / dec # b x c x c
107 | # diags = 1 - torch.eye(c).unsqueeze(0).type_as(out)
108 | # out = out * diags
109 | out = torch.abs(out) # .mean(0).sum() / c / (c - 1)
110 | return out.mean()
111 |
112 | def forward(self, x):
113 | '''
114 | :param x: (bt, c, h, w)
115 | :return:
116 | '''
117 | bt, c, h, w = x.shape
118 | residual = x
119 | x_stretch = x.view(bt, c, h * w)
120 | self.corr_bf = self._compute_correlation(x)
121 | # self.corr_af = self._compute_correlation(x)
122 | # return x
123 |
124 | if self.spatial_height == h and self.spatial_width == w:
125 | x_stacked = x_stretch # (b) x c x (h * w)
126 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
127 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
128 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
129 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c
130 | # x_m = x_m - x_m.mean(2).unsqueeze(2)
131 | # x_m = self.ln(x_m)
132 | # x_m = x_m.detach()
133 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
134 | # score = -score / score.sum(2).unsqueeze(2)
135 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
136 | # x_v = F.dropout(x_v, 0.1, self.training)
137 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
138 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
139 | out = self.fc_out(torch.bmm(x_v, attn)) # (b) x (h * w) x c
140 | out = F.dropout(out.permute(0, 2, 1).contiguous().view(bt, c, h, w), 0.0, self.training)
141 | out = F.relu(residual + out)
142 | self.corr_af = self._compute_correlation(out)
143 | return out
144 | else:
145 | # x = self.downsample(x)
146 | x = F.interpolate(x, (self.spatial_height, self.spatial_width))
147 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width)
148 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width)
149 |
150 | x_stacked = x_stretch # (b) x c x (h * w)
151 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
152 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
153 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # b x 1 x c
154 | # x_m = x_m - x_m.mean(2).unsqueeze(2)
155 | # x_m = self.ln(x_m)
156 | # x_m = x_m.detach()
157 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # b x c x c
158 | # score = -score / score.sum(2).unsqueeze(2)
159 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
160 | # x_v = F.dropout(x_v, 0.1, self.training)
161 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
162 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
163 | out = self.fc_out(torch.bmm(x_v, attn)) # (b) x (h * w) x c
164 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width)
165 | # out = F.dropout(self.upsample(out), 0.0, self.training)
166 | out = F.dropout(F.interpolate(out, (h, w)), 0.0, self.training)
167 | out = F.relu(residual + out)
168 | self.corr_af = self._compute_correlation(out)
169 | return out
170 |
171 | # class _CrossNeuronBlock(nn.Module):
172 | # def __init__(self, in_channels, in_height, in_width,
173 | # nblocks_channel=8,
174 | # spatial_height=32, spatial_width=32,
175 | # reduction=4, size_is_consistant=True):
176 | # # nblock_channel: number of block along channel axis
177 | # # spatial_size: spatial_size
178 | # super(_CrossNeuronBlock, self).__init__()
179 | #
180 | # # set channel splits
181 | # if in_channels <= 512:
182 | # self.nblocks_channel = 1
183 | # else:
184 | # self.nblocks_channel = in_channels // 512
185 | # block_size = in_channels // self.nblocks_channel
186 | # block = torch.Tensor(block_size, block_size).fill_(1)
187 | # self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
188 | # for i in range(self.nblocks_channel):
189 | # self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
190 | #
191 | # # set spatial splits
192 | # if in_height * in_width < 32 * 32 and size_is_consistant:
193 | # self.spatial_area = in_height * in_width
194 | # self.spatial_height = in_height
195 | # self.spatial_width = in_width
196 | # else:
197 | # self.spatial_area = spatial_height * spatial_width
198 | # self.spatial_height = spatial_height
199 | # self.spatial_width = spatial_width
200 | #
201 | # self.conv_in = nn.Sequential(
202 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
203 | # nn.BatchNorm2d(in_channels),
204 | # nn.ReLU(True),
205 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
206 | # nn.BatchNorm2d(in_channels),
207 | # nn.ReLU(True),
208 | # )
209 | #
210 | # self.conv_out = nn.Sequential(
211 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
212 | # nn.BatchNorm2d(in_channels),
213 | # nn.ReLU(True),
214 | # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
215 | # nn.BatchNorm2d(in_channels),
216 | # )
217 | #
218 | # # self.bn = nn.BatchNorm1d(self.spatial_area)
219 | #
220 | # self.initialize()
221 | #
222 | # def initialize(self):
223 | # for m in self.modules():
224 | # if isinstance(m, nn.Conv2d):
225 | # nn.init.kaiming_normal_(m.weight)
226 | # elif isinstance(m, nn.BatchNorm2d):
227 | # nn.init.constant_(m.weight, 1)
228 | # nn.init.constant_(m.bias, 0)
229 | #
230 | # def forward(self, x):
231 | # '''
232 | # :param x: (bt, c, h, w)
233 | # :return:
234 | # '''
235 | # # import pdb; pdb.set_trace()
236 | # bt, c, h, w = x.shape
237 | # residual = x
238 | # x_v = self.conv_in(x) # b x c x h x w
239 | # x_m = x_v.mean(3).mean(2).unsqueeze(2) # bt x c x 1
240 | #
241 | # score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # bt x c x c
242 | # # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
243 | # # x_v = F.dropout(x_v, 0.1, self.training)
244 | # # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
245 | # attn = F.softmax(score, dim=2) # bt x c x c
246 | # out = self.conv_out(torch.bmm(attn, x_v.view(bt, c, h * w)).view(bt, c, h, w))
247 | # return F.relu(residual + out)
248 |
249 | class CrossNeuronlBlock2D(_CrossNeuronBlock):
250 | def __init__(self, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8, size_is_consistant=True):
251 | super(CrossNeuronlBlock2D, self).__init__(in_channels, in_height, in_width,
252 | nblocks_channel=4,
253 | spatial_height=spatial_height,
254 | spatial_width=spatial_width,
255 | reduction=reduction,
256 | size_is_consistant=size_is_consistant)
257 |
258 |
259 | class CrossNeuronWrapper(nn.Module):
260 | def __init__(self, block, in_channels, in_height, in_width, spatial_height, spatial_width, reduction=8):
261 | super(CrossNeuronWrapper, self).__init__()
262 | self.block = block
263 | self.cn = CrossNeuronlBlock2D(in_channels, in_height, in_width, spatial_height, spatial_width, reduction=reduction)
264 |
265 | # self.conv = nn.Sequential(
266 | # nn.Conv2d(in_channels, 4 * in_channels, 3, 1, 1),
267 | # nn.ReLU(True),
268 | # nn.Conv2d(4 * in_channels, in_channels, 3, 1, 1),
269 | # # nn.ReLU(True),
270 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1),
271 | # # nn.ReLU(True),
272 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1),
273 | # # nn.ReLU(True),
274 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1),
275 | # # nn.ReLU(True),
276 | # # nn.Conv2d(in_channels, in_channels, 3, 1, 1)
277 | # )
278 |
279 | def forward(self, x):
280 | x = self.cn(x)
281 | x = self.block(x)
282 | return x
283 |
284 | def add_cross_neuron(net, img_height, img_width, spatial_height, spatial_width, reduction=[4,4,4,4]):
285 | import torchvision
286 | import lib.networks as archs
287 |
288 | import pdb; pdb.set_trace()
289 |
290 | if isinstance(net, torchvision.models.ResNet):
291 | dummy_img = torch.randn(1, 3, img_height, img_width)
292 | out = net.conv1(dummy_img)
293 | out = net.relu(net.bn1(out))
294 | out0 = net.maxpool(out)
295 | print("layer0 out shape: {}x{}x{}x{}".format(out0.shape[0], out0.shape[1], out0.shape[2], out0.shape[3]))
296 | out1 = net.layer1(out0)
297 | print("layer1 out shape: {}x{}x{}x{}".format(out1.shape[0], out1.shape[1], out1.shape[2], out1.shape[3]))
298 | out2 = net.layer2(out1)
299 | print("layer2 out shape: {}x{}x{}x{}".format(out2.shape[0], out2.shape[1], out2.shape[2], out2.shape[3]))
300 | out3 = net.layer3(out2)
301 | print("layer3 out shape: {}x{}x{}x{}".format(out3.shape[0], out3.shape[1], out3.shape[2], out3.shape[3]))
302 | out4 = net.layer4(out3)
303 | print("layer4 out shape: {}x{}x{}x{}".format(out4.shape[0], out4.shape[1], out4.shape[2], out4.shape[3]))
304 |
305 |
306 | layers = []
307 | l = len(net.layer1)
308 | for i in range(l):
309 | if i == 0:
310 | layers.append(CrossNeuronWrapper(net.layer1[i], out0.shape[1], out0.shape[2], out0.shape[3],
311 | out0.shape[2], out0.shape[3], 4))
312 | elif i % 4 == 0:
313 | layers.append(CrossNeuronWrapper(net.layer1[i], out1.shape[1], out1.shape[2], out1.shape[3],
314 | out1.shape[2], out1.shape[3], 4))
315 | else:
316 | layers.append(net.layer1[i])
317 | net.layer1 = nn.Sequential(*layers)
318 |
319 | layers = []
320 | l = len(net.layer2)
321 | for i in range(l):
322 | if i == 0:
323 | layers.append(CrossNeuronWrapper(net.layer2[i], out1.shape[1], out1.shape[2], out1.shape[3],
324 | out1.shape[2], out1.shape[3], 4))
325 | elif i % 4 == 0:
326 | layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3],
327 | out2.shape[2], out2.shape[3], 4))
328 | else:
329 | layers.append(net.layer2[i])
330 | net.layer2 = nn.Sequential(*layers)
331 |
332 | #
333 | layers = []
334 | l = len(net.layer3)
335 | for i in range(0, l):
336 | if i == 0:
337 | layers.append(CrossNeuronWrapper(net.layer3[i], out2.shape[1], out2.shape[2], out2.shape[3],
338 | out2.shape[2], out2.shape[3], 4))
339 | elif i % 4 == 0:
340 | layers.append(CrossNeuronWrapper(net.layer3[i], out3.shape[1], out3.shape[2], out3.shape[3],
341 | out3.shape[2], out3.shape[3], 4))
342 | else:
343 | layers.append(net.layer3[i])
344 | net.layer3 = nn.Sequential(*layers)
345 |
346 | layers = []
347 | l = len(net.layer4)
348 | for i in range(0, l):
349 | if i == 0:
350 | layers.append(CrossNeuronWrapper(net.layer4[i], out3.shape[1], out3.shape[2], out3.shape[3],
351 | out3.shape[2], out3.shape[3], 4))
352 | else:
353 | layers.append(net.layer4[i])
354 | net.layer4 = nn.Sequential(*layers)
355 |
356 | else:
357 | dummy_img = torch.randn(1, 3, img_height, img_width)
358 | out = net.conv1(dummy_img)
359 | out0 = net.relu(net.bn1(out))
360 | out1 = net.layer1(out0)
361 | out2 = net.layer2(out1)
362 | out3 = net.layer3(out2)
363 | #
364 | net.layer1 = CrossNeuronWrapper(net.layer1, out0.shape[1], out0.shape[2], out0.shape[3], spatial_height[0], spatial_width[0])
365 | net.layer2 = CrossNeuronWrapper(net.layer2, out1.shape[1], out1.shape[2], out1.shape[3], spatial_height[1], spatial_width[1])
366 | net.layer3 = CrossNeuronWrapper(net.layer3, out2.shape[1], out2.shape[2], out2.shape[3], spatial_height[2], spatial_width[2])
367 |
368 | '''
369 | layers = []
370 | l = len(net.layer1)
371 | for i in range(l):
372 | if i == 0:
373 | layers.append(CrossNeuronWrapper(net.layer1[i], out0.shape[1], out0.shape[2], out0.shape[3], out0.shape[2], out0.shape[3]))
374 | elif i in [4, 7]: # resnet56: [4, 7]
375 | layers.append(CrossNeuronWrapper(net.layer1[i], out1.shape[1], out1.shape[2], out1.shape[3], out1.shape[2], out1.shape[3]))
376 | else:
377 | layers.append(net.layer1[i])
378 | net.layer1 = nn.Sequential(*layers)
379 | #
380 | layers = []
381 | l = len(net.layer2)
382 | for i in range(l):
383 | if i in [0]:
384 | layers.append(CrossNeuronWrapper(net.layer2[i], out1.shape[1], out1.shape[2], out1.shape[3], out1.shape[2], out1.shape[3]))
385 | elif i in [4, 7]:
386 | layers.append(CrossNeuronWrapper(net.layer2[i], out2.shape[1], out2.shape[2], out2.shape[3], out2.shape[2], out2.shape[3]))
387 | else:
388 | layers.append(net.layer2[i])
389 | net.layer2 = nn.Sequential(*layers)
390 | #
391 | layers = []
392 | l = len(net.layer3)
393 | for i in range(0, l):
394 | if i in [0]:
395 | layers.append(CrossNeuronWrapper(net.layer3[i], out2.shape[1], out2.shape[2], out2.shape[3], out2.shape[2], out2.shape[3]))
396 | else:
397 | layers.append(net.layer3[i])
398 | net.layer3 = nn.Sequential(*layers)
399 | #
400 | '''
401 | # else:
402 | # raise NotImplementedError
403 |
404 | if __name__ == '__main__':
405 | from torch.autograd import Variable
406 | import torch
407 |
408 | sub_sample = True
409 | bn_layer = True
410 |
411 | img = torch.randn(2, 3, 10, 20, 20)
412 | net = CrossNeuronlBlock3D(3, 20 * 20)
413 | out = net(img)
414 | print(out.size())
415 |
--------------------------------------------------------------------------------
/lib/layers/data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn.parallel import DataParallel
3 | import torch
4 | from torch.nn.parallel._functions import Scatter
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8 | r"""
9 | Slices tensors into approximately equal chunks and
10 | distributes them across given GPUs. Duplicates
11 | references to objects that are not tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, torch.Tensor):
15 | try:
16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17 | except:
18 | print('obj', obj.size())
19 | print('dim', dim)
20 | print('chunk_sizes', chunk_sizes)
21 | quit()
22 | if isinstance(obj, tuple) and len(obj) > 0:
23 | return list(zip(*map(scatter_map, obj)))
24 | if isinstance(obj, list) and len(obj) > 0:
25 | return list(map(list, zip(*map(scatter_map, obj))))
26 | if isinstance(obj, dict) and len(obj) > 0:
27 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28 | return [obj for targets in target_gpus]
29 |
30 | # After scatter_map is called, a scatter_map cell will exist. This cell
31 | # has a reference to the actual function scatter_map, which has references
32 | # to a closure that has a reference to the scatter_map cell (because the
33 | # fn is recursive). To avoid this reference cycle, we set the function to
34 | # None, clearing the cell
35 | try:
36 | return scatter_map(inputs)
37 | finally:
38 | scatter_map = None
39 |
40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
41 | r"""Scatter with support for kwargs dictionary"""
42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
44 | if len(inputs) < len(kwargs):
45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
46 | elif len(kwargs) < len(inputs):
47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
48 | inputs = tuple(inputs)
49 | kwargs = tuple(kwargs)
50 | return inputs, kwargs
51 |
52 | class BalancedDataParallel(DataParallel):
53 | def __init__(self, gpu0_bsz, *args, **kwargs):
54 | self.gpu0_bsz = gpu0_bsz
55 | super().__init__(*args, **kwargs)
56 |
57 | def forward(self, *inputs, **kwargs):
58 | if not self.device_ids:
59 | return self.module(*inputs, **kwargs)
60 | if self.gpu0_bsz == 0:
61 | device_ids = self.device_ids[1:]
62 | else:
63 | device_ids = self.device_ids
64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
65 | if len(self.device_ids) == 1:
66 | return self.module(*inputs[0], **kwargs[0])
67 | replicas = self.replicate(self.module, self.device_ids)
68 | if self.gpu0_bsz == 0:
69 | replicas = replicas[1:]
70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
71 | return self.gather(outputs, self.output_device)
72 |
73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs):
74 | return parallel_apply(replicas, inputs, kwargs, device_ids)
75 |
76 | def scatter(self, inputs, kwargs, device_ids):
77 | bsz = inputs[0].size(self.dim)
78 | num_dev = len(self.device_ids)
79 | gpu0_bsz = self.gpu0_bsz
80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
81 | if gpu0_bsz < bsz_unit:
82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
83 | delta = bsz - sum(chunk_sizes)
84 | for i in range(delta):
85 | chunk_sizes[i + 1] += 1
86 | if gpu0_bsz == 0:
87 | chunk_sizes = chunk_sizes[1:]
88 | else:
89 | return super().scatter(inputs, kwargs, device_ids)
90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
91 |
--------------------------------------------------------------------------------
/lib/layers/selayer.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class SELayer(nn.Module):
5 | def __init__(self, channel, reduction=8):
6 | super(SELayer, self).__init__()
7 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
8 | self.fc = nn.Sequential(
9 | nn.Linear(channel, int(channel / reduction), bias=False),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(int(channel / reduction), channel, bias=False),
12 | nn.Sigmoid()
13 | )
14 |
15 | def forward(self, x):
16 | b, c, _, _ = x.size()
17 | y = self.avg_pool(x).view(b, c)
18 | y = self.fc(y).view(b, c, 1, 1)
19 | return x * y.expand_as(x)
20 |
--------------------------------------------------------------------------------
/lib/model.py:
--------------------------------------------------------------------------------
1 | # Code for Cross-Neuron Communication Network
2 | # Author: Jianwei Yang (jw2yang@gatech.edu)
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.models as models
7 |
8 | from .networks import *
9 | from .layers import *
10 | from .layers.cross_neuron_distributed import _CrossNeuronBlock
11 |
12 | class AlexNet(nn.Module):
13 | def __init__(self, num_classes=100, has_gtlayer=False):
14 | super(AlexNet, self).__init__()
15 |
16 | if has_gtlayer:
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
19 | nn.ReLU(inplace=True),
20 | _CrossNeuronBlock(64, 8, 8, spatial_height=8, spatial_width=8, reduction=2),
21 | nn.MaxPool2d(kernel_size=2, stride=2),
22 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
23 | nn.ReLU(inplace=True),
24 | nn.MaxPool2d(kernel_size=2, stride=2),
25 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.MaxPool2d(kernel_size=2, stride=2),
32 | )
33 | else:
34 | self.features = nn.Sequential(
35 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
36 | nn.ReLU(inplace=True),
37 | nn.MaxPool2d(kernel_size=2, stride=2),
38 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
39 | nn.ReLU(inplace=True),
40 | nn.MaxPool2d(kernel_size=2, stride=2),
41 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
44 | nn.ReLU(inplace=True),
45 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
46 | nn.ReLU(inplace=True),
47 | nn.MaxPool2d(kernel_size=2, stride=2),
48 | )
49 | self.has_gtlayer = has_gtlayer
50 |
51 | # if has_gtlayer:
52 | # self.gtlayer = _CrossNeuronBlock(256, 16, 16, 16, 16)
53 | self.classifier = nn.Linear(256, num_classes)
54 |
55 | def forward(self, x):
56 | x = self.features(x)
57 | x = x.view(x.size(0), -1)
58 | x = self.classifier(x)
59 | return x
60 |
61 | class CrossNeuronNet(nn.Module):
62 | def __init__(self, opts):
63 | super(CrossNeuronNet, self).__init__()
64 | if opts.dataset == "imagenet": # we directly use pytorch version arch
65 | if opts.arch in models.__dict__:
66 | self.net = models.__dict__[opts.arch](pretrained=False)
67 | if opts.arch == "vgg16":
68 | for i in range(len(self.net.features)):
69 | if self.net.features[i].__class__.__name__ == "Conv2d":
70 | channels = self.net.features[i].out_channels
71 | self.net.features[i] = nn.Sequential(
72 | self.net.features[i],
73 | nn.BatchNorm2d(channels)
74 | )
75 | elif opts.arch in imagenet_models:
76 | self.net = imagenet_models[opts.arch](num_classes=1000)
77 | else:
78 | raise ValueError("Unknow network architecture for imagenet, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?")
79 | # if opts.add_cross_neuron:
80 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16], reduction = [4, 4, 4, 2])
81 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16])
82 |
83 | elif opts.dataset == "cub2011":
84 | if opts.arch in models.__dict__:
85 | self.net = models.__dict__[opts.arch](pretrained=True)
86 | self.net.fc = nn.Linear(self.net.fc.in_features, 200, bias=True)
87 | # NOTE: replace the last fc layer
88 | else:
89 | raise ValueError("Unknow network architecture for cub2011, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?")
90 | if opts.add_cross_neuron:
91 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 32, 32, 16], [32, 32, 32, 16], 8)
92 | elif opts.dataset == "cifar10": # we use the arch following He's paper (deep residual learning)
93 | if opts.arch in cifar_models:
94 | self.net = cifar_models[opts.arch](num_classes=10, has_selayer=("se" in opts.arch))
95 | else:
96 | raise ValueError("Unknow network architecture for imagenet")
97 | if opts.add_cross_neuron:
98 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16])
99 | elif opts.dataset == "cifar100": # we use the arch following He's paper (deep residual learning)
100 | if opts.arch in cifar_models:
101 | self.net = cifar_models[opts.arch](num_classes=100, has_selayer=("se" in opts.arch), has_gtlayer=False)
102 | elif opts.arch == "alexnet":
103 | self.net = AlexNet(has_gtlayer=opts.add_cross_neuron)
104 | else:
105 | raise ValueError("Unknow network architecture for cifar100")
106 | self.net.name = "cnn"
107 | if opts.add_cross_neuron and not opts.arch == "alexnet":
108 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [16, 16, 8, 8], [16, 16, 8, 8])
109 | else:
110 | raise ValueError("Unknow dataset, we only support cifar and imagenet for now")
111 |
112 | print(self.net)
113 |
114 | def get_optim_policies(self):
115 | resnet_param = []
116 | ccn_param = []
117 | for m in self.modules():
118 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.BatchNorm1d):
119 | ccn_param = ccn_param + list(m.parameters())
120 | else:
121 | resnet_param = resnet_param + list(m.parameters())
122 |
123 | return [{"params": resnet_param, 'lr_mult': 1}, {"params": ccn_param, 'lr_mult': 10}]
124 |
125 | def forward(self, x):
126 | # x = self.net.conv1(x)
127 | # x = self.net.bn1(x)
128 | # x = self.net.relu(x)
129 | # # x = self.net.maxpool(x)
130 | # x0 = x.clone().detach()
131 | #
132 | # x = self.net.layer1(x)
133 | # x1 = x.clone().detach()
134 | #
135 | # x = self.net.layer2(x)
136 | # x2 = x.clone().detach()
137 | #
138 | # x = self.net.layer3(x)
139 | # x3 = x.clone().detach()
140 |
141 | # x = self.net.layer4(x)
142 | # x4 = x.clone().detach()
143 |
144 | # out = self.net.avgpool(x).squeeze(3).squeeze(2)
145 | # out = self.net.fc(out)
146 | out = self.net(x)
147 | return out
148 |
--------------------------------------------------------------------------------
/lib/model_analysis.py:
--------------------------------------------------------------------------------
1 | # Code for Cross-Neuron Communication Network
2 | # Author: Jianwei Yang (jw2yang@gatech.edu)
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.models as models
7 |
8 | from .networks import *
9 | from .layers import *
10 |
11 | class CrossNeuronNetAnalysis(nn.Module):
12 | def __init__(self, opts):
13 | super(CrossNeuronNetAnalysis, self).__init__()
14 | if opts.dataset == "imagenet": # we directly use pytorch version arch
15 | import pdb; pdb.set_trace()
16 | if opts.arch in models.__dict__:
17 | self.net = models.__dict__[opts.arch](pretrained=False)
18 | elif opts.arch in imagenet_models:
19 | self.net = imagenet_models[opts.arch](num_classes=1000)
20 | else:
21 | raise ValueError("Unknow network architecture for imagenet, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?")
22 | if opts.add_cross_neuron:
23 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16], reduction = [4, 4, 4, 2])
24 | elif opts.dataset == "cub2011":
25 | if opts.arch in models.__dict__:
26 | self.net = models.__dict__[opts.arch](pretrained=True)
27 | self.net.fc = nn.Linear(self.net.fc.in_features, 200, bias=True)
28 | # NOTE: replace the last fc layer
29 | else:
30 | raise ValueError("Unknow network architecture for cub2011, please refer to: https://pytorch.org/docs/0.4.0/torchvision/models.html?")
31 | if opts.add_cross_neuron:
32 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 32, 32, 16], [32, 32, 32, 16], 8)
33 | elif opts.dataset == "cifar10": # we use the arch following He's paper (deep residual learning)
34 | if opts.arch in cifar_models:
35 | self.net = cifar_models[opts.arch](num_classes=10, has_selayer=("se" in opts.arch))
36 | else:
37 | raise ValueError("Unknow network architecture for imagenet")
38 | if opts.add_cross_neuron:
39 | add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16])
40 | elif opts.dataset == "cifar100": # we use the arch following He's paper (deep residual learning)
41 | if opts.arch in cifar_models:
42 | self.net = cifar_models[opts.arch](num_classes=100, insert_layers=opts.layers, depth=opts.depth,
43 | has_selayer=("se" in opts.arch), has_gtlayer=opts.add_cross_neuron, communication=opts.communication)
44 | else:
45 | raise ValueError("Unknow network architecture for cifar100")
46 | # if opts.add_cross_neuron:
47 | # add_cross_neuron(self.net, opts.data.input_size, opts.data.input_size, [32, 24, 24, 16], [32, 24, 24, 16])
48 | else:
49 | raise ValueError("Unknow dataset, we only support cifar and imagenet for now")
50 |
51 | print(self.net)
52 |
53 | def get_optim_policies(self):
54 | resnet_param = []
55 | ccn_param = []
56 | for m in self.modules():
57 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.BatchNorm1d):
58 | ccn_param = ccn_param + list(m.parameters())
59 | else:
60 | resnet_param = resnet_param + list(m.parameters())
61 |
62 | return [{"params": resnet_param, 'lr_mult': 1}, {"params": ccn_param, 'lr_mult': 10}]
63 |
64 | def forward(self, x):
65 | out = self.net(x)
66 | return out
67 |
--------------------------------------------------------------------------------
/lib/networks/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/networks/.DS_Store
--------------------------------------------------------------------------------
/lib/networks/._.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jwyang/C3Net.pytorch/70026fc80c5427484268c428a9dcd4cde2e8197f/lib/networks/._.DS_Store
--------------------------------------------------------------------------------
/lib/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_cifar import *
2 | from .resnet_cifar_analysis import *
3 | from .resnext_cifar import *
4 | from .wide_resnet_cifar import *
5 | from .senet_pytorch import *
6 | from .ncnet_pytorch import *
7 | from .invcnn_pytorch import *
8 | from .resnext import *
9 | from .mobilenet_v2 import *
10 |
11 | cifar_models = {"resnet20": resnet20a_cifar,
12 | "resnet56": resnet56a_cifar,
13 | "resnet110": resnet110a_cifar,
14 | "resnet164": resnet164_cifar,
15 |
16 | "resnet20a": resnet20a_cifar,
17 | "resnet56a": resnet56a_cifar,
18 | "resnet62a": resnet62a_cifar,
19 | "resnet68a": resnet68a_cifar,
20 | "resnet74a": resnet74a_cifar,
21 | "resnet80a": resnet80a_cifar,
22 | "resnet86a": resnet86a_cifar,
23 | "resnet92a": resnet92a_cifar,
24 | "resnet98a": resnet98a_cifar,
25 | "resnet104a": resnet104a_cifar,
26 | "resnet110a": resnet110a_cifar,
27 |
28 | "seresnet20": resnet20_cifar,
29 | "seresnet56": resnet56_cifar,
30 | "seresnet110": resnet110_cifar,
31 | "seresnet164": resnet164_cifar,
32 |
33 | "seresnet20a": resnet20a_cifar,
34 | "seresnet56a": resnet56a_cifar,
35 | "seresnet62a": resnet62a_cifar,
36 | "seresnet68a": resnet68a_cifar,
37 | "seresnet74a": resnet74a_cifar,
38 | "seresnet80a": resnet80a_cifar,
39 | "seresnet86a": resnet86a_cifar,
40 | "seresnet92a": resnet92a_cifar,
41 | "seresnet98a": resnet98a_cifar,
42 | "seresnet104a": resnet104a_cifar,
43 | "seresnet110a": resnet110a_cifar,
44 |
45 | "plainnet20": resnet20plain_cifar,
46 | "plainnet110": resnet110plain_cifar,
47 |
48 | "seplainnet20": resnet20plain_cifar,
49 | "seplainnet110": resnet110plain_cifar,
50 |
51 | "wresnet20": wresnet20_cifar,
52 | "sewresnet20": wresnet20_cifar,
53 |
54 | "resnext110": resneXt110_cifar,
55 | "seresnext110": resneXt110_cifar,
56 |
57 | "invcnn4": inv_cnn_4,
58 | }
59 |
60 | imagenet_models = {"seresnet18": se_resnet18,
61 | "seresnet50": se_resnet50,
62 | "seresnet101": se_resnet101,
63 | "ncresnet101": nc_resnet101,
64 | "ncvgg16": nc_vgg16,
65 | "resnext50": resnext50,
66 | "seresnext50": seresnext50,
67 | "ncresnext50": ncresnext50,
68 | "mobilenetv2": MobileNetV2,
69 | "semobilenetv2": SEMobileNetV2,
70 | "ncmobilenetv2": NCMobileNetV2,
71 | "sedensenet121": se_densenet121,
72 | }
73 |
--------------------------------------------------------------------------------
/lib/networks/densenet_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | DenseNet for cifar with pytorch
3 |
4 | Reference:
5 | [1] H. Gao, Z. Liu, L. Maaten and K. Weinberger. Densely connected convolutional networks. In CVPR, 2017
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from collections import OrderedDict
12 |
13 | import math
14 |
15 | class _DenseLayer(nn.Sequential):
16 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
17 | super(_DenseLayer, self).__init__()
18 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
19 | self.add_module('relu1', nn.ReLU(inplace=True)),
20 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
21 | growth_rate, kernel_size=1, stride=1, bias=False)),
22 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
23 | self.add_module('relu2', nn.ReLU(inplace=True)),
24 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
25 | kernel_size=3, stride=1, padding=1, bias=False)),
26 | self.drop_rate = drop_rate
27 |
28 | def forward(self, x):
29 | new_features = super(_DenseLayer, self).forward(x)
30 | if self.drop_rate > 0:
31 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
32 | return torch.cat([x, new_features], 1)
33 |
34 |
35 | class _DenseBlock(nn.Sequential):
36 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
37 | super(_DenseBlock, self).__init__()
38 | for i in range(num_layers):
39 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
40 | self.add_module('denselayer%d' % (i + 1), layer)
41 |
42 |
43 | class _Transition(nn.Sequential):
44 | def __init__(self, num_input_features, num_output_features):
45 | super(_Transition, self).__init__()
46 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
47 | self.add_module('relu', nn.ReLU(inplace=True))
48 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
49 | kernel_size=1, stride=1, bias=False))
50 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
51 |
52 |
53 | class DenseNet_Cifar(nn.Module):
54 | r"""Densenet-BC model class, based on
55 | `"Densely Connected Convolutional Networks" `_
56 |
57 | Args:
58 | growth_rate (int) - how many filters to add each layer (`k` in paper)
59 | block_config (list of 4 ints) - how many layers in each pooling block
60 | num_init_features (int) - the number of filters to learn in the first convolution layer
61 | bn_size (int) - multiplicative factor for number of bottle neck layers
62 | (i.e. bn_size * k features in the bottleneck layer)
63 | drop_rate (float) - dropout rate after each dense layer
64 | num_classes (int) - number of classification classes
65 | """
66 | def __init__(self, growth_rate=12, block_config=(16, 16, 16),
67 | num_init_features=24, bn_size=4, drop_rate=0, num_classes=10):
68 |
69 | super(DenseNet_Cifar, self).__init__()
70 |
71 | # First convolution
72 | self.features = nn.Sequential(OrderedDict([
73 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)),
74 | ]))
75 |
76 | # Each denseblock
77 | num_features = num_init_features
78 | for i, num_layers in enumerate(block_config):
79 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
80 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
81 | self.features.add_module('denseblock%d' % (i + 1), block)
82 | num_features = num_features + num_layers * growth_rate
83 | if i != len(block_config) - 1:
84 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
85 | self.features.add_module('transition%d' % (i + 1), trans)
86 | num_features = num_features // 2
87 |
88 | # Final batch norm
89 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
90 |
91 | # Linear layer
92 | self.classifier = nn.Linear(num_features, num_classes)
93 |
94 | # initialize conv and bn parameters
95 | for m in self.modules():
96 | if isinstance(m, nn.Conv2d):
97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98 | m.weight.data.normal_(0, math.sqrt(2. / n))
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 |
103 | def forward(self, x):
104 | features = self.features(x)
105 | out = F.relu(features, inplace=True)
106 | out = F.avg_pool2d(out, kernel_size=8, stride=1).view(features.size(0), -1)
107 | out = self.classifier(out)
108 | return out
109 |
110 |
111 | def densenet_BC_cifar(depth, k, **kwargs):
112 | N = (depth - 4) // 6
113 | model = DenseNet_Cifar(growth_rate=k, block_config=[N, N, N], num_init_features=2*k, **kwargs)
114 | return model
115 |
116 |
117 | if __name__ == '__main__':
118 | net = densenet_BC_cifar(190, 40, num_classes=100)
119 | input = torch.randn(1, 3, 32, 32)
120 | y = net(input)
121 | print(net)
122 | print(y.size())
123 |
--------------------------------------------------------------------------------
/lib/networks/invcnn_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .invnet import *
2 |
--------------------------------------------------------------------------------
/lib/networks/invcnn_pytorch/invnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .inv_cnn import *
2 |
--------------------------------------------------------------------------------
/lib/networks/invcnn_pytorch/invnet/inv_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | class BasicBlock(nn.Module):
6 | expansion=1
7 |
8 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None):
9 | super(BasicBlock, self).__init__()
10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
11 | self.bn1 = nn.BatchNorm2d(planes)
12 | self.relu = nn.ReLU(inplace=True)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(planes)
15 | self.downsample = downsample
16 | self.stride = stride
17 |
18 | def forward(self, x):
19 | residual = x
20 |
21 | out = self.conv1(x)
22 | out = self.bn1(out)
23 | out = self.relu(out)
24 |
25 | out = self.conv2(out)
26 | out = self.bn2(out)
27 |
28 | if self.downsample is not None:
29 | residual = self.downsample(x)
30 |
31 | out += residual
32 | out = self.relu(out)
33 | return out
34 |
35 | class attention(nn.Module):
36 | def __init__(self, dim):
37 | super(attention, self).__init__()
38 | self.net1 = nn.Conv1d(dim, dim, 1, 1, 0)
39 | self.net2 = nn.Conv1d(dim, dim, 1, 1, 0)
40 |
41 | def forward(self, x):
42 | x1 = x[0].view(x[0].shape[0], x[0].shape[1], -1)
43 | x2 = x[1].view(x[1].shape[0], x[1].shape[1], -1)
44 | x1 = F.relu(self.net1(x1)) # B x C x M
45 | x2 = F.relu(self.net2(x2)) # B x C x 1
46 |
47 | attn = torch.softmax((x1 * x2).sum(1), dim=1)
48 | attn = attn.unsqueeze(2)
49 |
50 | out = (x2 + torch.bmm(x1, attn)).unsqueeze(3)
51 | return out
52 |
53 | class INVCNN(nn.Module):
54 | def __init__(self, nlayers, num_classes=100, has_gtlayer=False, has_selayer=False):
55 | super(INVCNN, self).__init__()
56 | self.name = "invcnn"
57 | self.planes = 3
58 |
59 | self.layer1 = nn.Sequential(
60 | nn.Conv2d(3, 64, kernel_size=4, stride=4, padding=0, bias=False),
61 | nn.BatchNorm2d(64),
62 | nn.ReLU(True),
63 | )
64 | self.fc1 = nn.Linear(64, num_classes)
65 |
66 | self.layer2 = self._make_layer(BasicBlock, 64, 1, kernel_size=4)
67 | self.conv1x1_2 = nn.Sequential(
68 | nn.Conv2d(64, 64, 1, 1, 0),
69 | # nn.ReLU(True),
70 | )
71 | self.conv3x3_2 = nn.Sequential(
72 | nn.Conv2d(64, 64, 3, 1, 1),
73 | nn.ReLU(True),
74 | nn.AdaptiveAvgPool2d(1)
75 | )
76 | self.attn2 = attention(64)
77 | self.fc2 = nn.Linear(64, num_classes)
78 |
79 | self.layer3 = self._make_layer(BasicBlock, 64, 2, kernel_size=4)
80 | self.conv1x1_3 = nn.Sequential(
81 | nn.Conv2d(64, 64, 1, 1, 0),
82 | # nn.ReLU(True),
83 | )
84 | self.conv3x3_3 = nn.Sequential(
85 | nn.Conv2d(64, 64, 3, 1, 1),
86 | nn.ReLU(True),
87 | nn.AdaptiveAvgPool2d(1)
88 | )
89 | self.attn3 = attention(64)
90 | self.fc3 = nn.Linear(64, num_classes)
91 |
92 | self.layer4 = self._make_layer(BasicBlock, 64, 3, kernel_size=4)
93 | self.conv1x1_4 = nn.Sequential(
94 | nn.Conv2d(64, 64, 1, 1, 0),
95 | # nn.ReLU(True),
96 | )
97 | self.conv3x3_4 = nn.Sequential(
98 | nn.Conv2d(64, 64, 3, 1, 1),
99 | nn.ReLU(True),
100 | nn.AdaptiveAvgPool2d(1)
101 | )
102 | self.attn4 = attention(64)
103 | self.fc4 = nn.Linear(64, num_classes)
104 |
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
108 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
109 | nn.init.constant_(m.weight, 1)
110 | nn.init.constant_(m.bias, 0)
111 |
112 | def _make_layer(self, block, planes, blocks, kernel_size=3):
113 | layers = []
114 | inplanes = 3
115 | for i in range(0, blocks):
116 | if i == 0:
117 | downsample = nn.Sequential(
118 | nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=kernel_size, padding=kernel_size // 2, bias=False),
119 | nn.BatchNorm2d(planes)
120 | )
121 | layers.append(block(inplanes, planes, kernel_size=kernel_size, stride=kernel_size, padding=kernel_size // 2, downsample=downsample))
122 | else:
123 | layers.append(block(inplanes, planes))
124 | inplanes = planes
125 | # layers.append(nn.AdaptiveAvgPool2d(1))
126 | return nn.Sequential(*layers)
127 |
128 | def forward(self, x):
129 | # import pdb; pdb.set_trace()
130 | # out1 = self.layer1(x)
131 | # score1 = self.fc1(out1.view(out1.size(0), -1))
132 | # out2 = self.layer2(x)
133 | # out2 = self.conv1x1_2(out2) + F.interpolate(out1, (out2.shape[2], out2.shape[3]))
134 | # out2_fc = self.conv3x3_2(out2)
135 | # score2 = self.fc2(out2_fc.view(out2_fc.size(0), -1))
136 | # out3 = self.layer3(x)
137 | # out3 = self.conv1x1_3(out3) + F.interpolate(out2, (out3.shape[2], out3.shape[3]))
138 | # out3_fc = self.conv3x3_3(out3)
139 | # score3 = self.fc3(out3_fc.view(out3_fc.size(0), -1))
140 | # out4 = self.layer4(x)
141 | # out4 = self.conv1x1_4(out4) + F.interpolate(out3, (out4.shape[2], out4.shape[3]))
142 | # out4_fc = self.conv3x3_4(out4)
143 | # score4 = self.fc4(out4_fc.view(out4_fc.size(0), -1))
144 | # # scores = (score1, score2, score3, score4)
145 | # return [score4]
146 |
147 | out1 = self.layer1(F.interpolate(x, (4, 4)))
148 | score1 = self.fc1(out1.view(out1.size(0), -1))
149 | out2 = self.layer2(F.interpolate(x, (8, 8))) # BxCx2x2
150 | out2 = self.attn2((out2, out1))
151 | score2 = self.fc2(out2.view(out2.size(0), -1))
152 | out3 = self.layer3(F.interpolate(x, (16, 16)))
153 | out3 = self.attn3((out3, out2))
154 | score3 = self.fc3(out3.view(out3.size(0), -1))
155 | out4 = self.layer4(x)
156 | out4 = self.attn4((out4, out3))
157 | score4 = self.fc4(out4.view(out4.size(0), -1))
158 | # scores = (score1, score2, score3, score4)
159 | return [score4]
160 |
161 | def inv_cnn_4(**kwargs):
162 | return INVCNN(4, **kwargs)
163 |
--------------------------------------------------------------------------------
/lib/networks/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv_bn(inp, oup, stride):
6 | return nn.Sequential(
7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
8 | nn.BatchNorm2d(oup),
9 | nn.ReLU6(inplace=True)
10 | )
11 |
12 |
13 | def conv_1x1_bn(inp, oup):
14 | return nn.Sequential(
15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
16 | nn.BatchNorm2d(oup),
17 | nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | class InvertedResidual(nn.Module):
22 | def __init__(self, inp, oup, stride, expand_ratio):
23 | super(InvertedResidual, self).__init__()
24 | self.stride = stride
25 | assert stride in [1, 2]
26 |
27 | hidden_dim = round(inp * expand_ratio)
28 | self.use_res_connect = self.stride == 1 and inp == oup
29 |
30 | if expand_ratio == 1:
31 | self.conv = nn.Sequential(
32 | # dw
33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
34 | nn.BatchNorm2d(hidden_dim),
35 | nn.ReLU6(inplace=True),
36 | # pw-linear
37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
38 | nn.BatchNorm2d(oup),
39 | )
40 | else:
41 | self.conv = nn.Sequential(
42 | # pw
43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(hidden_dim),
45 | nn.ReLU6(inplace=True),
46 | # dw
47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
48 | nn.BatchNorm2d(hidden_dim),
49 | nn.ReLU6(inplace=True),
50 | # pw-linear
51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
52 | nn.BatchNorm2d(oup),
53 | )
54 |
55 | def forward(self, x):
56 | if self.use_res_connect:
57 | return x + self.conv(x)
58 | else:
59 | return self.conv(x)
60 |
61 |
62 | class MobileNetV2(nn.Module):
63 | def __init__(self, num_classes=1000, input_size=224, width_mult=1.):
64 | super(MobileNetV2, self).__init__()
65 | block = InvertedResidual
66 | input_channel = 32
67 | last_channel = 1280
68 | interverted_residual_setting = [
69 | # t, c, n, s
70 | [1, 16, 1, 1],
71 | [6, 24, 2, 2],
72 | [6, 32, 3, 2],
73 | [6, 64, 4, 2],
74 | [6, 96, 3, 1],
75 | [6, 160, 3, 2],
76 | [6, 320, 1, 1],
77 | ]
78 |
79 | # building first layer
80 | assert input_size % 32 == 0
81 | input_channel = int(input_channel * width_mult)
82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
83 | self.features = [conv_bn(3, input_channel, 2)]
84 | # building inverted residual blocks
85 | for t, c, n, s in interverted_residual_setting:
86 | output_channel = int(c * width_mult)
87 | for i in range(n):
88 | if i == 0:
89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
90 | else:
91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
92 | input_channel = output_channel
93 | # building last several layers
94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
95 | # make it nn.Sequential
96 | self.features = nn.Sequential(*self.features)
97 |
98 | # building classifier
99 | self.classifier = nn.Sequential(
100 | nn.Dropout(0.2),
101 | nn.Linear(self.last_channel, num_classes),
102 | )
103 |
104 | self._initialize_weights()
105 |
106 | def forward(self, x):
107 | x = self.features(x)
108 | x = x.mean(3).mean(2)
109 | x = self.classifier(x)
110 | return x
111 |
112 | def _initialize_weights(self):
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | if m.bias is not None:
118 | m.bias.data.zero_()
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 | elif isinstance(m, nn.Linear):
123 | n = m.weight.size(1)
124 | m.weight.data.normal_(0, 0.01)
125 | m.bias.data.zero_()
126 |
--------------------------------------------------------------------------------
/lib/networks/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''
2 | resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import math
12 |
13 | from lib.layers import *
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | " 3x3 convolution with padding "
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 |
19 | class BasicBlockPlain(nn.Module):
20 | expansion=1
21 |
22 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8):
23 | super(BasicBlockPlain, self).__init__()
24 | self.use_se = use_se
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn1 = nn.BatchNorm2d(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = nn.BatchNorm2d(planes)
30 | if self.use_se:
31 | self.se = SELayer(planes, reduction=reduction)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | # out = self.gt_spatial(out)
37 | residual = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 | if self.use_se:
46 | out = self.se(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out = residual
52 | out = self.relu(out)
53 | return out
54 |
55 | class BasicBlock(nn.Module):
56 | expansion=1
57 |
58 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
59 | super(BasicBlock, self).__init__()
60 | self.use_se = use_se
61 | self.use_gt = use_gt
62 | self.conv1 = conv3x3(inplanes, planes, stride)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.conv2 = conv3x3(planes, planes)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | if self.use_se:
68 | self.se = SELayer(planes, reduction=reduction)
69 | if self.use_gt:
70 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8)
71 |
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | # out = self.gt_spatial(out)
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | if self.use_se:
86 | out = self.se(out)
87 | if self.use_gt:
88 | out = self.gt(out)
89 |
90 | if self.downsample is not None:
91 | residual = self.downsample(x)
92 |
93 | out += residual
94 | out = self.relu(out)
95 | return out
96 |
97 |
98 | class Bottleneck(nn.Module):
99 | expansion=4
100 |
101 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
102 | super(Bottleneck, self).__init__()
103 | self.use_se = use_se
104 | self.use_gt = use_gt
105 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
106 | self.bn1 = nn.BatchNorm2d(planes)
107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
108 | self.bn2 = nn.BatchNorm2d(planes)
109 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
110 | self.bn3 = nn.BatchNorm2d(planes*4)
111 | self.relu = nn.ReLU(inplace=True)
112 | if self.use_se:
113 | self.se = SELayer(planes, reduction=reduction)
114 | if self.use_gt:
115 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8)
116 | self.downsample = downsample
117 | self.stride = stride
118 |
119 | def forward(self, x):
120 | residual = x
121 |
122 | out = self.conv1(x)
123 | out = self.bn1(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv2(out)
127 | out = self.bn2(out)
128 | out = self.relu(out)
129 |
130 | out = self.conv3(out)
131 | out = self.bn3(out)
132 |
133 | if self.use_se:
134 | out = self.se(out)
135 | if self.use_gt:
136 | out = self.gt(out)
137 |
138 | if self.downsample is not None:
139 | residual = self.downsample(x)
140 |
141 | out += residual
142 | out = self.relu(out)
143 |
144 | return out
145 |
146 | # class CommunicationBlock(nn.Module):
147 |
148 |
149 | class PreActBasicBlock(nn.Module):
150 | expansion = 1
151 |
152 | def __init__(self, inplanes, planes, stride=1, downsample=None):
153 | super(PreActBasicBlock, self).__init__()
154 | self.bn1 = nn.BatchNorm2d(inplanes)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.conv1 = conv3x3(inplanes, planes, stride)
157 | self.bn2 = nn.BatchNorm2d(planes)
158 | self.conv2 = conv3x3(planes, planes)
159 | self.downsample = downsample
160 | self.stride = stride
161 |
162 | def forward(self, x):
163 | residual = x
164 |
165 | out = self.bn1(x)
166 | out = self.relu(out)
167 |
168 | if self.downsample is not None:
169 | residual = self.downsample(out)
170 |
171 | out = self.conv1(out)
172 |
173 | out = self.bn2(out)
174 | out = self.relu(out)
175 | out = self.conv2(out)
176 |
177 | out += residual
178 |
179 | return out
180 |
181 |
182 | class PreActBottleneck(nn.Module):
183 | expansion = 4
184 |
185 | def __init__(self, inplanes, planes, stride=1, downsample=None):
186 | super(PreActBottleneck, self).__init__()
187 | self.bn1 = nn.BatchNorm2d(inplanes)
188 | self.relu = nn.ReLU(inplace=True)
189 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
190 | self.bn2 = nn.BatchNorm2d(planes)
191 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
192 | self.bn3 = nn.BatchNorm2d(planes)
193 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
194 | self.downsample = downsample
195 | self.stride = stride
196 |
197 | def forward(self, x):
198 | residual = x
199 |
200 | out = self.bn1(x)
201 | out = self.relu(out)
202 |
203 | if self.downsample is not None:
204 | residual = self.downsample(out)
205 |
206 | out = self.conv1(out)
207 |
208 | out = self.bn2(out)
209 | out = self.relu(out)
210 | out = self.conv2(out)
211 |
212 | out = self.bn3(out)
213 | out = self.relu(out)
214 | out = self.conv3(out)
215 |
216 | out += residual
217 |
218 | return out
219 |
220 |
221 | class ResNet_Cifar(nn.Module):
222 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False):
223 | super(ResNet_Cifar, self).__init__()
224 | self.inplanes = 16
225 | self.insize = 32
226 | self.has_gtlayer = has_gtlayer
227 | self.has_selayer = has_selayer
228 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
229 | self.bn1 = nn.BatchNorm2d(16)
230 | self.relu = nn.ReLU(inplace=True)
231 | self.layer1 = self._make_layer(block, 16, layers[0])
232 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
233 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
234 |
235 | self.avgpool = nn.AvgPool2d(8, stride=1)
236 | self.fc = nn.Linear(64 * block.expansion, num_classes)
237 |
238 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False):
239 | downsample = None
240 | if stride != 1 or self.inplanes != planes * block.expansion:
241 | downsample = nn.Sequential(
242 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
243 | nn.BatchNorm2d(planes * block.expansion)
244 | )
245 |
246 | self.insize = int(self.insize / stride)
247 | layers = []
248 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=self.has_selayer, use_gt=self.has_gtlayer))
249 | self.inplanes = planes * block.expansion
250 | for _ in range(1, blocks - 1):
251 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer))
252 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer))
253 | return nn.Sequential(*layers)
254 |
255 | def forward(self, x):
256 | x = self.conv1(x)
257 | x = self.bn1(x)
258 | x = self.relu(x)
259 | x0 = x.clone()
260 | x = self.layer1(x)
261 | x1 = x.clone()
262 | x = self.layer2(x)
263 | x2 = x.clone()
264 | x = self.layer3(x)
265 | x3 = x.clone()
266 |
267 | x = self.avgpool(x)
268 | x = x.view(x.size(0), -1)
269 | x = self.fc(x)
270 |
271 | return x #, (x0, x1, x2, x3)
272 |
273 |
274 | class PreAct_ResNet_Cifar(nn.Module):
275 |
276 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False):
277 | super(PreAct_ResNet_Cifar, self).__init__()
278 | self.inplanes = 16
279 | self.insize = 32
280 | self.has_gtlayer = has_gtlayer
281 | self.has_selayer = has_selayer
282 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
283 | self.layer1 = self._make_layer(block, 16, layers[0])
284 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
285 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
286 |
287 | self.bn = nn.BatchNorm2d(64*block.expansion)
288 | self.relu = nn.ReLU(inplace=True)
289 | self.avgpool = nn.AvgPool2d(8, stride=1)
290 | self.fc = nn.Linear(64*block.expansion, num_classes)
291 |
292 | for m in self.modules():
293 | if isinstance(m, nn.Conv2d):
294 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
295 | m.weight.data.normal_(0, math.sqrt(2. / n))
296 | elif isinstance(m, nn.BatchNorm2d):
297 | m.weight.data.fill_(1)
298 | m.bias.data.zero_()
299 |
300 | def _make_layer(self, block, planes, blocks, stride=1):
301 | downsample = None
302 | if stride != 1 or self.inplanes != planes*block.expansion:
303 | downsample = nn.Sequential(
304 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
305 | )
306 |
307 | layers = []
308 | layers.append(block(self.inplanes, planes, stride, downsample))
309 | self.inplanes = planes*block.expansion
310 | for _ in range(1, blocks):
311 | layers.append(block(self.inplanes, planes))
312 | return nn.Sequential(*layers)
313 |
314 | def forward(self, x):
315 | x = self.conv1(x)
316 |
317 | x = self.layer1(x)
318 | x = self.layer2(x)
319 | x = self.layer3(x)
320 |
321 | x = self.bn(x)
322 | x = self.relu(x)
323 | x = self.avgpool(x)
324 | x = x.view(x.size(0), -1)
325 | x = self.fc(x)
326 |
327 | return x
328 |
329 |
330 |
331 | def resnet20_cifar(**kwargs):
332 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
333 | return model
334 |
335 | def resnet20plain_cifar(**kwargs):
336 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs)
337 | return model
338 |
339 | def resnet32_cifar(**kwargs):
340 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
341 | return model
342 |
343 |
344 | def resnet44_cifar(**kwargs):
345 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
346 | return model
347 |
348 |
349 | def resnet56_cifar(**kwargs):
350 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
351 | return model
352 |
353 |
354 | def resnet110_cifar(**kwargs):
355 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
356 | return model
357 |
358 | def resnet110plain_cifar(**kwargs):
359 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs)
360 | return model
361 |
362 | def resnet1202_cifar(**kwargs):
363 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs)
364 | return model
365 |
366 |
367 | def resnet164_cifar(**kwargs):
368 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs)
369 | return model
370 |
371 |
372 | def resnet1001_cifar(**kwargs):
373 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs)
374 | return model
375 |
376 |
377 | def preact_resnet110_cifar(**kwargs):
378 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
379 | return model
380 |
381 |
382 | def preact_resnet164_cifar(**kwargs):
383 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs)
384 | return model
385 |
386 |
387 | def preact_resnet1001_cifar(**kwargs):
388 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs)
389 | return model
390 |
391 |
392 | if __name__ == '__main__':
393 | net = resnet20_cifar()
394 | y = net(torch.randn(1, 3, 64, 64))
395 | print(net)
396 | print(y.size())
397 |
--------------------------------------------------------------------------------
/lib/networks/resnet_cifar_analysis.py:
--------------------------------------------------------------------------------
1 | '''
2 | resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import math
12 |
13 | from lib.layers import *
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | " 3x3 convolution with padding "
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 |
19 | def weights_init(m):
20 | if isinstance(m, nn.Conv1d):
21 | xavier(m.weight.data)
22 | xavier(m.bias.data)
23 |
24 | class _CrossNeuronBlockInternal(nn.Module):
25 | def __init__(self, in_channels):
26 | # nblock_channel: number of block along channel axis
27 | # spatial_size: spatial_size
28 | super(_CrossNeuronBlockInternal, self).__init__()
29 |
30 | self.conv_in = nn.Sequential(
31 | nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True),
32 | # nn.BatchNorm2d(in_channels),
33 | nn.ReLU(True),
34 | nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True),
35 | # nn.BatchNorm2d(in_channels),
36 | # nn.ReLU(True),
37 | )
38 |
39 | self.conv_out = nn.Sequential(
40 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True),
41 | # nn.BatchNorm2d(in_channels),
42 | nn.ReLU(True),
43 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=True),
44 | # nn.BatchNorm2d(in_channels),
45 | )
46 |
47 | # self.bn = nn.BatchNorm1d(self.spatial_area)
48 |
49 | self.initialize()
50 |
51 | def initialize(self):
52 | for m in self.modules():
53 | if isinstance(m, nn.Conv2d):
54 | nn.init.kaiming_normal_(m.weight)
55 | elif isinstance(m, nn.BatchNorm2d):
56 | nn.init.constant_(m.weight, 1)
57 | nn.init.constant_(m.bias, 0)
58 |
59 | def forward(self, x):
60 | '''
61 | :param x: (bt, c, h, w)
62 | :return:
63 | '''
64 | bt, c, h, w = x.shape
65 | residual = x
66 | x_v = self.conv_in(x) # b x c x h x w
67 | x_m = x_v.mean(3).mean(2).unsqueeze(2) # bt x c x 1
68 |
69 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # bt x c x c
70 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
71 | # x_v = F.dropout(x_v, 0.1, self.training)
72 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
73 | attn = F.softmax(score, dim=2) # bt x c x c
74 | out = self.conv_out(torch.bmm(attn, x_v.view(bt, c, -1)).view(bt, c, x_v.shape[2], x_v.shape[3]))
75 | return F.relu(residual + out)
76 |
77 | class _CrossNeuronBlock(nn.Module):
78 | def __init__(self, in_channels, in_height, in_width,
79 | spatial_height=32, spatial_width=32,
80 | reduction=16,
81 | size_is_consistant=True,
82 | communication=True,
83 | enc_dec=True):
84 | # nblock_channel: number of block along channel axis
85 | # spatial_size: spatial_size
86 | super(_CrossNeuronBlock, self).__init__()
87 |
88 | self.communication = communication
89 | self.enc_dec = enc_dec
90 |
91 | # set channel splits
92 | if in_channels <= 512:
93 | self.nblocks_channel = 1
94 | else:
95 | self.nblocks_channel = in_channels // 512
96 | block_size = in_channels // self.nblocks_channel
97 | block = torch.Tensor(block_size, block_size).fill_(1)
98 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
99 | for i in range(self.nblocks_channel):
100 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
101 |
102 | # set spatial splits
103 | if in_height * in_width < 16 * 16 and size_is_consistant:
104 | self.spatial_area = in_height * in_width
105 | self.spatial_height = in_height
106 | self.spatial_width = in_width
107 | else:
108 | self.spatial_area = spatial_height * spatial_width
109 | self.spatial_height = spatial_height
110 | self.spatial_width = spatial_width
111 | #
112 |
113 |
114 | self.fc_in = nn.Sequential(
115 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
116 | # nn.BatchNorm1d(self.spatial_area // reduction),
117 | nn.ReLU(True),
118 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
119 | )
120 |
121 | self.global_context = nn.Sequential(
122 | nn.Conv1d(self.spatial_area, 1, 1, 1, 0, bias=True),
123 | nn.ReLU(True),
124 | )
125 |
126 | self.fc_out = nn.Sequential(
127 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
128 | # nn.BatchNorm1d(self.spatial_area // reduction),
129 | nn.ReLU(True),
130 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
131 | )
132 |
133 | self.bn = nn.BatchNorm1d(self.spatial_area)
134 | # self.ln = nn.LayerNorm(in_channels)
135 |
136 | self.initialize()
137 |
138 | def initialize(self):
139 | for m in self.modules():
140 | if isinstance(m, nn.Conv1d):
141 | nn.init.kaiming_normal_(m.weight)
142 | if m.bias is not None:
143 | torch.nn.init.zeros_(m.bias)
144 | elif isinstance(m, nn.BatchNorm1d):
145 | nn.init.constant_(m.weight, 1)
146 | nn.init.constant_(m.bias, 0)
147 |
148 | def forward(self, x):
149 | '''
150 | :param x: (bt, c, h, w)
151 | :return:
152 | '''
153 | bt, c, h, w = x.shape
154 | residual = x
155 | x_stretch = x.view(bt, c, h * w)
156 |
157 | if self.spatial_height == h and self.spatial_width == w:
158 | x_stacked = x_stretch # (b) x c x (h * w)
159 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
160 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
161 |
162 | if self.enc_dec:
163 | x_v = self.fc_in(x_v) # (b) x c x (h * w)
164 |
165 | # x_m = self.global_context(x_v)
166 | # import pdb; pdb.set_trace()
167 | # x_m = self.global_context(x_v) # b x 1 x c
168 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x c
169 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
170 | # score = -torch.abs(x_m - x_m.permute(0, 2, 1).contiguous()) # (b * h * w) x c x c
171 |
172 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
173 | # x_v = F.dropout(x_v, 0.1, self.training)
174 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
175 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
176 | # attn = F.dropout(attn, 0.2, self.training)
177 |
178 | if self.communication:
179 | if self.enc_dec:
180 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
181 | else:
182 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
183 | else:
184 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
185 |
186 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w)
187 | return F.relu(residual + out)
188 | else:
189 | x = F.interpolate(x, (self.spatial_height, self.spatial_width))
190 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width)
191 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width)
192 |
193 | x_stacked = x_stretch # (b) x c x (h * w)
194 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
195 |
196 | if self.enc_dec:
197 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
198 |
199 | # x_m = self.global_context(x_v)
200 | # x_m = self.global_context(x_v) # b x 1 x c
201 | x_m = x_v.mean(1).unsqueeze(1) # (b) x 1 x c
202 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
203 | # score = -torch.abs(x_m - x_m.permute(0, 2, 1).contiguous()) # (b * h * w) x c x c
204 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
205 | # x_v = F.dropout(x_v, 0.1, self.training)
206 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
207 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
208 | # attn = F.dropout(attn, 0.2, self.training)
209 | if self.communication:
210 | if self.enc_dec:
211 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
212 | else:
213 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
214 | else:
215 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
216 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width)
217 | out = F.interpolate(out, (h, w))
218 | return F.relu(residual + out)
219 |
220 | class BasicBlockPlain(nn.Module):
221 | expansion=1
222 |
223 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8):
224 | super(BasicBlockPlain, self).__init__()
225 | self.use_se = use_se
226 | self.conv1 = conv3x3(inplanes, planes, stride)
227 | self.bn1 = nn.BatchNorm2d(planes)
228 | self.relu = nn.ReLU(inplace=True)
229 | self.conv2 = conv3x3(planes, planes)
230 | self.bn2 = nn.BatchNorm2d(planes)
231 | if self.use_se:
232 | self.se = SELayer(planes, reduction=reduction)
233 | self.downsample = downsample
234 | self.stride = stride
235 |
236 | def forward(self, x):
237 | # out = self.gt_spatial(out)
238 | residual = x
239 |
240 | out = self.conv1(x)
241 | out = self.bn1(out)
242 | out = self.relu(out)
243 |
244 | out = self.conv2(out)
245 | out = self.bn2(out)
246 | if self.use_se:
247 | out = self.se(out)
248 |
249 | if self.downsample is not None:
250 | residual = self.downsample(x)
251 |
252 | out = residual
253 | out = self.relu(out)
254 | return out
255 |
256 | class BasicBlock(nn.Module):
257 | expansion=1
258 |
259 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
260 | super(BasicBlock, self).__init__()
261 | self.use_se = use_se
262 | self.use_gt = use_gt
263 | self.conv1 = conv3x3(inplanes, planes, stride)
264 | self.bn1 = nn.BatchNorm2d(planes)
265 | self.relu = nn.ReLU(inplace=True)
266 | self.conv2 = conv3x3(planes, planes)
267 | self.bn2 = nn.BatchNorm2d(planes)
268 | if self.use_se:
269 | self.se = SELayer(planes, reduction=reduction)
270 | if self.use_gt:
271 | self.gt = _CrossNeuronBlockInternal(planes)
272 |
273 | self.downsample = downsample
274 | self.stride = stride
275 |
276 | def forward(self, x):
277 | # out = self.gt_spatial(out)
278 | residual = x
279 |
280 | out = self.conv1(x)
281 | out = self.bn1(out)
282 | out = self.relu(out)
283 |
284 | if self.use_gt:
285 | out = self.gt(out)
286 |
287 | out = self.conv2(out)
288 | out = self.bn2(out)
289 |
290 | if self.use_se:
291 | out = self.se(out)
292 |
293 | if self.downsample is not None:
294 | residual = self.downsample(x)
295 |
296 | out += residual
297 | out = self.relu(out)
298 | return out
299 |
300 |
301 | class Bottleneck(nn.Module):
302 | expansion=4
303 |
304 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
305 | super(Bottleneck, self).__init__()
306 | self.use_se = use_se
307 | self.use_gt = use_gt
308 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
309 | self.bn1 = nn.BatchNorm2d(planes)
310 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
311 | self.bn2 = nn.BatchNorm2d(planes)
312 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
313 | self.bn3 = nn.BatchNorm2d(planes*4)
314 | self.relu = nn.ReLU(inplace=True)
315 | if self.use_se:
316 | self.se = SELayer(planes, reduction=reduction)
317 | if self.use_gt:
318 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8)
319 | self.downsample = downsample
320 | self.stride = stride
321 |
322 | def forward(self, x):
323 | residual = x
324 |
325 | out = self.conv1(x)
326 | out = self.bn1(out)
327 | out = self.relu(out)
328 |
329 | out = self.conv2(out)
330 | out = self.bn2(out)
331 | out = self.relu(out)
332 |
333 | out = self.conv3(out)
334 | out = self.bn3(out)
335 |
336 | if self.use_se:
337 | out = self.se(out)
338 | if self.use_gt:
339 | out = self.gt(out)
340 |
341 | if self.downsample is not None:
342 | residual = self.downsample(x)
343 |
344 | out += residual
345 | out = self.relu(out)
346 |
347 | return out
348 |
349 | # class CommunicationBlock(nn.Module):
350 |
351 |
352 | class PreActBasicBlock(nn.Module):
353 | expansion = 1
354 |
355 | def __init__(self, inplanes, planes, stride=1, downsample=None):
356 | super(PreActBasicBlock, self).__init__()
357 | self.bn1 = nn.BatchNorm2d(inplanes)
358 | self.relu = nn.ReLU(inplace=True)
359 | self.conv1 = conv3x3(inplanes, planes, stride)
360 | self.bn2 = nn.BatchNorm2d(planes)
361 | self.conv2 = conv3x3(planes, planes)
362 | self.downsample = downsample
363 | self.stride = stride
364 |
365 | def forward(self, x):
366 | residual = x
367 |
368 | out = self.bn1(x)
369 | out = self.relu(out)
370 |
371 | if self.downsample is not None:
372 | residual = self.downsample(out)
373 |
374 | out = self.conv1(out)
375 |
376 | out = self.bn2(out)
377 | out = self.relu(out)
378 | out = self.conv2(out)
379 |
380 | out += residual
381 |
382 | return out
383 |
384 |
385 | class PreActBottleneck(nn.Module):
386 | expansion = 4
387 |
388 | def __init__(self, inplanes, planes, stride=1, downsample=None):
389 | super(PreActBottleneck, self).__init__()
390 | self.bn1 = nn.BatchNorm2d(inplanes)
391 | self.relu = nn.ReLU(inplace=True)
392 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
393 | self.bn2 = nn.BatchNorm2d(planes)
394 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
395 | self.bn3 = nn.BatchNorm2d(planes)
396 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
397 | self.downsample = downsample
398 | self.stride = stride
399 |
400 | def forward(self, x):
401 | residual = x
402 |
403 | out = self.bn1(x)
404 | out = self.relu(out)
405 |
406 | if self.downsample is not None:
407 | residual = self.downsample(out)
408 |
409 | out = self.conv1(out)
410 |
411 | out = self.bn2(out)
412 | out = self.relu(out)
413 | out = self.conv2(out)
414 |
415 | out = self.bn3(out)
416 | out = self.relu(out)
417 | out = self.conv3(out)
418 |
419 | out += residual
420 |
421 | return out
422 |
423 |
424 | class ResNet_Cifar(nn.Module):
425 | def __init__(self, block, layers, num_classes=10, insert_layers=["1", "2", "3"], depth=1,
426 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True):
427 | super(ResNet_Cifar, self).__init__()
428 | self.inplanes = 16
429 | self.insize = 32
430 | self.layers = insert_layers
431 | self.depth = depth
432 | self.has_gtlayer = has_gtlayer
433 | self.has_selayer = has_selayer
434 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
435 | self.bn1 = nn.BatchNorm2d(16)
436 | self.relu = nn.ReLU(inplace=True)
437 |
438 | self.layer1 = self._make_layer(block, 16, layers[0], has_selayer=has_selayer, has_gtlayer=has_gtlayer)
439 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, has_selayer=has_selayer, has_gtlayer=has_gtlayer)
440 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, has_selayer=has_selayer, has_gtlayer=has_gtlayer)
441 |
442 | if self.has_gtlayer:
443 | if "1" in self.layers:
444 | layers = []
445 | for i in range(self.depth):
446 | layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, reduction=16,
447 | communication=communication, enc_dec=enc_dec))
448 | self.nclayer1 = nn.Sequential(*layers)
449 |
450 | if "2" in self.layers:
451 | layers = []
452 | for i in range(self.depth):
453 | layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16, reduction=16,
454 | communication=communication, enc_dec=enc_dec))
455 | self.nclayer2 = nn.Sequential(*layers)
456 |
457 | if "3" in self.layers:
458 | layers = []
459 | for i in range(self.depth):
460 | layers.append(_CrossNeuronBlock(32, 16, 16, spatial_height=16, spatial_width=16, reduction=16,
461 | communication=communication, enc_dec=enc_dec))
462 | self.nclayer3 = nn.Sequential(*layers)
463 |
464 | if "4" in self.layers:
465 | layers = []
466 | for i in range(self.depth):
467 | layers.append(_CrossNeuronBlock(64, 8, 8, spatial_height=8, spatial_width=8, reduction=8,
468 | communication=communication, enc_dec=enc_dec))
469 | self.nclayer4 = nn.Sequential(*layers)
470 |
471 | self.avgpool = nn.AvgPool2d(8, stride=1)
472 | self.fc = nn.Linear(64 * block.expansion, num_classes)
473 |
474 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False, has_gtlayer=False):
475 | downsample = None
476 | if stride != 1 or self.inplanes != planes * block.expansion:
477 | downsample = nn.Sequential(
478 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
479 | nn.BatchNorm2d(planes * block.expansion)
480 | )
481 |
482 | self.insize = int(self.insize / stride)
483 | layers = []
484 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=has_selayer, use_gt=False))
485 | self.inplanes = planes * block.expansion
486 | for _ in range(1, blocks - 1):
487 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False))
488 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False))
489 | return nn.Sequential(*layers)
490 |
491 | def _correlation(self, x):
492 | b, c, h, w = x.shape
493 | x_v = x.view(b, c, -1) # b x c x (hw)
494 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw)
495 | x_c = x_v - x_m # b x c x (hw)
496 | # x_c = x_v
497 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c
498 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1
499 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c
500 | out = num / dec
501 | out = torch.abs(out)
502 | return out.mean()
503 |
504 | def forward(self, x):
505 | x = self.conv1(x)
506 | x = self.bn1(x)
507 | x = self.relu(x)
508 |
509 | corr1 = self._correlation(x)
510 | x0 = x.clone().detach()
511 |
512 | if self.has_gtlayer and "1" in self.layers:
513 | x = self.nclayer1(x)
514 | corr1_ = self._correlation(x)
515 | else:
516 | corr1_ = corr1.clone()
517 |
518 | x = self.layer1(x)
519 | x1 = x.clone().detach()
520 |
521 | corr2 = self._correlation(x)
522 | if self.has_gtlayer and "2" in self.layers:
523 | x = self.nclayer2(x)
524 | corr2_ = self._correlation(x)
525 | else:
526 | corr2_ = corr2.clone()
527 |
528 | x = self.layer2(x)
529 | x2 = x.clone().detach()
530 |
531 | corr3 = self._correlation(x)
532 | if self.has_gtlayer and "3" in self.layers:
533 | x = self.nclayer3(x)
534 | corr3_ = self._correlation(x)
535 | else:
536 | corr3_ = corr3.clone()
537 |
538 | x = self.layer3(x)
539 | x3 = x.clone().detach()
540 | corr4 = self._correlation(x)
541 |
542 | if self.has_gtlayer and "4" in self.layers:
543 | x = self.nclayer4(x)
544 | corr4_ = self._correlation(x)
545 | else:
546 | corr4_ = corr4.clone()
547 |
548 | # print("corr0: {}, corr1: {}, corr2: {}".format(corr0, corr1, corr2))
549 | x = self.avgpool(x)
550 | x = x.view(x.size(0), -1)
551 | x = self.fc(x)
552 |
553 | return x, (x0, x1, x2, x3), \
554 | (corr1, corr2, corr3, corr4), \
555 | (corr1_, corr2_, corr3_, corr4)
556 |
557 |
558 | class PreAct_ResNet_Cifar(nn.Module):
559 |
560 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False):
561 | super(PreAct_ResNet_Cifar, self).__init__()
562 | self.inplanes = 16
563 | self.insize = 32
564 | self.has_gtlayer = has_gtlayer
565 | self.has_selayer = has_selayer
566 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
567 | self.layer1 = self._make_layer(block, 16, layers[0])
568 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
569 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
570 |
571 | self.bn = nn.BatchNorm2d(64*block.expansion)
572 | self.relu = nn.ReLU(inplace=True)
573 | self.avgpool = nn.AvgPool2d(8, stride=1)
574 | self.fc = nn.Linear(64*block.expansion, num_classes)
575 |
576 | for m in self.modules():
577 | if isinstance(m, nn.Conv2d):
578 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
579 | m.weight.data.normal_(0, math.sqrt(2. / n))
580 | elif isinstance(m, nn.BatchNorm2d):
581 | m.weight.data.fill_(1)
582 | m.bias.data.zero_()
583 |
584 | def _make_layer(self, block, planes, blocks, stride=1):
585 | downsample = None
586 | if stride != 1 or self.inplanes != planes*block.expansion:
587 | downsample = nn.Sequential(
588 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
589 | )
590 |
591 | layers = []
592 | layers.append(block(self.inplanes, planes, stride, downsample))
593 | self.inplanes = planes*block.expansion
594 | for _ in range(1, blocks):
595 | layers.append(block(self.inplanes, planes))
596 | return nn.Sequential(*layers)
597 |
598 | def forward(self, x):
599 | x = self.conv1(x)
600 |
601 | x = self.layer1(x)
602 | x = self.layer2(x)
603 | x = self.layer3(x)
604 |
605 | x = self.bn(x)
606 | x = self.relu(x)
607 | x = self.avgpool(x)
608 | x = x.view(x.size(0), -1)
609 | x = self.fc(x)
610 |
611 | return x
612 |
613 |
614 |
615 | def resnet20a_cifar(**kwargs):
616 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
617 | return model
618 |
619 | def resnet20plain_cifar(**kwargs):
620 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs)
621 | return model
622 |
623 | def resnet32_cifar(**kwargs):
624 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
625 | return model
626 |
627 |
628 | def resnet44_cifar(**kwargs):
629 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
630 | return model
631 |
632 |
633 | def resnet56a_cifar(**kwargs):
634 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
635 | return model
636 |
637 | def resnet62a_cifar(**kwargs):
638 | model = ResNet_Cifar(BasicBlock, [10, 10, 10], **kwargs)
639 | return model
640 |
641 | def resnet68a_cifar(**kwargs):
642 | model = ResNet_Cifar(BasicBlock, [11, 11, 11], **kwargs)
643 | return model
644 |
645 | def resnet74a_cifar(**kwargs):
646 | model = ResNet_Cifar(BasicBlock, [12, 12, 12], **kwargs)
647 | return model
648 |
649 | def resnet80a_cifar(**kwargs):
650 | model = ResNet_Cifar(BasicBlock, [13, 13, 13], **kwargs)
651 | return model
652 |
653 | def resnet86a_cifar(**kwargs):
654 | model = ResNet_Cifar(BasicBlock, [14, 14, 14], **kwargs)
655 | return model
656 |
657 | def resnet92a_cifar(**kwargs):
658 | model = ResNet_Cifar(BasicBlock, [15, 15, 15], **kwargs)
659 | return model
660 |
661 | def resnet98a_cifar(**kwargs):
662 | model = ResNet_Cifar(BasicBlock, [16, 16, 16], **kwargs)
663 | return model
664 |
665 | def resnet104a_cifar(**kwargs):
666 | model = ResNet_Cifar(BasicBlock, [17, 17, 17], **kwargs)
667 | return model
668 |
669 | def resnet110a_cifar(**kwargs):
670 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
671 | return model
672 |
673 | def resnet110plain_cifar(**kwargs):
674 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs)
675 | return model
676 |
677 | def resnet1202_cifar(**kwargs):
678 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs)
679 | return model
680 |
681 |
682 | def resnet164_cifar(**kwargs):
683 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs)
684 | return model
685 |
686 |
687 | def resnet1001_cifar(**kwargs):
688 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs)
689 | return model
690 |
691 |
692 | def preact_resnet110_cifar(**kwargs):
693 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
694 | return model
695 |
696 |
697 | def preact_resnet164_cifar(**kwargs):
698 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs)
699 | return model
700 |
701 |
702 | def preact_resnet1001_cifar(**kwargs):
703 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs)
704 | return model
705 |
706 |
707 | if __name__ == '__main__':
708 | net = resnet20_cifar()
709 | y = net(torch.randn(1, 3, 64, 64))
710 | print(net)
711 | print(y.size())
712 |
--------------------------------------------------------------------------------
/lib/networks/resnet_cifar_analysis1.py:
--------------------------------------------------------------------------------
1 | '''
2 | resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import math
12 |
13 | from lib.layers import *
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | " 3x3 convolution with padding "
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 |
19 | class _CrossNeuronBlock(nn.Module):
20 | def __init__(self, in_channels, in_height, in_width,
21 | nblocks_channel=8,
22 | spatial_height=32, spatial_width=32,
23 | reduction=8, size_is_consistant=True,
24 | communication=True,
25 | enc_dec=True):
26 | # nblock_channel: number of block along channel axis
27 | # spatial_size: spatial_size
28 | super(_CrossNeuronBlock, self).__init__()
29 |
30 | self.communication = communication
31 | self.enc_dec = enc_dec
32 |
33 | # set channel splits
34 | if in_channels <= 512:
35 | self.nblocks_channel = 1
36 | else:
37 | self.nblocks_channel = in_channels // 512
38 | block_size = in_channels // self.nblocks_channel
39 | block = torch.Tensor(block_size, block_size).fill_(1)
40 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
41 | for i in range(self.nblocks_channel):
42 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
43 |
44 | # set spatial splits
45 | if in_height * in_width < 16 * 16 and size_is_consistant:
46 | self.spatial_area = in_height * in_width
47 | self.spatial_height = in_height
48 | self.spatial_width = in_width
49 | else:
50 | self.spatial_area = spatial_height * spatial_width
51 | self.spatial_height = spatial_height
52 | self.spatial_width = spatial_width
53 |
54 | self.fc_in = nn.Sequential(
55 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
56 | nn.ReLU(True),
57 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
58 | )
59 |
60 | self.fc_out = nn.Sequential(
61 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
62 | nn.ReLU(True),
63 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
64 | )
65 |
66 | self.bn = nn.BatchNorm1d(self.spatial_area)
67 |
68 | self.initialize()
69 |
70 | def initialize(self):
71 | for m in self.modules():
72 | if isinstance(m, nn.Conv1d):
73 | nn.init.kaiming_normal_(m.weight)
74 | elif isinstance(m, nn.BatchNorm1d):
75 | nn.init.constant_(m.weight, 1)
76 | nn.init.constant_(m.bias, 0)
77 |
78 | def forward(self, x):
79 | '''
80 | :param x: (bt, c, h, w)
81 | :return:
82 | '''
83 | bt, c, h, w = x.shape
84 | residual = x
85 | x_stretch = x.view(bt, c, h * w)
86 | spblock_h = int(np.ceil(h / self.spatial_height))
87 | spblock_w = int(np.ceil(w / self.spatial_width))
88 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0
89 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0
90 |
91 | if self.spatial_height == h and self.spatial_width == w:
92 | x_stacked = x_stretch # (b) x c x (h * w)
93 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
94 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
95 |
96 | if self.enc_dec:
97 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
98 |
99 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c
100 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
101 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
102 | # x_v = F.dropout(x_v, 0.1, self.training)
103 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
104 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
105 |
106 | if self.communication:
107 | if self.enc_dec:
108 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
109 | else:
110 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
111 | else:
112 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
113 |
114 | out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w)
115 | return (residual + out)
116 | else:
117 | x = F.interpolate(x, (self.spatial_height, self.spatial_width))
118 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width)
119 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width)
120 |
121 | x_stacked = x_stretch # (b) x c x (h * w)
122 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
123 |
124 | if self.enc_dec:
125 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
126 |
127 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c
128 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
129 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
130 | # x_v = F.dropout(x_v, 0.1, self.training)
131 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
132 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
133 | if self.communication:
134 | if self.enc_dec:
135 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
136 | else:
137 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
138 | else:
139 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
140 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width)
141 | out = F.interpolate(out, (h, w))
142 | return (residual + out)
143 |
144 | class BasicBlockPlain(nn.Module):
145 | expansion=1
146 |
147 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, reduction=8):
148 | super(BasicBlockPlain, self).__init__()
149 | self.use_se = use_se
150 | self.conv1 = conv3x3(inplanes, planes, stride)
151 | self.bn1 = nn.BatchNorm2d(planes)
152 | self.relu = nn.ReLU(inplace=True)
153 | self.conv2 = conv3x3(planes, planes)
154 | self.bn2 = nn.BatchNorm2d(planes)
155 | if self.use_se:
156 | self.se = SELayer(planes, reduction=reduction)
157 | self.downsample = downsample
158 | self.stride = stride
159 |
160 | def forward(self, x):
161 | # out = self.gt_spatial(out)
162 | residual = x
163 |
164 | out = self.conv1(x)
165 | out = self.bn1(out)
166 | out = self.relu(out)
167 |
168 | out = self.conv2(out)
169 | out = self.bn2(out)
170 | if self.use_se:
171 | out = self.se(out)
172 |
173 | if self.downsample is not None:
174 | residual = self.downsample(x)
175 |
176 | out = residual
177 | out = self.relu(out)
178 | return out
179 |
180 | class BasicBlock(nn.Module):
181 | expansion=1
182 |
183 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
184 | super(BasicBlock, self).__init__()
185 | self.use_se = use_se
186 | self.use_gt = use_gt
187 | self.conv1 = conv3x3(inplanes, planes, stride)
188 | self.bn1 = nn.BatchNorm2d(planes)
189 | self.relu = nn.ReLU(inplace=True)
190 |
191 | # if not self.use_gt:
192 | self.conv2 = conv3x3(planes, planes)
193 |
194 | self.bn2 = nn.BatchNorm2d(planes)
195 | if self.use_se:
196 | self.se = SELayer(planes, reduction=reduction)
197 | if self.use_gt:
198 | self.gt = _CrossNeuronBlock(planes, insize, insize, spatial_height=16, spatial_width=16, reduction=8)
199 |
200 | self.downsample = downsample
201 | self.stride = stride
202 |
203 | def forward(self, x):
204 | # out = self.gt_spatial(out)
205 | residual = x
206 |
207 | out = self.conv1(x)
208 | out = self.bn1(out)
209 | out = self.relu(out)
210 |
211 | out = self.conv2(out)
212 | out = self.bn2(out)
213 |
214 | if self.use_gt:
215 | out = self.gt(out)
216 |
217 | if self.use_se:
218 | out = self.se(out)
219 |
220 | if self.downsample is not None:
221 | residual = self.downsample(x)
222 |
223 | out += residual
224 | out = self.relu(out)
225 | return out
226 |
227 |
228 | class Bottleneck(nn.Module):
229 | expansion=4
230 |
231 | def __init__(self, insize, inplanes, planes, stride=1, downsample=None, use_se=False, use_gt=False, reduction=8):
232 | super(Bottleneck, self).__init__()
233 | self.use_se = use_se
234 | self.use_gt = use_gt
235 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
236 | self.bn1 = nn.BatchNorm2d(planes)
237 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
238 | self.bn2 = nn.BatchNorm2d(planes)
239 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
240 | self.bn3 = nn.BatchNorm2d(planes*4)
241 | self.relu = nn.ReLU(inplace=True)
242 | if self.use_se:
243 | self.se = SELayer(planes, reduction=reduction)
244 | if self.use_gt:
245 | self.gt = CrossNeuronlBlock2D(planes, insize, insize, insize, insize, reduction=8)
246 | self.downsample = downsample
247 | self.stride = stride
248 |
249 | def forward(self, x):
250 | residual = x
251 |
252 | out = self.conv1(x)
253 | out = self.bn1(out)
254 | out = self.relu(out)
255 |
256 | out = self.conv2(out)
257 | out = self.bn2(out)
258 | out = self.relu(out)
259 |
260 | out = self.conv3(out)
261 | out = self.bn3(out)
262 |
263 | if self.use_se:
264 | out = self.se(out)
265 | if self.use_gt:
266 | out = self.gt(out)
267 |
268 | if self.downsample is not None:
269 | residual = self.downsample(x)
270 |
271 | out += residual
272 | out = self.relu(out)
273 |
274 | return out
275 |
276 | # class CommunicationBlock(nn.Module):
277 |
278 |
279 | class PreActBasicBlock(nn.Module):
280 | expansion = 1
281 |
282 | def __init__(self, inplanes, planes, stride=1, downsample=None):
283 | super(PreActBasicBlock, self).__init__()
284 | self.bn1 = nn.BatchNorm2d(inplanes)
285 | self.relu = nn.ReLU(inplace=True)
286 | self.conv1 = conv3x3(inplanes, planes, stride)
287 | self.bn2 = nn.BatchNorm2d(planes)
288 | self.conv2 = conv3x3(planes, planes)
289 | self.downsample = downsample
290 | self.stride = stride
291 |
292 | def forward(self, x):
293 | residual = x
294 |
295 | out = self.bn1(x)
296 | out = self.relu(out)
297 |
298 | if self.downsample is not None:
299 | residual = self.downsample(out)
300 |
301 | out = self.conv1(out)
302 |
303 | out = self.bn2(out)
304 | out = self.relu(out)
305 | out = self.conv2(out)
306 |
307 | out += residual
308 |
309 | return out
310 |
311 |
312 | class PreActBottleneck(nn.Module):
313 | expansion = 4
314 |
315 | def __init__(self, inplanes, planes, stride=1, downsample=None):
316 | super(PreActBottleneck, self).__init__()
317 | self.bn1 = nn.BatchNorm2d(inplanes)
318 | self.relu = nn.ReLU(inplace=True)
319 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
320 | self.bn2 = nn.BatchNorm2d(planes)
321 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
322 | self.bn3 = nn.BatchNorm2d(planes)
323 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
324 | self.downsample = downsample
325 | self.stride = stride
326 |
327 | def forward(self, x):
328 | residual = x
329 |
330 | out = self.bn1(x)
331 | out = self.relu(out)
332 |
333 | if self.downsample is not None:
334 | residual = self.downsample(out)
335 |
336 | out = self.conv1(out)
337 |
338 | out = self.bn2(out)
339 | out = self.relu(out)
340 | out = self.conv2(out)
341 |
342 | out = self.bn3(out)
343 | out = self.relu(out)
344 | out = self.conv3(out)
345 |
346 | out += residual
347 |
348 | return out
349 |
350 |
351 | class ResNet_Cifar(nn.Module):
352 | def __init__(self, block, layers, num_classes=10, insert_layers=["1", "2", "3"], depth=1,
353 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True):
354 | super(ResNet_Cifar, self).__init__()
355 | self.inplanes = 16
356 | self.insize = 32
357 | self.layers = insert_layers
358 | self.depth = depth
359 | self.has_gtlayer = has_gtlayer
360 | self.has_selayer = has_selayer
361 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
362 | self.bn1 = nn.BatchNorm2d(16)
363 | self.relu = nn.ReLU(inplace=True)
364 | self.layer1 = self._make_layer(block, 16, layers[0], has_selayer=has_selayer,
365 | has_gtlayer=(has_gtlayer if ("1" in self.layers) else False))
366 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, has_selayer=has_selayer,
367 | has_gtlayer=(has_gtlayer if ("2" in self.layers) else False))
368 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, has_selayer=has_selayer,
369 | has_gtlayer=(has_gtlayer if ("3" in self.layers) else False))
370 |
371 | # if self.has_gtlayer:
372 | # if "1" in self.layers:
373 | # layers = []
374 | # for i in range(self.depth):
375 | # layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16,
376 | # communication=communication, enc_dec=enc_dec))
377 | # self.nclayer1 = nn.Sequential(*layers)
378 | #
379 | # if "2" in self.layers:
380 | # layers = []
381 | # for i in range(self.depth):
382 | # layers.append(_CrossNeuronBlock(16, 32, 32, spatial_height=16, spatial_width=16,
383 | # communication=communication, enc_dec=enc_dec))
384 | # self.nclayer2 = nn.Sequential(*layers)
385 | #
386 | # if "3" in self.layers:
387 | # layers = []
388 | # for i in range(self.depth):
389 | # layers.append(_CrossNeuronBlock(32, 16, 16, spatial_height=16, spatial_width=16,
390 | # communication=communication, enc_dec=enc_dec))
391 | # self.nclayer3 = nn.Sequential(*layers)
392 |
393 | self.avgpool = nn.AvgPool2d(8, stride=1)
394 | self.fc = nn.Linear(64 * block.expansion, num_classes)
395 |
396 | def _make_layer(self, block, planes, blocks, stride=1, has_selayer=False, has_gtlayer=False):
397 | downsample = None
398 | if stride != 1 or self.inplanes != planes * block.expansion:
399 | downsample = nn.Sequential(
400 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
401 | nn.BatchNorm2d(planes * block.expansion)
402 | )
403 |
404 | layers = []
405 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=has_selayer, use_gt=has_gtlayer))
406 | self.inplanes = planes * block.expansion
407 | self.insize = int(self.insize / stride)
408 | for _ in range(1, blocks - 1):
409 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False))
410 | layers.append(block(self.insize, self.inplanes, planes, use_se=has_selayer, use_gt=False))
411 | return nn.Sequential(*layers)
412 |
413 | def _correlation(self, x):
414 | b, c, h, w = x.shape
415 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw)
416 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw)
417 | x_c = x_v - x_m # b x c x (hw)
418 | # x_c = x_v
419 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c
420 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1
421 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c
422 | out = num / dec
423 | out = torch.abs(out)
424 | return out.mean()
425 |
426 | def forward(self, x):
427 | x = self.conv1(x)
428 | x = self.bn1(x)
429 | x = self.relu(x)
430 |
431 | x0 = x.clone().detach()
432 | corr1 = self._correlation(x)
433 |
434 | x = self.layer1(x)
435 | x1 = x.clone().detach()
436 | corr2 = self._correlation(x)
437 |
438 | x = self.layer2(x)
439 | x2 = x.clone().detach()
440 | corr3 = self._correlation(x)
441 |
442 | x = self.layer3(x)
443 | x3 = x.clone().detach()
444 | corr4 = self._correlation(x)
445 |
446 | x = self.avgpool(x)
447 | x = x.view(x.size(0), -1)
448 | x = self.fc(x)
449 |
450 | return x, (x0, x1, x2, x3), (corr1.item(), corr2.item(), corr3.item(), corr4.item()), (corr1.item(), corr2.item(), corr3.item(), corr4.item())
451 |
452 |
453 | class PreAct_ResNet_Cifar(nn.Module):
454 |
455 | def __init__(self, block, layers, num_classes=10, has_gtlayer=False, has_selayer=False):
456 | super(PreAct_ResNet_Cifar, self).__init__()
457 | self.inplanes = 16
458 | self.insize = 32
459 | self.has_gtlayer = has_gtlayer
460 | self.has_selayer = has_selayer
461 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
462 | self.layer1 = self._make_layer(block, 16, layers[0])
463 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
464 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
465 |
466 | self.bn = nn.BatchNorm2d(64*block.expansion)
467 | self.relu = nn.ReLU(inplace=True)
468 | self.avgpool = nn.AvgPool2d(8, stride=1)
469 | self.fc = nn.Linear(64*block.expansion, num_classes)
470 |
471 | for m in self.modules():
472 | if isinstance(m, nn.Conv2d):
473 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
474 | m.weight.data.normal_(0, math.sqrt(2. / n))
475 | elif isinstance(m, nn.BatchNorm2d):
476 | m.weight.data.fill_(1)
477 | m.bias.data.zero_()
478 |
479 | def _make_layer(self, block, planes, blocks, stride=1):
480 | downsample = None
481 | if stride != 1 or self.inplanes != planes*block.expansion:
482 | downsample = nn.Sequential(
483 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
484 | )
485 |
486 | layers = []
487 | layers.append(block(self.inplanes, planes, stride, downsample))
488 | self.inplanes = planes*block.expansion
489 | for _ in range(1, blocks):
490 | layers.append(block(self.inplanes, planes))
491 | return nn.Sequential(*layers)
492 |
493 | def forward(self, x):
494 | x = self.conv1(x)
495 |
496 | x = self.layer1(x)
497 | x = self.layer2(x)
498 | x = self.layer3(x)
499 |
500 | x = self.bn(x)
501 | x = self.relu(x)
502 | x = self.avgpool(x)
503 | x = x.view(x.size(0), -1)
504 | x = self.fc(x)
505 |
506 | return x
507 |
508 |
509 |
510 | def resnet20a_cifar(**kwargs):
511 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
512 | return model
513 |
514 | def resnet20plain_cifar(**kwargs):
515 | model = ResNet_Cifar(BasicBlockPlain, [3, 3, 3], **kwargs)
516 | return model
517 |
518 | def resnet32_cifar(**kwargs):
519 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
520 | return model
521 |
522 |
523 | def resnet44_cifar(**kwargs):
524 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
525 | return model
526 |
527 |
528 | def resnet56a_cifar(**kwargs):
529 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
530 | return model
531 |
532 | def resnet62a_cifar(**kwargs):
533 | model = ResNet_Cifar(BasicBlock, [10, 10, 10], **kwargs)
534 | return model
535 |
536 | def resnet68a_cifar(**kwargs):
537 | model = ResNet_Cifar(BasicBlock, [11, 11, 11], **kwargs)
538 | return model
539 |
540 | def resnet74a_cifar(**kwargs):
541 | model = ResNet_Cifar(BasicBlock, [12, 12, 12], **kwargs)
542 | return model
543 |
544 | def resnet80a_cifar(**kwargs):
545 | model = ResNet_Cifar(BasicBlock, [13, 13, 13], **kwargs)
546 | return model
547 |
548 | def resnet86a_cifar(**kwargs):
549 | model = ResNet_Cifar(BasicBlock, [14, 14, 14], **kwargs)
550 | return model
551 |
552 | def resnet92a_cifar(**kwargs):
553 | model = ResNet_Cifar(BasicBlock, [15, 15, 15], **kwargs)
554 | return model
555 |
556 | def resnet98a_cifar(**kwargs):
557 | model = ResNet_Cifar(BasicBlock, [16, 16, 16], **kwargs)
558 | return model
559 |
560 | def resnet104a_cifar(**kwargs):
561 | model = ResNet_Cifar(BasicBlock, [17, 17, 17], **kwargs)
562 | return model
563 |
564 | def resnet110a_cifar(**kwargs):
565 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
566 | return model
567 |
568 | def resnet110plain_cifar(**kwargs):
569 | model = ResNet_Cifar(BasicBlockPlain, [18, 18, 18], **kwargs)
570 | return model
571 |
572 | def resnet1202_cifar(**kwargs):
573 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs)
574 | return model
575 |
576 |
577 | def resnet164_cifar(**kwargs):
578 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs)
579 | return model
580 |
581 |
582 | def resnet1001_cifar(**kwargs):
583 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs)
584 | return model
585 |
586 |
587 | def preact_resnet110_cifar(**kwargs):
588 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
589 | return model
590 |
591 |
592 | def preact_resnet164_cifar(**kwargs):
593 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs)
594 | return model
595 |
596 |
597 | def preact_resnet1001_cifar(**kwargs):
598 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs)
599 | return model
600 |
601 |
602 | if __name__ == '__main__':
603 | net = resnet20_cifar()
604 | y = net(torch.randn(1, 3, 64, 64))
605 | print(net)
606 | print(y.size())
607 |
--------------------------------------------------------------------------------
/lib/networks/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
8 | """
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 | import torch
14 |
15 | __all__ = ['resnext50', 'resnext101', 'resnext152']
16 |
17 | class Bottleneck(nn.Module):
18 | """
19 | RexNeXt bottleneck type C
20 | """
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None):
24 | """ Constructor
25 | Args:
26 | inplanes: input channel dimensionality
27 | planes: output channel dimensionality
28 | baseWidth: base width.
29 | cardinality: num of convolution groups.
30 | stride: conv stride. Replaces pooling layer.
31 | """
32 | super(Bottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64)))
35 | C = cardinality
36 |
37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(D*C)
39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
40 | self.bn2 = nn.BatchNorm2d(D*C)
41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes * 4)
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | self.downsample = downsample
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 | """
72 | ResNext optimized for the ImageNet dataset, as specified in
73 | https://arxiv.org/pdf/1611.05431.pdf
74 | """
75 | def __init__(self, baseWidth, cardinality, layers, num_classes):
76 | """ Constructor
77 | Args:
78 | baseWidth: baseWidth for ResNeXt.
79 | cardinality: number of convolution groups.
80 | layers: config of layers, e.g., [3, 4, 6, 3]
81 | num_classes: number of classes
82 | """
83 | super(ResNeXt, self).__init__()
84 | block = Bottleneck
85 |
86 | self.cardinality = cardinality
87 | self.baseWidth = baseWidth
88 | self.num_classes = num_classes
89 | self.inplanes = 64
90 | self.output_size = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0])
97 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
98 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
99 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
100 | self.avgpool = nn.AvgPool2d(7)
101 | self.fc = nn.Linear(512 * block.expansion, num_classes)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
113 | Args:
114 | block: block type used to construct ResNext
115 | planes: number of output channels (need to multiply by block.expansion)
116 | blocks: number of blocks to be built
117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
118 | Returns: a Module consisting of n sequential bottlenecks.
119 | """
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool1(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | x = self.fc(x)
148 |
149 | return x
150 |
151 |
152 | def resnext50(baseWidth=32, cardinality=4, num_classes=1000):
153 | """
154 | Construct ResNeXt-50.
155 | """
156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], num_classes)
157 | return model
158 |
159 |
160 | def resnext101(baseWidth, cardinality):
161 | """
162 | Construct ResNeXt-101.
163 | """
164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000)
165 | return model
166 |
167 |
168 | def resnext152(baseWidth, cardinality):
169 | """
170 | Construct ResNeXt-152.
171 | """
172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000)
173 | return model
174 |
--------------------------------------------------------------------------------
/lib/networks/resnext_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | resneXt for cifar with pytorch
3 |
4 | Reference:
5 | [1] S. Xie, G. Ross, P. Dollar, Z. Tu and K. He Aggregated residual transformations for deep neural networks. In CVPR, 2017
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import math
11 |
12 | from lib.layers import *
13 |
14 | class _CrossNeuronBlock(nn.Module):
15 | def __init__(self, in_channels, in_height, in_width,
16 | nblocks_channel=8,
17 | spatial_height=32, spatial_width=32,
18 | reduction=8, size_is_consistant=True,
19 | communication=True,
20 | enc_dec=True):
21 | # nblock_channel: number of block along channel axis
22 | # spatial_size: spatial_size
23 | super(_CrossNeuronBlock, self).__init__()
24 |
25 | self.communication = communication
26 | self.enc_dec = enc_dec
27 |
28 | # set channel splits
29 | if in_channels <= 512:
30 | self.nblocks_channel = 1
31 | else:
32 | self.nblocks_channel = in_channels // 512
33 | block_size = in_channels // self.nblocks_channel
34 | block = torch.Tensor(block_size, block_size).fill_(1)
35 | self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
36 | for i in range(self.nblocks_channel):
37 | self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)
38 |
39 | # set spatial splits
40 | if in_height * in_width < 16 * 16 and size_is_consistant:
41 | self.spatial_area = in_height * in_width
42 | self.spatial_height = in_height
43 | self.spatial_width = in_width
44 | else:
45 | self.spatial_area = spatial_height * spatial_width
46 | self.spatial_height = spatial_height
47 | self.spatial_width = spatial_width
48 |
49 | self.fc_in = nn.Sequential(
50 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
51 | nn.ReLU(True),
52 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
53 | )
54 |
55 | self.fc_out = nn.Sequential(
56 | nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
57 | nn.ReLU(True),
58 | nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
59 | )
60 |
61 | self.bn = nn.BatchNorm1d(self.spatial_area)
62 |
63 | self.initialize()
64 |
65 | def initialize(self):
66 | for m in self.modules():
67 | if isinstance(m, nn.Conv1d):
68 | nn.init.kaiming_normal_(m.weight)
69 | elif isinstance(m, nn.BatchNorm1d):
70 | nn.init.constant_(m.weight, 1)
71 | nn.init.constant_(m.bias, 0)
72 |
73 | def forward(self, x):
74 | '''
75 | :param x: (bt, c, h, w)
76 | :return:
77 | '''
78 | bt, c, h, w = x.shape
79 | residual = x
80 | x_stretch = x.view(bt, c, h * w)
81 | spblock_h = int(np.ceil(h / self.spatial_height))
82 | spblock_w = int(np.ceil(w / self.spatial_width))
83 | stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0
84 | stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0
85 |
86 | if self.spatial_height == h and self.spatial_width == w:
87 | x_stacked = x_stretch # (b) x c x (h * w)
88 | x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
89 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
90 |
91 | if self.enc_dec:
92 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
93 |
94 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c
95 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
96 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
97 | # x_v = F.dropout(x_v, 0.1, self.training)
98 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
99 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
100 |
101 | if self.communication:
102 | if self.enc_dec:
103 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
104 | else:
105 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
106 | else:
107 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
108 |
109 | out = F.dropout(out.permute(0, 2, 1).contiguous().view(bt, c, h, w), 0.0, self.training)
110 | return F.relu(residual + out)
111 | else:
112 | x = F.interpolate(x, (self.spatial_height, self.spatial_width))
113 | x_stretch = x.view(bt, c, self.spatial_height * self.spatial_width)
114 | x_stretch = x.view(bt * self.nblocks_channel, c // self.nblocks_channel, self.spatial_height * self.spatial_width)
115 |
116 | x_stacked = x_stretch # (b) x c x (h * w)
117 | x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
118 |
119 | if self.enc_dec:
120 | x_v = self.fc_in(x_v) # (b) x (h * w) x c
121 |
122 | x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel) # (b * h * w) x 1 x c
123 | score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
124 | # score = torch.bmm(x_v.transpose(1, 2).contiguous(), x_v)
125 | # x_v = F.dropout(x_v, 0.1, self.training)
126 | # score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
127 | attn = F.softmax(score, dim=1) # (b * h * w) x c x c
128 | if self.communication:
129 | if self.enc_dec:
130 | out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
131 | else:
132 | out = self.bn(torch.bmm(x_v, attn)) # (b) x (h * w) x c
133 | else:
134 | out = self.bn(self.fc_out(x_v)) # (b) x (h * w) x c
135 | out = out.permute(0, 2, 1).contiguous().view(bt, c, self.spatial_height, self.spatial_width)
136 | out = F.dropout(F.interpolate(out, (h, w)), 0.0, self.training)
137 | return F.relu(residual + out)
138 |
139 | class Bottleneck(nn.Module):
140 | expansion = 4
141 |
142 | def __init__(self, inplanes, planes, cardinality, baseWidth, stride=1, downsample=None, use_se=False):
143 | super(Bottleneck, self).__init__()
144 | D = int(planes * (baseWidth / 64.))
145 | C = cardinality
146 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, bias=False)
147 | self.bn1 = nn.BatchNorm2d(D*C)
148 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
149 | self.bn2 = nn.BatchNorm2d(D*C)
150 | self.conv3 = nn.Conv2d(D*C, planes*4, kernel_size=1, bias=False)
151 | self.bn3 = nn.BatchNorm2d(planes*4)
152 | self.se = nn.SELayer(planes*4, reduction=8)
153 | self.relu = nn.ReLU(inplace=True)
154 | self.downsample = downsample
155 | self.stride = stride
156 |
157 | def forward(self, x):
158 | residual = x
159 |
160 | out = self.conv1(x)
161 | out = self.bn1(out)
162 | out = self.relu(out)
163 |
164 | out = self.conv2(out)
165 | out = self.bn2(out)
166 | out = self.relu(out)
167 |
168 | out = self.conv3(out)
169 | out = self.bn3(out)
170 | out = self.se(out)
171 |
172 | if self.downsample is not None:
173 | residual = self.downsample(x)
174 |
175 | if residual.size() != out.size():
176 | print(out.size(), residual.size())
177 | out += residual
178 | out = self.relu(out)
179 |
180 | return out
181 |
182 |
183 | class ResNeXt_Cifar(nn.Module):
184 |
185 | def __init__(self, block, layers, cardinality, baseWidth, num_classes=10, insert_layers=["1", "2", "3"], depth=1,
186 | has_gtlayer=False, has_selayer=False, communication=True, enc_dec=True):
187 | super(ResNeXt_Cifar, self).__init__()
188 | self.inplanes = 64
189 |
190 | self.insize = 32
191 | self.layers = insert_layers
192 | self.depth = depth
193 | self.has_gtlayer = has_gtlayer
194 | self.has_selayer = has_selayer
195 |
196 | self.cardinality = cardinality
197 | self.baseWidth = baseWidth
198 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
199 | self.bn1 = nn.BatchNorm2d(64)
200 | self.relu = nn.ReLU(inplace=True)
201 | self.layer1 = self._make_layer(block, 64, layers[0])
202 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
203 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
204 | self.avgpool = nn.AvgPool2d(8, stride=1)
205 | self.fc = nn.Linear(256 * block.expansion, num_classes)
206 |
207 | if self.has_gtlayer:
208 | if "1" in self.layers:
209 | layers = []
210 | for i in range(self.depth):
211 | layers.append(_CrossNeuronBlock(64, 32, 32, spatial_height=16, spatial_width=16,
212 | communication=communication, enc_dec=enc_dec))
213 | self.nclayer1 = nn.Sequential(*layers)
214 |
215 | if "2" in self.layers:
216 | layers = []
217 | for i in range(self.depth):
218 | layers.append(_CrossNeuronBlock(64, 32, 32, spatial_height=16, spatial_width=16,
219 | communication=communication, enc_dec=enc_dec))
220 | self.nclayer2 = nn.Sequential(*layers)
221 |
222 | if "3" in self.layers:
223 | layers = []
224 | for i in range(self.depth):
225 | layers.append(_CrossNeuronBlock(128, 16, 16, spatial_height=8, spatial_width=8,
226 | communication=communication, enc_dec=enc_dec))
227 | self.nclayer3 = nn.Sequential(*layers)
228 |
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
232 | m.weight.data.normal_(0, math.sqrt(2. / n))
233 | elif isinstance(m, nn.BatchNorm2d):
234 | m.weight.data.fill_(1)
235 | m.bias.data.zero_()
236 |
237 | def _make_layer(self, block, planes, blocks, stride=1):
238 | downsample = None
239 | if stride != 1 or self.inplanes != planes * block.expansion:
240 | downsample = nn.Sequential(
241 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
242 | nn.BatchNorm2d(planes * block.expansion)
243 | )
244 |
245 | layers = []
246 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth, stride, downsample, use_se=self.has_selayer))
247 | self.inplanes = planes * block.expansion
248 | for _ in range(1, blocks):
249 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth, use_se=self.has_selayer))
250 |
251 | return nn.Sequential(*layers)
252 |
253 |
254 | def _correlation(self, x):
255 | b, c, h, w = x.shape
256 | x_v = x.clone().detach().view(b, c, -1) # b x c x (hw)
257 | x_m = x_v.mean(1).unsqueeze(1) # b x 1 x (hw)
258 | x_c = x_v - x_m # b x c x (hw)
259 | # x_c = x_v
260 | num = torch.bmm(x_c, x_c.transpose(1, 2)) # b x c x c
261 | x_mode = torch.sqrt(torch.sum(x_c ** 2, 2).unsqueeze(2)) # b x c x 1
262 | dec = torch.bmm(x_mode, x_mode.transpose(1, 2).contiguous()) # b x c x c
263 | out = num / dec
264 | out = torch.abs(out)
265 | return out.mean()
266 |
267 | def forward(self, x):
268 | x = self.conv1(x)
269 | x = self.bn1(x)
270 | x = self.relu(x)
271 |
272 | corr1 = self._correlation(x)
273 | x0 = x.clone().detach()
274 |
275 | if self.has_gtlayer and "1" in self.layers:
276 | x = self.nclayer1(x)
277 | corr1_ = self._correlation(x)
278 | else:
279 | corr1_ = corr1.clone()
280 |
281 | x = self.layer1(x)
282 | x1 = x.clone().detach()
283 |
284 | corr2 = self._correlation(x)
285 | if self.has_gtlayer and "2" in self.layers:
286 | x = self.nclayer2(x)
287 | corr2_ = self._correlation(x)
288 | else:
289 | corr2_ = corr2.clone()
290 |
291 | x = self.layer2(x)
292 | x2 = x.clone().detach()
293 |
294 | corr3 = self._correlation(x)
295 | if self.has_gtlayer and "3" in self.layers:
296 | x = self.nclayer3(x)
297 | corr3_ = self._correlation(x)
298 | else:
299 | corr3_ = corr3.clone()
300 |
301 | x = self.layer3(x)
302 | x3 = x.clone().detach()
303 |
304 | # print("corr0: {}, corr1: {}, corr2: {}".format(corr0, corr1, corr2))
305 | x = self.avgpool(x)
306 | x = x.view(x.size(0), -1)
307 | x = self.fc(x)
308 |
309 | return x, (x0, x1, x2, x3), (corr1.item(), corr2.item(), corr3.item()), (corr1_.item(), corr2_.item(), corr3_.item())
310 |
311 |
312 | # def forward(self, x):
313 | # x = self.conv1(x)
314 | # x = self.bn1(x)
315 | # x = self.relu(x)
316 | #
317 | # x = self.layer1(x)
318 | # x = self.layer2(x)
319 | # x = self.layer3(x)
320 | #
321 | # x = self.avgpool(x)
322 | # x = x.view(x.size(0), -1)
323 | # x = self.fc(x)
324 | #
325 | # return x
326 |
327 |
328 | def resneXt110_cifar(**kwargs):
329 | model = ResNeXt_Cifar(BasicBlock, [18, 18, 18], **kwargs)
330 | return model
331 |
332 | def resneXt164_cifar(**kwargs):
333 | model = ResNeXt_Cifar(Bottleneck, [18, 18, 18], **kwargs)
334 | return model
335 |
336 | def resneXt_cifar(depth, cardinality, baseWidth, **kwargs):
337 | assert (depth - 2) % 9 == 0
338 | n = (depth - 2) / 9
339 | model = ResNeXt_Cifar(Bottleneck, [n, n, n], cardinality, baseWidth, **kwargs)
340 | return model
341 |
342 |
343 | if __name__ == '__main__':
344 | net = resneXt_cifar(29, 16, 64)
345 | y = net(torch.randn(1, 3, 32, 32))
346 | print(net)
347 | print(y.size())
348 |
--------------------------------------------------------------------------------
/lib/networks/wide_resnet_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | wide resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import math
10 | from .resnet_cifar_analysis import BasicBlock
11 |
12 |
13 | class Wide_ResNet_Cifar(nn.Module):
14 |
15 | def __init__(self, block, layers, wfactor, num_classes=100, has_selayer=False, has_gtlayer=False):
16 | super(Wide_ResNet_Cifar, self).__init__()
17 | self.inplanes = 16
18 | self.insize = 32
19 | self.has_selayer = has_selayer
20 | self.has_gtlayer = has_gtlayer
21 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(16)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
25 | self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
26 | self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
27 | self.avgpool = nn.AvgPool2d(8, stride=1)
28 | self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
29 |
30 | for m in self.modules():
31 | if isinstance(m, nn.Conv2d):
32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
33 | m.weight.data.normal_(0, math.sqrt(2. / n))
34 | elif isinstance(m, nn.BatchNorm2d):
35 | m.weight.data.fill_(1)
36 | m.bias.data.zero_()
37 |
38 | def _make_layer(self, block, planes, blocks, stride=1):
39 | downsample = None
40 | if stride != 1 or self.inplanes != planes * block.expansion:
41 | downsample = nn.Sequential(
42 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
43 | nn.BatchNorm2d(planes * block.expansion)
44 | )
45 |
46 | self.insize = int(self.insize / stride)
47 | layers = []
48 | layers.append(block(self.insize, self.inplanes, planes, stride, downsample, use_se=self.has_selayer, use_gt=self.has_gtlayer))
49 | self.inplanes = planes * block.expansion
50 | for _ in range(1, blocks):
51 | layers.append(block(self.insize, self.inplanes, planes, use_se=self.has_selayer))
52 |
53 | return nn.Sequential(*layers)
54 |
55 | def forward(self, x):
56 | x = self.conv1(x)
57 | x = self.bn1(x)
58 | x = self.relu(x)
59 |
60 | x = self.layer1(x)
61 | x = self.layer2(x)
62 | x = self.layer3(x)
63 |
64 | x = self.avgpool(x)
65 | x = x.view(x.size(0), -1)
66 | x = self.fc(x)
67 |
68 | return x
69 |
70 |
71 | def wresnet20_cifar(**kwargs):
72 | return Wide_ResNet_Cifar(BasicBlock, [3, 3, 3], 10, **kwargs)
73 |
74 | def wide_resnet_cifar(depth, width, **kwargs):
75 | assert (depth - 2) % 6 == 0
76 | n = (depth - 2) / 6
77 | return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
78 |
79 | if __name__=='__main__':
80 | net = wide_resnet_cifar(20, 10)
81 | y = net(torch.randn(1, 3, 32, 32))
82 | print(isinstance(net, Wide_ResNet_Cifar))
83 | print(y.size())
84 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .verbo import *
2 |
--------------------------------------------------------------------------------
/lib/utils/cub2011.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from torchvision.datasets.folder import default_loader
4 | from torchvision.datasets.utils import download_url
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class Cub2011(Dataset):
9 | base_folder = 'CUB_200_2011/images'
10 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
11 | filename = 'CUB_200_2011.tgz'
12 | tgz_md5 = '97eceeb196236b17998738112f37df78'
13 |
14 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
15 | self.root = os.path.expanduser(root)
16 | self.transform = transform
17 | self.loader = default_loader
18 | self.train = train
19 |
20 | if download:
21 | self._download()
22 |
23 | if not self._check_integrity():
24 | raise RuntimeError('Dataset not found or corrupted.' +
25 | ' You can use download=True to download it')
26 |
27 | def _load_metadata(self):
28 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
29 | names=['img_id', 'filepath'])
30 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
31 | sep=' ', names=['img_id', 'target'])
32 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
33 | sep=' ', names=['img_id', 'is_training_img'])
34 |
35 | data = images.merge(image_class_labels, on='img_id')
36 | self.data = data.merge(train_test_split, on='img_id')
37 |
38 | if self.train:
39 | self.data = self.data[self.data.is_training_img == 1]
40 | else:
41 | self.data = self.data[self.data.is_training_img == 0]
42 |
43 | def _check_integrity(self):
44 | try:
45 | self._load_metadata()
46 | except Exception:
47 | return False
48 |
49 | for index, row in self.data.iterrows():
50 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
51 | if not os.path.isfile(filepath):
52 | import pdb; pdb.set_trace()
53 | print(filepath)
54 | return False
55 | return True
56 |
57 | def _download(self):
58 | import tarfile
59 |
60 | if self._check_integrity():
61 | print('Files already downloaded and verified')
62 | return
63 |
64 | download_url(self.url, self.root, self.filename, self.tgz_md5)
65 |
66 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
67 | tar.extractall(path=self.root)
68 |
69 | def __len__(self):
70 | return len(self.data)
71 |
72 | def __getitem__(self, idx):
73 | sample = self.data.iloc[idx]
74 | path = os.path.join(self.root, self.base_folder, sample.filepath)
75 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
76 | img = self.loader(path)
77 |
78 | if self.transform is not None:
79 | img = self.transform(img)
80 |
81 | return img, target
82 |
--------------------------------------------------------------------------------
/lib/utils/imagenet.py:
--------------------------------------------------------------------------------
1 | # dataloader respecting the PyTorch conventions, but using tensorpack to load and process
2 | # includes typical augmentations for ImageNet training
3 |
4 | import os
5 |
6 | import cv2
7 | import torch
8 |
9 | import numpy as np
10 | import tensorpack.dataflow as td
11 | from tensorpack import imgaug
12 | from tensorpack.dataflow import (AugmentImageComponent, PrefetchDataZMQ,
13 | BatchData, MultiThreadMapData)
14 |
15 | #####################################################################################################
16 | # copied from: https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/imagenet_utils.py #
17 | #####################################################################################################
18 | class GoogleNetResize(imgaug.ImageAugmentor):
19 | """
20 | crop 8%~100% of the original image
21 | See `Going Deeper with Convolutions` by Google.
22 | """
23 | def __init__(self, crop_area_fraction=0.08,
24 | aspect_ratio_low=0.75, aspect_ratio_high=1.333,
25 | target_shape=224):
26 | self._init(locals())
27 |
28 | def _augment(self, img, _):
29 | h, w = img.shape[:2]
30 | area = h * w
31 | for _ in range(10):
32 | targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area
33 | aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high)
34 | ww = int(np.sqrt(targetArea * aspectR) + 0.5)
35 | hh = int(np.sqrt(targetArea / aspectR) + 0.5)
36 | if self.rng.uniform() < 0.5:
37 | ww, hh = hh, ww
38 | if hh <= h and ww <= w:
39 | x1 = 0 if w == ww else self.rng.randint(0, w - ww)
40 | y1 = 0 if h == hh else self.rng.randint(0, h - hh)
41 | out = img[y1:y1 + hh, x1:x1 + ww]
42 | out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC)
43 | return out
44 | out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img)
45 | out = imgaug.CenterCrop(self.target_shape).augment(out)
46 | return out
47 |
48 |
49 | def fbresnet_augmentor(isTrain):
50 | """
51 | Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
52 | """
53 | if isTrain:
54 | augmentors = [
55 | GoogleNetResize(),
56 | imgaug.RandomOrderAug(
57 | [imgaug.BrightnessScale((0.6, 1.4), clip=False),
58 | imgaug.Contrast((0.6, 1.4), clip=False),
59 | imgaug.Saturation(0.4, rgb=False),
60 | # rgb-bgr conversion for the constants copied from fb.resnet.torch
61 | imgaug.Lighting(0.1,
62 | eigval=np.asarray(
63 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
64 | eigvec=np.array(
65 | [[-0.5675, 0.7192, 0.4009],
66 | [-0.5808, -0.0045, -0.8140],
67 | [-0.5836, -0.6948, 0.4203]],
68 | dtype='float32')[::-1, ::-1]
69 | )]),
70 | imgaug.Flip(horiz=True),
71 | ]
72 | else:
73 | augmentors = [
74 | imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
75 | imgaug.CenterCrop((224, 224)),
76 | ]
77 | return augmentors
78 | #####################################################################################################
79 | #####################################################################################################
80 |
81 |
82 | numpy_type_map = {
83 | 'float64': torch.DoubleTensor,
84 | 'float32': torch.FloatTensor,
85 | 'float16': torch.HalfTensor,
86 | 'int64': torch.LongTensor,
87 | 'int32': torch.IntTensor,
88 | 'int16': torch.ShortTensor,
89 | 'int8': torch.CharTensor,
90 | 'uint8': torch.ByteTensor,
91 | }
92 |
93 |
94 | def default_collate(batch):
95 | "Puts each data field into a tensor with outer dimension batch size"
96 |
97 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
98 | elem_type = type(batch[0])
99 | if torch.is_tensor(batch[0]):
100 | out = None
101 | if _use_shared_memory:
102 | # If we're in a background process, concatenate directly into a
103 | # shared memory tensor to avoid an extra copy
104 | numel = sum([x.numel() for x in batch])
105 | storage = batch[0].storage()._new_shared(numel)
106 | out = batch[0].new(storage)
107 | return torch.stack(batch, 0, out=out)
108 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
109 | and elem_type.__name__ != 'string_':
110 | elem = batch[0]
111 | if elem_type.__name__ == 'ndarray':
112 | # array of string classes and object
113 | if re.search('[SaUO]', elem.dtype.str) is not None:
114 | raise TypeError(error_msg.format(elem.dtype))
115 |
116 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
117 | if elem.shape == (): # scalars
118 | py_type = float if elem.dtype.name.startswith('float') else int
119 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
120 | elif isinstance(batch[0], int):
121 | return torch.LongTensor(batch)
122 | elif isinstance(batch[0], float):
123 | return torch.DoubleTensor(batch)
124 | elif isinstance(batch[0], string_classes):
125 | return batch
126 | elif isinstance(batch[0], collections.Mapping):
127 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
128 | elif isinstance(batch[0], collections.Sequence):
129 | transposed = zip(*batch)
130 | return [default_collate(samples) for samples in transposed]
131 |
132 | raise TypeError((error_msg.format(type(batch[0]))))
133 |
134 |
135 | class Loader(object):
136 | """
137 | Data loader. Combines a dataset and a sampler, and provides
138 | single- or multi-process iterators over the dataset.
139 |
140 | Arguments:
141 | mode (str, required): mode of dataset to operate in, one of ['train', 'val']
142 | batch_size (int, optional): how many samples per batch to load
143 | (default: 1).
144 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
145 | at every epoch (default: False).
146 | num_workers (int, optional): how many subprocesses to use for data
147 | loading. 0 means that the data will be loaded in the main process
148 | (default: 0)
149 | cache (int, optional): cache size to use when loading data,
150 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
151 | if the dataset size is not divisible by the batch size. If ``False`` and
152 | the size of dataset is not divisible by the batch size, then the last batch
153 | will be smaller. (default: False)
154 | cuda (bool, optional): set to ``True`` and the PyTorch tensors will get preloaded
155 | to the GPU for you (necessary because this lets us to uint8 conversion on the
156 | GPU, which is faster).
157 | """
158 |
159 | def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000,
160 | collate_fn=default_collate, drop_last=False, cuda=False):
161 | # enumerate standard imagenet augmentors
162 | imagenet_augmentors = fbresnet_augmentor(mode == 'train')
163 |
164 | # load the lmdb if we can find it
165 | lmdb_loc = os.path.join('/srv/share/datasets/ImageNet','ILSVRC-%s.lmdb'%mode)
166 | ds = td.LMDBSerializer.load(lmdb_loc, shuffle=False)
167 | ds = td.LocallyShuffleData(ds, cache)
168 | ds = td.PrefetchData(ds, 5000, 1)
169 | # ds = td.LMDBDataPoint(ds)
170 | ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
171 | ds = td.AugmentImageComponent(ds, imagenet_augmentors)
172 | ds = td.PrefetchDataZMQ(ds, num_workers)
173 | self.ds = td.BatchData(ds, batch_size)
174 | self.ds.reset_state()
175 |
176 | self.batch_size = batch_size
177 | self.num_workers = num_workers
178 | self.cuda = cuda
179 | #self.drop_last = drop_last
180 |
181 | def __iter__(self):
182 | for x, y in self.ds.get_data():
183 | if self.cuda:
184 | # images come out as uint8, which are faster to copy onto the gpu
185 | x = torch.ByteTensor(x).cuda()
186 | y = torch.IntTensor(y).cuda()
187 | # but once they're on the gpu, we'll need them in
188 | yield uint8_to_float(x), y.long()
189 | # yield (x), y.long()
190 | else:
191 | yield uint8_to_float(torch.ByteTensor(x)), torch.IntTensor(y).long()
192 |
193 | def __len__(self):
194 | return self.ds.size()
195 |
196 | def uint8_to_float(x):
197 | x = x.permute(0,3,1,2) # pytorch is (n,c,w,h)
198 | return x.float()/128. - 1.
199 |
--------------------------------------------------------------------------------
/lib/utils/net_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torchvision.models as models
7 | import cv2
8 | import pdb
9 | import random
10 |
11 | def adjust_learning_rate(optimizer, decay=0.1):
12 | """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs"""
13 | for param_group in optimizer.param_groups:
14 | param_group['lr'] = decay * param_group['lr']
15 |
16 | def weights_normal_init(model, dev=0.01):
17 | if isinstance(model, list):
18 | for m in model:
19 | weights_normal_init(m, dev)
20 | else:
21 | for m in model.modules():
22 | if isinstance(m, nn.Conv2d):
23 | m.weight.data.normal_(0.0, dev)
24 | elif isinstance(m, nn.Linear):
25 | m.weight.data.normal_(0.0, dev)
26 |
--------------------------------------------------------------------------------
/lib/utils/verbo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 | def __init__(self, name, fmt=':f'):
7 | self.name = name
8 | self.fmt = fmt
9 | self.reset()
10 |
11 | def reset(self):
12 | self.val = 0
13 | self.avg = 0
14 | self.sum = 0
15 | self.count = 0
16 |
17 | def update(self, val, n=1):
18 | self.val = val
19 | self.sum += val * n
20 | self.count += n
21 | self.avg = self.sum / self.count
22 |
23 | def __str__(self):
24 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
25 | return fmtstr.format(**self.__dict__)
26 |
27 | class ProgressMeter(object):
28 | def __init__(self, num_batches, *meters, prefix=""):
29 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
30 | self.meters = meters
31 | self.prefix = prefix
32 |
33 | def print(self, batch):
34 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
35 | entries += [str(meter) for meter in self.meters]
36 | print(' '.join(entries))
37 |
38 | def _get_batch_fmtstr(self, num_batches):
39 | num_digits = len(str(num_batches // 1))
40 | fmt = '{:' + str(num_digits) + 'd}'
41 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
42 |
--------------------------------------------------------------------------------
/lib/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import cv2
4 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import pdb
5 | import time
6 | import yaml
7 | import os
8 | import os.path as osp
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torchvision import datasets, transforms
13 | from lib.utils.net_utils import weights_normal_init, adjust_learning_rate
14 | from configs import cfg
15 | from lib import *
16 | from ptflops import get_model_complexity_info
17 | from thop import profile
18 |
19 | def compute_accuracy(output, target, topk=(1,)):
20 | """Computes the accuracy over the k top predictions for the specified values of k"""
21 | with torch.no_grad():
22 | maxk = max(topk)
23 | batch_size = target.size(0)
24 |
25 | _, pred = output.topk(maxk, 1, True, True)
26 | pred = pred.t()
27 | correct = pred.eq(target.view(1, -1).expand_as(pred))
28 |
29 | res = []
30 | for k in topk:
31 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
32 | res.append(correct_k.mul_(100.0 / batch_size))
33 | return res
34 |
35 | def graph_node_loss(graphs, target):
36 | loss = 0
37 | for graph in graphs:
38 | node = graph[0]
39 | B, N = node.shape
40 | mean_node = torch.mm(node, node.transpose(0, 1).contiguous())
41 | loss += (mean_node.sum() - torch.diag(mean_node).sum()) / N / N
42 | return loss
43 |
44 | def graph_edge_loss(graphs, target):
45 | loss = 0
46 | for graph in graphs:
47 | edge = graph[1]
48 | B, N, _ = edge.shape
49 | edge = F.relu(edge)
50 | loss += (edge.mean(0).sum() - torch.diag(edge.mean(0)).sum()) / N / N
51 | return loss
52 |
53 | def train(args, model, device, train_loader, optimizer, epoch):
54 | model.train()
55 | batch_time = AverageMeter('Time', ':3.3f')
56 | data_time = AverageMeter('Data', ':3.3f')
57 | mem_cost = AverageMeter('Mem', ':3.3f')
58 | losses = AverageMeter('Loss', ':.3f')
59 | top1 = AverageMeter('Acc@1', ':3.2f')
60 | top5 = AverageMeter('Acc@5', ':3.2f')
61 | layer1_bf = AverageMeter('Corr@bf', ':3.2f')
62 | layer1_af = AverageMeter('Corr@af', ':3.2f')
63 | layer2_bf = AverageMeter('Corr@bf', ':3.2f')
64 | layer2_af = AverageMeter('Corr@af', ':3.2f')
65 | layer3_bf = AverageMeter('Corr@bf', ':3.2f')
66 | layer3_af = AverageMeter('Corr@af', ':3.2f')
67 |
68 | progress = ProgressMeter(len(train_loader), mem_cost, batch_time, data_time, losses, top1,
69 | top5, layer1_bf, layer1_af, layer2_bf, layer2_af, layer3_bf, layer3_af,
70 | prefix="Epoch: [{}]".format(epoch))
71 | end = time.time()
72 | for batch_idx, (data, target) in enumerate(train_loader):
73 | data_time.update(time.time() - end)
74 | data, target = data.to(device), target.to(device)
75 | optimizer.zero_grad()
76 | output = model(data)
77 | mem = torch.cuda.max_memory_allocated()
78 | mem_cost.update(mem / 1024 / 1024 / 1024)
79 |
80 | if model.net.name == "invcnn":
81 | loss = 0
82 | for out in output:
83 | out = F.log_softmax(out, dim=1)
84 | loss += F.nll_loss(out, target)
85 | loss /= len(output)
86 | else:
87 | output = F.log_softmax(output, dim=1)
88 | loss = F.nll_loss(output, target)
89 |
90 | acc1, acc5 = compute_accuracy(output[-1] if model.net.name == "invcnn" else output, target, topk=(1, 5))
91 | losses.update(loss.item(), data.size(0))
92 | top1.update(acc1[0], data.size(0))
93 | top5.update(acc5[0], data.size(0))
94 | layer1_bf.update(model.net.layer1.cn.corr_bf.item())
95 | layer1_af.update(model.net.layer1.cn.corr_af.item())
96 | layer2_bf.update(model.net.layer2.cn.corr_bf.item())
97 | layer2_af.update(model.net.layer2.cn.corr_af.item())
98 | layer3_bf.update(model.net.layer3.cn.corr_bf.item())
99 | layer3_af.update(model.net.layer3.cn.corr_af.item())
100 | # import pdb; pdb.set_trace()
101 | # node_loss = graph_node_loss(graphs, target)
102 | # edge_loss = graph_edge_loss(graphs, target)
103 | loss.backward()
104 | optimizer.step()
105 |
106 | # measure elapsed time
107 | batch_time.update(time.time() - end)
108 | end = time.time()
109 |
110 | if batch_idx % args.log.print_interval == 0:
111 | progress.print(batch_idx)
112 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tmem: {:.3f}m\tLoss: {:.6f}'.format(
113 | # epoch, batch_idx * len(data), len(train_loader.dataset), mem / 1024 / 1024,
114 | # 100. * batch_idx / len(train_loader), loss.item()))
115 |
116 | def test(args, model, device, test_loader):
117 | batch_time = AverageMeter('Time', ':3.3f')
118 | losses = AverageMeter('Loss', ':.3f')
119 | mem_cost = AverageMeter('Mem', ':3.3f')
120 | top1 = AverageMeter('Acc@1', ':3.2f')
121 | top5 = AverageMeter('Acc@5', ':3.2f')
122 | progress = ProgressMeter(len(test_loader), mem_cost, batch_time, losses, top1, top5,
123 | prefix='Test: ')
124 |
125 | model.eval()
126 | test_loss = 0
127 | correct = 0
128 | with torch.no_grad():
129 | end = time.time()
130 | for batch_idx, (data, target) in enumerate(test_loader):
131 | data, target = data.to(device), target.to(device)
132 | output = model(data)
133 | if model.net.name == "invcnn":
134 | output = torch.stack(output, 1).max(1)[0]
135 | mem = torch.cuda.max_memory_allocated()
136 | mem_cost.update(mem / 1024 / 1024 / 1024)
137 |
138 | # measure accuracy and record loss
139 | acc1, acc5 = compute_accuracy(output, target, topk=(1, 5))
140 | top1.update(acc1[0], data.size(0))
141 | top5.update(acc5[0], data.size(0))
142 |
143 | # measure elapsed time
144 | batch_time.update(time.time() - end)
145 | end = time.time()
146 |
147 | # compute val loss
148 | output = F.log_softmax(output, dim=1)
149 | loss = F.nll_loss(output, target).item() # sum up batch loss
150 | test_loss += loss
151 | losses.update(loss, data.size(0))
152 |
153 | # compute accuracy
154 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
155 | correct += pred.eq(target.view_as(pred)).sum().item()
156 |
157 | if batch_idx % args.log.print_interval == 0:
158 | progress.print(batch_idx)
159 | test_loss /= batch_idx
160 | data_size = len(test_loader.dataset) if hasattr(test_loader, 'dataset') else len(test_loader)
161 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
162 | test_loss, correct, data_size,
163 | 100. * float(correct) / data_size))
164 | accuracy = 100. * float(correct) / data_size
165 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
166 | .format(top1=top1, top5=top5))
167 | return accuracy
168 |
169 | def count_parameters(model):
170 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
171 |
172 | def main():
173 | # Training settings
174 | parser = argparse.ArgumentParser(description='PyTorch Image Classification')
175 | parser.add_argument('--dataset', type=str, default='cifar100',
176 | help='specify training dataset')
177 | parser.add_argument('--session', type=int, default='1',
178 | help='training session to recoder multiple runs')
179 | parser.add_argument('--arch', type=str, default='resnet110',
180 | help='specify network architecture')
181 | parser.add_argument('--bs', dest="batch_size", type=int, default=128,
182 | help='training batch size')
183 | parser.add_argument('--gpu0-bs', dest="gpu0_bs", type=int, default=0,
184 | help='training batch size on gpu0')
185 | parser.add_argument('--add-ccn', type=str, default='no',
186 | help='add cross neruon communication')
187 | parser.add_argument('--mgpus', type=str, default="no",
188 | help='multi-gpu training')
189 | parser.add_argument('--resume', dest="resume", type=int, default=0,
190 | help='resume epoch')
191 |
192 |
193 | args = parser.parse_args()
194 | cfg.merge_from_file(osp.join("configs", args.dataset + ".yaml"))
195 | cfg.dataset = args.dataset
196 | cfg.arch = args.arch
197 | cfg.add_cross_neuron = True if args.add_ccn == "yes" else False
198 | use_cuda = True if torch.cuda.is_available() else False
199 | cfg.use_cuda = use_cuda
200 | cfg.training.batch_size = args.batch_size
201 | cfg.mGPUs = True if args.mgpus == "yes" else False
202 |
203 | torch.manual_seed(cfg.initialize.seed)
204 | device = torch.device("cuda" if use_cuda else "cpu")
205 | train_loader, test_loader = create_data_loader(cfg)
206 | model = CrossNeuronNet(cfg)
207 | print("parameter numer: %d" % (count_parameters(model)))
208 | with torch.cuda.device(0):
209 | if args.dataset == "cifar100":
210 | flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
211 | # flops, params = profile(model, input_size=(1, 3, 32, 32))
212 | elif args.dataset == "imagenet":
213 | flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
214 | # flops, params = profile(model, input_size=(1, 3, 224, 224))
215 | print('Flops: {}'.format(flops))
216 | print('Params: {}'.format(params))
217 |
218 | model = model.to(device)
219 |
220 | # optimizer_policy = model.get_optim_policies()
221 | optimizer = optim.SGD(model.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
222 | # optimizer = optim.Adam(model.parameters(), lr=1e-3)
223 | if cfg.mGPUs:
224 | if args.gpu0_bs > 0:
225 | model = BalancedDataParallel(args.gpu0_bs, model).to(device)
226 | else:
227 | model = nn.DataParallel(model).to(device)
228 |
229 | lr = cfg.optimizer.lr
230 | checkpoint_tag = osp.join("checkponts", args.dataset, args.arch)
231 | if not osp.exists(checkpoint_tag):
232 | os.makedirs(checkpoint_tag)
233 |
234 | if args.resume > 0:
235 | ckpt_path = osp.join(checkpoint_tag,
236 | ("ccn" if cfg.add_cross_neuron else "plain") + "_{}_{}.pth".format(args.session, args.resume))
237 | print("resume model from {}".format(ckpt_path))
238 | ckpt = torch.load(ckpt_path)
239 | model.load_state_dict(ckpt["model"])
240 | print("resume model succesfully")
241 | acc = test(cfg, model, device, test_loader)
242 |
243 | best_acc = 0
244 | for epoch in range(args.resume + 1, cfg.optimizer.max_epoch + 1):
245 | if epoch in cfg.optimizer.lr_decay_schedule:
246 | adjust_learning_rate(optimizer, cfg.optimizer.lr_decay_gamma)
247 | lr *= cfg.optimizer.lr_decay_gamma
248 | print('Train Epoch: {} learning rate: {}'.format(epoch, lr))
249 | tic = time.time()
250 | train(cfg, model, device, train_loader, optimizer, epoch)
251 | acc = test(cfg, model, device, test_loader)
252 | time_cost = time.time() - tic
253 | if acc > best_acc:
254 | best_acc = acc
255 | print('\nModel: {} Best Accuracy-Baseline: {}\tTime Cost per Epoch: {}\n'.format(
256 | checkpoint_tag + ("ccn" if args.add_ccn == "yes" else "plain"),
257 | best_acc,
258 | time_cost))
259 |
260 | if epoch % cfg.log.checkpoint_interval == 0:
261 | checkpoint = {"arch": cfg.arch,
262 | "model": model.state_dict(),
263 | "epoch": epoch,
264 | "lr": lr,
265 | "test_acc": acc,
266 | "best_acc": best_acc}
267 | torch.save(checkpoint, osp.join(checkpoint_tag,
268 | ("ccn" if cfg.add_cross_neuron else "plain") + "_{}_{}.pth".format(args.session, epoch)))
269 |
270 |
271 | if __name__ == '__main__':
272 | torch.manual_seed(1)
273 | torch.cuda.manual_seed(1)
274 | torch.backends.cudnn.deterministic = True
275 | torch.backends.cudnn.benchmark = False
276 | main()
277 |
--------------------------------------------------------------------------------