├── LICENSE ├── README.md ├── config.py ├── core ├── anchors.py ├── dataset.py ├── model.py ├── resnet.py └── utils.py ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ze Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NTS-Net 2 | 3 | This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang). 4 | 5 | ## Requirements 6 | - python 3+ 7 | - pytorch 0.4+ 8 | - numpy 9 | - datetime 10 | 11 | ## Datasets 12 | Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets. 13 | 14 | ## Train the model 15 | If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume. 16 | 17 | ## Test the model 18 | If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing. 19 | 20 | ## Model 21 | We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy. 22 | 23 | ## Reference 24 | If you are interested in our work and want to cite it, please acknowledge the following paper: 25 | 26 | ``` 27 | @inproceedings{Yang2018Learning, 28 | author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei}, 29 | title = {Learning to Navigate for Fine-grained Classification}, 30 | booktitle = {ECCV}, 31 | year = {2018} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 16 2 | PROPOSAL_NUM = 6 3 | CAT_NUM = 4 4 | INPUT_SIZE = (448, 448) # (w, h) 5 | LR = 0.001 6 | WD = 1e-4 7 | SAVE_FREQ = 1 8 | resume = '' 9 | test_model = 'model.ckpt' 10 | save_dir = '/data_4t/yangz/models/' 11 | -------------------------------------------------------------------------------- /core/anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from config import INPUT_SIZE 3 | 4 | _default_anchors_setting = ( 5 | dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 6 | dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 7 | dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 8 | ) 9 | 10 | 11 | def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE): 12 | """ 13 | generate default anchor 14 | 15 | :param anchors_setting: all informations of anchors 16 | :param input_shape: shape of input images, e.g. (h, w) 17 | :return: center_anchors: # anchors * 4 (oy, ox, h, w) 18 | edge_anchors: # anchors * 4 (y0, x0, y1, x1) 19 | anchor_area: # anchors * 1 (area) 20 | """ 21 | if anchors_setting is None: 22 | anchors_setting = _default_anchors_setting 23 | 24 | center_anchors = np.zeros((0, 4), dtype=np.float32) 25 | edge_anchors = np.zeros((0, 4), dtype=np.float32) 26 | anchor_areas = np.zeros((0,), dtype=np.float32) 27 | input_shape = np.array(input_shape, dtype=int) 28 | 29 | for anchor_info in anchors_setting: 30 | 31 | stride = anchor_info['stride'] 32 | size = anchor_info['size'] 33 | scales = anchor_info['scale'] 34 | aspect_ratios = anchor_info['aspect_ratio'] 35 | 36 | output_map_shape = np.ceil(input_shape.astype(np.float32) / stride) 37 | output_map_shape = output_map_shape.astype(np.int) 38 | output_shape = tuple(output_map_shape) + (4,) 39 | ostart = stride / 2. 40 | oy = np.arange(ostart, ostart + stride * output_shape[0], stride) 41 | oy = oy.reshape(output_shape[0], 1) 42 | ox = np.arange(ostart, ostart + stride * output_shape[1], stride) 43 | ox = ox.reshape(1, output_shape[1]) 44 | center_anchor_map_template = np.zeros(output_shape, dtype=np.float32) 45 | center_anchor_map_template[:, :, 0] = oy 46 | center_anchor_map_template[:, :, 1] = ox 47 | for scale in scales: 48 | for aspect_ratio in aspect_ratios: 49 | center_anchor_map = center_anchor_map_template.copy() 50 | center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5 51 | center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5 52 | 53 | edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2., 54 | center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.), 55 | axis=-1) 56 | anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3] 57 | center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4))) 58 | edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4))) 59 | anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1))) 60 | 61 | return center_anchors, edge_anchors, anchor_areas 62 | 63 | 64 | def hard_nms(cdds, topn=10, iou_thresh=0.25): 65 | if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5): 66 | raise TypeError('edge_box_map should be N * 5+ ndarray') 67 | 68 | cdds = cdds.copy() 69 | indices = np.argsort(cdds[:, 0]) 70 | cdds = cdds[indices] 71 | cdd_results = [] 72 | 73 | res = cdds 74 | 75 | while res.any(): 76 | cdd = res[-1] 77 | cdd_results.append(cdd) 78 | if len(cdd_results) == topn: 79 | return np.array(cdd_results) 80 | res = res[:-1] 81 | 82 | start_max = np.maximum(res[:, 1:3], cdd[1:3]) 83 | end_min = np.minimum(res[:, 3:5], cdd[3:5]) 84 | lengths = end_min - start_max 85 | intersec_map = lengths[:, 0] * lengths[:, 1] 86 | intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0 87 | iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * ( 88 | cdd[4] - cdd[2]) - intersec_map) 89 | res = res[iou_map_cur < iou_thresh] 90 | 91 | return np.array(cdd_results) 92 | 93 | 94 | if __name__ == '__main__': 95 | a = hard_nms(np.array([ 96 | [0.4, 1, 10, 12, 20], 97 | [0.5, 1, 11, 11, 20], 98 | [0.55, 20, 30, 40, 50] 99 | ]), topn=100, iou_thresh=0.4) 100 | print(a) 101 | -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import os 4 | from PIL import Image 5 | from torchvision import transforms 6 | from config import INPUT_SIZE 7 | 8 | 9 | class CUB(): 10 | def __init__(self, root, is_train=True, data_len=None): 11 | self.root = root 12 | self.is_train = is_train 13 | img_txt_file = open(os.path.join(self.root, 'images.txt')) 14 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) 15 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) 16 | img_name_list = [] 17 | for line in img_txt_file: 18 | img_name_list.append(line[:-1].split(' ')[-1]) 19 | label_list = [] 20 | for line in label_txt_file: 21 | label_list.append(int(line[:-1].split(' ')[-1]) - 1) 22 | train_test_list = [] 23 | for line in train_val_file: 24 | train_test_list.append(int(line[:-1].split(' ')[-1])) 25 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] 26 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] 27 | if self.is_train: 28 | self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in 29 | train_file_list[:data_len]] 30 | self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] 31 | if not self.is_train: 32 | self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in 33 | test_file_list[:data_len]] 34 | self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] 35 | 36 | def __getitem__(self, index): 37 | if self.is_train: 38 | img, target = self.train_img[index], self.train_label[index] 39 | if len(img.shape) == 2: 40 | img = np.stack([img] * 3, 2) 41 | img = Image.fromarray(img, mode='RGB') 42 | img = transforms.Resize((600, 600), Image.BILINEAR)(img) 43 | img = transforms.RandomCrop(INPUT_SIZE)(img) 44 | img = transforms.RandomHorizontalFlip()(img) 45 | img = transforms.ToTensor()(img) 46 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) 47 | 48 | else: 49 | img, target = self.test_img[index], self.test_label[index] 50 | if len(img.shape) == 2: 51 | img = np.stack([img] * 3, 2) 52 | img = Image.fromarray(img, mode='RGB') 53 | img = transforms.Resize((600, 600), Image.BILINEAR)(img) 54 | img = transforms.CenterCrop(INPUT_SIZE)(img) 55 | img = transforms.ToTensor()(img) 56 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) 57 | 58 | return img, target 59 | 60 | def __len__(self): 61 | if self.is_train: 62 | return len(self.train_label) 63 | else: 64 | return len(self.test_label) 65 | 66 | 67 | if __name__ == '__main__': 68 | dataset = CUB(root='./CUB_200_2011') 69 | print(len(dataset.train_img)) 70 | print(len(dataset.train_label)) 71 | for data in dataset: 72 | print(data[0].size(), data[1]) 73 | dataset = CUB(root='./CUB_200_2011', is_train=False) 74 | print(len(dataset.test_img)) 75 | print(len(dataset.test_label)) 76 | for data in dataset: 77 | print(data[0].size(), data[1]) 78 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from core import resnet 6 | import numpy as np 7 | from core.anchors import generate_default_anchor_maps, hard_nms 8 | from config import CAT_NUM, PROPOSAL_NUM 9 | 10 | 11 | class ProposalNet(nn.Module): 12 | def __init__(self): 13 | super(ProposalNet, self).__init__() 14 | self.down1 = nn.Conv2d(2048, 128, 3, 1, 1) 15 | self.down2 = nn.Conv2d(128, 128, 3, 2, 1) 16 | self.down3 = nn.Conv2d(128, 128, 3, 2, 1) 17 | self.ReLU = nn.ReLU() 18 | self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0) 19 | self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0) 20 | self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0) 21 | 22 | def forward(self, x): 23 | batch_size = x.size(0) 24 | d1 = self.ReLU(self.down1(x)) 25 | d2 = self.ReLU(self.down2(d1)) 26 | d3 = self.ReLU(self.down3(d2)) 27 | t1 = self.tidy1(d1).view(batch_size, -1) 28 | t2 = self.tidy2(d2).view(batch_size, -1) 29 | t3 = self.tidy3(d3).view(batch_size, -1) 30 | return torch.cat((t1, t2, t3), dim=1) 31 | 32 | 33 | class attention_net(nn.Module): 34 | def __init__(self, topN=4): 35 | super(attention_net, self).__init__() 36 | self.pretrained_model = resnet.resnet50(pretrained=True) 37 | self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1) 38 | self.pretrained_model.fc = nn.Linear(512 * 4, 200) 39 | self.proposal_net = ProposalNet() 40 | self.topN = topN 41 | self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200) 42 | self.partcls_net = nn.Linear(512 * 4, 200) 43 | _, edge_anchors, _ = generate_default_anchor_maps() 44 | self.pad_side = 224 45 | self.edge_anchors = (edge_anchors + 224).astype(np.int) 46 | 47 | def forward(self, x): 48 | resnet_out, rpn_feature, feature = self.pretrained_model(x) 49 | x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0) 50 | batch = x.size(0) 51 | # we will reshape rpn to shape: batch * nb_anchor 52 | rpn_score = self.proposal_net(rpn_feature.detach()) 53 | all_cdds = [ 54 | np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1) 55 | for x in rpn_score.data.cpu().numpy()] 56 | top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds] 57 | top_n_cdds = np.array(top_n_cdds) 58 | top_n_index = top_n_cdds[:, :, -1].astype(np.int) 59 | top_n_index = torch.from_numpy(top_n_index).cuda() 60 | top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index) 61 | part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda() 62 | for i in range(batch): 63 | for j in range(self.topN): 64 | [y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int) 65 | part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear', 66 | align_corners=True) 67 | part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224) 68 | _, _, part_features = self.pretrained_model(part_imgs.detach()) 69 | part_feature = part_features.view(batch, self.topN, -1) 70 | part_feature = part_feature[:, :CAT_NUM, ...].contiguous() 71 | part_feature = part_feature.view(batch, -1) 72 | # concat_logits have the shape: B*200 73 | concat_out = torch.cat([part_feature, feature], dim=1) 74 | concat_logits = self.concat_net(concat_out) 75 | raw_logits = resnet_out 76 | # part_logits have the shape: B*N*200 77 | part_logits = self.partcls_net(part_features).view(batch, self.topN, -1) 78 | return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob] 79 | 80 | 81 | def list_loss(logits, targets): 82 | temp = F.log_softmax(logits, -1) 83 | loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))] 84 | return torch.stack(loss) 85 | 86 | 87 | def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM): 88 | loss = Variable(torch.zeros(1).cuda()) 89 | batch_size = score.size(0) 90 | for i in range(proposal_num): 91 | targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor) 92 | pivot = score[:, i].unsqueeze(1) 93 | loss_p = (1 - pivot + score) * targets_p 94 | loss_p = torch.sum(F.relu(loss_p)) 95 | loss += loss_p 96 | return loss / batch_size 97 | -------------------------------------------------------------------------------- /core/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, layers, num_classes=1000): 96 | self.inplanes = 64 97 | super(ResNet, self).__init__() 98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(64) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 103 | self.layer1 = self._make_layer(block, 64, layers[0]) 104 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 106 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 107 | self.avgpool = nn.AvgPool2d(7) 108 | self.fc = nn.Linear(512 * block.expansion, num_classes) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | x = self.maxpool(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | feature1 = x 146 | x = self.avgpool(x) 147 | x = x.view(x.size(0), -1) 148 | x = nn.Dropout(p=0.5)(x) 149 | feature2 = x 150 | x = self.fc(x) 151 | 152 | return x, feature1, feature2 153 | 154 | 155 | def resnet18(pretrained=False, **kwargs): 156 | """Constructs a ResNet-18 model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False, **kwargs): 168 | """Constructs a ResNet-34 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False, **kwargs): 192 | """Constructs a ResNet-101 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False, **kwargs): 204 | """Constructs a ResNet-152 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 212 | return model 213 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import logging 6 | 7 | _, term_width = os.popen('stty size', 'r').read().split() 8 | term_width = int(term_width) 9 | 10 | TOTAL_BAR_LENGTH = 40. 11 | last_time = time.time() 12 | begin_time = last_time 13 | 14 | 15 | def progress_bar(current, total, msg=None): 16 | global last_time, begin_time 17 | if current == 0: 18 | begin_time = time.time() # Reset for new bar. 19 | 20 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 21 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 22 | 23 | sys.stdout.write(' [') 24 | for i in range(cur_len): 25 | sys.stdout.write('=') 26 | sys.stdout.write('>') 27 | for i in range(rest_len): 28 | sys.stdout.write('.') 29 | sys.stdout.write(']') 30 | 31 | cur_time = time.time() 32 | step_time = cur_time - last_time 33 | last_time = cur_time 34 | tot_time = cur_time - begin_time 35 | 36 | L = [] 37 | L.append(' Step: %s' % format_time(step_time)) 38 | L.append(' | Tot: %s' % format_time(tot_time)) 39 | if msg: 40 | L.append(' | ' + msg) 41 | 42 | msg = ''.join(L) 43 | sys.stdout.write(msg) 44 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 45 | sys.stdout.write(' ') 46 | 47 | # Go back to the center of the bar. 48 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)): 49 | sys.stdout.write('\b') 50 | sys.stdout.write(' %d/%d ' % (current + 1, total)) 51 | 52 | if current < total - 1: 53 | sys.stdout.write('\r') 54 | else: 55 | sys.stdout.write('\n') 56 | sys.stdout.flush() 57 | 58 | 59 | def format_time(seconds): 60 | days = int(seconds / 3600 / 24) 61 | seconds = seconds - days * 3600 * 24 62 | hours = int(seconds / 3600) 63 | seconds = seconds - hours * 3600 64 | minutes = int(seconds / 60) 65 | seconds = seconds - minutes * 60 66 | secondsf = int(seconds) 67 | seconds = seconds - secondsf 68 | millis = int(seconds * 1000) 69 | 70 | f = '' 71 | i = 1 72 | if days > 0: 73 | f += str(days) + 'D' 74 | i += 1 75 | if hours > 0 and i <= 2: 76 | f += str(hours) + 'h' 77 | i += 1 78 | if minutes > 0 and i <= 2: 79 | f += str(minutes) + 'm' 80 | i += 1 81 | if secondsf > 0 and i <= 2: 82 | f += str(secondsf) + 's' 83 | i += 1 84 | if millis > 0 and i <= 2: 85 | f += str(millis) + 'ms' 86 | i += 1 87 | if f == '': 88 | f = '0ms' 89 | return f 90 | 91 | 92 | def init_log(output_dir): 93 | logging.basicConfig(level=logging.DEBUG, 94 | format='%(asctime)s %(message)s', 95 | datefmt='%Y%m%d-%H:%M:%S', 96 | filename=os.path.join(output_dir, 'log.log'), 97 | filemode='w') 98 | console = logging.StreamHandler() 99 | console.setLevel(logging.INFO) 100 | logging.getLogger('').addHandler(console) 101 | return logging 102 | 103 | if __name__ == '__main__': 104 | pass 105 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.autograd import Variable 3 | import torch.utils.data 4 | from torch.nn import DataParallel 5 | from config import BATCH_SIZE, PROPOSAL_NUM, test_model 6 | from core import model, dataset 7 | from core.utils import progress_bar 8 | 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 10 | if not test_model: 11 | raise NameError('please set the test_model file to choose the checkpoint!') 12 | # read dataset 13 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None) 14 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, 15 | shuffle=True, num_workers=8, drop_last=False) 16 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None) 17 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, 18 | shuffle=False, num_workers=8, drop_last=False) 19 | # define model 20 | net = model.attention_net(topN=PROPOSAL_NUM) 21 | ckpt = torch.load(test_model) 22 | net.load_state_dict(ckpt['net_state_dict']) 23 | net = net.cuda() 24 | net = DataParallel(net) 25 | creterion = torch.nn.CrossEntropyLoss() 26 | 27 | # evaluate on train set 28 | train_loss = 0 29 | train_correct = 0 30 | total = 0 31 | net.eval() 32 | 33 | for i, data in enumerate(trainloader): 34 | with torch.no_grad(): 35 | img, label = data[0].cuda(), data[1].cuda() 36 | batch_size = img.size(0) 37 | _, concat_logits, _, _, _ = net(img) 38 | # calculate loss 39 | concat_loss = creterion(concat_logits, label) 40 | # calculate accuracy 41 | _, concat_predict = torch.max(concat_logits, 1) 42 | total += batch_size 43 | train_correct += torch.sum(concat_predict.data == label.data) 44 | train_loss += concat_loss.item() * batch_size 45 | progress_bar(i, len(trainloader), 'eval on train set') 46 | 47 | train_acc = float(train_correct) / total 48 | train_loss = train_loss / total 49 | print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total)) 50 | 51 | 52 | # evaluate on test set 53 | test_loss = 0 54 | test_correct = 0 55 | total = 0 56 | for i, data in enumerate(testloader): 57 | with torch.no_grad(): 58 | img, label = data[0].cuda(), data[1].cuda() 59 | batch_size = img.size(0) 60 | _, concat_logits, _, _, _ = net(img) 61 | # calculate loss 62 | concat_loss = creterion(concat_logits, label) 63 | # calculate accuracy 64 | _, concat_predict = torch.max(concat_logits, 1) 65 | total += batch_size 66 | test_correct += torch.sum(concat_predict.data == label.data) 67 | test_loss += concat_loss.item() * batch_size 68 | progress_bar(i, len(testloader), 'eval on test set') 69 | 70 | test_acc = float(test_correct) / total 71 | test_loss = test_loss / total 72 | print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total)) 73 | 74 | print('finishing testing') 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | from torch.nn import DataParallel 4 | from datetime import datetime 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir 7 | from core import model, dataset 8 | from core.utils import init_log, progress_bar 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 11 | start_epoch = 1 12 | save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S')) 13 | if os.path.exists(save_dir): 14 | raise NameError('model dir exists!') 15 | os.makedirs(save_dir) 16 | logging = init_log(save_dir) 17 | _print = logging.info 18 | 19 | # read dataset 20 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None) 21 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, 22 | shuffle=True, num_workers=8, drop_last=False) 23 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None) 24 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, 25 | shuffle=False, num_workers=8, drop_last=False) 26 | # define model 27 | net = model.attention_net(topN=PROPOSAL_NUM) 28 | if resume: 29 | ckpt = torch.load(resume) 30 | net.load_state_dict(ckpt['net_state_dict']) 31 | start_epoch = ckpt['epoch'] + 1 32 | creterion = torch.nn.CrossEntropyLoss() 33 | 34 | # define optimizers 35 | raw_parameters = list(net.pretrained_model.parameters()) 36 | part_parameters = list(net.proposal_net.parameters()) 37 | concat_parameters = list(net.concat_net.parameters()) 38 | partcls_parameters = list(net.partcls_net.parameters()) 39 | 40 | raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD) 41 | concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD) 42 | part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD) 43 | partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD) 44 | schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1), 45 | MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1), 46 | MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1), 47 | MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)] 48 | net = net.cuda() 49 | net = DataParallel(net) 50 | 51 | for epoch in range(start_epoch, 500): 52 | for scheduler in schedulers: 53 | scheduler.step() 54 | 55 | # begin training 56 | _print('--' * 50) 57 | net.train() 58 | for i, data in enumerate(trainloader): 59 | img, label = data[0].cuda(), data[1].cuda() 60 | batch_size = img.size(0) 61 | raw_optimizer.zero_grad() 62 | part_optimizer.zero_grad() 63 | concat_optimizer.zero_grad() 64 | partcls_optimizer.zero_grad() 65 | 66 | raw_logits, concat_logits, part_logits, _, top_n_prob = net(img) 67 | part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1), 68 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM) 69 | raw_loss = creterion(raw_logits, label) 70 | concat_loss = creterion(concat_logits, label) 71 | rank_loss = model.ranking_loss(top_n_prob, part_loss) 72 | partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1), 73 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)) 74 | 75 | total_loss = raw_loss + rank_loss + concat_loss + partcls_loss 76 | total_loss.backward() 77 | raw_optimizer.step() 78 | part_optimizer.step() 79 | concat_optimizer.step() 80 | partcls_optimizer.step() 81 | progress_bar(i, len(trainloader), 'train') 82 | 83 | if epoch % SAVE_FREQ == 0: 84 | train_loss = 0 85 | train_correct = 0 86 | total = 0 87 | net.eval() 88 | for i, data in enumerate(trainloader): 89 | with torch.no_grad(): 90 | img, label = data[0].cuda(), data[1].cuda() 91 | batch_size = img.size(0) 92 | _, concat_logits, _, _, _ = net(img) 93 | # calculate loss 94 | concat_loss = creterion(concat_logits, label) 95 | # calculate accuracy 96 | _, concat_predict = torch.max(concat_logits, 1) 97 | total += batch_size 98 | train_correct += torch.sum(concat_predict.data == label.data) 99 | train_loss += concat_loss.item() * batch_size 100 | progress_bar(i, len(trainloader), 'eval train set') 101 | 102 | train_acc = float(train_correct) / total 103 | train_loss = train_loss / total 104 | 105 | _print( 106 | 'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format( 107 | epoch, 108 | train_loss, 109 | train_acc, 110 | total)) 111 | 112 | # evaluate on test set 113 | test_loss = 0 114 | test_correct = 0 115 | total = 0 116 | for i, data in enumerate(testloader): 117 | with torch.no_grad(): 118 | img, label = data[0].cuda(), data[1].cuda() 119 | batch_size = img.size(0) 120 | _, concat_logits, _, _, _ = net(img) 121 | # calculate loss 122 | concat_loss = creterion(concat_logits, label) 123 | # calculate accuracy 124 | _, concat_predict = torch.max(concat_logits, 1) 125 | total += batch_size 126 | test_correct += torch.sum(concat_predict.data == label.data) 127 | test_loss += concat_loss.item() * batch_size 128 | progress_bar(i, len(testloader), 'eval test set') 129 | 130 | test_acc = float(test_correct) / total 131 | test_loss = test_loss / total 132 | _print( 133 | 'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format( 134 | epoch, 135 | test_loss, 136 | test_acc, 137 | total)) 138 | 139 | # save model 140 | net_state_dict = net.module.state_dict() 141 | if not os.path.exists(save_dir): 142 | os.mkdir(save_dir) 143 | torch.save({ 144 | 'epoch': epoch, 145 | 'train_loss': train_loss, 146 | 'train_acc': train_acc, 147 | 'test_loss': test_loss, 148 | 'test_acc': test_acc, 149 | 'net_state_dict': net_state_dict}, 150 | os.path.join(save_dir, '%03d.ckpt' % epoch)) 151 | 152 | print('finishing training') 153 | --------------------------------------------------------------------------------