├── README.md ├── algorithm.png ├── args_tiered.py ├── args_xent.py ├── test.py ├── test_transductive.py ├── torchFewShot ├── __init__.py ├── data_manager.py ├── dataset_loader │ ├── __init__.py │ ├── test_loader.py │ └── train_loader.py ├── datasets │ ├── __init__.py │ ├── miniImageNet.py │ ├── miniImageNet_load.py │ └── tieredImageNet.py ├── losses.py ├── models │ ├── __init__.py │ ├── cam.py │ ├── net.py │ └── resnet12.py ├── optimizers.py ├── transforms.py └── utils │ ├── __init__.py │ ├── avgmeter.py │ ├── iotools.py │ ├── logger.py │ └── torchtools.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # fewshot-CAN 2 | This repository contains the code for the paper: 3 |
4 | [**Cross Attention Network for Few-shot Classification**](https://arxiv.org/pdf/1910.07677.pdf) 5 |
6 | Ruibing Hou, Hong Chang, Bingpeng Ma, Shiguang Shan, Xilin Chen 7 |
8 | NeurIPS 2019 9 |

10 | 11 |

12 | 13 | ### Abstract 14 | 15 | Few-shot classification aims to recognize unlabeled samples from unseen classes given only few labeled samples. The unseen classes and low-data problem make few-shot classification very challenging. Many existing approaches extracted features from labeled and unlabeled samples independently, as a result, the features are not discriminative enough. In this work, we propose a novel Cross Attention 16 | Network to address the challenging problems in few-shot classification. Firstly, Cross Attention Module is introduced to deal with the problem of unseen classes. The module generates cross attention maps for each pair of class feature and query sample feature so as to highlight the target object regions, making the extracted feature more discriminative. Secondly, a transductive inference algorithm is proposed to alleviate the low-data problem, which iteratively utilizes the unlabeled query set to augment the support set, thereby making the class features more representative. Extensive experiments on two benchmarks show our method is a simple, effective and computationally efficient framework and outperforms the state-of-the-arts. 17 | 18 | ### Citation 19 | 20 | If you use this code for your research, please cite our paper: 21 | ``` 22 | @inproceedings{CAN, 23 | title={Cross Attention Network for Few-shot Classification}, 24 | author={Ruibing Hou and Hong Chang and Bingpeng Ma and Shiguang Shan and Xilin Chen}, 25 | booktitle={NeurIPS}, 26 | year={2019} 27 | } 28 | ``` 29 | 30 | ### Trained models: 31 | We have released our trained models of CAN on miniImageNet and tieredImageNet benchmark at [google drive](https://drive.google.com/drive/folders/1Gi3GvQB3Ypwu-QW_CVmToLC64fMv8TqO?usp=sharing). 32 | 33 | ### Platform 34 | This code was developed and tested with pytorch version 1.0.1 35 | 36 | ## Acknowledgments 37 | 38 | This code is based on the implementations of [**Dynamic Few-Shot Visual Learning without Forgetting**](https://github.com/gidariss/FewShotWithoutForgetting). 39 | -------------------------------------------------------------------------------- /algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-blue272/fewshot-CAN/6055985e82944c305f0fbf4d1b3e0eef1833d06c/algorithm.png -------------------------------------------------------------------------------- /args_tiered.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchFewShot 3 | 4 | def argument_parser(): 5 | 6 | parser = argparse.ArgumentParser(description='Train image model with cross entropy loss') 7 | # ************************************************************ 8 | # Datasets (general) 9 | # ************************************************************ 10 | parser.add_argument('-d', '--dataset', type=str, default='tieredImageNet') 11 | parser.add_argument('--load', default=False) 12 | 13 | parser.add_argument('-j', '--workers', default=4, type=int, 14 | help="number of data loading workers (default: 4)") 15 | parser.add_argument('--height', type=int, default=84, 16 | help="height of an image (default: 84)") 17 | parser.add_argument('--width', type=int, default=84, 18 | help="width of an image (default: 84)") 19 | 20 | # ************************************************************ 21 | # Optimization options 22 | # ************************************************************ 23 | parser.add_argument('--optim', type=str, default='sgd', 24 | help="optimization algorithm (see optimizers.py)") 25 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 26 | help="initial learning rate") 27 | parser.add_argument('--weight-decay', default=5e-04, type=float, 28 | help="weight decay (default: 5e-04)") 29 | 30 | parser.add_argument('--max-epoch', default=80, type=int, 31 | help="maximum epochs to run") 32 | parser.add_argument('--start-epoch', default=0, type=int, 33 | help="manual epoch number (useful on restarts)") 34 | parser.add_argument('--stepsize', default=[60], nargs='+', type=int, 35 | help="stepsize to decay learning rate") 36 | parser.add_argument('--LUT_lr', default=[(20, 0.1), (40, 0.01), (60, 0.001), (80, 0.0001)], 37 | help="multistep to decay learning rate") 38 | 39 | parser.add_argument('--train-batch', default=4, type=int, 40 | help="train batch size") 41 | parser.add_argument('--test-batch', default=8, type=int, 42 | help="test batch size") 43 | 44 | 45 | # ************************************************************ 46 | # Architecture settings 47 | # ************************************************************ 48 | parser.add_argument('--num_classes', type=int, default=351) 49 | parser.add_argument('--scale_cls', type=int, default=7) 50 | 51 | # ************************************************************ 52 | # Miscs 53 | # ************************************************************ 54 | parser.add_argument('--save-dir', type=str, default='./result/tieredImageNet/CAM/5-shot-v2/') 55 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 56 | parser.add_argument('--gpu-devices', default='0', type=str) 57 | 58 | # ************************************************************ 59 | # FewShot settting 60 | # ************************************************************ 61 | parser.add_argument('--nKnovel', type=int, default=5, 62 | help='number of novel categories') 63 | parser.add_argument('--nExemplars', type=int, default=5, 64 | help='number of training examples per novel category.') 65 | 66 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 67 | help='number of test examples for all the novel category when training') 68 | parser.add_argument('--train_epoch_size', type=int, default=13980, 69 | help='number of batches per epoch when training') 70 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 71 | help='number of test examples for all the novel category') 72 | parser.add_argument('--epoch_size', type=int, default=2000, 73 | help='number of batches per epoch') 74 | 75 | parser.add_argument('--phase', default='test', type=str, 76 | help='use test or val dataset to early stop') 77 | parser.add_argument('--seed', type=int, default=1) 78 | 79 | return parser 80 | 81 | -------------------------------------------------------------------------------- /args_xent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchFewShot 3 | 4 | def argument_parser(): 5 | 6 | parser = argparse.ArgumentParser(description='Train image model with cross entropy loss') 7 | # ************************************************************ 8 | # Datasets (general) 9 | # ************************************************************ 10 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load') 11 | parser.add_argument('--load', default=True) 12 | 13 | parser.add_argument('-j', '--workers', default=4, type=int, 14 | help="number of data loading workers (default: 4)") 15 | parser.add_argument('--height', type=int, default=84, 16 | help="height of an image (default: 84)") 17 | parser.add_argument('--width', type=int, default=84, 18 | help="width of an image (default: 84)") 19 | 20 | # ************************************************************ 21 | # Optimization options 22 | # ************************************************************ 23 | parser.add_argument('--optim', type=str, default='sgd', 24 | help="optimization algorithm (see optimizers.py)") 25 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 26 | help="initial learning rate") 27 | parser.add_argument('--weight-decay', default=5e-04, type=float, 28 | help="weight decay (default: 5e-04)") 29 | 30 | parser.add_argument('--max-epoch', default=90, type=int, 31 | help="maximum epochs to run") 32 | parser.add_argument('--start-epoch', default=0, type=int, 33 | help="manual epoch number (useful on restarts)") 34 | parser.add_argument('--stepsize', default=[60], nargs='+', type=int, 35 | help="stepsize to decay learning rate") 36 | parser.add_argument('--LUT_lr', default=[(60, 0.1), (70, 0.006), (80, 0.0012), (90, 0.00024)], 37 | help="multistep to decay learning rate") 38 | 39 | parser.add_argument('--train-batch', default=4, type=int, 40 | help="train batch size") 41 | parser.add_argument('--test-batch', default=4, type=int, 42 | help="test batch size") 43 | 44 | # ************************************************************ 45 | # Architecture settings 46 | # ************************************************************ 47 | parser.add_argument('--num_classes', type=int, default=64) 48 | parser.add_argument('--scale_cls', type=int, default=7) 49 | 50 | # ************************************************************ 51 | # Miscs 52 | # ************************************************************ 53 | parser.add_argument('--save-dir', type=str, default='./result/miniImageNet/CAM/5-shot-seed112/') 54 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 55 | parser.add_argument('--gpu-devices', default='2', type=str) 56 | 57 | # ************************************************************ 58 | # FewShot settting 59 | # ************************************************************ 60 | parser.add_argument('--nKnovel', type=int, default=5, 61 | help='number of novel categories') 62 | parser.add_argument('--nExemplars', type=int, default=5, 63 | help='number of training examples per novel category.') 64 | 65 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 66 | help='number of test examples for all the novel category when training') 67 | parser.add_argument('--train_epoch_size', type=int, default=1200, 68 | help='number of batches per epoch when training') 69 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 70 | help='number of test examples for all the novel category') 71 | parser.add_argument('--epoch_size', type=int, default=2000, 72 | help='number of batches per epoch') 73 | 74 | parser.add_argument('--phase', default='test', type=str, 75 | help='use test or val dataset to early stop') 76 | parser.add_argument('--seed', type=int, default=1) 77 | 78 | return parser 79 | 80 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import argparse 9 | import os.path as osp 10 | import numpy as np 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | from torch.optim import lr_scheduler 18 | import torch.nn.functional as F 19 | sys.path.append('./torchFewShot') 20 | 21 | from torchFewShot.models.net import Model 22 | from torchFewShot.data_manager import DataManager 23 | from torchFewShot.losses import CrossEntropyLoss 24 | from torchFewShot.optimizers import init_optimizer 25 | 26 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 27 | from torchFewShot.utils.avgmeter import AverageMeter 28 | from torchFewShot.utils.logger import Logger 29 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 30 | 31 | parser = argparse.ArgumentParser(description='Test image model with 5-way classification') 32 | # Datasets 33 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load') 34 | parser.add_argument('--load', default=True) 35 | parser.add_argument('-j', '--workers', default=4, type=int, 36 | help="number of data loading workers (default: 4)") 37 | parser.add_argument('--height', type=int, default=84, 38 | help="height of an image (default: 84)") 39 | parser.add_argument('--width', type=int, default=84, 40 | help="width of an image (default: 84)") 41 | # Optimization options 42 | parser.add_argument('--train-batch', default=4, type=int, 43 | help="train batch size") 44 | parser.add_argument('--test-batch', default=8, type=int, 45 | help="test batch size") 46 | # Architecture 47 | parser.add_argument('--num_classes', type=int, default=64) 48 | parser.add_argument('--scale_cls', type=int, default=7) 49 | parser.add_argument('--save-dir', type=str, default='') 50 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 51 | # FewShot settting 52 | parser.add_argument('--nKnovel', type=int, default=5, 53 | help='number of novel categories') 54 | parser.add_argument('--nExemplars', type=int, default=1, 55 | help='number of training examples per novel category.') 56 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 57 | help='number of test examples for all the novel category when training') 58 | parser.add_argument('--train_epoch_size', type=int, default=1200, 59 | help='number of episodes per epoch when training') 60 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 61 | help='number of test examples for all the novel category') 62 | parser.add_argument('--epoch_size', type=int, default=2000, 63 | help='number of batches per epoch') 64 | # Miscs 65 | parser.add_argument('--phase', default='test', type=str) 66 | parser.add_argument('--seed', type=int, default=1) 67 | parser.add_argument('--gpu-devices', default='1', type=str) 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | def main(): 73 | torch.manual_seed(args.seed) 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 75 | use_gpu = torch.cuda.is_available() 76 | 77 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) 78 | print("==========\nArgs:{}\n==========".format(args)) 79 | 80 | if use_gpu: 81 | print("Currently using GPU {}".format(args.gpu_devices)) 82 | cudnn.benchmark = True 83 | torch.cuda.manual_seed_all(args.seed) 84 | else: 85 | print("Currently using CPU (GPU is highly recommended)") 86 | 87 | print('Initializing image data manager') 88 | dm = DataManager(args, use_gpu) 89 | trainloader, testloader = dm.return_dataloaders() 90 | 91 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes) 92 | # load the model 93 | checkpoint = torch.load(args.resume) 94 | model.load_state_dict(checkpoint['state_dict']) 95 | print("Loaded checkpoint from '{}'".format(args.resume)) 96 | 97 | if use_gpu: 98 | model = model.cuda() 99 | 100 | test(model, testloader, use_gpu) 101 | 102 | 103 | def test(model, testloader, use_gpu): 104 | accs = AverageMeter() 105 | test_accuracies = [] 106 | model.eval() 107 | 108 | with torch.no_grad(): 109 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader): 110 | if use_gpu: 111 | images_train = images_train.cuda() 112 | images_test = images_test.cuda() 113 | 114 | end = time.time() 115 | 116 | batch_size, num_train_examples, channels, height, width = images_train.size() 117 | num_test_examples = images_test.size(1) 118 | 119 | labels_train_1hot = one_hot(labels_train).cuda() 120 | labels_test_1hot = one_hot(labels_test).cuda() 121 | 122 | cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot) 123 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1) 124 | labels_test = labels_test.view(batch_size * num_test_examples) 125 | 126 | _, preds = torch.max(cls_scores.detach().cpu(), 1) 127 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0) 128 | accs.update(acc.item(), labels_test.size(0)) 129 | 130 | gt = (preds == labels_test.detach().cpu()).float() 131 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n] 132 | acc = np.sum(gt, 1) / num_test_examples 133 | acc = np.reshape(acc, (batch_size)) 134 | test_accuracies.append(acc) 135 | 136 | accuracy = accs.avg 137 | test_accuracies = np.array(test_accuracies) 138 | test_accuracies = np.reshape(test_accuracies, -1) 139 | stds = np.std(test_accuracies, 0) 140 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size) 141 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95)) 142 | 143 | return accuracy 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /test_transductive.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import argparse 9 | import os.path as osp 10 | import numpy as np 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | from torch.optim import lr_scheduler 18 | import torch.nn.functional as F 19 | sys.path.append('./torchFewShot') 20 | 21 | from torchFewShot.models.net import Model 22 | from torchFewShot.data_manager import DataManager 23 | from torchFewShot.losses import CrossEntropyLoss 24 | from torchFewShot.optimizers import init_optimizer 25 | 26 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 27 | from torchFewShot.utils.avgmeter import AverageMeter 28 | from torchFewShot.utils.logger import Logger 29 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 30 | 31 | parser = argparse.ArgumentParser(description='Test image model with 5-way classification') 32 | # Datasets 33 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load') 34 | parser.add_argument('--load', default=True) 35 | parser.add_argument('-j', '--workers', default=4, type=int, 36 | help="number of data loading workers (default: 4)") 37 | parser.add_argument('--height', type=int, default=84, 38 | help="height of an image (default: 84)") 39 | parser.add_argument('--width', type=int, default=84, 40 | help="width of an image (default: 84)") 41 | # Optimization options 42 | parser.add_argument('--train-batch', default=4, type=int, 43 | help="train batch size") 44 | parser.add_argument('--test-batch', default=1, type=int, 45 | help="test batch size") 46 | # Architecture 47 | parser.add_argument('--num_classes', type=int, default=64) 48 | parser.add_argument('--scale_cls', type=int, default=7) 49 | parser.add_argument('--save-dir', type=str, default='') 50 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 51 | # FewShot settting 52 | parser.add_argument('--nKnovel', type=int, default=5, 53 | help='number of novel categories') 54 | parser.add_argument('--nExemplars', type=int, default=1, 55 | help='number of training examples per novel category.') 56 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 57 | help='number of test examples for all the novel category when training') 58 | parser.add_argument('--train_epoch_size', type=int, default=1200, 59 | help='number of episodes per epoch when training') 60 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 61 | help='number of test examples for all the novel category') 62 | parser.add_argument('--epoch_size', type=int, default=2000, 63 | help='number of batches per epoch') 64 | # Miscs 65 | parser.add_argument('--phase', default='test', type=str) 66 | parser.add_argument('--seed', type=int, default=1) 67 | parser.add_argument('--gpu-devices', default='3', type=str) 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | def main(): 73 | torch.manual_seed(args.seed) 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 75 | use_gpu = torch.cuda.is_available() 76 | 77 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test_tranductive.txt')) 78 | print("==========\nArgs:{}\n==========".format(args)) 79 | 80 | if use_gpu: 81 | print("Currently using GPU {}".format(args.gpu_devices)) 82 | cudnn.benchmark = True 83 | torch.cuda.manual_seed_all(args.seed) 84 | else: 85 | print("Currently using CPU (GPU is highly recommended)") 86 | 87 | print('Initializing image data manager') 88 | dm = DataManager(args, use_gpu) 89 | trainloader, testloader = dm.return_dataloaders() 90 | 91 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes) 92 | # load the model 93 | checkpoint = torch.load(args.resume) 94 | model.load_state_dict(checkpoint['state_dict']) 95 | print("Loaded checkpoint from '{}'".format(args.resume)) 96 | 97 | if use_gpu: 98 | model = model.cuda() 99 | 100 | test(model, testloader, use_gpu) 101 | 102 | 103 | def test(model, testloader, use_gpu): 104 | accs = AverageMeter() 105 | test_accuracies = [] 106 | model.eval() 107 | 108 | with torch.no_grad(): 109 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader): 110 | if use_gpu: 111 | images_train = images_train.cuda() 112 | images_test = images_test.cuda() 113 | 114 | end = time.time() 115 | 116 | batch_size, num_train_examples, channels, height, width = images_train.size() 117 | num_test_examples = images_test.size(1) 118 | 119 | labels_train_1hot = one_hot(labels_train).cuda() 120 | labels_test_1hot = one_hot(labels_test).cuda() 121 | 122 | cls_scores = model.test_transductive(images_train, images_test, labels_train_1hot, labels_test_1hot) 123 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1) 124 | labels_test = labels_test.view(batch_size * num_test_examples) 125 | 126 | _, preds = torch.max(cls_scores.detach().cpu(), 1) 127 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0) 128 | accs.update(acc.item(), labels_test.size(0)) 129 | 130 | gt = (preds == labels_test.detach().cpu()).float() 131 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n] 132 | acc = np.sum(gt, 1) / num_test_examples 133 | acc = np.reshape(acc, (batch_size)) 134 | test_accuracies.append(acc) 135 | 136 | accuracy = accs.avg 137 | test_accuracies = np.array(test_accuracies) 138 | test_accuracies = np.reshape(test_accuracies, -1) 139 | stds = np.std(test_accuracies, 0) 140 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size) 141 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95)) 142 | 143 | return accuracy 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /torchFewShot/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | import transforms as T 7 | import datasets 8 | import dataset_loader 9 | 10 | class DataManager(object): 11 | """ 12 | Few shot data manager 13 | """ 14 | 15 | def __init__(self, args, use_gpu): 16 | super(DataManager, self).__init__() 17 | self.args = args 18 | self.use_gpu = use_gpu 19 | 20 | print("Initializing dataset {}".format(args.dataset)) 21 | dataset = datasets.init_imgfewshot_dataset(name=args.dataset) 22 | 23 | if args.load: 24 | transform_train = T.Compose([ 25 | T.RandomCrop(84, padding=8), 26 | T.RandomHorizontalFlip(), 27 | T.ToTensor(), 28 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 29 | T.RandomErasing(0.5) 30 | ]) 31 | 32 | transform_test = T.Compose([ 33 | T.ToTensor(), 34 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 35 | ]) 36 | 37 | else: 38 | transform_train = T.Compose([ 39 | T.Resize((args.height, args.width), interpolation=3), 40 | T.RandomCrop(args.height, padding=8), 41 | T.RandomHorizontalFlip(), 42 | T.ToTensor(), 43 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | T.RandomErasing(0.5) 45 | ]) 46 | 47 | transform_test = T.Compose([ 48 | T.Resize((args.height, args.width), interpolation=3), 49 | T.ToTensor(), 50 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 51 | ]) 52 | 53 | pin_memory = True if use_gpu else False 54 | 55 | self.trainloader = DataLoader( 56 | dataset_loader.init_loader(name='train_loader', 57 | dataset=dataset.train, 58 | labels2inds=dataset.train_labels2inds, 59 | labelIds=dataset.train_labelIds, 60 | nKnovel=args.nKnovel, 61 | nExemplars=args.nExemplars, 62 | nTestNovel=args.train_nTestNovel, 63 | epoch_size=args.train_epoch_size, 64 | transform=transform_train, 65 | load=args.load, 66 | ), 67 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, 68 | pin_memory=pin_memory, drop_last=True, 69 | ) 70 | 71 | self.valloader = DataLoader( 72 | dataset_loader.init_loader(name='test_loader', 73 | dataset=dataset.val, 74 | labels2inds=dataset.val_labels2inds, 75 | labelIds=dataset.val_labelIds, 76 | nKnovel=args.nKnovel, 77 | nExemplars=args.nExemplars, 78 | nTestNovel=args.nTestNovel, 79 | epoch_size=args.epoch_size, 80 | transform=transform_test, 81 | load=args.load, 82 | ), 83 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 84 | pin_memory=pin_memory, drop_last=False, 85 | ) 86 | self.testloader = DataLoader( 87 | dataset_loader.init_loader(name='test_loader', 88 | dataset=dataset.test, 89 | labels2inds=dataset.test_labels2inds, 90 | labelIds=dataset.test_labelIds, 91 | nKnovel=args.nKnovel, 92 | nExemplars=args.nExemplars, 93 | nTestNovel=args.nTestNovel, 94 | epoch_size=args.epoch_size, 95 | transform=transform_test, 96 | load=args.load, 97 | ), 98 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 99 | pin_memory=pin_memory, drop_last=False, 100 | ) 101 | 102 | def return_dataloaders(self): 103 | if self.args.phase == 'test': 104 | return self.trainloader, self.testloader 105 | elif self.args.phase == 'val': 106 | return self.trainloader, self.valloader 107 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .train_loader import FewShotDataset_train 4 | from .test_loader import FewShotDataset_test 5 | 6 | 7 | __loader_factory = { 8 | 'train_loader': FewShotDataset_train, 9 | 'test_loader': FewShotDataset_test, 10 | } 11 | 12 | 13 | 14 | def get_names(): 15 | return list(__loader_factory.keys()) 16 | 17 | 18 | def init_loader(name, *args, **kwargs): 19 | if name not in list(__loader_factory.keys()): 20 | raise KeyError("Unknown model: {}".format(name)) 21 | return __loader_factory[name](*args, **kwargs) 22 | 23 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/test_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import lmdb 10 | import io 11 | import random 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | 17 | def read_image(img_path): 18 | """Keep reading image until succeed. 19 | This can avoid IOError incurred by heavy IO process.""" 20 | got_img = False 21 | if not osp.exists(img_path): 22 | raise IOError("{} does not exist".format(img_path)) 23 | while not got_img: 24 | try: 25 | img = Image.open(img_path).convert('RGB') 26 | got_img = True 27 | except IOError: 28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 29 | pass 30 | return img 31 | 32 | 33 | class FewShotDataset_test(Dataset): 34 | """Few shot epoish Dataset 35 | 36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify' 37 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 38 | Ytrain: [nKnovel*nExpemplars]. 39 | Xtest: [nTestNovel, c, h, w]. 40 | Ytest: [nTestNovel]. 41 | """ 42 | 43 | def __init__(self, 44 | dataset, # dataset of [(img_path, cats), ...]. 45 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 46 | labelIds, # train labels [0, 1, 2, 3, ...,]. 47 | nKnovel=5, # number of novel categories. 48 | nExemplars=1, # number of training examples per novel category. 49 | nTestNovel=2*5, # number of test examples for all the novel categories. 50 | epoch_size=2000, # number of tasks per eooch. 51 | transform=None, 52 | load=True, 53 | **kwargs 54 | ): 55 | 56 | self.dataset = dataset 57 | self.labels2inds = labels2inds 58 | self.labelIds = labelIds 59 | self.nKnovel = nKnovel 60 | self.transform = transform 61 | 62 | self.nExemplars = nExemplars 63 | self.nTestNovel = nTestNovel 64 | self.epoch_size = epoch_size 65 | self.load = load 66 | 67 | seed = 112 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | 71 | self.Epoch_Exemplar = [] 72 | self.Epoch_Tnovel = [] 73 | for i in range(epoch_size): 74 | Tnovel, Exemplar = self._sample_episode() 75 | self.Epoch_Exemplar.append(Exemplar) 76 | self.Epoch_Tnovel.append(Tnovel) 77 | 78 | def __len__(self): 79 | return self.epoch_size 80 | 81 | def _sample_episode(self): 82 | """sampels a training epoish indexs. 83 | Returns: 84 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 85 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 86 | """ 87 | 88 | Knovel = random.sample(self.labelIds, self.nKnovel) 89 | nKnovel = len(Knovel) 90 | assert((self.nTestNovel % nKnovel) == 0) 91 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) 92 | 93 | Tnovel = [] 94 | Exemplars = [] 95 | for Knovel_idx in range(len(Knovel)): 96 | ids = (nEvalExamplesPerClass + self.nExemplars) 97 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 98 | 99 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 100 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 101 | 102 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 103 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 104 | assert(len(Tnovel) == self.nTestNovel) 105 | assert(len(Exemplars) == nKnovel * self.nExemplars) 106 | random.shuffle(Exemplars) 107 | random.shuffle(Tnovel) 108 | 109 | return Tnovel, Exemplars 110 | 111 | def _creatExamplesTensorData(self, examples): 112 | """ 113 | Creats the examples image label tensor data. 114 | 115 | Args: 116 | examples: a list of 2-element tuples. (sample_index, label). 117 | 118 | Returns: 119 | images: a tensor [nExemplars, c, h, w] 120 | labels: a tensor [nExemplars] 121 | """ 122 | 123 | images = [] 124 | labels = [] 125 | for (img_idx, label) in examples: 126 | img = self.dataset[img_idx][0] 127 | if self.load: 128 | img = Image.fromarray(img) 129 | else: 130 | img = read_image(img) 131 | if self.transform is not None: 132 | img = self.transform(img) 133 | images.append(img) 134 | labels.append(label) 135 | images = torch.stack(images, dim=0) 136 | labels = torch.LongTensor(labels) 137 | return images, labels 138 | 139 | def __getitem__(self, index): 140 | Tnovel = self.Epoch_Tnovel[index] 141 | Exemplars = self.Epoch_Exemplar[index] 142 | Xt, Yt = self._creatExamplesTensorData(Exemplars) 143 | Xe, Ye = self._creatExamplesTensorData(Tnovel) 144 | return Xt, Yt, Xe, Ye 145 | 146 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/train_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import lmdb 10 | import io 11 | import random 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | 17 | def read_image(img_path): 18 | """Keep reading image until succeed. 19 | This can avoid IOError incurred by heavy IO process.""" 20 | got_img = False 21 | if not osp.exists(img_path): 22 | raise IOError("{} does not exist".format(img_path)) 23 | while not got_img: 24 | try: 25 | img = Image.open(img_path).convert('RGB') 26 | got_img = True 27 | except IOError: 28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 29 | pass 30 | return img 31 | 32 | 33 | class FewShotDataset_train(Dataset): 34 | """Few shot epoish Dataset 35 | 36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify' 37 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 38 | Ytrain: [nKnovel*nExpemplars]. 39 | Xtest: [nTestNovel, c, h, w]. 40 | Ytest: [nTestNovel]. 41 | Ycls: [nTestNovel]. 42 | """ 43 | 44 | def __init__(self, 45 | dataset, # dataset of [(img_path, cats), ...]. 46 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 47 | labelIds, # train labels [0, 1, 2, 3, ...,]. 48 | nKnovel=5, # number of novel categories. 49 | nExemplars=1, # number of training examples per novel category. 50 | nTestNovel=6*5, # number of test examples for all the novel categories. 51 | epoch_size=2000, # number of tasks per eooch. 52 | transform=None, 53 | load=False, 54 | **kwargs 55 | ): 56 | 57 | self.dataset = dataset 58 | self.labels2inds = labels2inds 59 | self.labelIds = labelIds 60 | self.nKnovel = nKnovel 61 | self.transform = transform 62 | 63 | self.nExemplars = nExemplars 64 | self.nTestNovel = nTestNovel 65 | self.epoch_size = epoch_size 66 | self.load = load 67 | 68 | def __len__(self): 69 | return self.epoch_size 70 | 71 | def _sample_episode(self): 72 | """sampels a training epoish indexs. 73 | Returns: 74 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 75 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 76 | """ 77 | 78 | Knovel = random.sample(self.labelIds, self.nKnovel) 79 | nKnovel = len(Knovel) 80 | assert((self.nTestNovel % nKnovel) == 0) 81 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) 82 | 83 | Tnovel = [] 84 | Exemplars = [] 85 | for Knovel_idx in range(len(Knovel)): 86 | ids = (nEvalExamplesPerClass + self.nExemplars) 87 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 88 | 89 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 90 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 91 | 92 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 93 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 94 | assert(len(Tnovel) == self.nTestNovel) 95 | assert(len(Exemplars) == nKnovel * self.nExemplars) 96 | random.shuffle(Exemplars) 97 | random.shuffle(Tnovel) 98 | 99 | return Tnovel, Exemplars 100 | 101 | def _creatExamplesTensorData(self, examples): 102 | """ 103 | Creats the examples image label tensor data. 104 | 105 | Args: 106 | examples: a list of 2-element tuples. (sample_index, label). 107 | 108 | Returns: 109 | images: a tensor [nExemplars, c, h, w] 110 | labels: a tensor [nExemplars] 111 | cls: a tensor [nExemplars] 112 | """ 113 | 114 | images = [] 115 | labels = [] 116 | cls = [] 117 | for (img_idx, label) in examples: 118 | img, ids = self.dataset[img_idx] 119 | if self.load: 120 | img = Image.fromarray(img) 121 | else: 122 | img = read_image(img) 123 | if self.transform is not None: 124 | img = self.transform(img) 125 | images.append(img) 126 | labels.append(label) 127 | cls.append(ids) 128 | images = torch.stack(images, dim=0) 129 | labels = torch.LongTensor(labels) 130 | cls = torch.LongTensor(cls) 131 | return images, labels, cls 132 | 133 | 134 | def __getitem__(self, index): 135 | Tnovel, Exemplars = self._sample_episode() 136 | Xt, Yt, Ytc = self._creatExamplesTensorData(Exemplars) 137 | Xe, Ye, Yec = self._creatExamplesTensorData(Tnovel) 138 | return Xt, Yt, Xe, Ye, Yec 139 | -------------------------------------------------------------------------------- /torchFewShot/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .miniImageNet import miniImageNet 6 | from .tieredImageNet import tieredImageNet 7 | from .miniImageNet_load import miniImageNet_load 8 | 9 | 10 | __imgfewshot_factory = { 11 | 'miniImageNet': miniImageNet, 12 | 'tieredImageNet': tieredImageNet, 13 | 'miniImageNet_load': miniImageNet_load, 14 | } 15 | 16 | 17 | def get_names(): 18 | return list(__imgfewshot_factory.keys()) 19 | 20 | 21 | def init_imgfewshot_dataset(name, **kwargs): 22 | if name not in list(__imgfewshot_factory.keys()): 23 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgfewshot_factory.keys()))) 24 | return __imgfewshot_factory[name](**kwargs) 25 | 26 | -------------------------------------------------------------------------------- /torchFewShot/datasets/miniImageNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | 8 | class miniImageNet(object): 9 | """ 10 | Dataset statistics: 11 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test) 12 | """ 13 | dataset_dir = '/home/houruibing/data/few-shot/mini-imagenet/' 14 | 15 | def __init__(self): 16 | super(miniImageNet, self).__init__() 17 | self.train_dir = os.path.join(self.dataset_dir, 'train') 18 | self.val_dir = os.path.join(self.dataset_dir, 'val') 19 | self.test_dir = os.path.join(self.dataset_dir, 'test') 20 | 21 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir) 22 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir) 23 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir) 24 | 25 | self.num_train_cats = len(self.train_labelIds) 26 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds) 27 | num_total_imgs = len(self.train + self.val + self.test) 28 | 29 | print("=> MiniImageNet loaded") 30 | print("Dataset statistics:") 31 | print(" ------------------------------") 32 | print(" subset | # cats | # images") 33 | print(" ------------------------------") 34 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train))) 35 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val))) 36 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test))) 37 | print(" ------------------------------") 38 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs)) 39 | print(" ------------------------------") 40 | 41 | def _check_before_run(self): 42 | """Check if all files are available before going deeper""" 43 | if not osp.exists(self.dataset_dir): 44 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 45 | if not osp.exists(self.train_dir): 46 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 47 | if not osp.exists(self.val_dir): 48 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 49 | if not osp.exists(self.test_dir): 50 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 51 | 52 | def _process_dir(self, dir_path): 53 | cat_container = sorted(os.listdir(dir_path)) 54 | cats2label = {cat:label for label, cat in enumerate(cat_container)} 55 | 56 | dataset = [] 57 | labels = [] 58 | for cat in cat_container: 59 | for img_path in sorted(os.listdir(os.path.join(dir_path, cat))): 60 | if '.jpg' not in img_path: 61 | continue 62 | label = cats2label[cat] 63 | dataset.append((os.path.join(dir_path, cat, img_path), label)) 64 | labels.append(label) 65 | 66 | labels2inds = {} 67 | for idx, label in enumerate(labels): 68 | if label not in labels2inds: 69 | labels2inds[label] = [] 70 | labels2inds[label].append(idx) 71 | 72 | labelIds = sorted(labels2inds.keys()) 73 | return dataset, labels2inds, labelIds 74 | -------------------------------------------------------------------------------- /torchFewShot/datasets/miniImageNet_load.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | import pickle 8 | 9 | 10 | def load_data(file): 11 | with open(file, 'rb') as fo: 12 | data = pickle.load(fo) 13 | return data 14 | 15 | 16 | def buildLabelIndex(labels): 17 | label2inds = {} 18 | for idx, label in enumerate(labels): 19 | if label not in label2inds: 20 | label2inds[label] = [] 21 | label2inds[label].append(idx) 22 | 23 | return label2inds 24 | 25 | 26 | class miniImageNet_load(object): 27 | """ 28 | Dataset statistics: 29 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test) 30 | """ 31 | dataset_dir = '/home/houruibing/code/few-shot/MiniImagenet/' 32 | 33 | def __init__(self, **kwargs): 34 | super(miniImageNet_load, self).__init__() 35 | self.train_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_train_phase_train.pickle') 36 | self.val_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_val.pickle') 37 | self.test_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_test.pickle') 38 | 39 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir) 40 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir) 41 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir) 42 | 43 | self.num_train_cats = len(self.train_labelIds) 44 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds) 45 | num_total_imgs = len(self.train + self.val + self.test) 46 | 47 | print("=> MiniImageNet loaded") 48 | print("Dataset statistics:") 49 | print(" ------------------------------") 50 | print(" subset | # cats | # images") 51 | print(" ------------------------------") 52 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train))) 53 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val))) 54 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test))) 55 | print(" ------------------------------") 56 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs)) 57 | print(" ------------------------------") 58 | 59 | def _check_before_run(self): 60 | """Check if all files are available before going deeper""" 61 | if not osp.exists(self.dataset_dir): 62 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 63 | if not osp.exists(self.train_dir): 64 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 65 | if not osp.exists(self.val_dir): 66 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 67 | if not osp.exists(self.test_dir): 68 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 69 | 70 | def _get_pair(self, data, labels): 71 | assert (data.shape[0] == len(labels)) 72 | data_pair = [] 73 | for i in range(data.shape[0]): 74 | data_pair.append((data[i], labels[i])) 75 | return data_pair 76 | 77 | def _process_dir(self, file_path): 78 | dataset = load_data(file_path) 79 | data = dataset['data'] 80 | print(data.shape) 81 | labels = dataset['labels'] 82 | data_pair = self._get_pair(data, labels) 83 | labels2inds = buildLabelIndex(labels) 84 | labelIds = sorted(labels2inds.keys()) 85 | return data_pair, labels2inds, labelIds 86 | 87 | if __name__ == '__main__': 88 | miniImageNet_load() 89 | -------------------------------------------------------------------------------- /torchFewShot/datasets/tieredImageNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | 8 | class tieredImageNet(object): 9 | """ 10 | Dataset statistics: 11 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test) 12 | """ 13 | dataset_dir = '/home/houruibing/data/few-shot/tieredImagenet/images/images/' 14 | 15 | def __init__(self): 16 | super(tieredImageNet, self).__init__() 17 | self.train_dir = os.path.join(self.dataset_dir, 'train') 18 | self.val_dir = os.path.join(self.dataset_dir, 'val') 19 | self.test_dir = os.path.join(self.dataset_dir, 'test') 20 | 21 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir) 22 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir) 23 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir) 24 | 25 | self.num_train_cats = len(self.train_labelIds) 26 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds) 27 | num_total_imgs = len(self.train + self.val + self.test) 28 | 29 | print("=> tieredImageNet loaded") 30 | print("Dataset statistics:") 31 | print(" ------------------------------") 32 | print(" subset | # cats | # images") 33 | print(" ------------------------------") 34 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train))) 35 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val))) 36 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test))) 37 | print(" ------------------------------") 38 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs)) 39 | print(" ------------------------------") 40 | 41 | def _check_before_run(self): 42 | """Check if all files are available before going deeper""" 43 | if not osp.exists(self.dataset_dir): 44 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 45 | if not osp.exists(self.train_dir): 46 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 47 | if not osp.exists(self.val_dir): 48 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 49 | if not osp.exists(self.test_dir): 50 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 51 | 52 | def _process_dir(self, dir_path): 53 | cat_container = sorted(os.listdir(dir_path)) 54 | cats2label = {cat:label for label, cat in enumerate(cat_container)} 55 | 56 | dataset = [] 57 | labels = [] 58 | for cat in cat_container: 59 | for img_path in sorted(os.listdir(os.path.join(dir_path, cat))): 60 | if '.JPEG' not in img_path: 61 | continue 62 | label = cats2label[cat] 63 | dataset.append((os.path.join(dir_path, cat, img_path), label)) 64 | labels.append(label) 65 | 66 | labels2inds = {} 67 | for idx, label in enumerate(labels): 68 | if label not in labels2inds: 69 | labels2inds[label] = [] 70 | labels2inds[label].append(idx) 71 | 72 | labelIds = sorted(labels2inds.keys()) 73 | return dataset, labels2inds, labelIds 74 | 75 | 76 | if __name__ == '__main__': 77 | tieredImageNet() 78 | -------------------------------------------------------------------------------- /torchFewShot/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class CrossEntropyLoss(nn.Module): 8 | def __init__(self): 9 | super(CrossEntropyLoss, self).__init__() 10 | self.logsoftmax = nn.LogSoftmax(dim=1) 11 | 12 | def forward(self, inputs, targets): 13 | inputs = inputs.view(inputs.size(0), inputs.size(1), -1) 14 | 15 | log_probs = self.logsoftmax(inputs) 16 | targets = torch.zeros(inputs.size(0), inputs.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 17 | targets = targets.unsqueeze(-1) 18 | targets = targets.cuda() 19 | loss = (- targets * log_probs).mean(0).sum() 20 | return loss / inputs.size(2) 21 | -------------------------------------------------------------------------------- /torchFewShot/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/models/cam.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import math 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class ConvBlock(nn.Module): 11 | """Basic convolutional block: 12 | convolution + batch normalization. 13 | 14 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 15 | - in_c (int): number of input channels. 16 | - out_c (int): number of output channels. 17 | - k (int or tuple): kernel size. 18 | - s (int or tuple): stride. 19 | - p (int or tuple): padding. 20 | """ 21 | def __init__(self, in_c, out_c, k, s=1, p=0): 22 | super(ConvBlock, self).__init__() 23 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 24 | self.bn = nn.BatchNorm2d(out_c) 25 | 26 | def forward(self, x): 27 | return self.bn(self.conv(x)) 28 | 29 | 30 | class CAM(nn.Module): 31 | def __init__(self): 32 | super(CAM, self).__init__() 33 | self.conv1 = ConvBlock(36, 6, 1) 34 | self.conv2 = nn.Conv2d(6, 36, 1, stride=1, padding=0) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | 41 | def get_attention(self, a): 42 | input_a = a 43 | 44 | a = a.mean(3) 45 | a = a.transpose(1, 3) 46 | a = F.relu(self.conv1(a)) 47 | a = self.conv2(a) 48 | a = a.transpose(1, 3) 49 | a = a.unsqueeze(3) 50 | 51 | a = torch.mean(input_a * a, -1) 52 | a = F.softmax(a / 0.025, dim=-1) + 1 53 | return a 54 | 55 | def forward(self, f1, f2): 56 | b, n1, c, h, w = f1.size() 57 | n2 = f2.size(1) 58 | 59 | f1 = f1.view(b, n1, c, -1) 60 | f2 = f2.view(b, n2, c, -1) 61 | 62 | f1_norm = F.normalize(f1, p=2, dim=2, eps=1e-12) 63 | f2_norm = F.normalize(f2, p=2, dim=2, eps=1e-12) 64 | 65 | f1_norm = f1_norm.transpose(2, 3).unsqueeze(2) 66 | f2_norm = f2_norm.unsqueeze(1) 67 | 68 | a1 = torch.matmul(f1_norm, f2_norm) 69 | a2 = a1.transpose(3, 4) 70 | 71 | a1 = self.get_attention(a1) 72 | a2 = self.get_attention(a2) 73 | 74 | f1 = f1.unsqueeze(2) * a1.unsqueeze(3) 75 | f1 = f1.view(b, n1, n2, c, h, w) 76 | f2 = f2.unsqueeze(1) * a2.unsqueeze(3) 77 | f2 = f2.view(b, n1, n2, c, h, w) 78 | 79 | return f1.transpose(1, 2), f2.transpose(1, 2) 80 | -------------------------------------------------------------------------------- /torchFewShot/models/net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from resnet12 import resnet12 7 | from cam import CAM 8 | 9 | 10 | def one_hot(labels_train): 11 | labels_train = labels_train.cpu() 12 | nKnovel = 5 13 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,] 14 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) 15 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) 16 | return labels_train_1hot 17 | 18 | 19 | class Model(nn.Module): 20 | def __init__(self, scale_cls, iter_num_prob=35.0/75, num_classes=64): 21 | super(Model, self).__init__() 22 | self.scale_cls = scale_cls 23 | self.iter_num_prob = iter_num_prob 24 | 25 | self.base = resnet12() 26 | self.cam = CAM() 27 | 28 | self.nFeat = self.base.nFeat 29 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 30 | 31 | def test(self, ftrain, ftest): 32 | ftest = ftest.mean(4) 33 | ftest = ftest.mean(4) 34 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 35 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 36 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 37 | return scores 38 | 39 | def forward(self, xtrain, xtest, ytrain, ytest): 40 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 41 | num_test = xtest.size(1) 42 | K = ytrain.size(2) 43 | ytrain = ytrain.transpose(1, 2) 44 | 45 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4)) 46 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4)) 47 | x = torch.cat((xtrain, xtest), 0) 48 | f = self.base(x) 49 | 50 | ftrain = f[:batch_size * num_train] 51 | ftrain = ftrain.view(batch_size, num_train, -1) 52 | ftrain = torch.bmm(ytrain, ftrain) 53 | ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain)) 54 | ftrain = ftrain.view(batch_size, -1, *f.size()[1:]) 55 | ftest = f[batch_size * num_train:] 56 | ftest = ftest.view(batch_size, num_test, *f.size()[1:]) 57 | ftrain, ftest = self.cam(ftrain, ftest) 58 | ftrain = ftrain.mean(4) 59 | ftrain = ftrain.mean(4) 60 | 61 | if not self.training: 62 | return self.test(ftrain, ftest) 63 | 64 | ftest_norm = F.normalize(ftest, p=2, dim=3, eps=1e-12) 65 | ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12) 66 | ftrain_norm = ftrain_norm.unsqueeze(4) 67 | ftrain_norm = ftrain_norm.unsqueeze(5) 68 | cls_scores = self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3) 69 | cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:]) 70 | 71 | ftest = ftest.view(batch_size, num_test, K, -1) 72 | ftest = ftest.transpose(2, 3) 73 | ytest = ytest.unsqueeze(3) 74 | ftest = torch.matmul(ftest, ytest) 75 | ftest = ftest.view(batch_size * num_test, -1, 6, 6) 76 | ytest = self.clasifier(ftest) 77 | return ytest, cls_scores 78 | 79 | def helper(self, ftrain, ftest, ytrain): 80 | b, n, c, h, w = ftrain.size() 81 | k = ytrain.size(2) 82 | 83 | ytrain_transposed = ytrain.transpose(1, 2) 84 | ftrain = torch.bmm(ytrain_transposed, ftrain.view(b, n, -1)) 85 | ftrain = ftrain.div(ytrain_transposed.sum(dim=2, keepdim=True).expand_as(ftrain)) 86 | ftrain = ftrain.view(b, -1, c, h, w) 87 | 88 | ftrain, ftest = self.cam(ftrain, ftest) 89 | ftrain = ftrain.mean(-1).mean(-1) 90 | ftest = ftest.mean(-1).mean(-1) 91 | 92 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 93 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 94 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 95 | return scores 96 | 97 | def test_transductive(self, xtrain, xtest, ytrain, ytest): 98 | iter_num_prob = self.iter_num_prob 99 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 100 | num_test = xtest.size(1) 101 | K = ytrain.size(2) 102 | 103 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4)) 104 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4)) 105 | x = torch.cat((xtrain, xtest), 0) 106 | f = self.base(x) 107 | 108 | ftrain = f[: batch_size*num_train].view(batch_size, num_train, *f.size()[1:]) 109 | ftest = f[batch_size*num_train:].view(batch_size, num_test, *f.size()[1:]) 110 | cls_scores = self.helper(ftrain, ftest, ytrain) 111 | 112 | num_images_per_iter = int(num_test * iter_num_prob) 113 | num_iter = num_test // num_images_per_iter 114 | 115 | for i in range(num_iter): 116 | max_scores, preds = torch.max(cls_scores, 2) 117 | chose_index = torch.argsort(max_scores.view(-1), descending=True) 118 | chose_index = chose_index[: num_images_per_iter * (i + 1)] 119 | 120 | ftest_iter = ftest[0, chose_index].unsqueeze(0) 121 | preds_iter = preds[0, chose_index].unsqueeze(0) 122 | preds_iter = one_hot(preds_iter).cuda() 123 | 124 | ftrain_iter = torch.cat((ftrain, ftest_iter), 1) 125 | ytrain_iter = torch.cat((ytrain, preds_iter), 1) 126 | cls_scores = self.helper(ftrain_iter, ftest, ytrain_iter) 127 | 128 | return cls_scores 129 | 130 | 131 | 132 | 133 | if __name__ == '__main__': 134 | torch.manual_seed(0) 135 | 136 | net = Model(scale_cls=7) 137 | net.eval() 138 | 139 | x1 = torch.rand(1, 5, 3, 84, 84) 140 | x2 = torch.rand(1, 75, 3, 84, 84) 141 | y1 = torch.rand(1, 5, 5) 142 | y2 = torch.rand(1, 75, 5) 143 | 144 | y1 = net.test_transductive(x1, x2, y1, y2) 145 | print(y1.size()) 146 | -------------------------------------------------------------------------------- /torchFewShot/models/resnet12.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, inplanes, planes, kernel=3, stride=1, downsample=None): 16 | super(BasicBlock, self).__init__() 17 | if kernel == 1: 18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 19 | elif kernel == 3: 20 | self.conv1 = conv3x3(inplanes, planes) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | if kernel == 1: 26 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 27 | elif kernel == 3: 28 | self.conv3 = conv3x3(planes, planes) 29 | self.bn3 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, kernel=1, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, kernel=3): 99 | self.inplanes = 64 100 | self.kernel = kernel 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | 106 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | 111 | self.nFeat = 512 * block.expansion 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, self.kernel, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes, self.kernel)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def resnet12(): 151 | model = ResNet(BasicBlock, [1,1,1,1], kernel=3) 152 | return model 153 | -------------------------------------------------------------------------------- /torchFewShot/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def init_optimizer(optim, params, lr, weight_decay): 7 | if optim == 'adam': 8 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 9 | elif optim == 'amsgrad': 10 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True) 11 | elif optim == 'sgd': 12 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True) 13 | elif optim == 'rmsprop': 14 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 15 | else: 16 | raise KeyError("Unsupported optimizer: {}".format(optim)) 17 | -------------------------------------------------------------------------------- /torchFewShot/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from torchvision.transforms import * 5 | 6 | from PIL import Image 7 | import random 8 | import numpy as np 9 | import math 10 | import torch 11 | 12 | 13 | class Random2DTranslation(object): 14 | """ 15 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 16 | 17 | Args: 18 | - height (int): target height. 19 | - width (int): target width. 20 | - p (float): probability of performing this transformation. Default: 0.5. 21 | """ 22 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 23 | self.height = height 24 | self.width = width 25 | self.p = p 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | """ 30 | Args: 31 | - img (PIL Image): Image to be cropped. 32 | """ 33 | if random.uniform(0, 1) > self.p: 34 | return img.resize((self.width, self.height), self.interpolation) 35 | 36 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 37 | resized_img = img.resize((new_width, new_height), self.interpolation) 38 | x_maxrange = new_width - self.width 39 | y_maxrange = new_height - self.height 40 | x1 = int(round(random.uniform(0, x_maxrange))) 41 | y1 = int(round(random.uniform(0, y_maxrange))) 42 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 43 | return croped_img 44 | 45 | 46 | class RandomErasing(object): 47 | """ Randomly selects a rectangle region in an image and erases its pixels. 48 | 'Random Erasing Data Augmentation' by Zhong et al. 49 | See https://arxiv.org/pdf/1708.04896.pdf 50 | Args: 51 | probability: The probability that the Random Erasing operation will be performed. 52 | sl: Minimum proportion of erased area against input image. 53 | sh: Maximum proportion of erased area against input image. 54 | r1: Minimum aspect ratio of erased area. 55 | mean: Erasing value. 56 | """ 57 | 58 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 59 | self.probability = probability 60 | self.mean = mean 61 | self.sl = sl 62 | self.sh = sh 63 | self.r1 = r1 64 | 65 | def __call__(self, img): 66 | 67 | if random.uniform(0, 1) > self.probability: 68 | return img 69 | 70 | for attempt in range(100): 71 | area = img.size()[1] * img.size()[2] 72 | 73 | target_area = random.uniform(self.sl, self.sh) * area 74 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 75 | 76 | h = int(round(math.sqrt(target_area * aspect_ratio))) 77 | w = int(round(math.sqrt(target_area / aspect_ratio))) 78 | 79 | if w < img.size()[2] and h < img.size()[1]: 80 | x1 = random.randint(0, img.size()[1] - h) 81 | y1 = random.randint(0, img.size()[2] - w) 82 | if img.size()[0] == 3: 83 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 84 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 85 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 86 | else: 87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 88 | return img 89 | 90 | return img 91 | -------------------------------------------------------------------------------- /torchFewShot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchFewShot/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def check_isfile(path): 22 | isfile = osp.isfile(path) 23 | if not isfile: 24 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 25 | return isfile 26 | 27 | 28 | def read_json(fpath): 29 | with open(fpath, 'r') as f: 30 | obj = json.load(f) 31 | return obj 32 | 33 | 34 | def write_json(obj, fpath): 35 | mkdir_if_missing(osp.dirname(fpath)) 36 | with open(fpath, 'w') as f: 37 | json.dump(obj, f, indent=4, separators=(',', ': ')) 38 | 39 | 40 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 41 | if len(osp.dirname(fpath)) != 0: 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | torch.save(state, fpath) 44 | if is_best: 45 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /torchFewShot/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | 7 | from .iotools import mkdir_if_missing 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Write console output to external text file. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 14 | """ 15 | def __init__(self, fpath=None, mode='a'): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | self.file = open(fpath, mode) 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | -------------------------------------------------------------------------------- /torchFewShot/utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def open_all_layers(model): 9 | """ 10 | Open all layers in model for training. 11 | """ 12 | model.train() 13 | for p in model.parameters(): 14 | p.requires_grad = True 15 | 16 | 17 | def open_specified_layers(model, open_layers): 18 | """ 19 | Open specified layers in model for training while keeping 20 | other layers frozen. 21 | 22 | Args: 23 | - model (nn.Module): neural net model. 24 | - open_layers (list): list of layers names. 25 | """ 26 | if isinstance(model, nn.DataParallel): 27 | model = model.module 28 | 29 | for layer in open_layers: 30 | assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer) 31 | 32 | for name, module in model.named_children(): 33 | if name in open_layers: 34 | #print(module) 35 | module.train() 36 | for p in module.parameters(): 37 | p.requires_grad = True 38 | else: 39 | module.eval() 40 | for p in module.parameters(): 41 | p.requires_grad = False 42 | 43 | 44 | 45 | def adjust_learning_rate(optimizer, iters, LUT): 46 | # decay learning rate by 'gamma' for every 'stepsize' 47 | for (stepvalue, base_lr) in LUT: 48 | if iters < stepvalue: 49 | lr = base_lr 50 | break 51 | 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = lr 54 | return lr 55 | 56 | 57 | def adjust_lambda(iters, LUT): 58 | for (stepvalue, base_lambda) in LUT: 59 | if iters < stepvalue: 60 | lambda_xent = base_lambda 61 | break 62 | return lambda_xent 63 | 64 | 65 | def set_bn_to_eval(m): 66 | # 1. no update for running mean and var 67 | # 2. scale and shift parameters are still trainable 68 | classname = m.__class__.__name__ 69 | if classname.find('BatchNorm') != -1: 70 | m.eval() 71 | 72 | 73 | def count_num_param(model): 74 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 75 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 76 | # we ignore the classifier because it is unused at test time 77 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 78 | return num_param 79 | 80 | 81 | def one_hot(labels_train): 82 | """ 83 | Turn the labels_train to one-hot encoding. 84 | Args: 85 | labels_train: [batch_size, num_train_examples] 86 | Return: 87 | labels_train_1hot: [batch_size, num_train_examples, K] 88 | """ 89 | labels_train = labels_train.cpu() 90 | nKnovel = 1 + labels_train.max() 91 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,] 92 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) 93 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) 94 | return labels_train_1hot 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import argparse 9 | import os.path as osp 10 | import numpy as np 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | from torch.optim import lr_scheduler 18 | import torch.nn.functional as F 19 | sys.path.append('./torchFewShot') 20 | 21 | from args_tiered import argument_parser 22 | 23 | from torchFewShot.models.net import Model 24 | from torchFewShot.data_manager import DataManager 25 | from torchFewShot.losses import CrossEntropyLoss 26 | from torchFewShot.optimizers import init_optimizer 27 | 28 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 29 | from torchFewShot.utils.avgmeter import AverageMeter 30 | from torchFewShot.utils.logger import Logger 31 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 32 | 33 | 34 | parser = argument_parser() 35 | args = parser.parse_args() 36 | 37 | def main(): 38 | torch.manual_seed(args.seed) 39 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 40 | use_gpu = torch.cuda.is_available() 41 | 42 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) 43 | print("==========\nArgs:{}\n==========".format(args)) 44 | 45 | if use_gpu: 46 | print("Currently using GPU {}".format(args.gpu_devices)) 47 | cudnn.benchmark = True 48 | torch.cuda.manual_seed_all(args.seed) 49 | else: 50 | print("Currently using CPU (GPU is highly recommended)") 51 | 52 | print('Initializing image data manager') 53 | dm = DataManager(args, use_gpu) 54 | trainloader, testloader = dm.return_dataloaders() 55 | 56 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes) 57 | criterion = CrossEntropyLoss() 58 | optimizer = init_optimizer(args.optim, model.parameters(), args.lr, args.weight_decay) 59 | 60 | if use_gpu: 61 | model = model.cuda() 62 | 63 | start_time = time.time() 64 | train_time = 0 65 | best_acc = -np.inf 66 | best_epoch = 0 67 | print("==> Start training") 68 | 69 | for epoch in range(args.max_epoch): 70 | learning_rate = adjust_learning_rate(optimizer, epoch, args.LUT_lr) 71 | 72 | start_train_time = time.time() 73 | train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu) 74 | train_time += round(time.time() - start_train_time) 75 | 76 | if epoch == 0 or epoch > (args.stepsize[0]-1) or (epoch + 1) % 10 == 0: 77 | acc = test(model, testloader, use_gpu) 78 | is_best = acc > best_acc 79 | 80 | if is_best: 81 | best_acc = acc 82 | best_epoch = epoch + 1 83 | 84 | save_checkpoint({ 85 | 'state_dict': model.state_dict(), 86 | 'acc': acc, 87 | 'epoch': epoch, 88 | }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) 89 | 90 | print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".format(best_acc, best_epoch)) 91 | 92 | elapsed = round(time.time() - start_time) 93 | elapsed = str(datetime.timedelta(seconds=elapsed)) 94 | train_time = str(datetime.timedelta(seconds=train_time)) 95 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 96 | print("==========\nArgs:{}\n==========".format(args)) 97 | 98 | 99 | def train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu): 100 | losses = AverageMeter() 101 | batch_time = AverageMeter() 102 | data_time = AverageMeter() 103 | 104 | model.train() 105 | 106 | end = time.time() 107 | for batch_idx, (images_train, labels_train, images_test, labels_test, pids) in enumerate(trainloader): 108 | data_time.update(time.time() - end) 109 | 110 | if use_gpu: 111 | images_train, labels_train = images_train.cuda(), labels_train.cuda() 112 | images_test, labels_test = images_test.cuda(), labels_test.cuda() 113 | pids = pids.cuda() 114 | 115 | batch_size, num_train_examples, channels, height, width = images_train.size() 116 | num_test_examples = images_test.size(1) 117 | 118 | labels_train_1hot = one_hot(labels_train).cuda() 119 | labels_test_1hot = one_hot(labels_test).cuda() 120 | 121 | ytest, cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot) 122 | 123 | loss1 = criterion(ytest, pids.view(-1)) 124 | loss2 = criterion(cls_scores, labels_test.view(-1)) 125 | loss = loss1 + 0.5 * loss2 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | 131 | losses.update(loss.item(), pids.size(0)) 132 | batch_time.update(time.time() - end) 133 | end = time.time() 134 | 135 | print('Epoch{0} ' 136 | 'lr: {1} ' 137 | 'Time:{batch_time.sum:.1f}s ' 138 | 'Data:{data_time.sum:.1f}s ' 139 | 'Loss:{loss.avg:.4f} '.format( 140 | epoch+1, learning_rate, batch_time=batch_time, 141 | data_time=data_time, loss=losses)) 142 | 143 | 144 | def test(model, testloader, use_gpu): 145 | accs = AverageMeter() 146 | test_accuracies = [] 147 | model.eval() 148 | 149 | with torch.no_grad(): 150 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader): 151 | if use_gpu: 152 | images_train = images_train.cuda() 153 | images_test = images_test.cuda() 154 | 155 | end = time.time() 156 | 157 | batch_size, num_train_examples, channels, height, width = images_train.size() 158 | num_test_examples = images_test.size(1) 159 | 160 | labels_train_1hot = one_hot(labels_train).cuda() 161 | labels_test_1hot = one_hot(labels_test).cuda() 162 | 163 | cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot) 164 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1) 165 | labels_test = labels_test.view(batch_size * num_test_examples) 166 | 167 | _, preds = torch.max(cls_scores.detach().cpu(), 1) 168 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0) 169 | accs.update(acc.item(), labels_test.size(0)) 170 | 171 | gt = (preds == labels_test.detach().cpu()).float() 172 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n] 173 | acc = np.sum(gt, 1) / num_test_examples 174 | acc = np.reshape(acc, (batch_size)) 175 | test_accuracies.append(acc) 176 | 177 | accuracy = accs.avg 178 | test_accuracies = np.array(test_accuracies) 179 | test_accuracies = np.reshape(test_accuracies, -1) 180 | stds = np.std(test_accuracies, 0) 181 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size) 182 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95)) 183 | 184 | return accuracy 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | --------------------------------------------------------------------------------