├── .gitignore
├── images
├── Fig1.png
├── Fig2a.PNG
├── Fig2b.PNG
├── Fig3.PNG
├── 007_raw.jpg
├── 007_heat_atten.jpg
└── 007_raw_atten.jpg
├── models
├── __init__.py
├── blocks.py
├── vgg.py
├── wsdan.py
├── resnet.py
└── inception.py
├── datasets
├── __init__.py
├── dog_dataset.py
├── car_dataset.py
├── aircraft_dataset.py
└── bird_dataset.py
├── LICENSE
├── config.py
├── eval.py
├── README.md
├── utils.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | *.xml
3 |
--------------------------------------------------------------------------------
/images/Fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig1.png
--------------------------------------------------------------------------------
/images/Fig2a.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig2a.PNG
--------------------------------------------------------------------------------
/images/Fig2b.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig2b.PNG
--------------------------------------------------------------------------------
/images/Fig3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig3.PNG
--------------------------------------------------------------------------------
/images/007_raw.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_raw.jpg
--------------------------------------------------------------------------------
/images/007_heat_atten.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_heat_atten.jpg
--------------------------------------------------------------------------------
/images/007_raw_atten.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_raw_atten.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .vgg import *
3 | from .inception import *
4 | from .wsdan import *
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .aircraft_dataset import AircraftDataset
2 | from .bird_dataset import BirdDataset
3 | from .car_dataset import CarDataset
4 | from .dog_dataset import DogDataset
5 |
6 |
7 | def get_trainval_datasets(tag, resize):
8 | if tag == 'aircraft':
9 | return AircraftDataset(phase='train', resize=resize), AircraftDataset(phase='val', resize=resize)
10 | elif tag == 'bird':
11 | return BirdDataset(phase='train', resize=resize), BirdDataset(phase='val', resize=resize)
12 | elif tag == 'car':
13 | return CarDataset(phase='train', resize=resize), CarDataset(phase='val', resize=resize)
14 | elif tag == 'dog':
15 | return DogDataset(phase='train', resize=resize), DogDataset(phase='val', resize=resize)
16 | else:
17 | raise ValueError('Unsupported Tag {}'.format(tag))
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yuchong Gu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Training Config
3 | ##################################################
4 | GPU = '0' # GPU
5 | workers = 4 # number of Dataloader workers
6 | epochs = 160 # number of epochs
7 | batch_size = 12 # batch size
8 | learning_rate = 1e-3 # initial learning rate
9 |
10 | ##################################################
11 | # Model Config
12 | ##################################################
13 | image_size = (448, 448) # size of training images
14 | net = 'inception_mixed_6e' # feature extractor
15 | num_attentions = 32 # number of attention maps
16 | beta = 5e-2 # param for update feature centers
17 |
18 | ##################################################
19 | # Dataset/Path Config
20 | ##################################################
21 | tag = 'bird' # 'aircraft', 'bird', 'car', or 'dog'
22 |
23 | # saving directory of .ckpt models
24 | save_dir = './FGVC/CUB-200-2011/ckpt/'
25 | model_name = 'model.ckpt'
26 | log_name = 'train.log'
27 |
28 | # checkpoint model for resume training
29 | ckpt = False
30 | # ckpt = save_dir + model_name
31 |
32 | ##################################################
33 | # Eval Config
34 | ##################################################
35 | visualize = True
36 | eval_ckpt = save_dir + model_name
37 | eval_savepath = './FGVC/CUB-200-2011/visualize/'
--------------------------------------------------------------------------------
/datasets/dog_dataset.py:
--------------------------------------------------------------------------------
1 | """ Stanford Dogs (Dog) Dataset
2 | Created: Nov 15,2019 - Yuchong Gu
3 | Revised: Nov 15,2019 - Yuchong Gu
4 | """
5 | import os
6 | import pdb
7 | from PIL import Image
8 | from scipy.io import loadmat
9 | from torch.utils.data import Dataset
10 | from utils import get_transform
11 |
12 | DATAPATH = '/home/guyuchong/DATA/FGVC/StanfordDogs'
13 |
14 |
15 | class DogDataset(Dataset):
16 | """
17 | # Description:
18 | Dataset for retrieving Stanford Dogs images and labels
19 |
20 | # Member Functions:
21 | __init__(self, phase, resize): initializes a dataset
22 | phase: a string in ['train', 'val', 'test']
23 | resize: output shape/size of an image
24 |
25 | __getitem__(self, item): returns an image
26 | item: the idex of image in the whole dataset
27 |
28 | __len__(self): returns the length of dataset
29 | """
30 |
31 | def __init__(self, phase='train', resize=500):
32 | assert phase in ['train', 'val', 'test']
33 | self.phase = phase
34 | self.resize = resize
35 | self.num_classes = 120
36 |
37 | if phase == 'train':
38 | list_path = os.path.join(DATAPATH, 'train_list.mat')
39 | else:
40 | list_path = os.path.join(DATAPATH, 'test_list.mat')
41 |
42 | list_mat = loadmat(list_path)
43 | self.images = [f.item().item() for f in list_mat['file_list']]
44 | self.labels = [f.item() for f in list_mat['labels']]
45 |
46 | # transform
47 | self.transform = get_transform(self.resize, self.phase)
48 |
49 | def __getitem__(self, item):
50 | # image
51 | image = Image.open(os.path.join(DATAPATH, 'Images', self.images[item])).convert('RGB') # (C, H, W)
52 | image = self.transform(image)
53 |
54 | # return image and label
55 | return image, self.labels[item] - 1 # count begin from zero
56 |
57 | def __len__(self):
58 | return len(self.images)
59 |
60 |
61 | if __name__ == '__main__':
62 | ds = DogDataset('train')
63 | # print(len(ds))
64 | for i in range(0, 1000):
65 | image, label = ds[i]
66 | # print(image.shape, label)
67 |
--------------------------------------------------------------------------------
/datasets/car_dataset.py:
--------------------------------------------------------------------------------
1 | """ Stanford Cars (Car) Dataset
2 | Created: Nov 15,2019 - Yuchong Gu
3 | Revised: Nov 15,2019 - Yuchong Gu
4 | """
5 | import os
6 | import pdb
7 | from PIL import Image
8 | from scipy.io import loadmat
9 | from torch.utils.data import Dataset
10 | from utils import get_transform
11 |
12 | DATAPATH = '/home/guyuchong/DATA/FGVC/StanfordCars'
13 |
14 |
15 | class CarDataset(Dataset):
16 | """
17 | # Description:
18 | Dataset for retrieving Stanford Cars images and labels
19 |
20 | # Member Functions:
21 | __init__(self, phase, resize): initializes a dataset
22 | phase: a string in ['train', 'val', 'test']
23 | resize: output shape/size of an image
24 |
25 | __getitem__(self, item): returns an image
26 | item: the idex of image in the whole dataset
27 |
28 | __len__(self): returns the length of dataset
29 | """
30 |
31 | def __init__(self, phase='train', resize=500):
32 | assert phase in ['train', 'val', 'test']
33 | self.phase = phase
34 | self.resize = resize
35 | self.num_classes = 196
36 |
37 | if phase == 'train':
38 | list_path = os.path.join(DATAPATH, 'devkit', 'cars_train_annos.mat')
39 | self.image_path = os.path.join(DATAPATH, 'cars_train')
40 | else:
41 | list_path = os.path.join(DATAPATH, 'cars_test_annos_withlabels.mat')
42 | self.image_path = os.path.join(DATAPATH, 'cars_test')
43 |
44 | list_mat = loadmat(list_path)
45 | self.images = [f.item() for f in list_mat['annotations']['fname'][0]]
46 | self.labels = [f.item() for f in list_mat['annotations']['class'][0]]
47 |
48 | # transform
49 | self.transform = get_transform(self.resize, self.phase)
50 |
51 | def __getitem__(self, item):
52 | # image
53 | image = Image.open(os.path.join(self.image_path, self.images[item])).convert('RGB') # (C, H, W)
54 | image = self.transform(image)
55 |
56 | # return image and label
57 | return image, self.labels[item] - 1 # count begin from zero
58 |
59 | def __len__(self):
60 | return len(self.images)
61 |
62 |
63 | if __name__ == '__main__':
64 | ds = CarDataset('val')
65 | # print(len(ds))
66 | for i in range(0, 100):
67 | image, label = ds[i]
68 | # print(image.shape, label)
69 |
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | __all__ = ['CBAMLayer', 'SPPLayer']
7 |
8 | '''
9 | Woo et al.,
10 | "CBAM: Convolutional Block Attention Module",
11 | ECCV 2018,
12 | arXiv:1807.06521
13 | '''
14 | class CBAMLayer(nn.Module):
15 | def __init__(self, channel, reduction=16, spatial_kernel=7):
16 | super(CBAMLayer, self).__init__()
17 | # channel attention
18 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
19 | self.max_pool = nn.AdaptiveMaxPool2d(1)
20 | self.mlp = nn.Sequential(
21 | nn.Conv2d(channel, channel // reduction, 1, bias=False),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(channel // reduction, channel, 1, bias=False),
24 | )
25 | # spatial attention
26 | self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel, padding=spatial_kernel//2, bias=False)
27 | self.sigmoid = nn.Sigmoid()
28 |
29 | def forward(self, x):
30 | # channel attention
31 | max_out = self.mlp(self.max_pool(x))
32 | avg_out = self.mlp(self.avg_pool(x))
33 | channel_out = self.sigmoid(max_out + avg_out)
34 | x = channel_out * x
35 |
36 | # spatial attention
37 | max_out, _ = torch.max(x, dim=1, keepdim=True)
38 | avg_out = torch.mean(x, dim=1, keepdim=True)
39 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1)))
40 | x = spatial_out * x
41 | return x
42 |
43 |
44 | '''
45 | He et al.,
46 | "Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition",
47 | TPAMI 2015,
48 | arXiv:1406.4729
49 | '''
50 | class SPPLayer(nn.Module):
51 | def __init__(self, pool_size, pool=nn.MaxPool2d):
52 | super(SPPLayer, self).__init__()
53 | self.pool_size = pool_size
54 | self.pool = pool
55 | self.out_length = np.sum(np.array(self.pool_size) ** 2)
56 |
57 | def forward(self, x):
58 | B, C, H, W = x.size()
59 | for i in range(len(self.pool_size)):
60 | h_wid = int(math.ceil(H / self.pool_size[i]))
61 | w_wid = int(math.ceil(W / self.pool_size[i]))
62 | h_pad = (h_wid * self.pool_size[i] - H + 1) / 2
63 | w_pad = (w_wid * self.pool_size[i] - W + 1) / 2
64 | out = self.pool((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))(x)
65 | if i == 0:
66 | spp = out.view(B, -1)
67 | else:
68 | spp = torch.cat([spp, out.view(B, -1)], dim=1)
69 | return spp
70 |
--------------------------------------------------------------------------------
/datasets/aircraft_dataset.py:
--------------------------------------------------------------------------------
1 | """ FGVC Aircraft (Aircraft) Dataset
2 | Created: Nov 15,2019 - Yuchong Gu
3 | Revised: Nov 15,2019 - Yuchong Gu
4 | """
5 | import os
6 | import pdb
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from utils import get_transform
10 |
11 | DATAPATH = '/home/guyuchong/DATA/FGVC/FGVC-Aircraft/data'
12 | FILENAME_LENGTH = 7
13 |
14 |
15 | class AircraftDataset(Dataset):
16 | """
17 | # Description:
18 | Dataset for retrieving FGVC Aircraft images and labels
19 |
20 | # Member Functions:
21 | __init__(self, phase, resize): initializes a dataset
22 | phase: a string in ['train', 'val', 'test']
23 | resize: output shape/size of an image
24 |
25 | __getitem__(self, item): returns an image
26 | item: the idex of image in the whole dataset
27 |
28 | __len__(self): returns the length of dataset
29 | """
30 |
31 | def __init__(self, phase='train', resize=500):
32 | assert phase in ['train', 'val', 'test']
33 | self.phase = phase
34 | self.resize = resize
35 |
36 | variants_dict = {}
37 | with open(os.path.join(DATAPATH, 'variants.txt'), 'r') as f:
38 | for idx, line in enumerate(f.readlines()):
39 | variants_dict[line.strip()] = idx
40 | self.num_classes = len(variants_dict)
41 |
42 | if phase == 'train':
43 | list_path = os.path.join(DATAPATH, 'images_variant_trainval.txt')
44 | else:
45 | list_path = os.path.join(DATAPATH, 'images_variant_test.txt')
46 |
47 | self.images = []
48 | self.labels = []
49 | with open(list_path, 'r') as f:
50 | for line in f.readlines():
51 | fname_and_variant = line.strip()
52 | self.images.append(fname_and_variant[:FILENAME_LENGTH])
53 | self.labels.append(variants_dict[fname_and_variant[FILENAME_LENGTH + 1:]])
54 |
55 | # transform
56 | self.transform = get_transform(self.resize, self.phase)
57 |
58 | def __getitem__(self, item):
59 | # image
60 | image = Image.open(os.path.join(DATAPATH, 'images', '%s.jpg' % self.images[item])).convert('RGB') # (C, H, W)
61 | image = self.transform(image)
62 |
63 | # return image and label
64 | return image, self.labels[item] # count begin from zero
65 |
66 | def __len__(self):
67 | return len(self.images)
68 |
69 |
70 | if __name__ == '__main__':
71 | ds = AircraftDataset('test', 448)
72 | # print(len(ds))
73 | from utils import AverageMeter
74 | height_meter = AverageMeter('height')
75 | width_meter = AverageMeter('width')
76 |
77 | for i in range(len(ds)):
78 | image, label = ds[i]
79 | avgH = height_meter(image.size(1))
80 | avgW = width_meter(image.size(2))
81 | print('H: %.2f, W: %.2f' % (avgH, avgW))
82 |
--------------------------------------------------------------------------------
/datasets/bird_dataset.py:
--------------------------------------------------------------------------------
1 | """ CUB-200-2011 (Bird) Dataset
2 | Created: Oct 11,2019 - Yuchong Gu
3 | Revised: Oct 11,2019 - Yuchong Gu
4 | """
5 | import os
6 | import pdb
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from utils import get_transform
10 |
11 | DATAPATH = '/home/guyuchong/DATA/FGVC/CUB-200-2011'
12 | image_path = {}
13 | image_label = {}
14 |
15 |
16 | class BirdDataset(Dataset):
17 | """
18 | # Description:
19 | Dataset for retrieving CUB-200-2011 images and labels
20 |
21 | # Member Functions:
22 | __init__(self, phase, resize): initializes a dataset
23 | phase: a string in ['train', 'val', 'test']
24 | resize: output shape/size of an image
25 |
26 | __getitem__(self, item): returns an image
27 | item: the idex of image in the whole dataset
28 |
29 | __len__(self): returns the length of dataset
30 | """
31 |
32 | def __init__(self, phase='train', resize=500):
33 | assert phase in ['train', 'val', 'test']
34 | self.phase = phase
35 | self.resize = resize
36 | self.image_id = []
37 | self.num_classes = 200
38 |
39 | # get image path from images.txt
40 | with open(os.path.join(DATAPATH, 'images.txt')) as f:
41 | for line in f.readlines():
42 | id, path = line.strip().split(' ')
43 | image_path[id] = path
44 |
45 | # get image label from image_class_labels.txt
46 | with open(os.path.join(DATAPATH, 'image_class_labels.txt')) as f:
47 | for line in f.readlines():
48 | id, label = line.strip().split(' ')
49 | image_label[id] = int(label)
50 |
51 | # get train/test image id from train_test_split.txt
52 | with open(os.path.join(DATAPATH, 'train_test_split.txt')) as f:
53 | for line in f.readlines():
54 | image_id, is_training_image = line.strip().split(' ')
55 | is_training_image = int(is_training_image)
56 |
57 | if self.phase == 'train' and is_training_image:
58 | self.image_id.append(image_id)
59 | if self.phase in ('val', 'test') and not is_training_image:
60 | self.image_id.append(image_id)
61 |
62 | # transform
63 | self.transform = get_transform(self.resize, self.phase)
64 |
65 | def __getitem__(self, item):
66 | # get image id
67 | image_id = self.image_id[item]
68 |
69 | # image
70 | image = Image.open(os.path.join(DATAPATH, 'images', image_path[image_id])).convert('RGB') # (C, H, W)
71 | image = self.transform(image)
72 |
73 | # return image and label
74 | return image, image_label[image_id] - 1 # count begin from zero
75 |
76 | def __len__(self):
77 | return len(self.image_id)
78 |
79 |
80 | if __name__ == '__main__':
81 | ds = BirdDataset('train')
82 | print(len(ds))
83 | for i in range(0, 10):
84 | image, label = ds[i]
85 | print(image.shape, label)
86 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 | from models.blocks import SPPLayer
6 | import logging
7 |
8 |
9 | __all__ = ['vgg19_bn', 'vgg19']
10 |
11 |
12 | model_urls = {
13 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
14 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
15 | }
16 |
17 |
18 | class VGG(nn.Module):
19 |
20 | def __init__(self, features, num_classes=1000, init_weights=True):
21 | super(VGG, self).__init__()
22 | self.features = features
23 | self.spp = SPPLayer(pool_size=[1, 2, 4], pool=nn.MaxPool2d)
24 | self.fc = nn.Sequential(
25 | nn.Linear(512 * self.spp.out_length, 1024),
26 | nn.ReLU(True),
27 | nn.Dropout(),
28 | nn.Linear(1024, num_classes))
29 |
30 | if init_weights:
31 | self._initialize_weights()
32 |
33 | def forward(self, x):
34 | x = self.features(x)
35 | x = self.spp(x)
36 | x = self.fc(x)
37 | return x
38 |
39 | def _initialize_weights(self):
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
43 | m.weight.data.normal_(0, math.sqrt(2. / n))
44 | if m.bias is not None:
45 | m.bias.data.zero_()
46 | elif isinstance(m, nn.BatchNorm2d):
47 | m.weight.data.fill_(1)
48 | m.bias.data.zero_()
49 | elif isinstance(m, nn.Linear):
50 | m.weight.data.normal_(0, 0.01)
51 | m.bias.data.zero_()
52 |
53 | def get_features(self):
54 | return self.features
55 |
56 | def load_state_dict(self, state_dict, strict=True):
57 | model_dict = self.state_dict()
58 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
59 | if len(pretrained_dict) == len(state_dict):
60 | logging.info('%s: All params loaded' % type(self).__name__)
61 | else:
62 | logging.info('%s: Some params were not loaded:' % type(self).__name__)
63 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
64 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
65 | model_dict.update(pretrained_dict)
66 | super(VGG, self).load_state_dict(model_dict)
67 |
68 |
69 | def make_layers(cfg, batch_norm=False):
70 | layers = []
71 | in_channels = 3
72 | for v in cfg:
73 | if v == 'M':
74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
75 | else:
76 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
77 | if batch_norm:
78 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
79 | else:
80 | layers += [conv2d, nn.ReLU(inplace=True)]
81 | in_channels = v
82 | return nn.Sequential(*layers)
83 |
84 |
85 | cfg = {
86 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
87 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
88 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
89 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
90 | }
91 |
92 |
93 | def vgg19(pretrained=False, **kwargs):
94 | """VGG 19-layer model (configuration "E")
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | """
99 | if pretrained:
100 | kwargs['init_weights'] = False
101 | model = VGG(make_layers(cfg['E']), **kwargs)
102 | if pretrained:
103 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
104 | return model
105 |
106 |
107 | def vgg19_bn(pretrained=False, **kwargs):
108 | """VGG 19-layer model (configuration 'E') with batch normalization
109 |
110 | Args:
111 | pretrained (bool): If True, returns a model pre-trained on ImageNet
112 | """
113 | if pretrained:
114 | kwargs['init_weights'] = False
115 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
116 | if pretrained:
117 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
118 | return model
119 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """EVALUATION
2 | Created: Nov 22,2019 - Yuchong Gu
3 | Revised: Dec 03,2019 - Yuchong Gu
4 | """
5 | import os
6 | import logging
7 | import warnings
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 |
15 | import config
16 | from models import WSDAN
17 | from datasets import get_trainval_datasets
18 | from utils import TopKAccuracyMetric, batch_augment
19 |
20 | # GPU settings
21 | assert torch.cuda.is_available()
22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
23 | device = torch.device("cuda")
24 | torch.backends.cudnn.benchmark = True
25 |
26 | # visualize
27 | visualize = config.visualize
28 | savepath = config.eval_savepath
29 | if visualize:
30 | os.makedirs(savepath, exist_ok=True)
31 |
32 | ToPILImage = transforms.ToPILImage()
33 | MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
34 | STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
35 |
36 |
37 | def generate_heatmap(attention_maps):
38 | heat_attention_maps = []
39 | heat_attention_maps.append(attention_maps[:, 0, ...]) # R
40 | heat_attention_maps.append(attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float() + \
41 | (1. - attention_maps[:, 0, ...]) * (attention_maps[:, 0, ...] >= 0.5).float()) # G
42 | heat_attention_maps.append(1. - attention_maps[:, 0, ...]) # B
43 | return torch.stack(heat_attention_maps, dim=1)
44 |
45 |
46 | def main():
47 | logging.basicConfig(
48 | format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
49 | level=logging.INFO)
50 | warnings.filterwarnings("ignore")
51 |
52 | try:
53 | ckpt = config.eval_ckpt
54 | except:
55 | logging.info('Set ckpt for evaluation in config.py')
56 | return
57 |
58 | ##################################
59 | # Dataset for testing
60 | ##################################
61 | _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
62 | test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
63 | num_workers=2, pin_memory=True)
64 |
65 | ##################################
66 | # Initialize model
67 | ##################################
68 | net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net)
69 |
70 | # Load ckpt and get state_dict
71 | checkpoint = torch.load(ckpt)
72 | state_dict = checkpoint['state_dict']
73 |
74 | # Load weights
75 | net.load_state_dict(state_dict)
76 | logging.info('Network loaded from {}'.format(ckpt))
77 |
78 | ##################################
79 | # use cuda
80 | ##################################
81 | net.to(device)
82 | if torch.cuda.device_count() > 1:
83 | net = nn.DataParallel(net)
84 |
85 | ##################################
86 | # Prediction
87 | ##################################
88 | raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
89 | ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
90 | raw_accuracy.reset()
91 | ref_accuracy.reset()
92 |
93 | net.eval()
94 | with torch.no_grad():
95 | pbar = tqdm(total=len(test_loader), unit=' batches')
96 | pbar.set_description('Validation')
97 | for i, (X, y) in enumerate(test_loader):
98 | X = X.to(device)
99 | y = y.to(device)
100 |
101 | # WS-DAN
102 | y_pred_raw, _, attention_maps = net(X)
103 |
104 | # Augmentation with crop_mask
105 | crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)
106 |
107 | y_pred_crop, _, _ = net(crop_image)
108 | y_pred = (y_pred_raw + y_pred_crop) / 2.
109 |
110 | if visualize:
111 | # reshape attention maps
112 | attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3)))
113 | attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item())
114 |
115 | # get heat attention maps
116 | heat_attention_maps = generate_heatmap(attention_maps)
117 |
118 | # raw_image, heat_attention, raw_attention
119 | raw_image = X.cpu() * STD + MEAN
120 | heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
121 | raw_attention_image = raw_image * attention_maps
122 |
123 | for batch_idx in range(X.size(0)):
124 | rimg = ToPILImage(raw_image[batch_idx])
125 | raimg = ToPILImage(raw_attention_image[batch_idx])
126 | haimg = ToPILImage(heat_attention_image[batch_idx])
127 | rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx)))
128 | raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx)))
129 | haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx)))
130 |
131 | # Top K
132 | epoch_raw_acc = raw_accuracy(y_pred_raw, y)
133 | epoch_ref_acc = ref_accuracy(y_pred, y)
134 |
135 | # end of this batch
136 | batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
137 | epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1])
138 | pbar.update()
139 | pbar.set_postfix_str(batch_info)
140 |
141 | pbar.close()
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WS-DAN.PyTorch
2 | A neat PyTorch implementation of WS-DAN (Weakly Supervised Data Augmentation Network) for FGVC (Fine-Grained Visual Classification). (_Hu et al._, ["See Better Before Looking Closer: Weakly Supervised Data Augmentation
3 | Network for Fine-Grained Visual Classification"](https://arxiv.org/abs/1901.09891v2), arXiv:1901.09891)
4 |
5 | **NOTICE: This is NOT an official implementation by authors of WS-DAN. The official implementation is available at [tau-yihouxiang/WS_DAN](https://github.com/tau-yihouxiang/WS_DAN) (and there's another unofficial PyTorch version [wvinzh/WS_DAN_PyTorch](https://github.com/wvinzh/WS_DAN_PyTorch)).**
6 |
7 |
8 |
9 |
10 | ## Innovations
11 | 1. Data Augmentation: Attention Cropping and Attention Dropping
12 |
13 |

14 |
15 |
16 | 2. Bilinear Attention Pooling (BAP) for Features Generation
17 |
18 |

19 |
20 |
21 | 3. Training Process and Testing Process
22 |
23 |

24 |

25 |
26 |
27 |
28 |
29 | ## Performance
30 | * PyTorch experiments were done on a Titan Xp GPU (batch_size = 12).
31 |
32 | |Dataset|Object|Category|Train|Test|Accuracy (Paper)|Accuracy (PyTorch)|Feature Net|
33 | |-------|------|--------|-----|----|----------------|--------------------|---|
34 | |[FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)|Aircraft|100|6,667|3,333|93.0|93.28|inception_mixed_6e|
35 | |[CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)|Bird|200|5,994|5,794|89.4|88.28|inception_mixed_6e|
36 | |[Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)|Car|196|8,144|8,041|94.5|94.38|inception_mixed_6e|
37 | |[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/)|Dog|120|12,000|8,580|92.2|89.66|inception_mixed_7c|
38 |
39 |
40 |
41 | ## Usage
42 |
43 | ### WS-DAN
44 | This repo contains WS-DAN with feature extractors including VGG19(```'vgg19', 'vgg19_bn'```),
45 | ResNet34/50/101/152(```'resnet34', 'resnet50', 'resnet101', 'resnet152'```),
46 | and Inception_v3(```'inception_mixed_6e', 'inception_mixed_7c'```) in PyTorch form, see ```./models/wsdan.py```.
47 |
48 | ```python
49 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True)
50 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_7c', pretrained=True)
51 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='vgg19_bn', pretrained=True)
52 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='resnet50', pretrained=True)
53 | ```
54 |
55 | ### Dataset Directory
56 |
57 | * [FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) (Aircraft)
58 |
59 | ```
60 | -/FGVC-Aircraft/data/
61 | └─── images
62 | └─── 0034309.jpg
63 | └─── 0034958.jpg
64 | └─── ...
65 | └─── variants.txt
66 | └─── images_variant_trainval.txt
67 | └─── images_variant_test.txt
68 | ```
69 |
70 | * [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) (Bird)
71 |
72 | ```
73 | -/CUB-200-2011
74 | └─── images.txt
75 | └─── image_class_labels.txt
76 | └─── train_test_split.txt
77 | └─── images
78 | └─── 001.Black_footed_Albatross
79 | └─── Black_Footed_Albatross_0001_796111.jpg
80 | └─── ...
81 | └─── 002.Laysan_Albatross
82 | └─── ...
83 | ```
84 |
85 | * [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) (Car)
86 |
87 | ```
88 | -/StanfordCars
89 | └─── cars_test
90 | └─── 00001.jpg
91 | └─── 00002.jpg
92 | └─── ...
93 | └─── cars_train
94 | └─── 00001.jpg
95 | └─── 00002.jpg
96 | └─── ...
97 | └─── devkit
98 | └─── cars_train_annos.mat
99 | └─── cars_test_annos_withlabels.mat
100 | ```
101 |
102 | * [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) (Dog)
103 |
104 | ```
105 | -/StanfordDogs
106 | └─── Images
107 | └─── n02085620-Chihuahua
108 | └─── n02085620_10074.jpg
109 | └─── ...
110 | └─── n02085782-Japanese_spaniel
111 | └─── ...
112 | └─── train_list.mat
113 | └─── test_list.mat
114 | ```
115 |
116 |
117 | ### Run
118 |
119 | 1. ``` git clone``` this repo.
120 |
121 | 2. Prepare data and **modify DATAPATH** in ```datasets/_dataset.py```.
122 |
123 | 3. **Set configurations** in ```config.py``` (Training Config, Model Config, Dataset/Path Config):
124 |
125 | ```python
126 | tag = 'aircraft' # 'aircraft', 'bird', 'car', or 'dog'
127 | ```
128 |
129 | 4. ```$ nohup python3 train.py > progress.bar &``` for training.
130 |
131 | 5. ```$ tail -f progress.bar``` to see training process (tqdm package is required. Other logs are written in ```/train.log```).
132 |
133 | 6. Set configurations in ```config.py``` (Eval Config) and run ```$ python3 eval.py``` for evaluation and visualization.
134 |
135 | ### Attention Maps Visualization
136 |
137 | Code in ```eval.py``` helps generate attention maps. (Image, Heat Attention Map, Image x Attention Map)
138 |
139 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/models/wsdan.py:
--------------------------------------------------------------------------------
1 | """
2 | WS-DAN models
3 |
4 | Hu et al.,
5 | "See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification",
6 | arXiv:1901.09891
7 |
8 | Created: May 04,2019 - Yuchong Gu
9 | Revised: Dec 03,2019 - Yuchong Gu
10 | """
11 | import logging
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | import models.vgg as vgg
18 | import models.resnet as resnet
19 | from models.inception import inception_v3, BasicConv2d
20 |
21 | __all__ = ['WSDAN']
22 | EPSILON = 1e-12
23 |
24 |
25 | # Bilinear Attention Pooling
26 | class BAP(nn.Module):
27 | def __init__(self, pool='GAP'):
28 | super(BAP, self).__init__()
29 | assert pool in ['GAP', 'GMP']
30 | if pool == 'GAP':
31 | self.pool = None
32 | else:
33 | self.pool = nn.AdaptiveMaxPool2d(1)
34 |
35 | def forward(self, features, attentions):
36 | B, C, H, W = features.size()
37 | _, M, AH, AW = attentions.size()
38 |
39 | # match size
40 | if AH != H or AW != W:
41 | attentions = F.upsample_bilinear(attentions, size=(H, W))
42 |
43 | # feature_matrix: (B, M, C) -> (B, M * C)
44 | if self.pool is None:
45 | feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)
46 | else:
47 | feature_matrix = []
48 | for i in range(M):
49 | AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)
50 | feature_matrix.append(AiF)
51 | feature_matrix = torch.cat(feature_matrix, dim=1)
52 |
53 | # sign-sqrt
54 | feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)
55 |
56 | # l2 normalization along dimension M and C
57 | feature_matrix = F.normalize(feature_matrix, dim=-1)
58 | return feature_matrix
59 |
60 |
61 | # WS-DAN: Weakly Supervised Data Augmentation Network for FGVC
62 | class WSDAN(nn.Module):
63 | def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False):
64 | super(WSDAN, self).__init__()
65 | self.num_classes = num_classes
66 | self.M = M
67 | self.net = net
68 |
69 | # Network Initialization
70 | if 'inception' in net:
71 | if net == 'inception_mixed_6e':
72 | self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e()
73 | self.num_features = 768
74 | elif net == 'inception_mixed_7c':
75 | self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c()
76 | self.num_features = 2048
77 | else:
78 | raise ValueError('Unsupported net: %s' % net)
79 | elif 'vgg' in net:
80 | self.features = getattr(vgg, net)(pretrained=pretrained).get_features()
81 | self.num_features = 512
82 | elif 'resnet' in net:
83 | self.features = getattr(resnet, net)(pretrained=pretrained).get_features()
84 | self.num_features = 512 * self.features[-1][-1].expansion
85 | else:
86 | raise ValueError('Unsupported net: %s' % net)
87 |
88 | # Attention Maps
89 | self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)
90 |
91 | # Bilinear Attention Pooling
92 | self.bap = BAP(pool='GAP')
93 |
94 | # Classification Layer
95 | self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False)
96 |
97 | logging.info('WSDAN: using {} as feature extractor, num_classes: {}, num_attentions: {}'.format(net, self.num_classes, self.M))
98 |
99 | def forward(self, x):
100 | batch_size = x.size(0)
101 |
102 | # Feature Maps, Attention Maps and Feature Matrix
103 | feature_maps = self.features(x)
104 | if self.net != 'inception_mixed_7c':
105 | attention_maps = self.attentions(feature_maps)
106 | else:
107 | attention_maps = feature_maps[:, :self.M, ...]
108 | feature_matrix = self.bap(feature_maps, attention_maps)
109 |
110 | # Classification
111 | p = self.fc(feature_matrix * 100.)
112 |
113 | # Generate Attention Map
114 | if self.training:
115 | # Randomly choose one of attention maps Ak
116 | attention_map = []
117 | for i in range(batch_size):
118 | attention_weights = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + EPSILON)
119 | attention_weights = F.normalize(attention_weights, p=1, dim=0)
120 | k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy())
121 | attention_map.append(attention_maps[i, k_index, ...])
122 | attention_map = torch.stack(attention_map) # (B, 2, H, W) - one for cropping, the other for dropping
123 | else:
124 | # Object Localization Am = mean(Ak)
125 | attention_map = torch.mean(attention_maps, dim=1, keepdim=True) # (B, 1, H, W)
126 |
127 | # p: (B, self.num_classes)
128 | # feature_matrix: (B, M * C)
129 | # attention_map: (B, 2, H, W) in training, (B, 1, H, W) in val/testing
130 | return p, feature_matrix, attention_map
131 |
132 | def load_state_dict(self, state_dict, strict=True):
133 | model_dict = self.state_dict()
134 | pretrained_dict = {k: v for k, v in state_dict.items()
135 | if k in model_dict and model_dict[k].size() == v.size()}
136 |
137 | if len(pretrained_dict) == len(state_dict):
138 | logging.info('%s: All params loaded' % type(self).__name__)
139 | else:
140 | logging.info('%s: Some params were not loaded:' % type(self).__name__)
141 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
142 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
143 |
144 | model_dict.update(pretrained_dict)
145 | super(WSDAN, self).load_state_dict(model_dict)
146 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """Utils
2 | Created: Nov 11,2019 - Yuchong Gu
3 | Revised: Dec 03,2019 - Yuchong Gu
4 | """
5 | import torch
6 | import random
7 | import numpy as np
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision.transforms as transforms
11 |
12 |
13 | ##############################################
14 | # Center Loss for Attention Regularization
15 | ##############################################
16 | class CenterLoss(nn.Module):
17 | def __init__(self):
18 | super(CenterLoss, self).__init__()
19 | self.l2_loss = nn.MSELoss(reduction='sum')
20 |
21 | def forward(self, outputs, targets):
22 | return self.l2_loss(outputs, targets) / outputs.size(0)
23 |
24 |
25 | ##################################
26 | # Metric
27 | ##################################
28 | class Metric(object):
29 | pass
30 |
31 |
32 | class AverageMeter(Metric):
33 | def __init__(self, name='loss'):
34 | self.name = name
35 | self.reset()
36 |
37 | def reset(self):
38 | self.scores = 0.
39 | self.total_num = 0.
40 |
41 | def __call__(self, batch_score, sample_num=1):
42 | self.scores += batch_score
43 | self.total_num += sample_num
44 | return self.scores / self.total_num
45 |
46 |
47 | class TopKAccuracyMetric(Metric):
48 | def __init__(self, topk=(1,)):
49 | self.name = 'topk_accuracy'
50 | self.topk = topk
51 | self.maxk = max(topk)
52 | self.reset()
53 |
54 | def reset(self):
55 | self.corrects = np.zeros(len(self.topk))
56 | self.num_samples = 0.
57 |
58 | def __call__(self, output, target):
59 | """Computes the precision@k for the specified values of k"""
60 | self.num_samples += target.size(0)
61 | _, pred = output.topk(self.maxk, 1, True, True)
62 | pred = pred.t()
63 | correct = pred.eq(target.view(1, -1).expand_as(pred))
64 |
65 | for i, k in enumerate(self.topk):
66 | correct_k = correct[:k].view(-1).float().sum(0)
67 | self.corrects[i] += correct_k.item()
68 |
69 | return self.corrects * 100. / self.num_samples
70 |
71 |
72 | ##################################
73 | # Callback
74 | ##################################
75 | class Callback(object):
76 | def __init__(self):
77 | pass
78 |
79 | def on_epoch_begin(self):
80 | pass
81 |
82 | def on_epoch_end(self, *args):
83 | pass
84 |
85 |
86 | class ModelCheckpoint(Callback):
87 | def __init__(self, savepath, monitor='val_topk_accuracy', mode='max'):
88 | self.savepath = savepath
89 | self.monitor = monitor
90 | self.mode = mode
91 | self.reset()
92 | super(ModelCheckpoint, self).__init__()
93 |
94 | def reset(self):
95 | if self.mode == 'max':
96 | self.best_score = float('-inf')
97 | else:
98 | self.best_score = float('inf')
99 |
100 | def set_best_score(self, score):
101 | if isinstance(score, np.ndarray):
102 | self.best_score = score[0]
103 | else:
104 | self.best_score = score
105 |
106 | def on_epoch_begin(self):
107 | pass
108 |
109 | def on_epoch_end(self, logs, net, **kwargs):
110 | current_score = logs[self.monitor]
111 | if isinstance(current_score, np.ndarray):
112 | current_score = current_score[0]
113 |
114 | if (self.mode == 'max' and current_score > self.best_score) or \
115 | (self.mode == 'min' and current_score < self.best_score):
116 | self.best_score = current_score
117 |
118 | if isinstance(net, torch.nn.DataParallel):
119 | state_dict = net.module.state_dict()
120 | else:
121 | state_dict = net.state_dict()
122 |
123 | for key in state_dict.keys():
124 | state_dict[key] = state_dict[key].cpu()
125 |
126 | if 'feature_center' in kwargs:
127 | feature_center = kwargs['feature_center']
128 | feature_center = feature_center.cpu()
129 |
130 | torch.save({
131 | 'logs': logs,
132 | 'state_dict': state_dict,
133 | 'feature_center': feature_center}, self.savepath)
134 | else:
135 | torch.save({
136 | 'logs': logs,
137 | 'state_dict': state_dict}, self.savepath)
138 |
139 |
140 | ##################################
141 | # augment function
142 | ##################################
143 | def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1):
144 | batches, _, imgH, imgW = images.size()
145 |
146 | if mode == 'crop':
147 | crop_images = []
148 | for batch_index in range(batches):
149 | atten_map = attention_map[batch_index:batch_index + 1]
150 | if isinstance(theta, tuple):
151 | theta_c = random.uniform(*theta) * atten_map.max()
152 | else:
153 | theta_c = theta * atten_map.max()
154 |
155 | crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c
156 | nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])
157 | height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0)
158 | height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)
159 | width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)
160 | width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)
161 |
162 | crop_images.append(
163 | F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],
164 | size=(imgH, imgW)))
165 | crop_images = torch.cat(crop_images, dim=0)
166 | return crop_images
167 |
168 | elif mode == 'drop':
169 | drop_masks = []
170 | for batch_index in range(batches):
171 | atten_map = attention_map[batch_index:batch_index + 1]
172 | if isinstance(theta, tuple):
173 | theta_d = random.uniform(*theta) * atten_map.max()
174 | else:
175 | theta_d = theta * atten_map.max()
176 |
177 | drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d)
178 | drop_masks = torch.cat(drop_masks, dim=0)
179 | drop_images = images * drop_masks.float()
180 | return drop_images
181 |
182 | else:
183 | raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode)
184 |
185 |
186 | ##################################
187 | # transform in dataset
188 | ##################################
189 | def get_transform(resize, phase='train'):
190 | if phase == 'train':
191 | return transforms.Compose([
192 | transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
193 | transforms.RandomCrop(resize),
194 | transforms.RandomHorizontalFlip(0.5),
195 | transforms.ColorJitter(brightness=0.126, saturation=0.5),
196 | transforms.ToTensor(),
197 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
198 | ])
199 | else:
200 | return transforms.Compose([
201 | transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
202 | transforms.CenterCrop(resize),
203 | transforms.ToTensor(),
204 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
205 | ])
206 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 | from models.blocks import CBAMLayer, SPPLayer
6 | import logging
7 |
8 | __all__ = ['resnet34', 'resnet50', 'resnet101', 'resnet152',
9 | 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam', 'resnet152_cbam']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, inplanes, planes, stride=1, cbam=None, downsample=None):
25 | super(BasicBlock, self).__init__()
26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride)
27 | self.bn1 = nn.BatchNorm2d(planes)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1)
30 | self.bn2 = nn.BatchNorm2d(planes)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | if cbam is not None:
35 | self.cbam = CBAMLayer(planes)
36 | else:
37 | self.cbam = None
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.cbam is not None:
50 | out = self.cbam(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, cbam=None, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
71 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | if cbam is not None:
77 | self.cbam = CBAMLayer(planes * Bottleneck.expansion)
78 | else:
79 | self.cbam = None
80 |
81 | def forward(self, x):
82 | residual = x
83 |
84 | out = self.conv1(x)
85 | out = self.bn1(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv2(out)
89 | out = self.bn2(out)
90 | out = self.relu(out)
91 |
92 | out = self.conv3(out)
93 | out = self.bn3(out)
94 |
95 | if self.cbam is not None:
96 | out = self.cbam(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class ResNet(nn.Module):
108 | def __init__(self, block, layers, cbam=None, num_classes=1000):
109 | self.inplanes = 64
110 | super(ResNet, self).__init__()
111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
112 | self.bn1 = nn.BatchNorm2d(64)
113 | self.relu = nn.ReLU(inplace=True)
114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115 | self.layer1 = self._make_layer(block, 64, layers[0], cbam)
116 | self.layer2 = self._make_layer(block, 128, layers[1], cbam, stride=2)
117 | self.layer3 = self._make_layer(block, 256, layers[2], cbam, stride=2)
118 | self.layer4 = self._make_layer(block, 512, layers[3], cbam, stride=2)
119 |
120 | self.avgpool = nn.AdaptiveAvgPool2d(1)
121 | self.fc = nn.Linear(512 * block.expansion, num_classes)
122 |
123 | # self.spp = SPPLayer(pool_size=[1, 2, 4], pool=nn.MaxPool2d)
124 | # self.fc = nn.Linear(512 * block.expansion * self.spp.out_length, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def _make_layer(self, block, planes, blocks, cbam=None, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
139 | nn.BatchNorm2d(planes * block.expansion))
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride=stride, cbam=cbam, downsample=downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes, cbam=cbam))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x)
153 | x = self.maxpool(x)
154 |
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | # x = self.spp(x)
163 | x = self.fc(x)
164 |
165 | return x
166 |
167 | def get_features(self):
168 | return nn.Sequential(
169 | self.conv1,
170 | self.bn1,
171 | self.relu,
172 | self.maxpool,
173 | self.layer1,
174 | self.layer2,
175 | self.layer3,
176 | self.layer4,
177 | )
178 |
179 | def load_state_dict(self, state_dict, strict=True):
180 | model_dict = self.state_dict()
181 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
182 | if len(pretrained_dict) == len(state_dict):
183 | logging.info('%s: All params loaded' % type(self).__name__)
184 | else:
185 | logging.info('%s: Some params were not loaded:' % type(self).__name__)
186 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
187 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
188 | model_dict.update(pretrained_dict)
189 | super(ResNet, self).load_state_dict(model_dict)
190 |
191 |
192 | def resnet34(pretrained=False, num_classes=1000):
193 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
194 | if pretrained:
195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
196 | return model
197 |
198 |
199 | def resnet50(pretrained=False, num_classes=1000):
200 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
201 | if pretrained:
202 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
203 | return model
204 |
205 |
206 | def resnet101(pretrained=False, num_classes=1000):
207 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
208 | if pretrained:
209 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
210 | return model
211 |
212 |
213 | def resnet152(pretrained=False, num_classes=1000):
214 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)
215 | if pretrained:
216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
217 | return model
218 |
219 |
220 | def resnet34_cbam(pretrained=False, num_classes=1000):
221 | model = ResNet(BasicBlock, [3, 4, 6, 3], cbam=True, num_classes=num_classes)
222 | if pretrained:
223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
224 | return model
225 |
226 |
227 | def resnet50_cbam(pretrained=False, num_classes=1000):
228 | model = ResNet(Bottleneck, [3, 4, 6, 3], cbam=True, num_classes=num_classes)
229 | if pretrained:
230 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
231 | return model
232 |
233 |
234 | def resnet101_cbam(pretrained=False, num_classes=1000):
235 | model = ResNet(Bottleneck, [3, 4, 23, 3], cbam=True, num_classes=num_classes)
236 | if pretrained:
237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
238 | return model
239 |
240 |
241 | def resnet152_cbam(pretrained=False, num_classes=1000):
242 | model = ResNet(Bottleneck, [3, 8, 36, 3], cbam=True, num_classes=num_classes)
243 | if pretrained:
244 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
245 | return model
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """TRAINING
2 | Created: May 04,2019 - Yuchong Gu
3 | Revised: Dec 03,2019 - Yuchong Gu
4 | """
5 | import os
6 | import time
7 | import logging
8 | import warnings
9 | from tqdm import tqdm
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 |
15 | import config
16 | from models import WSDAN
17 | from datasets import get_trainval_datasets
18 | from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, ModelCheckpoint, batch_augment
19 |
20 | # GPU settings
21 | assert torch.cuda.is_available()
22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
23 | device = torch.device("cuda")
24 | torch.backends.cudnn.benchmark = True
25 |
26 | # General loss functions
27 | cross_entropy_loss = nn.CrossEntropyLoss()
28 | center_loss = CenterLoss()
29 |
30 | # loss and metric
31 | loss_container = AverageMeter(name='loss')
32 | raw_metric = TopKAccuracyMetric(topk=(1, 5))
33 | crop_metric = TopKAccuracyMetric(topk=(1, 5))
34 | drop_metric = TopKAccuracyMetric(topk=(1, 5))
35 |
36 |
37 | def main():
38 | ##################################
39 | # Initialize saving directory
40 | ##################################
41 | if not os.path.exists(config.save_dir):
42 | os.makedirs(config.save_dir)
43 |
44 | ##################################
45 | # Logging setting
46 | ##################################
47 | logging.basicConfig(
48 | filename=os.path.join(config.save_dir, config.log_name),
49 | filemode='w',
50 | format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
51 | level=logging.INFO)
52 | warnings.filterwarnings("ignore")
53 |
54 | ##################################
55 | # Load dataset
56 | ##################################
57 | train_dataset, validate_dataset = get_trainval_datasets(config.tag, config.image_size)
58 |
59 | train_loader, validate_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True,
60 | num_workers=config.workers, pin_memory=True), \
61 | DataLoader(validate_dataset, batch_size=config.batch_size * 4, shuffle=False,
62 | num_workers=config.workers, pin_memory=True)
63 | num_classes = train_dataset.num_classes
64 |
65 | ##################################
66 | # Initialize model
67 | ##################################
68 | logs = {}
69 | start_epoch = 0
70 | net = WSDAN(num_classes=num_classes, M=config.num_attentions, net=config.net, pretrained=True)
71 |
72 | # feature_center: size of (#classes, #attention_maps * #channel_features)
73 | feature_center = torch.zeros(num_classes, config.num_attentions * net.num_features).to(device)
74 |
75 | if config.ckpt:
76 | # Load ckpt and get state_dict
77 | checkpoint = torch.load(config.ckpt)
78 |
79 | # Get epoch and some logs
80 | logs = checkpoint['logs']
81 | start_epoch = int(logs['epoch'])
82 |
83 | # Load weights
84 | state_dict = checkpoint['state_dict']
85 | net.load_state_dict(state_dict)
86 | logging.info('Network loaded from {}'.format(config.ckpt))
87 |
88 | # load feature center
89 | if 'feature_center' in checkpoint:
90 | feature_center = checkpoint['feature_center'].to(device)
91 | logging.info('feature_center loaded from {}'.format(config.ckpt))
92 |
93 | logging.info('Network weights save to {}'.format(config.save_dir))
94 |
95 | ##################################
96 | # Use cuda
97 | ##################################
98 | net.to(device)
99 | if torch.cuda.device_count() > 1:
100 | net = nn.DataParallel(net)
101 |
102 | ##################################
103 | # Optimizer, LR Scheduler
104 | ##################################
105 | learning_rate = logs['lr'] if 'lr' in logs else config.learning_rate
106 | optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
107 |
108 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2)
109 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
110 |
111 | ##################################
112 | # ModelCheckpoint
113 | ##################################
114 | callback_monitor = 'val_{}'.format(raw_metric.name)
115 | callback = ModelCheckpoint(savepath=os.path.join(config.save_dir, config.model_name),
116 | monitor=callback_monitor,
117 | mode='max')
118 | if callback_monitor in logs:
119 | callback.set_best_score(logs[callback_monitor])
120 | else:
121 | callback.reset()
122 |
123 | ##################################
124 | # TRAINING
125 | ##################################
126 | logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'.
127 | format(config.epochs, config.batch_size, len(train_dataset), len(validate_dataset)))
128 | logging.info('')
129 |
130 | for epoch in range(start_epoch, config.epochs):
131 | callback.on_epoch_begin()
132 |
133 | logs['epoch'] = epoch + 1
134 | logs['lr'] = optimizer.param_groups[0]['lr']
135 |
136 | logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr']))
137 |
138 | pbar = tqdm(total=len(train_loader), unit=' batches')
139 | pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs))
140 |
141 | train(logs=logs,
142 | data_loader=train_loader,
143 | net=net,
144 | feature_center=feature_center,
145 | optimizer=optimizer,
146 | pbar=pbar)
147 | validate(logs=logs,
148 | data_loader=validate_loader,
149 | net=net,
150 | pbar=pbar)
151 |
152 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
153 | scheduler.step(logs['val_loss'])
154 | else:
155 | scheduler.step()
156 |
157 | callback.on_epoch_end(logs, net, feature_center=feature_center)
158 | pbar.close()
159 |
160 |
161 | def train(**kwargs):
162 | # Retrieve training configuration
163 | logs = kwargs['logs']
164 | data_loader = kwargs['data_loader']
165 | net = kwargs['net']
166 | feature_center = kwargs['feature_center']
167 | optimizer = kwargs['optimizer']
168 | pbar = kwargs['pbar']
169 |
170 | # metrics initialization
171 | loss_container.reset()
172 | raw_metric.reset()
173 | crop_metric.reset()
174 | drop_metric.reset()
175 |
176 | # begin training
177 | start_time = time.time()
178 | net.train()
179 | for i, (X, y) in enumerate(data_loader):
180 | optimizer.zero_grad()
181 |
182 | # obtain data for training
183 | X = X.to(device)
184 | y = y.to(device)
185 |
186 | ##################################
187 | # Raw Image
188 | ##################################
189 | # raw images forward
190 | y_pred_raw, feature_matrix, attention_map = net(X)
191 |
192 | # Update Feature Center
193 | feature_center_batch = F.normalize(feature_center[y], dim=-1)
194 | feature_center[y] += config.beta * (feature_matrix.detach() - feature_center_batch)
195 |
196 | ##################################
197 | # Attention Cropping
198 | ##################################
199 | with torch.no_grad():
200 | crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)
201 |
202 | # crop images forward
203 | y_pred_crop, _, _ = net(crop_images)
204 |
205 | ##################################
206 | # Attention Dropping
207 | ##################################
208 | with torch.no_grad():
209 | drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))
210 |
211 | # drop images forward
212 | y_pred_drop, _, _ = net(drop_images)
213 |
214 | # loss
215 | batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \
216 | cross_entropy_loss(y_pred_crop, y) / 3. + \
217 | cross_entropy_loss(y_pred_drop, y) / 3. + \
218 | center_loss(feature_matrix, feature_center_batch)
219 |
220 | # backward
221 | batch_loss.backward()
222 | optimizer.step()
223 |
224 | # metrics: loss and top-1,5 error
225 | with torch.no_grad():
226 | epoch_loss = loss_container(batch_loss.item())
227 | epoch_raw_acc = raw_metric(y_pred_raw, y)
228 | epoch_crop_acc = crop_metric(y_pred_crop, y)
229 | epoch_drop_acc = drop_metric(y_pred_drop, y)
230 |
231 | # end of this batch
232 | batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format(
233 | epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1],
234 | epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1])
235 | pbar.update()
236 | pbar.set_postfix_str(batch_info)
237 |
238 | # end of this epoch
239 | logs['train_{}'.format(loss_container.name)] = epoch_loss
240 | logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc
241 | logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc
242 | logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc
243 | logs['train_info'] = batch_info
244 | end_time = time.time()
245 |
246 | # write log for this epoch
247 | logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
248 |
249 |
250 | def validate(**kwargs):
251 | # Retrieve training configuration
252 | logs = kwargs['logs']
253 | data_loader = kwargs['data_loader']
254 | net = kwargs['net']
255 | pbar = kwargs['pbar']
256 |
257 | # metrics initialization
258 | loss_container.reset()
259 | raw_metric.reset()
260 |
261 | # begin validation
262 | start_time = time.time()
263 | net.eval()
264 | with torch.no_grad():
265 | for i, (X, y) in enumerate(data_loader):
266 | # obtain data
267 | X = X.to(device)
268 | y = y.to(device)
269 |
270 | ##################################
271 | # Raw Image
272 | ##################################
273 | y_pred_raw, _, attention_map = net(X)
274 |
275 | ##################################
276 | # Object Localization and Refinement
277 | ##################################
278 | crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
279 | y_pred_crop, _, _ = net(crop_images)
280 |
281 | ##################################
282 | # Final prediction
283 | ##################################
284 | y_pred = (y_pred_raw + y_pred_crop) / 2.
285 |
286 | # loss
287 | batch_loss = cross_entropy_loss(y_pred, y)
288 | epoch_loss = loss_container(batch_loss.item())
289 |
290 | # metrics: top-1,5 error
291 | epoch_acc = raw_metric(y_pred, y)
292 |
293 | # end of validation
294 | logs['val_{}'.format(loss_container.name)] = epoch_loss
295 | logs['val_{}'.format(raw_metric.name)] = epoch_acc
296 | end_time = time.time()
297 |
298 | batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
299 | pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info))
300 |
301 | # write log for this epoch
302 | logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
303 | logging.info('')
304 |
305 |
306 | if __name__ == '__main__':
307 | main()
308 |
--------------------------------------------------------------------------------
/models/inception.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 |
7 |
8 | __all__ = ['Inception3', 'inception_v3']
9 |
10 |
11 | model_urls = {
12 | # Inception v3 ported from TensorFlow
13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
14 | }
15 |
16 |
17 | def inception_v3(pretrained=False, **kwargs):
18 | r"""Inception v3 model architecture from
19 | `"Rethinking the Inception Architecture for Computer Vision" `_.
20 |
21 | Args:
22 | pretrained (bool): If True, returns a model pre-trained on ImageNet
23 | """
24 | if pretrained:
25 | if 'transform_input' not in kwargs:
26 | kwargs['transform_input'] = True
27 | model = Inception3(**kwargs)
28 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
29 | return model
30 |
31 | return Inception3(**kwargs)
32 |
33 |
34 | class Inception3(nn.Module):
35 |
36 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
37 | super(Inception3, self).__init__()
38 | self.aux_logits = aux_logits
39 | self.transform_input = transform_input
40 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
41 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
42 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
43 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
44 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
45 | self.Mixed_5b = InceptionA(192, pool_features=32)
46 | self.Mixed_5c = InceptionA(256, pool_features=64)
47 | self.Mixed_5d = InceptionA(288, pool_features=64)
48 | self.Mixed_6a = InceptionB(288)
49 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
50 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
51 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
52 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
53 | if aux_logits:
54 | self.AuxLogits = InceptionAux(768, num_classes)
55 | self.Mixed_7a = InceptionD(768)
56 | self.Mixed_7b = InceptionE(1280)
57 | self.Mixed_7c = InceptionE(2048)
58 | self.fc = nn.Linear(2048, num_classes)
59 |
60 | for m in self.modules():
61 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
62 | import scipy.stats as stats
63 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1
64 | X = stats.truncnorm(-2, 2, scale=stddev)
65 | values = torch.Tensor(X.rvs(m.weight.data.numel()))
66 | values = values.view(m.weight.data.size())
67 | m.weight.data.copy_(values)
68 | elif isinstance(m, nn.BatchNorm2d):
69 | m.weight.data.fill_(1)
70 | m.bias.data.zero_()
71 |
72 | def forward(self, x):
73 | if self.transform_input:
74 | x = x.clone()
75 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
76 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
77 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
78 | # 299 x 299 x 3
79 | x = self.Conv2d_1a_3x3(x)
80 | # 149 x 149 x 32
81 | x = self.Conv2d_2a_3x3(x)
82 | # 147 x 147 x 32
83 | x = self.Conv2d_2b_3x3(x)
84 | # 147 x 147 x 64
85 | x = F.max_pool2d(x, kernel_size=3, stride=2)
86 | # 73 x 73 x 64
87 | x = self.Conv2d_3b_1x1(x)
88 | # 73 x 73 x 80
89 | x = self.Conv2d_4a_3x3(x)
90 | # 71 x 71 x 192
91 | x = F.max_pool2d(x, kernel_size=3, stride=2)
92 | # 35 x 35 x 192
93 | x = self.Mixed_5b(x)
94 | # 35 x 35 x 256
95 | x = self.Mixed_5c(x)
96 | # 35 x 35 x 288
97 | x = self.Mixed_5d(x)
98 | # 35 x 35 x 288
99 | x = self.Mixed_6a(x)
100 | # 17 x 17 x 768
101 | x = self.Mixed_6b(x)
102 | # 17 x 17 x 768
103 | x = self.Mixed_6c(x)
104 | # 17 x 17 x 768
105 | x = self.Mixed_6d(x)
106 | # 17 x 17 x 768
107 | x = self.Mixed_6e(x)
108 | # 17 x 17 x 768
109 | if self.training and self.aux_logits:
110 | aux = self.AuxLogits(x)
111 | # 17 x 17 x 768
112 | x = self.Mixed_7a(x)
113 | # 8 x 8 x 1280
114 | x = self.Mixed_7b(x)
115 | # 8 x 8 x 2048
116 | x = self.Mixed_7c(x)
117 | # 8 x 8 x 2048
118 | x = F.avg_pool2d(x, kernel_size=8)
119 | # 1 x 1 x 2048
120 | x = F.dropout(x, training=self.training)
121 | # 1 x 1 x 2048
122 | x = x.view(x.size(0), -1)
123 | # 2048
124 | x = self.fc(x)
125 | # 1000 (num_classes)
126 | if self.training and self.aux_logits:
127 | return x, aux
128 | return x
129 |
130 | def get_features_mixed_6e(self):
131 | return nn.Sequential(
132 | self.Conv2d_1a_3x3,
133 | self.Conv2d_2a_3x3,
134 | self.Conv2d_2b_3x3,
135 | nn.MaxPool2d(kernel_size=3, stride=2),
136 | self.Conv2d_3b_1x1,
137 | self.Conv2d_4a_3x3,
138 | nn.MaxPool2d(kernel_size=3, stride=2),
139 | self.Mixed_5b,
140 | self.Mixed_5c,
141 | self.Mixed_5d,
142 | self.Mixed_6a,
143 | self.Mixed_6b,
144 | self.Mixed_6c,
145 | self.Mixed_6d,
146 | self.Mixed_6e,
147 | )
148 |
149 | def get_features_mixed_7c(self):
150 | return nn.Sequential(
151 | self.Conv2d_1a_3x3,
152 | self.Conv2d_2a_3x3,
153 | self.Conv2d_2b_3x3,
154 | nn.MaxPool2d(kernel_size=3, stride=2),
155 | self.Conv2d_3b_1x1,
156 | self.Conv2d_4a_3x3,
157 | nn.MaxPool2d(kernel_size=3, stride=2),
158 | self.Mixed_5b,
159 | self.Mixed_5c,
160 | self.Mixed_5d,
161 | self.Mixed_6a,
162 | self.Mixed_6b,
163 | self.Mixed_6c,
164 | self.Mixed_6d,
165 | self.Mixed_6e,
166 | self.Mixed_7a,
167 | self.Mixed_7b,
168 | self.Mixed_7c,
169 | )
170 |
171 | def load_state_dict(self, state_dict, strict=True):
172 | model_dict = self.state_dict()
173 | pretrained_dict = {k: v for k, v in state_dict.items()
174 | if k in model_dict and model_dict[k].size() == v.size()}
175 |
176 | if len(pretrained_dict) == len(state_dict):
177 | logging.info('%s: All params loaded' % type(self).__name__)
178 | else:
179 | logging.info('%s: Some params were not loaded:' % type(self).__name__)
180 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
181 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
182 |
183 | model_dict.update(pretrained_dict)
184 | super(Inception3, self).load_state_dict(model_dict)
185 |
186 |
187 | class InceptionA(nn.Module):
188 |
189 | def __init__(self, in_channels, pool_features):
190 | super(InceptionA, self).__init__()
191 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
192 |
193 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
194 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
195 |
196 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
197 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
198 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
199 |
200 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
201 |
202 | def forward(self, x):
203 | branch1x1 = self.branch1x1(x)
204 |
205 | branch5x5 = self.branch5x5_1(x)
206 | branch5x5 = self.branch5x5_2(branch5x5)
207 |
208 | branch3x3dbl = self.branch3x3dbl_1(x)
209 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
210 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
211 |
212 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
213 | branch_pool = self.branch_pool(branch_pool)
214 |
215 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
216 | return torch.cat(outputs, 1)
217 |
218 |
219 | class InceptionB(nn.Module):
220 |
221 | def __init__(self, in_channels):
222 | super(InceptionB, self).__init__()
223 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
224 |
225 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
226 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
227 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
228 |
229 | def forward(self, x):
230 | branch3x3 = self.branch3x3(x)
231 |
232 | branch3x3dbl = self.branch3x3dbl_1(x)
233 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
234 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
235 |
236 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
237 |
238 | outputs = [branch3x3, branch3x3dbl, branch_pool]
239 | return torch.cat(outputs, 1)
240 |
241 |
242 | class InceptionC(nn.Module):
243 |
244 | def __init__(self, in_channels, channels_7x7):
245 | super(InceptionC, self).__init__()
246 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
247 |
248 | c7 = channels_7x7
249 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
250 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
251 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
252 |
253 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
254 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
255 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
256 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
257 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
258 |
259 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
260 |
261 | def forward(self, x):
262 | branch1x1 = self.branch1x1(x)
263 |
264 | branch7x7 = self.branch7x7_1(x)
265 | branch7x7 = self.branch7x7_2(branch7x7)
266 | branch7x7 = self.branch7x7_3(branch7x7)
267 |
268 | branch7x7dbl = self.branch7x7dbl_1(x)
269 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
270 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
271 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
272 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
273 |
274 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
275 | branch_pool = self.branch_pool(branch_pool)
276 |
277 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
278 | return torch.cat(outputs, 1)
279 |
280 |
281 | class InceptionD(nn.Module):
282 |
283 | def __init__(self, in_channels):
284 | super(InceptionD, self).__init__()
285 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
286 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
287 |
288 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
289 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
290 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
291 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
292 |
293 | def forward(self, x):
294 | branch3x3 = self.branch3x3_1(x)
295 | branch3x3 = self.branch3x3_2(branch3x3)
296 |
297 | branch7x7x3 = self.branch7x7x3_1(x)
298 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
299 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
300 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
301 |
302 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
303 | outputs = [branch3x3, branch7x7x3, branch_pool]
304 | return torch.cat(outputs, 1)
305 |
306 |
307 | class InceptionE(nn.Module):
308 |
309 | def __init__(self, in_channels):
310 | super(InceptionE, self).__init__()
311 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
312 |
313 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
314 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
315 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
316 |
317 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
318 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
319 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
320 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
321 |
322 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
323 |
324 | def forward(self, x):
325 | branch1x1 = self.branch1x1(x)
326 |
327 | branch3x3 = self.branch3x3_1(x)
328 | branch3x3 = [
329 | self.branch3x3_2a(branch3x3),
330 | self.branch3x3_2b(branch3x3),
331 | ]
332 | branch3x3 = torch.cat(branch3x3, 1)
333 |
334 | branch3x3dbl = self.branch3x3dbl_1(x)
335 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
336 | branch3x3dbl = [
337 | self.branch3x3dbl_3a(branch3x3dbl),
338 | self.branch3x3dbl_3b(branch3x3dbl),
339 | ]
340 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
341 |
342 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
343 | branch_pool = self.branch_pool(branch_pool)
344 |
345 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
346 | return torch.cat(outputs, 1)
347 |
348 |
349 | class InceptionAux(nn.Module):
350 |
351 | def __init__(self, in_channels, num_classes):
352 | super(InceptionAux, self).__init__()
353 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
354 | self.conv1 = BasicConv2d(128, 768, kernel_size=5)
355 | self.conv1.stddev = 0.01
356 | self.fc = nn.Linear(768, num_classes)
357 | self.fc.stddev = 0.001
358 |
359 | def forward(self, x):
360 | # 17 x 17 x 768
361 | x = F.avg_pool2d(x, kernel_size=5, stride=3)
362 | # 5 x 5 x 768
363 | x = self.conv0(x)
364 | # 5 x 5 x 128
365 | x = self.conv1(x)
366 | # 1 x 1 x 768
367 | x = x.view(x.size(0), -1)
368 | # 768
369 | x = self.fc(x)
370 | # 1000
371 | return x
372 |
373 |
374 | class BasicConv2d(nn.Module):
375 |
376 | def __init__(self, in_channels, out_channels, **kwargs):
377 | super(BasicConv2d, self).__init__()
378 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
379 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
380 |
381 | def forward(self, x):
382 | x = self.conv(x)
383 | x = self.bn(x)
384 | return F.relu(x, inplace=True)
385 |
--------------------------------------------------------------------------------