├── .gitignore ├── images ├── Fig1.png ├── Fig2a.PNG ├── Fig2b.PNG ├── Fig3.PNG ├── 007_raw.jpg ├── 007_heat_atten.jpg └── 007_raw_atten.jpg ├── models ├── __init__.py ├── blocks.py ├── vgg.py ├── wsdan.py ├── resnet.py └── inception.py ├── datasets ├── __init__.py ├── dog_dataset.py ├── car_dataset.py ├── aircraft_dataset.py └── bird_dataset.py ├── LICENSE ├── config.py ├── eval.py ├── README.md ├── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | *.xml 3 | -------------------------------------------------------------------------------- /images/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig1.png -------------------------------------------------------------------------------- /images/Fig2a.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig2a.PNG -------------------------------------------------------------------------------- /images/Fig2b.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig2b.PNG -------------------------------------------------------------------------------- /images/Fig3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/Fig3.PNG -------------------------------------------------------------------------------- /images/007_raw.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_raw.jpg -------------------------------------------------------------------------------- /images/007_heat_atten.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_heat_atten.jpg -------------------------------------------------------------------------------- /images/007_raw_atten.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuYuc/WS-DAN.PyTorch/HEAD/images/007_raw_atten.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .vgg import * 3 | from .inception import * 4 | from .wsdan import * -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .aircraft_dataset import AircraftDataset 2 | from .bird_dataset import BirdDataset 3 | from .car_dataset import CarDataset 4 | from .dog_dataset import DogDataset 5 | 6 | 7 | def get_trainval_datasets(tag, resize): 8 | if tag == 'aircraft': 9 | return AircraftDataset(phase='train', resize=resize), AircraftDataset(phase='val', resize=resize) 10 | elif tag == 'bird': 11 | return BirdDataset(phase='train', resize=resize), BirdDataset(phase='val', resize=resize) 12 | elif tag == 'car': 13 | return CarDataset(phase='train', resize=resize), CarDataset(phase='val', resize=resize) 14 | elif tag == 'dog': 15 | return DogDataset(phase='train', resize=resize), DogDataset(phase='val', resize=resize) 16 | else: 17 | raise ValueError('Unsupported Tag {}'.format(tag)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yuchong Gu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Training Config 3 | ################################################## 4 | GPU = '0' # GPU 5 | workers = 4 # number of Dataloader workers 6 | epochs = 160 # number of epochs 7 | batch_size = 12 # batch size 8 | learning_rate = 1e-3 # initial learning rate 9 | 10 | ################################################## 11 | # Model Config 12 | ################################################## 13 | image_size = (448, 448) # size of training images 14 | net = 'inception_mixed_6e' # feature extractor 15 | num_attentions = 32 # number of attention maps 16 | beta = 5e-2 # param for update feature centers 17 | 18 | ################################################## 19 | # Dataset/Path Config 20 | ################################################## 21 | tag = 'bird' # 'aircraft', 'bird', 'car', or 'dog' 22 | 23 | # saving directory of .ckpt models 24 | save_dir = './FGVC/CUB-200-2011/ckpt/' 25 | model_name = 'model.ckpt' 26 | log_name = 'train.log' 27 | 28 | # checkpoint model for resume training 29 | ckpt = False 30 | # ckpt = save_dir + model_name 31 | 32 | ################################################## 33 | # Eval Config 34 | ################################################## 35 | visualize = True 36 | eval_ckpt = save_dir + model_name 37 | eval_savepath = './FGVC/CUB-200-2011/visualize/' -------------------------------------------------------------------------------- /datasets/dog_dataset.py: -------------------------------------------------------------------------------- 1 | """ Stanford Dogs (Dog) Dataset 2 | Created: Nov 15,2019 - Yuchong Gu 3 | Revised: Nov 15,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import pdb 7 | from PIL import Image 8 | from scipy.io import loadmat 9 | from torch.utils.data import Dataset 10 | from utils import get_transform 11 | 12 | DATAPATH = '/home/guyuchong/DATA/FGVC/StanfordDogs' 13 | 14 | 15 | class DogDataset(Dataset): 16 | """ 17 | # Description: 18 | Dataset for retrieving Stanford Dogs images and labels 19 | 20 | # Member Functions: 21 | __init__(self, phase, resize): initializes a dataset 22 | phase: a string in ['train', 'val', 'test'] 23 | resize: output shape/size of an image 24 | 25 | __getitem__(self, item): returns an image 26 | item: the idex of image in the whole dataset 27 | 28 | __len__(self): returns the length of dataset 29 | """ 30 | 31 | def __init__(self, phase='train', resize=500): 32 | assert phase in ['train', 'val', 'test'] 33 | self.phase = phase 34 | self.resize = resize 35 | self.num_classes = 120 36 | 37 | if phase == 'train': 38 | list_path = os.path.join(DATAPATH, 'train_list.mat') 39 | else: 40 | list_path = os.path.join(DATAPATH, 'test_list.mat') 41 | 42 | list_mat = loadmat(list_path) 43 | self.images = [f.item().item() for f in list_mat['file_list']] 44 | self.labels = [f.item() for f in list_mat['labels']] 45 | 46 | # transform 47 | self.transform = get_transform(self.resize, self.phase) 48 | 49 | def __getitem__(self, item): 50 | # image 51 | image = Image.open(os.path.join(DATAPATH, 'Images', self.images[item])).convert('RGB') # (C, H, W) 52 | image = self.transform(image) 53 | 54 | # return image and label 55 | return image, self.labels[item] - 1 # count begin from zero 56 | 57 | def __len__(self): 58 | return len(self.images) 59 | 60 | 61 | if __name__ == '__main__': 62 | ds = DogDataset('train') 63 | # print(len(ds)) 64 | for i in range(0, 1000): 65 | image, label = ds[i] 66 | # print(image.shape, label) 67 | -------------------------------------------------------------------------------- /datasets/car_dataset.py: -------------------------------------------------------------------------------- 1 | """ Stanford Cars (Car) Dataset 2 | Created: Nov 15,2019 - Yuchong Gu 3 | Revised: Nov 15,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import pdb 7 | from PIL import Image 8 | from scipy.io import loadmat 9 | from torch.utils.data import Dataset 10 | from utils import get_transform 11 | 12 | DATAPATH = '/home/guyuchong/DATA/FGVC/StanfordCars' 13 | 14 | 15 | class CarDataset(Dataset): 16 | """ 17 | # Description: 18 | Dataset for retrieving Stanford Cars images and labels 19 | 20 | # Member Functions: 21 | __init__(self, phase, resize): initializes a dataset 22 | phase: a string in ['train', 'val', 'test'] 23 | resize: output shape/size of an image 24 | 25 | __getitem__(self, item): returns an image 26 | item: the idex of image in the whole dataset 27 | 28 | __len__(self): returns the length of dataset 29 | """ 30 | 31 | def __init__(self, phase='train', resize=500): 32 | assert phase in ['train', 'val', 'test'] 33 | self.phase = phase 34 | self.resize = resize 35 | self.num_classes = 196 36 | 37 | if phase == 'train': 38 | list_path = os.path.join(DATAPATH, 'devkit', 'cars_train_annos.mat') 39 | self.image_path = os.path.join(DATAPATH, 'cars_train') 40 | else: 41 | list_path = os.path.join(DATAPATH, 'cars_test_annos_withlabels.mat') 42 | self.image_path = os.path.join(DATAPATH, 'cars_test') 43 | 44 | list_mat = loadmat(list_path) 45 | self.images = [f.item() for f in list_mat['annotations']['fname'][0]] 46 | self.labels = [f.item() for f in list_mat['annotations']['class'][0]] 47 | 48 | # transform 49 | self.transform = get_transform(self.resize, self.phase) 50 | 51 | def __getitem__(self, item): 52 | # image 53 | image = Image.open(os.path.join(self.image_path, self.images[item])).convert('RGB') # (C, H, W) 54 | image = self.transform(image) 55 | 56 | # return image and label 57 | return image, self.labels[item] - 1 # count begin from zero 58 | 59 | def __len__(self): 60 | return len(self.images) 61 | 62 | 63 | if __name__ == '__main__': 64 | ds = CarDataset('val') 65 | # print(len(ds)) 66 | for i in range(0, 100): 67 | image, label = ds[i] 68 | # print(image.shape, label) 69 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | __all__ = ['CBAMLayer', 'SPPLayer'] 7 | 8 | ''' 9 | Woo et al., 10 | "CBAM: Convolutional Block Attention Module", 11 | ECCV 2018, 12 | arXiv:1807.06521 13 | ''' 14 | class CBAMLayer(nn.Module): 15 | def __init__(self, channel, reduction=16, spatial_kernel=7): 16 | super(CBAMLayer, self).__init__() 17 | # channel attention 18 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 19 | self.max_pool = nn.AdaptiveMaxPool2d(1) 20 | self.mlp = nn.Sequential( 21 | nn.Conv2d(channel, channel // reduction, 1, bias=False), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(channel // reduction, channel, 1, bias=False), 24 | ) 25 | # spatial attention 26 | self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel, padding=spatial_kernel//2, bias=False) 27 | self.sigmoid = nn.Sigmoid() 28 | 29 | def forward(self, x): 30 | # channel attention 31 | max_out = self.mlp(self.max_pool(x)) 32 | avg_out = self.mlp(self.avg_pool(x)) 33 | channel_out = self.sigmoid(max_out + avg_out) 34 | x = channel_out * x 35 | 36 | # spatial attention 37 | max_out, _ = torch.max(x, dim=1, keepdim=True) 38 | avg_out = torch.mean(x, dim=1, keepdim=True) 39 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1))) 40 | x = spatial_out * x 41 | return x 42 | 43 | 44 | ''' 45 | He et al., 46 | "Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition", 47 | TPAMI 2015, 48 | arXiv:1406.4729 49 | ''' 50 | class SPPLayer(nn.Module): 51 | def __init__(self, pool_size, pool=nn.MaxPool2d): 52 | super(SPPLayer, self).__init__() 53 | self.pool_size = pool_size 54 | self.pool = pool 55 | self.out_length = np.sum(np.array(self.pool_size) ** 2) 56 | 57 | def forward(self, x): 58 | B, C, H, W = x.size() 59 | for i in range(len(self.pool_size)): 60 | h_wid = int(math.ceil(H / self.pool_size[i])) 61 | w_wid = int(math.ceil(W / self.pool_size[i])) 62 | h_pad = (h_wid * self.pool_size[i] - H + 1) / 2 63 | w_pad = (w_wid * self.pool_size[i] - W + 1) / 2 64 | out = self.pool((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))(x) 65 | if i == 0: 66 | spp = out.view(B, -1) 67 | else: 68 | spp = torch.cat([spp, out.view(B, -1)], dim=1) 69 | return spp 70 | -------------------------------------------------------------------------------- /datasets/aircraft_dataset.py: -------------------------------------------------------------------------------- 1 | """ FGVC Aircraft (Aircraft) Dataset 2 | Created: Nov 15,2019 - Yuchong Gu 3 | Revised: Nov 15,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import pdb 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from utils import get_transform 10 | 11 | DATAPATH = '/home/guyuchong/DATA/FGVC/FGVC-Aircraft/data' 12 | FILENAME_LENGTH = 7 13 | 14 | 15 | class AircraftDataset(Dataset): 16 | """ 17 | # Description: 18 | Dataset for retrieving FGVC Aircraft images and labels 19 | 20 | # Member Functions: 21 | __init__(self, phase, resize): initializes a dataset 22 | phase: a string in ['train', 'val', 'test'] 23 | resize: output shape/size of an image 24 | 25 | __getitem__(self, item): returns an image 26 | item: the idex of image in the whole dataset 27 | 28 | __len__(self): returns the length of dataset 29 | """ 30 | 31 | def __init__(self, phase='train', resize=500): 32 | assert phase in ['train', 'val', 'test'] 33 | self.phase = phase 34 | self.resize = resize 35 | 36 | variants_dict = {} 37 | with open(os.path.join(DATAPATH, 'variants.txt'), 'r') as f: 38 | for idx, line in enumerate(f.readlines()): 39 | variants_dict[line.strip()] = idx 40 | self.num_classes = len(variants_dict) 41 | 42 | if phase == 'train': 43 | list_path = os.path.join(DATAPATH, 'images_variant_trainval.txt') 44 | else: 45 | list_path = os.path.join(DATAPATH, 'images_variant_test.txt') 46 | 47 | self.images = [] 48 | self.labels = [] 49 | with open(list_path, 'r') as f: 50 | for line in f.readlines(): 51 | fname_and_variant = line.strip() 52 | self.images.append(fname_and_variant[:FILENAME_LENGTH]) 53 | self.labels.append(variants_dict[fname_and_variant[FILENAME_LENGTH + 1:]]) 54 | 55 | # transform 56 | self.transform = get_transform(self.resize, self.phase) 57 | 58 | def __getitem__(self, item): 59 | # image 60 | image = Image.open(os.path.join(DATAPATH, 'images', '%s.jpg' % self.images[item])).convert('RGB') # (C, H, W) 61 | image = self.transform(image) 62 | 63 | # return image and label 64 | return image, self.labels[item] # count begin from zero 65 | 66 | def __len__(self): 67 | return len(self.images) 68 | 69 | 70 | if __name__ == '__main__': 71 | ds = AircraftDataset('test', 448) 72 | # print(len(ds)) 73 | from utils import AverageMeter 74 | height_meter = AverageMeter('height') 75 | width_meter = AverageMeter('width') 76 | 77 | for i in range(len(ds)): 78 | image, label = ds[i] 79 | avgH = height_meter(image.size(1)) 80 | avgW = width_meter(image.size(2)) 81 | print('H: %.2f, W: %.2f' % (avgH, avgW)) 82 | -------------------------------------------------------------------------------- /datasets/bird_dataset.py: -------------------------------------------------------------------------------- 1 | """ CUB-200-2011 (Bird) Dataset 2 | Created: Oct 11,2019 - Yuchong Gu 3 | Revised: Oct 11,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import pdb 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from utils import get_transform 10 | 11 | DATAPATH = '/home/guyuchong/DATA/FGVC/CUB-200-2011' 12 | image_path = {} 13 | image_label = {} 14 | 15 | 16 | class BirdDataset(Dataset): 17 | """ 18 | # Description: 19 | Dataset for retrieving CUB-200-2011 images and labels 20 | 21 | # Member Functions: 22 | __init__(self, phase, resize): initializes a dataset 23 | phase: a string in ['train', 'val', 'test'] 24 | resize: output shape/size of an image 25 | 26 | __getitem__(self, item): returns an image 27 | item: the idex of image in the whole dataset 28 | 29 | __len__(self): returns the length of dataset 30 | """ 31 | 32 | def __init__(self, phase='train', resize=500): 33 | assert phase in ['train', 'val', 'test'] 34 | self.phase = phase 35 | self.resize = resize 36 | self.image_id = [] 37 | self.num_classes = 200 38 | 39 | # get image path from images.txt 40 | with open(os.path.join(DATAPATH, 'images.txt')) as f: 41 | for line in f.readlines(): 42 | id, path = line.strip().split(' ') 43 | image_path[id] = path 44 | 45 | # get image label from image_class_labels.txt 46 | with open(os.path.join(DATAPATH, 'image_class_labels.txt')) as f: 47 | for line in f.readlines(): 48 | id, label = line.strip().split(' ') 49 | image_label[id] = int(label) 50 | 51 | # get train/test image id from train_test_split.txt 52 | with open(os.path.join(DATAPATH, 'train_test_split.txt')) as f: 53 | for line in f.readlines(): 54 | image_id, is_training_image = line.strip().split(' ') 55 | is_training_image = int(is_training_image) 56 | 57 | if self.phase == 'train' and is_training_image: 58 | self.image_id.append(image_id) 59 | if self.phase in ('val', 'test') and not is_training_image: 60 | self.image_id.append(image_id) 61 | 62 | # transform 63 | self.transform = get_transform(self.resize, self.phase) 64 | 65 | def __getitem__(self, item): 66 | # get image id 67 | image_id = self.image_id[item] 68 | 69 | # image 70 | image = Image.open(os.path.join(DATAPATH, 'images', image_path[image_id])).convert('RGB') # (C, H, W) 71 | image = self.transform(image) 72 | 73 | # return image and label 74 | return image, image_label[image_id] - 1 # count begin from zero 75 | 76 | def __len__(self): 77 | return len(self.image_id) 78 | 79 | 80 | if __name__ == '__main__': 81 | ds = BirdDataset('train') 82 | print(len(ds)) 83 | for i in range(0, 10): 84 | image, label = ds[i] 85 | print(image.shape, label) 86 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from models.blocks import SPPLayer 6 | import logging 7 | 8 | 9 | __all__ = ['vgg19_bn', 'vgg19'] 10 | 11 | 12 | model_urls = { 13 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 14 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 15 | } 16 | 17 | 18 | class VGG(nn.Module): 19 | 20 | def __init__(self, features, num_classes=1000, init_weights=True): 21 | super(VGG, self).__init__() 22 | self.features = features 23 | self.spp = SPPLayer(pool_size=[1, 2, 4], pool=nn.MaxPool2d) 24 | self.fc = nn.Sequential( 25 | nn.Linear(512 * self.spp.out_length, 1024), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(1024, num_classes)) 29 | 30 | if init_weights: 31 | self._initialize_weights() 32 | 33 | def forward(self, x): 34 | x = self.features(x) 35 | x = self.spp(x) 36 | x = self.fc(x) 37 | return x 38 | 39 | def _initialize_weights(self): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 43 | m.weight.data.normal_(0, math.sqrt(2. / n)) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1) 48 | m.bias.data.zero_() 49 | elif isinstance(m, nn.Linear): 50 | m.weight.data.normal_(0, 0.01) 51 | m.bias.data.zero_() 52 | 53 | def get_features(self): 54 | return self.features 55 | 56 | def load_state_dict(self, state_dict, strict=True): 57 | model_dict = self.state_dict() 58 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 59 | if len(pretrained_dict) == len(state_dict): 60 | logging.info('%s: All params loaded' % type(self).__name__) 61 | else: 62 | logging.info('%s: Some params were not loaded:' % type(self).__name__) 63 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()] 64 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys)) 65 | model_dict.update(pretrained_dict) 66 | super(VGG, self).load_state_dict(model_dict) 67 | 68 | 69 | def make_layers(cfg, batch_norm=False): 70 | layers = [] 71 | in_channels = 3 72 | for v in cfg: 73 | if v == 'M': 74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 75 | else: 76 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 77 | if batch_norm: 78 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 79 | else: 80 | layers += [conv2d, nn.ReLU(inplace=True)] 81 | in_channels = v 82 | return nn.Sequential(*layers) 83 | 84 | 85 | cfg = { 86 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 87 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 88 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 89 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 90 | } 91 | 92 | 93 | def vgg19(pretrained=False, **kwargs): 94 | """VGG 19-layer model (configuration "E") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | if pretrained: 100 | kwargs['init_weights'] = False 101 | model = VGG(make_layers(cfg['E']), **kwargs) 102 | if pretrained: 103 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 104 | return model 105 | 106 | 107 | def vgg19_bn(pretrained=False, **kwargs): 108 | """VGG 19-layer model (configuration 'E') with batch normalization 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | if pretrained: 114 | kwargs['init_weights'] = False 115 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 116 | if pretrained: 117 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 118 | return model 119 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """EVALUATION 2 | Created: Nov 22,2019 - Yuchong Gu 3 | Revised: Dec 03,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import logging 7 | import warnings 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | import config 16 | from models import WSDAN 17 | from datasets import get_trainval_datasets 18 | from utils import TopKAccuracyMetric, batch_augment 19 | 20 | # GPU settings 21 | assert torch.cuda.is_available() 22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 23 | device = torch.device("cuda") 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # visualize 27 | visualize = config.visualize 28 | savepath = config.eval_savepath 29 | if visualize: 30 | os.makedirs(savepath, exist_ok=True) 31 | 32 | ToPILImage = transforms.ToPILImage() 33 | MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 34 | STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 35 | 36 | 37 | def generate_heatmap(attention_maps): 38 | heat_attention_maps = [] 39 | heat_attention_maps.append(attention_maps[:, 0, ...]) # R 40 | heat_attention_maps.append(attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float() + \ 41 | (1. - attention_maps[:, 0, ...]) * (attention_maps[:, 0, ...] >= 0.5).float()) # G 42 | heat_attention_maps.append(1. - attention_maps[:, 0, ...]) # B 43 | return torch.stack(heat_attention_maps, dim=1) 44 | 45 | 46 | def main(): 47 | logging.basicConfig( 48 | format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', 49 | level=logging.INFO) 50 | warnings.filterwarnings("ignore") 51 | 52 | try: 53 | ckpt = config.eval_ckpt 54 | except: 55 | logging.info('Set ckpt for evaluation in config.py') 56 | return 57 | 58 | ################################## 59 | # Dataset for testing 60 | ################################## 61 | _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) 62 | test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, 63 | num_workers=2, pin_memory=True) 64 | 65 | ################################## 66 | # Initialize model 67 | ################################## 68 | net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) 69 | 70 | # Load ckpt and get state_dict 71 | checkpoint = torch.load(ckpt) 72 | state_dict = checkpoint['state_dict'] 73 | 74 | # Load weights 75 | net.load_state_dict(state_dict) 76 | logging.info('Network loaded from {}'.format(ckpt)) 77 | 78 | ################################## 79 | # use cuda 80 | ################################## 81 | net.to(device) 82 | if torch.cuda.device_count() > 1: 83 | net = nn.DataParallel(net) 84 | 85 | ################################## 86 | # Prediction 87 | ################################## 88 | raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) 89 | ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) 90 | raw_accuracy.reset() 91 | ref_accuracy.reset() 92 | 93 | net.eval() 94 | with torch.no_grad(): 95 | pbar = tqdm(total=len(test_loader), unit=' batches') 96 | pbar.set_description('Validation') 97 | for i, (X, y) in enumerate(test_loader): 98 | X = X.to(device) 99 | y = y.to(device) 100 | 101 | # WS-DAN 102 | y_pred_raw, _, attention_maps = net(X) 103 | 104 | # Augmentation with crop_mask 105 | crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) 106 | 107 | y_pred_crop, _, _ = net(crop_image) 108 | y_pred = (y_pred_raw + y_pred_crop) / 2. 109 | 110 | if visualize: 111 | # reshape attention maps 112 | attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) 113 | attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) 114 | 115 | # get heat attention maps 116 | heat_attention_maps = generate_heatmap(attention_maps) 117 | 118 | # raw_image, heat_attention, raw_attention 119 | raw_image = X.cpu() * STD + MEAN 120 | heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 121 | raw_attention_image = raw_image * attention_maps 122 | 123 | for batch_idx in range(X.size(0)): 124 | rimg = ToPILImage(raw_image[batch_idx]) 125 | raimg = ToPILImage(raw_attention_image[batch_idx]) 126 | haimg = ToPILImage(heat_attention_image[batch_idx]) 127 | rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) 128 | raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) 129 | haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) 130 | 131 | # Top K 132 | epoch_raw_acc = raw_accuracy(y_pred_raw, y) 133 | epoch_ref_acc = ref_accuracy(y_pred, y) 134 | 135 | # end of this batch 136 | batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( 137 | epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) 138 | pbar.update() 139 | pbar.set_postfix_str(batch_info) 140 | 141 | pbar.close() 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WS-DAN.PyTorch 2 | A neat PyTorch implementation of WS-DAN (Weakly Supervised Data Augmentation Network) for FGVC (Fine-Grained Visual Classification). (_Hu et al._, ["See Better Before Looking Closer: Weakly Supervised Data Augmentation 3 | Network for Fine-Grained Visual Classification"](https://arxiv.org/abs/1901.09891v2), arXiv:1901.09891) 4 | 5 | **NOTICE: This is NOT an official implementation by authors of WS-DAN. The official implementation is available at [tau-yihouxiang/WS_DAN](https://github.com/tau-yihouxiang/WS_DAN) (and there's another unofficial PyTorch version [wvinzh/WS_DAN_PyTorch](https://github.com/wvinzh/WS_DAN_PyTorch)).** 6 | 7 | 8 | 9 | 10 | ## Innovations 11 | 1. Data Augmentation: Attention Cropping and Attention Dropping 12 |
13 | Fig1 14 |
15 | 16 | 2. Bilinear Attention Pooling (BAP) for Features Generation 17 |
18 | Fig3 19 |
20 | 21 | 3. Training Process and Testing Process 22 |
23 | Fig2a 24 | Fig2b 25 |
26 | 27 | 28 | 29 | ## Performance 30 | * PyTorch experiments were done on a Titan Xp GPU (batch_size = 12). 31 | 32 | |Dataset|Object|Category|Train|Test|Accuracy (Paper)|Accuracy (PyTorch)|Feature Net| 33 | |-------|------|--------|-----|----|----------------|--------------------|---| 34 | |[FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)|Aircraft|100|6,667|3,333|93.0|93.28|inception_mixed_6e| 35 | |[CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)|Bird|200|5,994|5,794|89.4|88.28|inception_mixed_6e| 36 | |[Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)|Car|196|8,144|8,041|94.5|94.38|inception_mixed_6e| 37 | |[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/)|Dog|120|12,000|8,580|92.2|89.66|inception_mixed_7c| 38 | 39 | 40 | 41 | ## Usage 42 | 43 | ### WS-DAN 44 | This repo contains WS-DAN with feature extractors including VGG19(```'vgg19', 'vgg19_bn'```), 45 | ResNet34/50/101/152(```'resnet34', 'resnet50', 'resnet101', 'resnet152'```), 46 | and Inception_v3(```'inception_mixed_6e', 'inception_mixed_7c'```) in PyTorch form, see ```./models/wsdan.py```. 47 | 48 | ```python 49 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True) 50 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_7c', pretrained=True) 51 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='vgg19_bn', pretrained=True) 52 | net = WSDAN(num_classes=num_classes, M=num_attentions, net='resnet50', pretrained=True) 53 | ``` 54 | 55 | ### Dataset Directory 56 | 57 | * [FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) (Aircraft) 58 | 59 | ``` 60 | -/FGVC-Aircraft/data/ 61 | └─── images 62 | └─── 0034309.jpg 63 | └─── 0034958.jpg 64 | └─── ... 65 | └─── variants.txt 66 | └─── images_variant_trainval.txt 67 | └─── images_variant_test.txt 68 | ``` 69 | 70 | * [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) (Bird) 71 | 72 | ``` 73 | -/CUB-200-2011 74 | └─── images.txt 75 | └─── image_class_labels.txt 76 | └─── train_test_split.txt 77 | └─── images 78 | └─── 001.Black_footed_Albatross 79 | └─── Black_Footed_Albatross_0001_796111.jpg 80 | └─── ... 81 | └─── 002.Laysan_Albatross 82 | └─── ... 83 | ``` 84 | 85 | * [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) (Car) 86 | 87 | ``` 88 | -/StanfordCars 89 | └─── cars_test 90 | └─── 00001.jpg 91 | └─── 00002.jpg 92 | └─── ... 93 | └─── cars_train 94 | └─── 00001.jpg 95 | └─── 00002.jpg 96 | └─── ... 97 | └─── devkit 98 | └─── cars_train_annos.mat 99 | └─── cars_test_annos_withlabels.mat 100 | ``` 101 | 102 | * [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) (Dog) 103 | 104 | ``` 105 | -/StanfordDogs 106 | └─── Images 107 | └─── n02085620-Chihuahua 108 | └─── n02085620_10074.jpg 109 | └─── ... 110 | └─── n02085782-Japanese_spaniel 111 | └─── ... 112 | └─── train_list.mat 113 | └─── test_list.mat 114 | ``` 115 | 116 | 117 | ### Run 118 | 119 | 1. ``` git clone``` this repo. 120 | 121 | 2. Prepare data and **modify DATAPATH** in ```datasets/_dataset.py```. 122 | 123 | 3. **Set configurations** in ```config.py``` (Training Config, Model Config, Dataset/Path Config): 124 | 125 | ```python 126 | tag = 'aircraft' # 'aircraft', 'bird', 'car', or 'dog' 127 | ``` 128 | 129 | 4. ```$ nohup python3 train.py > progress.bar &``` for training. 130 | 131 | 5. ```$ tail -f progress.bar``` to see training process (tqdm package is required. Other logs are written in ```/train.log```). 132 | 133 | 6. Set configurations in ```config.py``` (Eval Config) and run ```$ python3 eval.py``` for evaluation and visualization. 134 | 135 | ### Attention Maps Visualization 136 | 137 | Code in ```eval.py``` helps generate attention maps. (Image, Heat Attention Map, Image x Attention Map) 138 | 139 |
140 | Raw 141 | Heat 142 | Atten 143 |
144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /models/wsdan.py: -------------------------------------------------------------------------------- 1 | """ 2 | WS-DAN models 3 | 4 | Hu et al., 5 | "See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification", 6 | arXiv:1901.09891 7 | 8 | Created: May 04,2019 - Yuchong Gu 9 | Revised: Dec 03,2019 - Yuchong Gu 10 | """ 11 | import logging 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | import models.vgg as vgg 18 | import models.resnet as resnet 19 | from models.inception import inception_v3, BasicConv2d 20 | 21 | __all__ = ['WSDAN'] 22 | EPSILON = 1e-12 23 | 24 | 25 | # Bilinear Attention Pooling 26 | class BAP(nn.Module): 27 | def __init__(self, pool='GAP'): 28 | super(BAP, self).__init__() 29 | assert pool in ['GAP', 'GMP'] 30 | if pool == 'GAP': 31 | self.pool = None 32 | else: 33 | self.pool = nn.AdaptiveMaxPool2d(1) 34 | 35 | def forward(self, features, attentions): 36 | B, C, H, W = features.size() 37 | _, M, AH, AW = attentions.size() 38 | 39 | # match size 40 | if AH != H or AW != W: 41 | attentions = F.upsample_bilinear(attentions, size=(H, W)) 42 | 43 | # feature_matrix: (B, M, C) -> (B, M * C) 44 | if self.pool is None: 45 | feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1) 46 | else: 47 | feature_matrix = [] 48 | for i in range(M): 49 | AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1) 50 | feature_matrix.append(AiF) 51 | feature_matrix = torch.cat(feature_matrix, dim=1) 52 | 53 | # sign-sqrt 54 | feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON) 55 | 56 | # l2 normalization along dimension M and C 57 | feature_matrix = F.normalize(feature_matrix, dim=-1) 58 | return feature_matrix 59 | 60 | 61 | # WS-DAN: Weakly Supervised Data Augmentation Network for FGVC 62 | class WSDAN(nn.Module): 63 | def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False): 64 | super(WSDAN, self).__init__() 65 | self.num_classes = num_classes 66 | self.M = M 67 | self.net = net 68 | 69 | # Network Initialization 70 | if 'inception' in net: 71 | if net == 'inception_mixed_6e': 72 | self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e() 73 | self.num_features = 768 74 | elif net == 'inception_mixed_7c': 75 | self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c() 76 | self.num_features = 2048 77 | else: 78 | raise ValueError('Unsupported net: %s' % net) 79 | elif 'vgg' in net: 80 | self.features = getattr(vgg, net)(pretrained=pretrained).get_features() 81 | self.num_features = 512 82 | elif 'resnet' in net: 83 | self.features = getattr(resnet, net)(pretrained=pretrained).get_features() 84 | self.num_features = 512 * self.features[-1][-1].expansion 85 | else: 86 | raise ValueError('Unsupported net: %s' % net) 87 | 88 | # Attention Maps 89 | self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1) 90 | 91 | # Bilinear Attention Pooling 92 | self.bap = BAP(pool='GAP') 93 | 94 | # Classification Layer 95 | self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False) 96 | 97 | logging.info('WSDAN: using {} as feature extractor, num_classes: {}, num_attentions: {}'.format(net, self.num_classes, self.M)) 98 | 99 | def forward(self, x): 100 | batch_size = x.size(0) 101 | 102 | # Feature Maps, Attention Maps and Feature Matrix 103 | feature_maps = self.features(x) 104 | if self.net != 'inception_mixed_7c': 105 | attention_maps = self.attentions(feature_maps) 106 | else: 107 | attention_maps = feature_maps[:, :self.M, ...] 108 | feature_matrix = self.bap(feature_maps, attention_maps) 109 | 110 | # Classification 111 | p = self.fc(feature_matrix * 100.) 112 | 113 | # Generate Attention Map 114 | if self.training: 115 | # Randomly choose one of attention maps Ak 116 | attention_map = [] 117 | for i in range(batch_size): 118 | attention_weights = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + EPSILON) 119 | attention_weights = F.normalize(attention_weights, p=1, dim=0) 120 | k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy()) 121 | attention_map.append(attention_maps[i, k_index, ...]) 122 | attention_map = torch.stack(attention_map) # (B, 2, H, W) - one for cropping, the other for dropping 123 | else: 124 | # Object Localization Am = mean(Ak) 125 | attention_map = torch.mean(attention_maps, dim=1, keepdim=True) # (B, 1, H, W) 126 | 127 | # p: (B, self.num_classes) 128 | # feature_matrix: (B, M * C) 129 | # attention_map: (B, 2, H, W) in training, (B, 1, H, W) in val/testing 130 | return p, feature_matrix, attention_map 131 | 132 | def load_state_dict(self, state_dict, strict=True): 133 | model_dict = self.state_dict() 134 | pretrained_dict = {k: v for k, v in state_dict.items() 135 | if k in model_dict and model_dict[k].size() == v.size()} 136 | 137 | if len(pretrained_dict) == len(state_dict): 138 | logging.info('%s: All params loaded' % type(self).__name__) 139 | else: 140 | logging.info('%s: Some params were not loaded:' % type(self).__name__) 141 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()] 142 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys)) 143 | 144 | model_dict.update(pretrained_dict) 145 | super(WSDAN, self).load_state_dict(model_dict) 146 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utils 2 | Created: Nov 11,2019 - Yuchong Gu 3 | Revised: Dec 03,2019 - Yuchong Gu 4 | """ 5 | import torch 6 | import random 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.transforms as transforms 11 | 12 | 13 | ############################################## 14 | # Center Loss for Attention Regularization 15 | ############################################## 16 | class CenterLoss(nn.Module): 17 | def __init__(self): 18 | super(CenterLoss, self).__init__() 19 | self.l2_loss = nn.MSELoss(reduction='sum') 20 | 21 | def forward(self, outputs, targets): 22 | return self.l2_loss(outputs, targets) / outputs.size(0) 23 | 24 | 25 | ################################## 26 | # Metric 27 | ################################## 28 | class Metric(object): 29 | pass 30 | 31 | 32 | class AverageMeter(Metric): 33 | def __init__(self, name='loss'): 34 | self.name = name 35 | self.reset() 36 | 37 | def reset(self): 38 | self.scores = 0. 39 | self.total_num = 0. 40 | 41 | def __call__(self, batch_score, sample_num=1): 42 | self.scores += batch_score 43 | self.total_num += sample_num 44 | return self.scores / self.total_num 45 | 46 | 47 | class TopKAccuracyMetric(Metric): 48 | def __init__(self, topk=(1,)): 49 | self.name = 'topk_accuracy' 50 | self.topk = topk 51 | self.maxk = max(topk) 52 | self.reset() 53 | 54 | def reset(self): 55 | self.corrects = np.zeros(len(self.topk)) 56 | self.num_samples = 0. 57 | 58 | def __call__(self, output, target): 59 | """Computes the precision@k for the specified values of k""" 60 | self.num_samples += target.size(0) 61 | _, pred = output.topk(self.maxk, 1, True, True) 62 | pred = pred.t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | 65 | for i, k in enumerate(self.topk): 66 | correct_k = correct[:k].view(-1).float().sum(0) 67 | self.corrects[i] += correct_k.item() 68 | 69 | return self.corrects * 100. / self.num_samples 70 | 71 | 72 | ################################## 73 | # Callback 74 | ################################## 75 | class Callback(object): 76 | def __init__(self): 77 | pass 78 | 79 | def on_epoch_begin(self): 80 | pass 81 | 82 | def on_epoch_end(self, *args): 83 | pass 84 | 85 | 86 | class ModelCheckpoint(Callback): 87 | def __init__(self, savepath, monitor='val_topk_accuracy', mode='max'): 88 | self.savepath = savepath 89 | self.monitor = monitor 90 | self.mode = mode 91 | self.reset() 92 | super(ModelCheckpoint, self).__init__() 93 | 94 | def reset(self): 95 | if self.mode == 'max': 96 | self.best_score = float('-inf') 97 | else: 98 | self.best_score = float('inf') 99 | 100 | def set_best_score(self, score): 101 | if isinstance(score, np.ndarray): 102 | self.best_score = score[0] 103 | else: 104 | self.best_score = score 105 | 106 | def on_epoch_begin(self): 107 | pass 108 | 109 | def on_epoch_end(self, logs, net, **kwargs): 110 | current_score = logs[self.monitor] 111 | if isinstance(current_score, np.ndarray): 112 | current_score = current_score[0] 113 | 114 | if (self.mode == 'max' and current_score > self.best_score) or \ 115 | (self.mode == 'min' and current_score < self.best_score): 116 | self.best_score = current_score 117 | 118 | if isinstance(net, torch.nn.DataParallel): 119 | state_dict = net.module.state_dict() 120 | else: 121 | state_dict = net.state_dict() 122 | 123 | for key in state_dict.keys(): 124 | state_dict[key] = state_dict[key].cpu() 125 | 126 | if 'feature_center' in kwargs: 127 | feature_center = kwargs['feature_center'] 128 | feature_center = feature_center.cpu() 129 | 130 | torch.save({ 131 | 'logs': logs, 132 | 'state_dict': state_dict, 133 | 'feature_center': feature_center}, self.savepath) 134 | else: 135 | torch.save({ 136 | 'logs': logs, 137 | 'state_dict': state_dict}, self.savepath) 138 | 139 | 140 | ################################## 141 | # augment function 142 | ################################## 143 | def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1): 144 | batches, _, imgH, imgW = images.size() 145 | 146 | if mode == 'crop': 147 | crop_images = [] 148 | for batch_index in range(batches): 149 | atten_map = attention_map[batch_index:batch_index + 1] 150 | if isinstance(theta, tuple): 151 | theta_c = random.uniform(*theta) * atten_map.max() 152 | else: 153 | theta_c = theta * atten_map.max() 154 | 155 | crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c 156 | nonzero_indices = torch.nonzero(crop_mask[0, 0, ...]) 157 | height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0) 158 | height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH) 159 | width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0) 160 | width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW) 161 | 162 | crop_images.append( 163 | F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max], 164 | size=(imgH, imgW))) 165 | crop_images = torch.cat(crop_images, dim=0) 166 | return crop_images 167 | 168 | elif mode == 'drop': 169 | drop_masks = [] 170 | for batch_index in range(batches): 171 | atten_map = attention_map[batch_index:batch_index + 1] 172 | if isinstance(theta, tuple): 173 | theta_d = random.uniform(*theta) * atten_map.max() 174 | else: 175 | theta_d = theta * atten_map.max() 176 | 177 | drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d) 178 | drop_masks = torch.cat(drop_masks, dim=0) 179 | drop_images = images * drop_masks.float() 180 | return drop_images 181 | 182 | else: 183 | raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode) 184 | 185 | 186 | ################################## 187 | # transform in dataset 188 | ################################## 189 | def get_transform(resize, phase='train'): 190 | if phase == 'train': 191 | return transforms.Compose([ 192 | transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))), 193 | transforms.RandomCrop(resize), 194 | transforms.RandomHorizontalFlip(0.5), 195 | transforms.ColorJitter(brightness=0.126, saturation=0.5), 196 | transforms.ToTensor(), 197 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 198 | ]) 199 | else: 200 | return transforms.Compose([ 201 | transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))), 202 | transforms.CenterCrop(resize), 203 | transforms.ToTensor(), 204 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 205 | ]) 206 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from models.blocks import CBAMLayer, SPPLayer 6 | import logging 7 | 8 | __all__ = ['resnet34', 'resnet50', 'resnet101', 'resnet152', 9 | 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam', 'resnet152_cbam'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, cbam=None, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | if cbam is not None: 35 | self.cbam = CBAMLayer(planes) 36 | else: 37 | self.cbam = None 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.cbam is not None: 50 | out = self.cbam(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, cbam=None, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | if cbam is not None: 77 | self.cbam = CBAMLayer(planes * Bottleneck.expansion) 78 | else: 79 | self.cbam = None 80 | 81 | def forward(self, x): 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.cbam is not None: 96 | out = self.cbam(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(nn.Module): 108 | def __init__(self, block, layers, cbam=None, num_classes=1000): 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 112 | self.bn1 = nn.BatchNorm2d(64) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | self.layer1 = self._make_layer(block, 64, layers[0], cbam) 116 | self.layer2 = self._make_layer(block, 128, layers[1], cbam, stride=2) 117 | self.layer3 = self._make_layer(block, 256, layers[2], cbam, stride=2) 118 | self.layer4 = self._make_layer(block, 512, layers[3], cbam, stride=2) 119 | 120 | self.avgpool = nn.AdaptiveAvgPool2d(1) 121 | self.fc = nn.Linear(512 * block.expansion, num_classes) 122 | 123 | # self.spp = SPPLayer(pool_size=[1, 2, 4], pool=nn.MaxPool2d) 124 | # self.fc = nn.Linear(512 * block.expansion * self.spp.out_length, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, cbam=None, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 139 | nn.BatchNorm2d(planes * block.expansion)) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride=stride, cbam=cbam, downsample=downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes, cbam=cbam)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | # x = self.spp(x) 163 | x = self.fc(x) 164 | 165 | return x 166 | 167 | def get_features(self): 168 | return nn.Sequential( 169 | self.conv1, 170 | self.bn1, 171 | self.relu, 172 | self.maxpool, 173 | self.layer1, 174 | self.layer2, 175 | self.layer3, 176 | self.layer4, 177 | ) 178 | 179 | def load_state_dict(self, state_dict, strict=True): 180 | model_dict = self.state_dict() 181 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 182 | if len(pretrained_dict) == len(state_dict): 183 | logging.info('%s: All params loaded' % type(self).__name__) 184 | else: 185 | logging.info('%s: Some params were not loaded:' % type(self).__name__) 186 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()] 187 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys)) 188 | model_dict.update(pretrained_dict) 189 | super(ResNet, self).load_state_dict(model_dict) 190 | 191 | 192 | def resnet34(pretrained=False, num_classes=1000): 193 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 196 | return model 197 | 198 | 199 | def resnet50(pretrained=False, num_classes=1000): 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 201 | if pretrained: 202 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 203 | return model 204 | 205 | 206 | def resnet101(pretrained=False, num_classes=1000): 207 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 208 | if pretrained: 209 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 210 | return model 211 | 212 | 213 | def resnet152(pretrained=False, num_classes=1000): 214 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 215 | if pretrained: 216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 217 | return model 218 | 219 | 220 | def resnet34_cbam(pretrained=False, num_classes=1000): 221 | model = ResNet(BasicBlock, [3, 4, 6, 3], cbam=True, num_classes=num_classes) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 224 | return model 225 | 226 | 227 | def resnet50_cbam(pretrained=False, num_classes=1000): 228 | model = ResNet(Bottleneck, [3, 4, 6, 3], cbam=True, num_classes=num_classes) 229 | if pretrained: 230 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 231 | return model 232 | 233 | 234 | def resnet101_cbam(pretrained=False, num_classes=1000): 235 | model = ResNet(Bottleneck, [3, 4, 23, 3], cbam=True, num_classes=num_classes) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 238 | return model 239 | 240 | 241 | def resnet152_cbam(pretrained=False, num_classes=1000): 242 | model = ResNet(Bottleneck, [3, 8, 36, 3], cbam=True, num_classes=num_classes) 243 | if pretrained: 244 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 245 | return model -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """TRAINING 2 | Created: May 04,2019 - Yuchong Gu 3 | Revised: Dec 03,2019 - Yuchong Gu 4 | """ 5 | import os 6 | import time 7 | import logging 8 | import warnings 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | import config 16 | from models import WSDAN 17 | from datasets import get_trainval_datasets 18 | from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, ModelCheckpoint, batch_augment 19 | 20 | # GPU settings 21 | assert torch.cuda.is_available() 22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 23 | device = torch.device("cuda") 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # General loss functions 27 | cross_entropy_loss = nn.CrossEntropyLoss() 28 | center_loss = CenterLoss() 29 | 30 | # loss and metric 31 | loss_container = AverageMeter(name='loss') 32 | raw_metric = TopKAccuracyMetric(topk=(1, 5)) 33 | crop_metric = TopKAccuracyMetric(topk=(1, 5)) 34 | drop_metric = TopKAccuracyMetric(topk=(1, 5)) 35 | 36 | 37 | def main(): 38 | ################################## 39 | # Initialize saving directory 40 | ################################## 41 | if not os.path.exists(config.save_dir): 42 | os.makedirs(config.save_dir) 43 | 44 | ################################## 45 | # Logging setting 46 | ################################## 47 | logging.basicConfig( 48 | filename=os.path.join(config.save_dir, config.log_name), 49 | filemode='w', 50 | format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', 51 | level=logging.INFO) 52 | warnings.filterwarnings("ignore") 53 | 54 | ################################## 55 | # Load dataset 56 | ################################## 57 | train_dataset, validate_dataset = get_trainval_datasets(config.tag, config.image_size) 58 | 59 | train_loader, validate_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 60 | num_workers=config.workers, pin_memory=True), \ 61 | DataLoader(validate_dataset, batch_size=config.batch_size * 4, shuffle=False, 62 | num_workers=config.workers, pin_memory=True) 63 | num_classes = train_dataset.num_classes 64 | 65 | ################################## 66 | # Initialize model 67 | ################################## 68 | logs = {} 69 | start_epoch = 0 70 | net = WSDAN(num_classes=num_classes, M=config.num_attentions, net=config.net, pretrained=True) 71 | 72 | # feature_center: size of (#classes, #attention_maps * #channel_features) 73 | feature_center = torch.zeros(num_classes, config.num_attentions * net.num_features).to(device) 74 | 75 | if config.ckpt: 76 | # Load ckpt and get state_dict 77 | checkpoint = torch.load(config.ckpt) 78 | 79 | # Get epoch and some logs 80 | logs = checkpoint['logs'] 81 | start_epoch = int(logs['epoch']) 82 | 83 | # Load weights 84 | state_dict = checkpoint['state_dict'] 85 | net.load_state_dict(state_dict) 86 | logging.info('Network loaded from {}'.format(config.ckpt)) 87 | 88 | # load feature center 89 | if 'feature_center' in checkpoint: 90 | feature_center = checkpoint['feature_center'].to(device) 91 | logging.info('feature_center loaded from {}'.format(config.ckpt)) 92 | 93 | logging.info('Network weights save to {}'.format(config.save_dir)) 94 | 95 | ################################## 96 | # Use cuda 97 | ################################## 98 | net.to(device) 99 | if torch.cuda.device_count() > 1: 100 | net = nn.DataParallel(net) 101 | 102 | ################################## 103 | # Optimizer, LR Scheduler 104 | ################################## 105 | learning_rate = logs['lr'] if 'lr' in logs else config.learning_rate 106 | optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5) 107 | 108 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2) 109 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) 110 | 111 | ################################## 112 | # ModelCheckpoint 113 | ################################## 114 | callback_monitor = 'val_{}'.format(raw_metric.name) 115 | callback = ModelCheckpoint(savepath=os.path.join(config.save_dir, config.model_name), 116 | monitor=callback_monitor, 117 | mode='max') 118 | if callback_monitor in logs: 119 | callback.set_best_score(logs[callback_monitor]) 120 | else: 121 | callback.reset() 122 | 123 | ################################## 124 | # TRAINING 125 | ################################## 126 | logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'. 127 | format(config.epochs, config.batch_size, len(train_dataset), len(validate_dataset))) 128 | logging.info('') 129 | 130 | for epoch in range(start_epoch, config.epochs): 131 | callback.on_epoch_begin() 132 | 133 | logs['epoch'] = epoch + 1 134 | logs['lr'] = optimizer.param_groups[0]['lr'] 135 | 136 | logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr'])) 137 | 138 | pbar = tqdm(total=len(train_loader), unit=' batches') 139 | pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs)) 140 | 141 | train(logs=logs, 142 | data_loader=train_loader, 143 | net=net, 144 | feature_center=feature_center, 145 | optimizer=optimizer, 146 | pbar=pbar) 147 | validate(logs=logs, 148 | data_loader=validate_loader, 149 | net=net, 150 | pbar=pbar) 151 | 152 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 153 | scheduler.step(logs['val_loss']) 154 | else: 155 | scheduler.step() 156 | 157 | callback.on_epoch_end(logs, net, feature_center=feature_center) 158 | pbar.close() 159 | 160 | 161 | def train(**kwargs): 162 | # Retrieve training configuration 163 | logs = kwargs['logs'] 164 | data_loader = kwargs['data_loader'] 165 | net = kwargs['net'] 166 | feature_center = kwargs['feature_center'] 167 | optimizer = kwargs['optimizer'] 168 | pbar = kwargs['pbar'] 169 | 170 | # metrics initialization 171 | loss_container.reset() 172 | raw_metric.reset() 173 | crop_metric.reset() 174 | drop_metric.reset() 175 | 176 | # begin training 177 | start_time = time.time() 178 | net.train() 179 | for i, (X, y) in enumerate(data_loader): 180 | optimizer.zero_grad() 181 | 182 | # obtain data for training 183 | X = X.to(device) 184 | y = y.to(device) 185 | 186 | ################################## 187 | # Raw Image 188 | ################################## 189 | # raw images forward 190 | y_pred_raw, feature_matrix, attention_map = net(X) 191 | 192 | # Update Feature Center 193 | feature_center_batch = F.normalize(feature_center[y], dim=-1) 194 | feature_center[y] += config.beta * (feature_matrix.detach() - feature_center_batch) 195 | 196 | ################################## 197 | # Attention Cropping 198 | ################################## 199 | with torch.no_grad(): 200 | crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) 201 | 202 | # crop images forward 203 | y_pred_crop, _, _ = net(crop_images) 204 | 205 | ################################## 206 | # Attention Dropping 207 | ################################## 208 | with torch.no_grad(): 209 | drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) 210 | 211 | # drop images forward 212 | y_pred_drop, _, _ = net(drop_images) 213 | 214 | # loss 215 | batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \ 216 | cross_entropy_loss(y_pred_crop, y) / 3. + \ 217 | cross_entropy_loss(y_pred_drop, y) / 3. + \ 218 | center_loss(feature_matrix, feature_center_batch) 219 | 220 | # backward 221 | batch_loss.backward() 222 | optimizer.step() 223 | 224 | # metrics: loss and top-1,5 error 225 | with torch.no_grad(): 226 | epoch_loss = loss_container(batch_loss.item()) 227 | epoch_raw_acc = raw_metric(y_pred_raw, y) 228 | epoch_crop_acc = crop_metric(y_pred_crop, y) 229 | epoch_drop_acc = drop_metric(y_pred_drop, y) 230 | 231 | # end of this batch 232 | batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format( 233 | epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1], 234 | epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1]) 235 | pbar.update() 236 | pbar.set_postfix_str(batch_info) 237 | 238 | # end of this epoch 239 | logs['train_{}'.format(loss_container.name)] = epoch_loss 240 | logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc 241 | logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc 242 | logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc 243 | logs['train_info'] = batch_info 244 | end_time = time.time() 245 | 246 | # write log for this epoch 247 | logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time)) 248 | 249 | 250 | def validate(**kwargs): 251 | # Retrieve training configuration 252 | logs = kwargs['logs'] 253 | data_loader = kwargs['data_loader'] 254 | net = kwargs['net'] 255 | pbar = kwargs['pbar'] 256 | 257 | # metrics initialization 258 | loss_container.reset() 259 | raw_metric.reset() 260 | 261 | # begin validation 262 | start_time = time.time() 263 | net.eval() 264 | with torch.no_grad(): 265 | for i, (X, y) in enumerate(data_loader): 266 | # obtain data 267 | X = X.to(device) 268 | y = y.to(device) 269 | 270 | ################################## 271 | # Raw Image 272 | ################################## 273 | y_pred_raw, _, attention_map = net(X) 274 | 275 | ################################## 276 | # Object Localization and Refinement 277 | ################################## 278 | crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) 279 | y_pred_crop, _, _ = net(crop_images) 280 | 281 | ################################## 282 | # Final prediction 283 | ################################## 284 | y_pred = (y_pred_raw + y_pred_crop) / 2. 285 | 286 | # loss 287 | batch_loss = cross_entropy_loss(y_pred, y) 288 | epoch_loss = loss_container(batch_loss.item()) 289 | 290 | # metrics: top-1,5 error 291 | epoch_acc = raw_metric(y_pred, y) 292 | 293 | # end of validation 294 | logs['val_{}'.format(loss_container.name)] = epoch_loss 295 | logs['val_{}'.format(raw_metric.name)] = epoch_acc 296 | end_time = time.time() 297 | 298 | batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1]) 299 | pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info)) 300 | 301 | # write log for this epoch 302 | logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time)) 303 | logging.info('') 304 | 305 | 306 | if __name__ == '__main__': 307 | main() 308 | -------------------------------------------------------------------------------- /models/inception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | __all__ = ['Inception3', 'inception_v3'] 9 | 10 | 11 | model_urls = { 12 | # Inception v3 ported from TensorFlow 13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 14 | } 15 | 16 | 17 | def inception_v3(pretrained=False, **kwargs): 18 | r"""Inception v3 model architecture from 19 | `"Rethinking the Inception Architecture for Computer Vision" `_. 20 | 21 | Args: 22 | pretrained (bool): If True, returns a model pre-trained on ImageNet 23 | """ 24 | if pretrained: 25 | if 'transform_input' not in kwargs: 26 | kwargs['transform_input'] = True 27 | model = Inception3(**kwargs) 28 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 29 | return model 30 | 31 | return Inception3(**kwargs) 32 | 33 | 34 | class Inception3(nn.Module): 35 | 36 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 37 | super(Inception3, self).__init__() 38 | self.aux_logits = aux_logits 39 | self.transform_input = transform_input 40 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 41 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 42 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 43 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 44 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 45 | self.Mixed_5b = InceptionA(192, pool_features=32) 46 | self.Mixed_5c = InceptionA(256, pool_features=64) 47 | self.Mixed_5d = InceptionA(288, pool_features=64) 48 | self.Mixed_6a = InceptionB(288) 49 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 50 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 51 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 52 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 53 | if aux_logits: 54 | self.AuxLogits = InceptionAux(768, num_classes) 55 | self.Mixed_7a = InceptionD(768) 56 | self.Mixed_7b = InceptionE(1280) 57 | self.Mixed_7c = InceptionE(2048) 58 | self.fc = nn.Linear(2048, num_classes) 59 | 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 62 | import scipy.stats as stats 63 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 64 | X = stats.truncnorm(-2, 2, scale=stddev) 65 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 66 | values = values.view(m.weight.data.size()) 67 | m.weight.data.copy_(values) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | m.weight.data.fill_(1) 70 | m.bias.data.zero_() 71 | 72 | def forward(self, x): 73 | if self.transform_input: 74 | x = x.clone() 75 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 76 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 77 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 78 | # 299 x 299 x 3 79 | x = self.Conv2d_1a_3x3(x) 80 | # 149 x 149 x 32 81 | x = self.Conv2d_2a_3x3(x) 82 | # 147 x 147 x 32 83 | x = self.Conv2d_2b_3x3(x) 84 | # 147 x 147 x 64 85 | x = F.max_pool2d(x, kernel_size=3, stride=2) 86 | # 73 x 73 x 64 87 | x = self.Conv2d_3b_1x1(x) 88 | # 73 x 73 x 80 89 | x = self.Conv2d_4a_3x3(x) 90 | # 71 x 71 x 192 91 | x = F.max_pool2d(x, kernel_size=3, stride=2) 92 | # 35 x 35 x 192 93 | x = self.Mixed_5b(x) 94 | # 35 x 35 x 256 95 | x = self.Mixed_5c(x) 96 | # 35 x 35 x 288 97 | x = self.Mixed_5d(x) 98 | # 35 x 35 x 288 99 | x = self.Mixed_6a(x) 100 | # 17 x 17 x 768 101 | x = self.Mixed_6b(x) 102 | # 17 x 17 x 768 103 | x = self.Mixed_6c(x) 104 | # 17 x 17 x 768 105 | x = self.Mixed_6d(x) 106 | # 17 x 17 x 768 107 | x = self.Mixed_6e(x) 108 | # 17 x 17 x 768 109 | if self.training and self.aux_logits: 110 | aux = self.AuxLogits(x) 111 | # 17 x 17 x 768 112 | x = self.Mixed_7a(x) 113 | # 8 x 8 x 1280 114 | x = self.Mixed_7b(x) 115 | # 8 x 8 x 2048 116 | x = self.Mixed_7c(x) 117 | # 8 x 8 x 2048 118 | x = F.avg_pool2d(x, kernel_size=8) 119 | # 1 x 1 x 2048 120 | x = F.dropout(x, training=self.training) 121 | # 1 x 1 x 2048 122 | x = x.view(x.size(0), -1) 123 | # 2048 124 | x = self.fc(x) 125 | # 1000 (num_classes) 126 | if self.training and self.aux_logits: 127 | return x, aux 128 | return x 129 | 130 | def get_features_mixed_6e(self): 131 | return nn.Sequential( 132 | self.Conv2d_1a_3x3, 133 | self.Conv2d_2a_3x3, 134 | self.Conv2d_2b_3x3, 135 | nn.MaxPool2d(kernel_size=3, stride=2), 136 | self.Conv2d_3b_1x1, 137 | self.Conv2d_4a_3x3, 138 | nn.MaxPool2d(kernel_size=3, stride=2), 139 | self.Mixed_5b, 140 | self.Mixed_5c, 141 | self.Mixed_5d, 142 | self.Mixed_6a, 143 | self.Mixed_6b, 144 | self.Mixed_6c, 145 | self.Mixed_6d, 146 | self.Mixed_6e, 147 | ) 148 | 149 | def get_features_mixed_7c(self): 150 | return nn.Sequential( 151 | self.Conv2d_1a_3x3, 152 | self.Conv2d_2a_3x3, 153 | self.Conv2d_2b_3x3, 154 | nn.MaxPool2d(kernel_size=3, stride=2), 155 | self.Conv2d_3b_1x1, 156 | self.Conv2d_4a_3x3, 157 | nn.MaxPool2d(kernel_size=3, stride=2), 158 | self.Mixed_5b, 159 | self.Mixed_5c, 160 | self.Mixed_5d, 161 | self.Mixed_6a, 162 | self.Mixed_6b, 163 | self.Mixed_6c, 164 | self.Mixed_6d, 165 | self.Mixed_6e, 166 | self.Mixed_7a, 167 | self.Mixed_7b, 168 | self.Mixed_7c, 169 | ) 170 | 171 | def load_state_dict(self, state_dict, strict=True): 172 | model_dict = self.state_dict() 173 | pretrained_dict = {k: v for k, v in state_dict.items() 174 | if k in model_dict and model_dict[k].size() == v.size()} 175 | 176 | if len(pretrained_dict) == len(state_dict): 177 | logging.info('%s: All params loaded' % type(self).__name__) 178 | else: 179 | logging.info('%s: Some params were not loaded:' % type(self).__name__) 180 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()] 181 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys)) 182 | 183 | model_dict.update(pretrained_dict) 184 | super(Inception3, self).load_state_dict(model_dict) 185 | 186 | 187 | class InceptionA(nn.Module): 188 | 189 | def __init__(self, in_channels, pool_features): 190 | super(InceptionA, self).__init__() 191 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 192 | 193 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 194 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 195 | 196 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 197 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 198 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 199 | 200 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 201 | 202 | def forward(self, x): 203 | branch1x1 = self.branch1x1(x) 204 | 205 | branch5x5 = self.branch5x5_1(x) 206 | branch5x5 = self.branch5x5_2(branch5x5) 207 | 208 | branch3x3dbl = self.branch3x3dbl_1(x) 209 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 210 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 211 | 212 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 213 | branch_pool = self.branch_pool(branch_pool) 214 | 215 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 216 | return torch.cat(outputs, 1) 217 | 218 | 219 | class InceptionB(nn.Module): 220 | 221 | def __init__(self, in_channels): 222 | super(InceptionB, self).__init__() 223 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 224 | 225 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 226 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 227 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 228 | 229 | def forward(self, x): 230 | branch3x3 = self.branch3x3(x) 231 | 232 | branch3x3dbl = self.branch3x3dbl_1(x) 233 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 234 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 235 | 236 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 237 | 238 | outputs = [branch3x3, branch3x3dbl, branch_pool] 239 | return torch.cat(outputs, 1) 240 | 241 | 242 | class InceptionC(nn.Module): 243 | 244 | def __init__(self, in_channels, channels_7x7): 245 | super(InceptionC, self).__init__() 246 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 247 | 248 | c7 = channels_7x7 249 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 250 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 251 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 252 | 253 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 254 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 255 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 256 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 257 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 258 | 259 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 260 | 261 | def forward(self, x): 262 | branch1x1 = self.branch1x1(x) 263 | 264 | branch7x7 = self.branch7x7_1(x) 265 | branch7x7 = self.branch7x7_2(branch7x7) 266 | branch7x7 = self.branch7x7_3(branch7x7) 267 | 268 | branch7x7dbl = self.branch7x7dbl_1(x) 269 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 270 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 271 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 272 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 273 | 274 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 275 | branch_pool = self.branch_pool(branch_pool) 276 | 277 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 278 | return torch.cat(outputs, 1) 279 | 280 | 281 | class InceptionD(nn.Module): 282 | 283 | def __init__(self, in_channels): 284 | super(InceptionD, self).__init__() 285 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 286 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 287 | 288 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 289 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 290 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 291 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 292 | 293 | def forward(self, x): 294 | branch3x3 = self.branch3x3_1(x) 295 | branch3x3 = self.branch3x3_2(branch3x3) 296 | 297 | branch7x7x3 = self.branch7x7x3_1(x) 298 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 299 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 300 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 301 | 302 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 303 | outputs = [branch3x3, branch7x7x3, branch_pool] 304 | return torch.cat(outputs, 1) 305 | 306 | 307 | class InceptionE(nn.Module): 308 | 309 | def __init__(self, in_channels): 310 | super(InceptionE, self).__init__() 311 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 312 | 313 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 314 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 315 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 316 | 317 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 318 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 319 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 320 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 321 | 322 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 323 | 324 | def forward(self, x): 325 | branch1x1 = self.branch1x1(x) 326 | 327 | branch3x3 = self.branch3x3_1(x) 328 | branch3x3 = [ 329 | self.branch3x3_2a(branch3x3), 330 | self.branch3x3_2b(branch3x3), 331 | ] 332 | branch3x3 = torch.cat(branch3x3, 1) 333 | 334 | branch3x3dbl = self.branch3x3dbl_1(x) 335 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 336 | branch3x3dbl = [ 337 | self.branch3x3dbl_3a(branch3x3dbl), 338 | self.branch3x3dbl_3b(branch3x3dbl), 339 | ] 340 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 341 | 342 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 343 | branch_pool = self.branch_pool(branch_pool) 344 | 345 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 346 | return torch.cat(outputs, 1) 347 | 348 | 349 | class InceptionAux(nn.Module): 350 | 351 | def __init__(self, in_channels, num_classes): 352 | super(InceptionAux, self).__init__() 353 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 354 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 355 | self.conv1.stddev = 0.01 356 | self.fc = nn.Linear(768, num_classes) 357 | self.fc.stddev = 0.001 358 | 359 | def forward(self, x): 360 | # 17 x 17 x 768 361 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 362 | # 5 x 5 x 768 363 | x = self.conv0(x) 364 | # 5 x 5 x 128 365 | x = self.conv1(x) 366 | # 1 x 1 x 768 367 | x = x.view(x.size(0), -1) 368 | # 768 369 | x = self.fc(x) 370 | # 1000 371 | return x 372 | 373 | 374 | class BasicConv2d(nn.Module): 375 | 376 | def __init__(self, in_channels, out_channels, **kwargs): 377 | super(BasicConv2d, self).__init__() 378 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 379 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 380 | 381 | def forward(self, x): 382 | x = self.conv(x) 383 | x = self.bn(x) 384 | return F.relu(x, inplace=True) 385 | --------------------------------------------------------------------------------