├── README.md
├── algorithm.png
├── args_tiered.py
├── args_xent.py
├── test.py
├── test_transductive.py
├── torchFewShot
├── __init__.py
├── data_manager.py
├── dataset_loader
│ ├── __init__.py
│ ├── test_loader.py
│ └── train_loader.py
├── datasets
│ ├── __init__.py
│ ├── miniImageNet.py
│ ├── miniImageNet_load.py
│ └── tieredImageNet.py
├── losses.py
├── models
│ ├── __init__.py
│ ├── cam.py
│ ├── net.py
│ └── resnet12.py
├── optimizers.py
├── transforms.py
└── utils
│ ├── __init__.py
│ ├── avgmeter.py
│ ├── iotools.py
│ ├── logger.py
│ └── torchtools.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # fewshot-CAN
2 | This repository contains the code for the paper:
3 |
4 | [**Cross Attention Network for Few-shot Classification**](https://arxiv.org/pdf/1910.07677.pdf)
5 |
6 | Ruibing Hou, Hong Chang, Bingpeng Ma, Shiguang Shan, Xilin Chen
7 |
8 | NeurIPS 2019
9 |
10 |
11 |
12 |
13 | ### Abstract
14 |
15 | Few-shot classification aims to recognize unlabeled samples from unseen classes given only few labeled samples. The unseen classes and low-data problem make few-shot classification very challenging. Many existing approaches extracted features from labeled and unlabeled samples independently, as a result, the features are not discriminative enough. In this work, we propose a novel Cross Attention
16 | Network to address the challenging problems in few-shot classification. Firstly, Cross Attention Module is introduced to deal with the problem of unseen classes. The module generates cross attention maps for each pair of class feature and query sample feature so as to highlight the target object regions, making the extracted feature more discriminative. Secondly, a transductive inference algorithm is proposed to alleviate the low-data problem, which iteratively utilizes the unlabeled query set to augment the support set, thereby making the class features more representative. Extensive experiments on two benchmarks show our method is a simple, effective and computationally efficient framework and outperforms the state-of-the-arts.
17 |
18 | ### Citation
19 |
20 | If you use this code for your research, please cite our paper:
21 | ```
22 | @inproceedings{CAN,
23 | title={Cross Attention Network for Few-shot Classification},
24 | author={Ruibing Hou and Hong Chang and Bingpeng Ma and Shiguang Shan and Xilin Chen},
25 | booktitle={NeurIPS},
26 | year={2019}
27 | }
28 | ```
29 |
30 | ### Trained models:
31 | We have released our trained models of CAN on miniImageNet and tieredImageNet benchmark at [google drive](https://drive.google.com/drive/folders/1Gi3GvQB3Ypwu-QW_CVmToLC64fMv8TqO?usp=sharing).
32 |
33 | ### Platform
34 | This code was developed and tested with pytorch version 1.0.1
35 |
36 | ## Acknowledgments
37 |
38 | This code is based on the implementations of [**Dynamic Few-Shot Visual Learning without Forgetting**](https://github.com/gidariss/FewShotWithoutForgetting).
39 |
--------------------------------------------------------------------------------
/algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-blue272/fewshot-CAN/6055985e82944c305f0fbf4d1b3e0eef1833d06c/algorithm.png
--------------------------------------------------------------------------------
/args_tiered.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torchFewShot
3 |
4 | def argument_parser():
5 |
6 | parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
7 | # ************************************************************
8 | # Datasets (general)
9 | # ************************************************************
10 | parser.add_argument('-d', '--dataset', type=str, default='tieredImageNet')
11 | parser.add_argument('--load', default=False)
12 |
13 | parser.add_argument('-j', '--workers', default=4, type=int,
14 | help="number of data loading workers (default: 4)")
15 | parser.add_argument('--height', type=int, default=84,
16 | help="height of an image (default: 84)")
17 | parser.add_argument('--width', type=int, default=84,
18 | help="width of an image (default: 84)")
19 |
20 | # ************************************************************
21 | # Optimization options
22 | # ************************************************************
23 | parser.add_argument('--optim', type=str, default='sgd',
24 | help="optimization algorithm (see optimizers.py)")
25 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
26 | help="initial learning rate")
27 | parser.add_argument('--weight-decay', default=5e-04, type=float,
28 | help="weight decay (default: 5e-04)")
29 |
30 | parser.add_argument('--max-epoch', default=80, type=int,
31 | help="maximum epochs to run")
32 | parser.add_argument('--start-epoch', default=0, type=int,
33 | help="manual epoch number (useful on restarts)")
34 | parser.add_argument('--stepsize', default=[60], nargs='+', type=int,
35 | help="stepsize to decay learning rate")
36 | parser.add_argument('--LUT_lr', default=[(20, 0.1), (40, 0.01), (60, 0.001), (80, 0.0001)],
37 | help="multistep to decay learning rate")
38 |
39 | parser.add_argument('--train-batch', default=4, type=int,
40 | help="train batch size")
41 | parser.add_argument('--test-batch', default=8, type=int,
42 | help="test batch size")
43 |
44 |
45 | # ************************************************************
46 | # Architecture settings
47 | # ************************************************************
48 | parser.add_argument('--num_classes', type=int, default=351)
49 | parser.add_argument('--scale_cls', type=int, default=7)
50 |
51 | # ************************************************************
52 | # Miscs
53 | # ************************************************************
54 | parser.add_argument('--save-dir', type=str, default='./result/tieredImageNet/CAM/5-shot-v2/')
55 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
56 | parser.add_argument('--gpu-devices', default='0', type=str)
57 |
58 | # ************************************************************
59 | # FewShot settting
60 | # ************************************************************
61 | parser.add_argument('--nKnovel', type=int, default=5,
62 | help='number of novel categories')
63 | parser.add_argument('--nExemplars', type=int, default=5,
64 | help='number of training examples per novel category.')
65 |
66 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5,
67 | help='number of test examples for all the novel category when training')
68 | parser.add_argument('--train_epoch_size', type=int, default=13980,
69 | help='number of batches per epoch when training')
70 | parser.add_argument('--nTestNovel', type=int, default=15 * 5,
71 | help='number of test examples for all the novel category')
72 | parser.add_argument('--epoch_size', type=int, default=2000,
73 | help='number of batches per epoch')
74 |
75 | parser.add_argument('--phase', default='test', type=str,
76 | help='use test or val dataset to early stop')
77 | parser.add_argument('--seed', type=int, default=1)
78 |
79 | return parser
80 |
81 |
--------------------------------------------------------------------------------
/args_xent.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torchFewShot
3 |
4 | def argument_parser():
5 |
6 | parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
7 | # ************************************************************
8 | # Datasets (general)
9 | # ************************************************************
10 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load')
11 | parser.add_argument('--load', default=True)
12 |
13 | parser.add_argument('-j', '--workers', default=4, type=int,
14 | help="number of data loading workers (default: 4)")
15 | parser.add_argument('--height', type=int, default=84,
16 | help="height of an image (default: 84)")
17 | parser.add_argument('--width', type=int, default=84,
18 | help="width of an image (default: 84)")
19 |
20 | # ************************************************************
21 | # Optimization options
22 | # ************************************************************
23 | parser.add_argument('--optim', type=str, default='sgd',
24 | help="optimization algorithm (see optimizers.py)")
25 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
26 | help="initial learning rate")
27 | parser.add_argument('--weight-decay', default=5e-04, type=float,
28 | help="weight decay (default: 5e-04)")
29 |
30 | parser.add_argument('--max-epoch', default=90, type=int,
31 | help="maximum epochs to run")
32 | parser.add_argument('--start-epoch', default=0, type=int,
33 | help="manual epoch number (useful on restarts)")
34 | parser.add_argument('--stepsize', default=[60], nargs='+', type=int,
35 | help="stepsize to decay learning rate")
36 | parser.add_argument('--LUT_lr', default=[(60, 0.1), (70, 0.006), (80, 0.0012), (90, 0.00024)],
37 | help="multistep to decay learning rate")
38 |
39 | parser.add_argument('--train-batch', default=4, type=int,
40 | help="train batch size")
41 | parser.add_argument('--test-batch', default=4, type=int,
42 | help="test batch size")
43 |
44 | # ************************************************************
45 | # Architecture settings
46 | # ************************************************************
47 | parser.add_argument('--num_classes', type=int, default=64)
48 | parser.add_argument('--scale_cls', type=int, default=7)
49 |
50 | # ************************************************************
51 | # Miscs
52 | # ************************************************************
53 | parser.add_argument('--save-dir', type=str, default='./result/miniImageNet/CAM/5-shot-seed112/')
54 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
55 | parser.add_argument('--gpu-devices', default='2', type=str)
56 |
57 | # ************************************************************
58 | # FewShot settting
59 | # ************************************************************
60 | parser.add_argument('--nKnovel', type=int, default=5,
61 | help='number of novel categories')
62 | parser.add_argument('--nExemplars', type=int, default=5,
63 | help='number of training examples per novel category.')
64 |
65 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5,
66 | help='number of test examples for all the novel category when training')
67 | parser.add_argument('--train_epoch_size', type=int, default=1200,
68 | help='number of batches per epoch when training')
69 | parser.add_argument('--nTestNovel', type=int, default=15 * 5,
70 | help='number of test examples for all the novel category')
71 | parser.add_argument('--epoch_size', type=int, default=2000,
72 | help='number of batches per epoch')
73 |
74 | parser.add_argument('--phase', default='test', type=str,
75 | help='use test or val dataset to early stop')
76 | parser.add_argument('--seed', type=int, default=1)
77 |
78 | return parser
79 |
80 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
4 | import os
5 | import sys
6 | import time
7 | import datetime
8 | import argparse
9 | import os.path as osp
10 | import numpy as np
11 | import random
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 | from torch.utils.data import DataLoader
17 | from torch.optim import lr_scheduler
18 | import torch.nn.functional as F
19 | sys.path.append('./torchFewShot')
20 |
21 | from torchFewShot.models.net import Model
22 | from torchFewShot.data_manager import DataManager
23 | from torchFewShot.losses import CrossEntropyLoss
24 | from torchFewShot.optimizers import init_optimizer
25 |
26 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile
27 | from torchFewShot.utils.avgmeter import AverageMeter
28 | from torchFewShot.utils.logger import Logger
29 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate
30 |
31 | parser = argparse.ArgumentParser(description='Test image model with 5-way classification')
32 | # Datasets
33 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load')
34 | parser.add_argument('--load', default=True)
35 | parser.add_argument('-j', '--workers', default=4, type=int,
36 | help="number of data loading workers (default: 4)")
37 | parser.add_argument('--height', type=int, default=84,
38 | help="height of an image (default: 84)")
39 | parser.add_argument('--width', type=int, default=84,
40 | help="width of an image (default: 84)")
41 | # Optimization options
42 | parser.add_argument('--train-batch', default=4, type=int,
43 | help="train batch size")
44 | parser.add_argument('--test-batch', default=8, type=int,
45 | help="test batch size")
46 | # Architecture
47 | parser.add_argument('--num_classes', type=int, default=64)
48 | parser.add_argument('--scale_cls', type=int, default=7)
49 | parser.add_argument('--save-dir', type=str, default='')
50 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
51 | # FewShot settting
52 | parser.add_argument('--nKnovel', type=int, default=5,
53 | help='number of novel categories')
54 | parser.add_argument('--nExemplars', type=int, default=1,
55 | help='number of training examples per novel category.')
56 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5,
57 | help='number of test examples for all the novel category when training')
58 | parser.add_argument('--train_epoch_size', type=int, default=1200,
59 | help='number of episodes per epoch when training')
60 | parser.add_argument('--nTestNovel', type=int, default=15 * 5,
61 | help='number of test examples for all the novel category')
62 | parser.add_argument('--epoch_size', type=int, default=2000,
63 | help='number of batches per epoch')
64 | # Miscs
65 | parser.add_argument('--phase', default='test', type=str)
66 | parser.add_argument('--seed', type=int, default=1)
67 | parser.add_argument('--gpu-devices', default='1', type=str)
68 |
69 | args = parser.parse_args()
70 |
71 |
72 | def main():
73 | torch.manual_seed(args.seed)
74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
75 | use_gpu = torch.cuda.is_available()
76 |
77 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
78 | print("==========\nArgs:{}\n==========".format(args))
79 |
80 | if use_gpu:
81 | print("Currently using GPU {}".format(args.gpu_devices))
82 | cudnn.benchmark = True
83 | torch.cuda.manual_seed_all(args.seed)
84 | else:
85 | print("Currently using CPU (GPU is highly recommended)")
86 |
87 | print('Initializing image data manager')
88 | dm = DataManager(args, use_gpu)
89 | trainloader, testloader = dm.return_dataloaders()
90 |
91 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes)
92 | # load the model
93 | checkpoint = torch.load(args.resume)
94 | model.load_state_dict(checkpoint['state_dict'])
95 | print("Loaded checkpoint from '{}'".format(args.resume))
96 |
97 | if use_gpu:
98 | model = model.cuda()
99 |
100 | test(model, testloader, use_gpu)
101 |
102 |
103 | def test(model, testloader, use_gpu):
104 | accs = AverageMeter()
105 | test_accuracies = []
106 | model.eval()
107 |
108 | with torch.no_grad():
109 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
110 | if use_gpu:
111 | images_train = images_train.cuda()
112 | images_test = images_test.cuda()
113 |
114 | end = time.time()
115 |
116 | batch_size, num_train_examples, channels, height, width = images_train.size()
117 | num_test_examples = images_test.size(1)
118 |
119 | labels_train_1hot = one_hot(labels_train).cuda()
120 | labels_test_1hot = one_hot(labels_test).cuda()
121 |
122 | cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot)
123 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
124 | labels_test = labels_test.view(batch_size * num_test_examples)
125 |
126 | _, preds = torch.max(cls_scores.detach().cpu(), 1)
127 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
128 | accs.update(acc.item(), labels_test.size(0))
129 |
130 | gt = (preds == labels_test.detach().cpu()).float()
131 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
132 | acc = np.sum(gt, 1) / num_test_examples
133 | acc = np.reshape(acc, (batch_size))
134 | test_accuracies.append(acc)
135 |
136 | accuracy = accs.avg
137 | test_accuracies = np.array(test_accuracies)
138 | test_accuracies = np.reshape(test_accuracies, -1)
139 | stds = np.std(test_accuracies, 0)
140 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
141 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
142 |
143 | return accuracy
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/test_transductive.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
4 | import os
5 | import sys
6 | import time
7 | import datetime
8 | import argparse
9 | import os.path as osp
10 | import numpy as np
11 | import random
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 | from torch.utils.data import DataLoader
17 | from torch.optim import lr_scheduler
18 | import torch.nn.functional as F
19 | sys.path.append('./torchFewShot')
20 |
21 | from torchFewShot.models.net import Model
22 | from torchFewShot.data_manager import DataManager
23 | from torchFewShot.losses import CrossEntropyLoss
24 | from torchFewShot.optimizers import init_optimizer
25 |
26 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile
27 | from torchFewShot.utils.avgmeter import AverageMeter
28 | from torchFewShot.utils.logger import Logger
29 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate
30 |
31 | parser = argparse.ArgumentParser(description='Test image model with 5-way classification')
32 | # Datasets
33 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load')
34 | parser.add_argument('--load', default=True)
35 | parser.add_argument('-j', '--workers', default=4, type=int,
36 | help="number of data loading workers (default: 4)")
37 | parser.add_argument('--height', type=int, default=84,
38 | help="height of an image (default: 84)")
39 | parser.add_argument('--width', type=int, default=84,
40 | help="width of an image (default: 84)")
41 | # Optimization options
42 | parser.add_argument('--train-batch', default=4, type=int,
43 | help="train batch size")
44 | parser.add_argument('--test-batch', default=1, type=int,
45 | help="test batch size")
46 | # Architecture
47 | parser.add_argument('--num_classes', type=int, default=64)
48 | parser.add_argument('--scale_cls', type=int, default=7)
49 | parser.add_argument('--save-dir', type=str, default='')
50 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
51 | # FewShot settting
52 | parser.add_argument('--nKnovel', type=int, default=5,
53 | help='number of novel categories')
54 | parser.add_argument('--nExemplars', type=int, default=1,
55 | help='number of training examples per novel category.')
56 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5,
57 | help='number of test examples for all the novel category when training')
58 | parser.add_argument('--train_epoch_size', type=int, default=1200,
59 | help='number of episodes per epoch when training')
60 | parser.add_argument('--nTestNovel', type=int, default=15 * 5,
61 | help='number of test examples for all the novel category')
62 | parser.add_argument('--epoch_size', type=int, default=2000,
63 | help='number of batches per epoch')
64 | # Miscs
65 | parser.add_argument('--phase', default='test', type=str)
66 | parser.add_argument('--seed', type=int, default=1)
67 | parser.add_argument('--gpu-devices', default='3', type=str)
68 |
69 | args = parser.parse_args()
70 |
71 |
72 | def main():
73 | torch.manual_seed(args.seed)
74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
75 | use_gpu = torch.cuda.is_available()
76 |
77 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test_tranductive.txt'))
78 | print("==========\nArgs:{}\n==========".format(args))
79 |
80 | if use_gpu:
81 | print("Currently using GPU {}".format(args.gpu_devices))
82 | cudnn.benchmark = True
83 | torch.cuda.manual_seed_all(args.seed)
84 | else:
85 | print("Currently using CPU (GPU is highly recommended)")
86 |
87 | print('Initializing image data manager')
88 | dm = DataManager(args, use_gpu)
89 | trainloader, testloader = dm.return_dataloaders()
90 |
91 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes)
92 | # load the model
93 | checkpoint = torch.load(args.resume)
94 | model.load_state_dict(checkpoint['state_dict'])
95 | print("Loaded checkpoint from '{}'".format(args.resume))
96 |
97 | if use_gpu:
98 | model = model.cuda()
99 |
100 | test(model, testloader, use_gpu)
101 |
102 |
103 | def test(model, testloader, use_gpu):
104 | accs = AverageMeter()
105 | test_accuracies = []
106 | model.eval()
107 |
108 | with torch.no_grad():
109 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
110 | if use_gpu:
111 | images_train = images_train.cuda()
112 | images_test = images_test.cuda()
113 |
114 | end = time.time()
115 |
116 | batch_size, num_train_examples, channels, height, width = images_train.size()
117 | num_test_examples = images_test.size(1)
118 |
119 | labels_train_1hot = one_hot(labels_train).cuda()
120 | labels_test_1hot = one_hot(labels_test).cuda()
121 |
122 | cls_scores = model.test_transductive(images_train, images_test, labels_train_1hot, labels_test_1hot)
123 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
124 | labels_test = labels_test.view(batch_size * num_test_examples)
125 |
126 | _, preds = torch.max(cls_scores.detach().cpu(), 1)
127 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
128 | accs.update(acc.item(), labels_test.size(0))
129 |
130 | gt = (preds == labels_test.detach().cpu()).float()
131 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
132 | acc = np.sum(gt, 1) / num_test_examples
133 | acc = np.reshape(acc, (batch_size))
134 | test_accuracies.append(acc)
135 |
136 | accuracy = accs.avg
137 | test_accuracies = np.array(test_accuracies)
138 | test_accuracies = np.reshape(test_accuracies, -1)
139 | stds = np.std(test_accuracies, 0)
140 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
141 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
142 |
143 | return accuracy
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/torchFewShot/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/torchFewShot/data_manager.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | import transforms as T
7 | import datasets
8 | import dataset_loader
9 |
10 | class DataManager(object):
11 | """
12 | Few shot data manager
13 | """
14 |
15 | def __init__(self, args, use_gpu):
16 | super(DataManager, self).__init__()
17 | self.args = args
18 | self.use_gpu = use_gpu
19 |
20 | print("Initializing dataset {}".format(args.dataset))
21 | dataset = datasets.init_imgfewshot_dataset(name=args.dataset)
22 |
23 | if args.load:
24 | transform_train = T.Compose([
25 | T.RandomCrop(84, padding=8),
26 | T.RandomHorizontalFlip(),
27 | T.ToTensor(),
28 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29 | T.RandomErasing(0.5)
30 | ])
31 |
32 | transform_test = T.Compose([
33 | T.ToTensor(),
34 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35 | ])
36 |
37 | else:
38 | transform_train = T.Compose([
39 | T.Resize((args.height, args.width), interpolation=3),
40 | T.RandomCrop(args.height, padding=8),
41 | T.RandomHorizontalFlip(),
42 | T.ToTensor(),
43 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
44 | T.RandomErasing(0.5)
45 | ])
46 |
47 | transform_test = T.Compose([
48 | T.Resize((args.height, args.width), interpolation=3),
49 | T.ToTensor(),
50 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
51 | ])
52 |
53 | pin_memory = True if use_gpu else False
54 |
55 | self.trainloader = DataLoader(
56 | dataset_loader.init_loader(name='train_loader',
57 | dataset=dataset.train,
58 | labels2inds=dataset.train_labels2inds,
59 | labelIds=dataset.train_labelIds,
60 | nKnovel=args.nKnovel,
61 | nExemplars=args.nExemplars,
62 | nTestNovel=args.train_nTestNovel,
63 | epoch_size=args.train_epoch_size,
64 | transform=transform_train,
65 | load=args.load,
66 | ),
67 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers,
68 | pin_memory=pin_memory, drop_last=True,
69 | )
70 |
71 | self.valloader = DataLoader(
72 | dataset_loader.init_loader(name='test_loader',
73 | dataset=dataset.val,
74 | labels2inds=dataset.val_labels2inds,
75 | labelIds=dataset.val_labelIds,
76 | nKnovel=args.nKnovel,
77 | nExemplars=args.nExemplars,
78 | nTestNovel=args.nTestNovel,
79 | epoch_size=args.epoch_size,
80 | transform=transform_test,
81 | load=args.load,
82 | ),
83 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
84 | pin_memory=pin_memory, drop_last=False,
85 | )
86 | self.testloader = DataLoader(
87 | dataset_loader.init_loader(name='test_loader',
88 | dataset=dataset.test,
89 | labels2inds=dataset.test_labels2inds,
90 | labelIds=dataset.test_labelIds,
91 | nKnovel=args.nKnovel,
92 | nExemplars=args.nExemplars,
93 | nTestNovel=args.nTestNovel,
94 | epoch_size=args.epoch_size,
95 | transform=transform_test,
96 | load=args.load,
97 | ),
98 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
99 | pin_memory=pin_memory, drop_last=False,
100 | )
101 |
102 | def return_dataloaders(self):
103 | if self.args.phase == 'test':
104 | return self.trainloader, self.testloader
105 | elif self.args.phase == 'val':
106 | return self.trainloader, self.valloader
107 |
--------------------------------------------------------------------------------
/torchFewShot/dataset_loader/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .train_loader import FewShotDataset_train
4 | from .test_loader import FewShotDataset_test
5 |
6 |
7 | __loader_factory = {
8 | 'train_loader': FewShotDataset_train,
9 | 'test_loader': FewShotDataset_test,
10 | }
11 |
12 |
13 |
14 | def get_names():
15 | return list(__loader_factory.keys())
16 |
17 |
18 | def init_loader(name, *args, **kwargs):
19 | if name not in list(__loader_factory.keys()):
20 | raise KeyError("Unknown model: {}".format(name))
21 | return __loader_factory[name](*args, **kwargs)
22 |
23 |
--------------------------------------------------------------------------------
/torchFewShot/dataset_loader/test_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import os
6 | from PIL import Image
7 | import numpy as np
8 | import os.path as osp
9 | import lmdb
10 | import io
11 | import random
12 |
13 | import torch
14 | from torch.utils.data import Dataset
15 |
16 |
17 | def read_image(img_path):
18 | """Keep reading image until succeed.
19 | This can avoid IOError incurred by heavy IO process."""
20 | got_img = False
21 | if not osp.exists(img_path):
22 | raise IOError("{} does not exist".format(img_path))
23 | while not got_img:
24 | try:
25 | img = Image.open(img_path).convert('RGB')
26 | got_img = True
27 | except IOError:
28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
29 | pass
30 | return img
31 |
32 |
33 | class FewShotDataset_test(Dataset):
34 | """Few shot epoish Dataset
35 |
36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify'
37 | Xtrain: [nKnovel*nExpemplars, c, h, w].
38 | Ytrain: [nKnovel*nExpemplars].
39 | Xtest: [nTestNovel, c, h, w].
40 | Ytest: [nTestNovel].
41 | """
42 |
43 | def __init__(self,
44 | dataset, # dataset of [(img_path, cats), ...].
45 | labels2inds, # labels of index {(cats: index1, index2, ...)}.
46 | labelIds, # train labels [0, 1, 2, 3, ...,].
47 | nKnovel=5, # number of novel categories.
48 | nExemplars=1, # number of training examples per novel category.
49 | nTestNovel=2*5, # number of test examples for all the novel categories.
50 | epoch_size=2000, # number of tasks per eooch.
51 | transform=None,
52 | load=True,
53 | **kwargs
54 | ):
55 |
56 | self.dataset = dataset
57 | self.labels2inds = labels2inds
58 | self.labelIds = labelIds
59 | self.nKnovel = nKnovel
60 | self.transform = transform
61 |
62 | self.nExemplars = nExemplars
63 | self.nTestNovel = nTestNovel
64 | self.epoch_size = epoch_size
65 | self.load = load
66 |
67 | seed = 112
68 | random.seed(seed)
69 | np.random.seed(seed)
70 |
71 | self.Epoch_Exemplar = []
72 | self.Epoch_Tnovel = []
73 | for i in range(epoch_size):
74 | Tnovel, Exemplar = self._sample_episode()
75 | self.Epoch_Exemplar.append(Exemplar)
76 | self.Epoch_Tnovel.append(Tnovel)
77 |
78 | def __len__(self):
79 | return self.epoch_size
80 |
81 | def _sample_episode(self):
82 | """sampels a training epoish indexs.
83 | Returns:
84 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label)
85 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label)
86 | """
87 |
88 | Knovel = random.sample(self.labelIds, self.nKnovel)
89 | nKnovel = len(Knovel)
90 | assert((self.nTestNovel % nKnovel) == 0)
91 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel)
92 |
93 | Tnovel = []
94 | Exemplars = []
95 | for Knovel_idx in range(len(Knovel)):
96 | ids = (nEvalExamplesPerClass + self.nExemplars)
97 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids)
98 |
99 | imgs_tnovel = img_ids[:nEvalExamplesPerClass]
100 | imgs_emeplars = img_ids[nEvalExamplesPerClass:]
101 |
102 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel]
103 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars]
104 | assert(len(Tnovel) == self.nTestNovel)
105 | assert(len(Exemplars) == nKnovel * self.nExemplars)
106 | random.shuffle(Exemplars)
107 | random.shuffle(Tnovel)
108 |
109 | return Tnovel, Exemplars
110 |
111 | def _creatExamplesTensorData(self, examples):
112 | """
113 | Creats the examples image label tensor data.
114 |
115 | Args:
116 | examples: a list of 2-element tuples. (sample_index, label).
117 |
118 | Returns:
119 | images: a tensor [nExemplars, c, h, w]
120 | labels: a tensor [nExemplars]
121 | """
122 |
123 | images = []
124 | labels = []
125 | for (img_idx, label) in examples:
126 | img = self.dataset[img_idx][0]
127 | if self.load:
128 | img = Image.fromarray(img)
129 | else:
130 | img = read_image(img)
131 | if self.transform is not None:
132 | img = self.transform(img)
133 | images.append(img)
134 | labels.append(label)
135 | images = torch.stack(images, dim=0)
136 | labels = torch.LongTensor(labels)
137 | return images, labels
138 |
139 | def __getitem__(self, index):
140 | Tnovel = self.Epoch_Tnovel[index]
141 | Exemplars = self.Epoch_Exemplar[index]
142 | Xt, Yt = self._creatExamplesTensorData(Exemplars)
143 | Xe, Ye = self._creatExamplesTensorData(Tnovel)
144 | return Xt, Yt, Xe, Ye
145 |
146 |
--------------------------------------------------------------------------------
/torchFewShot/dataset_loader/train_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import os
6 | from PIL import Image
7 | import numpy as np
8 | import os.path as osp
9 | import lmdb
10 | import io
11 | import random
12 |
13 | import torch
14 | from torch.utils.data import Dataset
15 |
16 |
17 | def read_image(img_path):
18 | """Keep reading image until succeed.
19 | This can avoid IOError incurred by heavy IO process."""
20 | got_img = False
21 | if not osp.exists(img_path):
22 | raise IOError("{} does not exist".format(img_path))
23 | while not got_img:
24 | try:
25 | img = Image.open(img_path).convert('RGB')
26 | got_img = True
27 | except IOError:
28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
29 | pass
30 | return img
31 |
32 |
33 | class FewShotDataset_train(Dataset):
34 | """Few shot epoish Dataset
35 |
36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify'
37 | Xtrain: [nKnovel*nExpemplars, c, h, w].
38 | Ytrain: [nKnovel*nExpemplars].
39 | Xtest: [nTestNovel, c, h, w].
40 | Ytest: [nTestNovel].
41 | Ycls: [nTestNovel].
42 | """
43 |
44 | def __init__(self,
45 | dataset, # dataset of [(img_path, cats), ...].
46 | labels2inds, # labels of index {(cats: index1, index2, ...)}.
47 | labelIds, # train labels [0, 1, 2, 3, ...,].
48 | nKnovel=5, # number of novel categories.
49 | nExemplars=1, # number of training examples per novel category.
50 | nTestNovel=6*5, # number of test examples for all the novel categories.
51 | epoch_size=2000, # number of tasks per eooch.
52 | transform=None,
53 | load=False,
54 | **kwargs
55 | ):
56 |
57 | self.dataset = dataset
58 | self.labels2inds = labels2inds
59 | self.labelIds = labelIds
60 | self.nKnovel = nKnovel
61 | self.transform = transform
62 |
63 | self.nExemplars = nExemplars
64 | self.nTestNovel = nTestNovel
65 | self.epoch_size = epoch_size
66 | self.load = load
67 |
68 | def __len__(self):
69 | return self.epoch_size
70 |
71 | def _sample_episode(self):
72 | """sampels a training epoish indexs.
73 | Returns:
74 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label)
75 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label)
76 | """
77 |
78 | Knovel = random.sample(self.labelIds, self.nKnovel)
79 | nKnovel = len(Knovel)
80 | assert((self.nTestNovel % nKnovel) == 0)
81 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel)
82 |
83 | Tnovel = []
84 | Exemplars = []
85 | for Knovel_idx in range(len(Knovel)):
86 | ids = (nEvalExamplesPerClass + self.nExemplars)
87 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids)
88 |
89 | imgs_tnovel = img_ids[:nEvalExamplesPerClass]
90 | imgs_emeplars = img_ids[nEvalExamplesPerClass:]
91 |
92 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel]
93 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars]
94 | assert(len(Tnovel) == self.nTestNovel)
95 | assert(len(Exemplars) == nKnovel * self.nExemplars)
96 | random.shuffle(Exemplars)
97 | random.shuffle(Tnovel)
98 |
99 | return Tnovel, Exemplars
100 |
101 | def _creatExamplesTensorData(self, examples):
102 | """
103 | Creats the examples image label tensor data.
104 |
105 | Args:
106 | examples: a list of 2-element tuples. (sample_index, label).
107 |
108 | Returns:
109 | images: a tensor [nExemplars, c, h, w]
110 | labels: a tensor [nExemplars]
111 | cls: a tensor [nExemplars]
112 | """
113 |
114 | images = []
115 | labels = []
116 | cls = []
117 | for (img_idx, label) in examples:
118 | img, ids = self.dataset[img_idx]
119 | if self.load:
120 | img = Image.fromarray(img)
121 | else:
122 | img = read_image(img)
123 | if self.transform is not None:
124 | img = self.transform(img)
125 | images.append(img)
126 | labels.append(label)
127 | cls.append(ids)
128 | images = torch.stack(images, dim=0)
129 | labels = torch.LongTensor(labels)
130 | cls = torch.LongTensor(cls)
131 | return images, labels, cls
132 |
133 |
134 | def __getitem__(self, index):
135 | Tnovel, Exemplars = self._sample_episode()
136 | Xt, Yt, Ytc = self._creatExamplesTensorData(Exemplars)
137 | Xe, Ye, Yec = self._creatExamplesTensorData(Tnovel)
138 | return Xt, Yt, Xe, Ye, Yec
139 |
--------------------------------------------------------------------------------
/torchFewShot/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .miniImageNet import miniImageNet
6 | from .tieredImageNet import tieredImageNet
7 | from .miniImageNet_load import miniImageNet_load
8 |
9 |
10 | __imgfewshot_factory = {
11 | 'miniImageNet': miniImageNet,
12 | 'tieredImageNet': tieredImageNet,
13 | 'miniImageNet_load': miniImageNet_load,
14 | }
15 |
16 |
17 | def get_names():
18 | return list(__imgfewshot_factory.keys())
19 |
20 |
21 | def init_imgfewshot_dataset(name, **kwargs):
22 | if name not in list(__imgfewshot_factory.keys()):
23 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgfewshot_factory.keys())))
24 | return __imgfewshot_factory[name](**kwargs)
25 |
26 |
--------------------------------------------------------------------------------
/torchFewShot/datasets/miniImageNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import torch
7 |
8 | class miniImageNet(object):
9 | """
10 | Dataset statistics:
11 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test)
12 | """
13 | dataset_dir = '/home/houruibing/data/few-shot/mini-imagenet/'
14 |
15 | def __init__(self):
16 | super(miniImageNet, self).__init__()
17 | self.train_dir = os.path.join(self.dataset_dir, 'train')
18 | self.val_dir = os.path.join(self.dataset_dir, 'val')
19 | self.test_dir = os.path.join(self.dataset_dir, 'test')
20 |
21 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir)
22 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir)
23 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir)
24 |
25 | self.num_train_cats = len(self.train_labelIds)
26 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds)
27 | num_total_imgs = len(self.train + self.val + self.test)
28 |
29 | print("=> MiniImageNet loaded")
30 | print("Dataset statistics:")
31 | print(" ------------------------------")
32 | print(" subset | # cats | # images")
33 | print(" ------------------------------")
34 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train)))
35 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val)))
36 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test)))
37 | print(" ------------------------------")
38 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs))
39 | print(" ------------------------------")
40 |
41 | def _check_before_run(self):
42 | """Check if all files are available before going deeper"""
43 | if not osp.exists(self.dataset_dir):
44 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
45 | if not osp.exists(self.train_dir):
46 | raise RuntimeError("'{}' is not available".format(self.train_dir))
47 | if not osp.exists(self.val_dir):
48 | raise RuntimeError("'{}' is not available".format(self.val_dir))
49 | if not osp.exists(self.test_dir):
50 | raise RuntimeError("'{}' is not available".format(self.test_dir))
51 |
52 | def _process_dir(self, dir_path):
53 | cat_container = sorted(os.listdir(dir_path))
54 | cats2label = {cat:label for label, cat in enumerate(cat_container)}
55 |
56 | dataset = []
57 | labels = []
58 | for cat in cat_container:
59 | for img_path in sorted(os.listdir(os.path.join(dir_path, cat))):
60 | if '.jpg' not in img_path:
61 | continue
62 | label = cats2label[cat]
63 | dataset.append((os.path.join(dir_path, cat, img_path), label))
64 | labels.append(label)
65 |
66 | labels2inds = {}
67 | for idx, label in enumerate(labels):
68 | if label not in labels2inds:
69 | labels2inds[label] = []
70 | labels2inds[label].append(idx)
71 |
72 | labelIds = sorted(labels2inds.keys())
73 | return dataset, labels2inds, labelIds
74 |
--------------------------------------------------------------------------------
/torchFewShot/datasets/miniImageNet_load.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import torch
7 | import pickle
8 |
9 |
10 | def load_data(file):
11 | with open(file, 'rb') as fo:
12 | data = pickle.load(fo)
13 | return data
14 |
15 |
16 | def buildLabelIndex(labels):
17 | label2inds = {}
18 | for idx, label in enumerate(labels):
19 | if label not in label2inds:
20 | label2inds[label] = []
21 | label2inds[label].append(idx)
22 |
23 | return label2inds
24 |
25 |
26 | class miniImageNet_load(object):
27 | """
28 | Dataset statistics:
29 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test)
30 | """
31 | dataset_dir = '/home/houruibing/code/few-shot/MiniImagenet/'
32 |
33 | def __init__(self, **kwargs):
34 | super(miniImageNet_load, self).__init__()
35 | self.train_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_train_phase_train.pickle')
36 | self.val_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_val.pickle')
37 | self.test_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_test.pickle')
38 |
39 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir)
40 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir)
41 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir)
42 |
43 | self.num_train_cats = len(self.train_labelIds)
44 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds)
45 | num_total_imgs = len(self.train + self.val + self.test)
46 |
47 | print("=> MiniImageNet loaded")
48 | print("Dataset statistics:")
49 | print(" ------------------------------")
50 | print(" subset | # cats | # images")
51 | print(" ------------------------------")
52 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train)))
53 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val)))
54 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test)))
55 | print(" ------------------------------")
56 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs))
57 | print(" ------------------------------")
58 |
59 | def _check_before_run(self):
60 | """Check if all files are available before going deeper"""
61 | if not osp.exists(self.dataset_dir):
62 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
63 | if not osp.exists(self.train_dir):
64 | raise RuntimeError("'{}' is not available".format(self.train_dir))
65 | if not osp.exists(self.val_dir):
66 | raise RuntimeError("'{}' is not available".format(self.val_dir))
67 | if not osp.exists(self.test_dir):
68 | raise RuntimeError("'{}' is not available".format(self.test_dir))
69 |
70 | def _get_pair(self, data, labels):
71 | assert (data.shape[0] == len(labels))
72 | data_pair = []
73 | for i in range(data.shape[0]):
74 | data_pair.append((data[i], labels[i]))
75 | return data_pair
76 |
77 | def _process_dir(self, file_path):
78 | dataset = load_data(file_path)
79 | data = dataset['data']
80 | print(data.shape)
81 | labels = dataset['labels']
82 | data_pair = self._get_pair(data, labels)
83 | labels2inds = buildLabelIndex(labels)
84 | labelIds = sorted(labels2inds.keys())
85 | return data_pair, labels2inds, labelIds
86 |
87 | if __name__ == '__main__':
88 | miniImageNet_load()
89 |
--------------------------------------------------------------------------------
/torchFewShot/datasets/tieredImageNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import torch
7 |
8 | class tieredImageNet(object):
9 | """
10 | Dataset statistics:
11 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test)
12 | """
13 | dataset_dir = '/home/houruibing/data/few-shot/tieredImagenet/images/images/'
14 |
15 | def __init__(self):
16 | super(tieredImageNet, self).__init__()
17 | self.train_dir = os.path.join(self.dataset_dir, 'train')
18 | self.val_dir = os.path.join(self.dataset_dir, 'val')
19 | self.test_dir = os.path.join(self.dataset_dir, 'test')
20 |
21 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir)
22 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir)
23 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir)
24 |
25 | self.num_train_cats = len(self.train_labelIds)
26 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds)
27 | num_total_imgs = len(self.train + self.val + self.test)
28 |
29 | print("=> tieredImageNet loaded")
30 | print("Dataset statistics:")
31 | print(" ------------------------------")
32 | print(" subset | # cats | # images")
33 | print(" ------------------------------")
34 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train)))
35 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val)))
36 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test)))
37 | print(" ------------------------------")
38 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs))
39 | print(" ------------------------------")
40 |
41 | def _check_before_run(self):
42 | """Check if all files are available before going deeper"""
43 | if not osp.exists(self.dataset_dir):
44 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
45 | if not osp.exists(self.train_dir):
46 | raise RuntimeError("'{}' is not available".format(self.train_dir))
47 | if not osp.exists(self.val_dir):
48 | raise RuntimeError("'{}' is not available".format(self.val_dir))
49 | if not osp.exists(self.test_dir):
50 | raise RuntimeError("'{}' is not available".format(self.test_dir))
51 |
52 | def _process_dir(self, dir_path):
53 | cat_container = sorted(os.listdir(dir_path))
54 | cats2label = {cat:label for label, cat in enumerate(cat_container)}
55 |
56 | dataset = []
57 | labels = []
58 | for cat in cat_container:
59 | for img_path in sorted(os.listdir(os.path.join(dir_path, cat))):
60 | if '.JPEG' not in img_path:
61 | continue
62 | label = cats2label[cat]
63 | dataset.append((os.path.join(dir_path, cat, img_path), label))
64 | labels.append(label)
65 |
66 | labels2inds = {}
67 | for idx, label in enumerate(labels):
68 | if label not in labels2inds:
69 | labels2inds[label] = []
70 | labels2inds[label].append(idx)
71 |
72 | labelIds = sorted(labels2inds.keys())
73 | return dataset, labels2inds, labelIds
74 |
75 |
76 | if __name__ == '__main__':
77 | tieredImageNet()
78 |
--------------------------------------------------------------------------------
/torchFewShot/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class CrossEntropyLoss(nn.Module):
8 | def __init__(self):
9 | super(CrossEntropyLoss, self).__init__()
10 | self.logsoftmax = nn.LogSoftmax(dim=1)
11 |
12 | def forward(self, inputs, targets):
13 | inputs = inputs.view(inputs.size(0), inputs.size(1), -1)
14 |
15 | log_probs = self.logsoftmax(inputs)
16 | targets = torch.zeros(inputs.size(0), inputs.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
17 | targets = targets.unsqueeze(-1)
18 | targets = targets.cuda()
19 | loss = (- targets * log_probs).mean(0).sum()
20 | return loss / inputs.size(2)
21 |
--------------------------------------------------------------------------------
/torchFewShot/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/torchFewShot/models/cam.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import math
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 |
10 | class ConvBlock(nn.Module):
11 | """Basic convolutional block:
12 | convolution + batch normalization.
13 |
14 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
15 | - in_c (int): number of input channels.
16 | - out_c (int): number of output channels.
17 | - k (int or tuple): kernel size.
18 | - s (int or tuple): stride.
19 | - p (int or tuple): padding.
20 | """
21 | def __init__(self, in_c, out_c, k, s=1, p=0):
22 | super(ConvBlock, self).__init__()
23 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
24 | self.bn = nn.BatchNorm2d(out_c)
25 |
26 | def forward(self, x):
27 | return self.bn(self.conv(x))
28 |
29 |
30 | class CAM(nn.Module):
31 | def __init__(self):
32 | super(CAM, self).__init__()
33 | self.conv1 = ConvBlock(36, 6, 1)
34 | self.conv2 = nn.Conv2d(6, 36, 1, stride=1, padding=0)
35 |
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d):
38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
39 | m.weight.data.normal_(0, math.sqrt(2. / n))
40 |
41 | def get_attention(self, a):
42 | input_a = a
43 |
44 | a = a.mean(3)
45 | a = a.transpose(1, 3)
46 | a = F.relu(self.conv1(a))
47 | a = self.conv2(a)
48 | a = a.transpose(1, 3)
49 | a = a.unsqueeze(3)
50 |
51 | a = torch.mean(input_a * a, -1)
52 | a = F.softmax(a / 0.025, dim=-1) + 1
53 | return a
54 |
55 | def forward(self, f1, f2):
56 | b, n1, c, h, w = f1.size()
57 | n2 = f2.size(1)
58 |
59 | f1 = f1.view(b, n1, c, -1)
60 | f2 = f2.view(b, n2, c, -1)
61 |
62 | f1_norm = F.normalize(f1, p=2, dim=2, eps=1e-12)
63 | f2_norm = F.normalize(f2, p=2, dim=2, eps=1e-12)
64 |
65 | f1_norm = f1_norm.transpose(2, 3).unsqueeze(2)
66 | f2_norm = f2_norm.unsqueeze(1)
67 |
68 | a1 = torch.matmul(f1_norm, f2_norm)
69 | a2 = a1.transpose(3, 4)
70 |
71 | a1 = self.get_attention(a1)
72 | a2 = self.get_attention(a2)
73 |
74 | f1 = f1.unsqueeze(2) * a1.unsqueeze(3)
75 | f1 = f1.view(b, n1, n2, c, h, w)
76 | f2 = f2.unsqueeze(1) * a2.unsqueeze(3)
77 | f2 = f2.view(b, n1, n2, c, h, w)
78 |
79 | return f1.transpose(1, 2), f2.transpose(1, 2)
80 |
--------------------------------------------------------------------------------
/torchFewShot/models/net.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from resnet12 import resnet12
7 | from cam import CAM
8 |
9 |
10 | def one_hot(labels_train):
11 | labels_train = labels_train.cpu()
12 | nKnovel = 5
13 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,]
14 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim())
15 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1)
16 | return labels_train_1hot
17 |
18 |
19 | class Model(nn.Module):
20 | def __init__(self, scale_cls, iter_num_prob=35.0/75, num_classes=64):
21 | super(Model, self).__init__()
22 | self.scale_cls = scale_cls
23 | self.iter_num_prob = iter_num_prob
24 |
25 | self.base = resnet12()
26 | self.cam = CAM()
27 |
28 | self.nFeat = self.base.nFeat
29 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1)
30 |
31 | def test(self, ftrain, ftest):
32 | ftest = ftest.mean(4)
33 | ftest = ftest.mean(4)
34 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12)
35 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12)
36 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1)
37 | return scores
38 |
39 | def forward(self, xtrain, xtest, ytrain, ytest):
40 | batch_size, num_train = xtrain.size(0), xtrain.size(1)
41 | num_test = xtest.size(1)
42 | K = ytrain.size(2)
43 | ytrain = ytrain.transpose(1, 2)
44 |
45 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))
46 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))
47 | x = torch.cat((xtrain, xtest), 0)
48 | f = self.base(x)
49 |
50 | ftrain = f[:batch_size * num_train]
51 | ftrain = ftrain.view(batch_size, num_train, -1)
52 | ftrain = torch.bmm(ytrain, ftrain)
53 | ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain))
54 | ftrain = ftrain.view(batch_size, -1, *f.size()[1:])
55 | ftest = f[batch_size * num_train:]
56 | ftest = ftest.view(batch_size, num_test, *f.size()[1:])
57 | ftrain, ftest = self.cam(ftrain, ftest)
58 | ftrain = ftrain.mean(4)
59 | ftrain = ftrain.mean(4)
60 |
61 | if not self.training:
62 | return self.test(ftrain, ftest)
63 |
64 | ftest_norm = F.normalize(ftest, p=2, dim=3, eps=1e-12)
65 | ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12)
66 | ftrain_norm = ftrain_norm.unsqueeze(4)
67 | ftrain_norm = ftrain_norm.unsqueeze(5)
68 | cls_scores = self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3)
69 | cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])
70 |
71 | ftest = ftest.view(batch_size, num_test, K, -1)
72 | ftest = ftest.transpose(2, 3)
73 | ytest = ytest.unsqueeze(3)
74 | ftest = torch.matmul(ftest, ytest)
75 | ftest = ftest.view(batch_size * num_test, -1, 6, 6)
76 | ytest = self.clasifier(ftest)
77 | return ytest, cls_scores
78 |
79 | def helper(self, ftrain, ftest, ytrain):
80 | b, n, c, h, w = ftrain.size()
81 | k = ytrain.size(2)
82 |
83 | ytrain_transposed = ytrain.transpose(1, 2)
84 | ftrain = torch.bmm(ytrain_transposed, ftrain.view(b, n, -1))
85 | ftrain = ftrain.div(ytrain_transposed.sum(dim=2, keepdim=True).expand_as(ftrain))
86 | ftrain = ftrain.view(b, -1, c, h, w)
87 |
88 | ftrain, ftest = self.cam(ftrain, ftest)
89 | ftrain = ftrain.mean(-1).mean(-1)
90 | ftest = ftest.mean(-1).mean(-1)
91 |
92 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12)
93 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12)
94 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1)
95 | return scores
96 |
97 | def test_transductive(self, xtrain, xtest, ytrain, ytest):
98 | iter_num_prob = self.iter_num_prob
99 | batch_size, num_train = xtrain.size(0), xtrain.size(1)
100 | num_test = xtest.size(1)
101 | K = ytrain.size(2)
102 |
103 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))
104 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))
105 | x = torch.cat((xtrain, xtest), 0)
106 | f = self.base(x)
107 |
108 | ftrain = f[: batch_size*num_train].view(batch_size, num_train, *f.size()[1:])
109 | ftest = f[batch_size*num_train:].view(batch_size, num_test, *f.size()[1:])
110 | cls_scores = self.helper(ftrain, ftest, ytrain)
111 |
112 | num_images_per_iter = int(num_test * iter_num_prob)
113 | num_iter = num_test // num_images_per_iter
114 |
115 | for i in range(num_iter):
116 | max_scores, preds = torch.max(cls_scores, 2)
117 | chose_index = torch.argsort(max_scores.view(-1), descending=True)
118 | chose_index = chose_index[: num_images_per_iter * (i + 1)]
119 |
120 | ftest_iter = ftest[0, chose_index].unsqueeze(0)
121 | preds_iter = preds[0, chose_index].unsqueeze(0)
122 | preds_iter = one_hot(preds_iter).cuda()
123 |
124 | ftrain_iter = torch.cat((ftrain, ftest_iter), 1)
125 | ytrain_iter = torch.cat((ytrain, preds_iter), 1)
126 | cls_scores = self.helper(ftrain_iter, ftest, ytrain_iter)
127 |
128 | return cls_scores
129 |
130 |
131 |
132 |
133 | if __name__ == '__main__':
134 | torch.manual_seed(0)
135 |
136 | net = Model(scale_cls=7)
137 | net.eval()
138 |
139 | x1 = torch.rand(1, 5, 3, 84, 84)
140 | x2 = torch.rand(1, 75, 3, 84, 84)
141 | y1 = torch.rand(1, 5, 5)
142 | y2 = torch.rand(1, 75, 5)
143 |
144 | y1 = net.test_transductive(x1, x2, y1, y2)
145 | print(y1.size())
146 |
--------------------------------------------------------------------------------
/torchFewShot/models/resnet12.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=1, bias=False)
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, inplanes, planes, kernel=3, stride=1, downsample=None):
16 | super(BasicBlock, self).__init__()
17 | if kernel == 1:
18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
19 | elif kernel == 3:
20 | self.conv1 = conv3x3(inplanes, planes)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | if kernel == 1:
26 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False)
27 | elif kernel == 3:
28 | self.conv3 = conv3x3(planes, planes)
29 | self.bn3 = nn.BatchNorm2d(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv3(out)
46 | out = self.bn3(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, kernel=1, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 |
98 | def __init__(self, block, layers, kernel=3):
99 | self.inplanes = 64
100 | self.kernel = kernel
101 | super(ResNet, self).__init__()
102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 |
106 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
110 |
111 | self.nFeat = 512 * block.expansion
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | elif isinstance(m, nn.BatchNorm2d):
118 | m.weight.data.fill_(1)
119 | m.bias.data.zero_()
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, self.kernel, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes, self.kernel))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.layer1(x)
143 | x = self.layer2(x)
144 | x = self.layer3(x)
145 | x = self.layer4(x)
146 |
147 | return x
148 |
149 |
150 | def resnet12():
151 | model = ResNet(BasicBlock, [1,1,1,1], kernel=3)
152 | return model
153 |
--------------------------------------------------------------------------------
/torchFewShot/optimizers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 |
6 | def init_optimizer(optim, params, lr, weight_decay):
7 | if optim == 'adam':
8 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
9 | elif optim == 'amsgrad':
10 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
11 | elif optim == 'sgd':
12 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True)
13 | elif optim == 'rmsprop':
14 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
15 | else:
16 | raise KeyError("Unsupported optimizer: {}".format(optim))
17 |
--------------------------------------------------------------------------------
/torchFewShot/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | from torchvision.transforms import *
5 |
6 | from PIL import Image
7 | import random
8 | import numpy as np
9 | import math
10 | import torch
11 |
12 |
13 | class Random2DTranslation(object):
14 | """
15 | With a probability, first increase image size to (1 + 1/8), and then perform random crop.
16 |
17 | Args:
18 | - height (int): target height.
19 | - width (int): target width.
20 | - p (float): probability of performing this transformation. Default: 0.5.
21 | """
22 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
23 | self.height = height
24 | self.width = width
25 | self.p = p
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | """
30 | Args:
31 | - img (PIL Image): Image to be cropped.
32 | """
33 | if random.uniform(0, 1) > self.p:
34 | return img.resize((self.width, self.height), self.interpolation)
35 |
36 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
37 | resized_img = img.resize((new_width, new_height), self.interpolation)
38 | x_maxrange = new_width - self.width
39 | y_maxrange = new_height - self.height
40 | x1 = int(round(random.uniform(0, x_maxrange)))
41 | y1 = int(round(random.uniform(0, y_maxrange)))
42 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
43 | return croped_img
44 |
45 |
46 | class RandomErasing(object):
47 | """ Randomly selects a rectangle region in an image and erases its pixels.
48 | 'Random Erasing Data Augmentation' by Zhong et al.
49 | See https://arxiv.org/pdf/1708.04896.pdf
50 | Args:
51 | probability: The probability that the Random Erasing operation will be performed.
52 | sl: Minimum proportion of erased area against input image.
53 | sh: Maximum proportion of erased area against input image.
54 | r1: Minimum aspect ratio of erased area.
55 | mean: Erasing value.
56 | """
57 |
58 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
59 | self.probability = probability
60 | self.mean = mean
61 | self.sl = sl
62 | self.sh = sh
63 | self.r1 = r1
64 |
65 | def __call__(self, img):
66 |
67 | if random.uniform(0, 1) > self.probability:
68 | return img
69 |
70 | for attempt in range(100):
71 | area = img.size()[1] * img.size()[2]
72 |
73 | target_area = random.uniform(self.sl, self.sh) * area
74 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
75 |
76 | h = int(round(math.sqrt(target_area * aspect_ratio)))
77 | w = int(round(math.sqrt(target_area / aspect_ratio)))
78 |
79 | if w < img.size()[2] and h < img.size()[1]:
80 | x1 = random.randint(0, img.size()[1] - h)
81 | y1 = random.randint(0, img.size()[2] - w)
82 | if img.size()[0] == 3:
83 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
84 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
85 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
86 | else:
87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
88 | return img
89 |
90 | return img
91 |
--------------------------------------------------------------------------------
/torchFewShot/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/torchFewShot/utils/avgmeter.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value.
7 |
8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
9 | """
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/torchFewShot/utils/iotools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import os.path as osp
5 | import errno
6 | import json
7 | import shutil
8 |
9 | import torch
10 |
11 |
12 | def mkdir_if_missing(directory):
13 | if not osp.exists(directory):
14 | try:
15 | os.makedirs(directory)
16 | except OSError as e:
17 | if e.errno != errno.EEXIST:
18 | raise
19 |
20 |
21 | def check_isfile(path):
22 | isfile = osp.isfile(path)
23 | if not isfile:
24 | print("=> Warning: no file found at '{}' (ignored)".format(path))
25 | return isfile
26 |
27 |
28 | def read_json(fpath):
29 | with open(fpath, 'r') as f:
30 | obj = json.load(f)
31 | return obj
32 |
33 |
34 | def write_json(obj, fpath):
35 | mkdir_if_missing(osp.dirname(fpath))
36 | with open(fpath, 'w') as f:
37 | json.dump(obj, f, indent=4, separators=(',', ': '))
38 |
39 |
40 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'):
41 | if len(osp.dirname(fpath)) != 0:
42 | mkdir_if_missing(osp.dirname(fpath))
43 | torch.save(state, fpath)
44 | if is_best:
45 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
--------------------------------------------------------------------------------
/torchFewShot/utils/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import sys
4 | import os
5 | import os.path as osp
6 |
7 | from .iotools import mkdir_if_missing
8 |
9 |
10 | class Logger(object):
11 | """
12 | Write console output to external text file.
13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
14 | """
15 | def __init__(self, fpath=None, mode='a'):
16 | self.console = sys.stdout
17 | self.file = None
18 | if fpath is not None:
19 | mkdir_if_missing(osp.dirname(fpath))
20 | self.file = open(fpath, mode)
21 |
22 | def __del__(self):
23 | self.close()
24 |
25 | def __enter__(self):
26 | pass
27 |
28 | def __exit__(self, *args):
29 | self.close()
30 |
31 | def write(self, msg):
32 | self.console.write(msg)
33 | if self.file is not None:
34 | self.file.write(msg)
35 |
36 | def flush(self):
37 | self.console.flush()
38 | if self.file is not None:
39 | self.file.flush()
40 | os.fsync(self.file.fileno())
41 |
42 | def close(self):
43 | self.console.close()
44 | if self.file is not None:
45 | self.file.close()
46 |
--------------------------------------------------------------------------------
/torchFewShot/utils/torchtools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def open_all_layers(model):
9 | """
10 | Open all layers in model for training.
11 | """
12 | model.train()
13 | for p in model.parameters():
14 | p.requires_grad = True
15 |
16 |
17 | def open_specified_layers(model, open_layers):
18 | """
19 | Open specified layers in model for training while keeping
20 | other layers frozen.
21 |
22 | Args:
23 | - model (nn.Module): neural net model.
24 | - open_layers (list): list of layers names.
25 | """
26 | if isinstance(model, nn.DataParallel):
27 | model = model.module
28 |
29 | for layer in open_layers:
30 | assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer)
31 |
32 | for name, module in model.named_children():
33 | if name in open_layers:
34 | #print(module)
35 | module.train()
36 | for p in module.parameters():
37 | p.requires_grad = True
38 | else:
39 | module.eval()
40 | for p in module.parameters():
41 | p.requires_grad = False
42 |
43 |
44 |
45 | def adjust_learning_rate(optimizer, iters, LUT):
46 | # decay learning rate by 'gamma' for every 'stepsize'
47 | for (stepvalue, base_lr) in LUT:
48 | if iters < stepvalue:
49 | lr = base_lr
50 | break
51 |
52 | for param_group in optimizer.param_groups:
53 | param_group['lr'] = lr
54 | return lr
55 |
56 |
57 | def adjust_lambda(iters, LUT):
58 | for (stepvalue, base_lambda) in LUT:
59 | if iters < stepvalue:
60 | lambda_xent = base_lambda
61 | break
62 | return lambda_xent
63 |
64 |
65 | def set_bn_to_eval(m):
66 | # 1. no update for running mean and var
67 | # 2. scale and shift parameters are still trainable
68 | classname = m.__class__.__name__
69 | if classname.find('BatchNorm') != -1:
70 | m.eval()
71 |
72 |
73 | def count_num_param(model):
74 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06
75 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
76 | # we ignore the classifier because it is unused at test time
77 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
78 | return num_param
79 |
80 |
81 | def one_hot(labels_train):
82 | """
83 | Turn the labels_train to one-hot encoding.
84 | Args:
85 | labels_train: [batch_size, num_train_examples]
86 | Return:
87 | labels_train_1hot: [batch_size, num_train_examples, K]
88 | """
89 | labels_train = labels_train.cpu()
90 | nKnovel = 1 + labels_train.max()
91 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,]
92 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim())
93 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1)
94 | return labels_train_1hot
95 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
4 | import os
5 | import sys
6 | import time
7 | import datetime
8 | import argparse
9 | import os.path as osp
10 | import numpy as np
11 | import random
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 | from torch.utils.data import DataLoader
17 | from torch.optim import lr_scheduler
18 | import torch.nn.functional as F
19 | sys.path.append('./torchFewShot')
20 |
21 | from args_tiered import argument_parser
22 |
23 | from torchFewShot.models.net import Model
24 | from torchFewShot.data_manager import DataManager
25 | from torchFewShot.losses import CrossEntropyLoss
26 | from torchFewShot.optimizers import init_optimizer
27 |
28 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile
29 | from torchFewShot.utils.avgmeter import AverageMeter
30 | from torchFewShot.utils.logger import Logger
31 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate
32 |
33 |
34 | parser = argument_parser()
35 | args = parser.parse_args()
36 |
37 | def main():
38 | torch.manual_seed(args.seed)
39 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
40 | use_gpu = torch.cuda.is_available()
41 |
42 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
43 | print("==========\nArgs:{}\n==========".format(args))
44 |
45 | if use_gpu:
46 | print("Currently using GPU {}".format(args.gpu_devices))
47 | cudnn.benchmark = True
48 | torch.cuda.manual_seed_all(args.seed)
49 | else:
50 | print("Currently using CPU (GPU is highly recommended)")
51 |
52 | print('Initializing image data manager')
53 | dm = DataManager(args, use_gpu)
54 | trainloader, testloader = dm.return_dataloaders()
55 |
56 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes)
57 | criterion = CrossEntropyLoss()
58 | optimizer = init_optimizer(args.optim, model.parameters(), args.lr, args.weight_decay)
59 |
60 | if use_gpu:
61 | model = model.cuda()
62 |
63 | start_time = time.time()
64 | train_time = 0
65 | best_acc = -np.inf
66 | best_epoch = 0
67 | print("==> Start training")
68 |
69 | for epoch in range(args.max_epoch):
70 | learning_rate = adjust_learning_rate(optimizer, epoch, args.LUT_lr)
71 |
72 | start_train_time = time.time()
73 | train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu)
74 | train_time += round(time.time() - start_train_time)
75 |
76 | if epoch == 0 or epoch > (args.stepsize[0]-1) or (epoch + 1) % 10 == 0:
77 | acc = test(model, testloader, use_gpu)
78 | is_best = acc > best_acc
79 |
80 | if is_best:
81 | best_acc = acc
82 | best_epoch = epoch + 1
83 |
84 | save_checkpoint({
85 | 'state_dict': model.state_dict(),
86 | 'acc': acc,
87 | 'epoch': epoch,
88 | }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
89 |
90 | print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".format(best_acc, best_epoch))
91 |
92 | elapsed = round(time.time() - start_time)
93 | elapsed = str(datetime.timedelta(seconds=elapsed))
94 | train_time = str(datetime.timedelta(seconds=train_time))
95 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
96 | print("==========\nArgs:{}\n==========".format(args))
97 |
98 |
99 | def train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu):
100 | losses = AverageMeter()
101 | batch_time = AverageMeter()
102 | data_time = AverageMeter()
103 |
104 | model.train()
105 |
106 | end = time.time()
107 | for batch_idx, (images_train, labels_train, images_test, labels_test, pids) in enumerate(trainloader):
108 | data_time.update(time.time() - end)
109 |
110 | if use_gpu:
111 | images_train, labels_train = images_train.cuda(), labels_train.cuda()
112 | images_test, labels_test = images_test.cuda(), labels_test.cuda()
113 | pids = pids.cuda()
114 |
115 | batch_size, num_train_examples, channels, height, width = images_train.size()
116 | num_test_examples = images_test.size(1)
117 |
118 | labels_train_1hot = one_hot(labels_train).cuda()
119 | labels_test_1hot = one_hot(labels_test).cuda()
120 |
121 | ytest, cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot)
122 |
123 | loss1 = criterion(ytest, pids.view(-1))
124 | loss2 = criterion(cls_scores, labels_test.view(-1))
125 | loss = loss1 + 0.5 * loss2
126 |
127 | optimizer.zero_grad()
128 | loss.backward()
129 | optimizer.step()
130 |
131 | losses.update(loss.item(), pids.size(0))
132 | batch_time.update(time.time() - end)
133 | end = time.time()
134 |
135 | print('Epoch{0} '
136 | 'lr: {1} '
137 | 'Time:{batch_time.sum:.1f}s '
138 | 'Data:{data_time.sum:.1f}s '
139 | 'Loss:{loss.avg:.4f} '.format(
140 | epoch+1, learning_rate, batch_time=batch_time,
141 | data_time=data_time, loss=losses))
142 |
143 |
144 | def test(model, testloader, use_gpu):
145 | accs = AverageMeter()
146 | test_accuracies = []
147 | model.eval()
148 |
149 | with torch.no_grad():
150 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
151 | if use_gpu:
152 | images_train = images_train.cuda()
153 | images_test = images_test.cuda()
154 |
155 | end = time.time()
156 |
157 | batch_size, num_train_examples, channels, height, width = images_train.size()
158 | num_test_examples = images_test.size(1)
159 |
160 | labels_train_1hot = one_hot(labels_train).cuda()
161 | labels_test_1hot = one_hot(labels_test).cuda()
162 |
163 | cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot)
164 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
165 | labels_test = labels_test.view(batch_size * num_test_examples)
166 |
167 | _, preds = torch.max(cls_scores.detach().cpu(), 1)
168 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
169 | accs.update(acc.item(), labels_test.size(0))
170 |
171 | gt = (preds == labels_test.detach().cpu()).float()
172 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
173 | acc = np.sum(gt, 1) / num_test_examples
174 | acc = np.reshape(acc, (batch_size))
175 | test_accuracies.append(acc)
176 |
177 | accuracy = accs.avg
178 | test_accuracies = np.array(test_accuracies)
179 | test_accuracies = np.reshape(test_accuracies, -1)
180 | stds = np.std(test_accuracies, 0)
181 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
182 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
183 |
184 | return accuracy
185 |
186 |
187 | if __name__ == '__main__':
188 | main()
189 |
--------------------------------------------------------------------------------