├── Framework.png ├── README.md ├── datasets.py ├── models.py ├── train.py └── utils.py /Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeiqinZhuang/API-Net/c4996c7fec3fbd46fe2873f74e8ba154218b7e7f/Framework.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Attentive Pairwise Interaction for Fine-Grained Classification (API-Net) 2 | Peiqin Zhuang, Yali Wang, Yu Qiao 3 | # Introduction: 4 | In order to effectively identify contrastive clues among highly-confused categories, we propose a simple but effective Attentive Pairwise Interaction Network (API-Net), which can progressively recognize a pair of fine-grained images by interaction. We aim at learning a mutual vector first to capture semantic differences in the input pair, and then comparing this mutual vector with individual vectors to highlight their semantic differences respectively. Besides, we also introduce a score-ranking regularization to promote the priorities of these features. For more details, please refer to [our paper](https://www.aaai.org/Papers/AAAI/2020GB/AAAI-ZhuangP.2505.pdf). 5 | # Framework: 6 | ![Framework](/Framework.png) 7 | # Dependencies: 8 | * Python 2.7 9 | * Pytorch 0.4.1 10 | * torchvision 0.2.0 11 | # How to use: 12 | ``` 13 | # python train.py 14 | ``` 15 | # Citing: 16 | Please kindly cite the following paper, if you find this code helpful in your work. 17 | ``` 18 | @inproceedings{zhuang2020learning, 19 | title={Learning Attentive Pairwise Interaction for Fine-Grained Classification.}, 20 | author={Zhuang, Peiqin and Wang, Yali and Qiao, Yu}, 21 | booktitle={AAAI}, 22 | pages={13130--13137}, 23 | year={2020} 24 | } 25 | ``` 26 | # Contact: 27 | Please feel free to contact zpq0316@163.com or {yl.wang, yu.qiao}@siat.ac.cn, if you have any questions. 28 | # Acknowledgement: 29 | Some of the codes are borrowed from [siamese-triplet](https://github.com/adambielski/siamese-triplet) and [triplet-reid-pytorch](https://github.com/CoinCheung/triplet-reid-pytorch). Many thanks to them. 30 | 31 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data.sampler import BatchSampler 4 | from PIL import Image 5 | import numpy as np 6 | 7 | def default_loader(path): 8 | try: 9 | img = Image.open(path).convert('RGB') 10 | except: 11 | with open('read_error.txt', 'a') as fid: 12 | fid.write(path+'\n') 13 | return Image.new('RGB', (224,224), 'white') 14 | return img 15 | 16 | class RandomDataset(Dataset): 17 | def __init__(self, transform=None, dataloader=default_loader): 18 | self.transform = transform 19 | self.dataloader = dataloader 20 | 21 | with open('/home/pqzhuang/data/CUB/CUB_200_2011/val.txt', 'r') as fid: 22 | self.imglist = fid.readlines() 23 | 24 | def __getitem__(self, index): 25 | image_name, label = self.imglist[index].strip().split() 26 | image_path = image_name 27 | img = self.dataloader(image_path) 28 | img = self.transform(img) 29 | label = int(label) 30 | label = torch.LongTensor([label]) 31 | 32 | return [img, label] 33 | 34 | 35 | def __len__(self): 36 | return len(self.imglist) 37 | 38 | class BatchDataset(Dataset): 39 | def __init__(self, transform=None, dataloader=default_loader): 40 | self.transform = transform 41 | self.dataloader = dataloader 42 | 43 | with open('/home/pqzhuang/data/CUB/CUB_200_2011/train.txt', 'r') as fid: 44 | self.imglist = fid.readlines() 45 | 46 | self.labels = [] 47 | for line in self.imglist: 48 | image_path, label = line.strip().split() 49 | self.labels.append(int(label)) 50 | self.labels = np.array(self.labels) 51 | self.labels = torch.LongTensor(self.labels) 52 | 53 | 54 | def __getitem__(self, index): 55 | image_name, label = self.imglist[index].strip().split() 56 | image_path = image_name 57 | img = self.dataloader(image_path) 58 | img = self.transform(img) 59 | label = int(label) 60 | label = torch.LongTensor([label]) 61 | 62 | return [img, label] 63 | 64 | 65 | def __len__(self): 66 | return len(self.imglist) 67 | 68 | class BalancedBatchSampler(BatchSampler): 69 | def __init__(self, dataset, n_classes, n_samples): 70 | self.labels = dataset.labels 71 | self.labels_set = list(set(self.labels.numpy())) 72 | self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] 73 | for label in self.labels_set} 74 | for l in self.labels_set: 75 | np.random.shuffle(self.label_to_indices[l]) 76 | self.used_label_indices_count = {label: 0 for label in self.labels_set} 77 | self.count = 0 78 | self.n_classes = n_classes 79 | self.n_samples = n_samples 80 | self.dataset = dataset 81 | self.batch_size = self.n_samples * self.n_classes 82 | 83 | def __iter__(self): 84 | self.count = 0 85 | while self.count + self.batch_size < len(self.dataset): 86 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 87 | indices = [] 88 | for class_ in classes: 89 | indices.extend(self.label_to_indices[class_][ 90 | self.used_label_indices_count[class_]:self.used_label_indices_count[ 91 | class_] + self.n_samples]) 92 | self.used_label_indices_count[class_] += self.n_samples 93 | if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): 94 | np.random.shuffle(self.label_to_indices[class_]) 95 | self.used_label_indices_count[class_] = 0 96 | yield indices 97 | self.count += self.n_classes * self.n_samples 98 | 99 | def __len__(self): 100 | return len(self.dataset) // self.batch_size 101 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | from torchvision import models 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def pdist(vectors): 12 | distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum( 13 | dim=1).view(-1, 1) 14 | return distance_matrix 15 | 16 | class API_Net(nn.Module): 17 | def __init__(self): 18 | super(API_Net, self).__init__() 19 | 20 | resnet101 = models.resnet101(pretrained=True) 21 | layers = list(resnet101.children())[:-2] 22 | 23 | self.conv = nn.Sequential(*layers) 24 | self.avg = nn.AvgPool2d(kernel_size=14, stride=1) 25 | self.map1 = nn.Linear(2048 * 2, 512) 26 | self.map2 = nn.Linear(512, 2048) 27 | self.fc = nn.Linear(2048, 200) 28 | self.drop = nn.Dropout(p=0.5) 29 | self.sigmoid = nn.Sigmoid() 30 | 31 | 32 | def forward(self, images, targets=None, flag='train'): 33 | conv_out = self.conv(images) 34 | pool_out = self.avg(conv_out).squeeze() 35 | 36 | if flag == 'train': 37 | intra_pairs, inter_pairs, \ 38 | intra_labels, inter_labels = self.get_pairs(pool_out, targets) 39 | 40 | features1 = torch.cat([pool_out[intra_pairs[:, 0]], pool_out[inter_pairs[:, 0]]], dim=0) 41 | features2 = torch.cat([pool_out[intra_pairs[:, 1]], pool_out[inter_pairs[:, 1]]], dim=0) 42 | labels1 = torch.cat([intra_labels[:, 0], inter_labels[:, 0]], dim=0) 43 | labels2 = torch.cat([intra_labels[:, 1], inter_labels[:, 1]], dim=0) 44 | 45 | 46 | mutual_features = torch.cat([features1, features2], dim=1) 47 | map1_out = self.map1(mutual_features) 48 | map2_out = self.drop(map1_out) 49 | map2_out = self.map2(map2_out) 50 | 51 | 52 | gate1 = torch.mul(map2_out, features1) 53 | gate1 = self.sigmoid(gate1) 54 | 55 | gate2 = torch.mul(map2_out, features2) 56 | gate2 = self.sigmoid(gate2) 57 | 58 | features1_self = torch.mul(gate1, features1) + features1 59 | features1_other = torch.mul(gate2, features1) + features1 60 | 61 | features2_self = torch.mul(gate2, features2) + features2 62 | features2_other = torch.mul(gate1, features2) + features2 63 | 64 | logit1_self = self.fc(self.drop(features1_self)) 65 | logit1_other = self.fc(self.drop(features1_other)) 66 | logit2_self = self.fc(self.drop(features2_self)) 67 | logit2_other = self.fc(self.drop(features2_other)) 68 | 69 | return logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2 70 | 71 | elif flag == 'val': 72 | return self.fc(pool_out) 73 | 74 | 75 | def get_pairs(self, embeddings, labels): 76 | distance_matrix = pdist(embeddings).detach().cpu().numpy() 77 | 78 | labels = labels.detach().cpu().numpy().reshape(-1,1) 79 | num = labels.shape[0] 80 | dia_inds = np.diag_indices(num) 81 | lb_eqs = (labels == labels.T) 82 | lb_eqs[dia_inds] = False 83 | dist_same = distance_matrix.copy() 84 | dist_same[lb_eqs == False] = np.inf 85 | intra_idxs = np.argmin(dist_same, axis=1) 86 | 87 | dist_diff = distance_matrix.copy() 88 | lb_eqs[dia_inds] = True 89 | dist_diff[lb_eqs == True] = np.inf 90 | inter_idxs = np.argmin(dist_diff, axis=1) 91 | 92 | intra_pairs = np.zeros([embeddings.shape[0], 2]) 93 | inter_pairs = np.zeros([embeddings.shape[0], 2]) 94 | intra_labels = np.zeros([embeddings.shape[0], 2]) 95 | inter_labels = np.zeros([embeddings.shape[0], 2]) 96 | for i in range(embeddings.shape[0]): 97 | intra_labels[i, 0] = labels[i] 98 | intra_labels[i, 1] = labels[intra_idxs[i]] 99 | intra_pairs[i, 0] = i 100 | intra_pairs[i, 1] = intra_idxs[i] 101 | 102 | inter_labels[i, 0] = labels[i] 103 | inter_labels[i, 1] = labels[inter_idxs[i]] 104 | inter_pairs[i, 0] = i 105 | inter_pairs[i, 1] = inter_idxs[i] 106 | 107 | intra_labels = torch.from_numpy(intra_labels).long().to(device) 108 | intra_pairs = torch.from_numpy(intra_pairs).long().to(device) 109 | inter_labels = torch.from_numpy(inter_labels).long().to(device) 110 | inter_pairs = torch.from_numpy(inter_pairs).long().to(device) 111 | 112 | return intra_pairs, inter_pairs, intra_labels, inter_labels 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim 8 | import torch.utils.data 9 | import torchvision.transforms as transforms 10 | import numpy as np 11 | from models import API_Net 12 | from datasets import RandomDataset, BatchDataset, BalancedBatchSampler 13 | from utils import accuracy, AverageMeter, save_checkpoint 14 | 15 | 16 | 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 19 | parser.add_argument('--exp_name', default=None, type=str, 20 | help='name of experiment') 21 | parser.add_argument('--data', metavar='DIR',default='', 22 | help='path to dataset') 23 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 26 | help='number of total epochs to run') 27 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 28 | help='manual epoch number (useful on restarts)') 29 | parser.add_argument('-b', '--batch-size', default=100, type=int, 30 | metavar='N', help='mini-batch size (default: 256)') 31 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 32 | metavar='LR', help='initial learning rate') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 34 | help='momentum') 35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 36 | metavar='W', help='weight decay (default: 1e-4)') 37 | parser.add_argument('--print-freq', '-p', default=1, type=int, 38 | metavar='N', help='print frequency (default: 10)') 39 | parser.add_argument('--evaluate-freq', default=10, type=int, 40 | help='the evaluation frequence') 41 | parser.add_argument('--resume', default='./checkpoint.pth.tar', type=str, metavar='PATH', 42 | help='path to latest checkpoint (default: none)') 43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 44 | help='evaluate model on validation set') 45 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 46 | help='use pre-trained model') 47 | parser.add_argument('--n_classes', default=30, type=int, 48 | help='the number of classes') 49 | parser.add_argument('--n_samples', default=4, type=int, 50 | help='the number of samples per class') 51 | 52 | 53 | best_prec1 = 0 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | 56 | def main(): 57 | global args, best_prec1 58 | args = parser.parse_args() 59 | torch.manual_seed(2) 60 | torch.cuda.manual_seed_all(2) 61 | np.random.seed(2) 62 | 63 | 64 | # create model 65 | model = API_Net() 66 | model = model.to(device) 67 | model.conv = nn.DataParallel(model.conv) 68 | 69 | # define loss function (criterion) and optimizer 70 | criterion = nn.CrossEntropyLoss().to(device) 71 | optimizer_conv = torch.optim.SGD(model.conv.parameters(), args.lr, 72 | momentum=args.momentum, 73 | weight_decay=args.weight_decay) 74 | 75 | fc_parameters = [value for name, value in model.named_parameters() if 'conv' not in name] 76 | optimizer_fc = torch.optim.SGD(fc_parameters, args.lr, 77 | momentum=args.momentum, 78 | weight_decay=args.weight_decay) 79 | if args.resume: 80 | if os.path.isfile(args.resume): 81 | print 'loading checkpoint {}'.format(args.resume) 82 | checkpoint = torch.load(args.resume) 83 | args.start_epoch = checkpoint['epoch'] 84 | best_prec1 = checkpoint['best_prec1'] 85 | model.load_state_dict(checkpoint['state_dict']) 86 | optimizer_conv.load_state_dict(checkpoint['optimizer_conv']) 87 | optimizer_fc.load_state_dict(checkpoint['optimizer_fc']) 88 | print 'loaded checkpoint {}(epoch {})'.format(args.resume, checkpoint['epoch']) 89 | else: 90 | print 'no checkpoint found at {}'.format(args.resume) 91 | 92 | 93 | cudnn.benchmark = True 94 | # Data loading code 95 | train_dataset = BatchDataset(transform=transforms.Compose([ 96 | transforms.Resize([512,512]), 97 | transforms.RandomCrop([448,448]), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | transforms.Normalize( 101 | mean=(0.485, 0.456, 0.406), 102 | std=(0.229, 0.224, 0.225) 103 | )])) 104 | 105 | train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples) 106 | train_loader = torch.utils.data.DataLoader( 107 | train_dataset, batch_sampler=train_sampler, 108 | num_workers=args.workers, pin_memory=True) 109 | scheduler_conv = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_conv, 100*len(train_loader)) 110 | scheduler_fc = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_fc, 100*len(train_loader)) 111 | 112 | step = 0 113 | print 'START TIME:', time.asctime(time.localtime(time.time())) 114 | for epoch in range(args.start_epoch, args.epochs): 115 | step = train(train_loader, model, criterion, optimizer_conv, scheduler_conv, optimizer_fc, scheduler_fc, epoch, step) 116 | 117 | 118 | 119 | 120 | def train(train_loader, model, criterion, optimizer_conv,scheduler_conv, optimizer_fc, scheduler_fc, epoch, step): 121 | global best_prec1 122 | 123 | batch_time = AverageMeter() 124 | data_time = AverageMeter() 125 | softmax_losses = AverageMeter() 126 | rank_losses = AverageMeter() 127 | losses = AverageMeter() 128 | top1 = AverageMeter() 129 | top5 = AverageMeter() 130 | 131 | # switch to train mode 132 | end = time.time() 133 | rank_criterion = nn.MarginRankingLoss(margin=0.05) 134 | softmax_layer = nn.Softmax(dim=1).to(device) 135 | 136 | for i, (input, target) in enumerate(train_loader): 137 | model.train() 138 | 139 | # measure data loading time 140 | data_time.update(time.time() - end) 141 | input_var = input.to(device) 142 | target_var = target.to(device).squeeze() 143 | 144 | 145 | # compute output 146 | logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2 = model(input_var, target_var, flag='train') 147 | batch_size = logit1_self.shape[0] 148 | labels1 = labels1.to(device) 149 | labels2 = labels2.to(device) 150 | 151 | self_logits = torch.zeros(2*batch_size, 200).to(device) 152 | other_logits= torch.zeros(2*batch_size, 200).to(device) 153 | self_logits[:batch_size] = logit1_self 154 | self_logits[batch_size:] = logit2_self 155 | other_logits[:batch_size] = logit1_other 156 | other_logits[batch_size:] = logit2_other 157 | 158 | # compute loss 159 | logits = torch.cat([self_logits, other_logits], dim=0) 160 | targets = torch.cat([labels1, labels2, labels1, labels2], dim=0) 161 | softmax_loss = criterion(logits, targets) 162 | 163 | self_scores = softmax_layer(self_logits)[torch.arange(2*batch_size).to(device).long(), 164 | torch.cat([labels1, labels2], dim=0)] 165 | other_scores = softmax_layer(other_logits)[torch.arange(2*batch_size).to(device).long(), 166 | torch.cat([labels1, labels2], dim=0)] 167 | flag = torch.ones([2*batch_size, ]).to(device) 168 | rank_loss = rank_criterion(self_scores, other_scores, flag) 169 | 170 | loss = softmax_loss + rank_loss 171 | 172 | # measure accuracy and record loss 173 | prec1 = accuracy(logits, targets, 1) 174 | prec5 = accuracy(logits, targets, 5) 175 | losses.update(loss.item(), 2*batch_size) 176 | softmax_losses.update(softmax_loss.item(), 4*batch_size) 177 | rank_losses.update(rank_loss.item(), 2*batch_size) 178 | top1.update(prec1, 4*batch_size) 179 | top5.update(prec5, 4*batch_size) 180 | 181 | # compute gradient and do SGD step 182 | optimizer_conv.zero_grad() 183 | optimizer_fc.zero_grad() 184 | loss.backward() 185 | if epoch >= 8: 186 | optimizer_conv.step() 187 | optimizer_fc.step() 188 | scheduler_conv.step() 189 | scheduler_fc.step() 190 | 191 | 192 | # measure elapsed time 193 | batch_time.update(time.time() - end) 194 | end = time.time() 195 | 196 | if i % args.print_freq == 0: 197 | print('Time: {time}\nStep: {step}\t Epoch: [{0}][{1}/{2}]\t' 198 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 199 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 200 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 201 | 'SoftmaxLoss {softmax_loss.val:.4f} ({softmax_loss.avg:.4f})\t' 202 | 'RankLoss {rank_loss.val:.4f} ({rank_loss.avg:.4f})\t' 203 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 204 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 205 | epoch, i, len(train_loader), batch_time=batch_time, 206 | data_time=data_time, loss=losses, softmax_loss=softmax_losses, rank_loss=rank_losses, 207 | top1=top1, top5=top5, step=step, time= time.asctime(time.localtime(time.time())))) 208 | 209 | if i== len(train_loader) - 1: 210 | val_dataset = RandomDataset(transform=transforms.Compose([ 211 | transforms.Resize([512,512]), 212 | transforms.CenterCrop([448,448]), 213 | transforms.ToTensor(), 214 | transforms.Normalize( 215 | mean=(0.485, 0.456, 0.406), 216 | std=(0.229, 0.224, 0.225) 217 | )])) 218 | val_loader = torch.utils.data.DataLoader( 219 | val_dataset, batch_size=args.batch_size, shuffle=False, 220 | num_workers=args.workers, pin_memory=True) 221 | prec1 = validate(val_loader, model, criterion) 222 | 223 | # remember best prec@1 and save checkpoint 224 | is_best = prec1 > best_prec1 225 | best_prec1 = max(prec1, best_prec1) 226 | save_checkpoint({ 227 | 'epoch': epoch + 1, 228 | 'state_dict': model.state_dict(), 229 | 'best_prec1': best_prec1, 230 | 'optimizer_conv': optimizer_conv.state_dict(), 231 | 'optimizer_fc': optimizer_fc.state_dict(), 232 | }, is_best) 233 | 234 | step = step +1 235 | return step 236 | 237 | 238 | 239 | 240 | 241 | 242 | def validate(val_loader, model, criterion): 243 | batch_time = AverageMeter() 244 | softmax_losses = AverageMeter() 245 | top1 = AverageMeter() 246 | top5 = AverageMeter() 247 | 248 | # switch to evaluate mode 249 | model.eval() 250 | end = time.time() 251 | 252 | with torch.no_grad(): 253 | for i, (input, target) in enumerate(val_loader): 254 | 255 | input_var = input.to(device) 256 | target_var = target.to(device).squeeze() 257 | 258 | # compute output 259 | logits = model(input_var, targets=None, flag='val') 260 | softmax_loss = criterion(logits, target_var) 261 | 262 | 263 | prec1= accuracy(logits, target_var, 1) 264 | prec5 = accuracy(logits, target_var, 5) 265 | softmax_losses.update(softmax_loss.item(), logits.size(0)) 266 | top1.update(prec1, logits.size(0)) 267 | top5.update(prec5, logits.size(0)) 268 | 269 | # measure elapsed time 270 | batch_time.update(time.time() - end) 271 | end = time.time() 272 | 273 | 274 | 275 | if i % args.print_freq == 0: 276 | print('Time: {time}\nTest: [{0}/{1}]\t' 277 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 278 | 'SoftmaxLoss {softmax_loss.val:.4f} ({softmax_loss.avg:.4f})\t' 279 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 280 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 281 | i, len(val_loader), batch_time=batch_time, softmax_loss=softmax_losses, 282 | top1=top1, top5=top5, time=time.asctime(time.localtime(time.time())))) 283 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) 284 | 285 | return top1.avg 286 | 287 | 288 | 289 | 290 | if __name__ == '__main__': 291 | main() 292 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | 4 | 5 | 6 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 7 | torch.save(state, filename) 8 | if is_best: 9 | shutil.copyfile(filename, 'model_best.pth.tar') 10 | 11 | 12 | class AverageMeter(object): 13 | """ 14 | Keeps track of most recent, average, sum, and count of a metric. 15 | """ 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | 34 | 35 | 36 | def accuracy(scores, targets, k): 37 | """ 38 | Computes top-k accuracy, from predicted and true labels. 39 | 40 | :param scores: scores from the model 41 | :param targets: true labels 42 | :param k: k in top-k accuracy 43 | :return: top-k accuracy 44 | """ 45 | 46 | batch_size = targets.size(0) 47 | _, ind = scores.topk(k, 1, True, True) 48 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 49 | correct_total = correct.view(-1).float().sum() # 0D tensor 50 | return correct_total.item() * (100.0 / batch_size) 51 | 52 | 53 | 54 | 55 | --------------------------------------------------------------------------------